Files
echo-core/tests/test_claude_session_mutex.py
Marius Mutu 3af6bcaea4 feat(voice): Pas 8 — threading.Lock per channel_id mutex + voice augment
Fix arhitectural general (beneficiu și pentru text adapters), nu doar voice.

src/claude_session.py:
- _session_locks: dict[str, threading.Lock] cu bootstrap lock pentru
  lazy creation thread-safe.
- _get_session_lock(channel_id) helper.
- send_message() body wrapped în with _get_session_lock(channel_id).
- threading.Lock (NU asyncio.Lock) — send_message e sync subprocess.run
  blocking; asyncio.Lock nu protejează cod sync rulat via to_thread.
- Per-channel granularity preserved — different channels run în paralel.
- send_message() public signature unchanged.

src/router.py:
- route_message(): dacă adapter_name == "discord-voice", prepend
  [speaker:<user_name>] prefix (Config.get("voice.user_name", "user")).
- Original text variable left untouched for downstream paths.
- Text adapters: zero behavior change.
- route_message() public signature unchanged.

tests/test_claude_session_mutex.py — 6 tests REGRESSION-CRITICAL:
- same channel serializes (concurrent → mutex serializes, no overlap)
- same channel lock identity (same dict entry per channel_id)
- different channels run in parallel (overlap MUST fire)
- 3 channels all overlap
- contested acquire blocks then proceeds (policy: blocking, not fail-fast)
- lock released on subprocess exception (no deadlock on crash)

Acquisition policy: BLOCKING acquire bound by claude --timeout (5min default)
nu fail-fast — adapters already serialize via asyncio.to_thread queue, un
non-blocking acquire ar surface transient busy errors.

Test results: 82 passed (51 existing + 31 new). 2 PRE-EXISTING failures în
TestPromptInjectionProtection (stale assertion vs current prompt text) —
out of scope, recomand ticket separat.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 14:43:05 +00:00

