stage-10: memory search with Ollama embeddings + SQLite
Semantic search over memory/*.md files using all-minilm embeddings. Adds /search Discord command and `echo memory search/reindex` CLI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -124,6 +124,7 @@ def create_bot(config: Config) -> discord.Client:
|
||||
"`/logs [n]` — Show last N log lines (default 10)",
|
||||
"`/restart` — Restart the bot process (owner only)",
|
||||
"`/heartbeat` — Run heartbeat health checks",
|
||||
"`/search <query>` — Search Echo's memory",
|
||||
"",
|
||||
"**Cron Jobs**",
|
||||
"`/cron list` — List all scheduled jobs",
|
||||
@@ -413,6 +414,41 @@ def create_bot(config: Config) -> discord.Client:
|
||||
f"Heartbeat error: {e}", ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="search", description="Search Echo's memory")
|
||||
@app_commands.describe(query="What to search for")
|
||||
async def search_cmd(
|
||||
interaction: discord.Interaction, query: str
|
||||
) -> None:
|
||||
await interaction.response.defer()
|
||||
try:
|
||||
from src.memory_search import search
|
||||
|
||||
results = await asyncio.to_thread(search, query)
|
||||
if not results:
|
||||
await interaction.followup.send(
|
||||
"No results found (index may be empty — run `echo memory reindex`)."
|
||||
)
|
||||
return
|
||||
|
||||
lines = [f"**Search results for:** {query}\n"]
|
||||
for i, r in enumerate(results, 1):
|
||||
score = r["score"]
|
||||
preview = r["chunk"][:150]
|
||||
if len(r["chunk"]) > 150:
|
||||
preview += "..."
|
||||
lines.append(
|
||||
f"**{i}.** `{r['file']}` (score: {score:.3f})\n{preview}\n"
|
||||
)
|
||||
text = "\n".join(lines)
|
||||
if len(text) > 1900:
|
||||
text = text[:1900] + "\n..."
|
||||
await interaction.followup.send(text)
|
||||
except ConnectionError as e:
|
||||
await interaction.followup.send(f"Search error: {e}")
|
||||
except Exception as e:
|
||||
logger.exception("Search command failed")
|
||||
await interaction.followup.send(f"Search error: {e}")
|
||||
|
||||
@tree.command(name="channels", description="List registered channels")
|
||||
async def channels(interaction: discord.Interaction) -> None:
|
||||
ch_map = config.get("channels", {})
|
||||
|
||||
210
src/memory_search.py
Normal file
210
src/memory_search.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Echo Core memory search — semantic search over memory/*.md files.
|
||||
|
||||
Uses Ollama all-minilm embeddings stored in SQLite for cosine similarity search.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sqlite3
|
||||
import struct
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_URL = "http://10.0.20.161:11434/api/embeddings"
|
||||
OLLAMA_MODEL = "all-minilm"
|
||||
EMBEDDING_DIM = 384
|
||||
DB_PATH = Path(__file__).resolve().parent.parent / "memory" / "echo.sqlite"
|
||||
MEMORY_DIR = Path(__file__).resolve().parent.parent / "memory"
|
||||
|
||||
_CHUNK_TARGET = 500
|
||||
_CHUNK_MAX = 1000
|
||||
_CHUNK_MIN = 100
|
||||
|
||||
|
||||
def get_db() -> sqlite3.Connection:
|
||||
"""Get SQLite connection, create table if needed."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS chunks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_path TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
UNIQUE(file_path, chunk_index)
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_file_path ON chunks(file_path)"
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector from Ollama. Returns list of 384 floats."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
OLLAMA_URL,
|
||||
json={"model": OLLAMA_MODEL, "prompt": text},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embedding = resp.json()["embedding"]
|
||||
if len(embedding) != EMBEDDING_DIM:
|
||||
raise ValueError(
|
||||
f"Expected {EMBEDDING_DIM} dimensions, got {len(embedding)}"
|
||||
)
|
||||
return embedding
|
||||
except httpx.ConnectError:
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to Ollama at {OLLAMA_URL}. Is Ollama running?"
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ConnectionError(f"Ollama API error: {e.response.status_code}")
|
||||
|
||||
|
||||
def serialize_embedding(embedding: list[float]) -> bytes:
|
||||
"""Pack floats to bytes for SQLite storage."""
|
||||
return struct.pack(f"{len(embedding)}f", *embedding)
|
||||
|
||||
|
||||
def deserialize_embedding(data: bytes) -> list[float]:
|
||||
"""Unpack bytes to floats."""
|
||||
n = len(data) // 4
|
||||
return list(struct.unpack(f"{n}f", data))
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def chunk_file(file_path: Path) -> list[str]:
|
||||
"""Split .md file into chunks of ~500 chars."""
|
||||
text = file_path.read_text(encoding="utf-8")
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
# Split by double newlines or headers
|
||||
raw_parts: list[str] = []
|
||||
current = ""
|
||||
for line in text.split("\n"):
|
||||
# Split on headers or empty lines (paragraph boundaries)
|
||||
if line.startswith("#") and current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
current = line + "\n"
|
||||
elif line.strip() == "" and current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
current = ""
|
||||
else:
|
||||
current += line + "\n"
|
||||
if current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
|
||||
# Merge small chunks with next, split large ones
|
||||
chunks: list[str] = []
|
||||
buffer = ""
|
||||
for part in raw_parts:
|
||||
if buffer and len(buffer) + len(part) + 1 > _CHUNK_MAX:
|
||||
chunks.append(buffer)
|
||||
buffer = part
|
||||
elif buffer:
|
||||
buffer = buffer + "\n\n" + part
|
||||
else:
|
||||
buffer = part
|
||||
|
||||
# If buffer exceeds max, flush
|
||||
if len(buffer) > _CHUNK_MAX:
|
||||
chunks.append(buffer)
|
||||
buffer = ""
|
||||
|
||||
if buffer:
|
||||
# Merge tiny trailing chunk with previous
|
||||
if len(buffer) < _CHUNK_MIN and chunks:
|
||||
chunks[-1] = chunks[-1] + "\n\n" + buffer
|
||||
else:
|
||||
chunks.append(buffer)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def index_file(file_path: Path) -> int:
|
||||
"""Index a single file. Returns number of chunks created."""
|
||||
rel_path = str(file_path.relative_to(MEMORY_DIR))
|
||||
chunks = chunk_file(file_path)
|
||||
if not chunks:
|
||||
return 0
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = get_db()
|
||||
try:
|
||||
conn.execute("DELETE FROM chunks WHERE file_path = ?", (rel_path,))
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
embedding = get_embedding(chunk_text)
|
||||
conn.execute(
|
||||
"""INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(rel_path, i, chunk_text, serialize_embedding(embedding), now),
|
||||
)
|
||||
conn.commit()
|
||||
return len(chunks)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def reindex() -> dict:
|
||||
"""Rebuild entire index. Returns {"files": N, "chunks": M}."""
|
||||
conn = get_db()
|
||||
conn.execute("DELETE FROM chunks")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
files_count = 0
|
||||
chunks_count = 0
|
||||
for md_file in sorted(MEMORY_DIR.rglob("*.md")):
|
||||
try:
|
||||
n = index_file(md_file)
|
||||
files_count += 1
|
||||
chunks_count += n
|
||||
log.info("Indexed %s (%d chunks)", md_file.name, n)
|
||||
except Exception as e:
|
||||
log.warning("Failed to index %s: %s", md_file, e)
|
||||
|
||||
return {"files": files_count, "chunks": chunks_count}
|
||||
|
||||
|
||||
def search(query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search for query. Returns list of {"file": str, "chunk": str, "score": float}."""
|
||||
query_embedding = get_embedding(query)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT file_path, chunk_text, embedding FROM chunks"
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
scored = []
|
||||
for file_path, chunk_text, emb_blob in rows:
|
||||
emb = deserialize_embedding(emb_blob)
|
||||
score = cosine_similarity(query_embedding, emb)
|
||||
scored.append({"file": file_path, "chunk": chunk_text, "score": score})
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:top_k]
|
||||
Reference in New Issue
Block a user