feat(voice): Pas 8 — threading.Lock per channel_id mutex + voice augment
Fix arhitectural general (beneficiu și pentru text adapters), nu doar voice.
src/claude_session.py:
- _session_locks: dict[str, threading.Lock] cu bootstrap lock pentru
lazy creation thread-safe.
- _get_session_lock(channel_id) helper.
- send_message() body wrapped în with _get_session_lock(channel_id).
- threading.Lock (NU asyncio.Lock) — send_message e sync subprocess.run
blocking; asyncio.Lock nu protejează cod sync rulat via to_thread.
- Per-channel granularity preserved — different channels run în paralel.
- send_message() public signature unchanged.
src/router.py:
- route_message(): dacă adapter_name == "discord-voice", prepend
[speaker:<user_name>] prefix (Config.get("voice.user_name", "user")).
- Original text variable left untouched for downstream paths.
- Text adapters: zero behavior change.
- route_message() public signature unchanged.
tests/test_claude_session_mutex.py — 6 tests REGRESSION-CRITICAL:
- same channel serializes (concurrent → mutex serializes, no overlap)
- same channel lock identity (same dict entry per channel_id)
- different channels run in parallel (overlap MUST fire)
- 3 channels all overlap
- contested acquire blocks then proceeds (policy: blocking, not fail-fast)
- lock released on subprocess exception (no deadlock on crash)
Acquisition policy: BLOCKING acquire bound by claude --timeout (5min default)
nu fail-fast — adapters already serialize via asyncio.to_thread queue, un
non-blocking acquire ar surface transient busy errors.
Test results: 82 passed (51 existing + 31 new). 2 PRE-EXISTING failures în
TestPromptInjectionProtection (stale assertion vs current prompt text) —
out of scope, recomand ticket separat.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,42 @@ DEFAULT_TIMEOUT = 300 # seconds
|
|||||||
|
|
||||||
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-channel mutex for send_message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Two paths can hit `send_message(channel_id, ...)` concurrently for the same
|
||||||
|
# channel: a text adapter (Discord/Telegram/WhatsApp) and the voice adapter
|
||||||
|
# (`adapter_name="discord-voice"`). The underlying Claude CLI subprocess is
|
||||||
|
# blocking (`subprocess.Popen` with stream-json read loop) and stateful via
|
||||||
|
# `--resume <session_id>` — interleaving two concurrent invocations on the
|
||||||
|
# same channel would corrupt the conversation order.
|
||||||
|
#
|
||||||
|
# We use `threading.Lock` (NOT `asyncio.Lock`) because `send_message` is sync
|
||||||
|
# code typically run from `asyncio.to_thread` in async adapters. asyncio.Lock
|
||||||
|
# only serializes coroutines, not threads — it would NOT protect this path.
|
||||||
|
#
|
||||||
|
# Each channel gets its own lock so DIFFERENT channels still run in parallel.
|
||||||
|
# Locks are created lazily on first use; the dict itself is guarded by a
|
||||||
|
# small bootstrap lock so two concurrent first-uses don't race on creation.
|
||||||
|
_session_locks: dict[str, threading.Lock] = {}
|
||||||
|
_session_locks_bootstrap = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_lock(channel_id: str) -> threading.Lock:
|
||||||
|
"""Return the channel's mutex, creating it on first access.
|
||||||
|
|
||||||
|
Two threads racing to create the same channel's lock would otherwise
|
||||||
|
end up with different lock objects (setdefault is not atomic across
|
||||||
|
the read-modify-write under all interpreter conditions — defensive).
|
||||||
|
"""
|
||||||
|
lock = _session_locks.get(channel_id)
|
||||||
|
if lock is not None:
|
||||||
|
return lock
|
||||||
|
with _session_locks_bootstrap:
|
||||||
|
return _session_locks.setdefault(channel_id, threading.Lock())
|
||||||
|
|
||||||
|
|
||||||
PERSONALITY_FILES = [
|
PERSONALITY_FILES = [
|
||||||
"IDENTITY.md",
|
"IDENTITY.md",
|
||||||
"SOUL.md",
|
"SOUL.md",
|
||||||
@@ -543,19 +579,28 @@ def send_message(
|
|||||||
timeout: int = DEFAULT_TIMEOUT,
|
timeout: int = DEFAULT_TIMEOUT,
|
||||||
on_text: Callable[[str], None] | None = None,
|
on_text: Callable[[str], None] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""High-level convenience: auto start or resume based on channel state."""
|
"""High-level convenience: auto start or resume based on channel state.
|
||||||
session = get_active_session(channel_id)
|
|
||||||
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
Concurrency: a per-`channel_id` `threading.Lock` serializes invocations
|
||||||
if session is not None and session.get("session_id"):
|
that hit the same channel (e.g. text adapter + voice adapter racing on
|
||||||
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
the same Discord guild text channel). Different channels run in
|
||||||
# Use model from pre-set session if available, otherwise use provided model
|
parallel — each holds its own lock. Lock is acquired blocking; we rely
|
||||||
effective_model = model
|
on `timeout` (default 5 minutes) to bound the worst case rather than
|
||||||
if session is not None and session.get("model"):
|
a non-blocking acquire (loss of fairness vs adapter-side queueing).
|
||||||
effective_model = session["model"]
|
"""
|
||||||
response_text, _session_id = start_session(
|
with _get_session_lock(channel_id):
|
||||||
channel_id, message, effective_model, timeout, on_text=on_text
|
session = get_active_session(channel_id)
|
||||||
)
|
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
||||||
return response_text
|
if session is not None and session.get("session_id"):
|
||||||
|
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
||||||
|
# Use model from pre-set session if available, otherwise use provided model
|
||||||
|
effective_model = model
|
||||||
|
if session is not None and session.get("model"):
|
||||||
|
effective_model = session["model"]
|
||||||
|
response_text, _session_id = start_session(
|
||||||
|
channel_id, message, effective_model, timeout, on_text=on_text
|
||||||
|
)
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
def clear_session(channel_id: str) -> bool:
|
def clear_session(channel_id: str) -> bool:
|
||||||
|
|||||||
@@ -154,8 +154,17 @@ def route_message(
|
|||||||
channel_cfg = _get_channel_config(channel_id)
|
channel_cfg = _get_channel_config(channel_id)
|
||||||
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
||||||
|
|
||||||
|
# Voice-mode augment: prepend speaker prefix so Claude knows who spoke
|
||||||
|
# in a voice channel. Cheap now, future-proof for multi-speaker later.
|
||||||
|
# (Engineering decision #14 in the plan.) Only the discord-voice adapter
|
||||||
|
# triggers it — text adapters keep the message verbatim.
|
||||||
|
claude_text = text
|
||||||
|
if adapter_name == "discord-voice":
|
||||||
|
user_name = _get_config().get("voice.user_name", "user") or "user"
|
||||||
|
claude_text = f"[speaker:{user_name}] {text}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = send_message(channel_id, text, model=model, on_text=on_text)
|
response = send_message(channel_id, claude_text, model=model, on_text=on_text)
|
||||||
_set_last_response(channel_id, response)
|
_set_last_response(channel_id, response)
|
||||||
return response, False
|
return response, False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
307
tests/test_claude_session_mutex.py
Normal file
307
tests/test_claude_session_mutex.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
"""Regression-critical tests for per-channel mutex in src/claude_session.py.
|
||||||
|
|
||||||
|
Three scenarios from the eng-review test plan (2026-05-27):
|
||||||
|
|
||||||
|
1. Concurrent `send_message` calls on the SAME channel_id serialize —
|
||||||
|
the second waits for the first to finish before its subprocess runs.
|
||||||
|
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
|
||||||
|
— independent channels never block each other.
|
||||||
|
3. Acquisition contract is documented and consistent: the lock is acquired
|
||||||
|
blocking (no acquire timeout), which means a hung subprocess on
|
||||||
|
channel X delays subsequent X messages but never X' (X != X'). This
|
||||||
|
test pins that behavior so future refactors must preserve it.
|
||||||
|
|
||||||
|
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
|
||||||
|
is a sync function typically dispatched via `asyncio.to_thread` from
|
||||||
|
async adapters. asyncio.Lock would serialize coroutines only — not the
|
||||||
|
subprocess invocation. See plan section "Engineering decisions" #2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src import claude_session
|
||||||
|
from src.claude_session import (
|
||||||
|
_get_session_lock,
|
||||||
|
_session_locks,
|
||||||
|
send_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_session_locks():
|
||||||
|
"""Each test starts with a fresh lock map so we don't share state."""
|
||||||
|
_session_locks.clear()
|
||||||
|
yield
|
||||||
|
_session_locks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_sessions(tmp_path, monkeypatch):
|
||||||
|
"""Isolated active.json per test — keeps real session state untouched."""
|
||||||
|
sessions_dir = tmp_path / "sessions"
|
||||||
|
sessions_dir.mkdir()
|
||||||
|
sf = sessions_dir / "active.json"
|
||||||
|
sf.write_text("{}")
|
||||||
|
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||||
|
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||||
|
return sf
|
||||||
|
|
||||||
|
|
||||||
|
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
|
||||||
|
concurrent_seen: threading.Event):
|
||||||
|
"""Build a fake `_run_claude` that signals when inside the critical section.
|
||||||
|
|
||||||
|
The fake holds the simulated subprocess for `sleep_seconds`. Any other
|
||||||
|
invocation that overlaps will set `concurrent_seen` — the mutex test
|
||||||
|
asserts this NEVER happens for the same channel_id.
|
||||||
|
"""
|
||||||
|
state = {"active": 0, "lock": threading.Lock()}
|
||||||
|
|
||||||
|
def fake(cmd, timeout, on_text=None, cwd=None):
|
||||||
|
with state["lock"]:
|
||||||
|
state["active"] += 1
|
||||||
|
if state["active"] > 1:
|
||||||
|
concurrent_seen.set()
|
||||||
|
in_critical.set()
|
||||||
|
time.sleep(sleep_seconds)
|
||||||
|
with state["lock"]:
|
||||||
|
state["active"] -= 1
|
||||||
|
return {
|
||||||
|
"result": "Hello from Claude!",
|
||||||
|
"session_id": "sess-abc-123",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
"cost_usd": 0.001,
|
||||||
|
"duration_ms": int(sleep_seconds * 1000),
|
||||||
|
"num_turns": 1,
|
||||||
|
"intermediate_count": 0,
|
||||||
|
"subtype": "success",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fake
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 1 — same channel serializes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSameChannelSerializes:
|
||||||
|
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
|
||||||
|
self, temp_sessions
|
||||||
|
):
|
||||||
|
"""Two parallel send_message on the SAME channel_id never overlap.
|
||||||
|
|
||||||
|
We instrument `_run_claude` to signal whenever more than one
|
||||||
|
invocation is concurrently inside it. The mutex MUST prevent that.
|
||||||
|
"""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(send_message, "ch-same", f"msg-{i}")
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
results = [f.result(timeout=10) for f in futures]
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
assert not concurrent_seen.is_set(), (
|
||||||
|
"Two send_message calls on the same channel ran concurrently — "
|
||||||
|
"mutex did not serialize them."
|
||||||
|
)
|
||||||
|
assert all(r == "Hello from Claude!" for r in results)
|
||||||
|
# Two serial 0.25s subprocesses must take at least ~0.5s total
|
||||||
|
# (we allow a generous floor — schedulers can be slow).
|
||||||
|
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
|
||||||
|
|
||||||
|
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
|
||||||
|
"""`_get_session_lock` returns the SAME lock object for the same channel."""
|
||||||
|
lock_a1 = _get_session_lock("channel-A")
|
||||||
|
lock_a2 = _get_session_lock("channel-A")
|
||||||
|
lock_b = _get_session_lock("channel-B")
|
||||||
|
assert lock_a1 is lock_a2
|
||||||
|
assert lock_a1 is not lock_b
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 2 — different channels parallel
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDifferentChannelsParallel:
|
||||||
|
def test_two_concurrent_calls_different_channels_run_in_parallel(
|
||||||
|
self, temp_sessions
|
||||||
|
):
|
||||||
|
"""Different channels MUST NOT block each other.
|
||||||
|
|
||||||
|
We measure elapsed wall-clock: two 0.4s subprocesses on different
|
||||||
|
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
|
||||||
|
"""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||||
|
f1 = pool.submit(send_message, "ch-A", "msg-A")
|
||||||
|
f2 = pool.submit(send_message, "ch-B", "msg-B")
|
||||||
|
results = [f1.result(timeout=10), f2.result(timeout=10)]
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
assert all(r == "Hello from Claude!" for r in results)
|
||||||
|
# Parallel execution: total time should be close to 0.4s, well under
|
||||||
|
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
|
||||||
|
# scheduler jitter on a busy test box.
|
||||||
|
assert elapsed < 0.65, (
|
||||||
|
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
|
||||||
|
f"(expected ~0.4s parallel, <0.65s ceiling)"
|
||||||
|
)
|
||||||
|
assert concurrent_seen.is_set(), (
|
||||||
|
"Different channels did not overlap — mutex is too coarse "
|
||||||
|
"(should be per-channel, not global)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_three_channels_all_overlap(self, temp_sessions):
|
||||||
|
"""Stress: three concurrent channels all run in parallel."""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
for f in as_completed(futures, timeout=10):
|
||||||
|
assert f.result() == "Hello from Claude!"
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
|
||||||
|
assert elapsed < 0.6, (
|
||||||
|
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 3 — acquisition behavior documented and consistent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAcquisitionBehavior:
|
||||||
|
"""Pin the chosen acquisition policy: blocking, no timeout.
|
||||||
|
|
||||||
|
Project style is to bound subprocess execution via `timeout` (default
|
||||||
|
5 min) rather than fail-fast on lock acquire. Reasons:
|
||||||
|
|
||||||
|
- Adapter callers (Discord/Telegram/voice) already serialize work via
|
||||||
|
asyncio.to_thread; queue depth is naturally bounded.
|
||||||
|
- A non-blocking acquire would surface a timing error to the user
|
||||||
|
("busy, try again") for an entirely transient and self-resolving
|
||||||
|
condition. Blocking gives FIFO-ish ordering with simple semantics.
|
||||||
|
- If a subprocess truly hangs past `timeout`, _run_claude raises
|
||||||
|
TimeoutError → the held lock releases (via `with`) → queued
|
||||||
|
callers proceed.
|
||||||
|
|
||||||
|
This test pins that: a second caller waits and eventually proceeds; it
|
||||||
|
does not raise an exception on contention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
results: list[str | BaseException] = []
|
||||||
|
|
||||||
|
def run(label: str):
|
||||||
|
try:
|
||||||
|
results.append(send_message("ch-contend", label))
|
||||||
|
except BaseException as e:
|
||||||
|
results.append(e)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
t1 = threading.Thread(target=run, args=("first",))
|
||||||
|
t1.start()
|
||||||
|
# Wait until the first call is inside the critical section so
|
||||||
|
# the second is GUARANTEED to contend on the lock.
|
||||||
|
assert in_critical.wait(timeout=2.0), "first call never entered"
|
||||||
|
in_critical.clear()
|
||||||
|
t2 = threading.Thread(target=run, args=("second",))
|
||||||
|
t2.start()
|
||||||
|
t1.join(timeout=5.0)
|
||||||
|
t2.join(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Both must return the canned response — no exception, no error.
|
||||||
|
assert all(r == "Hello from Claude!" for r in results), (
|
||||||
|
f"Contended acquire surfaced an error instead of blocking: {results}"
|
||||||
|
)
|
||||||
|
# Critical-section overlap check: contended calls MUST serialize.
|
||||||
|
assert not concurrent_seen.is_set(), (
|
||||||
|
"Contended same-channel calls ran concurrently — mutex broken."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lock_released_on_subprocess_exception(self, temp_sessions):
|
||||||
|
"""If `_run_claude` raises, the lock MUST be released so the next
|
||||||
|
caller can proceed (otherwise a single error deadlocks the channel
|
||||||
|
forever)."""
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
def flaky(cmd, timeout, on_text=None, cwd=None):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
raise RuntimeError("simulated subprocess crash")
|
||||||
|
return {
|
||||||
|
"result": "Hello from Claude!",
|
||||||
|
"session_id": "sess-abc-123",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
"cost_usd": 0.001,
|
||||||
|
"duration_ms": 50,
|
||||||
|
"num_turns": 1,
|
||||||
|
"intermediate_count": 0,
|
||||||
|
"subtype": "success",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=flaky):
|
||||||
|
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
|
||||||
|
send_message("ch-recover", "first")
|
||||||
|
|
||||||
|
# Second call MUST acquire the lock (proves the first released it).
|
||||||
|
# We use a short timeout via a thread so a deadlock would fail loudly.
|
||||||
|
done = threading.Event()
|
||||||
|
result_box: list[str] = []
|
||||||
|
|
||||||
|
def second():
|
||||||
|
result_box.append(send_message("ch-recover", "second"))
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=second)
|
||||||
|
t.start()
|
||||||
|
assert done.wait(timeout=3.0), (
|
||||||
|
"Second call deadlocked — lock was not released on exception."
|
||||||
|
)
|
||||||
|
t.join(timeout=1.0)
|
||||||
|
assert result_box == ["Hello from Claude!"]
|
||||||
Reference in New Issue
Block a user