"""Regression-critical tests for per-channel mutex in src/claude_session.py. Three scenarios from the eng-review test plan (2026-05-27): 1. Concurrent `send_message` calls on the SAME channel_id serialize — the second waits for the first to finish before its subprocess runs. 2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel — independent channels never block each other. 3. Acquisition contract is documented and consistent: the lock is acquired blocking (no acquire timeout), which means a hung subprocess on channel X delays subsequent X messages but never X' (X != X'). This test pins that behavior so future refactors must preserve it. The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message` is a sync function typically dispatched via `asyncio.to_thread` from async adapters. asyncio.Lock would serialize coroutines only — not the subprocess invocation. See plan section "Engineering decisions" #2. """ import json import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch import pytest from src import claude_session from src.claude_session import ( _get_session_lock, _session_locks, send_message, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def _clear_session_locks(): """Each test starts with a fresh lock map so we don't share state.""" _session_locks.clear() yield _session_locks.clear() @pytest.fixture def temp_sessions(tmp_path, monkeypatch): """Isolated active.json per test — keeps real session state untouched.""" sessions_dir = tmp_path / "sessions" sessions_dir.mkdir() sf = sessions_dir / "active.json" sf.write_text("{}") monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) return sf def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event, concurrent_seen: threading.Event): """Build a fake `_run_claude` that signals when inside the critical section. The fake holds the simulated subprocess for `sleep_seconds`. Any other invocation that overlaps will set `concurrent_seen` — the mutex test asserts this NEVER happens for the same channel_id. """ state = {"active": 0, "lock": threading.Lock()} def fake(cmd, timeout, on_text=None, cwd=None): with state["lock"]: state["active"] += 1 if state["active"] > 1: concurrent_seen.set() in_critical.set() time.sleep(sleep_seconds) with state["lock"]: state["active"] -= 1 return { "result": "Hello from Claude!", "session_id": "sess-abc-123", "usage": {"input_tokens": 10, "output_tokens": 5}, "total_cost_usd": 0.001, "cost_usd": 0.001, "duration_ms": int(sleep_seconds * 1000), "num_turns": 1, "intermediate_count": 0, "subtype": "success", "is_error": False, } return fake # --------------------------------------------------------------------------- # Scenario 1 — same channel serializes # --------------------------------------------------------------------------- class TestSameChannelSerializes: def test_two_concurrent_calls_same_channel_run_one_at_a_time( self, temp_sessions ): """Two parallel send_message on the SAME channel_id never overlap. We instrument `_run_claude` to signal whenever more than one invocation is concurrently inside it. The mutex MUST prevent that. """ in_critical = threading.Event() concurrent_seen = threading.Event() slow = _slow_run_claude(0.25, in_critical, concurrent_seen) with patch.object(claude_session, "_run_claude", side_effect=slow): start = time.monotonic() with ThreadPoolExecutor(max_workers=2) as pool: futures = [ pool.submit(send_message, "ch-same", f"msg-{i}") for i in range(2) ] results = [f.result(timeout=10) for f in futures] elapsed = time.monotonic() - start assert not concurrent_seen.is_set(), ( "Two send_message calls on the same channel ran concurrently — " "mutex did not serialize them." ) assert all(r == "Hello from Claude!" for r in results) # Two serial 0.25s subprocesses must take at least ~0.5s total # (we allow a generous floor — schedulers can be slow). assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s" def test_lock_is_reentrant_per_channel_dict(self, temp_sessions): """`_get_session_lock` returns the SAME lock object for the same channel.""" lock_a1 = _get_session_lock("channel-A") lock_a2 = _get_session_lock("channel-A") lock_b = _get_session_lock("channel-B") assert lock_a1 is lock_a2 assert lock_a1 is not lock_b # --------------------------------------------------------------------------- # Scenario 2 — different channels parallel # --------------------------------------------------------------------------- class TestDifferentChannelsParallel: def test_two_concurrent_calls_different_channels_run_in_parallel( self, temp_sessions ): """Different channels MUST NOT block each other. We measure elapsed wall-clock: two 0.4s subprocesses on different channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized). """ in_critical = threading.Event() # `concurrent_seen` is OK to fire here — we WANT them to overlap. concurrent_seen = threading.Event() slow = _slow_run_claude(0.4, in_critical, concurrent_seen) with patch.object(claude_session, "_run_claude", side_effect=slow): start = time.monotonic() with ThreadPoolExecutor(max_workers=2) as pool: f1 = pool.submit(send_message, "ch-A", "msg-A") f2 = pool.submit(send_message, "ch-B", "msg-B") results = [f1.result(timeout=10), f2.result(timeout=10)] elapsed = time.monotonic() - start assert all(r == "Hello from Claude!" for r in results) # Parallel execution: total time should be close to 0.4s, well under # 0.7s (would mean serialization). 0.65s ceiling allows for GIL + # scheduler jitter on a busy test box. assert elapsed < 0.65, ( f"Different channels appear serialized: elapsed {elapsed:.3f}s " f"(expected ~0.4s parallel, <0.65s ceiling)" ) assert concurrent_seen.is_set(), ( "Different channels did not overlap — mutex is too coarse " "(should be per-channel, not global)." ) def test_three_channels_all_overlap(self, temp_sessions): """Stress: three concurrent channels all run in parallel.""" in_critical = threading.Event() concurrent_seen = threading.Event() slow = _slow_run_claude(0.3, in_critical, concurrent_seen) with patch.object(claude_session, "_run_claude", side_effect=slow): start = time.monotonic() with ThreadPoolExecutor(max_workers=3) as pool: futures = [ pool.submit(send_message, f"ch-{i}", f"msg-{i}") for i in range(3) ] for f in as_completed(futures, timeout=10): assert f.result() == "Hello from Claude!" elapsed = time.monotonic() - start # 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s. assert elapsed < 0.6, ( f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)" ) # --------------------------------------------------------------------------- # Scenario 3 — acquisition behavior documented and consistent # --------------------------------------------------------------------------- class TestAcquisitionBehavior: """Pin the chosen acquisition policy: blocking, no timeout. Project style is to bound subprocess execution via `timeout` (default 5 min) rather than fail-fast on lock acquire. Reasons: - Adapter callers (Discord/Telegram/voice) already serialize work via asyncio.to_thread; queue depth is naturally bounded. - A non-blocking acquire would surface a timing error to the user ("busy, try again") for an entirely transient and self-resolving condition. Blocking gives FIFO-ish ordering with simple semantics. - If a subprocess truly hangs past `timeout`, _run_claude raises TimeoutError → the held lock releases (via `with`) → queued callers proceed. This test pins that: a second caller waits and eventually proceeds; it does not raise an exception on contention. """ def test_contested_acquire_blocks_then_proceeds(self, temp_sessions): in_critical = threading.Event() concurrent_seen = threading.Event() slow = _slow_run_claude(0.3, in_critical, concurrent_seen) results: list[str | BaseException] = [] def run(label: str): try: results.append(send_message("ch-contend", label)) except BaseException as e: results.append(e) with patch.object(claude_session, "_run_claude", side_effect=slow): t1 = threading.Thread(target=run, args=("first",)) t1.start() # Wait until the first call is inside the critical section so # the second is GUARANTEED to contend on the lock. assert in_critical.wait(timeout=2.0), "first call never entered" in_critical.clear() t2 = threading.Thread(target=run, args=("second",)) t2.start() t1.join(timeout=5.0) t2.join(timeout=5.0) assert len(results) == 2 # Both must return the canned response — no exception, no error. assert all(r == "Hello from Claude!" for r in results), ( f"Contended acquire surfaced an error instead of blocking: {results}" ) # Critical-section overlap check: contended calls MUST serialize. assert not concurrent_seen.is_set(), ( "Contended same-channel calls ran concurrently — mutex broken." ) def test_lock_released_on_subprocess_exception(self, temp_sessions): """If `_run_claude` raises, the lock MUST be released so the next caller can proceed (otherwise a single error deadlocks the channel forever).""" call_count = {"n": 0} def flaky(cmd, timeout, on_text=None, cwd=None): call_count["n"] += 1 if call_count["n"] == 1: raise RuntimeError("simulated subprocess crash") return { "result": "Hello from Claude!", "session_id": "sess-abc-123", "usage": {"input_tokens": 10, "output_tokens": 5}, "total_cost_usd": 0.001, "cost_usd": 0.001, "duration_ms": 50, "num_turns": 1, "intermediate_count": 0, "subtype": "success", "is_error": False, } with patch.object(claude_session, "_run_claude", side_effect=flaky): with pytest.raises(RuntimeError, match="simulated subprocess crash"): send_message("ch-recover", "first") # Second call MUST acquire the lock (proves the first released it). # We use a short timeout via a thread so a deadlock would fail loudly. done = threading.Event() result_box: list[str] = [] def second(): result_box.append(send_message("ch-recover", "second")) done.set() t = threading.Thread(target=second) t.start() assert done.wait(timeout=3.0), ( "Second call deadlocked — lock was not released on exception." ) t.join(timeout=1.0) assert result_box == ["Hello from Claude!"]