stage-10: memory search with Ollama embeddings + SQLite

Semantic search over memory/*.md files using all-minilm embeddings.
Adds /search Discord command and `echo memory search/reindex` CLI.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
MoltBot Service
2026-02-13 16:49:57 +00:00
parent 0bc4b8cb3e
commit 0ecfa630eb
4 changed files with 1007 additions and 0 deletions

58
cli.py
View File

@@ -405,6 +405,52 @@ def _cron_disable(name: str):
print(f"Job '{name}' not found.")
def cmd_memory(args):
"""Handle memory subcommand."""
if args.memory_action == "search":
_memory_search(args.query)
elif args.memory_action == "reindex":
_memory_reindex()
def _memory_search(query: str):
"""Search memory and print results."""
from src.memory_search import search
try:
results = search(query)
except ConnectionError as e:
print(f"Error: {e}")
sys.exit(1)
if not results:
print("No results found (index may be empty — run `echo memory reindex`).")
return
for i, r in enumerate(results, 1):
score = r["score"]
print(f"\n--- Result {i} (score: {score:.3f}) ---")
print(f"File: {r['file']}")
preview = r["chunk"][:200]
if len(r["chunk"]) > 200:
preview += "..."
print(preview)
def _memory_reindex():
"""Rebuild memory search index."""
from src.memory_search import reindex
print("Reindexing memory files...")
try:
stats = reindex()
except ConnectionError as e:
print(f"Error: {e}")
sys.exit(1)
print(f"Done. Indexed {stats['files']} files, {stats['chunks']} chunks.")
def cmd_heartbeat(args):
"""Run heartbeat health checks."""
from src.heartbeat import run_heartbeat
@@ -515,6 +561,15 @@ def main():
secrets_sub.add_parser("test", help="Check required secrets")
# memory
memory_parser = sub.add_parser("memory", help="Memory search commands")
memory_sub = memory_parser.add_subparsers(dest="memory_action")
memory_search_p = memory_sub.add_parser("search", help="Search memory files")
memory_search_p.add_argument("query", help="Search query text")
memory_sub.add_parser("reindex", help="Rebuild memory search index")
# heartbeat
sub.add_parser("heartbeat", help="Run heartbeat health checks")
@@ -563,6 +618,9 @@ def main():
cmd_channel(a) if a.channel_action else (channel_parser.print_help() or sys.exit(0))
),
"send": cmd_send,
"memory": lambda a: (
cmd_memory(a) if a.memory_action else (memory_parser.print_help() or sys.exit(0))
),
"heartbeat": cmd_heartbeat,
"cron": lambda a: (
cmd_cron(a) if a.cron_action else (cron_parser.print_help() or sys.exit(0))

View File

@@ -124,6 +124,7 @@ def create_bot(config: Config) -> discord.Client:
"`/logs [n]` — Show last N log lines (default 10)",
"`/restart` — Restart the bot process (owner only)",
"`/heartbeat` — Run heartbeat health checks",
"`/search <query>` — Search Echo's memory",
"",
"**Cron Jobs**",
"`/cron list` — List all scheduled jobs",
@@ -413,6 +414,41 @@ def create_bot(config: Config) -> discord.Client:
f"Heartbeat error: {e}", ephemeral=True
)
@tree.command(name="search", description="Search Echo's memory")
@app_commands.describe(query="What to search for")
async def search_cmd(
interaction: discord.Interaction, query: str
) -> None:
await interaction.response.defer()
try:
from src.memory_search import search
results = await asyncio.to_thread(search, query)
if not results:
await interaction.followup.send(
"No results found (index may be empty — run `echo memory reindex`)."
)
return
lines = [f"**Search results for:** {query}\n"]
for i, r in enumerate(results, 1):
score = r["score"]
preview = r["chunk"][:150]
if len(r["chunk"]) > 150:
preview += "..."
lines.append(
f"**{i}.** `{r['file']}` (score: {score:.3f})\n{preview}\n"
)
text = "\n".join(lines)
if len(text) > 1900:
text = text[:1900] + "\n..."
await interaction.followup.send(text)
except ConnectionError as e:
await interaction.followup.send(f"Search error: {e}")
except Exception as e:
logger.exception("Search command failed")
await interaction.followup.send(f"Search error: {e}")
@tree.command(name="channels", description="List registered channels")
async def channels(interaction: discord.Interaction) -> None:
ch_map = config.get("channels", {})

210
src/memory_search.py Normal file
View File

@@ -0,0 +1,210 @@
"""Echo Core memory search — semantic search over memory/*.md files.
Uses Ollama all-minilm embeddings stored in SQLite for cosine similarity search.
"""
import logging
import math
import sqlite3
import struct
from datetime import datetime, timezone
from pathlib import Path
import httpx
log = logging.getLogger(__name__)
OLLAMA_URL = "http://10.0.20.161:11434/api/embeddings"
OLLAMA_MODEL = "all-minilm"
EMBEDDING_DIM = 384
DB_PATH = Path(__file__).resolve().parent.parent / "memory" / "echo.sqlite"
MEMORY_DIR = Path(__file__).resolve().parent.parent / "memory"
_CHUNK_TARGET = 500
_CHUNK_MAX = 1000
_CHUNK_MIN = 100
def get_db() -> sqlite3.Connection:
"""Get SQLite connection, create table if needed."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.execute(
"""CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
chunk_text TEXT NOT NULL,
embedding BLOB NOT NULL,
updated_at TEXT NOT NULL,
UNIQUE(file_path, chunk_index)
)"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_file_path ON chunks(file_path)"
)
conn.commit()
return conn
def get_embedding(text: str) -> list[float]:
"""Get embedding vector from Ollama. Returns list of 384 floats."""
try:
resp = httpx.post(
OLLAMA_URL,
json={"model": OLLAMA_MODEL, "prompt": text},
timeout=30.0,
)
resp.raise_for_status()
embedding = resp.json()["embedding"]
if len(embedding) != EMBEDDING_DIM:
raise ValueError(
f"Expected {EMBEDDING_DIM} dimensions, got {len(embedding)}"
)
return embedding
except httpx.ConnectError:
raise ConnectionError(
f"Cannot connect to Ollama at {OLLAMA_URL}. Is Ollama running?"
)
except httpx.HTTPStatusError as e:
raise ConnectionError(f"Ollama API error: {e.response.status_code}")
def serialize_embedding(embedding: list[float]) -> bytes:
"""Pack floats to bytes for SQLite storage."""
return struct.pack(f"{len(embedding)}f", *embedding)
def deserialize_embedding(data: bytes) -> list[float]:
"""Unpack bytes to floats."""
n = len(data) // 4
return list(struct.unpack(f"{n}f", data))
def cosine_similarity(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def chunk_file(file_path: Path) -> list[str]:
"""Split .md file into chunks of ~500 chars."""
text = file_path.read_text(encoding="utf-8")
if not text.strip():
return []
# Split by double newlines or headers
raw_parts: list[str] = []
current = ""
for line in text.split("\n"):
# Split on headers or empty lines (paragraph boundaries)
if line.startswith("#") and current.strip():
raw_parts.append(current.strip())
current = line + "\n"
elif line.strip() == "" and current.strip():
raw_parts.append(current.strip())
current = ""
else:
current += line + "\n"
if current.strip():
raw_parts.append(current.strip())
# Merge small chunks with next, split large ones
chunks: list[str] = []
buffer = ""
for part in raw_parts:
if buffer and len(buffer) + len(part) + 1 > _CHUNK_MAX:
chunks.append(buffer)
buffer = part
elif buffer:
buffer = buffer + "\n\n" + part
else:
buffer = part
# If buffer exceeds max, flush
if len(buffer) > _CHUNK_MAX:
chunks.append(buffer)
buffer = ""
if buffer:
# Merge tiny trailing chunk with previous
if len(buffer) < _CHUNK_MIN and chunks:
chunks[-1] = chunks[-1] + "\n\n" + buffer
else:
chunks.append(buffer)
return chunks
def index_file(file_path: Path) -> int:
"""Index a single file. Returns number of chunks created."""
rel_path = str(file_path.relative_to(MEMORY_DIR))
chunks = chunk_file(file_path)
if not chunks:
return 0
now = datetime.now(timezone.utc).isoformat()
conn = get_db()
try:
conn.execute("DELETE FROM chunks WHERE file_path = ?", (rel_path,))
for i, chunk_text in enumerate(chunks):
embedding = get_embedding(chunk_text)
conn.execute(
"""INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at)
VALUES (?, ?, ?, ?, ?)""",
(rel_path, i, chunk_text, serialize_embedding(embedding), now),
)
conn.commit()
return len(chunks)
finally:
conn.close()
def reindex() -> dict:
"""Rebuild entire index. Returns {"files": N, "chunks": M}."""
conn = get_db()
conn.execute("DELETE FROM chunks")
conn.commit()
conn.close()
files_count = 0
chunks_count = 0
for md_file in sorted(MEMORY_DIR.rglob("*.md")):
try:
n = index_file(md_file)
files_count += 1
chunks_count += n
log.info("Indexed %s (%d chunks)", md_file.name, n)
except Exception as e:
log.warning("Failed to index %s: %s", md_file, e)
return {"files": files_count, "chunks": chunks_count}
def search(query: str, top_k: int = 5) -> list[dict]:
"""Search for query. Returns list of {"file": str, "chunk": str, "score": float}."""
query_embedding = get_embedding(query)
conn = get_db()
try:
rows = conn.execute(
"SELECT file_path, chunk_text, embedding FROM chunks"
).fetchall()
finally:
conn.close()
if not rows:
return []
scored = []
for file_path, chunk_text, emb_blob in rows:
emb = deserialize_embedding(emb_blob)
score = cosine_similarity(query_embedding, emb)
scored.append({"file": file_path, "chunk": chunk_text, "score": score})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]

703
tests/test_memory_search.py Normal file
View File

@@ -0,0 +1,703 @@
"""Comprehensive tests for src/memory_search.py — semantic memory search."""
import argparse
import math
import sqlite3
import struct
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.memory_search import (
chunk_file,
cosine_similarity,
deserialize_embedding,
get_db,
get_embedding,
index_file,
reindex,
search,
serialize_embedding,
EMBEDDING_DIM,
_CHUNK_TARGET,
_CHUNK_MAX,
_CHUNK_MIN,
)
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
FAKE_EMBEDDING = [0.1] * EMBEDDING_DIM
def _fake_ollama_response(embedding=None):
"""Build a mock httpx.Response for Ollama embeddings."""
if embedding is None:
embedding = FAKE_EMBEDDING
resp = MagicMock()
resp.status_code = 200
resp.raise_for_status = MagicMock()
resp.json.return_value = {"embedding": embedding}
return resp
@pytest.fixture
def mem_iso(tmp_path, monkeypatch):
"""Isolate memory_search module to use tmp_path for DB and memory dir."""
mem_dir = tmp_path / "memory"
mem_dir.mkdir()
db_path = mem_dir / "echo.sqlite"
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
monkeypatch.setattr("src.memory_search.MEMORY_DIR", mem_dir)
return {"mem_dir": mem_dir, "db_path": db_path}
def _write_md(mem_dir: Path, name: str, content: str) -> Path:
"""Write a .md file in the memory directory."""
f = mem_dir / name
f.write_text(content, encoding="utf-8")
return f
def _args(**kwargs):
"""Create an argparse.Namespace with given keyword attrs."""
return argparse.Namespace(**kwargs)
# ---------------------------------------------------------------------------
# chunk_file
# ---------------------------------------------------------------------------
class TestChunkFile:
def test_normal_md_file(self, tmp_path):
"""Paragraphs separated by blank lines become separate chunks."""
f = tmp_path / "notes.md"
f.write_text("# Header\n\nParagraph one.\n\nParagraph two.\n")
chunks = chunk_file(f)
assert len(chunks) >= 1
# All content should be represented
full = "\n\n".join(chunks)
assert "Header" in full
assert "Paragraph one" in full
assert "Paragraph two" in full
def test_empty_file(self, tmp_path):
"""Empty file returns empty list."""
f = tmp_path / "empty.md"
f.write_text("")
assert chunk_file(f) == []
def test_whitespace_only_file(self, tmp_path):
"""File with only whitespace returns empty list."""
f = tmp_path / "blank.md"
f.write_text(" \n\n \n")
assert chunk_file(f) == []
def test_single_long_paragraph(self, tmp_path):
"""A single paragraph exceeding _CHUNK_MAX gets split."""
long_text = "word " * 300 # ~1500 chars, well over _CHUNK_MAX
f = tmp_path / "long.md"
f.write_text(long_text)
chunks = chunk_file(f)
# Should have been forced into at least one chunk
assert len(chunks) >= 1
# All words preserved
joined = " ".join(chunks)
assert "word" in joined
def test_small_paragraphs_merged(self, tmp_path):
"""Small paragraphs below _CHUNK_MIN get merged together."""
# 5 very small paragraphs
content = "\n\n".join(["Hi." for _ in range(5)])
f = tmp_path / "small.md"
f.write_text(content)
chunks = chunk_file(f)
# Should merge them rather than having 5 tiny chunks
assert len(chunks) < 5
def test_chunks_within_size_limits(self, tmp_path):
"""All chunks (except maybe the last merged one) stay near target size."""
paragraphs = [f"Paragraph {i}. " + ("x" * 200) for i in range(20)]
content = "\n\n".join(paragraphs)
f = tmp_path / "medium.md"
f.write_text(content)
chunks = chunk_file(f)
assert len(chunks) > 1
# No chunk should wildly exceed the max (some tolerance for merging)
for chunk in chunks:
# After merging the tiny trailing chunk, could be up to 2x max
assert len(chunk) < _CHUNK_MAX * 2 + 200
def test_header_splits(self, tmp_path):
"""Headers trigger chunk boundaries."""
content = "Intro paragraph.\n\n# Section 1\n\nContent one.\n\n# Section 2\n\nContent two.\n"
f = tmp_path / "headers.md"
f.write_text(content)
chunks = chunk_file(f)
assert len(chunks) >= 1
full = "\n\n".join(chunks)
assert "Section 1" in full
assert "Section 2" in full
# ---------------------------------------------------------------------------
# serialize_embedding / deserialize_embedding
# ---------------------------------------------------------------------------
class TestEmbeddingSerialization:
def test_serialize_returns_bytes(self):
vec = [1.0, 2.0, 3.0]
data = serialize_embedding(vec)
assert isinstance(data, bytes)
assert len(data) == len(vec) * 4 # 4 bytes per float32
def test_round_trip(self):
vec = [0.1, -0.5, 3.14, 0.0, -1.0]
data = serialize_embedding(vec)
result = deserialize_embedding(data)
assert len(result) == len(vec)
for a, b in zip(vec, result):
assert abs(a - b) < 1e-5
def test_round_trip_384_dim(self):
"""Round-trip with full 384-dimension vector."""
vec = [float(i) / 384 for i in range(EMBEDDING_DIM)]
data = serialize_embedding(vec)
assert len(data) == EMBEDDING_DIM * 4
result = deserialize_embedding(data)
assert len(result) == EMBEDDING_DIM
for a, b in zip(vec, result):
assert abs(a - b) < 1e-5
def test_known_values(self):
"""Test with specific known float values."""
vec = [1.0, -1.0, 0.0]
data = serialize_embedding(vec)
# Manually check packed bytes
expected = struct.pack("3f", 1.0, -1.0, 0.0)
assert data == expected
assert deserialize_embedding(expected) == [1.0, -1.0, 0.0]
def test_empty_vector(self):
vec = []
data = serialize_embedding(vec)
assert data == b""
assert deserialize_embedding(data) == []
# ---------------------------------------------------------------------------
# cosine_similarity
# ---------------------------------------------------------------------------
class TestCosineSimilarity:
def test_identical_vectors(self):
v = [1.0, 2.0, 3.0]
assert abs(cosine_similarity(v, v) - 1.0) < 1e-9
def test_orthogonal_vectors(self):
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(cosine_similarity(a, b)) < 1e-9
def test_opposite_vectors(self):
a = [1.0, 2.0, 3.0]
b = [-1.0, -2.0, -3.0]
assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-9
def test_known_similarity(self):
"""Two known vectors with a calculable similarity."""
a = [1.0, 0.0, 0.0]
b = [1.0, 1.0, 0.0]
# cos(45°) = 1/sqrt(2) ≈ 0.7071
expected = 1.0 / math.sqrt(2)
assert abs(cosine_similarity(a, b) - expected) < 1e-9
def test_zero_vector_returns_zero(self):
"""Zero vector should return 0.0 (not NaN)."""
a = [0.0, 0.0, 0.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == 0.0
def test_both_zero_vectors(self):
a = [0.0, 0.0]
b = [0.0, 0.0]
assert cosine_similarity(a, b) == 0.0
# ---------------------------------------------------------------------------
# get_embedding
# ---------------------------------------------------------------------------
class TestGetEmbedding:
@patch("src.memory_search.httpx.post")
def test_returns_embedding(self, mock_post):
mock_post.return_value = _fake_ollama_response()
result = get_embedding("hello")
assert result == FAKE_EMBEDDING
assert len(result) == EMBEDDING_DIM
@patch("src.memory_search.httpx.post")
def test_raises_on_connect_error(self, mock_post):
import httpx
mock_post.side_effect = httpx.ConnectError("connection refused")
with pytest.raises(ConnectionError, match="Cannot connect to Ollama"):
get_embedding("hello")
@patch("src.memory_search.httpx.post")
def test_raises_on_http_error(self, mock_post):
import httpx
resp = MagicMock()
resp.status_code = 500
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"server error", request=MagicMock(), response=resp
)
resp.json.return_value = {}
mock_post.return_value = resp
with pytest.raises(ConnectionError, match="Ollama API error"):
get_embedding("hello")
@patch("src.memory_search.httpx.post")
def test_raises_on_wrong_dimension(self, mock_post):
mock_post.return_value = _fake_ollama_response(embedding=[0.1] * 10)
with pytest.raises(ValueError, match="Expected 384"):
get_embedding("hello")
# ---------------------------------------------------------------------------
# get_db (SQLite)
# ---------------------------------------------------------------------------
class TestGetDb:
def test_creates_table(self, mem_iso):
conn = get_db()
try:
# Table should exist
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
)
assert cursor.fetchone() is not None
finally:
conn.close()
def test_creates_index(self, mem_iso):
conn = get_db()
try:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_file_path'"
)
assert cursor.fetchone() is not None
finally:
conn.close()
def test_idempotent(self, mem_iso):
"""Calling get_db twice doesn't error (CREATE IF NOT EXISTS)."""
conn1 = get_db()
conn1.close()
conn2 = get_db()
conn2.close()
def test_creates_parent_dir(self, tmp_path, monkeypatch):
"""Creates parent directory for the DB file if missing."""
db_path = tmp_path / "deep" / "nested" / "echo.sqlite"
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
conn = get_db()
conn.close()
assert db_path.exists()
# ---------------------------------------------------------------------------
# index_file
# ---------------------------------------------------------------------------
class TestIndexFile:
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_index_stores_chunks(self, mock_emb, mem_iso):
f = _write_md(mem_iso["mem_dir"], "notes.md", "# Title\n\nSome content here.\n")
n = index_file(f)
assert n >= 1
conn = get_db()
try:
rows = conn.execute("SELECT file_path, chunk_text FROM chunks").fetchall()
assert len(rows) == n
assert rows[0][0] == "notes.md"
assert "Title" in rows[0][1] or "content" in rows[0][1]
finally:
conn.close()
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_index_empty_file(self, mock_emb, mem_iso):
f = _write_md(mem_iso["mem_dir"], "empty.md", "")
n = index_file(f)
assert n == 0
mock_emb.assert_not_called()
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_reindex_replaces_old_chunks(self, mock_emb, mem_iso):
"""Calling index_file twice for the same file replaces old chunks."""
f = _write_md(mem_iso["mem_dir"], "test.md", "First version.\n")
index_file(f)
# Update the file
f.write_text("Second version with more content.\n\nAnother paragraph.\n")
n2 = index_file(f)
conn = get_db()
try:
rows = conn.execute(
"SELECT chunk_text FROM chunks WHERE file_path = ?", ("test.md",)
).fetchall()
assert len(rows) == n2
# Should contain new content, not old
all_text = " ".join(r[0] for r in rows)
assert "Second version" in all_text
finally:
conn.close()
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_stores_embedding_blob(self, mock_emb, mem_iso):
f = _write_md(mem_iso["mem_dir"], "test.md", "Some text.\n")
index_file(f)
conn = get_db()
try:
row = conn.execute("SELECT embedding FROM chunks").fetchone()
assert row is not None
emb = deserialize_embedding(row[0])
assert len(emb) == EMBEDDING_DIM
finally:
conn.close()
# ---------------------------------------------------------------------------
# reindex
# ---------------------------------------------------------------------------
class TestReindex:
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_reindex_all_files(self, mock_emb, mem_iso):
_write_md(mem_iso["mem_dir"], "a.md", "File A content.\n")
_write_md(mem_iso["mem_dir"], "b.md", "File B content.\n")
stats = reindex()
assert stats["files"] == 2
assert stats["chunks"] >= 2
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_reindex_clears_old_data(self, mock_emb, mem_iso):
"""Reindex deletes all existing chunks first."""
f = _write_md(mem_iso["mem_dir"], "old.md", "Old content.\n")
index_file(f)
# Remove the file, reindex should clear stale data
f.unlink()
_write_md(mem_iso["mem_dir"], "new.md", "New content.\n")
stats = reindex()
conn = get_db()
try:
rows = conn.execute("SELECT DISTINCT file_path FROM chunks").fetchall()
files = [r[0] for r in rows]
assert "old.md" not in files
assert "new.md" in files
finally:
conn.close()
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_reindex_empty_dir(self, mock_emb, mem_iso):
stats = reindex()
assert stats == {"files": 0, "chunks": 0}
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
def test_reindex_includes_subdirs(self, mock_emb, mem_iso):
"""rglob should find .md files in subdirectories."""
sub = mem_iso["mem_dir"] / "kb"
sub.mkdir()
_write_md(sub, "deep.md", "Deep content.\n")
_write_md(mem_iso["mem_dir"], "top.md", "Top content.\n")
stats = reindex()
assert stats["files"] == 2
@patch("src.memory_search.get_embedding", side_effect=ConnectionError("offline"))
def test_reindex_handles_embedding_failure(self, mock_emb, mem_iso):
"""Files that fail to embed are skipped, not crash."""
_write_md(mem_iso["mem_dir"], "fail.md", "Content.\n")
stats = reindex()
# File attempted but failed — still counted (index_file raises, caught by reindex)
assert stats["files"] == 0
assert stats["chunks"] == 0
# ---------------------------------------------------------------------------
# search
# ---------------------------------------------------------------------------
class TestSearch:
def _seed_db(self, mem_iso, entries):
"""Insert test chunks into the database.
entries: list of (file_path, chunk_text, embedding_list)
"""
conn = get_db()
for i, (fp, text, emb) in enumerate(entries):
conn.execute(
"INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) VALUES (?, ?, ?, ?, ?)",
(fp, i, text, serialize_embedding(emb), "2025-01-01T00:00:00"),
)
conn.commit()
conn.close()
@patch("src.memory_search.get_embedding")
def test_search_returns_sorted(self, mock_emb, mem_iso):
"""Results are sorted by score descending."""
query_vec = [1.0, 0.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
close_vec = [0.9, 0.1, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
far_vec = [0.0, 1.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
mock_emb.return_value = query_vec
self._seed_db(mem_iso, [
("close.md", "close content", close_vec),
("far.md", "far content", far_vec),
])
results = search("test query")
assert len(results) == 2
assert results[0]["file"] == "close.md"
assert results[1]["file"] == "far.md"
assert results[0]["score"] > results[1]["score"]
@patch("src.memory_search.get_embedding")
def test_search_top_k(self, mock_emb, mem_iso):
"""top_k limits the number of results."""
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
mock_emb.return_value = query_vec
entries = [
(f"file{i}.md", f"content {i}", [float(i) / 10] + [0.0] * (EMBEDDING_DIM - 1))
for i in range(10)
]
self._seed_db(mem_iso, entries)
results = search("test", top_k=3)
assert len(results) == 3
@patch("src.memory_search.get_embedding")
def test_search_empty_index(self, mock_emb, mem_iso):
"""Search with no indexed data returns empty list."""
mock_emb.return_value = FAKE_EMBEDDING
# Ensure db exists but is empty
conn = get_db()
conn.close()
results = search("anything")
assert results == []
@patch("src.memory_search.get_embedding")
def test_search_result_structure(self, mock_emb, mem_iso):
"""Each result has file, chunk, score keys."""
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
mock_emb.return_value = query_vec
self._seed_db(mem_iso, [
("test.md", "test content", query_vec),
])
results = search("query")
assert len(results) == 1
r = results[0]
assert "file" in r
assert "chunk" in r
assert "score" in r
assert r["file"] == "test.md"
assert r["chunk"] == "test content"
assert isinstance(r["score"], float)
@patch("src.memory_search.get_embedding")
def test_search_default_top_k(self, mock_emb, mem_iso):
"""Default top_k is 5."""
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
mock_emb.return_value = query_vec
entries = [
(f"file{i}.md", f"content {i}", [float(i) / 20] + [0.0] * (EMBEDDING_DIM - 1))
for i in range(10)
]
self._seed_db(mem_iso, entries)
results = search("test")
assert len(results) == 5
# ---------------------------------------------------------------------------
# CLI commands: memory search, memory reindex
# ---------------------------------------------------------------------------
class TestCliMemorySearch:
@patch("src.memory_search.search")
def test_memory_search_shows_results(self, mock_search, capsys):
import cli
mock_search.return_value = [
{"file": "notes.md", "chunk": "Some relevant content here", "score": 0.85},
{"file": "kb/info.md", "chunk": "Another result", "score": 0.72},
]
cli._memory_search("test query")
out = capsys.readouterr().out
assert "notes.md" in out
assert "0.850" in out
assert "kb/info.md" in out
assert "0.720" in out
assert "Result 1" in out
assert "Result 2" in out
@patch("src.memory_search.search")
def test_memory_search_empty_results(self, mock_search, capsys):
import cli
mock_search.return_value = []
cli._memory_search("nothing")
out = capsys.readouterr().out
assert "No results" in out
assert "reindex" in out
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
def test_memory_search_connection_error(self, mock_search, capsys):
import cli
with pytest.raises(SystemExit):
cli._memory_search("test")
out = capsys.readouterr().out
assert "Ollama offline" in out
@patch("src.memory_search.search")
def test_memory_search_truncates_long_chunks(self, mock_search, capsys):
import cli
mock_search.return_value = [
{"file": "test.md", "chunk": "x" * 500, "score": 0.9},
]
cli._memory_search("query")
out = capsys.readouterr().out
assert "..." in out # preview truncated at 200 chars
class TestCliMemoryReindex:
@patch("src.memory_search.reindex")
def test_memory_reindex_shows_stats(self, mock_reindex, capsys):
import cli
mock_reindex.return_value = {"files": 5, "chunks": 23}
cli._memory_reindex()
out = capsys.readouterr().out
assert "5 files" in out
assert "23 chunks" in out
@patch("src.memory_search.reindex", side_effect=ConnectionError("Ollama down"))
def test_memory_reindex_connection_error(self, mock_reindex, capsys):
import cli
with pytest.raises(SystemExit):
cli._memory_reindex()
out = capsys.readouterr().out
assert "Ollama down" in out
# ---------------------------------------------------------------------------
# Discord /search command
# ---------------------------------------------------------------------------
class TestDiscordSearchCommand:
def _find_command(self, tree, name):
for cmd in tree.get_commands():
if cmd.name == name:
return cmd
return None
def _mock_interaction(self, user_id="123", channel_id="456"):
interaction = AsyncMock()
interaction.user = MagicMock()
interaction.user.id = int(user_id)
interaction.channel_id = int(channel_id)
interaction.response = AsyncMock()
interaction.response.defer = AsyncMock()
interaction.followup = AsyncMock()
interaction.followup.send = AsyncMock()
return interaction
@pytest.fixture
def search_bot(self, tmp_path):
"""Create a bot with config for testing /search."""
import json
from src.config import Config
from src.adapters.discord_bot import create_bot
data = {
"bot": {"name": "Echo", "default_model": "sonnet", "owner": "111", "admins": []},
"channels": {},
}
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps(data, indent=2))
config = Config(config_file)
return create_bot(config)
@pytest.mark.asyncio
@patch("src.memory_search.search")
async def test_search_command_exists(self, mock_search, search_bot):
cmd = self._find_command(search_bot.tree, "search")
assert cmd is not None
@pytest.mark.asyncio
@patch("src.memory_search.search")
async def test_search_command_with_results(self, mock_search, search_bot):
mock_search.return_value = [
{"file": "notes.md", "chunk": "Relevant content", "score": 0.9},
]
cmd = self._find_command(search_bot.tree, "search")
interaction = self._mock_interaction()
await cmd.callback(interaction, query="test query")
interaction.response.defer.assert_awaited_once()
interaction.followup.send.assert_awaited_once()
msg = interaction.followup.send.call_args.args[0]
assert "notes.md" in msg
assert "0.9" in msg
@pytest.mark.asyncio
@patch("src.memory_search.search")
async def test_search_command_empty_results(self, mock_search, search_bot):
mock_search.return_value = []
cmd = self._find_command(search_bot.tree, "search")
interaction = self._mock_interaction()
await cmd.callback(interaction, query="nothing")
msg = interaction.followup.send.call_args.args[0]
assert "no results" in msg.lower()
@pytest.mark.asyncio
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
async def test_search_command_connection_error(self, mock_search, search_bot):
cmd = self._find_command(search_bot.tree, "search")
interaction = self._mock_interaction()
await cmd.callback(interaction, query="test")
msg = interaction.followup.send.call_args.args[0]
assert "error" in msg.lower()