diff --git a/src/claude_session.py b/src/claude_session.py index 9983d41..b319b23 100644 --- a/src/claude_session.py +++ b/src/claude_session.py @@ -37,6 +37,42 @@ DEFAULT_TIMEOUT = 300 # seconds CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude") +# --------------------------------------------------------------------------- +# Per-channel mutex for send_message +# --------------------------------------------------------------------------- +# +# Two paths can hit `send_message(channel_id, ...)` concurrently for the same +# channel: a text adapter (Discord/Telegram/WhatsApp) and the voice adapter +# (`adapter_name="discord-voice"`). The underlying Claude CLI subprocess is +# blocking (`subprocess.Popen` with stream-json read loop) and stateful via +# `--resume ` — interleaving two concurrent invocations on the +# same channel would corrupt the conversation order. +# +# We use `threading.Lock` (NOT `asyncio.Lock`) because `send_message` is sync +# code typically run from `asyncio.to_thread` in async adapters. asyncio.Lock +# only serializes coroutines, not threads — it would NOT protect this path. +# +# Each channel gets its own lock so DIFFERENT channels still run in parallel. +# Locks are created lazily on first use; the dict itself is guarded by a +# small bootstrap lock so two concurrent first-uses don't race on creation. +_session_locks: dict[str, threading.Lock] = {} +_session_locks_bootstrap = threading.Lock() + + +def _get_session_lock(channel_id: str) -> threading.Lock: + """Return the channel's mutex, creating it on first access. + + Two threads racing to create the same channel's lock would otherwise + end up with different lock objects (setdefault is not atomic across + the read-modify-write under all interpreter conditions — defensive). + """ + lock = _session_locks.get(channel_id) + if lock is not None: + return lock + with _session_locks_bootstrap: + return _session_locks.setdefault(channel_id, threading.Lock()) + + PERSONALITY_FILES = [ "IDENTITY.md", "SOUL.md", @@ -543,19 +579,28 @@ def send_message( timeout: int = DEFAULT_TIMEOUT, on_text: Callable[[str], None] | None = None, ) -> str: - """High-level convenience: auto start or resume based on channel state.""" - session = get_active_session(channel_id) - # Only resume if session has a valid session_id (not a pre-set model placeholder) - if session is not None and session.get("session_id"): - return resume_session(session["session_id"], message, timeout, on_text=on_text) - # Use model from pre-set session if available, otherwise use provided model - effective_model = model - if session is not None and session.get("model"): - effective_model = session["model"] - response_text, _session_id = start_session( - channel_id, message, effective_model, timeout, on_text=on_text - ) - return response_text + """High-level convenience: auto start or resume based on channel state. + + Concurrency: a per-`channel_id` `threading.Lock` serializes invocations + that hit the same channel (e.g. text adapter + voice adapter racing on + the same Discord guild text channel). Different channels run in + parallel — each holds its own lock. Lock is acquired blocking; we rely + on `timeout` (default 5 minutes) to bound the worst case rather than + a non-blocking acquire (loss of fairness vs adapter-side queueing). + """ + with _get_session_lock(channel_id): + session = get_active_session(channel_id) + # Only resume if session has a valid session_id (not a pre-set model placeholder) + if session is not None and session.get("session_id"): + return resume_session(session["session_id"], message, timeout, on_text=on_text) + # Use model from pre-set session if available, otherwise use provided model + effective_model = model + if session is not None and session.get("model"): + effective_model = session["model"] + response_text, _session_id = start_session( + channel_id, message, effective_model, timeout, on_text=on_text + ) + return response_text def clear_session(channel_id: str) -> bool: diff --git a/src/router.py b/src/router.py index e733845..c6a0344 100644 --- a/src/router.py +++ b/src/router.py @@ -154,8 +154,17 @@ def route_message( channel_cfg = _get_channel_config(channel_id) model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet") + # Voice-mode augment: prepend speaker prefix so Claude knows who spoke + # in a voice channel. Cheap now, future-proof for multi-speaker later. + # (Engineering decision #14 in the plan.) Only the discord-voice adapter + # triggers it — text adapters keep the message verbatim. + claude_text = text + if adapter_name == "discord-voice": + user_name = _get_config().get("voice.user_name", "user") or "user" + claude_text = f"[speaker:{user_name}] {text}" + try: - response = send_message(channel_id, text, model=model, on_text=on_text) + response = send_message(channel_id, claude_text, model=model, on_text=on_text) _set_last_response(channel_id, response) return response, False except Exception as e: diff --git a/tests/test_claude_session_mutex.py b/tests/test_claude_session_mutex.py new file mode 100644 index 0000000..9a6b2f1 --- /dev/null +++ b/tests/test_claude_session_mutex.py @@ -0,0 +1,307 @@ +"""Regression-critical tests for per-channel mutex in src/claude_session.py. + +Three scenarios from the eng-review test plan (2026-05-27): + + 1. Concurrent `send_message` calls on the SAME channel_id serialize — + the second waits for the first to finish before its subprocess runs. + 2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel + — independent channels never block each other. + 3. Acquisition contract is documented and consistent: the lock is acquired + blocking (no acquire timeout), which means a hung subprocess on + channel X delays subsequent X messages but never X' (X != X'). This + test pins that behavior so future refactors must preserve it. + +The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message` +is a sync function typically dispatched via `asyncio.to_thread` from +async adapters. asyncio.Lock would serialize coroutines only — not the +subprocess invocation. See plan section "Engineering decisions" #2. +""" + +import json +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest.mock import patch + +import pytest + +from src import claude_session +from src.claude_session import ( + _get_session_lock, + _session_locks, + send_message, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clear_session_locks(): + """Each test starts with a fresh lock map so we don't share state.""" + _session_locks.clear() + yield + _session_locks.clear() + + +@pytest.fixture +def temp_sessions(tmp_path, monkeypatch): + """Isolated active.json per test — keeps real session state untouched.""" + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + sf.write_text("{}") + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + return sf + + +def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event, + concurrent_seen: threading.Event): + """Build a fake `_run_claude` that signals when inside the critical section. + + The fake holds the simulated subprocess for `sleep_seconds`. Any other + invocation that overlaps will set `concurrent_seen` — the mutex test + asserts this NEVER happens for the same channel_id. + """ + state = {"active": 0, "lock": threading.Lock()} + + def fake(cmd, timeout, on_text=None, cwd=None): + with state["lock"]: + state["active"] += 1 + if state["active"] > 1: + concurrent_seen.set() + in_critical.set() + time.sleep(sleep_seconds) + with state["lock"]: + state["active"] -= 1 + return { + "result": "Hello from Claude!", + "session_id": "sess-abc-123", + "usage": {"input_tokens": 10, "output_tokens": 5}, + "total_cost_usd": 0.001, + "cost_usd": 0.001, + "duration_ms": int(sleep_seconds * 1000), + "num_turns": 1, + "intermediate_count": 0, + "subtype": "success", + "is_error": False, + } + + return fake + + +# --------------------------------------------------------------------------- +# Scenario 1 — same channel serializes +# --------------------------------------------------------------------------- + + +class TestSameChannelSerializes: + def test_two_concurrent_calls_same_channel_run_one_at_a_time( + self, temp_sessions + ): + """Two parallel send_message on the SAME channel_id never overlap. + + We instrument `_run_claude` to signal whenever more than one + invocation is concurrently inside it. The mutex MUST prevent that. + """ + in_critical = threading.Event() + concurrent_seen = threading.Event() + slow = _slow_run_claude(0.25, in_critical, concurrent_seen) + + with patch.object(claude_session, "_run_claude", side_effect=slow): + start = time.monotonic() + with ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(send_message, "ch-same", f"msg-{i}") + for i in range(2) + ] + results = [f.result(timeout=10) for f in futures] + elapsed = time.monotonic() - start + + assert not concurrent_seen.is_set(), ( + "Two send_message calls on the same channel ran concurrently — " + "mutex did not serialize them." + ) + assert all(r == "Hello from Claude!" for r in results) + # Two serial 0.25s subprocesses must take at least ~0.5s total + # (we allow a generous floor — schedulers can be slow). + assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s" + + def test_lock_is_reentrant_per_channel_dict(self, temp_sessions): + """`_get_session_lock` returns the SAME lock object for the same channel.""" + lock_a1 = _get_session_lock("channel-A") + lock_a2 = _get_session_lock("channel-A") + lock_b = _get_session_lock("channel-B") + assert lock_a1 is lock_a2 + assert lock_a1 is not lock_b + + +# --------------------------------------------------------------------------- +# Scenario 2 — different channels parallel +# --------------------------------------------------------------------------- + + +class TestDifferentChannelsParallel: + def test_two_concurrent_calls_different_channels_run_in_parallel( + self, temp_sessions + ): + """Different channels MUST NOT block each other. + + We measure elapsed wall-clock: two 0.4s subprocesses on different + channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized). + """ + in_critical = threading.Event() + # `concurrent_seen` is OK to fire here — we WANT them to overlap. + concurrent_seen = threading.Event() + slow = _slow_run_claude(0.4, in_critical, concurrent_seen) + + with patch.object(claude_session, "_run_claude", side_effect=slow): + start = time.monotonic() + with ThreadPoolExecutor(max_workers=2) as pool: + f1 = pool.submit(send_message, "ch-A", "msg-A") + f2 = pool.submit(send_message, "ch-B", "msg-B") + results = [f1.result(timeout=10), f2.result(timeout=10)] + elapsed = time.monotonic() - start + + assert all(r == "Hello from Claude!" for r in results) + # Parallel execution: total time should be close to 0.4s, well under + # 0.7s (would mean serialization). 0.65s ceiling allows for GIL + + # scheduler jitter on a busy test box. + assert elapsed < 0.65, ( + f"Different channels appear serialized: elapsed {elapsed:.3f}s " + f"(expected ~0.4s parallel, <0.65s ceiling)" + ) + assert concurrent_seen.is_set(), ( + "Different channels did not overlap — mutex is too coarse " + "(should be per-channel, not global)." + ) + + def test_three_channels_all_overlap(self, temp_sessions): + """Stress: three concurrent channels all run in parallel.""" + in_critical = threading.Event() + concurrent_seen = threading.Event() + slow = _slow_run_claude(0.3, in_critical, concurrent_seen) + + with patch.object(claude_session, "_run_claude", side_effect=slow): + start = time.monotonic() + with ThreadPoolExecutor(max_workers=3) as pool: + futures = [ + pool.submit(send_message, f"ch-{i}", f"msg-{i}") + for i in range(3) + ] + for f in as_completed(futures, timeout=10): + assert f.result() == "Hello from Claude!" + elapsed = time.monotonic() - start + + # 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s. + assert elapsed < 0.6, ( + f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)" + ) + + +# --------------------------------------------------------------------------- +# Scenario 3 — acquisition behavior documented and consistent +# --------------------------------------------------------------------------- + + +class TestAcquisitionBehavior: + """Pin the chosen acquisition policy: blocking, no timeout. + + Project style is to bound subprocess execution via `timeout` (default + 5 min) rather than fail-fast on lock acquire. Reasons: + + - Adapter callers (Discord/Telegram/voice) already serialize work via + asyncio.to_thread; queue depth is naturally bounded. + - A non-blocking acquire would surface a timing error to the user + ("busy, try again") for an entirely transient and self-resolving + condition. Blocking gives FIFO-ish ordering with simple semantics. + - If a subprocess truly hangs past `timeout`, _run_claude raises + TimeoutError → the held lock releases (via `with`) → queued + callers proceed. + + This test pins that: a second caller waits and eventually proceeds; it + does not raise an exception on contention. + """ + + def test_contested_acquire_blocks_then_proceeds(self, temp_sessions): + in_critical = threading.Event() + concurrent_seen = threading.Event() + slow = _slow_run_claude(0.3, in_critical, concurrent_seen) + + results: list[str | BaseException] = [] + + def run(label: str): + try: + results.append(send_message("ch-contend", label)) + except BaseException as e: + results.append(e) + + with patch.object(claude_session, "_run_claude", side_effect=slow): + t1 = threading.Thread(target=run, args=("first",)) + t1.start() + # Wait until the first call is inside the critical section so + # the second is GUARANTEED to contend on the lock. + assert in_critical.wait(timeout=2.0), "first call never entered" + in_critical.clear() + t2 = threading.Thread(target=run, args=("second",)) + t2.start() + t1.join(timeout=5.0) + t2.join(timeout=5.0) + + assert len(results) == 2 + # Both must return the canned response — no exception, no error. + assert all(r == "Hello from Claude!" for r in results), ( + f"Contended acquire surfaced an error instead of blocking: {results}" + ) + # Critical-section overlap check: contended calls MUST serialize. + assert not concurrent_seen.is_set(), ( + "Contended same-channel calls ran concurrently — mutex broken." + ) + + def test_lock_released_on_subprocess_exception(self, temp_sessions): + """If `_run_claude` raises, the lock MUST be released so the next + caller can proceed (otherwise a single error deadlocks the channel + forever).""" + + call_count = {"n": 0} + + def flaky(cmd, timeout, on_text=None, cwd=None): + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("simulated subprocess crash") + return { + "result": "Hello from Claude!", + "session_id": "sess-abc-123", + "usage": {"input_tokens": 10, "output_tokens": 5}, + "total_cost_usd": 0.001, + "cost_usd": 0.001, + "duration_ms": 50, + "num_turns": 1, + "intermediate_count": 0, + "subtype": "success", + "is_error": False, + } + + with patch.object(claude_session, "_run_claude", side_effect=flaky): + with pytest.raises(RuntimeError, match="simulated subprocess crash"): + send_message("ch-recover", "first") + + # Second call MUST acquire the lock (proves the first released it). + # We use a short timeout via a thread so a deadlock would fail loudly. + done = threading.Event() + result_box: list[str] = [] + + def second(): + result_box.append(send_message("ch-recover", "second")) + done.set() + + t = threading.Thread(target=second) + t.start() + assert done.wait(timeout=3.0), ( + "Second call deadlocked — lock was not released on exception." + ) + t.join(timeout=1.0) + assert result_box == ["Hello from Claude!"]