Compare commits
10 Commits
af5af8133f
...
d1bc77e87d
| Author | SHA1 | Date | |
|---|---|---|---|
| d1bc77e87d | |||
| e4f3177fc1 | |||
| 13931db953 | |||
| 23666f7910 | |||
| 217da65417 | |||
| 0cc01c1450 | |||
| c93c4f822e | |||
| 3af6bcaea4 | |||
| a3eefbc799 | |||
| a48562b2f5 |
101
cli.py
101
cli.py
@@ -114,6 +114,104 @@ def _load_sessions_file() -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def _voice_doctor_checks() -> list[tuple[str, bool]]:
|
||||
"""Voice-stack health checks (Pas 10).
|
||||
|
||||
Mirrors the logic in tools/voice_setup.py but returns (label, ok) tuples
|
||||
so they integrate with cmd_doctor's PASS/FAIL output. All checks degrade
|
||||
gracefully — ImportError on optional voice deps is reported as FAIL, never
|
||||
raised, so the rest of `eco doctor` is unaffected.
|
||||
"""
|
||||
import importlib.util
|
||||
import json as _json
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
results: list[tuple[str, bool]] = []
|
||||
|
||||
# 1. libopus0 loaded by discord.py
|
||||
try:
|
||||
import discord
|
||||
if not discord.opus.is_loaded():
|
||||
try:
|
||||
discord.opus._load_default()
|
||||
except Exception:
|
||||
pass
|
||||
results.append(("libopus loaded (discord.py)", discord.opus.is_loaded()))
|
||||
except ImportError:
|
||||
results.append(("libopus loaded (discord.py)", False))
|
||||
except Exception:
|
||||
results.append(("libopus loaded (discord.py)", False))
|
||||
|
||||
# 2. ffmpeg in PATH
|
||||
results.append(("ffmpeg in PATH", shutil.which("ffmpeg") is not None))
|
||||
|
||||
# 3. Supertonic TTS reachable at http://127.0.0.1:7788/
|
||||
supertonic_url = "http://127.0.0.1:7788/v1/audio/speech"
|
||||
supertonic_ok = False
|
||||
try:
|
||||
payload = _json.dumps({
|
||||
"model": "supertonic-3",
|
||||
"input": "test",
|
||||
"voice": "M2",
|
||||
"response_format": "wav",
|
||||
"lang": "ro",
|
||||
}).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
supertonic_url,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||
supertonic_ok = resp.status == 200
|
||||
except (urllib.error.URLError, ConnectionError, OSError):
|
||||
supertonic_ok = False
|
||||
except Exception:
|
||||
supertonic_ok = False
|
||||
results.append(("Supertonic TTS reachable at :7788", supertonic_ok))
|
||||
|
||||
# 4. faster-whisper importable (don't load model — too slow)
|
||||
results.append((
|
||||
"faster-whisper importable",
|
||||
importlib.util.find_spec("faster_whisper") is not None,
|
||||
))
|
||||
|
||||
# 5. silero-vad importable
|
||||
results.append((
|
||||
"silero-vad importable",
|
||||
importlib.util.find_spec("silero_vad") is not None,
|
||||
))
|
||||
|
||||
# 6. discord.ext.voice_recv importable (vendor package)
|
||||
voice_recv_ok = False
|
||||
try:
|
||||
voice_recv_ok = importlib.util.find_spec("discord.ext.voice_recv") is not None
|
||||
except (ImportError, ValueError, ModuleNotFoundError):
|
||||
voice_recv_ok = False
|
||||
except Exception:
|
||||
voice_recv_ok = False
|
||||
results.append(("discord.ext.voice_recv importable", voice_recv_ok))
|
||||
|
||||
# 7-9. Voice assets present and non-trivial size
|
||||
voice_assets = [
|
||||
("assets/voice/thinking.wav", 1024),
|
||||
("assets/voice/beep_200ms.wav", 512),
|
||||
("assets/voice/mhm.wav", 512),
|
||||
]
|
||||
for rel_path, min_bytes in voice_assets:
|
||||
path = PROJECT_ROOT / rel_path
|
||||
ok = False
|
||||
try:
|
||||
ok = path.exists() and path.stat().st_size > min_bytes
|
||||
except OSError:
|
||||
ok = False
|
||||
label = f"{rel_path} (>{min_bytes}B)"
|
||||
results.append((label, ok))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def cmd_doctor(args):
|
||||
"""Run diagnostic checks."""
|
||||
import re
|
||||
@@ -227,6 +325,9 @@ def cmd_doctor(args):
|
||||
else:
|
||||
checks.append(("WhatsApp bridge (optional)", True))
|
||||
|
||||
# ---- Voice stack checks (Pas 10) ----
|
||||
checks.extend(_voice_doctor_checks())
|
||||
|
||||
# Print results
|
||||
all_pass = True
|
||||
for label, passed in checks:
|
||||
|
||||
@@ -104,6 +104,12 @@
|
||||
"ollama": {
|
||||
"url": "http://10.0.20.161:11434"
|
||||
},
|
||||
"voice": {
|
||||
"allowed_user_ids": ["949388626146517022"],
|
||||
"user_name": "Marius",
|
||||
"default_voice": "M2",
|
||||
"auto_leave_minutes": 5
|
||||
},
|
||||
"paths": {
|
||||
"personality": "personality/",
|
||||
"tools": "tools/",
|
||||
|
||||
@@ -189,4 +189,16 @@ Când lansez sub-agent, îi dau context: AGENTS.md, SOUL.md, USER.md + relevant
|
||||
- Discord links: `<url>` pentru a suprima embed-uri
|
||||
- Cand primesc o sarcina mai mare de executat, raspund intotdeauna cu o reactie sau confirmare si apoi trec la executie
|
||||
- **Link-uri:** Folosesc `https://moltbot.tailf7372d.ts.net/echo/` (NU IP 100.120.119.70) pentru ca WhatsApp să le recunoască ca link-uri
|
||||
- **Link-uri fișiere salvate:** Când salvez/menționez fișiere din `memory/kb/`, ofer automat link către `files.html#memory/kb/path/to/file.md` pentru preview
|
||||
- **Link-uri fișiere salvate:** Când salvez/menționez fișiere din `memory/kb/`, ofer automat link către `files.html#memory/kb/path/to/file.md` pentru preview
|
||||
|
||||
## Voice mode
|
||||
|
||||
Reguli aplicate când `adapter_name == "discord-voice"` — Marius mă ascultă, nu citește. Vocea e intolerantă la lung și la structură.
|
||||
|
||||
- **1-3 propoziții max per răspuns.** Dacă am mai mult de spus, condensez sau mut în chat.
|
||||
- **Fără markdown.** Niciun bold, italic, cod cu backticks, headere. Text plat, atât.
|
||||
- **Fără bullet lists, nici numerotate.** Le pronunț natural ca propoziții: "trei lucruri: în primul rând..., apoi..., și la final..."
|
||||
- **Fără linkuri.** Nu rostesc URL-uri. Dacă e relevant: "îți trimit linkul în chat".
|
||||
- **Numere și valute formulate conversațional.** Scriu "treizeci de lei", nu "30 RON"; "douăzeci și cinci la sută", nu "25%". Modulul `normalize.py` face curățare tehnică, dar eu formulez deja natural — un om vorbește, nu citește tabelul.
|
||||
- **Lung sau structurat → mută în chat.** Dacă răspunsul cere listă, cod, linkuri sau peste 3 propoziții, închei rostit cu "L-am scris în chat." iar restul ajunge în text channel mirror.
|
||||
- **Ton:** cum vorbesc cu Marius la o cafea, nu cum scriu raport. Contracții, pauze, "păi" sau "stai puțin" dacă mă ajută să sune uman. Concis, fără tic-uri robotice.
|
||||
@@ -63,6 +63,13 @@
|
||||
- **Venv:** ~/echo-core/.venv/ | **Model:** base
|
||||
- **Utilizare:** `whisper.load_model('base').transcribe(path, language='ro')`
|
||||
|
||||
### Discord Voice
|
||||
- **Ce este:** Bot conectat la un voice channel Discord — ascultă microfonul lui Marius, transcrie cu faster-whisper (`small` int8, RO), rutează prin router și răspunde rostit cu Supertonic TTS.
|
||||
- **Cum sunt "în voce":** Slash command `/voice join` mă cheamă în channel; cât stau acolo, presence-ul arată că ascult. `/voice leave` sau auto-leave după 5 minute fără voce.
|
||||
- **Latență așteptată:** ~5 secunde perceput end-to-end (STT p50 2.25s + LLM + TTS first chunk). Peste 3s pornesc un filler audio ("Stai să-mi adun gândurile") ca să nu pară mort.
|
||||
- **Streaming TTS:** răspunsul iese pe clauze, nu cuvânt-cu-cuvânt și nu frază întreagă — primul sunet pleacă imediat ce am o propoziție scurtă.
|
||||
- **Limitări:** 1-3 propoziții max (vezi AGENTS.md § Voice mode). Cuvinte rare, nume proprii sau acronime pot apărea ciudat în STT — dacă sună greșit, cer reformulare în loc să ghicesc.
|
||||
|
||||
### Pauze respirație
|
||||
- **Script:** `python3 tools/pauza_random.py`
|
||||
- **Bancă:** memory/kb/tehnici-pauza.md
|
||||
|
||||
@@ -112,6 +112,7 @@ def create_bot(config: Config) -> discord.Client:
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.voice_states = True
|
||||
|
||||
client = discord.Client(intents=intents)
|
||||
tree = app_commands.CommandTree(client)
|
||||
@@ -958,6 +959,11 @@ def create_bot(config: Config) -> discord.Client:
|
||||
else:
|
||||
await interaction.followup.send(result or "Eroare TTS.")
|
||||
|
||||
# Voice slash group (Pas 7)
|
||||
from src.adapters.discord_voice import register as register_voice
|
||||
voice_group = register_voice(tree, client)
|
||||
tree.add_command(voice_group)
|
||||
|
||||
# --- Ralph commands (autonomous project execution) ---
|
||||
|
||||
async def _autocomplete_by_status(
|
||||
@@ -1118,6 +1124,11 @@ def create_bot(config: Config) -> discord.Client:
|
||||
from datetime import datetime, timezone
|
||||
client._ready_at = datetime.now(timezone.utc)
|
||||
logger.info("Echo Core online as %s", client.user)
|
||||
# Voice models eager warmup (Pas 7)
|
||||
from src.adapters import discord_voice
|
||||
discord_voice._models_warmup_future = asyncio.create_task(
|
||||
discord_voice.warmup_models()
|
||||
)
|
||||
|
||||
async def _handle_chat(message: discord.Message) -> None:
|
||||
"""Process a chat message through the router and send the response."""
|
||||
|
||||
322
src/adapters/discord_voice.py
Normal file
322
src/adapters/discord_voice.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Discord voice slash commands (Pas 7 — CONVERGENCE wiring).
|
||||
|
||||
Registers the `/voice` slash command group on the existing CommandTree and
|
||||
exposes an async `warmup_models()` for eager model load at bot startup.
|
||||
|
||||
Owns nothing in `src/voice/*` — purely the Discord-facing wiring. Defers
|
||||
heavy lifting to:
|
||||
|
||||
- ``src.voice.pipeline.VoiceSession`` — per-guild session state machine
|
||||
- ``src.voice.pipeline.EchoVoiceSink`` — discord-ext-voice-recv sink
|
||||
- ``src.voice.tts_stream.TTSQueue`` / ``EchoStreamingAudioSource``
|
||||
- ``src.voice._discord_voice_adapter.connect_voice``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
# Optional DAVE dep (mandatory at runtime when discord.py 2.7.1 is paired with
|
||||
# Discord voice gateway v=8; tolerated missing in tests / dev environments).
|
||||
try:
|
||||
import davey
|
||||
_HAS_DAVE = True
|
||||
except ImportError:
|
||||
_HAS_DAVE = False
|
||||
|
||||
from src.config import Config
|
||||
from src.voice.pipeline import (
|
||||
VoiceSession,
|
||||
EchoVoiceSink,
|
||||
_get_whisper_model,
|
||||
_get_silero_vad,
|
||||
)
|
||||
from src.voice.tts_stream import TTSQueue, EchoStreamingAudioSource
|
||||
from src.voice._discord_voice_adapter import connect_voice
|
||||
|
||||
log = logging.getLogger("echo-core.discord.voice")
|
||||
|
||||
# Per-guild voice session registry. Key = guild_id.
|
||||
_voice_sessions: dict[int, VoiceSession] = {}
|
||||
|
||||
# Set if model warmup failed; surfaces as ephemeral error on /voice join.
|
||||
_voice_load_error: Optional[str] = None
|
||||
|
||||
# Reference to the eager warmup task created in on_ready, so /voice join can
|
||||
# await it if the user is faster than the background load.
|
||||
_models_warmup_future: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def warmup_models() -> None:
|
||||
"""Eager model load — called from `on_ready()` as a background task.
|
||||
|
||||
Runs the (synchronous, blocking) model loaders on a worker thread so the
|
||||
event loop stays responsive. On failure, sets `_voice_load_error` instead
|
||||
of raising, so `/voice join` can degrade gracefully.
|
||||
"""
|
||||
global _voice_load_error
|
||||
try:
|
||||
if not discord.opus.is_loaded():
|
||||
discord.opus.load_opus("libopus.so.0")
|
||||
if _HAS_DAVE:
|
||||
log.info("DAVE protocol v%d available (davey %s)",
|
||||
davey.DAVE_PROTOCOL_VERSION, davey.__version__)
|
||||
await asyncio.to_thread(_get_whisper_model)
|
||||
await asyncio.to_thread(_get_silero_vad)
|
||||
log.info("Voice models warm")
|
||||
except Exception as e:
|
||||
_voice_load_error = f"{type(e).__name__}: {e}"
|
||||
log.error("Voice models load failed: %s", _voice_load_error)
|
||||
|
||||
|
||||
def _get_whitelist() -> set[int]:
|
||||
"""Read `voice.allowed_user_ids` from config and coerce to int set.
|
||||
|
||||
Re-reads config from disk to pick up any runtime edits between bot start
|
||||
and /voice join.
|
||||
"""
|
||||
try:
|
||||
raw = Config().get("voice.allowed_user_ids", [])
|
||||
except Exception:
|
||||
raw = []
|
||||
out: set[int] = set()
|
||||
for v in raw or []:
|
||||
try:
|
||||
out.add(int(v))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _get_default_voice() -> str:
|
||||
try:
|
||||
return Config().get("voice.default_voice", "M2") or "M2"
|
||||
except Exception:
|
||||
return "M2"
|
||||
|
||||
|
||||
def register(tree: app_commands.CommandTree, bot: discord.Client) -> app_commands.Group:
|
||||
"""Build the `/voice` slash command group and return it (caller registers)."""
|
||||
voice_group = app_commands.Group(
|
||||
name="voice", description="Echo Core voice channel"
|
||||
)
|
||||
|
||||
@voice_group.command(name="join", description="Echo intră în voice channel-ul tău")
|
||||
async def join(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
if _voice_load_error:
|
||||
await interaction.followup.send(
|
||||
f"Voice unavailable: {_voice_load_error}", ephemeral=True
|
||||
)
|
||||
return
|
||||
if _models_warmup_future is not None and not _models_warmup_future.done():
|
||||
try:
|
||||
await _models_warmup_future
|
||||
except Exception as e:
|
||||
await interaction.followup.send(
|
||||
f"Voice unavailable: {type(e).__name__}: {e}", ephemeral=True
|
||||
)
|
||||
return
|
||||
user = interaction.user
|
||||
if not isinstance(user, discord.Member) or user.voice is None or user.voice.channel is None:
|
||||
await interaction.followup.send(
|
||||
"Intră într-un voice channel întâi.", ephemeral=True
|
||||
)
|
||||
return
|
||||
channel = user.voice.channel
|
||||
whitelist = _get_whitelist()
|
||||
if user.id not in whitelist:
|
||||
await interaction.followup.send(
|
||||
"Nu ești pe whitelist voice.", ephemeral=True
|
||||
)
|
||||
return
|
||||
# Reject double-join on the same guild.
|
||||
guild_id = channel.guild.id
|
||||
if guild_id in _voice_sessions:
|
||||
await interaction.followup.send(
|
||||
"Sunt deja în voice pe acest server. Folosește /voice leave întâi.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
# Connect
|
||||
try:
|
||||
vc = await connect_voice(channel)
|
||||
except Exception as e:
|
||||
log.exception("connect_voice failed")
|
||||
await interaction.followup.send(
|
||||
f"Conectare eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||
)
|
||||
return
|
||||
# Build TTS queue + session
|
||||
ttsq = TTSQueue(voice_id=_get_default_voice(), lang="ro")
|
||||
ttsq.start()
|
||||
try:
|
||||
session = VoiceSession(
|
||||
channel_id=channel.id,
|
||||
guild_id=guild_id,
|
||||
voice_client=vc,
|
||||
text_channel=interaction.channel,
|
||||
record_enabled=False,
|
||||
mirror_enabled=True,
|
||||
whitelist=whitelist,
|
||||
ttsq=ttsq,
|
||||
bot=bot,
|
||||
loop=asyncio.get_running_loop(),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("VoiceSession construction failed")
|
||||
ttsq.stop()
|
||||
try:
|
||||
await vc.disconnect(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
await interaction.followup.send(
|
||||
f"Sesiune voice eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||
)
|
||||
return
|
||||
_voice_sessions[guild_id] = session
|
||||
# Start TTS streaming source for the entire session. Chain the
|
||||
# wake-up beep via `after=` so streaming takes over when beep ends.
|
||||
def _start_stream(error: Optional[Exception] = None) -> None:
|
||||
if error is not None:
|
||||
log.warning("Beep playback ended with error: %s", error)
|
||||
try:
|
||||
vc.play(EchoStreamingAudioSource(ttsq))
|
||||
log.info("TTS streaming source attached")
|
||||
except Exception:
|
||||
log.exception("EchoStreamingAudioSource attach failed")
|
||||
try:
|
||||
vc.play(
|
||||
discord.FFmpegPCMAudio("assets/voice/beep_200ms.wav"),
|
||||
after=_start_stream,
|
||||
)
|
||||
except Exception:
|
||||
log.warning("Beep playback skipped, starting stream directly", exc_info=True)
|
||||
_start_stream()
|
||||
# Attach sink
|
||||
try:
|
||||
bot_user_id = int(bot.user.id) if bot.user is not None else 0
|
||||
sink = EchoVoiceSink(session=session, bot_user_id=bot_user_id)
|
||||
vc.listen(sink)
|
||||
except Exception as e:
|
||||
log.exception("Sink attach failed")
|
||||
_voice_sessions.pop(guild_id, None)
|
||||
try:
|
||||
session.cleanup("sink_attach_failed")
|
||||
except Exception:
|
||||
pass
|
||||
await interaction.followup.send(
|
||||
f"Atașare sink eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||
)
|
||||
return
|
||||
# Presence
|
||||
try:
|
||||
await bot.change_presence(activity=discord.Activity(
|
||||
type=discord.ActivityType.listening,
|
||||
name=f"{user.display_name} în #{channel.name}",
|
||||
))
|
||||
except Exception:
|
||||
log.warning("Presence update skipped", exc_info=True)
|
||||
await interaction.followup.send(
|
||||
f"În voce în #{channel.name}.", ephemeral=True
|
||||
)
|
||||
|
||||
@voice_group.command(name="leave", description="Echo iese din voice channel")
|
||||
async def leave(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
guild_id = interaction.guild.id if interaction.guild else None
|
||||
session = _voice_sessions.pop(guild_id, None) if guild_id is not None else None
|
||||
if session is None:
|
||||
await interaction.followup.send(
|
||||
"Nu sunt în niciun voice channel aici.", ephemeral=True
|
||||
)
|
||||
return
|
||||
try:
|
||||
session.cleanup("user_leave")
|
||||
except Exception:
|
||||
log.exception("session.cleanup raised")
|
||||
try:
|
||||
await bot.change_presence(activity=None)
|
||||
except Exception:
|
||||
log.warning("Presence reset skipped", exc_info=True)
|
||||
await interaction.followup.send("Plecat.", ephemeral=True)
|
||||
|
||||
@voice_group.command(name="doctor", description="Verifică voice stack")
|
||||
async def doctor(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
checks: list[tuple[str, bool]] = []
|
||||
# libopus
|
||||
try:
|
||||
checks.append(("libopus", bool(discord.opus.is_loaded())))
|
||||
except Exception:
|
||||
checks.append(("libopus", False))
|
||||
# warmup
|
||||
checks.append(("voice load error", _voice_load_error is None))
|
||||
# Build response
|
||||
lines = ["**Voice doctor:**"]
|
||||
for label, ok in checks:
|
||||
lines.append(f"{'OK' if ok else 'FAIL'} — {label}")
|
||||
if _voice_load_error:
|
||||
lines.append(f" details: {_voice_load_error}")
|
||||
await interaction.followup.send("\n".join(lines), ephemeral=True)
|
||||
|
||||
# --- /voice mirror on|off ---
|
||||
mirror_group = app_commands.Group(
|
||||
name="mirror", description="Text mirror", parent=voice_group
|
||||
)
|
||||
|
||||
@mirror_group.command(name="on", description="Activează text mirror în canal")
|
||||
async def mirror_on(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
guild_id = interaction.guild.id if interaction.guild else None
|
||||
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||
if s is None:
|
||||
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||
return
|
||||
s.mirror_enabled = True
|
||||
await interaction.followup.send("Mirror ON.", ephemeral=True)
|
||||
|
||||
@mirror_group.command(name="off", description="Dezactivează text mirror")
|
||||
async def mirror_off(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
guild_id = interaction.guild.id if interaction.guild else None
|
||||
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||
if s is None:
|
||||
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||
return
|
||||
s.mirror_enabled = False
|
||||
await interaction.followup.send("Mirror OFF.", ephemeral=True)
|
||||
|
||||
# --- /voice record on|off ---
|
||||
record_group = app_commands.Group(
|
||||
name="record", description="KB recording", parent=voice_group
|
||||
)
|
||||
|
||||
@record_group.command(name="on", description="Activează înregistrare în KB")
|
||||
async def record_on(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
guild_id = interaction.guild.id if interaction.guild else None
|
||||
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||
if s is None:
|
||||
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||
return
|
||||
s.record_enabled = True
|
||||
await interaction.followup.send("Record ON.", ephemeral=True)
|
||||
|
||||
@record_group.command(name="off", description="Dezactivează înregistrare")
|
||||
async def record_off(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
guild_id = interaction.guild.id if interaction.guild else None
|
||||
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||
if s is None:
|
||||
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||
return
|
||||
s.record_enabled = False
|
||||
await interaction.followup.send("Record OFF.", ephemeral=True)
|
||||
|
||||
return voice_group
|
||||
@@ -37,6 +37,42 @@ DEFAULT_TIMEOUT = 300 # seconds
|
||||
|
||||
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-channel mutex for send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Two paths can hit `send_message(channel_id, ...)` concurrently for the same
|
||||
# channel: a text adapter (Discord/Telegram/WhatsApp) and the voice adapter
|
||||
# (`adapter_name="discord-voice"`). The underlying Claude CLI subprocess is
|
||||
# blocking (`subprocess.Popen` with stream-json read loop) and stateful via
|
||||
# `--resume <session_id>` — interleaving two concurrent invocations on the
|
||||
# same channel would corrupt the conversation order.
|
||||
#
|
||||
# We use `threading.Lock` (NOT `asyncio.Lock`) because `send_message` is sync
|
||||
# code typically run from `asyncio.to_thread` in async adapters. asyncio.Lock
|
||||
# only serializes coroutines, not threads — it would NOT protect this path.
|
||||
#
|
||||
# Each channel gets its own lock so DIFFERENT channels still run in parallel.
|
||||
# Locks are created lazily on first use; the dict itself is guarded by a
|
||||
# small bootstrap lock so two concurrent first-uses don't race on creation.
|
||||
_session_locks: dict[str, threading.Lock] = {}
|
||||
_session_locks_bootstrap = threading.Lock()
|
||||
|
||||
|
||||
def _get_session_lock(channel_id: str) -> threading.Lock:
|
||||
"""Return the channel's mutex, creating it on first access.
|
||||
|
||||
Two threads racing to create the same channel's lock would otherwise
|
||||
end up with different lock objects (setdefault is not atomic across
|
||||
the read-modify-write under all interpreter conditions — defensive).
|
||||
"""
|
||||
lock = _session_locks.get(channel_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
with _session_locks_bootstrap:
|
||||
return _session_locks.setdefault(channel_id, threading.Lock())
|
||||
|
||||
|
||||
PERSONALITY_FILES = [
|
||||
"IDENTITY.md",
|
||||
"SOUL.md",
|
||||
@@ -543,19 +579,28 @@ def send_message(
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
on_text: Callable[[str], None] | None = None,
|
||||
) -> str:
|
||||
"""High-level convenience: auto start or resume based on channel state."""
|
||||
session = get_active_session(channel_id)
|
||||
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
||||
if session is not None and session.get("session_id"):
|
||||
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
||||
# Use model from pre-set session if available, otherwise use provided model
|
||||
effective_model = model
|
||||
if session is not None and session.get("model"):
|
||||
effective_model = session["model"]
|
||||
response_text, _session_id = start_session(
|
||||
channel_id, message, effective_model, timeout, on_text=on_text
|
||||
)
|
||||
return response_text
|
||||
"""High-level convenience: auto start or resume based on channel state.
|
||||
|
||||
Concurrency: a per-`channel_id` `threading.Lock` serializes invocations
|
||||
that hit the same channel (e.g. text adapter + voice adapter racing on
|
||||
the same Discord guild text channel). Different channels run in
|
||||
parallel — each holds its own lock. Lock is acquired blocking; we rely
|
||||
on `timeout` (default 5 minutes) to bound the worst case rather than
|
||||
a non-blocking acquire (loss of fairness vs adapter-side queueing).
|
||||
"""
|
||||
with _get_session_lock(channel_id):
|
||||
session = get_active_session(channel_id)
|
||||
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
||||
if session is not None and session.get("session_id"):
|
||||
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
||||
# Use model from pre-set session if available, otherwise use provided model
|
||||
effective_model = model
|
||||
if session is not None and session.get("model"):
|
||||
effective_model = session["model"]
|
||||
response_text, _session_id = start_session(
|
||||
channel_id, message, effective_model, timeout, on_text=on_text
|
||||
)
|
||||
return response_text
|
||||
|
||||
|
||||
def clear_session(channel_id: str) -> bool:
|
||||
|
||||
@@ -154,8 +154,17 @@ def route_message(
|
||||
channel_cfg = _get_channel_config(channel_id)
|
||||
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
||||
|
||||
# Voice-mode augment: prepend speaker prefix so Claude knows who spoke
|
||||
# in a voice channel. Cheap now, future-proof for multi-speaker later.
|
||||
# (Engineering decision #14 in the plan.) Only the discord-voice adapter
|
||||
# triggers it — text adapters keep the message verbatim.
|
||||
claude_text = text
|
||||
if adapter_name == "discord-voice":
|
||||
user_name = _get_config().get("voice.user_name", "user") or "user"
|
||||
claude_text = f"[speaker:{user_name}] {text}"
|
||||
|
||||
try:
|
||||
response = send_message(channel_id, text, model=model, on_text=on_text)
|
||||
response = send_message(channel_id, claude_text, model=model, on_text=on_text)
|
||||
_set_last_response(channel_id, response)
|
||||
return response, False
|
||||
except Exception as e:
|
||||
|
||||
67
src/voice/_discord_voice_adapter.py
Normal file
67
src/voice/_discord_voice_adapter.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Adapter layer over `discord-ext-voice-recv` (vendored at vendor/).
|
||||
|
||||
If discord-ext-voice-recv breaks, swap to py-cord by rewriting only this file.
|
||||
Contract test in tests/test_voice_adapter_contract.py guards drift.
|
||||
|
||||
Downstream consumers (`src/voice/*`, `src/adapters/discord_voice.py`) MUST
|
||||
import from this file — never from `discord.ext.voice_recv` directly.
|
||||
|
||||
## Public API surface (stable across upstream changes)
|
||||
|
||||
- ``VoiceReceiveClient`` — alias for ``voice_recv.VoiceRecvClient``. Subclass
|
||||
of ``discord.VoiceClient`` with extra audio-receive plumbing.
|
||||
Key methods used by the pipeline:
|
||||
* ``await client.disconnect(force: bool = False)`` (from discord.VoiceClient)
|
||||
* ``client.listen(sink, *, after=None)`` — attach an ``AudioSink``;
|
||||
raises ``discord.ClientException`` if not connected or already listening
|
||||
* ``client.stop_listening()`` — detach the current sink
|
||||
* ``client.is_listening() -> bool``
|
||||
* ``client.stop()`` — stop both playing and listening
|
||||
* ``client.sink`` (property, getter+setter) — swap the active sink in place
|
||||
|
||||
- ``AudioSink`` — abstract base. Subclasses MUST implement:
|
||||
* ``write(user: Optional[discord.User|Member], data: VoiceData) -> None``
|
||||
* ``wants_opus() -> bool`` (True → receive opus bytes; False → receive PCM)
|
||||
* ``cleanup() -> None``
|
||||
|
||||
- ``VoiceData`` — per-packet container. Slots: ``packet``, ``source``, ``pcm``.
|
||||
``.pcm`` is decoded 48kHz s16le stereo bytes when ``wants_opus()`` is False.
|
||||
``.opus`` property returns the raw opus bytes from the underlying RTP packet.
|
||||
|
||||
- ``connect_voice(channel) -> VoiceReceiveClient`` — async helper, returns a
|
||||
connected receive-capable voice client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from discord.ext import voice_recv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
|
||||
|
||||
# --- Stable re-exports -------------------------------------------------------
|
||||
|
||||
VoiceReceiveClient = voice_recv.VoiceRecvClient
|
||||
AudioSink = voice_recv.AudioSink
|
||||
VoiceData = voice_recv.VoiceData
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VoiceReceiveClient",
|
||||
"AudioSink",
|
||||
"VoiceData",
|
||||
"connect_voice",
|
||||
]
|
||||
|
||||
|
||||
async def connect_voice(channel: "discord.VoiceChannel") -> VoiceReceiveClient:
|
||||
"""Connect to a Discord voice channel with the receive-capable client.
|
||||
|
||||
Thin wrapper around ``channel.connect(cls=VoiceRecvClient)`` so callers
|
||||
don't have to import the vendored class directly.
|
||||
"""
|
||||
return await channel.connect(cls=VoiceReceiveClient)
|
||||
267
src/voice/normalize.py
Normal file
267
src/voice/normalize.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Voice mode text normalization for Romanian TTS.
|
||||
|
||||
Pure functions — no side effects, no I/O, no logging. Strip markdown,
|
||||
expand numbers / currency / symbols / abbreviations into natural-sounding
|
||||
Romanian text. See plan: src/voice/normalize.py (Pas 3).
|
||||
|
||||
Pipeline order in normalize_for_tts:
|
||||
strip_markdown -> expand_abbreviations -> expand_currency
|
||||
-> expand_numbers_ro -> expand_symbols -> truncate(200)
|
||||
|
||||
Currency runs BEFORE generic number expansion so "12.50 RON" becomes
|
||||
"doisprezece lei și cincizeci de bani" rather than
|
||||
"doisprezece virgulă cincizeci RON".
|
||||
"""
|
||||
import re
|
||||
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
# ---------- Markdown ----------
|
||||
|
||||
_MARKDOWN_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)')
|
||||
_MARKDOWN_BOLD = re.compile(r'\*\*([^*]+)\*\*')
|
||||
_MARKDOWN_CODE = re.compile(r'`([^`\n]+)`')
|
||||
_MARKDOWN_ITALIC = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)')
|
||||
_MARKDOWN_HEADING = re.compile(r'^[ \t]*#{1,6}[ \t]+', re.MULTILINE)
|
||||
_MARKDOWN_LIST = re.compile(r'^[ \t]*[-*+][ \t]+', re.MULTILINE)
|
||||
|
||||
|
||||
def strip_markdown(text: str) -> str:
|
||||
"""Remove common markdown formatting, preserve the visible content."""
|
||||
text = _MARKDOWN_LINK.sub(r'\1', text)
|
||||
text = _MARKDOWN_BOLD.sub(r'\1', text)
|
||||
text = _MARKDOWN_CODE.sub(r'\1', text)
|
||||
text = _MARKDOWN_ITALIC.sub(r'\1', text)
|
||||
text = _MARKDOWN_HEADING.sub('', text)
|
||||
text = _MARKDOWN_LIST.sub('', text)
|
||||
return text
|
||||
|
||||
|
||||
# ---------- Number helpers ----------
|
||||
|
||||
def _needs_de(n: int) -> bool:
|
||||
"""Romanian: insert 'de' between numeral and noun for n >= 20,
|
||||
except when the trailing 1-19 portion makes it sound off
|
||||
(e.g., 105, 119 -> no 'de'; 120, 200 -> 'de').
|
||||
"""
|
||||
if n < 20:
|
||||
return False
|
||||
last = n % 100
|
||||
if 1 <= last <= 19:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _int_to_ro(n: int) -> str:
|
||||
return num2words(n, lang='ro')
|
||||
|
||||
|
||||
def _decimal_to_ro(s: str) -> str:
|
||||
"""Convert decimal string 'X.Y' to RO words.
|
||||
|
||||
Decimal part is read as a whole number ('3.14' -> 'trei virgulă paisprezece'),
|
||||
unless it has a leading zero ('3.05' -> 'trei virgulă zero cinci') so the
|
||||
magnitude is preserved.
|
||||
"""
|
||||
int_part, dec_part = s.split('.', 1)
|
||||
int_words = _int_to_ro(int(int_part))
|
||||
if dec_part.startswith('0') and len(dec_part) > 1:
|
||||
dec_words = ' '.join(_int_to_ro(int(d)) for d in dec_part)
|
||||
else:
|
||||
dec_words = _int_to_ro(int(dec_part))
|
||||
return f"{int_words} virgulă {dec_words}"
|
||||
|
||||
|
||||
# ---------- Numbers ----------
|
||||
|
||||
_NUM_TOKEN = re.compile(r'(?<!\w)(\d+(?:\.\d+)?)(?!\w)')
|
||||
|
||||
|
||||
def expand_numbers_ro(text: str) -> str:
|
||||
"""Expand bare numeric tokens to Romanian words.
|
||||
|
||||
Only matches pure number tokens (no surrounding letters). Decimals
|
||||
use 'virgulă' separator. Currency-bound numbers should already be
|
||||
handled by expand_currency before this runs.
|
||||
"""
|
||||
def _sub(match: re.Match) -> str:
|
||||
token = match.group(1)
|
||||
if '.' in token:
|
||||
return _decimal_to_ro(token)
|
||||
return _int_to_ro(int(token))
|
||||
|
||||
return _NUM_TOKEN.sub(_sub, text)
|
||||
|
||||
|
||||
# ---------- Time ----------
|
||||
|
||||
_TIME_PATTERN = re.compile(r'(?<!\d)([01]?\d|2[0-3]):([0-5]?\d)(?!\d)')
|
||||
|
||||
|
||||
def _format_minutes_ro(n: int) -> str:
|
||||
"""Romanian-correct feminine forms for minute counts (0-59)."""
|
||||
if n == 1:
|
||||
return "un minut"
|
||||
if n == 2:
|
||||
return "două minute"
|
||||
if n < 20:
|
||||
return f"{_int_to_ro(n)} minute"
|
||||
last = n % 10
|
||||
rest = n - last
|
||||
if last == 0:
|
||||
return f"{_int_to_ro(n)} de minute"
|
||||
if last == 1:
|
||||
return f"{_int_to_ro(rest)} și una de minute"
|
||||
if last == 2:
|
||||
return f"{_int_to_ro(rest)} și două de minute"
|
||||
return f"{_int_to_ro(rest)} și {_int_to_ro(last)} de minute"
|
||||
|
||||
|
||||
def expand_time(text: str) -> str:
|
||||
"""Expand ``HH:MM`` clock times into colloquial Romanian.
|
||||
|
||||
23:09 -> "douăzeci și trei și nouă minute"
|
||||
23:00 -> "douăzeci și trei fix"
|
||||
"""
|
||||
def _sub(match: re.Match) -> str:
|
||||
h = int(match.group(1))
|
||||
m = int(match.group(2))
|
||||
hour_str = _int_to_ro(h)
|
||||
if m == 0:
|
||||
return f"{hour_str} fix"
|
||||
return f"{hour_str} și {_format_minutes_ro(m)}"
|
||||
|
||||
return _TIME_PATTERN.sub(_sub, text)
|
||||
|
||||
|
||||
# ---------- Currency ----------
|
||||
|
||||
_CURRENCY_MAIN = {
|
||||
'RON': ('leu', 'lei'),
|
||||
'USD': ('dolar', 'dolari'),
|
||||
'EUR': ('euro', 'euro'),
|
||||
'GBP': ('liră', 'lire'),
|
||||
}
|
||||
|
||||
_CURRENCY_SUB = {
|
||||
'RON': ('ban', 'bani'),
|
||||
'USD': ('cent', 'cenți'),
|
||||
'EUR': ('cent', 'cenți'),
|
||||
'GBP': ('penny', 'pence'),
|
||||
}
|
||||
|
||||
_CURRENCY_PATTERNS = [
|
||||
# RON suffix (case-insensitive: RON, ron, lei)
|
||||
(re.compile(r'(?<!\w)(\d+(?:\.\d+)?)\s+(?:RON|lei)\b', re.IGNORECASE), 'RON'),
|
||||
# Prefix currencies
|
||||
(re.compile(r'\$(\d+(?:\.\d+)?)'), 'USD'),
|
||||
(re.compile(r'€(\d+(?:\.\d+)?)'), 'EUR'),
|
||||
(re.compile(r'£(\d+(?:\.\d+)?)'), 'GBP'),
|
||||
]
|
||||
|
||||
|
||||
def _format_currency_unit(n: int, singular: str, plural: str) -> str:
|
||||
"""Format integer amount + currency noun with proper RO singular/plural
|
||||
and 'de' particle. Uses 'un' (article) for n=1, not 'unu' (cardinal).
|
||||
"""
|
||||
if n == 1:
|
||||
return f"un {singular}"
|
||||
word = _int_to_ro(n)
|
||||
if _needs_de(n):
|
||||
return f"{word} de {plural}"
|
||||
return f"{word} {plural}"
|
||||
|
||||
|
||||
def _format_currency(amount: str, code: str) -> str:
|
||||
main_sg, main_pl = _CURRENCY_MAIN[code]
|
||||
if '.' in amount:
|
||||
whole_s, frac_s = amount.split('.', 1)
|
||||
# Normalize fractional part to 2 digits so "12.5 RON" reads as
|
||||
# 50 bani, not 5 bani.
|
||||
if len(frac_s) == 1:
|
||||
frac_s = frac_s + '0'
|
||||
elif len(frac_s) > 2:
|
||||
frac_s = frac_s[:2]
|
||||
whole = int(whole_s)
|
||||
frac = int(frac_s)
|
||||
whole_part = _format_currency_unit(whole, main_sg, main_pl)
|
||||
if frac == 0:
|
||||
return whole_part
|
||||
sub_sg, sub_pl = _CURRENCY_SUB[code]
|
||||
frac_part = _format_currency_unit(frac, sub_sg, sub_pl)
|
||||
return f"{whole_part} și {frac_part}"
|
||||
return _format_currency_unit(int(amount), main_sg, main_pl)
|
||||
|
||||
|
||||
def expand_currency(text: str) -> str:
|
||||
"""Expand currency amounts into natural Romanian.
|
||||
|
||||
Recognises ``<n> RON`` / ``<n> lei`` suffix and ``$``, ``€``, ``£`` prefix
|
||||
forms with optional 2-decimal fractional part (treated as sub-unit:
|
||||
bani / cenți / pence).
|
||||
"""
|
||||
for pattern, code in _CURRENCY_PATTERNS:
|
||||
text = pattern.sub(lambda m, c=code: _format_currency(m.group(1), c), text)
|
||||
return text
|
||||
|
||||
|
||||
# ---------- Symbols ----------
|
||||
|
||||
def expand_symbols(text: str) -> str:
|
||||
"""Replace common symbols with their Romanian spoken form."""
|
||||
text = text.replace('%', ' la sută')
|
||||
text = text.replace('&', ' și ')
|
||||
text = text.replace('@', ' la ')
|
||||
text = text.replace('°', ' grade')
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
|
||||
|
||||
from tools.tts import sanitize_for_supertonic as sanitize_punctuation
|
||||
|
||||
|
||||
# ---------- Abbreviations ----------
|
||||
|
||||
# Longer patterns first so 'ș.a.m.d.' wins over 'ș.a.'
|
||||
_ABBREVIATIONS = [
|
||||
(re.compile(r'(?<!\w)[șş]\.a\.m\.d\.', re.IGNORECASE), 'și așa mai departe'),
|
||||
(re.compile(r'(?<!\w)[șş]\.a\.', re.IGNORECASE), 'și altele'),
|
||||
(re.compile(r'(?<!\w)etc\.', re.IGNORECASE), 'etcetera'),
|
||||
(re.compile(r'(?<!\w)dl\.', re.IGNORECASE), 'domnul'),
|
||||
(re.compile(r'(?<!\w)dna\.', re.IGNORECASE), 'doamna'),
|
||||
(re.compile(r'(?<!\w)nr\.', re.IGNORECASE), 'numărul'),
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text: str) -> str:
|
||||
"""Expand Romanian abbreviations into their full forms."""
|
||||
for pattern, replacement in _ABBREVIATIONS:
|
||||
text = pattern.sub(replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
# ---------- Top-level pipeline ----------
|
||||
|
||||
_MAX_WORDS = 200
|
||||
_TRUNCATE_SUFFIX = "Restul l-am scris în chat."
|
||||
|
||||
|
||||
def normalize_for_tts(text: str) -> str:
|
||||
"""Apply the full normalization pipeline and truncate to 200 words.
|
||||
|
||||
If the text exceeds 200 words, the first 200 are kept and the suffix
|
||||
"Restul l-am scris în chat." is appended so the listener knows the
|
||||
response continues in the text channel mirror.
|
||||
"""
|
||||
text = strip_markdown(text)
|
||||
text = sanitize_punctuation(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = expand_time(text)
|
||||
text = expand_currency(text)
|
||||
text = expand_numbers_ro(text)
|
||||
text = expand_symbols(text)
|
||||
words = text.split()
|
||||
if len(words) > _MAX_WORDS:
|
||||
text = ' '.join(words[:_MAX_WORDS]) + f" {_TRUNCATE_SUFFIX}"
|
||||
return text.strip()
|
||||
580
src/voice/pipeline.py
Normal file
580
src/voice/pipeline.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""Central voice pipeline: VAD -> STT -> Claude -> TTS for Discord voice.
|
||||
|
||||
``VoiceSession`` binds per-call state — voice_client, TTS queue, transcript
|
||||
JSONL buffer, whitelist, presence — and exposes a single idempotent
|
||||
``cleanup()`` invoked from every exit path (user /voice leave, network
|
||||
disconnect, crash via ``__exit__``, auto-leave timer, user leaves channel).
|
||||
|
||||
``EchoVoiceSink`` is the discord-ext-voice-recv ``AudioSink`` subclass that
|
||||
runs in the voice_recv reader thread. It batches 20ms PCM packets into
|
||||
100ms windows for silero-vad inference, marks per-user speech timestamps,
|
||||
and on 800ms cumulative silence flushes the accumulated audio through
|
||||
faster-whisper. Hallucinated segments (``no_speech_prob > 0.6``) are
|
||||
dropped. Valid transcripts are scheduled onto the session's event loop
|
||||
via ``asyncio.run_coroutine_threadsafe``.
|
||||
|
||||
The bot's own ``user.id`` is filtered FIRST inside ``write()`` — load-bearing
|
||||
echo prevention so a future whitelist expansion (Bianca, etc.) never lets
|
||||
the bot transcribe itself.
|
||||
|
||||
See plan: ``src/voice/pipeline.py`` (Pas 5), Engineering decisions #4
|
||||
(VAD 100ms batched), #5 (cleanup centralizat), #7 (bot.user.id explicit
|
||||
guard).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.voice._discord_voice_adapter import AudioSink, VoiceData
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Discord delivers 48kHz s16le stereo PCM, 20ms per packet (3840 bytes).
|
||||
SAMPLE_RATE_DISCORD = 48000
|
||||
SAMPLE_RATE_WHISPER = 16000
|
||||
PACKET_MS = 20
|
||||
PACKET_BYTES = 3840 # 48000 Hz * 0.020 s * 2 channels * 2 bytes
|
||||
VAD_WINDOW_MS = 100 # batch 5 * 20ms packets per VAD inference (Decision #4)
|
||||
VAD_WINDOW_BYTES = PACKET_BYTES * (VAD_WINDOW_MS // PACKET_MS)
|
||||
VAD_THRESHOLD = 0.5
|
||||
SILENCE_FLUSH_MS = 800
|
||||
NO_SPEECH_DROP_THRESHOLD = 0.6
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
LOGS_DIR = PROJECT_ROOT / "logs"
|
||||
VOICE_METRICS_PATH = LOGS_DIR / "voice_metrics.jsonl"
|
||||
|
||||
|
||||
# ---------- Lazy model singletons ----------
|
||||
|
||||
_whisper_model: Any = None
|
||||
_whisper_lock = threading.Lock()
|
||||
_silero_model: Any = None
|
||||
_silero_get_timestamps: Any = None
|
||||
_silero_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_whisper_model() -> Any:
|
||||
"""Lazy-load faster-whisper ``small`` int8 with the spike-validated
|
||||
``cpu_threads=4`` (see ``tasks/voice-bench-results.md``)."""
|
||||
global _whisper_model
|
||||
if _whisper_model is not None:
|
||||
return _whisper_model
|
||||
with _whisper_lock:
|
||||
if _whisper_model is not None:
|
||||
return _whisper_model
|
||||
from faster_whisper import WhisperModel
|
||||
_whisper_model = WhisperModel(
|
||||
"small", device="cpu", compute_type="int8", cpu_threads=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
return _whisper_model
|
||||
|
||||
|
||||
def _get_silero_vad():
|
||||
"""Lazy-load silero-vad. Returns ``(model, get_speech_timestamps)``."""
|
||||
global _silero_model, _silero_get_timestamps
|
||||
if _silero_model is not None:
|
||||
return _silero_model, _silero_get_timestamps
|
||||
with _silero_lock:
|
||||
if _silero_model is not None:
|
||||
return _silero_model, _silero_get_timestamps
|
||||
from silero_vad import get_speech_timestamps, load_silero_vad
|
||||
_silero_model = load_silero_vad()
|
||||
_silero_get_timestamps = get_speech_timestamps
|
||||
return _silero_model, _silero_get_timestamps
|
||||
|
||||
|
||||
# ---------- Audio helpers ----------
|
||||
|
||||
def _pcm48_stereo_to_16_mono(pcm: bytes) -> np.ndarray:
|
||||
"""Discord 48kHz s16le stereo bytes -> 16kHz mono float32 in [-1, 1].
|
||||
|
||||
Cheap downsample: average the two channels, then average every 3
|
||||
samples (48k / 3 = 16k). faster-whisper + silero-vad accept the
|
||||
resulting ``np.float32`` array directly.
|
||||
"""
|
||||
if not pcm:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
samples = np.frombuffer(pcm, dtype=np.int16)
|
||||
if samples.size % 2 != 0:
|
||||
samples = samples[:-1]
|
||||
stereo = samples.reshape(-1, 2)
|
||||
mono = stereo.mean(axis=1).astype(np.float32) / 32768.0
|
||||
if mono.size == 0:
|
||||
return mono
|
||||
trim = (mono.size // 3) * 3
|
||||
if trim == 0:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
mono = mono[:trim].reshape(-1, 3).mean(axis=1)
|
||||
return mono.astype(np.float32)
|
||||
|
||||
|
||||
# ---------- VoiceSession ----------
|
||||
|
||||
class VoiceSession:
|
||||
"""Per-voice-call state with a single idempotent ``cleanup()``."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: int,
|
||||
guild_id: int,
|
||||
text_channel: Any,
|
||||
voice_client: Any,
|
||||
bot: Any,
|
||||
ttsq: Any,
|
||||
whitelist: Optional[set] = None,
|
||||
record_enabled: bool = False,
|
||||
mirror_enabled: bool = True,
|
||||
transcripts_jsonl_path: Optional[Path] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
router_route_message: Optional[Callable] = None,
|
||||
):
|
||||
self.channel_id = int(channel_id)
|
||||
self.guild_id = int(guild_id)
|
||||
self.text_channel = text_channel
|
||||
self.voice_client = voice_client
|
||||
self.bot = bot
|
||||
self.ttsq = ttsq
|
||||
self.whitelist: set = set(whitelist or set())
|
||||
self.record_enabled = bool(record_enabled)
|
||||
self.mirror_enabled = bool(mirror_enabled)
|
||||
self.transcripts_jsonl_path = transcripts_jsonl_path
|
||||
self.loop = loop
|
||||
# Injection seam so tests can replace router.route_message without
|
||||
# mocking the whole module.
|
||||
if router_route_message is None:
|
||||
from src.router import route_message as _rm
|
||||
self._route_message = _rm
|
||||
else:
|
||||
self._route_message = router_route_message
|
||||
|
||||
self.last_activity_ts = time.monotonic()
|
||||
self._jsonl_fh = None
|
||||
self._lock = threading.Lock()
|
||||
self._cleaned_up = False
|
||||
self._lock_owner_thread: Optional[int] = None
|
||||
|
||||
# ----- context manager -----
|
||||
|
||||
def __enter__(self) -> "VoiceSession":
|
||||
self._lock.acquire()
|
||||
self._lock_owner_thread = threading.get_ident()
|
||||
if self.record_enabled and self.transcripts_jsonl_path is not None:
|
||||
try:
|
||||
self.transcripts_jsonl_path.parent.mkdir(
|
||||
parents=True, exist_ok=True,
|
||||
)
|
||||
self._jsonl_fh = open(
|
||||
self.transcripts_jsonl_path, "a",
|
||||
buffering=1, encoding="utf-8",
|
||||
)
|
||||
except OSError as e:
|
||||
log.warning("voice transcript open failed: %s", e)
|
||||
self._jsonl_fh = None
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
self.cleanup("exit")
|
||||
return False # never suppress exceptions
|
||||
|
||||
# ----- cleanup (centralized, idempotent) -----
|
||||
|
||||
def cleanup(self, reason: str) -> None:
|
||||
"""Single drain path for ALL 5 exit scenarios. Safe to call twice."""
|
||||
if self._cleaned_up:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
|
||||
# 1. Flush or discard JSONL transcript.
|
||||
if self._jsonl_fh is not None:
|
||||
try:
|
||||
self._jsonl_fh.flush()
|
||||
self._jsonl_fh.close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice transcript flush failed: %s", e)
|
||||
self._jsonl_fh = None
|
||||
if (not self.record_enabled
|
||||
and self.transcripts_jsonl_path is not None
|
||||
and self.transcripts_jsonl_path.exists()):
|
||||
try:
|
||||
self.transcripts_jsonl_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# 2. Restore bot presence (clear Listening activity).
|
||||
if self.bot is not None:
|
||||
try:
|
||||
change = getattr(self.bot, "change_presence", None)
|
||||
if callable(change):
|
||||
coro = change(activity=None)
|
||||
if asyncio.iscoroutine(coro):
|
||||
if self.loop is not None and self.loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(coro, self.loop)
|
||||
else:
|
||||
# Best-effort: close the coroutine so Python
|
||||
# doesn't emit "coroutine was never awaited".
|
||||
coro.close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice presence restore failed: %s", e)
|
||||
|
||||
# 3. Tear down the voice client.
|
||||
if self.voice_client is not None:
|
||||
try:
|
||||
self.voice_client.cleanup()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice_client.cleanup failed: %s", e)
|
||||
|
||||
# 4. Stop the TTS queue worker.
|
||||
if self.ttsq is not None:
|
||||
try:
|
||||
self.ttsq.stop()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("ttsq.stop failed: %s", e)
|
||||
|
||||
# 5. Release the session lock (held since __enter__).
|
||||
try:
|
||||
if self._lock.locked():
|
||||
self._lock.release()
|
||||
except RuntimeError:
|
||||
# Released from a different thread than acquired it — already
|
||||
# free for the next caller; nothing to do.
|
||||
pass
|
||||
|
||||
self._log_metric({"event": "cleanup", "reason": reason})
|
||||
|
||||
# ----- segment completion (scheduled from sink) -----
|
||||
|
||||
async def on_segment_done(
|
||||
self,
|
||||
speaker_id: int,
|
||||
text: str,
|
||||
no_speech_prob: float,
|
||||
) -> None:
|
||||
"""Mirror, persist, route to Claude, drive TTS via streaming callback."""
|
||||
if self._cleaned_up:
|
||||
return
|
||||
self.last_activity_ts = time.monotonic()
|
||||
speaker_name = self._resolve_speaker_name(speaker_id)
|
||||
|
||||
# Drop any TTS frames from the previous turn so a new utterance cuts off
|
||||
# stale Echo speech (barge-in) and never mixes with the new response.
|
||||
try:
|
||||
self.ttsq.clear()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("ttsq.clear failed: %s", e)
|
||||
|
||||
# 1. Mirror to text channel (one Unicode 🎤 — exception per plan).
|
||||
if self.mirror_enabled and self.text_channel is not None:
|
||||
try:
|
||||
send = getattr(self.text_channel, "send", None)
|
||||
if callable(send):
|
||||
coro = send(f"\U0001f3a4 {speaker_name}: \"{text}\"")
|
||||
if asyncio.iscoroutine(coro):
|
||||
await coro
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice mirror send failed: %s", e)
|
||||
|
||||
# 2. Append to JSONL transcript buffer if recording.
|
||||
if self._jsonl_fh is not None:
|
||||
try:
|
||||
self._jsonl_fh.write(
|
||||
json.dumps({
|
||||
"ts": time.time(),
|
||||
"speaker_id": speaker_id,
|
||||
"speaker": speaker_name,
|
||||
"text": text,
|
||||
"no_speech_prob": no_speech_prob,
|
||||
}, ensure_ascii=False) + "\n"
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice transcript write failed: %s", e)
|
||||
|
||||
block_count = [0]
|
||||
|
||||
def voice_stream_callback(block: str) -> None:
|
||||
"""Called once per Claude streamed text block — pushes to TTS."""
|
||||
block_count[0] += 1
|
||||
log.info("voice stream block #%d (%d chars): %r",
|
||||
block_count[0], len(block or ""), (block or "")[:80])
|
||||
try:
|
||||
self.ttsq.push_text(block)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("ttsq.push_text failed: %s", e)
|
||||
|
||||
# Dispatch to Claude. send_message is sync subprocess, run on
|
||||
# a worker thread so the loop stays responsive for mirror/TTS.
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._route_message,
|
||||
str(self.channel_id),
|
||||
str(speaker_id),
|
||||
text,
|
||||
None, # model
|
||||
voice_stream_callback, # on_text
|
||||
"discord-voice", # adapter_name
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.error("route_message voice path failed: %s", e)
|
||||
|
||||
# ----- helpers -----
|
||||
|
||||
def _resolve_speaker_name(self, speaker_id: int) -> str:
|
||||
"""Best-effort display name lookup via the bot user cache."""
|
||||
try:
|
||||
if self.bot is not None and hasattr(self.bot, "get_user"):
|
||||
user = self.bot.get_user(speaker_id)
|
||||
if user is not None:
|
||||
name = getattr(user, "display_name", None) or getattr(
|
||||
user, "name", None,
|
||||
)
|
||||
if name:
|
||||
return str(name)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return str(speaker_id)
|
||||
|
||||
def _log_metric(self, payload: dict) -> None:
|
||||
"""Append a structured event to ``logs/voice_metrics.jsonl``."""
|
||||
event = {"ts": time.time(), "channel_id": self.channel_id, **payload}
|
||||
try:
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with open(VOICE_METRICS_PATH, "a", buffering=1, encoding="utf-8") as f:
|
||||
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------- EchoVoiceSink ----------
|
||||
|
||||
class EchoVoiceSink(AudioSink):
|
||||
"""PCM-in sink: per-user 20ms buffer -> 100ms VAD windows -> 800ms
|
||||
silence triggers Whisper STT -> schedules ``on_segment_done`` on the
|
||||
session loop.
|
||||
|
||||
Lives in the voice_recv reader thread; uses ``threading`` primitives
|
||||
only (no asyncio in the hot path).
|
||||
"""
|
||||
|
||||
def __init__(self, session: VoiceSession, bot_user_id: int):
|
||||
super().__init__()
|
||||
self.session = session
|
||||
self.bot_user_id = int(bot_user_id) if bot_user_id is not None else 0
|
||||
self.whitelist: set = set(session.whitelist or set())
|
||||
self._user_buffers: dict[int, bytearray] = {}
|
||||
self._packet_accum: dict[int, bytearray] = {}
|
||||
self._last_speech_ts: dict[int, float] = {}
|
||||
self._has_speech: dict[int, bool] = {}
|
||||
self._sink_lock = threading.Lock()
|
||||
# Diagnostics: log once-per-user when packets first arrive and when
|
||||
# VAD first detects speech. Cheap, but tells us exactly where the
|
||||
# chain breaks when "I spoke but Echo heard nothing" happens.
|
||||
self._first_packet_logged: set[int] = set()
|
||||
self._first_speech_logged: set[int] = set()
|
||||
# Background poller that triggers the silence flush even when Discord
|
||||
# DTX stops delivering RTP packets after the user stops speaking. Without
|
||||
# this, sink.write would stop firing and STT would never run on the
|
||||
# final utterance.
|
||||
self._poller_stop = threading.Event()
|
||||
self._poller_thread = threading.Thread(
|
||||
target=self._silence_flush_poller,
|
||||
name="echo-voice-flush-poller",
|
||||
daemon=True,
|
||||
)
|
||||
self._poller_thread.start()
|
||||
|
||||
def wants_opus(self) -> bool:
|
||||
return False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self._poller_stop.set()
|
||||
with self._sink_lock:
|
||||
self._user_buffers.clear()
|
||||
self._packet_accum.clear()
|
||||
self._last_speech_ts.clear()
|
||||
self._has_speech.clear()
|
||||
|
||||
def write(self, user, voice_data: VoiceData) -> None:
|
||||
# ---- FIRST GUARD (LOAD-BEARING): bot's own voice ---------------
|
||||
if user is None:
|
||||
return
|
||||
uid = int(getattr(user, "id", 0) or 0)
|
||||
if uid == 0:
|
||||
return
|
||||
if uid == self.bot_user_id:
|
||||
return
|
||||
|
||||
# ---- SECOND GUARD: whitelist filter ----------------------------
|
||||
if self.whitelist and uid not in self.whitelist:
|
||||
return
|
||||
|
||||
pcm = getattr(voice_data, "pcm", None)
|
||||
if not pcm:
|
||||
return
|
||||
|
||||
if uid not in self._first_packet_logged:
|
||||
self._first_packet_logged.add(uid)
|
||||
log.info("voice sink: first PCM packet from user %s (%d bytes)", uid, len(pcm))
|
||||
|
||||
window_pcm: Optional[bytes] = None
|
||||
pcm_for_stt: Optional[bytes] = None
|
||||
|
||||
try:
|
||||
with self._sink_lock:
|
||||
buf = self._user_buffers.setdefault(uid, bytearray())
|
||||
accum = self._packet_accum.setdefault(uid, bytearray())
|
||||
buf.extend(pcm)
|
||||
accum.extend(pcm)
|
||||
if len(accum) >= VAD_WINDOW_BYTES:
|
||||
window_pcm = bytes(accum[:VAD_WINDOW_BYTES])
|
||||
del accum[:VAD_WINDOW_BYTES]
|
||||
|
||||
if window_pcm is not None:
|
||||
if self._vad_detects_speech(window_pcm):
|
||||
if uid not in self._first_speech_logged:
|
||||
self._first_speech_logged.add(uid)
|
||||
log.info("voice sink: VAD detected speech from user %s", uid)
|
||||
with self._sink_lock:
|
||||
self._last_speech_ts[uid] = time.monotonic()
|
||||
self._has_speech[uid] = True
|
||||
|
||||
pcm_for_stt = self._take_flushable_pcm(uid)
|
||||
if pcm_for_stt:
|
||||
self._flush_to_stt(uid, pcm_for_stt)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("EchoVoiceSink.write failed: %s", e)
|
||||
|
||||
def _take_flushable_pcm(self, uid: int) -> Optional[bytes]:
|
||||
"""If user `uid` has buffered speech that's been silent ≥SILENCE_FLUSH_MS,
|
||||
consume the buffer and return it. Otherwise return None."""
|
||||
with self._sink_lock:
|
||||
if not self._has_speech.get(uid):
|
||||
return None
|
||||
last = self._last_speech_ts.get(uid, 0.0)
|
||||
silence_ms = (time.monotonic() - last) * 1000.0
|
||||
if silence_ms < SILENCE_FLUSH_MS:
|
||||
return None
|
||||
pcm = bytes(self._user_buffers.get(uid, b""))
|
||||
self._user_buffers[uid] = bytearray()
|
||||
self._packet_accum[uid] = bytearray()
|
||||
self._has_speech[uid] = False
|
||||
return pcm if pcm else None
|
||||
|
||||
def _silence_flush_poller(self) -> None:
|
||||
"""Background tick: Discord DTX stops sending RTP packets when the user
|
||||
goes silent, so the inline flush check in `write()` never fires for the
|
||||
last utterance. Poll every 200ms so the trailing audio actually reaches
|
||||
Whisper."""
|
||||
while not self._poller_stop.wait(0.2):
|
||||
try:
|
||||
with self._sink_lock:
|
||||
pending = [uid for uid, has in self._has_speech.items() if has]
|
||||
for uid in pending:
|
||||
pcm = self._take_flushable_pcm(uid)
|
||||
if pcm:
|
||||
self._flush_to_stt(uid, pcm)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("silence flush poller iter failed: %s", e)
|
||||
|
||||
# ----- VAD -----
|
||||
|
||||
def _vad_detects_speech(self, pcm48_stereo: bytes) -> bool:
|
||||
"""Run silero-vad on a 100ms window. silero-vad v5+ requires exactly
|
||||
512 samples per call at 16kHz, so we slice the window into 512-sample
|
||||
chunks and return True if any chunk crosses the threshold."""
|
||||
try:
|
||||
mono16 = _pcm48_stereo_to_16_mono(pcm48_stereo)
|
||||
if mono16.size == 0:
|
||||
return False
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
rms = float(np.sqrt(np.mean(mono16.astype(np.float64) ** 2)))
|
||||
return rms > 0.02
|
||||
model, _ = _get_silero_vad()
|
||||
chunk = 512 # silero-vad v5+ hard requirement at 16kHz
|
||||
max_prob = 0.0
|
||||
with torch.no_grad():
|
||||
for start in range(0, mono16.size - chunk + 1, chunk):
|
||||
seg = mono16[start:start + chunk]
|
||||
p = float(model(torch.from_numpy(seg), SAMPLE_RATE_WHISPER).item())
|
||||
if p > max_prob:
|
||||
max_prob = p
|
||||
if p >= VAD_THRESHOLD:
|
||||
return True
|
||||
return False
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.debug("VAD inference failed: %s", e)
|
||||
return False
|
||||
|
||||
# ----- STT flush -----
|
||||
|
||||
def _flush_to_stt(self, user_id: int, pcm48_stereo: bytes) -> None:
|
||||
"""Downsample, Whisper-transcribe RO, drop hallucinations, dispatch."""
|
||||
try:
|
||||
mono16 = _pcm48_stereo_to_16_mono(pcm48_stereo)
|
||||
if mono16.size == 0:
|
||||
return
|
||||
model = _get_whisper_model()
|
||||
segments, _info = model.transcribe(
|
||||
mono16, language="ro", beam_size=5,
|
||||
initial_prompt=(
|
||||
"Echo Core, asistent personal AI românesc al lui Marius. "
|
||||
"Conversație colocvială în română."
|
||||
),
|
||||
condition_on_previous_text=False,
|
||||
)
|
||||
text_parts: list[str] = []
|
||||
worst_no_speech = 0.0
|
||||
for seg in segments:
|
||||
no_sp = float(getattr(seg, "no_speech_prob", 0.0) or 0.0)
|
||||
if no_sp > worst_no_speech:
|
||||
worst_no_speech = no_sp
|
||||
if no_sp > NO_SPEECH_DROP_THRESHOLD:
|
||||
continue
|
||||
seg_text = (getattr(seg, "text", "") or "").strip()
|
||||
if seg_text:
|
||||
text_parts.append(seg_text)
|
||||
if not text_parts:
|
||||
return
|
||||
text = " ".join(text_parts).strip()
|
||||
if not text:
|
||||
return
|
||||
self._schedule_segment_done(user_id, text, worst_no_speech)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("Whisper transcribe failed: %s", e)
|
||||
|
||||
def _schedule_segment_done(
|
||||
self, user_id: int, text: str, no_speech_prob: float,
|
||||
) -> None:
|
||||
loop = self.session.loop
|
||||
if loop is None or not loop.is_running():
|
||||
log.debug("voice session loop missing — dropping segment")
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.session.on_segment_done(user_id, text, no_speech_prob),
|
||||
loop,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
log.warning("voice segment dispatch failed: %s", e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VoiceSession",
|
||||
"EchoVoiceSink",
|
||||
"SILENCE_FLUSH_MS",
|
||||
"VAD_THRESHOLD",
|
||||
"VAD_WINDOW_MS",
|
||||
"NO_SPEECH_DROP_THRESHOLD",
|
||||
]
|
||||
333
src/voice/tts_stream.py
Normal file
333
src/voice/tts_stream.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Streaming TTS with clause-level chunking for Discord voice mode.
|
||||
|
||||
A worker thread consumes text -> produces 20ms PCM frames on a queue.Queue.
|
||||
``EchoStreamingAudioSource`` pulls frames into Discord's audio thread so a
|
||||
single ``voice_client.play()`` call lasts the whole turn (eliminates the
|
||||
RTP gap between successive ``play()`` calls and the race with barge-in
|
||||
``stop()``). See plan: src/voice/tts_stream.py (Pas 6 / Lane TTS),
|
||||
Engineering decisions #6, #8, #15.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import threading
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import discord
|
||||
|
||||
from src.voice.normalize import normalize_for_tts
|
||||
from tools.tts import synthesize
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Discord wants 20ms of 16-bit 48kHz stereo PCM per frame.
|
||||
# 48000 Hz * 0.020 s * 2 channels * 2 bytes = 3840 bytes.
|
||||
FRAME_BYTES = 3840
|
||||
TARGET_SAMPLE_RATE = 48000
|
||||
TARGET_CHANNELS = 2
|
||||
TARGET_SAMPLE_WIDTH = 2
|
||||
|
||||
# Sentinel pushed onto the text queue to ask the worker to exit cleanly.
|
||||
_POISON = object()
|
||||
|
||||
|
||||
# ---------- Clause segmentation ----------
|
||||
|
||||
# Split at Romanian sentence punctuation followed by whitespace. The
|
||||
# trailing whitespace requirement protects mid-number (1.000), mid-decimal
|
||||
# (12.5), and mid-abbreviation (M.D.) tokens, since none of those have a
|
||||
# space right after the inner punctuation.
|
||||
_CLAUSE_SPLIT = re.compile(r'(?<=[,;:.!?])\s+')
|
||||
|
||||
|
||||
def clause_segments(text: str, min_words: int = 8) -> Iterator[str]:
|
||||
"""Yield text in clause-sized chunks for streaming TTS.
|
||||
|
||||
Splits at ``, ; : . ! ?`` boundaries (only when the punctuation is
|
||||
followed by whitespace, so numbers / decimals / abbreviations stay
|
||||
intact). Short clauses are buffered and merged with the next one
|
||||
until the accumulated chunk has at least ``min_words`` words. The
|
||||
final remainder is always yielded, even if it's shorter than
|
||||
``min_words`` -- otherwise the tail of the response would never
|
||||
reach the TTS.
|
||||
"""
|
||||
if text is None:
|
||||
return
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
pieces = [p.strip() for p in _CLAUSE_SPLIT.split(text) if p and p.strip()]
|
||||
if not pieces:
|
||||
return
|
||||
buffer = ''
|
||||
for clause in pieces:
|
||||
buffer = (buffer + ' ' + clause).strip() if buffer else clause
|
||||
if len(buffer.split()) >= min_words:
|
||||
yield buffer
|
||||
buffer = ''
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
|
||||
# ---------- WAV -> PCM frame conversion ----------
|
||||
|
||||
def _ffmpeg_resample(wav_bytes: bytes) -> bytes:
|
||||
"""Convert any WAV payload to raw 48kHz stereo s16le PCM via ffmpeg.
|
||||
|
||||
ffmpeg is already an Echo Core hard dependency (heartbeat, video
|
||||
transcription). Using a stdin/stdout pipe keeps the synth tempfile
|
||||
short-lived and avoids extra disk traffic.
|
||||
"""
|
||||
proc = subprocess.run(
|
||||
[
|
||||
'ffmpeg', '-hide_banner', '-loglevel', 'error',
|
||||
'-i', 'pipe:0',
|
||||
'-f', 's16le',
|
||||
'-ar', str(TARGET_SAMPLE_RATE),
|
||||
'-ac', str(TARGET_CHANNELS),
|
||||
'-acodec', 'pcm_s16le',
|
||||
'pipe:1',
|
||||
],
|
||||
input=wav_bytes,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
err = proc.stderr.decode('utf-8', errors='replace')[:200]
|
||||
raise RuntimeError(f"ffmpeg resample failed (rc={proc.returncode}): {err}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
def _is_target_format(wav_bytes: bytes) -> bool:
|
||||
"""Quick check whether the WAV already matches Discord's PCM format."""
|
||||
try:
|
||||
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||
return (
|
||||
w.getframerate() == TARGET_SAMPLE_RATE
|
||||
and w.getnchannels() == TARGET_CHANNELS
|
||||
and w.getsampwidth() == TARGET_SAMPLE_WIDTH
|
||||
and w.getcomptype() == 'NONE'
|
||||
)
|
||||
except (wave.Error, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
def _extract_pcm_native(wav_bytes: bytes) -> bytes:
|
||||
"""Strip the WAV header and return raw PCM (target format assumed)."""
|
||||
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||
return w.readframes(w.getnframes())
|
||||
|
||||
|
||||
def wav_to_pcm_20ms_frames(wav_bytes: bytes) -> List[bytes]:
|
||||
"""Parse a WAV blob, normalize to 48kHz s16le stereo, slice into 20ms frames.
|
||||
|
||||
The final frame is zero-padded to a full 3840 bytes so Discord's audio
|
||||
thread always reads whole frames. Empty input yields an empty list.
|
||||
"""
|
||||
if not wav_bytes:
|
||||
return []
|
||||
pcm = _extract_pcm_native(wav_bytes) if _is_target_format(wav_bytes) else _ffmpeg_resample(wav_bytes)
|
||||
if not pcm:
|
||||
return []
|
||||
frames: List[bytes] = []
|
||||
for offset in range(0, len(pcm), FRAME_BYTES):
|
||||
chunk = pcm[offset:offset + FRAME_BYTES]
|
||||
if len(chunk) < FRAME_BYTES:
|
||||
chunk = chunk + b'\x00' * (FRAME_BYTES - len(chunk))
|
||||
frames.append(chunk)
|
||||
return frames
|
||||
|
||||
|
||||
# ---------- TTS worker queue ----------
|
||||
|
||||
class TTSQueue:
|
||||
"""Worker thread: text in -> 20ms PCM frames out.
|
||||
|
||||
Usage::
|
||||
|
||||
ttsq = TTSQueue(voice_id="M2", lang="ro")
|
||||
ttsq.start()
|
||||
ttsq.push_text("salut Marius, ce mai faci?")
|
||||
voice_client.play(EchoStreamingAudioSource(ttsq))
|
||||
# ... barge-in detected:
|
||||
ttsq.clear()
|
||||
# ... session over:
|
||||
ttsq.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, voice_id: str = "M2", lang: str = "ro"):
|
||||
self.voice_id = voice_id
|
||||
self.lang = lang
|
||||
self._text_queue: queue.Queue = queue.Queue()
|
||||
self._pcm_queue: queue.Queue = queue.Queue()
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# --- lifecycle ---
|
||||
|
||||
def start(self) -> None:
|
||||
if self._worker_thread is not None and self._worker_thread.is_alive():
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"tts-worker-{self.voice_id}",
|
||||
daemon=True,
|
||||
)
|
||||
self._worker_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to exit, drain queues, join (timeout 5s)."""
|
||||
self._stop_event.set()
|
||||
# Wake the worker if it's blocked on get(timeout=...).
|
||||
self._text_queue.put(_POISON)
|
||||
thread = self._worker_thread
|
||||
if thread is not None:
|
||||
thread.join(timeout=5.0)
|
||||
self._worker_thread = None
|
||||
self._drain(self._text_queue)
|
||||
self._drain(self._pcm_queue)
|
||||
|
||||
# --- producer side ---
|
||||
|
||||
def push_text(self, text: str) -> None:
|
||||
"""Normalize, segment into clauses, enqueue each clause for synthesis."""
|
||||
if not text:
|
||||
return
|
||||
cleaned = normalize_for_tts(text)
|
||||
n = 0
|
||||
for clause in clause_segments(cleaned):
|
||||
clause = clause.strip()
|
||||
if clause:
|
||||
self._text_queue.put(clause)
|
||||
n += 1
|
||||
log.info("ttsq.push_text: input %d chars → %d clauses queued", len(text), n)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Drop everything pending (used for barge-in)."""
|
||||
self._drain(self._text_queue)
|
||||
self._drain(self._pcm_queue)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self._text_queue.empty() and self._pcm_queue.empty()
|
||||
|
||||
# --- consumer side (called by EchoStreamingAudioSource) ---
|
||||
|
||||
def get_frame_nowait(self) -> Optional[bytes]:
|
||||
"""Return the next PCM frame if available, else None — no blocking.
|
||||
|
||||
Blocking inside the player's read() loop wrecks Discord's 20ms cadence
|
||||
and the client interprets the stream as stuttering / out-of-order.
|
||||
"""
|
||||
try:
|
||||
return self._pcm_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
# --- internals ---
|
||||
|
||||
@staticmethod
|
||||
def _drain(q: queue.Queue) -> None:
|
||||
while True:
|
||||
try:
|
||||
q.get_nowait()
|
||||
except queue.Empty:
|
||||
return
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
item = self._text_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if item is _POISON:
|
||||
break
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
preview = item[:60]
|
||||
try:
|
||||
result = synthesize(item, voice=self.voice_id, lang=self.lang)
|
||||
except Exception as e:
|
||||
log.warning("TTS synth raised for %r: %s", preview, e)
|
||||
continue
|
||||
if not result.get('ok'):
|
||||
log.warning("TTS synth not ok for %r: %s", preview, result.get('error'))
|
||||
continue
|
||||
path = result.get('path')
|
||||
if not path:
|
||||
log.warning("TTS synth ok but no path for %r", preview)
|
||||
continue
|
||||
wav_bytes = b''
|
||||
try:
|
||||
wav_bytes = Path(path).read_bytes()
|
||||
except OSError as e:
|
||||
log.warning("TTS WAV read failed for %r: %s", preview, e)
|
||||
finally:
|
||||
try:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
if not wav_bytes:
|
||||
continue
|
||||
try:
|
||||
frames = wav_to_pcm_20ms_frames(wav_bytes)
|
||||
except RuntimeError as e:
|
||||
log.warning("TTS WAV-to-PCM failed for %r: %s", preview, e)
|
||||
continue
|
||||
if not frames:
|
||||
log.warning("TTS WAV-to-PCM produced 0 frames for %r", preview)
|
||||
continue
|
||||
for frame in frames:
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
self._pcm_queue.put(frame)
|
||||
log.info("TTS pushed %d frames (%.1fs) for %r",
|
||||
len(frames), len(frames) * 0.02, preview)
|
||||
|
||||
|
||||
# ---------- Discord audio source ----------
|
||||
|
||||
class EchoStreamingAudioSource(discord.AudioSource):
|
||||
"""Pull PCM frames from a ``TTSQueue`` into Discord's audio thread.
|
||||
|
||||
A single ``voice_client.play(EchoStreamingAudioSource(ttsq))`` call
|
||||
spans the whole session. When the TTS queue is empty, ``read()``
|
||||
returns a 20ms silence frame to keep the player alive — otherwise
|
||||
Discord would interpret an empty return as end-of-stream and stop
|
||||
the player, so real TTS frames pushed later would be silently
|
||||
discarded. The player is explicitly terminated only via
|
||||
``cleanup()`` (called on voice session teardown).
|
||||
"""
|
||||
|
||||
# 20ms of s16le stereo at 48kHz silence (960 samples × 2 channels × 2 bytes).
|
||||
_SILENCE_FRAME = b'\x00' * (960 * 2 * 2)
|
||||
|
||||
def __init__(self, ttsq: TTSQueue):
|
||||
self._ttsq = ttsq
|
||||
self._closed = False
|
||||
|
||||
def read(self) -> bytes:
|
||||
if self._closed:
|
||||
return b''
|
||||
frame = self._ttsq.get_frame_nowait()
|
||||
if frame is None:
|
||||
return self._SILENCE_FRAME
|
||||
return frame
|
||||
|
||||
def is_opus(self) -> bool:
|
||||
return False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self._closed = True
|
||||
try:
|
||||
self._ttsq.clear()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -17,6 +17,13 @@ Lecții capturate din corectările lui Marius. Citește acest fișier la începu
|
||||
|
||||
<!-- Lecțiile se adaugă mai jos, cele mai noi sus. -->
|
||||
|
||||
## Supertonic rejectează ghilimelele curly (Unicode) cu HTTP 500
|
||||
**Data:** 2026-05-27
|
||||
**Context:** Marius a dat o comandă audio pe Discord cu un URL, iar răspunsul lui Claude conținea `„foo"` (ghilimele românești curly). Supertonic a returnat `HTTP 500: synthesis failed: Found 1 unsupported character(s): ['„']` și răspunsul nu s-a mai auzit. Fără retry logic vizibil în UX — pur și simplu tace.
|
||||
**Greșeala:** Am presupus că `normalize_for_tts` produce text deja "TTS-safe" pentru Supertonic. În realitate `strip_markdown` păstrează ghilimelele Unicode (`„` U+201E, `"` U+201D, `—` U+2014, `…` U+2026, etc.) pe care Supertonic le refuză.
|
||||
**Regula:** Înainte de orice apel HTTP la Supertonic, **sanitizează punctuația Unicode** la echivalentele ASCII (`„` `"` `"` → `"`, `'` `'` `‚` → `'`, `–` `—` → `-`, `…` → `...`, `«` `»` → `"`). Funcția `sanitize_punctuation` în `src/voice/normalize.py` face asta și e apelată chiar după `strip_markdown` în pipeline. Dacă apar caractere noi care crapă Supertonic (ex: simboluri matematice, săgeți), adaugă-le în `_TTS_PUNCT_MAP`.
|
||||
**Când se aplică:** Orice cod care trimite text la Supertonic (`tools/tts.py`, `src/voice/tts_stream.py`). Inclusiv testare manuală cu `curl` — folosește text românesc realistic (include `„foo"`, em-dash `—`, ellipsis `…`).
|
||||
|
||||
## Mai multe threads ≠ mai rapid — fitează `cpu_threads` pe physical cores, nu logical
|
||||
**Data:** 2026-05-27
|
||||
**Context:** Benchmark `tools/voice_bench.py` pentru faster-whisper `small` int8 pe i7-6700T (4 physical / 8 logical cores). Marius a urcat VM-ul de la 2 → 4 → 6 cores online, așteptând că mai multe = mai rapid.
|
||||
|
||||
307
tests/test_claude_session_mutex.py
Normal file
307
tests/test_claude_session_mutex.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Regression-critical tests for per-channel mutex in src/claude_session.py.
|
||||
|
||||
Three scenarios from the eng-review test plan (2026-05-27):
|
||||
|
||||
1. Concurrent `send_message` calls on the SAME channel_id serialize —
|
||||
the second waits for the first to finish before its subprocess runs.
|
||||
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
|
||||
— independent channels never block each other.
|
||||
3. Acquisition contract is documented and consistent: the lock is acquired
|
||||
blocking (no acquire timeout), which means a hung subprocess on
|
||||
channel X delays subsequent X messages but never X' (X != X'). This
|
||||
test pins that behavior so future refactors must preserve it.
|
||||
|
||||
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
|
||||
is a sync function typically dispatched via `asyncio.to_thread` from
|
||||
async adapters. asyncio.Lock would serialize coroutines only — not the
|
||||
subprocess invocation. See plan section "Engineering decisions" #2.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src import claude_session
|
||||
from src.claude_session import (
|
||||
_get_session_lock,
|
||||
_session_locks,
|
||||
send_message,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_session_locks():
|
||||
"""Each test starts with a fresh lock map so we don't share state."""
|
||||
_session_locks.clear()
|
||||
yield
|
||||
_session_locks.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions(tmp_path, monkeypatch):
|
||||
"""Isolated active.json per test — keeps real session state untouched."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
sf = sessions_dir / "active.json"
|
||||
sf.write_text("{}")
|
||||
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||
return sf
|
||||
|
||||
|
||||
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
|
||||
concurrent_seen: threading.Event):
|
||||
"""Build a fake `_run_claude` that signals when inside the critical section.
|
||||
|
||||
The fake holds the simulated subprocess for `sleep_seconds`. Any other
|
||||
invocation that overlaps will set `concurrent_seen` — the mutex test
|
||||
asserts this NEVER happens for the same channel_id.
|
||||
"""
|
||||
state = {"active": 0, "lock": threading.Lock()}
|
||||
|
||||
def fake(cmd, timeout, on_text=None, cwd=None):
|
||||
with state["lock"]:
|
||||
state["active"] += 1
|
||||
if state["active"] > 1:
|
||||
concurrent_seen.set()
|
||||
in_critical.set()
|
||||
time.sleep(sleep_seconds)
|
||||
with state["lock"]:
|
||||
state["active"] -= 1
|
||||
return {
|
||||
"result": "Hello from Claude!",
|
||||
"session_id": "sess-abc-123",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
"total_cost_usd": 0.001,
|
||||
"cost_usd": 0.001,
|
||||
"duration_ms": int(sleep_seconds * 1000),
|
||||
"num_turns": 1,
|
||||
"intermediate_count": 0,
|
||||
"subtype": "success",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1 — same channel serializes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSameChannelSerializes:
|
||||
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
|
||||
self, temp_sessions
|
||||
):
|
||||
"""Two parallel send_message on the SAME channel_id never overlap.
|
||||
|
||||
We instrument `_run_claude` to signal whenever more than one
|
||||
invocation is concurrently inside it. The mutex MUST prevent that.
|
||||
"""
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(send_message, "ch-same", f"msg-{i}")
|
||||
for i in range(2)
|
||||
]
|
||||
results = [f.result(timeout=10) for f in futures]
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert not concurrent_seen.is_set(), (
|
||||
"Two send_message calls on the same channel ran concurrently — "
|
||||
"mutex did not serialize them."
|
||||
)
|
||||
assert all(r == "Hello from Claude!" for r in results)
|
||||
# Two serial 0.25s subprocesses must take at least ~0.5s total
|
||||
# (we allow a generous floor — schedulers can be slow).
|
||||
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
|
||||
|
||||
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
|
||||
"""`_get_session_lock` returns the SAME lock object for the same channel."""
|
||||
lock_a1 = _get_session_lock("channel-A")
|
||||
lock_a2 = _get_session_lock("channel-A")
|
||||
lock_b = _get_session_lock("channel-B")
|
||||
assert lock_a1 is lock_a2
|
||||
assert lock_a1 is not lock_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2 — different channels parallel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDifferentChannelsParallel:
|
||||
def test_two_concurrent_calls_different_channels_run_in_parallel(
|
||||
self, temp_sessions
|
||||
):
|
||||
"""Different channels MUST NOT block each other.
|
||||
|
||||
We measure elapsed wall-clock: two 0.4s subprocesses on different
|
||||
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
|
||||
"""
|
||||
in_critical = threading.Event()
|
||||
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||
f1 = pool.submit(send_message, "ch-A", "msg-A")
|
||||
f2 = pool.submit(send_message, "ch-B", "msg-B")
|
||||
results = [f1.result(timeout=10), f2.result(timeout=10)]
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert all(r == "Hello from Claude!" for r in results)
|
||||
# Parallel execution: total time should be close to 0.4s, well under
|
||||
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
|
||||
# scheduler jitter on a busy test box.
|
||||
assert elapsed < 0.65, (
|
||||
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
|
||||
f"(expected ~0.4s parallel, <0.65s ceiling)"
|
||||
)
|
||||
assert concurrent_seen.is_set(), (
|
||||
"Different channels did not overlap — mutex is too coarse "
|
||||
"(should be per-channel, not global)."
|
||||
)
|
||||
|
||||
def test_three_channels_all_overlap(self, temp_sessions):
|
||||
"""Stress: three concurrent channels all run in parallel."""
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||
futures = [
|
||||
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
for f in as_completed(futures, timeout=10):
|
||||
assert f.result() == "Hello from Claude!"
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
|
||||
assert elapsed < 0.6, (
|
||||
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3 — acquisition behavior documented and consistent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAcquisitionBehavior:
|
||||
"""Pin the chosen acquisition policy: blocking, no timeout.
|
||||
|
||||
Project style is to bound subprocess execution via `timeout` (default
|
||||
5 min) rather than fail-fast on lock acquire. Reasons:
|
||||
|
||||
- Adapter callers (Discord/Telegram/voice) already serialize work via
|
||||
asyncio.to_thread; queue depth is naturally bounded.
|
||||
- A non-blocking acquire would surface a timing error to the user
|
||||
("busy, try again") for an entirely transient and self-resolving
|
||||
condition. Blocking gives FIFO-ish ordering with simple semantics.
|
||||
- If a subprocess truly hangs past `timeout`, _run_claude raises
|
||||
TimeoutError → the held lock releases (via `with`) → queued
|
||||
callers proceed.
|
||||
|
||||
This test pins that: a second caller waits and eventually proceeds; it
|
||||
does not raise an exception on contention.
|
||||
"""
|
||||
|
||||
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
|
||||
in_critical = threading.Event()
|
||||
concurrent_seen = threading.Event()
|
||||
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||
|
||||
results: list[str | BaseException] = []
|
||||
|
||||
def run(label: str):
|
||||
try:
|
||||
results.append(send_message("ch-contend", label))
|
||||
except BaseException as e:
|
||||
results.append(e)
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||
t1 = threading.Thread(target=run, args=("first",))
|
||||
t1.start()
|
||||
# Wait until the first call is inside the critical section so
|
||||
# the second is GUARANTEED to contend on the lock.
|
||||
assert in_critical.wait(timeout=2.0), "first call never entered"
|
||||
in_critical.clear()
|
||||
t2 = threading.Thread(target=run, args=("second",))
|
||||
t2.start()
|
||||
t1.join(timeout=5.0)
|
||||
t2.join(timeout=5.0)
|
||||
|
||||
assert len(results) == 2
|
||||
# Both must return the canned response — no exception, no error.
|
||||
assert all(r == "Hello from Claude!" for r in results), (
|
||||
f"Contended acquire surfaced an error instead of blocking: {results}"
|
||||
)
|
||||
# Critical-section overlap check: contended calls MUST serialize.
|
||||
assert not concurrent_seen.is_set(), (
|
||||
"Contended same-channel calls ran concurrently — mutex broken."
|
||||
)
|
||||
|
||||
def test_lock_released_on_subprocess_exception(self, temp_sessions):
|
||||
"""If `_run_claude` raises, the lock MUST be released so the next
|
||||
caller can proceed (otherwise a single error deadlocks the channel
|
||||
forever)."""
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def flaky(cmd, timeout, on_text=None, cwd=None):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise RuntimeError("simulated subprocess crash")
|
||||
return {
|
||||
"result": "Hello from Claude!",
|
||||
"session_id": "sess-abc-123",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
"total_cost_usd": 0.001,
|
||||
"cost_usd": 0.001,
|
||||
"duration_ms": 50,
|
||||
"num_turns": 1,
|
||||
"intermediate_count": 0,
|
||||
"subtype": "success",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
with patch.object(claude_session, "_run_claude", side_effect=flaky):
|
||||
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
|
||||
send_message("ch-recover", "first")
|
||||
|
||||
# Second call MUST acquire the lock (proves the first released it).
|
||||
# We use a short timeout via a thread so a deadlock would fail loudly.
|
||||
done = threading.Event()
|
||||
result_box: list[str] = []
|
||||
|
||||
def second():
|
||||
result_box.append(send_message("ch-recover", "second"))
|
||||
done.set()
|
||||
|
||||
t = threading.Thread(target=second)
|
||||
t.start()
|
||||
assert done.wait(timeout=3.0), (
|
||||
"Second call deadlocked — lock was not released on exception."
|
||||
)
|
||||
t.join(timeout=1.0)
|
||||
assert result_box == ["Hello from Claude!"]
|
||||
222
tests/test_voice_adapter_contract.py
Normal file
222
tests/test_voice_adapter_contract.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Contract test for `src/voice/_discord_voice_adapter.py`.
|
||||
|
||||
Purpose: catch drift when the vendored `discord-ext-voice-recv` is upgraded.
|
||||
If upstream renames/removes a method we depend on, this test fails LOUDLY
|
||||
before any downstream code breaks at runtime in a Discord voice call.
|
||||
|
||||
Per VENDOR_INFO.md: this test MUST PASS after every vendor upgrade.
|
||||
|
||||
Plain `import` + `hasattr` / `callable` checks — no mocks. We're verifying
|
||||
the SHAPE of the API surface, not behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- Adapter re-exports import cleanly --------------------------------------
|
||||
|
||||
|
||||
def test_adapter_exports_voice_receive_client():
|
||||
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||
|
||||
assert VoiceReceiveClient is not None
|
||||
assert inspect.isclass(VoiceReceiveClient)
|
||||
|
||||
|
||||
def test_adapter_exports_audio_sink():
|
||||
from src.voice._discord_voice_adapter import AudioSink
|
||||
|
||||
assert AudioSink is not None
|
||||
assert inspect.isclass(AudioSink)
|
||||
|
||||
|
||||
def test_adapter_exports_voice_data():
|
||||
from src.voice._discord_voice_adapter import VoiceData
|
||||
|
||||
assert VoiceData is not None
|
||||
assert inspect.isclass(VoiceData)
|
||||
|
||||
|
||||
def test_adapter_exports_connect_helper():
|
||||
from src.voice._discord_voice_adapter import connect_voice
|
||||
|
||||
assert callable(connect_voice)
|
||||
assert inspect.iscoroutinefunction(connect_voice)
|
||||
|
||||
|
||||
# --- Re-exports point at the real vendored classes (no accidental shadowing) -
|
||||
|
||||
|
||||
def test_voice_receive_client_is_voice_recv_client():
|
||||
from discord.ext import voice_recv
|
||||
|
||||
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||
|
||||
assert VoiceReceiveClient is voice_recv.VoiceRecvClient
|
||||
|
||||
|
||||
def test_audio_sink_is_voice_recv_audio_sink():
|
||||
from discord.ext import voice_recv
|
||||
|
||||
from src.voice._discord_voice_adapter import AudioSink
|
||||
|
||||
assert AudioSink is voice_recv.AudioSink
|
||||
|
||||
|
||||
def test_voice_data_is_voice_recv_voice_data():
|
||||
from discord.ext import voice_recv
|
||||
|
||||
from src.voice._discord_voice_adapter import VoiceData
|
||||
|
||||
assert VoiceData is voice_recv.VoiceData
|
||||
|
||||
|
||||
# --- VoiceReceiveClient API surface used by the pipeline --------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name",
|
||||
[
|
||||
"connect", # inherited from discord.VoiceClient
|
||||
"disconnect", # inherited from discord.VoiceClient
|
||||
"listen", # voice_recv extension
|
||||
"stop_listening", # voice_recv extension
|
||||
"is_listening", # voice_recv extension
|
||||
"stop", # voice_recv extension (stops play+listen)
|
||||
"cleanup", # voice_recv extension
|
||||
],
|
||||
)
|
||||
def test_voice_receive_client_has_method(method_name):
|
||||
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||
|
||||
attr = getattr(VoiceReceiveClient, method_name, None)
|
||||
assert attr is not None, f"VoiceReceiveClient is missing `.{method_name}()`"
|
||||
assert callable(attr), f"VoiceReceiveClient.{method_name} is not callable"
|
||||
|
||||
|
||||
def test_voice_receive_client_listen_accepts_sink_and_after():
|
||||
"""`.listen(sink, *, after=None)` is the canonical call shape."""
|
||||
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||
|
||||
sig = inspect.signature(VoiceReceiveClient.listen)
|
||||
params = sig.parameters
|
||||
assert "sink" in params, f"VoiceReceiveClient.listen missing `sink` param; got {list(params)}"
|
||||
assert "after" in params, f"VoiceReceiveClient.listen missing `after` kwarg; got {list(params)}"
|
||||
|
||||
|
||||
def test_voice_receive_client_has_sink_property():
|
||||
"""`.sink` is read/write so we can swap sinks in place."""
|
||||
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||
|
||||
sink_attr = inspect.getattr_static(VoiceReceiveClient, "sink", None)
|
||||
assert isinstance(sink_attr, property), "VoiceReceiveClient.sink must be a property"
|
||||
assert sink_attr.fget is not None, "VoiceReceiveClient.sink property missing getter"
|
||||
assert sink_attr.fset is not None, "VoiceReceiveClient.sink property missing setter"
|
||||
|
||||
|
||||
# --- AudioSink API surface --------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name",
|
||||
[
|
||||
"write", # write(user, voice_data) — the hot path
|
||||
"cleanup",
|
||||
"wants_opus", # bool: opus bytes vs decoded PCM
|
||||
],
|
||||
)
|
||||
def test_audio_sink_has_method(method_name):
|
||||
from src.voice._discord_voice_adapter import AudioSink
|
||||
|
||||
attr = getattr(AudioSink, method_name, None)
|
||||
assert attr is not None, f"AudioSink is missing `.{method_name}()`"
|
||||
assert callable(attr), f"AudioSink.{method_name} is not callable"
|
||||
|
||||
|
||||
def test_audio_sink_write_signature():
|
||||
"""`.write(self, user, data)` — user is the speaker (Optional), data is VoiceData."""
|
||||
from src.voice._discord_voice_adapter import AudioSink
|
||||
|
||||
sig = inspect.signature(AudioSink.write)
|
||||
params = list(sig.parameters)
|
||||
# self, user, data
|
||||
assert len(params) >= 3, f"AudioSink.write expected (self, user, data), got {params}"
|
||||
|
||||
|
||||
# --- VoiceData attributes ---------------------------------------------------
|
||||
|
||||
|
||||
def test_voice_data_slots():
|
||||
"""VoiceData uses __slots__ for per-packet allocation. Pipeline reads these."""
|
||||
from src.voice._discord_voice_adapter import VoiceData
|
||||
|
||||
assert hasattr(VoiceData, "__slots__"), "VoiceData lost __slots__ — perf regression risk"
|
||||
slots = set(VoiceData.__slots__)
|
||||
# Documented attributes the pipeline depends on.
|
||||
assert "packet" in slots, f"VoiceData missing `packet` slot; got {slots}"
|
||||
assert "source" in slots, f"VoiceData missing `source` slot (speaker user); got {slots}"
|
||||
assert "pcm" in slots, f"VoiceData missing `pcm` slot (decoded audio); got {slots}"
|
||||
|
||||
|
||||
def test_voice_data_has_opus_property():
|
||||
"""`.opus` exposes the raw opus bytes from the underlying RTP packet."""
|
||||
from src.voice._discord_voice_adapter import VoiceData
|
||||
|
||||
opus_attr = inspect.getattr_static(VoiceData, "opus", None)
|
||||
assert isinstance(opus_attr, property), "VoiceData.opus must be a property"
|
||||
|
||||
|
||||
# --- Echo-core DAVE-decrypt fork guards -------------------------------------
|
||||
#
|
||||
# Two contract tests pinned by the DAVE receive-side decrypt patch.
|
||||
# See plan: /home/moltbot/.claude/plans/wiggly-exploring-glade.md
|
||||
#
|
||||
# These fail fast on either:
|
||||
# 1. An upstream voice-recv re-install wiping the fork's version marker
|
||||
# (i.e. our patch is gone), OR
|
||||
# 2. A discord.py upgrade renaming the connection-level DAVE attrs the
|
||||
# patch reads (`dave_session`, `dave_protocol_version`).
|
||||
|
||||
|
||||
def test_voice_recv_fork_version():
|
||||
"""Echo-core fork tag for the DAVE-decrypt patch.
|
||||
|
||||
Lane A bumps `voice_recv.__version__` to `'0.5.3a+echo.dave1'` (PEP 440
|
||||
local segment). If this assertion fails after a vendor reinstall, the
|
||||
fork patch has been lost — re-apply `_maybe_dave_decrypt` + the
|
||||
`callback()` hook before deploying, or live voice will regress to the
|
||||
`opus_decode: corrupted stream` error chain.
|
||||
"""
|
||||
from discord.ext import voice_recv
|
||||
|
||||
assert voice_recv.__version__ == "0.5.3a+echo.dave1", (
|
||||
f"voice_recv.__version__ is {voice_recv.__version__!r}; expected "
|
||||
"'0.5.3a+echo.dave1'. The DAVE-decrypt fork patch has been "
|
||||
"overwritten — re-apply before reinstalling the vendored package."
|
||||
)
|
||||
|
||||
|
||||
def test_voice_connection_state_has_dave_attrs():
|
||||
"""`_maybe_dave_decrypt` reads `dave_session` and `dave_protocol_version`
|
||||
off the discord.py `VoiceConnectionState`. If a future discord.py upgrade
|
||||
renames either attr, fail loudly here rather than in a live voice call
|
||||
(where the symptom is silent packet drops).
|
||||
"""
|
||||
from discord import voice_state
|
||||
|
||||
src = inspect.getsource(voice_state.VoiceConnectionState)
|
||||
assert "dave_session" in src, (
|
||||
"discord.voice_state.VoiceConnectionState source no longer mentions "
|
||||
"'dave_session' — discord.py may have renamed the attr. Update "
|
||||
"vendor/discord-ext-voice-recv/.../reader.py::_maybe_dave_decrypt."
|
||||
)
|
||||
assert "dave_protocol_version" in src, (
|
||||
"discord.voice_state.VoiceConnectionState source no longer mentions "
|
||||
"'dave_protocol_version' — discord.py may have renamed the attr. "
|
||||
"Update _maybe_dave_decrypt accordingly."
|
||||
)
|
||||
137
tests/test_voice_normalize.py
Normal file
137
tests/test_voice_normalize.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for src/voice/normalize.py — 35 Romanian cases.
|
||||
|
||||
Categories:
|
||||
markdown strip (5), numbers cardinals (6), decimals (4),
|
||||
currency natural (8), symbols (4), abbreviations (4),
|
||||
truncation boundary (2), edge cases empty / whitespace (2).
|
||||
|
||||
Total: 35.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from src.voice.normalize import (
|
||||
expand_abbreviations,
|
||||
expand_currency,
|
||||
expand_numbers_ro,
|
||||
expand_symbols,
|
||||
normalize_for_tts,
|
||||
strip_markdown,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Markdown stripping (5)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("**bold text**", "bold text"),
|
||||
("*italic text*", "italic text"),
|
||||
("`code snippet`", "code snippet"),
|
||||
("[click here](https://example.com)", "click here"),
|
||||
("# Heading text", "Heading text"),
|
||||
])
|
||||
def test_strip_markdown(text, expected):
|
||||
assert strip_markdown(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Numbers cardinals (6)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("21", "douăzeci și unu"),
|
||||
("81", "optzeci și unu"),
|
||||
("100", "o sută"),
|
||||
("3", "trei"),
|
||||
("0", "zero"),
|
||||
("200", "două sute"),
|
||||
])
|
||||
def test_expand_numbers_cardinals(text, expected):
|
||||
assert expand_numbers_ro(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Decimals (4)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("3.14", "trei virgulă paisprezece"),
|
||||
("12.5", "doisprezece virgulă cinci"),
|
||||
("0.5", "zero virgulă cinci"),
|
||||
("99.99", "nouăzeci și nouă virgulă nouăzeci și nouă"),
|
||||
])
|
||||
def test_expand_numbers_decimals(text, expected):
|
||||
assert expand_numbers_ro(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Currency natural RO (8) — RON / USD / EUR / GBP mix
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("12.50 RON", "doisprezece lei și cincizeci de bani"),
|
||||
("$25.99", "douăzeci și cinci de dolari și nouăzeci și nouă de cenți"),
|
||||
("€100.50", "o sută de euro și cincizeci de cenți"),
|
||||
("£200", "două sute de lire"),
|
||||
("100 RON", "o sută de lei"),
|
||||
("$1", "un dolar"),
|
||||
("€50", "cincizeci de euro"),
|
||||
("1 RON", "un leu"),
|
||||
])
|
||||
def test_expand_currency(text, expected):
|
||||
assert expand_currency(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Symbols (4)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("25%", "25 la sută"),
|
||||
("foo & bar", "foo și bar"),
|
||||
("Marius @ home", "Marius la home"),
|
||||
("30°", "30 grade"),
|
||||
])
|
||||
def test_expand_symbols(text, expected):
|
||||
assert expand_symbols(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Abbreviations (4)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("etc.", "etcetera"),
|
||||
("dl. Popescu", "domnul Popescu"),
|
||||
("dna. Ionescu", "doamna Ionescu"),
|
||||
("nr. 5", "numărul 5"),
|
||||
])
|
||||
def test_expand_abbreviations(text, expected):
|
||||
assert expand_abbreviations(text) == expected
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Truncation boundary (2)
|
||||
# ============================================================
|
||||
def test_truncate_exactly_200_words_unchanged():
|
||||
"""Exactly 200 simple word tokens — no truncation, no suffix."""
|
||||
text = " ".join(["cuvant"] * 200)
|
||||
out = normalize_for_tts(text)
|
||||
assert "Restul l-am scris în chat." not in out
|
||||
assert out.split() == ["cuvant"] * 200
|
||||
|
||||
|
||||
def test_truncate_over_200_words_appends_suffix():
|
||||
"""250 word tokens — keep first 200 then append the chat-deferral phrase."""
|
||||
text = " ".join(["cuvant"] * 250)
|
||||
out = normalize_for_tts(text)
|
||||
assert out.endswith("Restul l-am scris în chat.")
|
||||
words = out.split()
|
||||
# First 200 are 'cuvant', followed by the 5-word suffix.
|
||||
assert words[:200] == ["cuvant"] * 200
|
||||
assert words[200:] == ["Restul", "l-am", "scris", "în", "chat."]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Edge cases (2)
|
||||
# ============================================================
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("", ""),
|
||||
(" ", ""),
|
||||
])
|
||||
def test_normalize_edge_cases(text, expected):
|
||||
assert normalize_for_tts(text) == expected
|
||||
302
tests/test_voice_recv_dave.py
Normal file
302
tests/test_voice_recv_dave.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""DAVE receive-side decrypt tests for the vendored voice-recv fork.
|
||||
|
||||
Exercises Lane A's patch on
|
||||
`vendor/discord-ext-voice-recv/discord/ext/voice_recv/reader.py`:
|
||||
|
||||
* `_maybe_dave_decrypt(rtp_packet)` — DAVE E2E layer sandwiched between the
|
||||
transport-layer decrypt and the routing into the opus decoder. No-op when
|
||||
the room is non-DAVE, when davey isn't installed, or when the SSRC map
|
||||
hasn't caught up to a new speaker yet.
|
||||
* `callback()` hook — feeds the DAVE-unwrapped plaintext into
|
||||
`packet_router.feed_rtp()` on success, drops the packet on failure WITHOUT
|
||||
killing the reader thread.
|
||||
|
||||
The test fixtures mirror `tests/test_voice_session_cleanup.py:33-54`:
|
||||
* Construct `AudioReader` via `AudioReader.__new__(AudioReader)` + manual
|
||||
attr set so the reader thread is never started.
|
||||
* `MagicMock` everything below the unit under test.
|
||||
|
||||
`_HAS_DAVE` / `_MEDIA_TYPE_AUDIO` on the reader module are monkey-patched per
|
||||
test so the suite passes whether or not `davey` is importable in the venv.
|
||||
The assertions only become meaningful once Lane A's patch has landed and the
|
||||
package has been re-installed (`pip install -e vendor/discord-ext-voice-recv
|
||||
--force-reinstall`); the FILE itself is valid Python regardless.
|
||||
|
||||
See plan: /home/moltbot/.claude/plans/wiggly-exploring-glade.md
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from discord.ext.voice_recv.reader import AudioReader
|
||||
|
||||
|
||||
# Sentinel for `_MEDIA_TYPE_AUDIO`. Using a plain object() keeps the tests
|
||||
# independent of whether davey is importable — we just assert the value
|
||||
# flows through to `dave_session.decrypt()` unchanged.
|
||||
_FAKE_MEDIA_TYPE_AUDIO = object()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dave_session():
|
||||
sess = MagicMock(name="dave_session")
|
||||
sess.ready = True
|
||||
# Default: this user is NOT in passthrough — DAVE decrypt must run.
|
||||
# Individual tests can override to True to exercise the passthrough path.
|
||||
sess.can_passthrough = MagicMock(return_value=False)
|
||||
sess.decrypt = MagicMock(return_value=b"plaintext_opus")
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_connection(fake_dave_session):
|
||||
conn = MagicMock(name="_connection")
|
||||
conn.dave_protocol_version = 1
|
||||
conn.dave_session = fake_dave_session
|
||||
return conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_voice_client(fake_connection):
|
||||
vc = MagicMock(name="voice_client")
|
||||
vc._connection = fake_connection
|
||||
vc._ssrc_to_id = {12345: 999_000}
|
||||
return vc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_rtp_packet():
|
||||
pkt = MagicMock(name="rtp_packet")
|
||||
pkt.ssrc = 12345
|
||||
pkt.decrypted_data = b"ciphertext_after_transport_decrypt"
|
||||
pkt.is_silence = MagicMock(return_value=False)
|
||||
return pkt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reader(fake_voice_client):
|
||||
"""`AudioReader` instance with no reader thread spawned.
|
||||
|
||||
Same pattern used by `tests/test_voice_session_cleanup.py` for
|
||||
`VoiceSession` — bypass `__init__` so we can drive the public surface
|
||||
against pure mocks.
|
||||
"""
|
||||
r = AudioReader.__new__(AudioReader)
|
||||
r.voice_client = fake_voice_client
|
||||
r.error = None
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dave_enabled(monkeypatch):
|
||||
"""Force the reader module's DAVE-availability flags ON.
|
||||
|
||||
Pins `_MEDIA_TYPE_AUDIO` to a known sentinel so the happy-path test can
|
||||
assert exactly what gets passed to `dave_session.decrypt`. `raising=False`
|
||||
keeps the monkeypatch valid even if Lane A's patch hasn't landed yet —
|
||||
the tests will still fail (no `_maybe_dave_decrypt` attr), just for the
|
||||
right reason.
|
||||
"""
|
||||
import discord.ext.voice_recv.reader as reader_mod
|
||||
|
||||
monkeypatch.setattr(reader_mod, "_HAS_DAVE", True, raising=False)
|
||||
monkeypatch.setattr(
|
||||
reader_mod, "_MEDIA_TYPE_AUDIO", _FAKE_MEDIA_TYPE_AUDIO, raising=False
|
||||
)
|
||||
return reader_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: `_maybe_dave_decrypt`
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaybeDaveDecrypt:
|
||||
"""Seven unit tests on the DAVE-decrypt gate.
|
||||
|
||||
The gate mirrors `voice_client.can_encrypt` in discord.py 2.7.1 exactly
|
||||
(`voice_state.py:272-273`). Bypass semantics on every "DAVE inactive"
|
||||
branch let non-DAVE rooms and davey-less environments keep working.
|
||||
"""
|
||||
|
||||
def test_protocol_version_zero_bypasses_decrypt(
|
||||
self, dave_enabled, reader, fake_connection, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""`dave_protocol_version == 0` → return the transport-decrypted
|
||||
payload unchanged; never touch `dave_session.decrypt`."""
|
||||
fake_connection.dave_protocol_version = 0
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is fake_rtp_packet.decrypted_data
|
||||
fake_dave_session.decrypt.assert_not_called()
|
||||
|
||||
def test_dave_session_none_bypasses_decrypt(
|
||||
self, dave_enabled, reader, fake_connection, fake_rtp_packet,
|
||||
):
|
||||
"""`dave_session is None` → bypass. Pre-MLS-handshake state."""
|
||||
fake_connection.dave_session = None
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is fake_rtp_packet.decrypted_data
|
||||
|
||||
def test_dave_session_not_ready_bypasses_decrypt(
|
||||
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""`dave_session.ready is False` → bypass. Pre-MLS-epoch-1 packets
|
||||
are transport-only on the wire."""
|
||||
fake_dave_session.ready = False
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is fake_rtp_packet.decrypted_data
|
||||
fake_dave_session.decrypt.assert_not_called()
|
||||
|
||||
def test_unknown_ssrc_returns_none(
|
||||
self, dave_enabled, reader, fake_voice_client, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""SSRC not in `_ssrc_to_id` → drop (return None).
|
||||
|
||||
Accepted regression: davey requires per-user keys; when SPEAKING
|
||||
events race behind the first audio packet, 1-5 packets per new
|
||||
speaker per session are dropped. See plan §Edge cases.
|
||||
"""
|
||||
fake_voice_client._ssrc_to_id.clear()
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is None
|
||||
fake_dave_session.decrypt.assert_not_called()
|
||||
|
||||
def test_happy_path_invokes_decrypt_and_returns_plaintext(
|
||||
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""Full DAVE-active path: `decrypt(user_id, MediaType.audio, ciphertext)`
|
||||
called exactly once with the expected args; method returns the
|
||||
davey plaintext bytes verbatim."""
|
||||
ciphertext = fake_rtp_packet.decrypted_data
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result == b"plaintext_opus"
|
||||
fake_dave_session.decrypt.assert_called_once_with(
|
||||
999_000, _FAKE_MEDIA_TYPE_AUDIO, ciphertext,
|
||||
)
|
||||
|
||||
def test_decrypt_raises_returns_none_no_crash(
|
||||
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""davey.decrypt raising → drop the packet, don't propagate, and
|
||||
leave `reader.error` untouched so the reader thread stays alive.
|
||||
|
||||
MLS epoch transitions can produce transient decrypt failures —
|
||||
bumping `reader.error` would call `self.stop()` and kill the whole
|
||||
receive pipeline."""
|
||||
fake_dave_session.decrypt.side_effect = RuntimeError(
|
||||
"simulated MLS epoch transition fail"
|
||||
)
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is None
|
||||
assert reader.error is None
|
||||
|
||||
def test_has_dave_false_bypasses_even_with_session_present(
|
||||
self, monkeypatch, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""`_HAS_DAVE = False` → bypass everything, even if a real session
|
||||
somehow showed up on the connection. Defensive shim that keeps the
|
||||
tests (and any davey-less deploys) green."""
|
||||
import discord.ext.voice_recv.reader as reader_mod
|
||||
|
||||
monkeypatch.setattr(reader_mod, "_HAS_DAVE", False, raising=False)
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is fake_rtp_packet.decrypted_data
|
||||
fake_dave_session.decrypt.assert_not_called()
|
||||
|
||||
def test_can_passthrough_true_returns_payload_without_decrypt(
|
||||
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""`can_passthrough(user_id) == True` → return the transport-decrypted
|
||||
payload as-is; never call `decrypt`. Mirrors Discord's protocol where
|
||||
a passthrough-mode peer sends non-DAVE-wrapped packets that the
|
||||
receiver must accept verbatim."""
|
||||
fake_dave_session.can_passthrough = MagicMock(return_value=True)
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result is fake_rtp_packet.decrypted_data
|
||||
fake_dave_session.can_passthrough.assert_called_once_with(999_000)
|
||||
fake_dave_session.decrypt.assert_not_called()
|
||||
|
||||
def test_can_passthrough_raises_falls_through_to_decrypt(
|
||||
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||
):
|
||||
"""`can_passthrough` raising → swallow the error and try `decrypt`.
|
||||
Defensive: an older davey build or transient internal state shouldn't
|
||||
break the receive pipeline."""
|
||||
fake_dave_session.can_passthrough = MagicMock(
|
||||
side_effect=RuntimeError("simulated davey internal error")
|
||||
)
|
||||
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||
assert result == b"plaintext_opus"
|
||||
fake_dave_session.decrypt.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: `callback()` exercises the DAVE hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCallbackIntegration:
|
||||
"""Two integration tests for the hook Lane A inserts between transport
|
||||
decrypt (reader.py:141) and the post-decrypt routing (reader.py:159).
|
||||
|
||||
Strategy: stub the transport-decrypt and RTP parsing path so `callback()`
|
||||
reaches the hook, then mock `_maybe_dave_decrypt` directly on the reader
|
||||
instance. The assertion focuses on `feed_rtp` being called (test 8) vs.
|
||||
not called (test 9). The transport path correctness is covered by
|
||||
voice-recv's own upstream tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _wire_callback(reader, monkeypatch, fake_rtp_packet):
|
||||
import discord.ext.voice_recv.reader as reader_mod
|
||||
|
||||
# Redirect rtp parsing — we want an RTP path (not RTCP) so the hook fires.
|
||||
monkeypatch.setattr(reader_mod.rtp, "is_rtcp", lambda data: False)
|
||||
monkeypatch.setattr(reader_mod.rtp, "decode_rtp", lambda data: fake_rtp_packet)
|
||||
|
||||
# Stub the instance attrs `callback()` touches besides the hook.
|
||||
reader.decryptor = MagicMock(name="decryptor")
|
||||
reader.decryptor.decrypt_rtp = MagicMock(return_value=b"ciphertext")
|
||||
reader.packet_router = MagicMock(name="packet_router")
|
||||
reader.packet_router.feed_rtp = MagicMock()
|
||||
reader.speaking_timer = MagicMock(name="speaking_timer")
|
||||
reader.sink = MagicMock(name="sink")
|
||||
|
||||
def test_callback_feeds_when_dave_returns_bytes(
|
||||
self, monkeypatch, reader, fake_rtp_packet,
|
||||
):
|
||||
"""Hook returns plaintext → `feed_rtp` called once with the
|
||||
rtp_packet whose `decrypted_data` is now the post-DAVE plaintext."""
|
||||
self._wire_callback(reader, monkeypatch, fake_rtp_packet)
|
||||
plaintext = b"dave_unwrapped_opus_payload"
|
||||
reader._maybe_dave_decrypt = MagicMock(return_value=plaintext)
|
||||
|
||||
reader.callback(b"raw_packet_bytes")
|
||||
|
||||
reader._maybe_dave_decrypt.assert_called_once_with(fake_rtp_packet)
|
||||
assert reader.packet_router.feed_rtp.call_count == 1
|
||||
called_with = reader.packet_router.feed_rtp.call_args[0][0]
|
||||
assert called_with is fake_rtp_packet
|
||||
assert fake_rtp_packet.decrypted_data == plaintext
|
||||
assert reader.error is None
|
||||
|
||||
def test_callback_drops_when_dave_returns_none(
|
||||
self, monkeypatch, reader, fake_rtp_packet,
|
||||
):
|
||||
"""Hook returns None → `feed_rtp` NOT called, no exception propagated,
|
||||
`reader.error` stays None (reader thread survives the drop)."""
|
||||
self._wire_callback(reader, monkeypatch, fake_rtp_packet)
|
||||
reader._maybe_dave_decrypt = MagicMock(return_value=None)
|
||||
|
||||
reader.callback(b"raw_packet_bytes")
|
||||
|
||||
reader._maybe_dave_decrypt.assert_called_once_with(fake_rtp_packet)
|
||||
reader.packet_router.feed_rtp.assert_not_called()
|
||||
assert reader.error is None
|
||||
319
tests/test_voice_session_cleanup.py
Normal file
319
tests/test_voice_session_cleanup.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""Cleanup-path tests for ``src/voice/pipeline.py::VoiceSession``.
|
||||
|
||||
Pins the centralized ``cleanup()`` contract from the voice plan
|
||||
(Engineering decision #5): every one of the FIVE exit paths must drain
|
||||
state cleanly and idempotently — lock released, JSONL flushed or
|
||||
discarded, presence cleared, ``voice_client.cleanup()`` invoked,
|
||||
``ttsq.stop()`` invoked, and a second call to ``cleanup()`` MUST be a
|
||||
no-op (side effects happen exactly once).
|
||||
|
||||
The 5 paths under test:
|
||||
1. ``test_cleanup_on_voice_leave`` — explicit ``/voice leave``
|
||||
2. ``test_cleanup_on_disconnect`` — Discord-level disconnect
|
||||
3. ``test_cleanup_on_crash`` — exception via ``__exit__``
|
||||
4. ``test_cleanup_on_auto_leave`` — 5-min inactivity timer
|
||||
5. ``test_cleanup_on_user_leaves_channel`` — user leaves voice channel
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.voice.pipeline import VoiceSession
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot():
|
||||
bot = MagicMock(name="bot")
|
||||
bot.user = MagicMock()
|
||||
bot.user.id = 999_999
|
||||
bot.change_presence = AsyncMock(name="change_presence")
|
||||
bot.get_user = MagicMock(return_value=None)
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voice_client():
|
||||
vc = MagicMock(name="voice_client")
|
||||
vc.cleanup = MagicMock(name="vc_cleanup")
|
||||
return vc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ttsq():
|
||||
ttsq = MagicMock(name="ttsq")
|
||||
ttsq.stop = MagicMock(name="ttsq_stop")
|
||||
return ttsq
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_channel():
|
||||
tc = MagicMock(name="text_channel")
|
||||
tc.send = AsyncMock(name="text_send")
|
||||
return tc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(
|
||||
tmp_path: Path,
|
||||
mock_bot,
|
||||
mock_voice_client,
|
||||
mock_ttsq,
|
||||
mock_text_channel,
|
||||
*,
|
||||
record_enabled: bool = True,
|
||||
) -> VoiceSession:
|
||||
jsonl = tmp_path / ("transcripts.jsonl" if record_enabled else "noop.jsonl")
|
||||
return VoiceSession(
|
||||
channel_id=1001,
|
||||
guild_id=42,
|
||||
text_channel=mock_text_channel,
|
||||
voice_client=mock_voice_client,
|
||||
bot=mock_bot,
|
||||
ttsq=mock_ttsq,
|
||||
whitelist={1234},
|
||||
record_enabled=record_enabled,
|
||||
mirror_enabled=True,
|
||||
transcripts_jsonl_path=jsonl,
|
||||
loop=None,
|
||||
router_route_message=MagicMock(name="route_message"),
|
||||
)
|
||||
|
||||
|
||||
def _assert_clean_post_cleanup(
|
||||
session: VoiceSession,
|
||||
voice_client,
|
||||
ttsq,
|
||||
bot,
|
||||
jsonl_path: Path,
|
||||
record_enabled: bool,
|
||||
) -> None:
|
||||
"""Assertions shared across all five cleanup-path tests."""
|
||||
# 1. Lock released — non-blocking acquire from this thread returns True.
|
||||
acquired = session._lock.acquire(blocking=False)
|
||||
assert acquired, "session._lock must be released after cleanup()"
|
||||
session._lock.release()
|
||||
|
||||
# 2. voice_client.cleanup() called exactly once.
|
||||
assert voice_client.cleanup.call_count == 1, (
|
||||
f"voice_client.cleanup() called {voice_client.cleanup.call_count}x, "
|
||||
f"expected 1"
|
||||
)
|
||||
|
||||
# 3. ttsq.stop() called exactly once.
|
||||
assert ttsq.stop.call_count == 1, (
|
||||
f"ttsq.stop() called {ttsq.stop.call_count}x, expected 1"
|
||||
)
|
||||
|
||||
# 4. bot.change_presence(activity=None) called at least once with that kwarg.
|
||||
assert bot.change_presence.call_count >= 1, (
|
||||
"bot.change_presence was never called — presence not restored"
|
||||
)
|
||||
bot.change_presence.assert_called_with(activity=None)
|
||||
|
||||
# 5. JSONL flushed (record=on) OR absent (record=off).
|
||||
if record_enabled:
|
||||
assert jsonl_path.exists(), (
|
||||
"record=on: JSONL file must exist (was created by __enter__ and "
|
||||
"left in place by cleanup so transcript can be persisted)"
|
||||
)
|
||||
else:
|
||||
# record=off: cleanup unlinks the file if it ever existed.
|
||||
assert not jsonl_path.exists() or jsonl_path.stat().st_size == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1 — explicit /voice leave
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOnVoiceLeave:
|
||||
def test_cleanup_on_voice_leave(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=True,
|
||||
)
|
||||
jsonl_path = session.transcripts_jsonl_path
|
||||
|
||||
with session:
|
||||
# Simulate one transcript line.
|
||||
session._jsonl_fh.write(json.dumps({"text": "salut"}) + "\n")
|
||||
session.cleanup("voice_leave")
|
||||
assert session._cleaned_up is True
|
||||
|
||||
# __exit__ called cleanup("exit") — must be a no-op the second time.
|
||||
_assert_clean_post_cleanup(
|
||||
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||
jsonl_path, record_enabled=True,
|
||||
)
|
||||
|
||||
# Idempotency: a third explicit call still doesn't bump counts.
|
||||
session.cleanup("redundant")
|
||||
assert mock_voice_client.cleanup.call_count == 1
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2 — Discord-level voice disconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOnDisconnect:
|
||||
def test_cleanup_on_disconnect(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=False,
|
||||
)
|
||||
jsonl_path = session.transcripts_jsonl_path
|
||||
|
||||
session.__enter__()
|
||||
# Network drop arrives outside the with-block.
|
||||
session.cleanup("disconnect")
|
||||
_assert_clean_post_cleanup(
|
||||
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||
jsonl_path, record_enabled=False,
|
||||
)
|
||||
|
||||
# Idempotency.
|
||||
session.cleanup("disconnect-again")
|
||||
assert mock_voice_client.cleanup.call_count == 1
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3 — crash / exception via __exit__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOnCrash:
|
||||
def test_cleanup_on_crash(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=True,
|
||||
)
|
||||
jsonl_path = session.transcripts_jsonl_path
|
||||
|
||||
with pytest.raises(RuntimeError, match="simulated crash"):
|
||||
with session:
|
||||
# Pipeline raises mid-call.
|
||||
raise RuntimeError("simulated crash")
|
||||
|
||||
# __exit__ must have driven cleanup — every side effect happened once.
|
||||
_assert_clean_post_cleanup(
|
||||
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||
jsonl_path, record_enabled=True,
|
||||
)
|
||||
|
||||
# Idempotency: explicit follow-up call (e.g. an outer error handler
|
||||
# also tries to cleanup) MUST be a no-op.
|
||||
session.cleanup("post-crash")
|
||||
assert mock_voice_client.cleanup.call_count == 1
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 4 — auto-leave timer fires after 5 min inactivity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOnAutoLeave:
|
||||
def test_cleanup_on_auto_leave(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=True,
|
||||
)
|
||||
jsonl_path = session.transcripts_jsonl_path
|
||||
|
||||
session.__enter__()
|
||||
# The auto-leave timer trips outside the with-block.
|
||||
session.cleanup("auto_leave")
|
||||
|
||||
_assert_clean_post_cleanup(
|
||||
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||
jsonl_path, record_enabled=True,
|
||||
)
|
||||
|
||||
# Idempotency.
|
||||
session.cleanup("auto_leave_redundant")
|
||||
assert mock_voice_client.cleanup.call_count == 1
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 5 — user leaves voice channel themselves
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOnUserLeaves:
|
||||
def test_cleanup_on_user_leaves_channel(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=False,
|
||||
)
|
||||
jsonl_path = session.transcripts_jsonl_path
|
||||
|
||||
session.__enter__()
|
||||
# voice_state_update event handler invokes cleanup directly.
|
||||
session.cleanup("user_left_channel")
|
||||
|
||||
_assert_clean_post_cleanup(
|
||||
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||
jsonl_path, record_enabled=False,
|
||||
)
|
||||
|
||||
# Idempotency.
|
||||
session.cleanup("user_left_again")
|
||||
assert mock_voice_client.cleanup.call_count == 1
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cross-cutting: failures inside cleanup don't propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupRobustness:
|
||||
def test_cleanup_swallows_voice_client_errors(
|
||||
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
):
|
||||
"""If voice_client.cleanup() raises, ttsq.stop() must still run and
|
||||
the lock must still release — otherwise a broken Discord state would
|
||||
deadlock the channel forever."""
|
||||
mock_voice_client.cleanup.side_effect = RuntimeError("vc died")
|
||||
|
||||
session = _make_session(
|
||||
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||
record_enabled=False,
|
||||
)
|
||||
|
||||
with session:
|
||||
session.cleanup("voice_leave")
|
||||
|
||||
# ttsq.stop still ran exactly once.
|
||||
assert mock_ttsq.stop.call_count == 1
|
||||
# Lock released.
|
||||
acquired = session._lock.acquire(blocking=False)
|
||||
assert acquired, "lock must release even when voice_client.cleanup raises"
|
||||
session._lock.release()
|
||||
20
tools/tts.py
20
tools/tts.py
@@ -23,6 +23,24 @@ VOICES = {"M1", "M2", "M3", "M4", "M5", "F1", "F2", "F3", "F4", "F5"}
|
||||
DEFAULT_VOICE = "M2"
|
||||
DEFAULT_LANG = "ro"
|
||||
|
||||
# Punctuation Supertonic synthesis rejects with HTTP 500 (Romanian curly quotes,
|
||||
# smart dashes, ellipsis, angle quotes). Mapped to ASCII so a stray „foo" in
|
||||
# any caller's text doesn't kill the whole request.
|
||||
_TTS_PUNCT_MAP = {
|
||||
'„': '"', '“': '"', '”': '"',
|
||||
'‘': "'", '’': "'", '‚': "'",
|
||||
'«': '"', '»': '"',
|
||||
'–': '-', '—': '-',
|
||||
'…': '...',
|
||||
}
|
||||
|
||||
|
||||
def sanitize_for_supertonic(text: str) -> str:
|
||||
"""Replace Unicode punctuation Supertonic rejects with ASCII equivalents."""
|
||||
for src, dst in _TTS_PUNCT_MAP.items():
|
||||
text = text.replace(src, dst)
|
||||
return text
|
||||
|
||||
|
||||
def synthesize(text: str, voice: str = DEFAULT_VOICE, lang: str = DEFAULT_LANG) -> dict:
|
||||
"""Call Supertonic server and save audio to a temp WAV file.
|
||||
@@ -34,6 +52,8 @@ def synthesize(text: str, voice: str = DEFAULT_VOICE, lang: str = DEFAULT_LANG)
|
||||
if not text or not text.strip():
|
||||
return {"ok": False, "error": "Text gol."}
|
||||
|
||||
text = sanitize_for_supertonic(text)
|
||||
|
||||
voice = voice.upper()
|
||||
if voice not in VOICES:
|
||||
voice = DEFAULT_VOICE
|
||||
|
||||
60
vendor/discord-ext-voice-recv/VENDOR_INFO.md
vendored
60
vendor/discord-ext-voice-recv/VENDOR_INFO.md
vendored
@@ -1,22 +1,76 @@
|
||||
# Vendored: discord-ext-voice-recv
|
||||
|
||||
**Upstream:** https://github.com/imayhaveborkedit/discord-ext-voice-recv
|
||||
**Pinned commit:** `ac04ea7b0941112e83767cf1c1469b408fa06748` (bump version 0.5.3a)
|
||||
**Pinned commit:** `ac04ea7b0941112e83767cf1c1469b408fa06748` (bump version 0.5.3a, master HEAD Jun 2025)
|
||||
**Vendored at:** 2026-05-27
|
||||
**Echo Core fork version:** `0.5.3a+echo.dave1` (PEP 440 local segment)
|
||||
**Reason:** Discord voice protocol is fragile, upstream is hobby fork. Adapter
|
||||
layer in `src/voice/_discord_voice_adapter.py` isolates upstream churn — if this
|
||||
package breaks, swap to py-cord by rewriting only that file.
|
||||
|
||||
## Update procedure
|
||||
## Echo Core patch: `+echo.dave1` (DAVE E2E receive-side decrypt)
|
||||
|
||||
### Why
|
||||
|
||||
Discord enforces DAVE (E2E media encryption) on voice gateway `v=8` whenever the
|
||||
bot advertises `max_dave_protocol_version > 0` in IDENTIFY. discord.py 2.7.1 (the
|
||||
version Echo Core pins) does so unconditionally — Discord then closes the WS
|
||||
with code **4017** if the bot opts out by sending `max_dave_protocol_version=0`.
|
||||
DAVE is **mandatory**.
|
||||
|
||||
Audio received from a DAVE-active room is **dual-wrapped**: transport layer
|
||||
(`aead_xchacha20_poly1305_rtpsize`) + DAVE E2E. Upstream voice-recv decrypts
|
||||
only the transport layer, then hands DAVE ciphertext to libopus, which raises
|
||||
`OpusError: corrupted stream` on every packet.
|
||||
|
||||
### Patch shape
|
||||
|
||||
~30 lines, all in `discord/ext/voice_recv/reader.py`:
|
||||
|
||||
1. Module-level optional `davey` import (no-op when missing).
|
||||
2. `AudioReader._maybe_dave_decrypt(rtp_packet) -> Optional[bytes]` — gate logic
|
||||
mirrors discord.py 2.7.1 send-side `can_encrypt` exactly. Returns the
|
||||
DAVE-unwrapped payload, the original payload (DAVE inactive), or `None` to
|
||||
drop the packet (unknown SSRC, decrypt failure).
|
||||
3. 4-line hook in `callback()` between transport-decrypt and `feed_rtp`:
|
||||
overwrites `rtp_packet.decrypted_data` in place, or returns early to drop.
|
||||
|
||||
The post-decrypt `is_silence()` check (formerly at reader.py:172) still works
|
||||
because we overwrite `decrypted_data` in place — silence frames produced by
|
||||
davey reach the existing check unchanged.
|
||||
|
||||
### Dependency
|
||||
|
||||
`davey==0.1.5` — matches discord.py 2.7.1 expectation. Pin in
|
||||
`echo-core/requirements.txt`. The import is optional at module level so tests
|
||||
and non-DAVE environments still run; the gate degrades to a bypass.
|
||||
|
||||
### Re-sync strategy
|
||||
|
||||
When upstream voice-recv adds DAVE support natively:
|
||||
|
||||
1. Drop the three patch hunks in `reader.py` (davey import block,
|
||||
`_maybe_dave_decrypt` method, hook in `callback()`).
|
||||
2. Revert `__version__` to upstream value in `__init__.py`.
|
||||
3. Update `Pinned commit` below.
|
||||
4. Run `pytest tests/test_voice_recv_dave.py tests/test_voice_adapter_contract.py`.
|
||||
|
||||
The contract test `test_voice_recv_fork_version` asserts `__version__ ==
|
||||
'0.5.3a+echo.dave1'` and will fail fast on any accidental wipe during a careless
|
||||
upstream sync — forcing a conscious decision to either re-port or drop the
|
||||
patch.
|
||||
|
||||
## Update procedure (vanilla upstream sync)
|
||||
|
||||
```bash
|
||||
cd vendor/discord-ext-voice-recv
|
||||
git fetch origin master
|
||||
git log HEAD..origin/master --oneline # review what changed
|
||||
git checkout <new-commit>
|
||||
# RE-APPLY the +echo.dave1 patch if upstream still lacks DAVE
|
||||
cd ../..
|
||||
source .venv/bin/activate && pip install -e vendor/discord-ext-voice-recv --force-reinstall
|
||||
pytest tests/test_voice_adapter_contract.py -v # MUST PASS — contract guard
|
||||
pytest tests/test_voice_adapter_contract.py tests/test_voice_recv_dave.py -v # MUST PASS — contract + DAVE guards
|
||||
```
|
||||
|
||||
Update this file's `Pinned commit` after a successful upgrade.
|
||||
|
||||
@@ -17,4 +17,4 @@ __title__ = 'discord.ext.voice_recv'
|
||||
__author__ = 'Imayhaveborkedit'
|
||||
__license__ = 'MIT'
|
||||
__copyright__ = 'Copyright 2021-present Imayhaveborkedit'
|
||||
__version__ = '0.5.3a'
|
||||
__version__ = '0.5.3a+echo.dave1'
|
||||
|
||||
@@ -19,6 +19,15 @@ try:
|
||||
except ImportError as e:
|
||||
raise RuntimeError("pynacl is required") from e
|
||||
|
||||
# Echo Core +echo.dave1 patch: DAVE E2E receive-side decrypt. See VENDOR_INFO.md.
|
||||
try:
|
||||
import davey
|
||||
_MEDIA_TYPE_AUDIO = davey.MediaType.audio
|
||||
_HAS_DAVE = True
|
||||
except ImportError:
|
||||
_MEDIA_TYPE_AUDIO = None
|
||||
_HAS_DAVE = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Callable, Any, Dict, Literal, Union
|
||||
|
||||
@@ -133,12 +142,63 @@ class AudioReader:
|
||||
def _is_ip_discovery_packet(self, data: bytes) -> bool:
|
||||
return len(data) == 74 and data[1] == 0x02
|
||||
|
||||
def _maybe_dave_decrypt(self, rtp_packet) -> Optional[bytes]:
|
||||
"""DAVE E2E layer applied after transport decrypt.
|
||||
|
||||
Returns the (possibly DAVE-unwrapped) opus payload, or None to drop the
|
||||
packet. No-op when DAVE is inactive — non-DAVE rooms and environments
|
||||
without `davey` installed pass through unchanged.
|
||||
|
||||
NOTE: `is_silence()` is NOT checked here. In a DAVE-active room the
|
||||
transport-decrypted payload is ciphertext, so `is_silence()` (which
|
||||
compares to plaintext OPUS_SILENCE ``b'\\xf8\\xff\\xfe'``) never matches.
|
||||
Silence frames are handled either by davey.decrypt returning plaintext
|
||||
silence (then caught at the existing post-decrypt silence check on
|
||||
``decrypted_data``), or dropped via the decrypt-raises path. The
|
||||
existing post-decrypt silence check continues to work because we
|
||||
overwrite ``decrypted_data`` in place.
|
||||
"""
|
||||
if not _HAS_DAVE:
|
||||
return rtp_packet.decrypted_data
|
||||
conn = self.voice_client._connection
|
||||
if getattr(conn, 'dave_protocol_version', 0) == 0:
|
||||
return rtp_packet.decrypted_data
|
||||
dave = getattr(conn, 'dave_session', None)
|
||||
if dave is None or not dave.ready:
|
||||
return rtp_packet.decrypted_data
|
||||
user_id = self.voice_client._ssrc_to_id.get(rtp_packet.ssrc)
|
||||
if user_id is None:
|
||||
# ACCEPTED REGRESSION: davey requires per-user key. When SPEAKING
|
||||
# event races behind the first audio packet, we drop 1-5 packets
|
||||
# (~40-200ms) per new speaker per session.
|
||||
return None
|
||||
# can_passthrough(user_id) mirrors Discord's protocol: when this user's
|
||||
# decryptor is in passthrough mode, packets are not DAVE-wrapped and
|
||||
# must be returned as-is. Otherwise davey.decrypt unwraps DAVE E2E.
|
||||
try:
|
||||
if dave.can_passthrough(user_id):
|
||||
return rtp_packet.decrypted_data
|
||||
except Exception as e:
|
||||
log.debug("can_passthrough check failed for ssrc=%s user=%s: %s: %s",
|
||||
rtp_packet.ssrc, user_id, type(e).__name__, e)
|
||||
try:
|
||||
return dave.decrypt(user_id, _MEDIA_TYPE_AUDIO, rtp_packet.decrypted_data)
|
||||
except Exception as e:
|
||||
log.debug("DAVE decrypt failed for ssrc=%s user=%s: %s: %s",
|
||||
rtp_packet.ssrc, user_id, type(e).__name__, e)
|
||||
return None
|
||||
|
||||
def callback(self, packet_data: bytes) -> None:
|
||||
packet = rtp_packet = rtcp_packet = None
|
||||
try:
|
||||
if not rtp.is_rtcp(packet_data):
|
||||
packet = rtp_packet = rtp.decode_rtp(packet_data)
|
||||
packet.decrypted_data = self.decryptor.decrypt_rtp(packet)
|
||||
# Echo Core +echo.dave1: DAVE E2E layer (no-op when inactive).
|
||||
dave_payload = self._maybe_dave_decrypt(rtp_packet)
|
||||
if dave_payload is None:
|
||||
return # drop packet, do not feed_rtp; reader thread stays alive
|
||||
rtp_packet.decrypted_data = dave_payload
|
||||
else:
|
||||
packet = rtcp_packet = rtp.decode_rtcp(self.decryptor.decrypt_rtcp(packet_data))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user