From 339866baa169fa15878ce085ed7e8f43068750b1 Mon Sep 17 00:00:00 2001 From: MoltBot Service Date: Fri, 13 Feb 2026 12:12:07 +0000 Subject: [PATCH] stage-3: claude CLI wrapper with session management Subprocess wrapper for Claude CLI with start/resume/clear sessions, personality system prompt, atomic session tracking. 38 new tests (89 total). Co-Authored-By: Claude Opus 4.6 --- src/claude_session.py | 276 +++++++++++++++++ tests/test_claude_session.py | 576 +++++++++++++++++++++++++++++++++++ 2 files changed, 852 insertions(+) create mode 100644 src/claude_session.py create mode 100644 tests/test_claude_session.py diff --git a/src/claude_session.py b/src/claude_session.py new file mode 100644 index 0000000..2c99875 --- /dev/null +++ b/src/claude_session.py @@ -0,0 +1,276 @@ +""" +Claude CLI session manager for Echo-Core. + +Wraps the Claude Code CLI to provide conversation management. +Each Discord channel maps to at most one active Claude session, +tracked in sessions/active.json. +""" + +import json +import logging +import os +import shutil +import subprocess +import tempfile +from datetime import datetime, timezone +from pathlib import Path + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants & configuration +# --------------------------------------------------------------------------- + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +PERSONALITY_DIR = PROJECT_ROOT / "personality" +SESSIONS_DIR = PROJECT_ROOT / "sessions" +_SESSIONS_FILE = SESSIONS_DIR / "active.json" + +VALID_MODELS = {"haiku", "sonnet", "opus"} +DEFAULT_MODEL = "sonnet" +DEFAULT_TIMEOUT = 120 # seconds + +CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude") + +PERSONALITY_FILES = [ + "IDENTITY.md", + "SOUL.md", + "USER.md", + "AGENTS.md", + "TOOLS.md", + "HEARTBEAT.md", +] + +# Environment variables allowed through to the Claude subprocess +_ENV_PASSTHROUGH = { + "PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", + "SHELL", "TMPDIR", "XDG_CONFIG_HOME", "XDG_DATA_HOME", + "CLAUDE_CONFIG_DIR", +} + +# --------------------------------------------------------------------------- +# Module-level binary check +# --------------------------------------------------------------------------- + +if not shutil.which(CLAUDE_BIN): + logger.warning( + "Claude CLI (%s) not found on PATH. " + "Session functions will raise FileNotFoundError.", + CLAUDE_BIN, + ) + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _safe_env() -> dict[str, str]: + """Return a filtered copy of os.environ for subprocess calls.""" + return {k: v for k, v in os.environ.items() if k in _ENV_PASSTHROUGH} + + +def _load_sessions() -> dict: + """Load sessions from active.json. Returns {} if missing or empty.""" + try: + text = _SESSIONS_FILE.read_text(encoding="utf-8") + if not text.strip(): + return {} + return json.loads(text) + except (FileNotFoundError, json.JSONDecodeError): + return {} + + +def _save_sessions(data: dict) -> None: + """Atomically write sessions to active.json via tempfile + os.replace.""" + SESSIONS_DIR.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=SESSIONS_DIR, prefix=".active_", suffix=".json" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + f.write("\n") + os.replace(tmp_path, _SESSIONS_FILE) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +def _run_claude(cmd: list[str], timeout: int) -> dict: + """Run a Claude CLI command and return the parsed JSON output.""" + if not shutil.which(CLAUDE_BIN): + raise FileNotFoundError( + "Claude CLI not found. " + "Install: https://docs.anthropic.com/en/docs/claude-code" + ) + + try: + proc = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + env=_safe_env(), + cwd=PROJECT_ROOT, + ) + except subprocess.TimeoutExpired: + raise TimeoutError(f"Claude CLI timed out after {timeout}s") + + if proc.returncode != 0: + raise RuntimeError( + f"Claude CLI error (exit {proc.returncode}): " + f"{proc.stderr[:500]}" + ) + + try: + data = json.loads(proc.stdout) + except json.JSONDecodeError as exc: + raise RuntimeError(f"Failed to parse Claude CLI output: {exc}") + + return data + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def build_system_prompt() -> str: + """Concatenate personality/*.md files into a single system prompt.""" + if not PERSONALITY_DIR.is_dir(): + raise FileNotFoundError( + f"Personality directory not found: {PERSONALITY_DIR}" + ) + + parts: list[str] = [] + for filename in PERSONALITY_FILES: + filepath = PERSONALITY_DIR / filename + if filepath.is_file(): + parts.append(filepath.read_text(encoding="utf-8")) + + return "\n\n---\n\n".join(parts) + + +def start_session( + channel_id: str, + message: str, + model: str = DEFAULT_MODEL, + timeout: int = DEFAULT_TIMEOUT, +) -> tuple[str, str]: + """Start a new Claude CLI session for a channel. + + Returns (response_text, session_id). + """ + if model not in VALID_MODELS: + raise ValueError( + f"Invalid model '{model}'. Must be one of: haiku, sonnet, opus" + ) + + system_prompt = build_system_prompt() + + cmd = [ + CLAUDE_BIN, "-p", message, + "--model", model, + "--output-format", "json", + "--system-prompt", system_prompt, + ] + + data = _run_claude(cmd, timeout) + + for field in ("result", "session_id"): + if field not in data: + raise RuntimeError( + f"Claude CLI response missing required field: {field}" + ) + + response_text = data["result"] + session_id = data["session_id"] + + # Save session metadata + now = datetime.now(timezone.utc).isoformat() + sessions = _load_sessions() + sessions[channel_id] = { + "session_id": session_id, + "model": model, + "created_at": now, + "last_message_at": now, + "message_count": 1, + } + _save_sessions(sessions) + + return response_text, session_id + + +def resume_session( + session_id: str, + message: str, + timeout: int = DEFAULT_TIMEOUT, +) -> str: + """Resume an existing Claude session by ID. Returns response text.""" + cmd = [ + CLAUDE_BIN, "-p", message, + "--resume", session_id, + "--output-format", "json", + ] + + data = _run_claude(cmd, timeout) + + if "result" not in data: + raise RuntimeError( + "Claude CLI response missing required field: result" + ) + + response_text = data["result"] + + # Update session metadata + now = datetime.now(timezone.utc).isoformat() + sessions = _load_sessions() + for channel_id, session in sessions.items(): + if session.get("session_id") == session_id: + session["last_message_at"] = now + session["message_count"] = session.get("message_count", 0) + 1 + break + _save_sessions(sessions) + + return response_text + + +def send_message( + channel_id: str, + message: str, + model: str = DEFAULT_MODEL, + timeout: int = DEFAULT_TIMEOUT, +) -> str: + """High-level convenience: auto start or resume based on channel state.""" + session = get_active_session(channel_id) + if session is not None: + return resume_session(session["session_id"], message, timeout) + response_text, _session_id = start_session( + channel_id, message, model, timeout + ) + return response_text + + +def clear_session(channel_id: str) -> bool: + """Remove a channel's session entry. Returns True if removed.""" + sessions = _load_sessions() + if channel_id not in sessions: + return False + del sessions[channel_id] + _save_sessions(sessions) + return True + + +def get_active_session(channel_id: str) -> dict | None: + """Return session metadata for a channel, or None.""" + sessions = _load_sessions() + return sessions.get(channel_id) + + +def list_sessions() -> dict: + """Return all channel→session mappings from active.json.""" + return _load_sessions() diff --git a/tests/test_claude_session.py b/tests/test_claude_session.py new file mode 100644 index 0000000..b4ea90e --- /dev/null +++ b/tests/test_claude_session.py @@ -0,0 +1,576 @@ +"""Comprehensive tests for src/claude_session.py.""" + +import json +import os +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from src import claude_session +from src.claude_session import ( + _load_sessions, + _run_claude, + _safe_env, + _save_sessions, + build_system_prompt, + clear_session, + get_active_session, + list_sessions, + resume_session, + send_message, + start_session, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +FAKE_CLI_RESPONSE = { + "type": "result", + "subtype": "success", + "session_id": "sess-abc-123", + "result": "Hello from Claude!", + "cost_usd": 0.004, + "duration_ms": 1500, + "num_turns": 1, +} + + +def _make_proc(stdout=None, returncode=0, stderr=""): + """Build a fake subprocess.CompletedProcess.""" + if stdout is None: + stdout = json.dumps(FAKE_CLI_RESPONSE) + proc = MagicMock(spec=subprocess.CompletedProcess) + proc.stdout = stdout + proc.stderr = stderr + proc.returncode = returncode + return proc + + +# --------------------------------------------------------------------------- +# build_system_prompt +# --------------------------------------------------------------------------- + + +class TestBuildSystemPrompt: + def test_returns_non_empty_string(self): + prompt = build_system_prompt() + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_contains_personality_content(self): + prompt = build_system_prompt() + # IDENTITY.md is first; should appear in the prompt + identity_path = claude_session.PERSONALITY_DIR / "IDENTITY.md" + identity_content = identity_path.read_text(encoding="utf-8") + assert identity_content in prompt + + def test_correct_order(self): + prompt = build_system_prompt() + expected_order = [ + "IDENTITY.md", + "SOUL.md", + "USER.md", + "AGENTS.md", + "TOOLS.md", + "HEARTBEAT.md", + ] + # Read each file and find its position in the prompt + positions = [] + for filename in expected_order: + filepath = claude_session.PERSONALITY_DIR / filename + if filepath.is_file(): + content = filepath.read_text(encoding="utf-8") + pos = prompt.find(content) + assert pos >= 0, f"{filename} not found in prompt" + positions.append(pos) + + # Positions must be strictly increasing + for i in range(1, len(positions)): + assert positions[i] > positions[i - 1], ( + f"Order violation: {expected_order[i]} appears before " + f"{expected_order[i - 1]}" + ) + + def test_separator_between_files(self): + prompt = build_system_prompt() + assert "\n\n---\n\n" in prompt + + def test_missing_personality_dir_raises(self, tmp_path, monkeypatch): + monkeypatch.setattr( + claude_session, "PERSONALITY_DIR", tmp_path / "nonexistent" + ) + with pytest.raises(FileNotFoundError): + build_system_prompt() + + +# --------------------------------------------------------------------------- +# _safe_env +# --------------------------------------------------------------------------- + + +class TestSafeEnv: + def test_includes_path_and_home(self, monkeypatch): + monkeypatch.setenv("PATH", "/usr/bin") + monkeypatch.setenv("HOME", "/home/test") + env = _safe_env() + assert "PATH" in env + assert "HOME" in env + + def test_excludes_discord_token(self, monkeypatch): + monkeypatch.setenv("DISCORD_TOKEN", "secret-token") + env = _safe_env() + assert "DISCORD_TOKEN" not in env + + def test_excludes_claudecode(self, monkeypatch): + monkeypatch.setenv("CLAUDECODE", "1") + env = _safe_env() + assert "CLAUDECODE" not in env + + def test_excludes_api_keys(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "sk-xxx") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-xxx") + env = _safe_env() + assert "OPENAI_API_KEY" not in env + assert "ANTHROPIC_API_KEY" not in env + + def test_only_passthrough_keys(self, monkeypatch): + monkeypatch.setenv("RANDOM_SECRET", "bad") + env = _safe_env() + for key in env: + assert key in claude_session._ENV_PASSTHROUGH + + +# --------------------------------------------------------------------------- +# _run_claude +# --------------------------------------------------------------------------- + + +class TestRunClaude: + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_returns_parsed_json(self, mock_run, mock_which): + mock_run.return_value = _make_proc() + result = _run_claude(["claude", "-p", "hi"], timeout=30) + assert result == FAKE_CLI_RESPONSE + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_timeout_raises(self, mock_run, mock_which): + mock_run.side_effect = subprocess.TimeoutExpired(cmd="claude", timeout=30) + with pytest.raises(TimeoutError, match="timed out after 30s"): + _run_claude(["claude", "-p", "hi"], timeout=30) + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_nonzero_exit_raises(self, mock_run, mock_which): + mock_run.return_value = _make_proc( + stdout="", returncode=1, stderr="something went wrong" + ) + with pytest.raises(RuntimeError, match="exit 1"): + _run_claude(["claude", "-p", "hi"], timeout=30) + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_invalid_json_raises(self, mock_run, mock_which): + mock_run.return_value = _make_proc(stdout="not json {{{") + with pytest.raises(RuntimeError, match="Failed to parse"): + _run_claude(["claude", "-p", "hi"], timeout=30) + + @patch("shutil.which", return_value=None) + def test_missing_binary_raises(self, mock_which): + with pytest.raises(FileNotFoundError, match="Claude CLI not found"): + _run_claude(["claude", "-p", "hi"], timeout=30) + + +# --------------------------------------------------------------------------- +# Session file helpers (_load_sessions / _save_sessions) +# --------------------------------------------------------------------------- + + +class TestSessionFileOps: + def test_load_missing_file_returns_empty(self, tmp_path, monkeypatch): + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", tmp_path / "no_such.json" + ) + assert _load_sessions() == {} + + def test_load_empty_file_returns_empty(self, tmp_path, monkeypatch): + f = tmp_path / "active.json" + f.write_text("") + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", f) + assert _load_sessions() == {} + + def test_save_and_load_roundtrip(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + data = {"general": {"session_id": "abc", "message_count": 1}} + _save_sessions(data) + loaded = _load_sessions() + assert loaded == data + + def test_save_creates_directory(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "deep" / "sessions" + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + _save_sessions({"test": "data"}) + assert sf.exists() + assert json.loads(sf.read_text()) == {"test": "data"} + + def test_atomic_write_no_corruption_on_error(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + # Write initial good data + _save_sessions({"good": "data"}) + + # Attempt to write non-serializable data → should raise + class Unserializable: + pass + + with pytest.raises(TypeError): + _save_sessions({"bad": Unserializable()}) + + # Original file should still be intact + assert json.loads(sf.read_text()) == {"good": "data"} + + +# --------------------------------------------------------------------------- +# start_session +# --------------------------------------------------------------------------- + + +class TestStartSession: + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_returns_response_and_session_id( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", sessions_dir / "active.json" + ) + mock_run.return_value = _make_proc() + + response, sid = start_session("general", "Hello") + assert response == "Hello from Claude!" + assert sid == "sess-abc-123" + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_saves_to_active_json( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + mock_run.return_value = _make_proc() + + start_session("general", "Hello") + + data = json.loads(sf.read_text()) + assert "general" in data + assert data["general"]["session_id"] == "sess-abc-123" + assert data["general"]["model"] == "sonnet" + assert data["general"]["message_count"] == 1 + assert "created_at" in data["general"] + assert "last_message_at" in data["general"] + + def test_invalid_model_raises(self): + with pytest.raises(ValueError, match="Invalid model"): + start_session("general", "Hello", model="gpt-4") + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_missing_result_field_raises( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", sessions_dir / "active.json" + ) + bad_response = {"session_id": "abc"} # missing "result" + mock_run.return_value = _make_proc(stdout=json.dumps(bad_response)) + + with pytest.raises(RuntimeError, match="missing required field"): + start_session("general", "Hello") + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_missing_session_id_field_raises( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", sessions_dir / "active.json" + ) + bad_response = {"result": "hello"} # missing "session_id" + mock_run.return_value = _make_proc(stdout=json.dumps(bad_response)) + + with pytest.raises(RuntimeError, match="missing required field"): + start_session("general", "Hello") + + +# --------------------------------------------------------------------------- +# resume_session +# --------------------------------------------------------------------------- + + +class TestResumeSession: + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_returns_response( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + # Pre-populate active.json + sf.write_text(json.dumps({ + "general": { + "session_id": "sess-abc-123", + "model": "sonnet", + "created_at": "2026-01-01T00:00:00+00:00", + "last_message_at": "2026-01-01T00:00:00+00:00", + "message_count": 1, + } + })) + + mock_run.return_value = _make_proc() + response = resume_session("sess-abc-123", "Follow up") + assert response == "Hello from Claude!" + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_updates_message_count_and_timestamp( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + old_ts = "2026-01-01T00:00:00+00:00" + sf.write_text(json.dumps({ + "general": { + "session_id": "sess-abc-123", + "model": "sonnet", + "created_at": old_ts, + "last_message_at": old_ts, + "message_count": 3, + } + })) + + mock_run.return_value = _make_proc() + resume_session("sess-abc-123", "Follow up") + + data = json.loads(sf.read_text()) + assert data["general"]["message_count"] == 4 + assert data["general"]["last_message_at"] != old_ts + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_uses_resume_flag(self, mock_run, mock_which, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + sf.write_text(json.dumps({})) + + mock_run.return_value = _make_proc() + resume_session("sess-abc-123", "Follow up") + + # Verify --resume was in the command + cmd = mock_run.call_args[0][0] + assert "--resume" in cmd + assert "sess-abc-123" in cmd + + +# --------------------------------------------------------------------------- +# send_message +# --------------------------------------------------------------------------- + + +class TestSendMessage: + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_starts_new_session_when_none_exists( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + sf.write_text("{}") + + mock_run.return_value = _make_proc() + response = send_message("general", "Hello") + assert response == "Hello from Claude!" + + # Should have created a session + data = json.loads(sf.read_text()) + assert "general" in data + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_resumes_existing_session( + self, mock_run, mock_which, tmp_path, monkeypatch + ): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text(json.dumps({ + "general": { + "session_id": "sess-existing", + "model": "sonnet", + "created_at": "2026-01-01T00:00:00+00:00", + "last_message_at": "2026-01-01T00:00:00+00:00", + "message_count": 1, + } + })) + + mock_run.return_value = _make_proc() + response = send_message("general", "Follow up") + assert response == "Hello from Claude!" + + # Should have used --resume + cmd = mock_run.call_args[0][0] + assert "--resume" in cmd + assert "sess-existing" in cmd + + +# --------------------------------------------------------------------------- +# clear_session +# --------------------------------------------------------------------------- + + +class TestClearSession: + def test_returns_true_when_existed(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text(json.dumps({"general": {"session_id": "abc"}})) + assert clear_session("general") is True + + def test_returns_false_when_not_existed(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text("{}") + assert clear_session("general") is False + + def test_removes_from_active_json(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text(json.dumps({ + "general": {"session_id": "abc"}, + "other": {"session_id": "def"}, + })) + clear_session("general") + + data = json.loads(sf.read_text()) + assert "general" not in data + assert "other" in data + + +# --------------------------------------------------------------------------- +# get_active_session +# --------------------------------------------------------------------------- + + +class TestGetActiveSession: + def test_returns_dict_when_exists(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + session_data = {"session_id": "abc", "model": "sonnet"} + sf.write_text(json.dumps({"general": session_data})) + + result = get_active_session("general") + assert result == session_data + + def test_returns_none_when_not_exists(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text("{}") + assert get_active_session("general") is None + + def test_returns_none_when_file_missing(self, tmp_path, monkeypatch): + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", tmp_path / "nonexistent.json" + ) + assert get_active_session("general") is None + + +# --------------------------------------------------------------------------- +# list_sessions +# --------------------------------------------------------------------------- + + +class TestListSessions: + def test_returns_all_sessions(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + data = { + "general": {"session_id": "abc"}, + "dev": {"session_id": "def"}, + } + sf.write_text(json.dumps(data)) + + result = list_sessions() + assert result == data + + def test_returns_empty_when_none(self, tmp_path, monkeypatch): + monkeypatch.setattr( + claude_session, "_SESSIONS_FILE", tmp_path / "nonexistent.json" + ) + assert list_sessions() == {}