feat(voice): Pas 8 — threading.Lock per channel_id mutex + voice augment
Fix arhitectural general (beneficiu și pentru text adapters), nu doar voice.
src/claude_session.py:
- _session_locks: dict[str, threading.Lock] cu bootstrap lock pentru
lazy creation thread-safe.
- _get_session_lock(channel_id) helper.
- send_message() body wrapped în with _get_session_lock(channel_id).
- threading.Lock (NU asyncio.Lock) — send_message e sync subprocess.run
blocking; asyncio.Lock nu protejează cod sync rulat via to_thread.
- Per-channel granularity preserved — different channels run în paralel.
- send_message() public signature unchanged.
src/router.py:
- route_message(): dacă adapter_name == "discord-voice", prepend
[speaker:<user_name>] prefix (Config.get("voice.user_name", "user")).
- Original text variable left untouched for downstream paths.
- Text adapters: zero behavior change.
- route_message() public signature unchanged.
tests/test_claude_session_mutex.py — 6 tests REGRESSION-CRITICAL:
- same channel serializes (concurrent → mutex serializes, no overlap)
- same channel lock identity (same dict entry per channel_id)
- different channels run in parallel (overlap MUST fire)
- 3 channels all overlap
- contested acquire blocks then proceeds (policy: blocking, not fail-fast)
- lock released on subprocess exception (no deadlock on crash)
Acquisition policy: BLOCKING acquire bound by claude --timeout (5min default)
nu fail-fast — adapters already serialize via asyncio.to_thread queue, un
non-blocking acquire ar surface transient busy errors.
Test results: 82 passed (51 existing + 31 new). 2 PRE-EXISTING failures în
TestPromptInjectionProtection (stale assertion vs current prompt text) —
out of scope, recomand ticket separat.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,42 @@ DEFAULT_TIMEOUT = 300 # seconds
|
||||
|
||||
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-channel mutex for send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Two paths can hit `send_message(channel_id, ...)` concurrently for the same
|
||||
# channel: a text adapter (Discord/Telegram/WhatsApp) and the voice adapter
|
||||
# (`adapter_name="discord-voice"`). The underlying Claude CLI subprocess is
|
||||
# blocking (`subprocess.Popen` with stream-json read loop) and stateful via
|
||||
# `--resume <session_id>` — interleaving two concurrent invocations on the
|
||||
# same channel would corrupt the conversation order.
|
||||
#
|
||||
# We use `threading.Lock` (NOT `asyncio.Lock`) because `send_message` is sync
|
||||
# code typically run from `asyncio.to_thread` in async adapters. asyncio.Lock
|
||||
# only serializes coroutines, not threads — it would NOT protect this path.
|
||||
#
|
||||
# Each channel gets its own lock so DIFFERENT channels still run in parallel.
|
||||
# Locks are created lazily on first use; the dict itself is guarded by a
|
||||
# small bootstrap lock so two concurrent first-uses don't race on creation.
|
||||
_session_locks: dict[str, threading.Lock] = {}
|
||||
_session_locks_bootstrap = threading.Lock()
|
||||
|
||||
|
||||
def _get_session_lock(channel_id: str) -> threading.Lock:
|
||||
"""Return the channel's mutex, creating it on first access.
|
||||
|
||||
Two threads racing to create the same channel's lock would otherwise
|
||||
end up with different lock objects (setdefault is not atomic across
|
||||
the read-modify-write under all interpreter conditions — defensive).
|
||||
"""
|
||||
lock = _session_locks.get(channel_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
with _session_locks_bootstrap:
|
||||
return _session_locks.setdefault(channel_id, threading.Lock())
|
||||
|
||||
|
||||
PERSONALITY_FILES = [
|
||||
"IDENTITY.md",
|
||||
"SOUL.md",
|
||||
@@ -543,7 +579,16 @@ def send_message(
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
on_text: Callable[[str], None] | None = None,
|
||||
) -> str:
|
||||
"""High-level convenience: auto start or resume based on channel state."""
|
||||
"""High-level convenience: auto start or resume based on channel state.
|
||||
|
||||
Concurrency: a per-`channel_id` `threading.Lock` serializes invocations
|
||||
that hit the same channel (e.g. text adapter + voice adapter racing on
|
||||
the same Discord guild text channel). Different channels run in
|
||||
parallel — each holds its own lock. Lock is acquired blocking; we rely
|
||||
on `timeout` (default 5 minutes) to bound the worst case rather than
|
||||
a non-blocking acquire (loss of fairness vs adapter-side queueing).
|
||||
"""
|
||||
with _get_session_lock(channel_id):
|
||||
session = get_active_session(channel_id)
|
||||
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
||||
if session is not None and session.get("session_id"):
|
||||
|
||||
@@ -154,8 +154,17 @@ def route_message(
|
||||
channel_cfg = _get_channel_config(channel_id)
|
||||
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
||||
|
||||
# Voice-mode augment: prepend speaker prefix so Claude knows who spoke
|
||||
# in a voice channel. Cheap now, future-proof for multi-speaker later.
|
||||
# (Engineering decision #14 in the plan.) Only the discord-voice adapter
|
||||
# triggers it — text adapters keep the message verbatim.
|
||||
claude_text = text
|
||||
if adapter_name == "discord-voice":
|
||||
user_name = _get_config().get("voice.user_name", "user") or "user"
|
||||
claude_text = f"[speaker:{user_name}] {text}"
|
||||
|
||||
try:
|
||||
response = send_message(channel_id, text, model=model, on_text=on_text)
|
||||
response = send_message(channel_id, claude_text, model=model, on_text=on_text)
|
||||
_set_last_response(channel_id, response)
|
||||
return response, False
|
||||
except Exception as e:
|
||||
|
||||
307
tests/test_claude_session_mutex.py
Normal file
307
tests/test_claude_session_mutex.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Regression-critical tests for per-channel mutex in src/claude_session.py.
|
||||
|
||||
Three scenarios from the eng-review test plan (2026-05-27):
|
||||
|
||||
1. Concurrent `send_message` calls on the SAME channel_id serialize —
|
||||
the second waits for the first to finish before its subprocess runs.
|
||||
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
|
||||
— independent channels never block each other.
|
||||
3. Acquisition contract is documented and consistent: the lock is acquired
|
||||
blocking (no acquire timeout), which means a hung subprocess on
|
||||
channel X delays subsequent X messages but never X' (X != X'). This
|
||||
test pins that behavior so future refactors must preserve it.
|
||||
|
||||
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
|
||||
is a sync function typically dispatched via `asyncio.to_thread` from
|
||||
async adapters. asyncio.Lock would serialize coroutines only — not the
|
||||
subprocess invocation. See plan section "Engineering decisions" #2.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src import claude_session
|
||||
from src.claude_session import (
|
||||
_get_session_lock,
|
||||
_session_locks,
|
||||
send_message,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_session_locks():
|
||||
"""Each test starts with a fresh lock map so we don't share state."""
|
||||
_session_locks.clear()
|
||||
yield
|
||||
_session_locks.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions(tmp_path, monkeypatch):
|
||||
"""Isolated active.json per test — keeps real session state untouched."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
sf.write_text("{}")
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
return sf
|
||||
|
||||
|
||||
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
|
||||
concurrent_seen: threading.Event):
|
||||
"""Build a fake `_run_claude` that signals when inside the critical section.
|
||||
|
||||
The fake holds the simulated subprocess for `sleep_seconds`. Any other
|
||||
invocation that overlaps will set `concurrent_seen` — the mutex test
|
||||
asserts this NEVER happens for the same channel_id.
|
||||
"""
|
||||
state = {"active": 0, "lock": threading.Lock()}
|
||||
|
||||
def fake(cmd, timeout, on_text=None, cwd=None):
|
||||
with state["lock"]:
|
||||
state["active"] += 1
|
||||
if state["active"] > 1:
|
||||
concurrent_seen.set()
|
||||
in_critical.set()
|
||||
time.sleep(sleep_seconds)
|
||||
with state["lock"]:
|
||||
state["active"] -= 1
|
||||
return {
|
||||
"result": "Hello from Claude!",
|
||||
"session_id": "sess-abc-123",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
"total_cost_usd": 0.001,
|
||||
"cost_usd": 0.001,
|
||||
"duration_ms": int(sleep_seconds * 1000),
|
||||
"num_turns": 1,
|
||||
"intermediate_count": 0,
|
||||
"subtype": "success",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1 — same channel serializes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSameChannelSerializes:
|
||||
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
|
||||
self, temp_sessions
|
||||
):
|
||||
"""Two parallel send_message on the SAME channel_id never overlap.
|
||||
|
||||
We instrument `_run_claude` to signal whenever more than one
|
||||
invocation is concurrently inside it. The mutex MUST prevent that.
|
||||
"""
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(send_message, "ch-same", f"msg-{i}")
|
||||
for i in range(2)
|
||||
]
|
||||
results = [f.result(timeout=10) for f in futures]
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert not concurrent_seen.is_set(), (
|
||||
"Two send_message calls on the same channel ran concurrently — "
|
||||
"mutex did not serialize them."
|
||||
)
|
||||
assert all(r == "Hello from Claude!" for r in results)
|
||||
# Two serial 0.25s subprocesses must take at least ~0.5s total
|
||||
# (we allow a generous floor — schedulers can be slow).
|
||||
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
|
||||
|
||||
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
|
||||
"""`_get_session_lock` returns the SAME lock object for the same channel."""
|
||||
lock_a1 = _get_session_lock("channel-A")
|
||||
lock_a2 = _get_session_lock("channel-A")
|
||||
lock_b = _get_session_lock("channel-B")
|
||||
assert lock_a1 is lock_a2
|
||||
assert lock_a1 is not lock_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2 — different channels parallel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDifferentChannelsParallel:
|
||||
def test_two_concurrent_calls_different_channels_run_in_parallel(
|
||||
self, temp_sessions
|
||||
):
|
||||
"""Different channels MUST NOT block each other.
|
||||
|
||||
We measure elapsed wall-clock: two 0.4s subprocesses on different
|
||||
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
|
||||
"""
|
||||
in_critical = threading.Event()
|
||||
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||
f1 = pool.submit(send_message, "ch-A", "msg-A")
|
||||
f2 = pool.submit(send_message, "ch-B", "msg-B")
|
||||
results = [f1.result(timeout=10), f2.result(timeout=10)]
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert all(r == "Hello from Claude!" for r in results)
|
||||
# Parallel execution: total time should be close to 0.4s, well under
|
||||
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
|
||||
# scheduler jitter on a busy test box.
|
||||
assert elapsed < 0.65, (
|
||||
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
|
||||
f"(expected ~0.4s parallel, <0.65s ceiling)"
|
||||
)
|
||||
assert concurrent_seen.is_set(), (
|
||||
"Different channels did not overlap — mutex is too coarse "
|
||||
"(should be per-channel, not global)."
|
||||
)
|
||||
|
||||
def test_three_channels_all_overlap(self, temp_sessions):
|
||||
"""Stress: three concurrent channels all run in parallel."""
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||
futures = [
|
||||
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
for f in as_completed(futures, timeout=10):
|
||||
assert f.result() == "Hello from Claude!"
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
|
||||
assert elapsed < 0.6, (
|
||||
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3 — acquisition behavior documented and consistent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAcquisitionBehavior:
|
||||
"""Pin the chosen acquisition policy: blocking, no timeout.
|
||||
|
||||
Project style is to bound subprocess execution via `timeout` (default
|
||||
5 min) rather than fail-fast on lock acquire. Reasons:
|
||||
|
||||
- Adapter callers (Discord/Telegram/voice) already serialize work via
|
||||
asyncio.to_thread; queue depth is naturally bounded.
|
||||
- A non-blocking acquire would surface a timing error to the user
|
||||
("busy, try again") for an entirely transient and self-resolving
|
||||
condition. Blocking gives FIFO-ish ordering with simple semantics.
|
||||
- If a subprocess truly hangs past `timeout`, _run_claude raises
|
||||
TimeoutError → the held lock releases (via `with`) → queued
|
||||
callers proceed.
|
||||
|
||||
This test pins that: a second caller waits and eventually proceeds; it
|
||||
does not raise an exception on contention.
|
||||
"""
|
||||
|
||||
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||
|
||||
results: list[str | BaseException] = []
|
||||
|
||||
def run(label: str):
|
||||
try:
|
||||
results.append(send_message("ch-contend", label))
|
||||
except BaseException as e:
|
||||
results.append(e)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
t1 = threading.Thread(target=run, args=("first",))
|
||||
t1.start()
|
||||
# Wait until the first call is inside the critical section so
|
||||
# the second is GUARANTEED to contend on the lock.
|
||||
assert in_critical.wait(timeout=2.0), "first call never entered"
|
||||
in_critical.clear()
|
||||
t2 = threading.Thread(target=run, args=("second",))
|
||||
t2.start()
|
||||
t1.join(timeout=5.0)
|
||||
t2.join(timeout=5.0)
|
||||
|
||||
assert len(results) == 2
|
||||
# Both must return the canned response — no exception, no error.
|
||||
assert all(r == "Hello from Claude!" for r in results), (
|
||||
f"Contended acquire surfaced an error instead of blocking: {results}"
|
||||
)
|
||||
# Critical-section overlap check: contended calls MUST serialize.
|
||||
assert not concurrent_seen.is_set(), (
|
||||
"Contended same-channel calls ran concurrently — mutex broken."
|
||||
)
|
||||
|
||||
def test_lock_released_on_subprocess_exception(self, temp_sessions):
|
||||
"""If `_run_claude` raises, the lock MUST be released so the next
|
||||
caller can proceed (otherwise a single error deadlocks the channel
|
||||
forever)."""
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def flaky(cmd, timeout, on_text=None, cwd=None):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise RuntimeError("simulated subprocess crash")
|
||||
return {
|
||||
"result": "Hello from Claude!",
|
||||
"session_id": "sess-abc-123",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
"total_cost_usd": 0.001,
|
||||
"cost_usd": 0.001,
|
||||
"duration_ms": 50,
|
||||
"num_turns": 1,
|
||||
"intermediate_count": 0,
|
||||
"subtype": "success",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=flaky):
|
||||
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
|
||||
send_message("ch-recover", "first")
|
||||
|
||||
# Second call MUST acquire the lock (proves the first released it).
|
||||
# We use a short timeout via a thread so a deadlock would fail loudly.
|
||||
done = threading.Event()
|
||||
result_box: list[str] = []
|
||||
|
||||
def second():
|
||||
result_box.append(send_message("ch-recover", "second"))
|
||||
done.set()
|
||||
|
||||
t = threading.Thread(target=second)
|
||||
t.start()
|
||||
assert done.wait(timeout=3.0), (
|
||||
"Second call deadlocked — lock was not released on exception."
|
||||
)
|
||||
t.join(timeout=1.0)
|
||||
assert result_box == ["Hello from Claude!"]
|
||||
Reference in New Issue
Block a user