Fix arhitectural general (beneficiu și pentru text adapters), nu doar voice.
src/claude_session.py:
- _session_locks: dict[str, threading.Lock] cu bootstrap lock pentru
lazy creation thread-safe.
- _get_session_lock(channel_id) helper.
- send_message() body wrapped în with _get_session_lock(channel_id).
- threading.Lock (NU asyncio.Lock) — send_message e sync subprocess.run
blocking; asyncio.Lock nu protejează cod sync rulat via to_thread.
- Per-channel granularity preserved — different channels run în paralel.
- send_message() public signature unchanged.
src/router.py:
- route_message(): dacă adapter_name == "discord-voice", prepend
[speaker:<user_name>] prefix (Config.get("voice.user_name", "user")).
- Original text variable left untouched for downstream paths.
- Text adapters: zero behavior change.
- route_message() public signature unchanged.
tests/test_claude_session_mutex.py — 6 tests REGRESSION-CRITICAL:
- same channel serializes (concurrent → mutex serializes, no overlap)
- same channel lock identity (same dict entry per channel_id)
- different channels run in parallel (overlap MUST fire)
- 3 channels all overlap
- contested acquire blocks then proceeds (policy: blocking, not fail-fast)
- lock released on subprocess exception (no deadlock on crash)
Acquisition policy: BLOCKING acquire bound by claude --timeout (5min default)
nu fail-fast — adapters already serialize via asyncio.to_thread queue, un
non-blocking acquire ar surface transient busy errors.
Test results: 82 passed (51 existing + 31 new). 2 PRE-EXISTING failures în
TestPromptInjectionProtection (stale assertion vs current prompt text) —
out of scope, recomand ticket separat.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
308 lines
12 KiB
Python
308 lines
12 KiB
Python
"""Regression-critical tests for per-channel mutex in src/claude_session.py.
|
||
|
||
Three scenarios from the eng-review test plan (2026-05-27):
|
||
|
||
1. Concurrent `send_message` calls on the SAME channel_id serialize —
|
||
the second waits for the first to finish before its subprocess runs.
|
||
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
|
||
— independent channels never block each other.
|
||
3. Acquisition contract is documented and consistent: the lock is acquired
|
||
blocking (no acquire timeout), which means a hung subprocess on
|
||
channel X delays subsequent X messages but never X' (X != X'). This
|
||
test pins that behavior so future refactors must preserve it.
|
||
|
||
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
|
||
is a sync function typically dispatched via `asyncio.to_thread` from
|
||
async adapters. asyncio.Lock would serialize coroutines only — not the
|
||
subprocess invocation. See plan section "Engineering decisions" #2.
|
||
"""
|
||
|
||
import json
|
||
import threading
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from unittest.mock import patch
|
||
|
||
import pytest
|
||
|
||
from src import claude_session
|
||
from src.claude_session import (
|
||
_get_session_lock,
|
||
_session_locks,
|
||
send_message,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _clear_session_locks():
|
||
"""Each test starts with a fresh lock map so we don't share state."""
|
||
_session_locks.clear()
|
||
yield
|
||
_session_locks.clear()
|
||
|
||
|
||
@pytest.fixture
|
||
def temp_sessions(tmp_path, monkeypatch):
|
||
"""Isolated active.json per test — keeps real session state untouched."""
|
||
sessions_dir = tmp_path / "sessions"
|
||
sessions_dir.mkdir()
|
||
sf = sessions_dir / "active.json"
|
||
sf.write_text("{}")
|
||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||
return sf
|
||
|
||
|
||
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
|
||
concurrent_seen: threading.Event):
|
||
"""Build a fake `_run_claude` that signals when inside the critical section.
|
||
|
||
The fake holds the simulated subprocess for `sleep_seconds`. Any other
|
||
invocation that overlaps will set `concurrent_seen` — the mutex test
|
||
asserts this NEVER happens for the same channel_id.
|
||
"""
|
||
state = {"active": 0, "lock": threading.Lock()}
|
||
|
||
def fake(cmd, timeout, on_text=None, cwd=None):
|
||
with state["lock"]:
|
||
state["active"] += 1
|
||
if state["active"] > 1:
|
||
concurrent_seen.set()
|
||
in_critical.set()
|
||
time.sleep(sleep_seconds)
|
||
with state["lock"]:
|
||
state["active"] -= 1
|
||
return {
|
||
"result": "Hello from Claude!",
|
||
"session_id": "sess-abc-123",
|
||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||
"total_cost_usd": 0.001,
|
||
"cost_usd": 0.001,
|
||
"duration_ms": int(sleep_seconds * 1000),
|
||
"num_turns": 1,
|
||
"intermediate_count": 0,
|
||
"subtype": "success",
|
||
"is_error": False,
|
||
}
|
||
|
||
return fake
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Scenario 1 — same channel serializes
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSameChannelSerializes:
|
||
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
|
||
self, temp_sessions
|
||
):
|
||
"""Two parallel send_message on the SAME channel_id never overlap.
|
||
|
||
We instrument `_run_claude` to signal whenever more than one
|
||
invocation is concurrently inside it. The mutex MUST prevent that.
|
||
"""
|
||
in_critical = threading.Event()
|
||
concurrent_seen = threading.Event()
|
||
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
|
||
|
||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||
start = time.monotonic()
|
||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||
futures = [
|
||
pool.submit(send_message, "ch-same", f"msg-{i}")
|
||
for i in range(2)
|
||
]
|
||
results = [f.result(timeout=10) for f in futures]
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert not concurrent_seen.is_set(), (
|
||
"Two send_message calls on the same channel ran concurrently — "
|
||
"mutex did not serialize them."
|
||
)
|
||
assert all(r == "Hello from Claude!" for r in results)
|
||
# Two serial 0.25s subprocesses must take at least ~0.5s total
|
||
# (we allow a generous floor — schedulers can be slow).
|
||
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
|
||
|
||
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
|
||
"""`_get_session_lock` returns the SAME lock object for the same channel."""
|
||
lock_a1 = _get_session_lock("channel-A")
|
||
lock_a2 = _get_session_lock("channel-A")
|
||
lock_b = _get_session_lock("channel-B")
|
||
assert lock_a1 is lock_a2
|
||
assert lock_a1 is not lock_b
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Scenario 2 — different channels parallel
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestDifferentChannelsParallel:
|
||
def test_two_concurrent_calls_different_channels_run_in_parallel(
|
||
self, temp_sessions
|
||
):
|
||
"""Different channels MUST NOT block each other.
|
||
|
||
We measure elapsed wall-clock: two 0.4s subprocesses on different
|
||
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
|
||
"""
|
||
in_critical = threading.Event()
|
||
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
|
||
concurrent_seen = threading.Event()
|
||
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
|
||
|
||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||
start = time.monotonic()
|
||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||
f1 = pool.submit(send_message, "ch-A", "msg-A")
|
||
f2 = pool.submit(send_message, "ch-B", "msg-B")
|
||
results = [f1.result(timeout=10), f2.result(timeout=10)]
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert all(r == "Hello from Claude!" for r in results)
|
||
# Parallel execution: total time should be close to 0.4s, well under
|
||
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
|
||
# scheduler jitter on a busy test box.
|
||
assert elapsed < 0.65, (
|
||
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
|
||
f"(expected ~0.4s parallel, <0.65s ceiling)"
|
||
)
|
||
assert concurrent_seen.is_set(), (
|
||
"Different channels did not overlap — mutex is too coarse "
|
||
"(should be per-channel, not global)."
|
||
)
|
||
|
||
def test_three_channels_all_overlap(self, temp_sessions):
|
||
"""Stress: three concurrent channels all run in parallel."""
|
||
in_critical = threading.Event()
|
||
concurrent_seen = threading.Event()
|
||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||
|
||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||
start = time.monotonic()
|
||
with ThreadPoolExecutor(max_workers=3) as pool:
|
||
futures = [
|
||
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
|
||
for i in range(3)
|
||
]
|
||
for f in as_completed(futures, timeout=10):
|
||
assert f.result() == "Hello from Claude!"
|
||
elapsed = time.monotonic() - start
|
||
|
||
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
|
||
assert elapsed < 0.6, (
|
||
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Scenario 3 — acquisition behavior documented and consistent
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestAcquisitionBehavior:
|
||
"""Pin the chosen acquisition policy: blocking, no timeout.
|
||
|
||
Project style is to bound subprocess execution via `timeout` (default
|
||
5 min) rather than fail-fast on lock acquire. Reasons:
|
||
|
||
- Adapter callers (Discord/Telegram/voice) already serialize work via
|
||
asyncio.to_thread; queue depth is naturally bounded.
|
||
- A non-blocking acquire would surface a timing error to the user
|
||
("busy, try again") for an entirely transient and self-resolving
|
||
condition. Blocking gives FIFO-ish ordering with simple semantics.
|
||
- If a subprocess truly hangs past `timeout`, _run_claude raises
|
||
TimeoutError → the held lock releases (via `with`) → queued
|
||
callers proceed.
|
||
|
||
This test pins that: a second caller waits and eventually proceeds; it
|
||
does not raise an exception on contention.
|
||
"""
|
||
|
||
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
|
||
in_critical = threading.Event()
|
||
concurrent_seen = threading.Event()
|
||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||
|
||
results: list[str | BaseException] = []
|
||
|
||
def run(label: str):
|
||
try:
|
||
results.append(send_message("ch-contend", label))
|
||
except BaseException as e:
|
||
results.append(e)
|
||
|
||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||
t1 = threading.Thread(target=run, args=("first",))
|
||
t1.start()
|
||
# Wait until the first call is inside the critical section so
|
||
# the second is GUARANTEED to contend on the lock.
|
||
assert in_critical.wait(timeout=2.0), "first call never entered"
|
||
in_critical.clear()
|
||
t2 = threading.Thread(target=run, args=("second",))
|
||
t2.start()
|
||
t1.join(timeout=5.0)
|
||
t2.join(timeout=5.0)
|
||
|
||
assert len(results) == 2
|
||
# Both must return the canned response — no exception, no error.
|
||
assert all(r == "Hello from Claude!" for r in results), (
|
||
f"Contended acquire surfaced an error instead of blocking: {results}"
|
||
)
|
||
# Critical-section overlap check: contended calls MUST serialize.
|
||
assert not concurrent_seen.is_set(), (
|
||
"Contended same-channel calls ran concurrently — mutex broken."
|
||
)
|
||
|
||
def test_lock_released_on_subprocess_exception(self, temp_sessions):
|
||
"""If `_run_claude` raises, the lock MUST be released so the next
|
||
caller can proceed (otherwise a single error deadlocks the channel
|
||
forever)."""
|
||
|
||
call_count = {"n": 0}
|
||
|
||
def flaky(cmd, timeout, on_text=None, cwd=None):
|
||
call_count["n"] += 1
|
||
if call_count["n"] == 1:
|
||
raise RuntimeError("simulated subprocess crash")
|
||
return {
|
||
"result": "Hello from Claude!",
|
||
"session_id": "sess-abc-123",
|
||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||
"total_cost_usd": 0.001,
|
||
"cost_usd": 0.001,
|
||
"duration_ms": 50,
|
||
"num_turns": 1,
|
||
"intermediate_count": 0,
|
||
"subtype": "success",
|
||
"is_error": False,
|
||
}
|
||
|
||
with patch.object(claude_session, "_run_claude", side_effect=flaky):
|
||
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
|
||
send_message("ch-recover", "first")
|
||
|
||
# Second call MUST acquire the lock (proves the first released it).
|
||
# We use a short timeout via a thread so a deadlock would fail loudly.
|
||
done = threading.Event()
|
||
result_box: list[str] = []
|
||
|
||
def second():
|
||
result_box.append(send_message("ch-recover", "second"))
|
||
done.set()
|
||
|
||
t = threading.Thread(target=second)
|
||
t.start()
|
||
assert done.wait(timeout=3.0), (
|
||
"Second call deadlocked — lock was not released on exception."
|
||
)
|
||
t.join(timeout=1.0)
|
||
assert result_box == ["Hello from Claude!"]
|