"""Comprehensive tests for src/memory_search.py — semantic memory search.""" import argparse import math import sqlite3 import struct from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from src.memory_search import ( chunk_file, cosine_similarity, deserialize_embedding, get_db, get_embedding, index_file, reindex, search, serialize_embedding, EMBEDDING_DIM, _CHUNK_TARGET, _CHUNK_MAX, _CHUNK_MIN, ) # --------------------------------------------------------------------------- # Helpers / fixtures # --------------------------------------------------------------------------- FAKE_EMBEDDING = [0.1] * EMBEDDING_DIM def _fake_ollama_response(embedding=None): """Build a mock httpx.Response for Ollama embeddings.""" if embedding is None: embedding = FAKE_EMBEDDING resp = MagicMock() resp.status_code = 200 resp.raise_for_status = MagicMock() resp.json.return_value = {"embedding": embedding} return resp @pytest.fixture def mem_iso(tmp_path, monkeypatch): """Isolate memory_search module to use tmp_path for DB and memory dir.""" mem_dir = tmp_path / "memory" mem_dir.mkdir() db_path = mem_dir / "echo.sqlite" monkeypatch.setattr("src.memory_search.DB_PATH", db_path) monkeypatch.setattr("src.memory_search.MEMORY_DIR", mem_dir) return {"mem_dir": mem_dir, "db_path": db_path} def _write_md(mem_dir: Path, name: str, content: str) -> Path: """Write a .md file in the memory directory.""" f = mem_dir / name f.write_text(content, encoding="utf-8") return f def _args(**kwargs): """Create an argparse.Namespace with given keyword attrs.""" return argparse.Namespace(**kwargs) # --------------------------------------------------------------------------- # chunk_file # --------------------------------------------------------------------------- class TestChunkFile: def test_normal_md_file(self, tmp_path): """Paragraphs separated by blank lines become separate chunks.""" f = tmp_path / "notes.md" f.write_text("# Header\n\nParagraph one.\n\nParagraph two.\n") chunks = chunk_file(f) assert len(chunks) >= 1 # All content should be represented full = "\n\n".join(chunks) assert "Header" in full assert "Paragraph one" in full assert "Paragraph two" in full def test_empty_file(self, tmp_path): """Empty file returns empty list.""" f = tmp_path / "empty.md" f.write_text("") assert chunk_file(f) == [] def test_whitespace_only_file(self, tmp_path): """File with only whitespace returns empty list.""" f = tmp_path / "blank.md" f.write_text(" \n\n \n") assert chunk_file(f) == [] def test_single_long_paragraph(self, tmp_path): """A single paragraph exceeding _CHUNK_MAX gets split.""" long_text = "word " * 300 # ~1500 chars, well over _CHUNK_MAX f = tmp_path / "long.md" f.write_text(long_text) chunks = chunk_file(f) # Should have been forced into at least one chunk assert len(chunks) >= 1 # All words preserved joined = " ".join(chunks) assert "word" in joined def test_small_paragraphs_merged(self, tmp_path): """Small paragraphs below _CHUNK_MIN get merged together.""" # 5 very small paragraphs content = "\n\n".join(["Hi." for _ in range(5)]) f = tmp_path / "small.md" f.write_text(content) chunks = chunk_file(f) # Should merge them rather than having 5 tiny chunks assert len(chunks) < 5 def test_chunks_within_size_limits(self, tmp_path): """All chunks (except maybe the last merged one) stay near target size.""" paragraphs = [f"Paragraph {i}. " + ("x" * 200) for i in range(20)] content = "\n\n".join(paragraphs) f = tmp_path / "medium.md" f.write_text(content) chunks = chunk_file(f) assert len(chunks) > 1 # No chunk should wildly exceed the max (some tolerance for merging) for chunk in chunks: # After merging the tiny trailing chunk, could be up to 2x max assert len(chunk) < _CHUNK_MAX * 2 + 200 def test_header_splits(self, tmp_path): """Headers trigger chunk boundaries.""" content = "Intro paragraph.\n\n# Section 1\n\nContent one.\n\n# Section 2\n\nContent two.\n" f = tmp_path / "headers.md" f.write_text(content) chunks = chunk_file(f) assert len(chunks) >= 1 full = "\n\n".join(chunks) assert "Section 1" in full assert "Section 2" in full # --------------------------------------------------------------------------- # serialize_embedding / deserialize_embedding # --------------------------------------------------------------------------- class TestEmbeddingSerialization: def test_serialize_returns_bytes(self): vec = [1.0, 2.0, 3.0] data = serialize_embedding(vec) assert isinstance(data, bytes) assert len(data) == len(vec) * 4 # 4 bytes per float32 def test_round_trip(self): vec = [0.1, -0.5, 3.14, 0.0, -1.0] data = serialize_embedding(vec) result = deserialize_embedding(data) assert len(result) == len(vec) for a, b in zip(vec, result): assert abs(a - b) < 1e-5 def test_round_trip_384_dim(self): """Round-trip with full 384-dimension vector.""" vec = [float(i) / 384 for i in range(EMBEDDING_DIM)] data = serialize_embedding(vec) assert len(data) == EMBEDDING_DIM * 4 result = deserialize_embedding(data) assert len(result) == EMBEDDING_DIM for a, b in zip(vec, result): assert abs(a - b) < 1e-5 def test_known_values(self): """Test with specific known float values.""" vec = [1.0, -1.0, 0.0] data = serialize_embedding(vec) # Manually check packed bytes expected = struct.pack("3f", 1.0, -1.0, 0.0) assert data == expected assert deserialize_embedding(expected) == [1.0, -1.0, 0.0] def test_empty_vector(self): vec = [] data = serialize_embedding(vec) assert data == b"" assert deserialize_embedding(data) == [] # --------------------------------------------------------------------------- # cosine_similarity # --------------------------------------------------------------------------- class TestCosineSimilarity: def test_identical_vectors(self): v = [1.0, 2.0, 3.0] assert abs(cosine_similarity(v, v) - 1.0) < 1e-9 def test_orthogonal_vectors(self): a = [1.0, 0.0] b = [0.0, 1.0] assert abs(cosine_similarity(a, b)) < 1e-9 def test_opposite_vectors(self): a = [1.0, 2.0, 3.0] b = [-1.0, -2.0, -3.0] assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-9 def test_known_similarity(self): """Two known vectors with a calculable similarity.""" a = [1.0, 0.0, 0.0] b = [1.0, 1.0, 0.0] # cos(45°) = 1/sqrt(2) ≈ 0.7071 expected = 1.0 / math.sqrt(2) assert abs(cosine_similarity(a, b) - expected) < 1e-9 def test_zero_vector_returns_zero(self): """Zero vector should return 0.0 (not NaN).""" a = [0.0, 0.0, 0.0] b = [1.0, 2.0, 3.0] assert cosine_similarity(a, b) == 0.0 def test_both_zero_vectors(self): a = [0.0, 0.0] b = [0.0, 0.0] assert cosine_similarity(a, b) == 0.0 # --------------------------------------------------------------------------- # get_embedding # --------------------------------------------------------------------------- class TestGetEmbedding: @patch("src.memory_search.httpx.post") def test_returns_embedding(self, mock_post): mock_post.return_value = _fake_ollama_response() result = get_embedding("hello") assert result == FAKE_EMBEDDING assert len(result) == EMBEDDING_DIM @patch("src.memory_search.httpx.post") def test_raises_on_connect_error(self, mock_post): import httpx mock_post.side_effect = httpx.ConnectError("connection refused") with pytest.raises(ConnectionError, match="Cannot connect to Ollama"): get_embedding("hello") @patch("src.memory_search.httpx.post") def test_raises_on_http_error(self, mock_post): import httpx resp = MagicMock() resp.status_code = 500 resp.raise_for_status.side_effect = httpx.HTTPStatusError( "server error", request=MagicMock(), response=resp ) resp.json.return_value = {} mock_post.return_value = resp with pytest.raises(ConnectionError, match="Ollama API error"): get_embedding("hello") @patch("src.memory_search.httpx.post") def test_raises_on_wrong_dimension(self, mock_post): mock_post.return_value = _fake_ollama_response(embedding=[0.1] * 10) with pytest.raises(ValueError, match="Expected 384"): get_embedding("hello") # --------------------------------------------------------------------------- # get_db (SQLite) # --------------------------------------------------------------------------- class TestGetDb: def test_creates_table(self, mem_iso): conn = get_db() try: # Table should exist cursor = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'" ) assert cursor.fetchone() is not None finally: conn.close() def test_creates_index(self, mem_iso): conn = get_db() try: cursor = conn.execute( "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_file_path'" ) assert cursor.fetchone() is not None finally: conn.close() def test_idempotent(self, mem_iso): """Calling get_db twice doesn't error (CREATE IF NOT EXISTS).""" conn1 = get_db() conn1.close() conn2 = get_db() conn2.close() def test_creates_parent_dir(self, tmp_path, monkeypatch): """Creates parent directory for the DB file if missing.""" db_path = tmp_path / "deep" / "nested" / "echo.sqlite" monkeypatch.setattr("src.memory_search.DB_PATH", db_path) conn = get_db() conn.close() assert db_path.exists() # --------------------------------------------------------------------------- # index_file # --------------------------------------------------------------------------- class TestIndexFile: @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_index_stores_chunks(self, mock_emb, mem_iso): f = _write_md(mem_iso["mem_dir"], "notes.md", "# Title\n\nSome content here.\n") n = index_file(f) assert n >= 1 conn = get_db() try: rows = conn.execute("SELECT file_path, chunk_text FROM chunks").fetchall() assert len(rows) == n assert rows[0][0] == "notes.md" assert "Title" in rows[0][1] or "content" in rows[0][1] finally: conn.close() @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_index_empty_file(self, mock_emb, mem_iso): f = _write_md(mem_iso["mem_dir"], "empty.md", "") n = index_file(f) assert n == 0 mock_emb.assert_not_called() @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_reindex_replaces_old_chunks(self, mock_emb, mem_iso): """Calling index_file twice for the same file replaces old chunks.""" f = _write_md(mem_iso["mem_dir"], "test.md", "First version.\n") index_file(f) # Update the file f.write_text("Second version with more content.\n\nAnother paragraph.\n") n2 = index_file(f) conn = get_db() try: rows = conn.execute( "SELECT chunk_text FROM chunks WHERE file_path = ?", ("test.md",) ).fetchall() assert len(rows) == n2 # Should contain new content, not old all_text = " ".join(r[0] for r in rows) assert "Second version" in all_text finally: conn.close() @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_stores_embedding_blob(self, mock_emb, mem_iso): f = _write_md(mem_iso["mem_dir"], "test.md", "Some text.\n") index_file(f) conn = get_db() try: row = conn.execute("SELECT embedding FROM chunks").fetchone() assert row is not None emb = deserialize_embedding(row[0]) assert len(emb) == EMBEDDING_DIM finally: conn.close() # --------------------------------------------------------------------------- # reindex # --------------------------------------------------------------------------- class TestReindex: @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_reindex_all_files(self, mock_emb, mem_iso): _write_md(mem_iso["mem_dir"], "a.md", "File A content.\n") _write_md(mem_iso["mem_dir"], "b.md", "File B content.\n") stats = reindex() assert stats["files"] == 2 assert stats["chunks"] >= 2 @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_reindex_clears_old_data(self, mock_emb, mem_iso): """Reindex deletes all existing chunks first.""" f = _write_md(mem_iso["mem_dir"], "old.md", "Old content.\n") index_file(f) # Remove the file, reindex should clear stale data f.unlink() _write_md(mem_iso["mem_dir"], "new.md", "New content.\n") stats = reindex() conn = get_db() try: rows = conn.execute("SELECT DISTINCT file_path FROM chunks").fetchall() files = [r[0] for r in rows] assert "old.md" not in files assert "new.md" in files finally: conn.close() @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_reindex_empty_dir(self, mock_emb, mem_iso): stats = reindex() assert stats == {"files": 0, "chunks": 0} @patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING) def test_reindex_includes_subdirs(self, mock_emb, mem_iso): """rglob should find .md files in subdirectories.""" sub = mem_iso["mem_dir"] / "kb" sub.mkdir() _write_md(sub, "deep.md", "Deep content.\n") _write_md(mem_iso["mem_dir"], "top.md", "Top content.\n") stats = reindex() assert stats["files"] == 2 @patch("src.memory_search.get_embedding", side_effect=ConnectionError("offline")) def test_reindex_handles_embedding_failure(self, mock_emb, mem_iso): """Files that fail to embed are skipped, not crash.""" _write_md(mem_iso["mem_dir"], "fail.md", "Content.\n") stats = reindex() # File attempted but failed — still counted (index_file raises, caught by reindex) assert stats["files"] == 0 assert stats["chunks"] == 0 # --------------------------------------------------------------------------- # search # --------------------------------------------------------------------------- class TestSearch: def _seed_db(self, mem_iso, entries): """Insert test chunks into the database. entries: list of (file_path, chunk_text, embedding_list) """ conn = get_db() for i, (fp, text, emb) in enumerate(entries): conn.execute( "INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) VALUES (?, ?, ?, ?, ?)", (fp, i, text, serialize_embedding(emb), "2025-01-01T00:00:00"), ) conn.commit() conn.close() @patch("src.memory_search.get_embedding") def test_search_returns_sorted(self, mock_emb, mem_iso): """Results are sorted by score descending.""" query_vec = [1.0, 0.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3) close_vec = [0.9, 0.1, 0.0] + [0.0] * (EMBEDDING_DIM - 3) far_vec = [0.0, 1.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3) mock_emb.return_value = query_vec self._seed_db(mem_iso, [ ("close.md", "close content", close_vec), ("far.md", "far content", far_vec), ]) results = search("test query") assert len(results) == 2 assert results[0]["file"] == "close.md" assert results[1]["file"] == "far.md" assert results[0]["score"] > results[1]["score"] @patch("src.memory_search.get_embedding") def test_search_top_k(self, mock_emb, mem_iso): """top_k limits the number of results.""" query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) mock_emb.return_value = query_vec entries = [ (f"file{i}.md", f"content {i}", [float(i) / 10] + [0.0] * (EMBEDDING_DIM - 1)) for i in range(10) ] self._seed_db(mem_iso, entries) results = search("test", top_k=3) assert len(results) == 3 @patch("src.memory_search.get_embedding") def test_search_empty_index(self, mock_emb, mem_iso): """Search with no indexed data returns empty list.""" mock_emb.return_value = FAKE_EMBEDDING # Ensure db exists but is empty conn = get_db() conn.close() results = search("anything") assert results == [] @patch("src.memory_search.get_embedding") def test_search_result_structure(self, mock_emb, mem_iso): """Each result has file, chunk, score keys.""" query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) mock_emb.return_value = query_vec self._seed_db(mem_iso, [ ("test.md", "test content", query_vec), ]) results = search("query") assert len(results) == 1 r = results[0] assert "file" in r assert "chunk" in r assert "score" in r assert r["file"] == "test.md" assert r["chunk"] == "test content" assert isinstance(r["score"], float) @patch("src.memory_search.get_embedding") def test_search_default_top_k(self, mock_emb, mem_iso): """Default top_k is 5.""" query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1) mock_emb.return_value = query_vec entries = [ (f"file{i}.md", f"content {i}", [float(i) / 20] + [0.0] * (EMBEDDING_DIM - 1)) for i in range(10) ] self._seed_db(mem_iso, entries) results = search("test") assert len(results) == 5 # --------------------------------------------------------------------------- # CLI commands: memory search, memory reindex # --------------------------------------------------------------------------- class TestCliMemorySearch: @patch("src.memory_search.search") def test_memory_search_shows_results(self, mock_search, capsys): import cli mock_search.return_value = [ {"file": "notes.md", "chunk": "Some relevant content here", "score": 0.85}, {"file": "kb/info.md", "chunk": "Another result", "score": 0.72}, ] cli._memory_search("test query") out = capsys.readouterr().out assert "notes.md" in out assert "0.850" in out assert "kb/info.md" in out assert "0.720" in out assert "Result 1" in out assert "Result 2" in out @patch("src.memory_search.search") def test_memory_search_empty_results(self, mock_search, capsys): import cli mock_search.return_value = [] cli._memory_search("nothing") out = capsys.readouterr().out assert "No results" in out assert "reindex" in out @patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline")) def test_memory_search_connection_error(self, mock_search, capsys): import cli with pytest.raises(SystemExit): cli._memory_search("test") out = capsys.readouterr().out assert "Ollama offline" in out @patch("src.memory_search.search") def test_memory_search_truncates_long_chunks(self, mock_search, capsys): import cli mock_search.return_value = [ {"file": "test.md", "chunk": "x" * 500, "score": 0.9}, ] cli._memory_search("query") out = capsys.readouterr().out assert "..." in out # preview truncated at 200 chars class TestCliMemoryReindex: @patch("src.memory_search.reindex") def test_memory_reindex_shows_stats(self, mock_reindex, capsys): import cli mock_reindex.return_value = {"files": 5, "chunks": 23} cli._memory_reindex() out = capsys.readouterr().out assert "5 files" in out assert "23 chunks" in out @patch("src.memory_search.reindex", side_effect=ConnectionError("Ollama down")) def test_memory_reindex_connection_error(self, mock_reindex, capsys): import cli with pytest.raises(SystemExit): cli._memory_reindex() out = capsys.readouterr().out assert "Ollama down" in out # --------------------------------------------------------------------------- # Discord /search command # --------------------------------------------------------------------------- class TestDiscordSearchCommand: def _find_command(self, tree, name): for cmd in tree.get_commands(): if cmd.name == name: return cmd return None def _mock_interaction(self, user_id="123", channel_id="456"): interaction = AsyncMock() interaction.user = MagicMock() interaction.user.id = int(user_id) interaction.channel_id = int(channel_id) interaction.response = AsyncMock() interaction.response.defer = AsyncMock() interaction.followup = AsyncMock() interaction.followup.send = AsyncMock() return interaction @pytest.fixture def search_bot(self, tmp_path): """Create a bot with config for testing /search.""" import json from src.config import Config from src.adapters.discord_bot import create_bot data = { "bot": {"name": "Echo", "default_model": "sonnet", "owner": "111", "admins": []}, "channels": {}, } config_file = tmp_path / "config.json" config_file.write_text(json.dumps(data, indent=2)) config = Config(config_file) return create_bot(config) @pytest.mark.asyncio @patch("src.memory_search.search") async def test_search_command_exists(self, mock_search, search_bot): cmd = self._find_command(search_bot.tree, "search") assert cmd is not None @pytest.mark.asyncio @patch("src.memory_search.search") async def test_search_command_with_results(self, mock_search, search_bot): mock_search.return_value = [ {"file": "notes.md", "chunk": "Relevant content", "score": 0.9}, ] cmd = self._find_command(search_bot.tree, "search") interaction = self._mock_interaction() await cmd.callback(interaction, query="test query") interaction.response.defer.assert_awaited_once() interaction.followup.send.assert_awaited_once() msg = interaction.followup.send.call_args.args[0] assert "notes.md" in msg assert "0.9" in msg @pytest.mark.asyncio @patch("src.memory_search.search") async def test_search_command_empty_results(self, mock_search, search_bot): mock_search.return_value = [] cmd = self._find_command(search_bot.tree, "search") interaction = self._mock_interaction() await cmd.callback(interaction, query="nothing") msg = interaction.followup.send.call_args.args[0] assert "no results" in msg.lower() @pytest.mark.asyncio @patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline")) async def test_search_command_connection_error(self, mock_search, search_bot): cmd = self._find_command(search_bot.tree, "search") interaction = self._mock_interaction() await cmd.callback(interaction, query="test") msg = interaction.followup.send.call_args.args[0] assert "error" in msg.lower()