From 0ecfa630eb58a68de6a6e72244d89eebb2f1b029 Mon Sep 17 00:00:00 2001 From: MoltBot Service Date: Fri, 13 Feb 2026 16:49:57 +0000 Subject: [PATCH] stage-10: memory search with Ollama embeddings + SQLite Semantic search over memory/*.md files using all-minilm embeddings. Adds /search Discord command and `echo memory search/reindex` CLI. Co-Authored-By: Claude Opus 4.6 --- cli.py | 58 +++ src/adapters/discord_bot.py | 36 ++ src/memory_search.py | 210 +++++++++++ tests/test_memory_search.py | 703 ++++++++++++++++++++++++++++++++++++ 4 files changed, 1007 insertions(+) create mode 100644 src/memory_search.py create mode 100644 tests/test_memory_search.py diff --git a/cli.py b/cli.py index c8c558e..dec89d7 100755 --- a/cli.py +++ b/cli.py @@ -405,6 +405,52 @@ def _cron_disable(name: str): print(f"Job '{name}' not found.") +def cmd_memory(args): + """Handle memory subcommand.""" + if args.memory_action == "search": + _memory_search(args.query) + elif args.memory_action == "reindex": + _memory_reindex() + + +def _memory_search(query: str): + """Search memory and print results.""" + from src.memory_search import search + + try: + results = search(query) + except ConnectionError as e: + print(f"Error: {e}") + sys.exit(1) + + if not results: + print("No results found (index may be empty — run `echo memory reindex`).") + return + + for i, r in enumerate(results, 1): + score = r["score"] + print(f"\n--- Result {i} (score: {score:.3f}) ---") + print(f"File: {r['file']}") + preview = r["chunk"][:200] + if len(r["chunk"]) > 200: + preview += "..." + print(preview) + + +def _memory_reindex(): + """Rebuild memory search index.""" + from src.memory_search import reindex + + print("Reindexing memory files...") + try: + stats = reindex() + except ConnectionError as e: + print(f"Error: {e}") + sys.exit(1) + + print(f"Done. Indexed {stats['files']} files, {stats['chunks']} chunks.") + + def cmd_heartbeat(args): """Run heartbeat health checks.""" from src.heartbeat import run_heartbeat @@ -515,6 +561,15 @@ def main(): secrets_sub.add_parser("test", help="Check required secrets") + # memory + memory_parser = sub.add_parser("memory", help="Memory search commands") + memory_sub = memory_parser.add_subparsers(dest="memory_action") + + memory_search_p = memory_sub.add_parser("search", help="Search memory files") + memory_search_p.add_argument("query", help="Search query text") + + memory_sub.add_parser("reindex", help="Rebuild memory search index") + # heartbeat sub.add_parser("heartbeat", help="Run heartbeat health checks") @@ -563,6 +618,9 @@ def main(): cmd_channel(a) if a.channel_action else (channel_parser.print_help() or sys.exit(0)) ), "send": cmd_send, + "memory": lambda a: ( + cmd_memory(a) if a.memory_action else (memory_parser.print_help() or sys.exit(0)) + ), "heartbeat": cmd_heartbeat, "cron": lambda a: ( cmd_cron(a) if a.cron_action else (cron_parser.print_help() or sys.exit(0)) diff --git a/src/adapters/discord_bot.py b/src/adapters/discord_bot.py index 172c230..26a7192 100644 --- a/src/adapters/discord_bot.py +++ b/src/adapters/discord_bot.py @@ -124,6 +124,7 @@ def create_bot(config: Config) -> discord.Client: "`/logs [n]` — Show last N log lines (default 10)", "`/restart` — Restart the bot process (owner only)", "`/heartbeat` — Run heartbeat health checks", + "`/search ` — Search Echo's memory", "", "**Cron Jobs**", "`/cron list` — List all scheduled jobs", @@ -413,6 +414,41 @@ def create_bot(config: Config) -> discord.Client: f"Heartbeat error: {e}", ephemeral=True ) + @tree.command(name="search", description="Search Echo's memory") + @app_commands.describe(query="What to search for") + async def search_cmd( + interaction: discord.Interaction, query: str + ) -> None: + await interaction.response.defer() + try: + from src.memory_search import search + + results = await asyncio.to_thread(search, query) + if not results: + await interaction.followup.send( + "No results found (index may be empty — run `echo memory reindex`)." + ) + return + + lines = [f"**Search results for:** {query}\n"] + for i, r in enumerate(results, 1): + score = r["score"] + preview = r["chunk"][:150] + if len(r["chunk"]) > 150: + preview += "..." + lines.append( + f"**{i}.** `{r['file']}` (score: {score:.3f})\n{preview}\n" + ) + text = "\n".join(lines) + if len(text) > 1900: + text = text[:1900] + "\n..." + await interaction.followup.send(text) + except ConnectionError as e: + await interaction.followup.send(f"Search error: {e}") + except Exception as e: + logger.exception("Search command failed") + await interaction.followup.send(f"Search error: {e}") + @tree.command(name="channels", description="List registered channels") async def channels(interaction: discord.Interaction) -> None: ch_map = config.get("channels", {}) diff --git a/src/memory_search.py b/src/memory_search.py new file mode 100644 index 0000000..d3089cd --- /dev/null +++ b/src/memory_search.py @@ -0,0 +1,210 @@ +"""Echo Core memory search — semantic search over memory/*.md files. + +Uses Ollama all-minilm embeddings stored in SQLite for cosine similarity search. +""" + +import logging +import math +import sqlite3 +import struct +from datetime import datetime, timezone +from pathlib import Path + +import httpx + +log = logging.getLogger(__name__) + +OLLAMA_URL = "http://10.0.20.161:11434/api/embeddings" +OLLAMA_MODEL = "all-minilm" +EMBEDDING_DIM = 384 +DB_PATH = Path(__file__).resolve().parent.parent / "memory" / "echo.sqlite" +MEMORY_DIR = Path(__file__).resolve().parent.parent / "memory" + +_CHUNK_TARGET = 500 +_CHUNK_MAX = 1000 +_CHUNK_MIN = 100 + + +def get_db() -> sqlite3.Connection: + """Get SQLite connection, create table if needed.""" + DB_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(DB_PATH)) + conn.execute( + """CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + chunk_text TEXT NOT NULL, + embedding BLOB NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(file_path, chunk_index) + )""" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_file_path ON chunks(file_path)" + ) + conn.commit() + return conn + + +def get_embedding(text: str) -> list[float]: + """Get embedding vector from Ollama. Returns list of 384 floats.""" + try: + resp = httpx.post( + OLLAMA_URL, + json={"model": OLLAMA_MODEL, "prompt": text}, + timeout=30.0, + ) + resp.raise_for_status() + embedding = resp.json()["embedding"] + if len(embedding) != EMBEDDING_DIM: + raise ValueError( + f"Expected {EMBEDDING_DIM} dimensions, got {len(embedding)}" + ) + return embedding + except httpx.ConnectError: + raise ConnectionError( + f"Cannot connect to Ollama at {OLLAMA_URL}. Is Ollama running?" + ) + except httpx.HTTPStatusError as e: + raise ConnectionError(f"Ollama API error: {e.response.status_code}") + + +def serialize_embedding(embedding: list[float]) -> bytes: + """Pack floats to bytes for SQLite storage.""" + return struct.pack(f"{len(embedding)}f", *embedding) + + +def deserialize_embedding(data: bytes) -> list[float]: + """Unpack bytes to floats.""" + n = len(data) // 4 + return list(struct.unpack(f"{n}f", data)) + + +def cosine_similarity(a: list[float], b: list[float]) -> float: + """Compute cosine similarity between two vectors.""" + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def chunk_file(file_path: Path) -> list[str]: + """Split .md file into chunks of ~500 chars.""" + text = file_path.read_text(encoding="utf-8") + if not text.strip(): + return [] + + # Split by double newlines or headers + raw_parts: list[str] = [] + current = "" + for line in text.split("\n"): + # Split on headers or empty lines (paragraph boundaries) + if line.startswith("#") and current.strip(): + raw_parts.append(current.strip()) + current = line + "\n" + elif line.strip() == "" and current.strip(): + raw_parts.append(current.strip()) + current = "" + else: + current += line + "\n" + if current.strip(): + raw_parts.append(current.strip()) + + # Merge small chunks with next, split large ones + chunks: list[str] = [] + buffer = "" + for part in raw_parts: + if buffer and len(buffer) + len(part) + 1 > _CHUNK_MAX: + chunks.append(buffer) + buffer = part + elif buffer: + buffer = buffer + "\n\n" + part + else: + buffer = part + + # If buffer exceeds max, flush + if len(buffer) > _CHUNK_MAX: + chunks.append(buffer) + buffer = "" + + if buffer: + # Merge tiny trailing chunk with previous + if len(buffer) < _CHUNK_MIN and chunks: + chunks[-1] = chunks[-1] + "\n\n" + buffer + else: + chunks.append(buffer) + + return chunks + + +def index_file(file_path: Path) -> int: + """Index a single file. Returns number of chunks created.""" + rel_path = str(file_path.relative_to(MEMORY_DIR)) + chunks = chunk_file(file_path) + if not chunks: + return 0 + + now = datetime.now(timezone.utc).isoformat() + conn = get_db() + try: + conn.execute("DELETE FROM chunks WHERE file_path = ?", (rel_path,)) + for i, chunk_text in enumerate(chunks): + embedding = get_embedding(chunk_text) + conn.execute( + """INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) + VALUES (?, ?, ?, ?, ?)""", + (rel_path, i, chunk_text, serialize_embedding(embedding), now), + ) + conn.commit() + return len(chunks) + finally: + conn.close() + + +def reindex() -> dict: + """Rebuild entire index. Returns {"files": N, "chunks": M}.""" + conn = get_db() + conn.execute("DELETE FROM chunks") + conn.commit() + conn.close() + + files_count = 0 + chunks_count = 0 + for md_file in sorted(MEMORY_DIR.rglob("*.md")): + try: + n = index_file(md_file) + files_count += 1 + chunks_count += n + log.info("Indexed %s (%d chunks)", md_file.name, n) + except Exception as e: + log.warning("Failed to index %s: %s", md_file, e) + + return {"files": files_count, "chunks": chunks_count} + + +def search(query: str, top_k: int = 5) -> list[dict]: + """Search for query. Returns list of {"file": str, "chunk": str, "score": float}.""" + query_embedding = get_embedding(query) + + conn = get_db() + try: + rows = conn.execute( + "SELECT file_path, chunk_text, embedding FROM chunks" + ).fetchall() + finally: + conn.close() + + if not rows: + return [] + + scored = [] + for file_path, chunk_text, emb_blob in rows: + emb = deserialize_embedding(emb_blob) + score = cosine_similarity(query_embedding, emb) + scored.append({"file": file_path, "chunk": chunk_text, "score": score}) + + scored.sort(key=lambda x: x["score"], reverse=True) + return scored[:top_k] diff --git a/tests/test_memory_search.py b/tests/test_memory_search.py new file mode 100644 index 0000000..f5e2d4a --- /dev/null +++ b/tests/test_memory_search.py @@ -0,0 +1,703 @@ +"""Comprehensive tests for src/memory_search.py — semantic memory search.""" + +import argparse +import math +import sqlite3 +import struct +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.memory_search import ( + chunk_file, + cosine_similarity, + deserialize_embedding, + get_db, + get_embedding, + index_file, + reindex, + search, + serialize_embedding, + EMBEDDING_DIM, + _CHUNK_TARGET, + _CHUNK_MAX, + _CHUNK_MIN, +) + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +FAKE_EMBEDDING = [0.1] * EMBEDDING_DIM + + +def _fake_ollama_response(embedding=None): + """Build a mock httpx.Response for Ollama embeddings.""" + if embedding is None: + embedding = FAKE_EMBEDDING + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = {"embedding": embedding} + return resp + + +@pytest.fixture +def mem_iso(tmp_path, monkeypatch): + """Isolate memory_search module to use tmp_path for DB and memory dir.""" + mem_dir = tmp_path / "memory" + mem_dir.mkdir() + db_path = mem_dir / "echo.sqlite" + + monkeypatch.setattr("src.memory_search.DB_PATH", db_path) + monkeypatch.setattr("src.memory_search.MEMORY_DIR", mem_dir) + + return {"mem_dir": mem_dir, "db_path": db_path} + + +def _write_md(mem_dir: Path, name: str, content: str) -> Path: + """Write a .md file in the memory directory.""" + f = mem_dir / name + f.write_text(content, encoding="utf-8") + return f + + +def _args(**kwargs): + """Create an argparse.Namespace with given keyword attrs.""" + return argparse.Namespace(**kwargs) + + +# --------------------------------------------------------------------------- +# chunk_file +# --------------------------------------------------------------------------- + + +class TestChunkFile: + def test_normal_md_file(self, tmp_path): + """Paragraphs separated by blank lines become separate chunks.""" + f = tmp_path / "notes.md" + f.write_text("# Header\n\nParagraph one.\n\nParagraph two.\n") + chunks = chunk_file(f) + assert len(chunks) >= 1 + # All content should be represented + full = "\n\n".join(chunks) + assert "Header" in full + assert "Paragraph one" in full + assert "Paragraph two" in full + + def test_empty_file(self, tmp_path): + """Empty file returns empty list.""" + f = tmp_path / "empty.md" + f.write_text("") + assert chunk_file(f) == [] + + def test_whitespace_only_file(self, tmp_path): + """File with only whitespace returns empty list.""" + f = tmp_path / "blank.md" + f.write_text(" \n\n \n") + assert chunk_file(f) == [] + + def test_single_long_paragraph(self, tmp_path): + """A single paragraph exceeding _CHUNK_MAX gets split.""" + long_text = "word " * 300 # ~1500 chars, well over _CHUNK_MAX + f = tmp_path / "long.md" + f.write_text(long_text) + chunks = chunk_file(f) + # Should have been forced into at least one chunk + assert len(chunks) >= 1 + # All words preserved + joined = " ".join(chunks) + assert "word" in joined + + def test_small_paragraphs_merged(self, tmp_path): + """Small paragraphs below _CHUNK_MIN get merged together.""" + # 5 very small paragraphs + content = "\n\n".join(["Hi." for _ in range(5)]) + f = tmp_path / "small.md" + f.write_text(content) + chunks = chunk_file(f) + # Should merge them rather than having 5 tiny chunks + assert len(chunks) < 5 + + def test_chunks_within_size_limits(self, tmp_path): + """All chunks (except maybe the last merged one) stay near target size.""" + paragraphs = [f"Paragraph {i}. " + ("x" * 200) for i in range(20)] + content = "\n\n".join(paragraphs) + f = tmp_path / "medium.md" + f.write_text(content) + chunks = chunk_file(f) + assert len(chunks) > 1 + # No chunk should wildly exceed the max (some tolerance for merging) + for chunk in chunks: + # After merging the tiny trailing chunk, could be up to 2x max + assert len(chunk) < _CHUNK_MAX * 2 + 200 + + def test_header_splits(self, tmp_path): + """Headers trigger chunk boundaries.""" + content = "Intro paragraph.\n\n# Section 1\n\nContent one.\n\n# Section 2\n\nContent two.\n" + f = tmp_path / "headers.md" + f.write_text(content) + chunks = chunk_file(f) + assert len(chunks) >= 1 + full = "\n\n".join(chunks) + assert "Section 1" in full + assert "Section 2" in full + + +# --------------------------------------------------------------------------- +# serialize_embedding / deserialize_embedding +# --------------------------------------------------------------------------- + + +class TestEmbeddingSerialization: + def test_serialize_returns_bytes(self): + vec = [1.0, 2.0, 3.0] + data = serialize_embedding(vec) + assert isinstance(data, bytes) + assert len(data) == len(vec) * 4 # 4 bytes per float32 + + def test_round_trip(self): + vec = [0.1, -0.5, 3.14, 0.0, -1.0] + data = serialize_embedding(vec) + result = deserialize_embedding(data) + assert len(result) == len(vec) + for a, b in zip(vec, result): + assert abs(a - b) < 1e-5 + + def test_round_trip_384_dim(self): + """Round-trip with full 384-dimension vector.""" + vec = [float(i) / 384 for i in range(EMBEDDING_DIM)] + data = serialize_embedding(vec) + assert len(data) == EMBEDDING_DIM * 4 + result = deserialize_embedding(data) + assert len(result) == EMBEDDING_DIM + for a, b in zip(vec, result): + assert abs(a - b) < 1e-5 + + def test_known_values(self): + """Test with specific known float values.""" + vec = [1.0, -1.0, 0.0] + data = serialize_embedding(vec) + # Manually check packed bytes + expected = struct.pack("3f", 1.0, -1.0, 0.0) + assert data == expected + assert deserialize_embedding(expected) == [1.0, -1.0, 0.0] + + def test_empty_vector(self): + vec = [] + data = serialize_embedding(vec) + assert data == b"" + assert deserialize_embedding(data) == [] + + +# --------------------------------------------------------------------------- +# cosine_similarity +# --------------------------------------------------------------------------- + + +class TestCosineSimilarity: + def test_identical_vectors(self): + v = [1.0, 2.0, 3.0] + assert abs(cosine_similarity(v, v) - 1.0) < 1e-9 + + def test_orthogonal_vectors(self): + a = [1.0, 0.0] + b = [0.0, 1.0] + assert abs(cosine_similarity(a, b)) < 1e-9 + + def test_opposite_vectors(self): + a = [1.0, 2.0, 3.0] + b = [-1.0, -2.0, -3.0] + assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-9 + + def test_known_similarity(self): + """Two known vectors with a calculable similarity.""" + a = [1.0, 0.0, 0.0] + b = [1.0, 1.0, 0.0] + # cos(45°) = 1/sqrt(2) ≈ 0.7071 + expected = 1.0 / math.sqrt(2) + assert abs(cosine_similarity(a, b) - expected) < 1e-9 + + def test_zero_vector_returns_zero(self): + """Zero vector should return 0.0 (not NaN).""" + a = [0.0, 0.0, 0.0] + b = [1.0, 2.0, 3.0] + assert cosine_similarity(a, b) == 0.0 + + def test_both_zero_vectors(self): + a = [0.0, 0.0] + b = [0.0, 0.0] + assert cosine_similarity(a, b) == 0.0 + + +# --------------------------------------------------------------------------- +# get_embedding +# --------------------------------------------------------------------------- + + +class TestGetEmbedding: + @patch("src.memory_search.httpx.post") + def test_returns_embedding(self, mock_post): + mock_post.return_value = _fake_ollama_response() + result = get_embedding("hello") + assert result == FAKE_EMBEDDING + assert len(result) == EMBEDDING_DIM + + @patch("src.memory_search.httpx.post") + def test_raises_on_connect_error(self, mock_post): + import httpx + mock_post.side_effect = httpx.ConnectError("connection refused") + with pytest.raises(ConnectionError, match="Cannot connect to Ollama"): + get_embedding("hello") + + @patch("src.memory_search.httpx.post") + def test_raises_on_http_error(self, mock_post): + import httpx + resp = MagicMock() + resp.status_code = 500 + resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "server error", request=MagicMock(), response=resp + ) + resp.json.return_value = {} + mock_post.return_value = resp + with pytest.raises(ConnectionError, match="Ollama API error"): + get_embedding("hello") + + @patch("src.memory_search.httpx.post") + def test_raises_on_wrong_dimension(self, mock_post): + mock_post.return_value = _fake_ollama_response(embedding=[0.1] * 10) + with pytest.raises(ValueError, match="Expected 384"): + get_embedding("hello") + + +# --------------------------------------------------------------------------- +# get_db (SQLite) +# --------------------------------------------------------------------------- + + +class TestGetDb: + def test_creates_table(self, mem_iso): + conn = get_db() + try: + # Table should exist + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'" + ) + assert cursor.fetchone() is not None + finally: + conn.close() + + def test_creates_index(self, mem_iso): + conn = get_db() + try: + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_file_path'" + ) + assert cursor.fetchone() is not None + finally: + conn.close() + + def test_idempotent(self, mem_iso): + """Calling get_db twice doesn't error (CREATE IF NOT EXISTS).""" + conn1 = get_db() + conn1.close() + conn2 = get_db() + conn2.close() + + def test_creates_parent_dir(self, tmp_path, monkeypatch): + """Creates parent directory for the DB file if missing.""" + db_path = tmp_path / "deep" / "nested" / "echo.sqlite" + monkeypatch.setattr("src.memory_search.DB_PATH", db_path) + conn = get_db() + conn.close() + assert db_path.exists() + + +# --------------------------------------------------------------------------- +# index_file +# --------------------------------------------------------------------------- + + +class TestIndexFile: + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_index_stores_chunks(self, mock_emb, mem_iso): + f = _write_md(mem_iso["mem_dir"], "notes.md", "# Title\n\nSome content here.\n") + n = index_file(f) + assert n >= 1 + + conn = get_db() + try: + rows = conn.execute("SELECT file_path, chunk_text FROM chunks").fetchall() + assert len(rows) == n + assert rows[0][0] == "notes.md" + assert "Title" in rows[0][1] or "content" in rows[0][1] + finally: + conn.close() + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_index_empty_file(self, mock_emb, mem_iso): + f = _write_md(mem_iso["mem_dir"], "empty.md", "") + n = index_file(f) + assert n == 0 + mock_emb.assert_not_called() + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_reindex_replaces_old_chunks(self, mock_emb, mem_iso): + """Calling index_file twice for the same file replaces old chunks.""" + f = _write_md(mem_iso["mem_dir"], "test.md", "First version.\n") + index_file(f) + + # Update the file + f.write_text("Second version with more content.\n\nAnother paragraph.\n") + n2 = index_file(f) + + conn = get_db() + try: + rows = conn.execute( + "SELECT chunk_text FROM chunks WHERE file_path = ?", ("test.md",) + ).fetchall() + assert len(rows) == n2 + # Should contain new content, not old + all_text = " ".join(r[0] for r in rows) + assert "Second version" in all_text + finally: + conn.close() + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_stores_embedding_blob(self, mock_emb, mem_iso): + f = _write_md(mem_iso["mem_dir"], "test.md", "Some text.\n") + index_file(f) + + conn = get_db() + try: + row = conn.execute("SELECT embedding FROM chunks").fetchone() + assert row is not None + emb = deserialize_embedding(row[0]) + assert len(emb) == EMBEDDING_DIM + finally: + conn.close() + + +# --------------------------------------------------------------------------- +# reindex +# --------------------------------------------------------------------------- + + +class TestReindex: + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_reindex_all_files(self, mock_emb, mem_iso): + _write_md(mem_iso["mem_dir"], "a.md", "File A content.\n") + _write_md(mem_iso["mem_dir"], "b.md", "File B content.\n") + + stats = reindex() + assert stats["files"] == 2 + assert stats["chunks"] >= 2 + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_reindex_clears_old_data(self, mock_emb, mem_iso): + """Reindex deletes all existing chunks first.""" + f = _write_md(mem_iso["mem_dir"], "old.md", "Old content.\n") + index_file(f) + + # Remove the file, reindex should clear stale data + f.unlink() + _write_md(mem_iso["mem_dir"], "new.md", "New content.\n") + stats = reindex() + + conn = get_db() + try: + rows = conn.execute("SELECT DISTINCT file_path FROM chunks").fetchall() + files = [r[0] for r in rows] + assert "old.md" not in files + assert "new.md" in files + finally: + conn.close() + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_reindex_empty_dir(self, mock_emb, mem_iso): + stats = reindex() + assert stats == {"files": 0, "chunks": 0} + + @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) + def test_reindex_includes_subdirs(self, mock_emb, mem_iso): + """rglob should find .md files in subdirectories.""" + sub = mem_iso["mem_dir"] / "kb" + sub.mkdir() + _write_md(sub, "deep.md", "Deep content.\n") + _write_md(mem_iso["mem_dir"], "top.md", "Top content.\n") + + stats = reindex() + assert stats["files"] == 2 + + @patch("src.memory_search.get_embedding", side_effect=ConnectionError("offline")) + def test_reindex_handles_embedding_failure(self, mock_emb, mem_iso): + """Files that fail to embed are skipped, not crash.""" + _write_md(mem_iso["mem_dir"], "fail.md", "Content.\n") + stats = reindex() + # File attempted but failed — still counted (index_file raises, caught by reindex) + assert stats["files"] == 0 + assert stats["chunks"] == 0 + + +# --------------------------------------------------------------------------- +# search +# --------------------------------------------------------------------------- + + +class TestSearch: + def _seed_db(self, mem_iso, entries): + """Insert test chunks into the database. + + entries: list of (file_path, chunk_text, embedding_list) + """ + conn = get_db() + for i, (fp, text, emb) in enumerate(entries): + conn.execute( + "INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) VALUES (?, ?, ?, ?, ?)", + (fp, i, text, serialize_embedding(emb), "2025-01-01T00:00:00"), + ) + conn.commit() + conn.close() + + @patch("src.memory_search.get_embedding") + def test_search_returns_sorted(self, mock_emb, mem_iso): + """Results are sorted by score descending.""" + query_vec = [1.0, 0.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3) + close_vec = [0.9, 0.1, 0.0] + [0.0] * (EMBEDDING_DIM - 3) + far_vec = [0.0, 1.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3) + + mock_emb.return_value = query_vec + self._seed_db(mem_iso, [ + ("close.md", "close content", close_vec), + ("far.md", "far content", far_vec), + ]) + + results = search("test query") + assert len(results) == 2 + assert results[0]["file"] == "close.md" + assert results[1]["file"] == "far.md" + assert results[0]["score"] > results[1]["score"] + + @patch("src.memory_search.get_embedding") + def test_search_top_k(self, mock_emb, mem_iso): + """top_k limits the number of results.""" + query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) + mock_emb.return_value = query_vec + + entries = [ + (f"file{i}.md", f"content {i}", [float(i) / 10] + [0.0] * (EMBEDDING_DIM - 1)) + for i in range(10) + ] + self._seed_db(mem_iso, entries) + + results = search("test", top_k=3) + assert len(results) == 3 + + @patch("src.memory_search.get_embedding") + def test_search_empty_index(self, mock_emb, mem_iso): + """Search with no indexed data returns empty list.""" + mock_emb.return_value = FAKE_EMBEDDING + # Ensure db exists but is empty + conn = get_db() + conn.close() + + results = search("anything") + assert results == [] + + @patch("src.memory_search.get_embedding") + def test_search_result_structure(self, mock_emb, mem_iso): + """Each result has file, chunk, score keys.""" + query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) + mock_emb.return_value = query_vec + self._seed_db(mem_iso, [ + ("test.md", "test content", query_vec), + ]) + + results = search("query") + assert len(results) == 1 + r = results[0] + assert "file" in r + assert "chunk" in r + assert "score" in r + assert r["file"] == "test.md" + assert r["chunk"] == "test content" + assert isinstance(r["score"], float) + + @patch("src.memory_search.get_embedding") + def test_search_default_top_k(self, mock_emb, mem_iso): + """Default top_k is 5.""" + query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) + mock_emb.return_value = query_vec + + entries = [ + (f"file{i}.md", f"content {i}", [float(i) / 20] + [0.0] * (EMBEDDING_DIM - 1)) + for i in range(10) + ] + self._seed_db(mem_iso, entries) + + results = search("test") + assert len(results) == 5 + + +# --------------------------------------------------------------------------- +# CLI commands: memory search, memory reindex +# --------------------------------------------------------------------------- + + +class TestCliMemorySearch: + @patch("src.memory_search.search") + def test_memory_search_shows_results(self, mock_search, capsys): + import cli + + mock_search.return_value = [ + {"file": "notes.md", "chunk": "Some relevant content here", "score": 0.85}, + {"file": "kb/info.md", "chunk": "Another result", "score": 0.72}, + ] + cli._memory_search("test query") + + out = capsys.readouterr().out + assert "notes.md" in out + assert "0.850" in out + assert "kb/info.md" in out + assert "0.720" in out + assert "Result 1" in out + assert "Result 2" in out + + @patch("src.memory_search.search") + def test_memory_search_empty_results(self, mock_search, capsys): + import cli + + mock_search.return_value = [] + cli._memory_search("nothing") + + out = capsys.readouterr().out + assert "No results" in out + assert "reindex" in out + + @patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline")) + def test_memory_search_connection_error(self, mock_search, capsys): + import cli + + with pytest.raises(SystemExit): + cli._memory_search("test") + out = capsys.readouterr().out + assert "Ollama offline" in out + + @patch("src.memory_search.search") + def test_memory_search_truncates_long_chunks(self, mock_search, capsys): + import cli + + mock_search.return_value = [ + {"file": "test.md", "chunk": "x" * 500, "score": 0.9}, + ] + cli._memory_search("query") + + out = capsys.readouterr().out + assert "..." in out # preview truncated at 200 chars + + +class TestCliMemoryReindex: + @patch("src.memory_search.reindex") + def test_memory_reindex_shows_stats(self, mock_reindex, capsys): + import cli + + mock_reindex.return_value = {"files": 5, "chunks": 23} + cli._memory_reindex() + + out = capsys.readouterr().out + assert "5 files" in out + assert "23 chunks" in out + + @patch("src.memory_search.reindex", side_effect=ConnectionError("Ollama down")) + def test_memory_reindex_connection_error(self, mock_reindex, capsys): + import cli + + with pytest.raises(SystemExit): + cli._memory_reindex() + out = capsys.readouterr().out + assert "Ollama down" in out + + +# --------------------------------------------------------------------------- +# Discord /search command +# --------------------------------------------------------------------------- + + +class TestDiscordSearchCommand: + def _find_command(self, tree, name): + for cmd in tree.get_commands(): + if cmd.name == name: + return cmd + return None + + def _mock_interaction(self, user_id="123", channel_id="456"): + interaction = AsyncMock() + interaction.user = MagicMock() + interaction.user.id = int(user_id) + interaction.channel_id = int(channel_id) + interaction.response = AsyncMock() + interaction.response.defer = AsyncMock() + interaction.followup = AsyncMock() + interaction.followup.send = AsyncMock() + return interaction + + @pytest.fixture + def search_bot(self, tmp_path): + """Create a bot with config for testing /search.""" + import json + from src.config import Config + from src.adapters.discord_bot import create_bot + + data = { + "bot": {"name": "Echo", "default_model": "sonnet", "owner": "111", "admins": []}, + "channels": {}, + } + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps(data, indent=2)) + config = Config(config_file) + return create_bot(config) + + @pytest.mark.asyncio + @patch("src.memory_search.search") + async def test_search_command_exists(self, mock_search, search_bot): + cmd = self._find_command(search_bot.tree, "search") + assert cmd is not None + + @pytest.mark.asyncio + @patch("src.memory_search.search") + async def test_search_command_with_results(self, mock_search, search_bot): + mock_search.return_value = [ + {"file": "notes.md", "chunk": "Relevant content", "score": 0.9}, + ] + cmd = self._find_command(search_bot.tree, "search") + interaction = self._mock_interaction() + await cmd.callback(interaction, query="test query") + + interaction.response.defer.assert_awaited_once() + interaction.followup.send.assert_awaited_once() + msg = interaction.followup.send.call_args.args[0] + assert "notes.md" in msg + assert "0.9" in msg + + @pytest.mark.asyncio + @patch("src.memory_search.search") + async def test_search_command_empty_results(self, mock_search, search_bot): + mock_search.return_value = [] + cmd = self._find_command(search_bot.tree, "search") + interaction = self._mock_interaction() + await cmd.callback(interaction, query="nothing") + + msg = interaction.followup.send.call_args.args[0] + assert "no results" in msg.lower() + + @pytest.mark.asyncio + @patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline")) + async def test_search_command_connection_error(self, mock_search, search_bot): + cmd = self._find_command(search_bot.tree, "search") + interaction = self._mock_interaction() + await cmd.callback(interaction, query="test") + + msg = interaction.followup.send.call_args.args[0] + assert "error" in msg.lower()