stage-10: memory search with Ollama embeddings + SQLite
Semantic search over memory/*.md files using all-minilm embeddings. Adds /search Discord command and `echo memory search/reindex` CLI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
58
cli.py
58
cli.py
@@ -405,6 +405,52 @@ def _cron_disable(name: str):
|
||||
print(f"Job '{name}' not found.")
|
||||
|
||||
|
||||
def cmd_memory(args):
|
||||
"""Handle memory subcommand."""
|
||||
if args.memory_action == "search":
|
||||
_memory_search(args.query)
|
||||
elif args.memory_action == "reindex":
|
||||
_memory_reindex()
|
||||
|
||||
|
||||
def _memory_search(query: str):
|
||||
"""Search memory and print results."""
|
||||
from src.memory_search import search
|
||||
|
||||
try:
|
||||
results = search(query)
|
||||
except ConnectionError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not results:
|
||||
print("No results found (index may be empty — run `echo memory reindex`).")
|
||||
return
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
score = r["score"]
|
||||
print(f"\n--- Result {i} (score: {score:.3f}) ---")
|
||||
print(f"File: {r['file']}")
|
||||
preview = r["chunk"][:200]
|
||||
if len(r["chunk"]) > 200:
|
||||
preview += "..."
|
||||
print(preview)
|
||||
|
||||
|
||||
def _memory_reindex():
|
||||
"""Rebuild memory search index."""
|
||||
from src.memory_search import reindex
|
||||
|
||||
print("Reindexing memory files...")
|
||||
try:
|
||||
stats = reindex()
|
||||
except ConnectionError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Done. Indexed {stats['files']} files, {stats['chunks']} chunks.")
|
||||
|
||||
|
||||
def cmd_heartbeat(args):
|
||||
"""Run heartbeat health checks."""
|
||||
from src.heartbeat import run_heartbeat
|
||||
@@ -515,6 +561,15 @@ def main():
|
||||
|
||||
secrets_sub.add_parser("test", help="Check required secrets")
|
||||
|
||||
# memory
|
||||
memory_parser = sub.add_parser("memory", help="Memory search commands")
|
||||
memory_sub = memory_parser.add_subparsers(dest="memory_action")
|
||||
|
||||
memory_search_p = memory_sub.add_parser("search", help="Search memory files")
|
||||
memory_search_p.add_argument("query", help="Search query text")
|
||||
|
||||
memory_sub.add_parser("reindex", help="Rebuild memory search index")
|
||||
|
||||
# heartbeat
|
||||
sub.add_parser("heartbeat", help="Run heartbeat health checks")
|
||||
|
||||
@@ -563,6 +618,9 @@ def main():
|
||||
cmd_channel(a) if a.channel_action else (channel_parser.print_help() or sys.exit(0))
|
||||
),
|
||||
"send": cmd_send,
|
||||
"memory": lambda a: (
|
||||
cmd_memory(a) if a.memory_action else (memory_parser.print_help() or sys.exit(0))
|
||||
),
|
||||
"heartbeat": cmd_heartbeat,
|
||||
"cron": lambda a: (
|
||||
cmd_cron(a) if a.cron_action else (cron_parser.print_help() or sys.exit(0))
|
||||
|
||||
@@ -124,6 +124,7 @@ def create_bot(config: Config) -> discord.Client:
|
||||
"`/logs [n]` — Show last N log lines (default 10)",
|
||||
"`/restart` — Restart the bot process (owner only)",
|
||||
"`/heartbeat` — Run heartbeat health checks",
|
||||
"`/search <query>` — Search Echo's memory",
|
||||
"",
|
||||
"**Cron Jobs**",
|
||||
"`/cron list` — List all scheduled jobs",
|
||||
@@ -413,6 +414,41 @@ def create_bot(config: Config) -> discord.Client:
|
||||
f"Heartbeat error: {e}", ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="search", description="Search Echo's memory")
|
||||
@app_commands.describe(query="What to search for")
|
||||
async def search_cmd(
|
||||
interaction: discord.Interaction, query: str
|
||||
) -> None:
|
||||
await interaction.response.defer()
|
||||
try:
|
||||
from src.memory_search import search
|
||||
|
||||
results = await asyncio.to_thread(search, query)
|
||||
if not results:
|
||||
await interaction.followup.send(
|
||||
"No results found (index may be empty — run `echo memory reindex`)."
|
||||
)
|
||||
return
|
||||
|
||||
lines = [f"**Search results for:** {query}\n"]
|
||||
for i, r in enumerate(results, 1):
|
||||
score = r["score"]
|
||||
preview = r["chunk"][:150]
|
||||
if len(r["chunk"]) > 150:
|
||||
preview += "..."
|
||||
lines.append(
|
||||
f"**{i}.** `{r['file']}` (score: {score:.3f})\n{preview}\n"
|
||||
)
|
||||
text = "\n".join(lines)
|
||||
if len(text) > 1900:
|
||||
text = text[:1900] + "\n..."
|
||||
await interaction.followup.send(text)
|
||||
except ConnectionError as e:
|
||||
await interaction.followup.send(f"Search error: {e}")
|
||||
except Exception as e:
|
||||
logger.exception("Search command failed")
|
||||
await interaction.followup.send(f"Search error: {e}")
|
||||
|
||||
@tree.command(name="channels", description="List registered channels")
|
||||
async def channels(interaction: discord.Interaction) -> None:
|
||||
ch_map = config.get("channels", {})
|
||||
|
||||
210
src/memory_search.py
Normal file
210
src/memory_search.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Echo Core memory search — semantic search over memory/*.md files.
|
||||
|
||||
Uses Ollama all-minilm embeddings stored in SQLite for cosine similarity search.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sqlite3
|
||||
import struct
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_URL = "http://10.0.20.161:11434/api/embeddings"
|
||||
OLLAMA_MODEL = "all-minilm"
|
||||
EMBEDDING_DIM = 384
|
||||
DB_PATH = Path(__file__).resolve().parent.parent / "memory" / "echo.sqlite"
|
||||
MEMORY_DIR = Path(__file__).resolve().parent.parent / "memory"
|
||||
|
||||
_CHUNK_TARGET = 500
|
||||
_CHUNK_MAX = 1000
|
||||
_CHUNK_MIN = 100
|
||||
|
||||
|
||||
def get_db() -> sqlite3.Connection:
|
||||
"""Get SQLite connection, create table if needed."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS chunks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_path TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
UNIQUE(file_path, chunk_index)
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_file_path ON chunks(file_path)"
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector from Ollama. Returns list of 384 floats."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
OLLAMA_URL,
|
||||
json={"model": OLLAMA_MODEL, "prompt": text},
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embedding = resp.json()["embedding"]
|
||||
if len(embedding) != EMBEDDING_DIM:
|
||||
raise ValueError(
|
||||
f"Expected {EMBEDDING_DIM} dimensions, got {len(embedding)}"
|
||||
)
|
||||
return embedding
|
||||
except httpx.ConnectError:
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to Ollama at {OLLAMA_URL}. Is Ollama running?"
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ConnectionError(f"Ollama API error: {e.response.status_code}")
|
||||
|
||||
|
||||
def serialize_embedding(embedding: list[float]) -> bytes:
|
||||
"""Pack floats to bytes for SQLite storage."""
|
||||
return struct.pack(f"{len(embedding)}f", *embedding)
|
||||
|
||||
|
||||
def deserialize_embedding(data: bytes) -> list[float]:
|
||||
"""Unpack bytes to floats."""
|
||||
n = len(data) // 4
|
||||
return list(struct.unpack(f"{n}f", data))
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def chunk_file(file_path: Path) -> list[str]:
|
||||
"""Split .md file into chunks of ~500 chars."""
|
||||
text = file_path.read_text(encoding="utf-8")
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
# Split by double newlines or headers
|
||||
raw_parts: list[str] = []
|
||||
current = ""
|
||||
for line in text.split("\n"):
|
||||
# Split on headers or empty lines (paragraph boundaries)
|
||||
if line.startswith("#") and current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
current = line + "\n"
|
||||
elif line.strip() == "" and current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
current = ""
|
||||
else:
|
||||
current += line + "\n"
|
||||
if current.strip():
|
||||
raw_parts.append(current.strip())
|
||||
|
||||
# Merge small chunks with next, split large ones
|
||||
chunks: list[str] = []
|
||||
buffer = ""
|
||||
for part in raw_parts:
|
||||
if buffer and len(buffer) + len(part) + 1 > _CHUNK_MAX:
|
||||
chunks.append(buffer)
|
||||
buffer = part
|
||||
elif buffer:
|
||||
buffer = buffer + "\n\n" + part
|
||||
else:
|
||||
buffer = part
|
||||
|
||||
# If buffer exceeds max, flush
|
||||
if len(buffer) > _CHUNK_MAX:
|
||||
chunks.append(buffer)
|
||||
buffer = ""
|
||||
|
||||
if buffer:
|
||||
# Merge tiny trailing chunk with previous
|
||||
if len(buffer) < _CHUNK_MIN and chunks:
|
||||
chunks[-1] = chunks[-1] + "\n\n" + buffer
|
||||
else:
|
||||
chunks.append(buffer)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def index_file(file_path: Path) -> int:
|
||||
"""Index a single file. Returns number of chunks created."""
|
||||
rel_path = str(file_path.relative_to(MEMORY_DIR))
|
||||
chunks = chunk_file(file_path)
|
||||
if not chunks:
|
||||
return 0
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = get_db()
|
||||
try:
|
||||
conn.execute("DELETE FROM chunks WHERE file_path = ?", (rel_path,))
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
embedding = get_embedding(chunk_text)
|
||||
conn.execute(
|
||||
"""INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(rel_path, i, chunk_text, serialize_embedding(embedding), now),
|
||||
)
|
||||
conn.commit()
|
||||
return len(chunks)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def reindex() -> dict:
|
||||
"""Rebuild entire index. Returns {"files": N, "chunks": M}."""
|
||||
conn = get_db()
|
||||
conn.execute("DELETE FROM chunks")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
files_count = 0
|
||||
chunks_count = 0
|
||||
for md_file in sorted(MEMORY_DIR.rglob("*.md")):
|
||||
try:
|
||||
n = index_file(md_file)
|
||||
files_count += 1
|
||||
chunks_count += n
|
||||
log.info("Indexed %s (%d chunks)", md_file.name, n)
|
||||
except Exception as e:
|
||||
log.warning("Failed to index %s: %s", md_file, e)
|
||||
|
||||
return {"files": files_count, "chunks": chunks_count}
|
||||
|
||||
|
||||
def search(query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search for query. Returns list of {"file": str, "chunk": str, "score": float}."""
|
||||
query_embedding = get_embedding(query)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT file_path, chunk_text, embedding FROM chunks"
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
scored = []
|
||||
for file_path, chunk_text, emb_blob in rows:
|
||||
emb = deserialize_embedding(emb_blob)
|
||||
score = cosine_similarity(query_embedding, emb)
|
||||
scored.append({"file": file_path, "chunk": chunk_text, "score": score})
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:top_k]
|
||||
703
tests/test_memory_search.py
Normal file
703
tests/test_memory_search.py
Normal file
@@ -0,0 +1,703 @@
|
||||
"""Comprehensive tests for src/memory_search.py — semantic memory search."""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import sqlite3
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.memory_search import (
|
||||
chunk_file,
|
||||
cosine_similarity,
|
||||
deserialize_embedding,
|
||||
get_db,
|
||||
get_embedding,
|
||||
index_file,
|
||||
reindex,
|
||||
search,
|
||||
serialize_embedding,
|
||||
EMBEDDING_DIM,
|
||||
_CHUNK_TARGET,
|
||||
_CHUNK_MAX,
|
||||
_CHUNK_MIN,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FAKE_EMBEDDING = [0.1] * EMBEDDING_DIM
|
||||
|
||||
|
||||
def _fake_ollama_response(embedding=None):
|
||||
"""Build a mock httpx.Response for Ollama embeddings."""
|
||||
if embedding is None:
|
||||
embedding = FAKE_EMBEDDING
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json.return_value = {"embedding": embedding}
|
||||
return resp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem_iso(tmp_path, monkeypatch):
|
||||
"""Isolate memory_search module to use tmp_path for DB and memory dir."""
|
||||
mem_dir = tmp_path / "memory"
|
||||
mem_dir.mkdir()
|
||||
db_path = mem_dir / "echo.sqlite"
|
||||
|
||||
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
|
||||
monkeypatch.setattr("src.memory_search.MEMORY_DIR", mem_dir)
|
||||
|
||||
return {"mem_dir": mem_dir, "db_path": db_path}
|
||||
|
||||
|
||||
def _write_md(mem_dir: Path, name: str, content: str) -> Path:
|
||||
"""Write a .md file in the memory directory."""
|
||||
f = mem_dir / name
|
||||
f.write_text(content, encoding="utf-8")
|
||||
return f
|
||||
|
||||
|
||||
def _args(**kwargs):
|
||||
"""Create an argparse.Namespace with given keyword attrs."""
|
||||
return argparse.Namespace(**kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chunk_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChunkFile:
|
||||
def test_normal_md_file(self, tmp_path):
|
||||
"""Paragraphs separated by blank lines become separate chunks."""
|
||||
f = tmp_path / "notes.md"
|
||||
f.write_text("# Header\n\nParagraph one.\n\nParagraph two.\n")
|
||||
chunks = chunk_file(f)
|
||||
assert len(chunks) >= 1
|
||||
# All content should be represented
|
||||
full = "\n\n".join(chunks)
|
||||
assert "Header" in full
|
||||
assert "Paragraph one" in full
|
||||
assert "Paragraph two" in full
|
||||
|
||||
def test_empty_file(self, tmp_path):
|
||||
"""Empty file returns empty list."""
|
||||
f = tmp_path / "empty.md"
|
||||
f.write_text("")
|
||||
assert chunk_file(f) == []
|
||||
|
||||
def test_whitespace_only_file(self, tmp_path):
|
||||
"""File with only whitespace returns empty list."""
|
||||
f = tmp_path / "blank.md"
|
||||
f.write_text(" \n\n \n")
|
||||
assert chunk_file(f) == []
|
||||
|
||||
def test_single_long_paragraph(self, tmp_path):
|
||||
"""A single paragraph exceeding _CHUNK_MAX gets split."""
|
||||
long_text = "word " * 300 # ~1500 chars, well over _CHUNK_MAX
|
||||
f = tmp_path / "long.md"
|
||||
f.write_text(long_text)
|
||||
chunks = chunk_file(f)
|
||||
# Should have been forced into at least one chunk
|
||||
assert len(chunks) >= 1
|
||||
# All words preserved
|
||||
joined = " ".join(chunks)
|
||||
assert "word" in joined
|
||||
|
||||
def test_small_paragraphs_merged(self, tmp_path):
|
||||
"""Small paragraphs below _CHUNK_MIN get merged together."""
|
||||
# 5 very small paragraphs
|
||||
content = "\n\n".join(["Hi." for _ in range(5)])
|
||||
f = tmp_path / "small.md"
|
||||
f.write_text(content)
|
||||
chunks = chunk_file(f)
|
||||
# Should merge them rather than having 5 tiny chunks
|
||||
assert len(chunks) < 5
|
||||
|
||||
def test_chunks_within_size_limits(self, tmp_path):
|
||||
"""All chunks (except maybe the last merged one) stay near target size."""
|
||||
paragraphs = [f"Paragraph {i}. " + ("x" * 200) for i in range(20)]
|
||||
content = "\n\n".join(paragraphs)
|
||||
f = tmp_path / "medium.md"
|
||||
f.write_text(content)
|
||||
chunks = chunk_file(f)
|
||||
assert len(chunks) > 1
|
||||
# No chunk should wildly exceed the max (some tolerance for merging)
|
||||
for chunk in chunks:
|
||||
# After merging the tiny trailing chunk, could be up to 2x max
|
||||
assert len(chunk) < _CHUNK_MAX * 2 + 200
|
||||
|
||||
def test_header_splits(self, tmp_path):
|
||||
"""Headers trigger chunk boundaries."""
|
||||
content = "Intro paragraph.\n\n# Section 1\n\nContent one.\n\n# Section 2\n\nContent two.\n"
|
||||
f = tmp_path / "headers.md"
|
||||
f.write_text(content)
|
||||
chunks = chunk_file(f)
|
||||
assert len(chunks) >= 1
|
||||
full = "\n\n".join(chunks)
|
||||
assert "Section 1" in full
|
||||
assert "Section 2" in full
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# serialize_embedding / deserialize_embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmbeddingSerialization:
|
||||
def test_serialize_returns_bytes(self):
|
||||
vec = [1.0, 2.0, 3.0]
|
||||
data = serialize_embedding(vec)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) == len(vec) * 4 # 4 bytes per float32
|
||||
|
||||
def test_round_trip(self):
|
||||
vec = [0.1, -0.5, 3.14, 0.0, -1.0]
|
||||
data = serialize_embedding(vec)
|
||||
result = deserialize_embedding(data)
|
||||
assert len(result) == len(vec)
|
||||
for a, b in zip(vec, result):
|
||||
assert abs(a - b) < 1e-5
|
||||
|
||||
def test_round_trip_384_dim(self):
|
||||
"""Round-trip with full 384-dimension vector."""
|
||||
vec = [float(i) / 384 for i in range(EMBEDDING_DIM)]
|
||||
data = serialize_embedding(vec)
|
||||
assert len(data) == EMBEDDING_DIM * 4
|
||||
result = deserialize_embedding(data)
|
||||
assert len(result) == EMBEDDING_DIM
|
||||
for a, b in zip(vec, result):
|
||||
assert abs(a - b) < 1e-5
|
||||
|
||||
def test_known_values(self):
|
||||
"""Test with specific known float values."""
|
||||
vec = [1.0, -1.0, 0.0]
|
||||
data = serialize_embedding(vec)
|
||||
# Manually check packed bytes
|
||||
expected = struct.pack("3f", 1.0, -1.0, 0.0)
|
||||
assert data == expected
|
||||
assert deserialize_embedding(expected) == [1.0, -1.0, 0.0]
|
||||
|
||||
def test_empty_vector(self):
|
||||
vec = []
|
||||
data = serialize_embedding(vec)
|
||||
assert data == b""
|
||||
assert deserialize_embedding(data) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
def test_identical_vectors(self):
|
||||
v = [1.0, 2.0, 3.0]
|
||||
assert abs(cosine_similarity(v, v) - 1.0) < 1e-9
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [0.0, 1.0]
|
||||
assert abs(cosine_similarity(a, b)) < 1e-9
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
a = [1.0, 2.0, 3.0]
|
||||
b = [-1.0, -2.0, -3.0]
|
||||
assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-9
|
||||
|
||||
def test_known_similarity(self):
|
||||
"""Two known vectors with a calculable similarity."""
|
||||
a = [1.0, 0.0, 0.0]
|
||||
b = [1.0, 1.0, 0.0]
|
||||
# cos(45°) = 1/sqrt(2) ≈ 0.7071
|
||||
expected = 1.0 / math.sqrt(2)
|
||||
assert abs(cosine_similarity(a, b) - expected) < 1e-9
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
"""Zero vector should return 0.0 (not NaN)."""
|
||||
a = [0.0, 0.0, 0.0]
|
||||
b = [1.0, 2.0, 3.0]
|
||||
assert cosine_similarity(a, b) == 0.0
|
||||
|
||||
def test_both_zero_vectors(self):
|
||||
a = [0.0, 0.0]
|
||||
b = [0.0, 0.0]
|
||||
assert cosine_similarity(a, b) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEmbedding:
|
||||
@patch("src.memory_search.httpx.post")
|
||||
def test_returns_embedding(self, mock_post):
|
||||
mock_post.return_value = _fake_ollama_response()
|
||||
result = get_embedding("hello")
|
||||
assert result == FAKE_EMBEDDING
|
||||
assert len(result) == EMBEDDING_DIM
|
||||
|
||||
@patch("src.memory_search.httpx.post")
|
||||
def test_raises_on_connect_error(self, mock_post):
|
||||
import httpx
|
||||
mock_post.side_effect = httpx.ConnectError("connection refused")
|
||||
with pytest.raises(ConnectionError, match="Cannot connect to Ollama"):
|
||||
get_embedding("hello")
|
||||
|
||||
@patch("src.memory_search.httpx.post")
|
||||
def test_raises_on_http_error(self, mock_post):
|
||||
import httpx
|
||||
resp = MagicMock()
|
||||
resp.status_code = 500
|
||||
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"server error", request=MagicMock(), response=resp
|
||||
)
|
||||
resp.json.return_value = {}
|
||||
mock_post.return_value = resp
|
||||
with pytest.raises(ConnectionError, match="Ollama API error"):
|
||||
get_embedding("hello")
|
||||
|
||||
@patch("src.memory_search.httpx.post")
|
||||
def test_raises_on_wrong_dimension(self, mock_post):
|
||||
mock_post.return_value = _fake_ollama_response(embedding=[0.1] * 10)
|
||||
with pytest.raises(ValueError, match="Expected 384"):
|
||||
get_embedding("hello")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_db (SQLite)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
def test_creates_table(self, mem_iso):
|
||||
conn = get_db()
|
||||
try:
|
||||
# Table should exist
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_creates_index(self, mem_iso):
|
||||
conn = get_db()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_file_path'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_idempotent(self, mem_iso):
|
||||
"""Calling get_db twice doesn't error (CREATE IF NOT EXISTS)."""
|
||||
conn1 = get_db()
|
||||
conn1.close()
|
||||
conn2 = get_db()
|
||||
conn2.close()
|
||||
|
||||
def test_creates_parent_dir(self, tmp_path, monkeypatch):
|
||||
"""Creates parent directory for the DB file if missing."""
|
||||
db_path = tmp_path / "deep" / "nested" / "echo.sqlite"
|
||||
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
|
||||
conn = get_db()
|
||||
conn.close()
|
||||
assert db_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# index_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIndexFile:
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_index_stores_chunks(self, mock_emb, mem_iso):
|
||||
f = _write_md(mem_iso["mem_dir"], "notes.md", "# Title\n\nSome content here.\n")
|
||||
n = index_file(f)
|
||||
assert n >= 1
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = conn.execute("SELECT file_path, chunk_text FROM chunks").fetchall()
|
||||
assert len(rows) == n
|
||||
assert rows[0][0] == "notes.md"
|
||||
assert "Title" in rows[0][1] or "content" in rows[0][1]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_index_empty_file(self, mock_emb, mem_iso):
|
||||
f = _write_md(mem_iso["mem_dir"], "empty.md", "")
|
||||
n = index_file(f)
|
||||
assert n == 0
|
||||
mock_emb.assert_not_called()
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_reindex_replaces_old_chunks(self, mock_emb, mem_iso):
|
||||
"""Calling index_file twice for the same file replaces old chunks."""
|
||||
f = _write_md(mem_iso["mem_dir"], "test.md", "First version.\n")
|
||||
index_file(f)
|
||||
|
||||
# Update the file
|
||||
f.write_text("Second version with more content.\n\nAnother paragraph.\n")
|
||||
n2 = index_file(f)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT chunk_text FROM chunks WHERE file_path = ?", ("test.md",)
|
||||
).fetchall()
|
||||
assert len(rows) == n2
|
||||
# Should contain new content, not old
|
||||
all_text = " ".join(r[0] for r in rows)
|
||||
assert "Second version" in all_text
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_stores_embedding_blob(self, mock_emb, mem_iso):
|
||||
f = _write_md(mem_iso["mem_dir"], "test.md", "Some text.\n")
|
||||
index_file(f)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
row = conn.execute("SELECT embedding FROM chunks").fetchone()
|
||||
assert row is not None
|
||||
emb = deserialize_embedding(row[0])
|
||||
assert len(emb) == EMBEDDING_DIM
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reindex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReindex:
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_reindex_all_files(self, mock_emb, mem_iso):
|
||||
_write_md(mem_iso["mem_dir"], "a.md", "File A content.\n")
|
||||
_write_md(mem_iso["mem_dir"], "b.md", "File B content.\n")
|
||||
|
||||
stats = reindex()
|
||||
assert stats["files"] == 2
|
||||
assert stats["chunks"] >= 2
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_reindex_clears_old_data(self, mock_emb, mem_iso):
|
||||
"""Reindex deletes all existing chunks first."""
|
||||
f = _write_md(mem_iso["mem_dir"], "old.md", "Old content.\n")
|
||||
index_file(f)
|
||||
|
||||
# Remove the file, reindex should clear stale data
|
||||
f.unlink()
|
||||
_write_md(mem_iso["mem_dir"], "new.md", "New content.\n")
|
||||
stats = reindex()
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = conn.execute("SELECT DISTINCT file_path FROM chunks").fetchall()
|
||||
files = [r[0] for r in rows]
|
||||
assert "old.md" not in files
|
||||
assert "new.md" in files
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_reindex_empty_dir(self, mock_emb, mem_iso):
|
||||
stats = reindex()
|
||||
assert stats == {"files": 0, "chunks": 0}
|
||||
|
||||
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
||||
def test_reindex_includes_subdirs(self, mock_emb, mem_iso):
|
||||
"""rglob should find .md files in subdirectories."""
|
||||
sub = mem_iso["mem_dir"] / "kb"
|
||||
sub.mkdir()
|
||||
_write_md(sub, "deep.md", "Deep content.\n")
|
||||
_write_md(mem_iso["mem_dir"], "top.md", "Top content.\n")
|
||||
|
||||
stats = reindex()
|
||||
assert stats["files"] == 2
|
||||
|
||||
@patch("src.memory_search.get_embedding", side_effect=ConnectionError("offline"))
|
||||
def test_reindex_handles_embedding_failure(self, mock_emb, mem_iso):
|
||||
"""Files that fail to embed are skipped, not crash."""
|
||||
_write_md(mem_iso["mem_dir"], "fail.md", "Content.\n")
|
||||
stats = reindex()
|
||||
# File attempted but failed — still counted (index_file raises, caught by reindex)
|
||||
assert stats["files"] == 0
|
||||
assert stats["chunks"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# search
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearch:
|
||||
def _seed_db(self, mem_iso, entries):
|
||||
"""Insert test chunks into the database.
|
||||
|
||||
entries: list of (file_path, chunk_text, embedding_list)
|
||||
"""
|
||||
conn = get_db()
|
||||
for i, (fp, text, emb) in enumerate(entries):
|
||||
conn.execute(
|
||||
"INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||
(fp, i, text, serialize_embedding(emb), "2025-01-01T00:00:00"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@patch("src.memory_search.get_embedding")
|
||||
def test_search_returns_sorted(self, mock_emb, mem_iso):
|
||||
"""Results are sorted by score descending."""
|
||||
query_vec = [1.0, 0.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
||||
close_vec = [0.9, 0.1, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
||||
far_vec = [0.0, 1.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
||||
|
||||
mock_emb.return_value = query_vec
|
||||
self._seed_db(mem_iso, [
|
||||
("close.md", "close content", close_vec),
|
||||
("far.md", "far content", far_vec),
|
||||
])
|
||||
|
||||
results = search("test query")
|
||||
assert len(results) == 2
|
||||
assert results[0]["file"] == "close.md"
|
||||
assert results[1]["file"] == "far.md"
|
||||
assert results[0]["score"] > results[1]["score"]
|
||||
|
||||
@patch("src.memory_search.get_embedding")
|
||||
def test_search_top_k(self, mock_emb, mem_iso):
|
||||
"""top_k limits the number of results."""
|
||||
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
||||
mock_emb.return_value = query_vec
|
||||
|
||||
entries = [
|
||||
(f"file{i}.md", f"content {i}", [float(i) / 10] + [0.0] * (EMBEDDING_DIM - 1))
|
||||
for i in range(10)
|
||||
]
|
||||
self._seed_db(mem_iso, entries)
|
||||
|
||||
results = search("test", top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
@patch("src.memory_search.get_embedding")
|
||||
def test_search_empty_index(self, mock_emb, mem_iso):
|
||||
"""Search with no indexed data returns empty list."""
|
||||
mock_emb.return_value = FAKE_EMBEDDING
|
||||
# Ensure db exists but is empty
|
||||
conn = get_db()
|
||||
conn.close()
|
||||
|
||||
results = search("anything")
|
||||
assert results == []
|
||||
|
||||
@patch("src.memory_search.get_embedding")
|
||||
def test_search_result_structure(self, mock_emb, mem_iso):
|
||||
"""Each result has file, chunk, score keys."""
|
||||
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
||||
mock_emb.return_value = query_vec
|
||||
self._seed_db(mem_iso, [
|
||||
("test.md", "test content", query_vec),
|
||||
])
|
||||
|
||||
results = search("query")
|
||||
assert len(results) == 1
|
||||
r = results[0]
|
||||
assert "file" in r
|
||||
assert "chunk" in r
|
||||
assert "score" in r
|
||||
assert r["file"] == "test.md"
|
||||
assert r["chunk"] == "test content"
|
||||
assert isinstance(r["score"], float)
|
||||
|
||||
@patch("src.memory_search.get_embedding")
|
||||
def test_search_default_top_k(self, mock_emb, mem_iso):
|
||||
"""Default top_k is 5."""
|
||||
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
||||
mock_emb.return_value = query_vec
|
||||
|
||||
entries = [
|
||||
(f"file{i}.md", f"content {i}", [float(i) / 20] + [0.0] * (EMBEDDING_DIM - 1))
|
||||
for i in range(10)
|
||||
]
|
||||
self._seed_db(mem_iso, entries)
|
||||
|
||||
results = search("test")
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI commands: memory search, memory reindex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCliMemorySearch:
|
||||
@patch("src.memory_search.search")
|
||||
def test_memory_search_shows_results(self, mock_search, capsys):
|
||||
import cli
|
||||
|
||||
mock_search.return_value = [
|
||||
{"file": "notes.md", "chunk": "Some relevant content here", "score": 0.85},
|
||||
{"file": "kb/info.md", "chunk": "Another result", "score": 0.72},
|
||||
]
|
||||
cli._memory_search("test query")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "notes.md" in out
|
||||
assert "0.850" in out
|
||||
assert "kb/info.md" in out
|
||||
assert "0.720" in out
|
||||
assert "Result 1" in out
|
||||
assert "Result 2" in out
|
||||
|
||||
@patch("src.memory_search.search")
|
||||
def test_memory_search_empty_results(self, mock_search, capsys):
|
||||
import cli
|
||||
|
||||
mock_search.return_value = []
|
||||
cli._memory_search("nothing")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "No results" in out
|
||||
assert "reindex" in out
|
||||
|
||||
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
|
||||
def test_memory_search_connection_error(self, mock_search, capsys):
|
||||
import cli
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cli._memory_search("test")
|
||||
out = capsys.readouterr().out
|
||||
assert "Ollama offline" in out
|
||||
|
||||
@patch("src.memory_search.search")
|
||||
def test_memory_search_truncates_long_chunks(self, mock_search, capsys):
|
||||
import cli
|
||||
|
||||
mock_search.return_value = [
|
||||
{"file": "test.md", "chunk": "x" * 500, "score": 0.9},
|
||||
]
|
||||
cli._memory_search("query")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "..." in out # preview truncated at 200 chars
|
||||
|
||||
|
||||
class TestCliMemoryReindex:
|
||||
@patch("src.memory_search.reindex")
|
||||
def test_memory_reindex_shows_stats(self, mock_reindex, capsys):
|
||||
import cli
|
||||
|
||||
mock_reindex.return_value = {"files": 5, "chunks": 23}
|
||||
cli._memory_reindex()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "5 files" in out
|
||||
assert "23 chunks" in out
|
||||
|
||||
@patch("src.memory_search.reindex", side_effect=ConnectionError("Ollama down"))
|
||||
def test_memory_reindex_connection_error(self, mock_reindex, capsys):
|
||||
import cli
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cli._memory_reindex()
|
||||
out = capsys.readouterr().out
|
||||
assert "Ollama down" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord /search command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDiscordSearchCommand:
|
||||
def _find_command(self, tree, name):
|
||||
for cmd in tree.get_commands():
|
||||
if cmd.name == name:
|
||||
return cmd
|
||||
return None
|
||||
|
||||
def _mock_interaction(self, user_id="123", channel_id="456"):
|
||||
interaction = AsyncMock()
|
||||
interaction.user = MagicMock()
|
||||
interaction.user.id = int(user_id)
|
||||
interaction.channel_id = int(channel_id)
|
||||
interaction.response = AsyncMock()
|
||||
interaction.response.defer = AsyncMock()
|
||||
interaction.followup = AsyncMock()
|
||||
interaction.followup.send = AsyncMock()
|
||||
return interaction
|
||||
|
||||
@pytest.fixture
|
||||
def search_bot(self, tmp_path):
|
||||
"""Create a bot with config for testing /search."""
|
||||
import json
|
||||
from src.config import Config
|
||||
from src.adapters.discord_bot import create_bot
|
||||
|
||||
data = {
|
||||
"bot": {"name": "Echo", "default_model": "sonnet", "owner": "111", "admins": []},
|
||||
"channels": {},
|
||||
}
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps(data, indent=2))
|
||||
config = Config(config_file)
|
||||
return create_bot(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.memory_search.search")
|
||||
async def test_search_command_exists(self, mock_search, search_bot):
|
||||
cmd = self._find_command(search_bot.tree, "search")
|
||||
assert cmd is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.memory_search.search")
|
||||
async def test_search_command_with_results(self, mock_search, search_bot):
|
||||
mock_search.return_value = [
|
||||
{"file": "notes.md", "chunk": "Relevant content", "score": 0.9},
|
||||
]
|
||||
cmd = self._find_command(search_bot.tree, "search")
|
||||
interaction = self._mock_interaction()
|
||||
await cmd.callback(interaction, query="test query")
|
||||
|
||||
interaction.response.defer.assert_awaited_once()
|
||||
interaction.followup.send.assert_awaited_once()
|
||||
msg = interaction.followup.send.call_args.args[0]
|
||||
assert "notes.md" in msg
|
||||
assert "0.9" in msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.memory_search.search")
|
||||
async def test_search_command_empty_results(self, mock_search, search_bot):
|
||||
mock_search.return_value = []
|
||||
cmd = self._find_command(search_bot.tree, "search")
|
||||
interaction = self._mock_interaction()
|
||||
await cmd.callback(interaction, query="nothing")
|
||||
|
||||
msg = interaction.followup.send.call_args.args[0]
|
||||
assert "no results" in msg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
|
||||
async def test_search_command_connection_error(self, mock_search, search_bot):
|
||||
cmd = self._find_command(search_bot.tree, "search")
|
||||
interaction = self._mock_interaction()
|
||||
await cmd.callback(interaction, query="test")
|
||||
|
||||
msg = interaction.followup.send.call_args.args[0]
|
||||
assert "error" in msg.lower()
|
||||
Reference in New Issue
Block a user