308 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Regression-critical tests for per-channel mutex in src/claude_session.py.
Three scenarios from the eng-review test plan (2026-05-27):
1. Concurrent `send_message` calls on the SAME channel_id serialize —
the second waits for the first to finish before its subprocess runs.
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
— independent channels never block each other.
3. Acquisition contract is documented and consistent: the lock is acquired
blocking (no acquire timeout), which means a hung subprocess on
channel X delays subsequent X messages but never X' (X != X'). This
test pins that behavior so future refactors must preserve it.
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
is a sync function typically dispatched via `asyncio.to_thread` from
async adapters. asyncio.Lock would serialize coroutines only — not the
subprocess invocation. See plan section "Engineering decisions" #2.
"""
import json
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest.mock import patch
import pytest
from src import claude_session
from src.claude_session import (
_get_session_lock,
_session_locks,
send_message,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _clear_session_locks():
"""Each test starts with a fresh lock map so we don't share state."""
_session_locks.clear()
yield
_session_locks.clear()
@pytest.fixture
def temp_sessions(tmp_path, monkeypatch):
"""Isolated active.json per test — keeps real session state untouched."""
sessions_dir = tmp_path / "sessions"
sessions_dir.mkdir()
sf = sessions_dir / "active.json"
sf.write_text("{}")
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
return sf
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
concurrent_seen: threading.Event):
"""Build a fake `_run_claude` that signals when inside the critical section.
The fake holds the simulated subprocess for `sleep_seconds`. Any other
invocation that overlaps will set `concurrent_seen` — the mutex test
asserts this NEVER happens for the same channel_id.
"""
state = {"active": 0, "lock": threading.Lock()}
def fake(cmd, timeout, on_text=None, cwd=None):
with state["lock"]:
state["active"] += 1
if state["active"] > 1:
concurrent_seen.set()
in_critical.set()
time.sleep(sleep_seconds)
with state["lock"]:
state["active"] -= 1
return {
"result": "Hello from Claude!",
"session_id": "sess-abc-123",
"usage": {"input_tokens": 10, "output_tokens": 5},
"total_cost_usd": 0.001,
"cost_usd": 0.001,
"duration_ms": int(sleep_seconds * 1000),
"num_turns": 1,
"intermediate_count": 0,
"subtype": "success",
"is_error": False,
}
return fake
# ---------------------------------------------------------------------------
# Scenario 1 — same channel serializes
# ---------------------------------------------------------------------------
class TestSameChannelSerializes:
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
self, temp_sessions
):
"""Two parallel send_message on the SAME channel_id never overlap.
We instrument `_run_claude` to signal whenever more than one
invocation is concurrently inside it. The mutex MUST prevent that.
"""
in_critical = threading.Event()
concurrent_seen = threading.Event()
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
with patch.object(claude_session, "_run_claude", side_effect=slow):
start = time.monotonic()
with ThreadPoolExecutor(max_workers=2) as pool:
futures = [
pool.submit(send_message, "ch-same", f"msg-{i}")
for i in range(2)
]
results = [f.result(timeout=10) for f in futures]
elapsed = time.monotonic() - start
assert not concurrent_seen.is_set(), (
"Two send_message calls on the same channel ran concurrently — "
"mutex did not serialize them."
)
assert all(r == "Hello from Claude!" for r in results)
# Two serial 0.25s subprocesses must take at least ~0.5s total
# (we allow a generous floor — schedulers can be slow).
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
"""`_get_session_lock` returns the SAME lock object for the same channel."""
lock_a1 = _get_session_lock("channel-A")
lock_a2 = _get_session_lock("channel-A")
lock_b = _get_session_lock("channel-B")
assert lock_a1 is lock_a2
assert lock_a1 is not lock_b
# ---------------------------------------------------------------------------
# Scenario 2 — different channels parallel
# ---------------------------------------------------------------------------
class TestDifferentChannelsParallel:
def test_two_concurrent_calls_different_channels_run_in_parallel(
self, temp_sessions
):
"""Different channels MUST NOT block each other.
We measure elapsed wall-clock: two 0.4s subprocesses on different
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
"""
in_critical = threading.Event()
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
concurrent_seen = threading.Event()
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
with patch.object(claude_session, "_run_claude", side_effect=slow):
start = time.monotonic()
with ThreadPoolExecutor(max_workers=2) as pool:
f1 = pool.submit(send_message, "ch-A", "msg-A")
f2 = pool.submit(send_message, "ch-B", "msg-B")
results = [f1.result(timeout=10), f2.result(timeout=10)]
elapsed = time.monotonic() - start
assert all(r == "Hello from Claude!" for r in results)
# Parallel execution: total time should be close to 0.4s, well under
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
# scheduler jitter on a busy test box.
assert elapsed < 0.65, (
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
f"(expected ~0.4s parallel, <0.65s ceiling)"
)
assert concurrent_seen.is_set(), (
"Different channels did not overlap — mutex is too coarse "
"(should be per-channel, not global)."
)
def test_three_channels_all_overlap(self, temp_sessions):
"""Stress: three concurrent channels all run in parallel."""
in_critical = threading.Event()
concurrent_seen = threading.Event()
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
with patch.object(claude_session, "_run_claude", side_effect=slow):
start = time.monotonic()
with ThreadPoolExecutor(max_workers=3) as pool:
futures = [
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
for i in range(3)
]
for f in as_completed(futures, timeout=10):
assert f.result() == "Hello from Claude!"
elapsed = time.monotonic() - start
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
assert elapsed < 0.6, (
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
)
# ---------------------------------------------------------------------------
# Scenario 3 — acquisition behavior documented and consistent
# ---------------------------------------------------------------------------
class TestAcquisitionBehavior:
"""Pin the chosen acquisition policy: blocking, no timeout.
Project style is to bound subprocess execution via `timeout` (default
5 min) rather than fail-fast on lock acquire. Reasons:
- Adapter callers (Discord/Telegram/voice) already serialize work via
asyncio.to_thread; queue depth is naturally bounded.
- A non-blocking acquire would surface a timing error to the user
("busy, try again") for an entirely transient and self-resolving
condition. Blocking gives FIFO-ish ordering with simple semantics.
- If a subprocess truly hangs past `timeout`, _run_claude raises
TimeoutError → the held lock releases (via `with`) → queued
callers proceed.
This test pins that: a second caller waits and eventually proceeds; it
does not raise an exception on contention.
"""
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
in_critical = threading.Event()
concurrent_seen = threading.Event()
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
results: list[str | BaseException] = []
def run(label: str):
try:
results.append(send_message("ch-contend", label))
except BaseException as e:
results.append(e)
with patch.object(claude_session, "_run_claude", side_effect=slow):
t1 = threading.Thread(target=run, args=("first",))
t1.start()
# Wait until the first call is inside the critical section so
# the second is GUARANTEED to contend on the lock.
assert in_critical.wait(timeout=2.0), "first call never entered"
in_critical.clear()
t2 = threading.Thread(target=run, args=("second",))
t2.start()
t1.join(timeout=5.0)
t2.join(timeout=5.0)
assert len(results) == 2
# Both must return the canned response — no exception, no error.
assert all(r == "Hello from Claude!" for r in results), (
f"Contended acquire surfaced an error instead of blocking: {results}"
)
# Critical-section overlap check: contended calls MUST serialize.
assert not concurrent_seen.is_set(), (
"Contended same-channel calls ran concurrently — mutex broken."
)
def test_lock_released_on_subprocess_exception(self, temp_sessions):
"""If `_run_claude` raises, the lock MUST be released so the next
caller can proceed (otherwise a single error deadlocks the channel
forever)."""
call_count = {"n": 0}
def flaky(cmd, timeout, on_text=None, cwd=None):
call_count["n"] += 1
if call_count["n"] == 1:
raise RuntimeError("simulated subprocess crash")
return {
"result": "Hello from Claude!",
"session_id": "sess-abc-123",
"usage": {"input_tokens": 10, "output_tokens": 5},
"total_cost_usd": 0.001,
"cost_usd": 0.001,
"duration_ms": 50,
"num_turns": 1,
"intermediate_count": 0,
"subtype": "success",
"is_error": False,
}
with patch.object(claude_session, "_run_claude", side_effect=flaky):
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
send_message("ch-recover", "first")
# Second call MUST acquire the lock (proves the first released it).
# We use a short timeout via a thread so a deadlock would fail loudly.
done = threading.Event()
result_box: list[str] = []
def second():
result_box.append(send_message("ch-recover", "second"))
done.set()
t = threading.Thread(target=second)
t.start()
assert done.wait(timeout=3.0), (
"Second call deadlocked — lock was not released on exception."
)
t.join(timeout=1.0)
assert result_box == ["Hello from Claude!"]