stage-3: claude CLI wrapper with session management
Subprocess wrapper for Claude CLI with start/resume/clear sessions, personality system prompt, atomic session tracking. 38 new tests (89 total). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
276
src/claude_session.py
Normal file
276
src/claude_session.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Claude CLI session manager for Echo-Core.
|
||||
|
||||
Wraps the Claude Code CLI to provide conversation management.
|
||||
Each Discord channel maps to at most one active Claude session,
|
||||
tracked in sessions/active.json.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants & configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
PERSONALITY_DIR = PROJECT_ROOT / "personality"
|
||||
SESSIONS_DIR = PROJECT_ROOT / "sessions"
|
||||
_SESSIONS_FILE = SESSIONS_DIR / "active.json"
|
||||
|
||||
VALID_MODELS = {"haiku", "sonnet", "opus"}
|
||||
DEFAULT_MODEL = "sonnet"
|
||||
DEFAULT_TIMEOUT = 120 # seconds
|
||||
|
||||
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
||||
|
||||
PERSONALITY_FILES = [
|
||||
"IDENTITY.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"AGENTS.md",
|
||||
"TOOLS.md",
|
||||
"HEARTBEAT.md",
|
||||
]
|
||||
|
||||
# Environment variables allowed through to the Claude subprocess
|
||||
_ENV_PASSTHROUGH = {
|
||||
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM",
|
||||
"SHELL", "TMPDIR", "XDG_CONFIG_HOME", "XDG_DATA_HOME",
|
||||
"CLAUDE_CONFIG_DIR",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level binary check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if not shutil.which(CLAUDE_BIN):
|
||||
logger.warning(
|
||||
"Claude CLI (%s) not found on PATH. "
|
||||
"Session functions will raise FileNotFoundError.",
|
||||
CLAUDE_BIN,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_env() -> dict[str, str]:
|
||||
"""Return a filtered copy of os.environ for subprocess calls."""
|
||||
return {k: v for k, v in os.environ.items() if k in _ENV_PASSTHROUGH}
|
||||
|
||||
|
||||
def _load_sessions() -> dict:
|
||||
"""Load sessions from active.json. Returns {} if missing or empty."""
|
||||
try:
|
||||
text = _SESSIONS_FILE.read_text(encoding="utf-8")
|
||||
if not text.strip():
|
||||
return {}
|
||||
return json.loads(text)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
|
||||
def _save_sessions(data: dict) -> None:
|
||||
"""Atomically write sessions to active.json via tempfile + os.replace."""
|
||||
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=SESSIONS_DIR, prefix=".active_", suffix=".json"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
os.replace(tmp_path, _SESSIONS_FILE)
|
||||
except BaseException:
|
||||
# Clean up temp file on any failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _run_claude(cmd: list[str], timeout: int) -> dict:
|
||||
"""Run a Claude CLI command and return the parsed JSON output."""
|
||||
if not shutil.which(CLAUDE_BIN):
|
||||
raise FileNotFoundError(
|
||||
"Claude CLI not found. "
|
||||
"Install: https://docs.anthropic.com/en/docs/claude-code"
|
||||
)
|
||||
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=_safe_env(),
|
||||
cwd=PROJECT_ROOT,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise TimeoutError(f"Claude CLI timed out after {timeout}s")
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI error (exit {proc.returncode}): "
|
||||
f"{proc.stderr[:500]}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(proc.stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Failed to parse Claude CLI output: {exc}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_system_prompt() -> str:
|
||||
"""Concatenate personality/*.md files into a single system prompt."""
|
||||
if not PERSONALITY_DIR.is_dir():
|
||||
raise FileNotFoundError(
|
||||
f"Personality directory not found: {PERSONALITY_DIR}"
|
||||
)
|
||||
|
||||
parts: list[str] = []
|
||||
for filename in PERSONALITY_FILES:
|
||||
filepath = PERSONALITY_DIR / filename
|
||||
if filepath.is_file():
|
||||
parts.append(filepath.read_text(encoding="utf-8"))
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
|
||||
def start_session(
|
||||
channel_id: str,
|
||||
message: str,
|
||||
model: str = DEFAULT_MODEL,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> tuple[str, str]:
|
||||
"""Start a new Claude CLI session for a channel.
|
||||
|
||||
Returns (response_text, session_id).
|
||||
"""
|
||||
if model not in VALID_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model}'. Must be one of: haiku, sonnet, opus"
|
||||
)
|
||||
|
||||
system_prompt = build_system_prompt()
|
||||
|
||||
cmd = [
|
||||
CLAUDE_BIN, "-p", message,
|
||||
"--model", model,
|
||||
"--output-format", "json",
|
||||
"--system-prompt", system_prompt,
|
||||
]
|
||||
|
||||
data = _run_claude(cmd, timeout)
|
||||
|
||||
for field in ("result", "session_id"):
|
||||
if field not in data:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI response missing required field: {field}"
|
||||
)
|
||||
|
||||
response_text = data["result"]
|
||||
session_id = data["session_id"]
|
||||
|
||||
# Save session metadata
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
sessions = _load_sessions()
|
||||
sessions[channel_id] = {
|
||||
"session_id": session_id,
|
||||
"model": model,
|
||||
"created_at": now,
|
||||
"last_message_at": now,
|
||||
"message_count": 1,
|
||||
}
|
||||
_save_sessions(sessions)
|
||||
|
||||
return response_text, session_id
|
||||
|
||||
|
||||
def resume_session(
|
||||
session_id: str,
|
||||
message: str,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> str:
|
||||
"""Resume an existing Claude session by ID. Returns response text."""
|
||||
cmd = [
|
||||
CLAUDE_BIN, "-p", message,
|
||||
"--resume", session_id,
|
||||
"--output-format", "json",
|
||||
]
|
||||
|
||||
data = _run_claude(cmd, timeout)
|
||||
|
||||
if "result" not in data:
|
||||
raise RuntimeError(
|
||||
"Claude CLI response missing required field: result"
|
||||
)
|
||||
|
||||
response_text = data["result"]
|
||||
|
||||
# Update session metadata
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
sessions = _load_sessions()
|
||||
for channel_id, session in sessions.items():
|
||||
if session.get("session_id") == session_id:
|
||||
session["last_message_at"] = now
|
||||
session["message_count"] = session.get("message_count", 0) + 1
|
||||
break
|
||||
_save_sessions(sessions)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
def send_message(
|
||||
channel_id: str,
|
||||
message: str,
|
||||
model: str = DEFAULT_MODEL,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> str:
|
||||
"""High-level convenience: auto start or resume based on channel state."""
|
||||
session = get_active_session(channel_id)
|
||||
if session is not None:
|
||||
return resume_session(session["session_id"], message, timeout)
|
||||
response_text, _session_id = start_session(
|
||||
channel_id, message, model, timeout
|
||||
)
|
||||
return response_text
|
||||
|
||||
|
||||
def clear_session(channel_id: str) -> bool:
|
||||
"""Remove a channel's session entry. Returns True if removed."""
|
||||
sessions = _load_sessions()
|
||||
if channel_id not in sessions:
|
||||
return False
|
||||
del sessions[channel_id]
|
||||
_save_sessions(sessions)
|
||||
return True
|
||||
|
||||
|
||||
def get_active_session(channel_id: str) -> dict | None:
|
||||
"""Return session metadata for a channel, or None."""
|
||||
sessions = _load_sessions()
|
||||
return sessions.get(channel_id)
|
||||
|
||||
|
||||
def list_sessions() -> dict:
|
||||
"""Return all channel→session mappings from active.json."""
|
||||
return _load_sessions()
|
||||
576
tests/test_claude_session.py
Normal file
576
tests/test_claude_session.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""Comprehensive tests for src/claude_session.py."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src import claude_session
|
||||
from src.claude_session import (
|
||||
_load_sessions,
|
||||
_run_claude,
|
||||
_safe_env,
|
||||
_save_sessions,
|
||||
build_system_prompt,
|
||||
clear_session,
|
||||
get_active_session,
|
||||
list_sessions,
|
||||
resume_session,
|
||||
send_message,
|
||||
start_session,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FAKE_CLI_RESPONSE = {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"session_id": "sess-abc-123",
|
||||
"result": "Hello from Claude!",
|
||||
"cost_usd": 0.004,
|
||||
"duration_ms": 1500,
|
||||
"num_turns": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_proc(stdout=None, returncode=0, stderr=""):
|
||||
"""Build a fake subprocess.CompletedProcess."""
|
||||
if stdout is None:
|
||||
stdout = json.dumps(FAKE_CLI_RESPONSE)
|
||||
proc = MagicMock(spec=subprocess.CompletedProcess)
|
||||
proc.stdout = stdout
|
||||
proc.stderr = stderr
|
||||
proc.returncode = returncode
|
||||
return proc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSystemPrompt:
|
||||
def test_returns_non_empty_string(self):
|
||||
prompt = build_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
def test_contains_personality_content(self):
|
||||
prompt = build_system_prompt()
|
||||
# IDENTITY.md is first; should appear in the prompt
|
||||
identity_path = claude_session.PERSONALITY_DIR / "IDENTITY.md"
|
||||
identity_content = identity_path.read_text(encoding="utf-8")
|
||||
assert identity_content in prompt
|
||||
|
||||
def test_correct_order(self):
|
||||
prompt = build_system_prompt()
|
||||
expected_order = [
|
||||
"IDENTITY.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"AGENTS.md",
|
||||
"TOOLS.md",
|
||||
"HEARTBEAT.md",
|
||||
]
|
||||
# Read each file and find its position in the prompt
|
||||
positions = []
|
||||
for filename in expected_order:
|
||||
filepath = claude_session.PERSONALITY_DIR / filename
|
||||
if filepath.is_file():
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
pos = prompt.find(content)
|
||||
assert pos >= 0, f"{filename} not found in prompt"
|
||||
positions.append(pos)
|
||||
|
||||
# Positions must be strictly increasing
|
||||
for i in range(1, len(positions)):
|
||||
assert positions[i] > positions[i - 1], (
|
||||
f"Order violation: {expected_order[i]} appears before "
|
||||
f"{expected_order[i - 1]}"
|
||||
)
|
||||
|
||||
def test_separator_between_files(self):
|
||||
prompt = build_system_prompt()
|
||||
assert "\n\n---\n\n" in prompt
|
||||
|
||||
def test_missing_personality_dir_raises(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
claude_session, "PERSONALITY_DIR", tmp_path / "nonexistent"
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
build_system_prompt()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _safe_env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSafeEnv:
|
||||
def test_includes_path_and_home(self, monkeypatch):
|
||||
monkeypatch.setenv("PATH", "/usr/bin")
|
||||
monkeypatch.setenv("HOME", "/home/test")
|
||||
env = _safe_env()
|
||||
assert "PATH" in env
|
||||
assert "HOME" in env
|
||||
|
||||
def test_excludes_discord_token(self, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_TOKEN", "secret-token")
|
||||
env = _safe_env()
|
||||
assert "DISCORD_TOKEN" not in env
|
||||
|
||||
def test_excludes_claudecode(self, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDECODE", "1")
|
||||
env = _safe_env()
|
||||
assert "CLAUDECODE" not in env
|
||||
|
||||
def test_excludes_api_keys(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-xxx")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-xxx")
|
||||
env = _safe_env()
|
||||
assert "OPENAI_API_KEY" not in env
|
||||
assert "ANTHROPIC_API_KEY" not in env
|
||||
|
||||
def test_only_passthrough_keys(self, monkeypatch):
|
||||
monkeypatch.setenv("RANDOM_SECRET", "bad")
|
||||
env = _safe_env()
|
||||
for key in env:
|
||||
assert key in claude_session._ENV_PASSTHROUGH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_claude
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunClaude:
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_returns_parsed_json(self, mock_run, mock_which):
|
||||
mock_run.return_value = _make_proc()
|
||||
result = _run_claude(["claude", "-p", "hi"], timeout=30)
|
||||
assert result == FAKE_CLI_RESPONSE
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_timeout_raises(self, mock_run, mock_which):
|
||||
mock_run.side_effect = subprocess.TimeoutExpired(cmd="claude", timeout=30)
|
||||
with pytest.raises(TimeoutError, match="timed out after 30s"):
|
||||
_run_claude(["claude", "-p", "hi"], timeout=30)
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_nonzero_exit_raises(self, mock_run, mock_which):
|
||||
mock_run.return_value = _make_proc(
|
||||
stdout="", returncode=1, stderr="something went wrong"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="exit 1"):
|
||||
_run_claude(["claude", "-p", "hi"], timeout=30)
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_invalid_json_raises(self, mock_run, mock_which):
|
||||
mock_run.return_value = _make_proc(stdout="not json {{{")
|
||||
with pytest.raises(RuntimeError, match="Failed to parse"):
|
||||
_run_claude(["claude", "-p", "hi"], timeout=30)
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
def test_missing_binary_raises(self, mock_which):
|
||||
with pytest.raises(FileNotFoundError, match="Claude CLI not found"):
|
||||
_run_claude(["claude", "-p", "hi"], timeout=30)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session file helpers (_load_sessions / _save_sessions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionFileOps:
|
||||
def test_load_missing_file_returns_empty(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", tmp_path / "no_such.json"
|
||||
)
|
||||
assert _load_sessions() == {}
|
||||
|
||||
def test_load_empty_file_returns_empty(self, tmp_path, monkeypatch):
|
||||
f = tmp_path / "active.json"
|
||||
f.write_text("")
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", f)
|
||||
assert _load_sessions() == {}
|
||||
|
||||
def test_save_and_load_roundtrip(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
data = {"general": {"session_id": "abc", "message_count": 1}}
|
||||
_save_sessions(data)
|
||||
loaded = _load_sessions()
|
||||
assert loaded == data
|
||||
|
||||
def test_save_creates_directory(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "deep" / "sessions"
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
_save_sessions({"test": "data"})
|
||||
assert sf.exists()
|
||||
assert json.loads(sf.read_text()) == {"test": "data"}
|
||||
|
||||
def test_atomic_write_no_corruption_on_error(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
# Write initial good data
|
||||
_save_sessions({"good": "data"})
|
||||
|
||||
# Attempt to write non-serializable data → should raise
|
||||
class Unserializable:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_save_sessions({"bad": Unserializable()})
|
||||
|
||||
# Original file should still be intact
|
||||
assert json.loads(sf.read_text()) == {"good": "data"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStartSession:
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_returns_response_and_session_id(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", sessions_dir / "active.json"
|
||||
)
|
||||
mock_run.return_value = _make_proc()
|
||||
|
||||
response, sid = start_session("general", "Hello")
|
||||
assert response == "Hello from Claude!"
|
||||
assert sid == "sess-abc-123"
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_saves_to_active_json(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
mock_run.return_value = _make_proc()
|
||||
|
||||
start_session("general", "Hello")
|
||||
|
||||
data = json.loads(sf.read_text())
|
||||
assert "general" in data
|
||||
assert data["general"]["session_id"] == "sess-abc-123"
|
||||
assert data["general"]["model"] == "sonnet"
|
||||
assert data["general"]["message_count"] == 1
|
||||
assert "created_at" in data["general"]
|
||||
assert "last_message_at" in data["general"]
|
||||
|
||||
def test_invalid_model_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid model"):
|
||||
start_session("general", "Hello", model="gpt-4")
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_missing_result_field_raises(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", sessions_dir / "active.json"
|
||||
)
|
||||
bad_response = {"session_id": "abc"} # missing "result"
|
||||
mock_run.return_value = _make_proc(stdout=json.dumps(bad_response))
|
||||
|
||||
with pytest.raises(RuntimeError, match="missing required field"):
|
||||
start_session("general", "Hello")
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_missing_session_id_field_raises(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", sessions_dir / "active.json"
|
||||
)
|
||||
bad_response = {"result": "hello"} # missing "session_id"
|
||||
mock_run.return_value = _make_proc(stdout=json.dumps(bad_response))
|
||||
|
||||
with pytest.raises(RuntimeError, match="missing required field"):
|
||||
start_session("general", "Hello")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resume_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResumeSession:
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_returns_response(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
# Pre-populate active.json
|
||||
sf.write_text(json.dumps({
|
||||
"general": {
|
||||
"session_id": "sess-abc-123",
|
||||
"model": "sonnet",
|
||||
"created_at": "2026-01-01T00:00:00+00:00",
|
||||
"last_message_at": "2026-01-01T00:00:00+00:00",
|
||||
"message_count": 1,
|
||||
}
|
||||
}))
|
||||
|
||||
mock_run.return_value = _make_proc()
|
||||
response = resume_session("sess-abc-123", "Follow up")
|
||||
assert response == "Hello from Claude!"
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_updates_message_count_and_timestamp(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
old_ts = "2026-01-01T00:00:00+00:00"
|
||||
sf.write_text(json.dumps({
|
||||
"general": {
|
||||
"session_id": "sess-abc-123",
|
||||
"model": "sonnet",
|
||||
"created_at": old_ts,
|
||||
"last_message_at": old_ts,
|
||||
"message_count": 3,
|
||||
}
|
||||
}))
|
||||
|
||||
mock_run.return_value = _make_proc()
|
||||
resume_session("sess-abc-123", "Follow up")
|
||||
|
||||
data = json.loads(sf.read_text())
|
||||
assert data["general"]["message_count"] == 4
|
||||
assert data["general"]["last_message_at"] != old_ts
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_uses_resume_flag(self, mock_run, mock_which, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
sf.write_text(json.dumps({}))
|
||||
|
||||
mock_run.return_value = _make_proc()
|
||||
resume_session("sess-abc-123", "Follow up")
|
||||
|
||||
# Verify --resume was in the command
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert "--resume" in cmd
|
||||
assert "sess-abc-123" in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_starts_new_session_when_none_exists(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
sf.write_text("{}")
|
||||
|
||||
mock_run.return_value = _make_proc()
|
||||
response = send_message("general", "Hello")
|
||||
assert response == "Hello from Claude!"
|
||||
|
||||
# Should have created a session
|
||||
data = json.loads(sf.read_text())
|
||||
assert "general" in data
|
||||
|
||||
@patch("shutil.which", return_value="/usr/bin/claude")
|
||||
@patch("subprocess.run")
|
||||
def test_resumes_existing_session(
|
||||
self, mock_run, mock_which, tmp_path, monkeypatch
|
||||
):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
sf.write_text(json.dumps({
|
||||
"general": {
|
||||
"session_id": "sess-existing",
|
||||
"model": "sonnet",
|
||||
"created_at": "2026-01-01T00:00:00+00:00",
|
||||
"last_message_at": "2026-01-01T00:00:00+00:00",
|
||||
"message_count": 1,
|
||||
}
|
||||
}))
|
||||
|
||||
mock_run.return_value = _make_proc()
|
||||
response = send_message("general", "Follow up")
|
||||
assert response == "Hello from Claude!"
|
||||
|
||||
# Should have used --resume
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert "--resume" in cmd
|
||||
assert "sess-existing" in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClearSession:
|
||||
def test_returns_true_when_existed(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
sf.write_text(json.dumps({"general": {"session_id": "abc"}}))
|
||||
assert clear_session("general") is True
|
||||
|
||||
def test_returns_false_when_not_existed(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
sf.write_text("{}")
|
||||
assert clear_session("general") is False
|
||||
|
||||
def test_removes_from_active_json(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
sf.write_text(json.dumps({
|
||||
"general": {"session_id": "abc"},
|
||||
"other": {"session_id": "def"},
|
||||
}))
|
||||
clear_session("general")
|
||||
|
||||
data = json.loads(sf.read_text())
|
||||
assert "general" not in data
|
||||
assert "other" in data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_active_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetActiveSession:
|
||||
def test_returns_dict_when_exists(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
session_data = {"session_id": "abc", "model": "sonnet"}
|
||||
sf.write_text(json.dumps({"general": session_data}))
|
||||
|
||||
result = get_active_session("general")
|
||||
assert result == session_data
|
||||
|
||||
def test_returns_none_when_not_exists(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
sf.write_text("{}")
|
||||
assert get_active_session("general") is None
|
||||
|
||||
def test_returns_none_when_file_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", tmp_path / "nonexistent.json"
|
||||
)
|
||||
assert get_active_session("general") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListSessions:
|
||||
def test_returns_all_sessions(self, tmp_path, monkeypatch):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
|
||||
data = {
|
||||
"general": {"session_id": "abc"},
|
||||
"dev": {"session_id": "def"},
|
||||
}
|
||||
sf.write_text(json.dumps(data))
|
||||
|
||||
result = list_sessions()
|
||||
assert result == data
|
||||
|
||||
def test_returns_empty_when_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
claude_session, "_SESSIONS_FILE", tmp_path / "nonexistent.json"
|
||||
)
|
||||
assert list_sessions() == {}
|
||||
Reference in New Issue
Block a user