Semantic search over memory/*.md files using all-minilm embeddings. Adds /search Discord command and `echo memory search/reindex` CLI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
704 lines
24 KiB
Python
704 lines
24 KiB
Python
"""Comprehensive tests for src/memory_search.py — semantic memory search."""
|
|
|
|
import argparse
|
|
import math
|
|
import sqlite3
|
|
import struct
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.memory_search import (
|
|
chunk_file,
|
|
cosine_similarity,
|
|
deserialize_embedding,
|
|
get_db,
|
|
get_embedding,
|
|
index_file,
|
|
reindex,
|
|
search,
|
|
serialize_embedding,
|
|
EMBEDDING_DIM,
|
|
_CHUNK_TARGET,
|
|
_CHUNK_MAX,
|
|
_CHUNK_MIN,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers / fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FAKE_EMBEDDING = [0.1] * EMBEDDING_DIM
|
|
|
|
|
|
def _fake_ollama_response(embedding=None):
|
|
"""Build a mock httpx.Response for Ollama embeddings."""
|
|
if embedding is None:
|
|
embedding = FAKE_EMBEDDING
|
|
resp = MagicMock()
|
|
resp.status_code = 200
|
|
resp.raise_for_status = MagicMock()
|
|
resp.json.return_value = {"embedding": embedding}
|
|
return resp
|
|
|
|
|
|
@pytest.fixture
|
|
def mem_iso(tmp_path, monkeypatch):
|
|
"""Isolate memory_search module to use tmp_path for DB and memory dir."""
|
|
mem_dir = tmp_path / "memory"
|
|
mem_dir.mkdir()
|
|
db_path = mem_dir / "echo.sqlite"
|
|
|
|
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
|
|
monkeypatch.setattr("src.memory_search.MEMORY_DIR", mem_dir)
|
|
|
|
return {"mem_dir": mem_dir, "db_path": db_path}
|
|
|
|
|
|
def _write_md(mem_dir: Path, name: str, content: str) -> Path:
|
|
"""Write a .md file in the memory directory."""
|
|
f = mem_dir / name
|
|
f.write_text(content, encoding="utf-8")
|
|
return f
|
|
|
|
|
|
def _args(**kwargs):
|
|
"""Create an argparse.Namespace with given keyword attrs."""
|
|
return argparse.Namespace(**kwargs)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# chunk_file
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChunkFile:
|
|
def test_normal_md_file(self, tmp_path):
|
|
"""Paragraphs separated by blank lines become separate chunks."""
|
|
f = tmp_path / "notes.md"
|
|
f.write_text("# Header\n\nParagraph one.\n\nParagraph two.\n")
|
|
chunks = chunk_file(f)
|
|
assert len(chunks) >= 1
|
|
# All content should be represented
|
|
full = "\n\n".join(chunks)
|
|
assert "Header" in full
|
|
assert "Paragraph one" in full
|
|
assert "Paragraph two" in full
|
|
|
|
def test_empty_file(self, tmp_path):
|
|
"""Empty file returns empty list."""
|
|
f = tmp_path / "empty.md"
|
|
f.write_text("")
|
|
assert chunk_file(f) == []
|
|
|
|
def test_whitespace_only_file(self, tmp_path):
|
|
"""File with only whitespace returns empty list."""
|
|
f = tmp_path / "blank.md"
|
|
f.write_text(" \n\n \n")
|
|
assert chunk_file(f) == []
|
|
|
|
def test_single_long_paragraph(self, tmp_path):
|
|
"""A single paragraph exceeding _CHUNK_MAX gets split."""
|
|
long_text = "word " * 300 # ~1500 chars, well over _CHUNK_MAX
|
|
f = tmp_path / "long.md"
|
|
f.write_text(long_text)
|
|
chunks = chunk_file(f)
|
|
# Should have been forced into at least one chunk
|
|
assert len(chunks) >= 1
|
|
# All words preserved
|
|
joined = " ".join(chunks)
|
|
assert "word" in joined
|
|
|
|
def test_small_paragraphs_merged(self, tmp_path):
|
|
"""Small paragraphs below _CHUNK_MIN get merged together."""
|
|
# 5 very small paragraphs
|
|
content = "\n\n".join(["Hi." for _ in range(5)])
|
|
f = tmp_path / "small.md"
|
|
f.write_text(content)
|
|
chunks = chunk_file(f)
|
|
# Should merge them rather than having 5 tiny chunks
|
|
assert len(chunks) < 5
|
|
|
|
def test_chunks_within_size_limits(self, tmp_path):
|
|
"""All chunks (except maybe the last merged one) stay near target size."""
|
|
paragraphs = [f"Paragraph {i}. " + ("x" * 200) for i in range(20)]
|
|
content = "\n\n".join(paragraphs)
|
|
f = tmp_path / "medium.md"
|
|
f.write_text(content)
|
|
chunks = chunk_file(f)
|
|
assert len(chunks) > 1
|
|
# No chunk should wildly exceed the max (some tolerance for merging)
|
|
for chunk in chunks:
|
|
# After merging the tiny trailing chunk, could be up to 2x max
|
|
assert len(chunk) < _CHUNK_MAX * 2 + 200
|
|
|
|
def test_header_splits(self, tmp_path):
|
|
"""Headers trigger chunk boundaries."""
|
|
content = "Intro paragraph.\n\n# Section 1\n\nContent one.\n\n# Section 2\n\nContent two.\n"
|
|
f = tmp_path / "headers.md"
|
|
f.write_text(content)
|
|
chunks = chunk_file(f)
|
|
assert len(chunks) >= 1
|
|
full = "\n\n".join(chunks)
|
|
assert "Section 1" in full
|
|
assert "Section 2" in full
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# serialize_embedding / deserialize_embedding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEmbeddingSerialization:
|
|
def test_serialize_returns_bytes(self):
|
|
vec = [1.0, 2.0, 3.0]
|
|
data = serialize_embedding(vec)
|
|
assert isinstance(data, bytes)
|
|
assert len(data) == len(vec) * 4 # 4 bytes per float32
|
|
|
|
def test_round_trip(self):
|
|
vec = [0.1, -0.5, 3.14, 0.0, -1.0]
|
|
data = serialize_embedding(vec)
|
|
result = deserialize_embedding(data)
|
|
assert len(result) == len(vec)
|
|
for a, b in zip(vec, result):
|
|
assert abs(a - b) < 1e-5
|
|
|
|
def test_round_trip_384_dim(self):
|
|
"""Round-trip with full 384-dimension vector."""
|
|
vec = [float(i) / 384 for i in range(EMBEDDING_DIM)]
|
|
data = serialize_embedding(vec)
|
|
assert len(data) == EMBEDDING_DIM * 4
|
|
result = deserialize_embedding(data)
|
|
assert len(result) == EMBEDDING_DIM
|
|
for a, b in zip(vec, result):
|
|
assert abs(a - b) < 1e-5
|
|
|
|
def test_known_values(self):
|
|
"""Test with specific known float values."""
|
|
vec = [1.0, -1.0, 0.0]
|
|
data = serialize_embedding(vec)
|
|
# Manually check packed bytes
|
|
expected = struct.pack("3f", 1.0, -1.0, 0.0)
|
|
assert data == expected
|
|
assert deserialize_embedding(expected) == [1.0, -1.0, 0.0]
|
|
|
|
def test_empty_vector(self):
|
|
vec = []
|
|
data = serialize_embedding(vec)
|
|
assert data == b""
|
|
assert deserialize_embedding(data) == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# cosine_similarity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCosineSimilarity:
|
|
def test_identical_vectors(self):
|
|
v = [1.0, 2.0, 3.0]
|
|
assert abs(cosine_similarity(v, v) - 1.0) < 1e-9
|
|
|
|
def test_orthogonal_vectors(self):
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
assert abs(cosine_similarity(a, b)) < 1e-9
|
|
|
|
def test_opposite_vectors(self):
|
|
a = [1.0, 2.0, 3.0]
|
|
b = [-1.0, -2.0, -3.0]
|
|
assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-9
|
|
|
|
def test_known_similarity(self):
|
|
"""Two known vectors with a calculable similarity."""
|
|
a = [1.0, 0.0, 0.0]
|
|
b = [1.0, 1.0, 0.0]
|
|
# cos(45°) = 1/sqrt(2) ≈ 0.7071
|
|
expected = 1.0 / math.sqrt(2)
|
|
assert abs(cosine_similarity(a, b) - expected) < 1e-9
|
|
|
|
def test_zero_vector_returns_zero(self):
|
|
"""Zero vector should return 0.0 (not NaN)."""
|
|
a = [0.0, 0.0, 0.0]
|
|
b = [1.0, 2.0, 3.0]
|
|
assert cosine_similarity(a, b) == 0.0
|
|
|
|
def test_both_zero_vectors(self):
|
|
a = [0.0, 0.0]
|
|
b = [0.0, 0.0]
|
|
assert cosine_similarity(a, b) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_embedding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetEmbedding:
|
|
@patch("src.memory_search.httpx.post")
|
|
def test_returns_embedding(self, mock_post):
|
|
mock_post.return_value = _fake_ollama_response()
|
|
result = get_embedding("hello")
|
|
assert result == FAKE_EMBEDDING
|
|
assert len(result) == EMBEDDING_DIM
|
|
|
|
@patch("src.memory_search.httpx.post")
|
|
def test_raises_on_connect_error(self, mock_post):
|
|
import httpx
|
|
mock_post.side_effect = httpx.ConnectError("connection refused")
|
|
with pytest.raises(ConnectionError, match="Cannot connect to Ollama"):
|
|
get_embedding("hello")
|
|
|
|
@patch("src.memory_search.httpx.post")
|
|
def test_raises_on_http_error(self, mock_post):
|
|
import httpx
|
|
resp = MagicMock()
|
|
resp.status_code = 500
|
|
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
"server error", request=MagicMock(), response=resp
|
|
)
|
|
resp.json.return_value = {}
|
|
mock_post.return_value = resp
|
|
with pytest.raises(ConnectionError, match="Ollama API error"):
|
|
get_embedding("hello")
|
|
|
|
@patch("src.memory_search.httpx.post")
|
|
def test_raises_on_wrong_dimension(self, mock_post):
|
|
mock_post.return_value = _fake_ollama_response(embedding=[0.1] * 10)
|
|
with pytest.raises(ValueError, match="Expected 384"):
|
|
get_embedding("hello")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_db (SQLite)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetDb:
|
|
def test_creates_table(self, mem_iso):
|
|
conn = get_db()
|
|
try:
|
|
# Table should exist
|
|
cursor = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
|
|
)
|
|
assert cursor.fetchone() is not None
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_creates_index(self, mem_iso):
|
|
conn = get_db()
|
|
try:
|
|
cursor = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_file_path'"
|
|
)
|
|
assert cursor.fetchone() is not None
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_idempotent(self, mem_iso):
|
|
"""Calling get_db twice doesn't error (CREATE IF NOT EXISTS)."""
|
|
conn1 = get_db()
|
|
conn1.close()
|
|
conn2 = get_db()
|
|
conn2.close()
|
|
|
|
def test_creates_parent_dir(self, tmp_path, monkeypatch):
|
|
"""Creates parent directory for the DB file if missing."""
|
|
db_path = tmp_path / "deep" / "nested" / "echo.sqlite"
|
|
monkeypatch.setattr("src.memory_search.DB_PATH", db_path)
|
|
conn = get_db()
|
|
conn.close()
|
|
assert db_path.exists()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# index_file
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIndexFile:
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_index_stores_chunks(self, mock_emb, mem_iso):
|
|
f = _write_md(mem_iso["mem_dir"], "notes.md", "# Title\n\nSome content here.\n")
|
|
n = index_file(f)
|
|
assert n >= 1
|
|
|
|
conn = get_db()
|
|
try:
|
|
rows = conn.execute("SELECT file_path, chunk_text FROM chunks").fetchall()
|
|
assert len(rows) == n
|
|
assert rows[0][0] == "notes.md"
|
|
assert "Title" in rows[0][1] or "content" in rows[0][1]
|
|
finally:
|
|
conn.close()
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_index_empty_file(self, mock_emb, mem_iso):
|
|
f = _write_md(mem_iso["mem_dir"], "empty.md", "")
|
|
n = index_file(f)
|
|
assert n == 0
|
|
mock_emb.assert_not_called()
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_reindex_replaces_old_chunks(self, mock_emb, mem_iso):
|
|
"""Calling index_file twice for the same file replaces old chunks."""
|
|
f = _write_md(mem_iso["mem_dir"], "test.md", "First version.\n")
|
|
index_file(f)
|
|
|
|
# Update the file
|
|
f.write_text("Second version with more content.\n\nAnother paragraph.\n")
|
|
n2 = index_file(f)
|
|
|
|
conn = get_db()
|
|
try:
|
|
rows = conn.execute(
|
|
"SELECT chunk_text FROM chunks WHERE file_path = ?", ("test.md",)
|
|
).fetchall()
|
|
assert len(rows) == n2
|
|
# Should contain new content, not old
|
|
all_text = " ".join(r[0] for r in rows)
|
|
assert "Second version" in all_text
|
|
finally:
|
|
conn.close()
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_stores_embedding_blob(self, mock_emb, mem_iso):
|
|
f = _write_md(mem_iso["mem_dir"], "test.md", "Some text.\n")
|
|
index_file(f)
|
|
|
|
conn = get_db()
|
|
try:
|
|
row = conn.execute("SELECT embedding FROM chunks").fetchone()
|
|
assert row is not None
|
|
emb = deserialize_embedding(row[0])
|
|
assert len(emb) == EMBEDDING_DIM
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# reindex
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReindex:
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_reindex_all_files(self, mock_emb, mem_iso):
|
|
_write_md(mem_iso["mem_dir"], "a.md", "File A content.\n")
|
|
_write_md(mem_iso["mem_dir"], "b.md", "File B content.\n")
|
|
|
|
stats = reindex()
|
|
assert stats["files"] == 2
|
|
assert stats["chunks"] >= 2
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_reindex_clears_old_data(self, mock_emb, mem_iso):
|
|
"""Reindex deletes all existing chunks first."""
|
|
f = _write_md(mem_iso["mem_dir"], "old.md", "Old content.\n")
|
|
index_file(f)
|
|
|
|
# Remove the file, reindex should clear stale data
|
|
f.unlink()
|
|
_write_md(mem_iso["mem_dir"], "new.md", "New content.\n")
|
|
stats = reindex()
|
|
|
|
conn = get_db()
|
|
try:
|
|
rows = conn.execute("SELECT DISTINCT file_path FROM chunks").fetchall()
|
|
files = [r[0] for r in rows]
|
|
assert "old.md" not in files
|
|
assert "new.md" in files
|
|
finally:
|
|
conn.close()
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_reindex_empty_dir(self, mock_emb, mem_iso):
|
|
stats = reindex()
|
|
assert stats == {"files": 0, "chunks": 0}
|
|
|
|
@patch("src.memory_search.get_embedding", return_value=FAKE_EMBEDDING)
|
|
def test_reindex_includes_subdirs(self, mock_emb, mem_iso):
|
|
"""rglob should find .md files in subdirectories."""
|
|
sub = mem_iso["mem_dir"] / "kb"
|
|
sub.mkdir()
|
|
_write_md(sub, "deep.md", "Deep content.\n")
|
|
_write_md(mem_iso["mem_dir"], "top.md", "Top content.\n")
|
|
|
|
stats = reindex()
|
|
assert stats["files"] == 2
|
|
|
|
@patch("src.memory_search.get_embedding", side_effect=ConnectionError("offline"))
|
|
def test_reindex_handles_embedding_failure(self, mock_emb, mem_iso):
|
|
"""Files that fail to embed are skipped, not crash."""
|
|
_write_md(mem_iso["mem_dir"], "fail.md", "Content.\n")
|
|
stats = reindex()
|
|
# File attempted but failed — still counted (index_file raises, caught by reindex)
|
|
assert stats["files"] == 0
|
|
assert stats["chunks"] == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSearch:
|
|
def _seed_db(self, mem_iso, entries):
|
|
"""Insert test chunks into the database.
|
|
|
|
entries: list of (file_path, chunk_text, embedding_list)
|
|
"""
|
|
conn = get_db()
|
|
for i, (fp, text, emb) in enumerate(entries):
|
|
conn.execute(
|
|
"INSERT INTO chunks (file_path, chunk_index, chunk_text, embedding, updated_at) VALUES (?, ?, ?, ?, ?)",
|
|
(fp, i, text, serialize_embedding(emb), "2025-01-01T00:00:00"),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
@patch("src.memory_search.get_embedding")
|
|
def test_search_returns_sorted(self, mock_emb, mem_iso):
|
|
"""Results are sorted by score descending."""
|
|
query_vec = [1.0, 0.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
|
close_vec = [0.9, 0.1, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
|
far_vec = [0.0, 1.0, 0.0] + [0.0] * (EMBEDDING_DIM - 3)
|
|
|
|
mock_emb.return_value = query_vec
|
|
self._seed_db(mem_iso, [
|
|
("close.md", "close content", close_vec),
|
|
("far.md", "far content", far_vec),
|
|
])
|
|
|
|
results = search("test query")
|
|
assert len(results) == 2
|
|
assert results[0]["file"] == "close.md"
|
|
assert results[1]["file"] == "far.md"
|
|
assert results[0]["score"] > results[1]["score"]
|
|
|
|
@patch("src.memory_search.get_embedding")
|
|
def test_search_top_k(self, mock_emb, mem_iso):
|
|
"""top_k limits the number of results."""
|
|
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
|
mock_emb.return_value = query_vec
|
|
|
|
entries = [
|
|
(f"file{i}.md", f"content {i}", [float(i) / 10] + [0.0] * (EMBEDDING_DIM - 1))
|
|
for i in range(10)
|
|
]
|
|
self._seed_db(mem_iso, entries)
|
|
|
|
results = search("test", top_k=3)
|
|
assert len(results) == 3
|
|
|
|
@patch("src.memory_search.get_embedding")
|
|
def test_search_empty_index(self, mock_emb, mem_iso):
|
|
"""Search with no indexed data returns empty list."""
|
|
mock_emb.return_value = FAKE_EMBEDDING
|
|
# Ensure db exists but is empty
|
|
conn = get_db()
|
|
conn.close()
|
|
|
|
results = search("anything")
|
|
assert results == []
|
|
|
|
@patch("src.memory_search.get_embedding")
|
|
def test_search_result_structure(self, mock_emb, mem_iso):
|
|
"""Each result has file, chunk, score keys."""
|
|
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
|
mock_emb.return_value = query_vec
|
|
self._seed_db(mem_iso, [
|
|
("test.md", "test content", query_vec),
|
|
])
|
|
|
|
results = search("query")
|
|
assert len(results) == 1
|
|
r = results[0]
|
|
assert "file" in r
|
|
assert "chunk" in r
|
|
assert "score" in r
|
|
assert r["file"] == "test.md"
|
|
assert r["chunk"] == "test content"
|
|
assert isinstance(r["score"], float)
|
|
|
|
@patch("src.memory_search.get_embedding")
|
|
def test_search_default_top_k(self, mock_emb, mem_iso):
|
|
"""Default top_k is 5."""
|
|
query_vec = [1.0] + [0.0] * (EMBEDDING_DIM - 1)
|
|
mock_emb.return_value = query_vec
|
|
|
|
entries = [
|
|
(f"file{i}.md", f"content {i}", [float(i) / 20] + [0.0] * (EMBEDDING_DIM - 1))
|
|
for i in range(10)
|
|
]
|
|
self._seed_db(mem_iso, entries)
|
|
|
|
results = search("test")
|
|
assert len(results) == 5
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI commands: memory search, memory reindex
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCliMemorySearch:
|
|
@patch("src.memory_search.search")
|
|
def test_memory_search_shows_results(self, mock_search, capsys):
|
|
import cli
|
|
|
|
mock_search.return_value = [
|
|
{"file": "notes.md", "chunk": "Some relevant content here", "score": 0.85},
|
|
{"file": "kb/info.md", "chunk": "Another result", "score": 0.72},
|
|
]
|
|
cli._memory_search("test query")
|
|
|
|
out = capsys.readouterr().out
|
|
assert "notes.md" in out
|
|
assert "0.850" in out
|
|
assert "kb/info.md" in out
|
|
assert "0.720" in out
|
|
assert "Result 1" in out
|
|
assert "Result 2" in out
|
|
|
|
@patch("src.memory_search.search")
|
|
def test_memory_search_empty_results(self, mock_search, capsys):
|
|
import cli
|
|
|
|
mock_search.return_value = []
|
|
cli._memory_search("nothing")
|
|
|
|
out = capsys.readouterr().out
|
|
assert "No results" in out
|
|
assert "reindex" in out
|
|
|
|
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
|
|
def test_memory_search_connection_error(self, mock_search, capsys):
|
|
import cli
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli._memory_search("test")
|
|
out = capsys.readouterr().out
|
|
assert "Ollama offline" in out
|
|
|
|
@patch("src.memory_search.search")
|
|
def test_memory_search_truncates_long_chunks(self, mock_search, capsys):
|
|
import cli
|
|
|
|
mock_search.return_value = [
|
|
{"file": "test.md", "chunk": "x" * 500, "score": 0.9},
|
|
]
|
|
cli._memory_search("query")
|
|
|
|
out = capsys.readouterr().out
|
|
assert "..." in out # preview truncated at 200 chars
|
|
|
|
|
|
class TestCliMemoryReindex:
|
|
@patch("src.memory_search.reindex")
|
|
def test_memory_reindex_shows_stats(self, mock_reindex, capsys):
|
|
import cli
|
|
|
|
mock_reindex.return_value = {"files": 5, "chunks": 23}
|
|
cli._memory_reindex()
|
|
|
|
out = capsys.readouterr().out
|
|
assert "5 files" in out
|
|
assert "23 chunks" in out
|
|
|
|
@patch("src.memory_search.reindex", side_effect=ConnectionError("Ollama down"))
|
|
def test_memory_reindex_connection_error(self, mock_reindex, capsys):
|
|
import cli
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli._memory_reindex()
|
|
out = capsys.readouterr().out
|
|
assert "Ollama down" in out
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Discord /search command
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDiscordSearchCommand:
|
|
def _find_command(self, tree, name):
|
|
for cmd in tree.get_commands():
|
|
if cmd.name == name:
|
|
return cmd
|
|
return None
|
|
|
|
def _mock_interaction(self, user_id="123", channel_id="456"):
|
|
interaction = AsyncMock()
|
|
interaction.user = MagicMock()
|
|
interaction.user.id = int(user_id)
|
|
interaction.channel_id = int(channel_id)
|
|
interaction.response = AsyncMock()
|
|
interaction.response.defer = AsyncMock()
|
|
interaction.followup = AsyncMock()
|
|
interaction.followup.send = AsyncMock()
|
|
return interaction
|
|
|
|
@pytest.fixture
|
|
def search_bot(self, tmp_path):
|
|
"""Create a bot with config for testing /search."""
|
|
import json
|
|
from src.config import Config
|
|
from src.adapters.discord_bot import create_bot
|
|
|
|
data = {
|
|
"bot": {"name": "Echo", "default_model": "sonnet", "owner": "111", "admins": []},
|
|
"channels": {},
|
|
}
|
|
config_file = tmp_path / "config.json"
|
|
config_file.write_text(json.dumps(data, indent=2))
|
|
config = Config(config_file)
|
|
return create_bot(config)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.memory_search.search")
|
|
async def test_search_command_exists(self, mock_search, search_bot):
|
|
cmd = self._find_command(search_bot.tree, "search")
|
|
assert cmd is not None
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.memory_search.search")
|
|
async def test_search_command_with_results(self, mock_search, search_bot):
|
|
mock_search.return_value = [
|
|
{"file": "notes.md", "chunk": "Relevant content", "score": 0.9},
|
|
]
|
|
cmd = self._find_command(search_bot.tree, "search")
|
|
interaction = self._mock_interaction()
|
|
await cmd.callback(interaction, query="test query")
|
|
|
|
interaction.response.defer.assert_awaited_once()
|
|
interaction.followup.send.assert_awaited_once()
|
|
msg = interaction.followup.send.call_args.args[0]
|
|
assert "notes.md" in msg
|
|
assert "0.9" in msg
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.memory_search.search")
|
|
async def test_search_command_empty_results(self, mock_search, search_bot):
|
|
mock_search.return_value = []
|
|
cmd = self._find_command(search_bot.tree, "search")
|
|
interaction = self._mock_interaction()
|
|
await cmd.callback(interaction, query="nothing")
|
|
|
|
msg = interaction.followup.send.call_args.args[0]
|
|
assert "no results" in msg.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.memory_search.search", side_effect=ConnectionError("Ollama offline"))
|
|
async def test_search_command_connection_error(self, mock_search, search_bot):
|
|
cmd = self._find_command(search_bot.tree, "search")
|
|
interaction = self._mock_interaction()
|
|
await cmd.callback(interaction, query="test")
|
|
|
|
msg = interaction.followup.send.call_args.args[0]
|
|
assert "error" in msg.lower()
|