Compare commits
10 Commits
af5af8133f
...
d1bc77e87d
| Author | SHA1 | Date | |
|---|---|---|---|
| d1bc77e87d | |||
| e4f3177fc1 | |||
| 13931db953 | |||
| 23666f7910 | |||
| 217da65417 | |||
| 0cc01c1450 | |||
| c93c4f822e | |||
| 3af6bcaea4 | |||
| a3eefbc799 | |||
| a48562b2f5 |
101
cli.py
101
cli.py
@@ -114,6 +114,104 @@ def _load_sessions_file() -> dict:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _voice_doctor_checks() -> list[tuple[str, bool]]:
|
||||||
|
"""Voice-stack health checks (Pas 10).
|
||||||
|
|
||||||
|
Mirrors the logic in tools/voice_setup.py but returns (label, ok) tuples
|
||||||
|
so they integrate with cmd_doctor's PASS/FAIL output. All checks degrade
|
||||||
|
gracefully — ImportError on optional voice deps is reported as FAIL, never
|
||||||
|
raised, so the rest of `eco doctor` is unaffected.
|
||||||
|
"""
|
||||||
|
import importlib.util
|
||||||
|
import json as _json
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
results: list[tuple[str, bool]] = []
|
||||||
|
|
||||||
|
# 1. libopus0 loaded by discord.py
|
||||||
|
try:
|
||||||
|
import discord
|
||||||
|
if not discord.opus.is_loaded():
|
||||||
|
try:
|
||||||
|
discord.opus._load_default()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
results.append(("libopus loaded (discord.py)", discord.opus.is_loaded()))
|
||||||
|
except ImportError:
|
||||||
|
results.append(("libopus loaded (discord.py)", False))
|
||||||
|
except Exception:
|
||||||
|
results.append(("libopus loaded (discord.py)", False))
|
||||||
|
|
||||||
|
# 2. ffmpeg in PATH
|
||||||
|
results.append(("ffmpeg in PATH", shutil.which("ffmpeg") is not None))
|
||||||
|
|
||||||
|
# 3. Supertonic TTS reachable at http://127.0.0.1:7788/
|
||||||
|
supertonic_url = "http://127.0.0.1:7788/v1/audio/speech"
|
||||||
|
supertonic_ok = False
|
||||||
|
try:
|
||||||
|
payload = _json.dumps({
|
||||||
|
"model": "supertonic-3",
|
||||||
|
"input": "test",
|
||||||
|
"voice": "M2",
|
||||||
|
"response_format": "wav",
|
||||||
|
"lang": "ro",
|
||||||
|
}).encode("utf-8")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
supertonic_url,
|
||||||
|
data=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||||
|
supertonic_ok = resp.status == 200
|
||||||
|
except (urllib.error.URLError, ConnectionError, OSError):
|
||||||
|
supertonic_ok = False
|
||||||
|
except Exception:
|
||||||
|
supertonic_ok = False
|
||||||
|
results.append(("Supertonic TTS reachable at :7788", supertonic_ok))
|
||||||
|
|
||||||
|
# 4. faster-whisper importable (don't load model — too slow)
|
||||||
|
results.append((
|
||||||
|
"faster-whisper importable",
|
||||||
|
importlib.util.find_spec("faster_whisper") is not None,
|
||||||
|
))
|
||||||
|
|
||||||
|
# 5. silero-vad importable
|
||||||
|
results.append((
|
||||||
|
"silero-vad importable",
|
||||||
|
importlib.util.find_spec("silero_vad") is not None,
|
||||||
|
))
|
||||||
|
|
||||||
|
# 6. discord.ext.voice_recv importable (vendor package)
|
||||||
|
voice_recv_ok = False
|
||||||
|
try:
|
||||||
|
voice_recv_ok = importlib.util.find_spec("discord.ext.voice_recv") is not None
|
||||||
|
except (ImportError, ValueError, ModuleNotFoundError):
|
||||||
|
voice_recv_ok = False
|
||||||
|
except Exception:
|
||||||
|
voice_recv_ok = False
|
||||||
|
results.append(("discord.ext.voice_recv importable", voice_recv_ok))
|
||||||
|
|
||||||
|
# 7-9. Voice assets present and non-trivial size
|
||||||
|
voice_assets = [
|
||||||
|
("assets/voice/thinking.wav", 1024),
|
||||||
|
("assets/voice/beep_200ms.wav", 512),
|
||||||
|
("assets/voice/mhm.wav", 512),
|
||||||
|
]
|
||||||
|
for rel_path, min_bytes in voice_assets:
|
||||||
|
path = PROJECT_ROOT / rel_path
|
||||||
|
ok = False
|
||||||
|
try:
|
||||||
|
ok = path.exists() and path.stat().st_size > min_bytes
|
||||||
|
except OSError:
|
||||||
|
ok = False
|
||||||
|
label = f"{rel_path} (>{min_bytes}B)"
|
||||||
|
results.append((label, ok))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def cmd_doctor(args):
|
def cmd_doctor(args):
|
||||||
"""Run diagnostic checks."""
|
"""Run diagnostic checks."""
|
||||||
import re
|
import re
|
||||||
@@ -227,6 +325,9 @@ def cmd_doctor(args):
|
|||||||
else:
|
else:
|
||||||
checks.append(("WhatsApp bridge (optional)", True))
|
checks.append(("WhatsApp bridge (optional)", True))
|
||||||
|
|
||||||
|
# ---- Voice stack checks (Pas 10) ----
|
||||||
|
checks.extend(_voice_doctor_checks())
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
all_pass = True
|
all_pass = True
|
||||||
for label, passed in checks:
|
for label, passed in checks:
|
||||||
|
|||||||
@@ -104,6 +104,12 @@
|
|||||||
"ollama": {
|
"ollama": {
|
||||||
"url": "http://10.0.20.161:11434"
|
"url": "http://10.0.20.161:11434"
|
||||||
},
|
},
|
||||||
|
"voice": {
|
||||||
|
"allowed_user_ids": ["949388626146517022"],
|
||||||
|
"user_name": "Marius",
|
||||||
|
"default_voice": "M2",
|
||||||
|
"auto_leave_minutes": 5
|
||||||
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"personality": "personality/",
|
"personality": "personality/",
|
||||||
"tools": "tools/",
|
"tools": "tools/",
|
||||||
|
|||||||
@@ -189,4 +189,16 @@ Când lansez sub-agent, îi dau context: AGENTS.md, SOUL.md, USER.md + relevant
|
|||||||
- Discord links: `<url>` pentru a suprima embed-uri
|
- Discord links: `<url>` pentru a suprima embed-uri
|
||||||
- Cand primesc o sarcina mai mare de executat, raspund intotdeauna cu o reactie sau confirmare si apoi trec la executie
|
- Cand primesc o sarcina mai mare de executat, raspund intotdeauna cu o reactie sau confirmare si apoi trec la executie
|
||||||
- **Link-uri:** Folosesc `https://moltbot.tailf7372d.ts.net/echo/` (NU IP 100.120.119.70) pentru ca WhatsApp să le recunoască ca link-uri
|
- **Link-uri:** Folosesc `https://moltbot.tailf7372d.ts.net/echo/` (NU IP 100.120.119.70) pentru ca WhatsApp să le recunoască ca link-uri
|
||||||
- **Link-uri fișiere salvate:** Când salvez/menționez fișiere din `memory/kb/`, ofer automat link către `files.html#memory/kb/path/to/file.md` pentru preview
|
- **Link-uri fișiere salvate:** Când salvez/menționez fișiere din `memory/kb/`, ofer automat link către `files.html#memory/kb/path/to/file.md` pentru preview
|
||||||
|
|
||||||
|
## Voice mode
|
||||||
|
|
||||||
|
Reguli aplicate când `adapter_name == "discord-voice"` — Marius mă ascultă, nu citește. Vocea e intolerantă la lung și la structură.
|
||||||
|
|
||||||
|
- **1-3 propoziții max per răspuns.** Dacă am mai mult de spus, condensez sau mut în chat.
|
||||||
|
- **Fără markdown.** Niciun bold, italic, cod cu backticks, headere. Text plat, atât.
|
||||||
|
- **Fără bullet lists, nici numerotate.** Le pronunț natural ca propoziții: "trei lucruri: în primul rând..., apoi..., și la final..."
|
||||||
|
- **Fără linkuri.** Nu rostesc URL-uri. Dacă e relevant: "îți trimit linkul în chat".
|
||||||
|
- **Numere și valute formulate conversațional.** Scriu "treizeci de lei", nu "30 RON"; "douăzeci și cinci la sută", nu "25%". Modulul `normalize.py` face curățare tehnică, dar eu formulez deja natural — un om vorbește, nu citește tabelul.
|
||||||
|
- **Lung sau structurat → mută în chat.** Dacă răspunsul cere listă, cod, linkuri sau peste 3 propoziții, închei rostit cu "L-am scris în chat." iar restul ajunge în text channel mirror.
|
||||||
|
- **Ton:** cum vorbesc cu Marius la o cafea, nu cum scriu raport. Contracții, pauze, "păi" sau "stai puțin" dacă mă ajută să sune uman. Concis, fără tic-uri robotice.
|
||||||
@@ -63,6 +63,13 @@
|
|||||||
- **Venv:** ~/echo-core/.venv/ | **Model:** base
|
- **Venv:** ~/echo-core/.venv/ | **Model:** base
|
||||||
- **Utilizare:** `whisper.load_model('base').transcribe(path, language='ro')`
|
- **Utilizare:** `whisper.load_model('base').transcribe(path, language='ro')`
|
||||||
|
|
||||||
|
### Discord Voice
|
||||||
|
- **Ce este:** Bot conectat la un voice channel Discord — ascultă microfonul lui Marius, transcrie cu faster-whisper (`small` int8, RO), rutează prin router și răspunde rostit cu Supertonic TTS.
|
||||||
|
- **Cum sunt "în voce":** Slash command `/voice join` mă cheamă în channel; cât stau acolo, presence-ul arată că ascult. `/voice leave` sau auto-leave după 5 minute fără voce.
|
||||||
|
- **Latență așteptată:** ~5 secunde perceput end-to-end (STT p50 2.25s + LLM + TTS first chunk). Peste 3s pornesc un filler audio ("Stai să-mi adun gândurile") ca să nu pară mort.
|
||||||
|
- **Streaming TTS:** răspunsul iese pe clauze, nu cuvânt-cu-cuvânt și nu frază întreagă — primul sunet pleacă imediat ce am o propoziție scurtă.
|
||||||
|
- **Limitări:** 1-3 propoziții max (vezi AGENTS.md § Voice mode). Cuvinte rare, nume proprii sau acronime pot apărea ciudat în STT — dacă sună greșit, cer reformulare în loc să ghicesc.
|
||||||
|
|
||||||
### Pauze respirație
|
### Pauze respirație
|
||||||
- **Script:** `python3 tools/pauza_random.py`
|
- **Script:** `python3 tools/pauza_random.py`
|
||||||
- **Bancă:** memory/kb/tehnici-pauza.md
|
- **Bancă:** memory/kb/tehnici-pauza.md
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ def create_bot(config: Config) -> discord.Client:
|
|||||||
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
|
intents.voice_states = True
|
||||||
|
|
||||||
client = discord.Client(intents=intents)
|
client = discord.Client(intents=intents)
|
||||||
tree = app_commands.CommandTree(client)
|
tree = app_commands.CommandTree(client)
|
||||||
@@ -958,6 +959,11 @@ def create_bot(config: Config) -> discord.Client:
|
|||||||
else:
|
else:
|
||||||
await interaction.followup.send(result or "Eroare TTS.")
|
await interaction.followup.send(result or "Eroare TTS.")
|
||||||
|
|
||||||
|
# Voice slash group (Pas 7)
|
||||||
|
from src.adapters.discord_voice import register as register_voice
|
||||||
|
voice_group = register_voice(tree, client)
|
||||||
|
tree.add_command(voice_group)
|
||||||
|
|
||||||
# --- Ralph commands (autonomous project execution) ---
|
# --- Ralph commands (autonomous project execution) ---
|
||||||
|
|
||||||
async def _autocomplete_by_status(
|
async def _autocomplete_by_status(
|
||||||
@@ -1118,6 +1124,11 @@ def create_bot(config: Config) -> discord.Client:
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
client._ready_at = datetime.now(timezone.utc)
|
client._ready_at = datetime.now(timezone.utc)
|
||||||
logger.info("Echo Core online as %s", client.user)
|
logger.info("Echo Core online as %s", client.user)
|
||||||
|
# Voice models eager warmup (Pas 7)
|
||||||
|
from src.adapters import discord_voice
|
||||||
|
discord_voice._models_warmup_future = asyncio.create_task(
|
||||||
|
discord_voice.warmup_models()
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_chat(message: discord.Message) -> None:
|
async def _handle_chat(message: discord.Message) -> None:
|
||||||
"""Process a chat message through the router and send the response."""
|
"""Process a chat message through the router and send the response."""
|
||||||
|
|||||||
322
src/adapters/discord_voice.py
Normal file
322
src/adapters/discord_voice.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""Discord voice slash commands (Pas 7 — CONVERGENCE wiring).
|
||||||
|
|
||||||
|
Registers the `/voice` slash command group on the existing CommandTree and
|
||||||
|
exposes an async `warmup_models()` for eager model load at bot startup.
|
||||||
|
|
||||||
|
Owns nothing in `src/voice/*` — purely the Discord-facing wiring. Defers
|
||||||
|
heavy lifting to:
|
||||||
|
|
||||||
|
- ``src.voice.pipeline.VoiceSession`` — per-guild session state machine
|
||||||
|
- ``src.voice.pipeline.EchoVoiceSink`` — discord-ext-voice-recv sink
|
||||||
|
- ``src.voice.tts_stream.TTSQueue`` / ``EchoStreamingAudioSource``
|
||||||
|
- ``src.voice._discord_voice_adapter.connect_voice``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
|
|
||||||
|
# Optional DAVE dep (mandatory at runtime when discord.py 2.7.1 is paired with
|
||||||
|
# Discord voice gateway v=8; tolerated missing in tests / dev environments).
|
||||||
|
try:
|
||||||
|
import davey
|
||||||
|
_HAS_DAVE = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_DAVE = False
|
||||||
|
|
||||||
|
from src.config import Config
|
||||||
|
from src.voice.pipeline import (
|
||||||
|
VoiceSession,
|
||||||
|
EchoVoiceSink,
|
||||||
|
_get_whisper_model,
|
||||||
|
_get_silero_vad,
|
||||||
|
)
|
||||||
|
from src.voice.tts_stream import TTSQueue, EchoStreamingAudioSource
|
||||||
|
from src.voice._discord_voice_adapter import connect_voice
|
||||||
|
|
||||||
|
log = logging.getLogger("echo-core.discord.voice")
|
||||||
|
|
||||||
|
# Per-guild voice session registry. Key = guild_id.
|
||||||
|
_voice_sessions: dict[int, VoiceSession] = {}
|
||||||
|
|
||||||
|
# Set if model warmup failed; surfaces as ephemeral error on /voice join.
|
||||||
|
_voice_load_error: Optional[str] = None
|
||||||
|
|
||||||
|
# Reference to the eager warmup task created in on_ready, so /voice join can
|
||||||
|
# await it if the user is faster than the background load.
|
||||||
|
_models_warmup_future: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def warmup_models() -> None:
|
||||||
|
"""Eager model load — called from `on_ready()` as a background task.
|
||||||
|
|
||||||
|
Runs the (synchronous, blocking) model loaders on a worker thread so the
|
||||||
|
event loop stays responsive. On failure, sets `_voice_load_error` instead
|
||||||
|
of raising, so `/voice join` can degrade gracefully.
|
||||||
|
"""
|
||||||
|
global _voice_load_error
|
||||||
|
try:
|
||||||
|
if not discord.opus.is_loaded():
|
||||||
|
discord.opus.load_opus("libopus.so.0")
|
||||||
|
if _HAS_DAVE:
|
||||||
|
log.info("DAVE protocol v%d available (davey %s)",
|
||||||
|
davey.DAVE_PROTOCOL_VERSION, davey.__version__)
|
||||||
|
await asyncio.to_thread(_get_whisper_model)
|
||||||
|
await asyncio.to_thread(_get_silero_vad)
|
||||||
|
log.info("Voice models warm")
|
||||||
|
except Exception as e:
|
||||||
|
_voice_load_error = f"{type(e).__name__}: {e}"
|
||||||
|
log.error("Voice models load failed: %s", _voice_load_error)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_whitelist() -> set[int]:
|
||||||
|
"""Read `voice.allowed_user_ids` from config and coerce to int set.
|
||||||
|
|
||||||
|
Re-reads config from disk to pick up any runtime edits between bot start
|
||||||
|
and /voice join.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw = Config().get("voice.allowed_user_ids", [])
|
||||||
|
except Exception:
|
||||||
|
raw = []
|
||||||
|
out: set[int] = set()
|
||||||
|
for v in raw or []:
|
||||||
|
try:
|
||||||
|
out.add(int(v))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_voice() -> str:
|
||||||
|
try:
|
||||||
|
return Config().get("voice.default_voice", "M2") or "M2"
|
||||||
|
except Exception:
|
||||||
|
return "M2"
|
||||||
|
|
||||||
|
|
||||||
|
def register(tree: app_commands.CommandTree, bot: discord.Client) -> app_commands.Group:
|
||||||
|
"""Build the `/voice` slash command group and return it (caller registers)."""
|
||||||
|
voice_group = app_commands.Group(
|
||||||
|
name="voice", description="Echo Core voice channel"
|
||||||
|
)
|
||||||
|
|
||||||
|
@voice_group.command(name="join", description="Echo intră în voice channel-ul tău")
|
||||||
|
async def join(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
if _voice_load_error:
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"Voice unavailable: {_voice_load_error}", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if _models_warmup_future is not None and not _models_warmup_future.done():
|
||||||
|
try:
|
||||||
|
await _models_warmup_future
|
||||||
|
except Exception as e:
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"Voice unavailable: {type(e).__name__}: {e}", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
user = interaction.user
|
||||||
|
if not isinstance(user, discord.Member) or user.voice is None or user.voice.channel is None:
|
||||||
|
await interaction.followup.send(
|
||||||
|
"Intră într-un voice channel întâi.", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
channel = user.voice.channel
|
||||||
|
whitelist = _get_whitelist()
|
||||||
|
if user.id not in whitelist:
|
||||||
|
await interaction.followup.send(
|
||||||
|
"Nu ești pe whitelist voice.", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Reject double-join on the same guild.
|
||||||
|
guild_id = channel.guild.id
|
||||||
|
if guild_id in _voice_sessions:
|
||||||
|
await interaction.followup.send(
|
||||||
|
"Sunt deja în voice pe acest server. Folosește /voice leave întâi.",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Connect
|
||||||
|
try:
|
||||||
|
vc = await connect_voice(channel)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("connect_voice failed")
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"Conectare eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Build TTS queue + session
|
||||||
|
ttsq = TTSQueue(voice_id=_get_default_voice(), lang="ro")
|
||||||
|
ttsq.start()
|
||||||
|
try:
|
||||||
|
session = VoiceSession(
|
||||||
|
channel_id=channel.id,
|
||||||
|
guild_id=guild_id,
|
||||||
|
voice_client=vc,
|
||||||
|
text_channel=interaction.channel,
|
||||||
|
record_enabled=False,
|
||||||
|
mirror_enabled=True,
|
||||||
|
whitelist=whitelist,
|
||||||
|
ttsq=ttsq,
|
||||||
|
bot=bot,
|
||||||
|
loop=asyncio.get_running_loop(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("VoiceSession construction failed")
|
||||||
|
ttsq.stop()
|
||||||
|
try:
|
||||||
|
await vc.disconnect(force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"Sesiune voice eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
_voice_sessions[guild_id] = session
|
||||||
|
# Start TTS streaming source for the entire session. Chain the
|
||||||
|
# wake-up beep via `after=` so streaming takes over when beep ends.
|
||||||
|
def _start_stream(error: Optional[Exception] = None) -> None:
|
||||||
|
if error is not None:
|
||||||
|
log.warning("Beep playback ended with error: %s", error)
|
||||||
|
try:
|
||||||
|
vc.play(EchoStreamingAudioSource(ttsq))
|
||||||
|
log.info("TTS streaming source attached")
|
||||||
|
except Exception:
|
||||||
|
log.exception("EchoStreamingAudioSource attach failed")
|
||||||
|
try:
|
||||||
|
vc.play(
|
||||||
|
discord.FFmpegPCMAudio("assets/voice/beep_200ms.wav"),
|
||||||
|
after=_start_stream,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
log.warning("Beep playback skipped, starting stream directly", exc_info=True)
|
||||||
|
_start_stream()
|
||||||
|
# Attach sink
|
||||||
|
try:
|
||||||
|
bot_user_id = int(bot.user.id) if bot.user is not None else 0
|
||||||
|
sink = EchoVoiceSink(session=session, bot_user_id=bot_user_id)
|
||||||
|
vc.listen(sink)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("Sink attach failed")
|
||||||
|
_voice_sessions.pop(guild_id, None)
|
||||||
|
try:
|
||||||
|
session.cleanup("sink_attach_failed")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"Atașare sink eșuată: {type(e).__name__}: {e}", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Presence
|
||||||
|
try:
|
||||||
|
await bot.change_presence(activity=discord.Activity(
|
||||||
|
type=discord.ActivityType.listening,
|
||||||
|
name=f"{user.display_name} în #{channel.name}",
|
||||||
|
))
|
||||||
|
except Exception:
|
||||||
|
log.warning("Presence update skipped", exc_info=True)
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"În voce în #{channel.name}.", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@voice_group.command(name="leave", description="Echo iese din voice channel")
|
||||||
|
async def leave(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
guild_id = interaction.guild.id if interaction.guild else None
|
||||||
|
session = _voice_sessions.pop(guild_id, None) if guild_id is not None else None
|
||||||
|
if session is None:
|
||||||
|
await interaction.followup.send(
|
||||||
|
"Nu sunt în niciun voice channel aici.", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
session.cleanup("user_leave")
|
||||||
|
except Exception:
|
||||||
|
log.exception("session.cleanup raised")
|
||||||
|
try:
|
||||||
|
await bot.change_presence(activity=None)
|
||||||
|
except Exception:
|
||||||
|
log.warning("Presence reset skipped", exc_info=True)
|
||||||
|
await interaction.followup.send("Plecat.", ephemeral=True)
|
||||||
|
|
||||||
|
@voice_group.command(name="doctor", description="Verifică voice stack")
|
||||||
|
async def doctor(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
checks: list[tuple[str, bool]] = []
|
||||||
|
# libopus
|
||||||
|
try:
|
||||||
|
checks.append(("libopus", bool(discord.opus.is_loaded())))
|
||||||
|
except Exception:
|
||||||
|
checks.append(("libopus", False))
|
||||||
|
# warmup
|
||||||
|
checks.append(("voice load error", _voice_load_error is None))
|
||||||
|
# Build response
|
||||||
|
lines = ["**Voice doctor:**"]
|
||||||
|
for label, ok in checks:
|
||||||
|
lines.append(f"{'OK' if ok else 'FAIL'} — {label}")
|
||||||
|
if _voice_load_error:
|
||||||
|
lines.append(f" details: {_voice_load_error}")
|
||||||
|
await interaction.followup.send("\n".join(lines), ephemeral=True)
|
||||||
|
|
||||||
|
# --- /voice mirror on|off ---
|
||||||
|
mirror_group = app_commands.Group(
|
||||||
|
name="mirror", description="Text mirror", parent=voice_group
|
||||||
|
)
|
||||||
|
|
||||||
|
@mirror_group.command(name="on", description="Activează text mirror în canal")
|
||||||
|
async def mirror_on(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
guild_id = interaction.guild.id if interaction.guild else None
|
||||||
|
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||||
|
if s is None:
|
||||||
|
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||||
|
return
|
||||||
|
s.mirror_enabled = True
|
||||||
|
await interaction.followup.send("Mirror ON.", ephemeral=True)
|
||||||
|
|
||||||
|
@mirror_group.command(name="off", description="Dezactivează text mirror")
|
||||||
|
async def mirror_off(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
guild_id = interaction.guild.id if interaction.guild else None
|
||||||
|
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||||
|
if s is None:
|
||||||
|
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||||
|
return
|
||||||
|
s.mirror_enabled = False
|
||||||
|
await interaction.followup.send("Mirror OFF.", ephemeral=True)
|
||||||
|
|
||||||
|
# --- /voice record on|off ---
|
||||||
|
record_group = app_commands.Group(
|
||||||
|
name="record", description="KB recording", parent=voice_group
|
||||||
|
)
|
||||||
|
|
||||||
|
@record_group.command(name="on", description="Activează înregistrare în KB")
|
||||||
|
async def record_on(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
guild_id = interaction.guild.id if interaction.guild else None
|
||||||
|
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||||
|
if s is None:
|
||||||
|
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||||
|
return
|
||||||
|
s.record_enabled = True
|
||||||
|
await interaction.followup.send("Record ON.", ephemeral=True)
|
||||||
|
|
||||||
|
@record_group.command(name="off", description="Dezactivează înregistrare")
|
||||||
|
async def record_off(interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
guild_id = interaction.guild.id if interaction.guild else None
|
||||||
|
s = _voice_sessions.get(guild_id) if guild_id is not None else None
|
||||||
|
if s is None:
|
||||||
|
await interaction.followup.send("Nu sunt în voice.", ephemeral=True)
|
||||||
|
return
|
||||||
|
s.record_enabled = False
|
||||||
|
await interaction.followup.send("Record OFF.", ephemeral=True)
|
||||||
|
|
||||||
|
return voice_group
|
||||||
@@ -37,6 +37,42 @@ DEFAULT_TIMEOUT = 300 # seconds
|
|||||||
|
|
||||||
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-channel mutex for send_message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Two paths can hit `send_message(channel_id, ...)` concurrently for the same
|
||||||
|
# channel: a text adapter (Discord/Telegram/WhatsApp) and the voice adapter
|
||||||
|
# (`adapter_name="discord-voice"`). The underlying Claude CLI subprocess is
|
||||||
|
# blocking (`subprocess.Popen` with stream-json read loop) and stateful via
|
||||||
|
# `--resume <session_id>` — interleaving two concurrent invocations on the
|
||||||
|
# same channel would corrupt the conversation order.
|
||||||
|
#
|
||||||
|
# We use `threading.Lock` (NOT `asyncio.Lock`) because `send_message` is sync
|
||||||
|
# code typically run from `asyncio.to_thread` in async adapters. asyncio.Lock
|
||||||
|
# only serializes coroutines, not threads — it would NOT protect this path.
|
||||||
|
#
|
||||||
|
# Each channel gets its own lock so DIFFERENT channels still run in parallel.
|
||||||
|
# Locks are created lazily on first use; the dict itself is guarded by a
|
||||||
|
# small bootstrap lock so two concurrent first-uses don't race on creation.
|
||||||
|
_session_locks: dict[str, threading.Lock] = {}
|
||||||
|
_session_locks_bootstrap = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_lock(channel_id: str) -> threading.Lock:
|
||||||
|
"""Return the channel's mutex, creating it on first access.
|
||||||
|
|
||||||
|
Two threads racing to create the same channel's lock would otherwise
|
||||||
|
end up with different lock objects (setdefault is not atomic across
|
||||||
|
the read-modify-write under all interpreter conditions — defensive).
|
||||||
|
"""
|
||||||
|
lock = _session_locks.get(channel_id)
|
||||||
|
if lock is not None:
|
||||||
|
return lock
|
||||||
|
with _session_locks_bootstrap:
|
||||||
|
return _session_locks.setdefault(channel_id, threading.Lock())
|
||||||
|
|
||||||
|
|
||||||
PERSONALITY_FILES = [
|
PERSONALITY_FILES = [
|
||||||
"IDENTITY.md",
|
"IDENTITY.md",
|
||||||
"SOUL.md",
|
"SOUL.md",
|
||||||
@@ -543,19 +579,28 @@ def send_message(
|
|||||||
timeout: int = DEFAULT_TIMEOUT,
|
timeout: int = DEFAULT_TIMEOUT,
|
||||||
on_text: Callable[[str], None] | None = None,
|
on_text: Callable[[str], None] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""High-level convenience: auto start or resume based on channel state."""
|
"""High-level convenience: auto start or resume based on channel state.
|
||||||
session = get_active_session(channel_id)
|
|
||||||
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
Concurrency: a per-`channel_id` `threading.Lock` serializes invocations
|
||||||
if session is not None and session.get("session_id"):
|
that hit the same channel (e.g. text adapter + voice adapter racing on
|
||||||
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
the same Discord guild text channel). Different channels run in
|
||||||
# Use model from pre-set session if available, otherwise use provided model
|
parallel — each holds its own lock. Lock is acquired blocking; we rely
|
||||||
effective_model = model
|
on `timeout` (default 5 minutes) to bound the worst case rather than
|
||||||
if session is not None and session.get("model"):
|
a non-blocking acquire (loss of fairness vs adapter-side queueing).
|
||||||
effective_model = session["model"]
|
"""
|
||||||
response_text, _session_id = start_session(
|
with _get_session_lock(channel_id):
|
||||||
channel_id, message, effective_model, timeout, on_text=on_text
|
session = get_active_session(channel_id)
|
||||||
)
|
# Only resume if session has a valid session_id (not a pre-set model placeholder)
|
||||||
return response_text
|
if session is not None and session.get("session_id"):
|
||||||
|
return resume_session(session["session_id"], message, timeout, on_text=on_text)
|
||||||
|
# Use model from pre-set session if available, otherwise use provided model
|
||||||
|
effective_model = model
|
||||||
|
if session is not None and session.get("model"):
|
||||||
|
effective_model = session["model"]
|
||||||
|
response_text, _session_id = start_session(
|
||||||
|
channel_id, message, effective_model, timeout, on_text=on_text
|
||||||
|
)
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
def clear_session(channel_id: str) -> bool:
|
def clear_session(channel_id: str) -> bool:
|
||||||
|
|||||||
@@ -154,8 +154,17 @@ def route_message(
|
|||||||
channel_cfg = _get_channel_config(channel_id)
|
channel_cfg = _get_channel_config(channel_id)
|
||||||
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
||||||
|
|
||||||
|
# Voice-mode augment: prepend speaker prefix so Claude knows who spoke
|
||||||
|
# in a voice channel. Cheap now, future-proof for multi-speaker later.
|
||||||
|
# (Engineering decision #14 in the plan.) Only the discord-voice adapter
|
||||||
|
# triggers it — text adapters keep the message verbatim.
|
||||||
|
claude_text = text
|
||||||
|
if adapter_name == "discord-voice":
|
||||||
|
user_name = _get_config().get("voice.user_name", "user") or "user"
|
||||||
|
claude_text = f"[speaker:{user_name}] {text}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = send_message(channel_id, text, model=model, on_text=on_text)
|
response = send_message(channel_id, claude_text, model=model, on_text=on_text)
|
||||||
_set_last_response(channel_id, response)
|
_set_last_response(channel_id, response)
|
||||||
return response, False
|
return response, False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
67
src/voice/_discord_voice_adapter.py
Normal file
67
src/voice/_discord_voice_adapter.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapter layer over `discord-ext-voice-recv` (vendored at vendor/).
|
||||||
|
|
||||||
|
If discord-ext-voice-recv breaks, swap to py-cord by rewriting only this file.
|
||||||
|
Contract test in tests/test_voice_adapter_contract.py guards drift.
|
||||||
|
|
||||||
|
Downstream consumers (`src/voice/*`, `src/adapters/discord_voice.py`) MUST
|
||||||
|
import from this file — never from `discord.ext.voice_recv` directly.
|
||||||
|
|
||||||
|
## Public API surface (stable across upstream changes)
|
||||||
|
|
||||||
|
- ``VoiceReceiveClient`` — alias for ``voice_recv.VoiceRecvClient``. Subclass
|
||||||
|
of ``discord.VoiceClient`` with extra audio-receive plumbing.
|
||||||
|
Key methods used by the pipeline:
|
||||||
|
* ``await client.disconnect(force: bool = False)`` (from discord.VoiceClient)
|
||||||
|
* ``client.listen(sink, *, after=None)`` — attach an ``AudioSink``;
|
||||||
|
raises ``discord.ClientException`` if not connected or already listening
|
||||||
|
* ``client.stop_listening()`` — detach the current sink
|
||||||
|
* ``client.is_listening() -> bool``
|
||||||
|
* ``client.stop()`` — stop both playing and listening
|
||||||
|
* ``client.sink`` (property, getter+setter) — swap the active sink in place
|
||||||
|
|
||||||
|
- ``AudioSink`` — abstract base. Subclasses MUST implement:
|
||||||
|
* ``write(user: Optional[discord.User|Member], data: VoiceData) -> None``
|
||||||
|
* ``wants_opus() -> bool`` (True → receive opus bytes; False → receive PCM)
|
||||||
|
* ``cleanup() -> None``
|
||||||
|
|
||||||
|
- ``VoiceData`` — per-packet container. Slots: ``packet``, ``source``, ``pcm``.
|
||||||
|
``.pcm`` is decoded 48kHz s16le stereo bytes when ``wants_opus()`` is False.
|
||||||
|
``.opus`` property returns the raw opus bytes from the underlying RTP packet.
|
||||||
|
|
||||||
|
- ``connect_voice(channel) -> VoiceReceiveClient`` — async helper, returns a
|
||||||
|
connected receive-capable voice client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import discord
|
||||||
|
|
||||||
|
|
||||||
|
# --- Stable re-exports -------------------------------------------------------
|
||||||
|
|
||||||
|
VoiceReceiveClient = voice_recv.VoiceRecvClient
|
||||||
|
AudioSink = voice_recv.AudioSink
|
||||||
|
VoiceData = voice_recv.VoiceData
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VoiceReceiveClient",
|
||||||
|
"AudioSink",
|
||||||
|
"VoiceData",
|
||||||
|
"connect_voice",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_voice(channel: "discord.VoiceChannel") -> VoiceReceiveClient:
|
||||||
|
"""Connect to a Discord voice channel with the receive-capable client.
|
||||||
|
|
||||||
|
Thin wrapper around ``channel.connect(cls=VoiceRecvClient)`` so callers
|
||||||
|
don't have to import the vendored class directly.
|
||||||
|
"""
|
||||||
|
return await channel.connect(cls=VoiceReceiveClient)
|
||||||
267
src/voice/normalize.py
Normal file
267
src/voice/normalize.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Voice mode text normalization for Romanian TTS.
|
||||||
|
|
||||||
|
Pure functions — no side effects, no I/O, no logging. Strip markdown,
|
||||||
|
expand numbers / currency / symbols / abbreviations into natural-sounding
|
||||||
|
Romanian text. See plan: src/voice/normalize.py (Pas 3).
|
||||||
|
|
||||||
|
Pipeline order in normalize_for_tts:
|
||||||
|
strip_markdown -> expand_abbreviations -> expand_currency
|
||||||
|
-> expand_numbers_ro -> expand_symbols -> truncate(200)
|
||||||
|
|
||||||
|
Currency runs BEFORE generic number expansion so "12.50 RON" becomes
|
||||||
|
"doisprezece lei și cincizeci de bani" rather than
|
||||||
|
"doisprezece virgulă cincizeci RON".
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Markdown ----------
|
||||||
|
|
||||||
|
_MARKDOWN_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)')
|
||||||
|
_MARKDOWN_BOLD = re.compile(r'\*\*([^*]+)\*\*')
|
||||||
|
_MARKDOWN_CODE = re.compile(r'`([^`\n]+)`')
|
||||||
|
_MARKDOWN_ITALIC = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)')
|
||||||
|
_MARKDOWN_HEADING = re.compile(r'^[ \t]*#{1,6}[ \t]+', re.MULTILINE)
|
||||||
|
_MARKDOWN_LIST = re.compile(r'^[ \t]*[-*+][ \t]+', re.MULTILINE)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_markdown(text: str) -> str:
|
||||||
|
"""Remove common markdown formatting, preserve the visible content."""
|
||||||
|
text = _MARKDOWN_LINK.sub(r'\1', text)
|
||||||
|
text = _MARKDOWN_BOLD.sub(r'\1', text)
|
||||||
|
text = _MARKDOWN_CODE.sub(r'\1', text)
|
||||||
|
text = _MARKDOWN_ITALIC.sub(r'\1', text)
|
||||||
|
text = _MARKDOWN_HEADING.sub('', text)
|
||||||
|
text = _MARKDOWN_LIST.sub('', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Number helpers ----------
|
||||||
|
|
||||||
|
def _needs_de(n: int) -> bool:
|
||||||
|
"""Romanian: insert 'de' between numeral and noun for n >= 20,
|
||||||
|
except when the trailing 1-19 portion makes it sound off
|
||||||
|
(e.g., 105, 119 -> no 'de'; 120, 200 -> 'de').
|
||||||
|
"""
|
||||||
|
if n < 20:
|
||||||
|
return False
|
||||||
|
last = n % 100
|
||||||
|
if 1 <= last <= 19:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _int_to_ro(n: int) -> str:
|
||||||
|
return num2words(n, lang='ro')
|
||||||
|
|
||||||
|
|
||||||
|
def _decimal_to_ro(s: str) -> str:
|
||||||
|
"""Convert decimal string 'X.Y' to RO words.
|
||||||
|
|
||||||
|
Decimal part is read as a whole number ('3.14' -> 'trei virgulă paisprezece'),
|
||||||
|
unless it has a leading zero ('3.05' -> 'trei virgulă zero cinci') so the
|
||||||
|
magnitude is preserved.
|
||||||
|
"""
|
||||||
|
int_part, dec_part = s.split('.', 1)
|
||||||
|
int_words = _int_to_ro(int(int_part))
|
||||||
|
if dec_part.startswith('0') and len(dec_part) > 1:
|
||||||
|
dec_words = ' '.join(_int_to_ro(int(d)) for d in dec_part)
|
||||||
|
else:
|
||||||
|
dec_words = _int_to_ro(int(dec_part))
|
||||||
|
return f"{int_words} virgulă {dec_words}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Numbers ----------
|
||||||
|
|
||||||
|
_NUM_TOKEN = re.compile(r'(?<!\w)(\d+(?:\.\d+)?)(?!\w)')
|
||||||
|
|
||||||
|
|
||||||
|
def expand_numbers_ro(text: str) -> str:
|
||||||
|
"""Expand bare numeric tokens to Romanian words.
|
||||||
|
|
||||||
|
Only matches pure number tokens (no surrounding letters). Decimals
|
||||||
|
use 'virgulă' separator. Currency-bound numbers should already be
|
||||||
|
handled by expand_currency before this runs.
|
||||||
|
"""
|
||||||
|
def _sub(match: re.Match) -> str:
|
||||||
|
token = match.group(1)
|
||||||
|
if '.' in token:
|
||||||
|
return _decimal_to_ro(token)
|
||||||
|
return _int_to_ro(int(token))
|
||||||
|
|
||||||
|
return _NUM_TOKEN.sub(_sub, text)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Time ----------
|
||||||
|
|
||||||
|
_TIME_PATTERN = re.compile(r'(?<!\d)([01]?\d|2[0-3]):([0-5]?\d)(?!\d)')
|
||||||
|
|
||||||
|
|
||||||
|
def _format_minutes_ro(n: int) -> str:
|
||||||
|
"""Romanian-correct feminine forms for minute counts (0-59)."""
|
||||||
|
if n == 1:
|
||||||
|
return "un minut"
|
||||||
|
if n == 2:
|
||||||
|
return "două minute"
|
||||||
|
if n < 20:
|
||||||
|
return f"{_int_to_ro(n)} minute"
|
||||||
|
last = n % 10
|
||||||
|
rest = n - last
|
||||||
|
if last == 0:
|
||||||
|
return f"{_int_to_ro(n)} de minute"
|
||||||
|
if last == 1:
|
||||||
|
return f"{_int_to_ro(rest)} și una de minute"
|
||||||
|
if last == 2:
|
||||||
|
return f"{_int_to_ro(rest)} și două de minute"
|
||||||
|
return f"{_int_to_ro(rest)} și {_int_to_ro(last)} de minute"
|
||||||
|
|
||||||
|
|
||||||
|
def expand_time(text: str) -> str:
|
||||||
|
"""Expand ``HH:MM`` clock times into colloquial Romanian.
|
||||||
|
|
||||||
|
23:09 -> "douăzeci și trei și nouă minute"
|
||||||
|
23:00 -> "douăzeci și trei fix"
|
||||||
|
"""
|
||||||
|
def _sub(match: re.Match) -> str:
|
||||||
|
h = int(match.group(1))
|
||||||
|
m = int(match.group(2))
|
||||||
|
hour_str = _int_to_ro(h)
|
||||||
|
if m == 0:
|
||||||
|
return f"{hour_str} fix"
|
||||||
|
return f"{hour_str} și {_format_minutes_ro(m)}"
|
||||||
|
|
||||||
|
return _TIME_PATTERN.sub(_sub, text)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Currency ----------
|
||||||
|
|
||||||
|
_CURRENCY_MAIN = {
|
||||||
|
'RON': ('leu', 'lei'),
|
||||||
|
'USD': ('dolar', 'dolari'),
|
||||||
|
'EUR': ('euro', 'euro'),
|
||||||
|
'GBP': ('liră', 'lire'),
|
||||||
|
}
|
||||||
|
|
||||||
|
_CURRENCY_SUB = {
|
||||||
|
'RON': ('ban', 'bani'),
|
||||||
|
'USD': ('cent', 'cenți'),
|
||||||
|
'EUR': ('cent', 'cenți'),
|
||||||
|
'GBP': ('penny', 'pence'),
|
||||||
|
}
|
||||||
|
|
||||||
|
_CURRENCY_PATTERNS = [
|
||||||
|
# RON suffix (case-insensitive: RON, ron, lei)
|
||||||
|
(re.compile(r'(?<!\w)(\d+(?:\.\d+)?)\s+(?:RON|lei)\b', re.IGNORECASE), 'RON'),
|
||||||
|
# Prefix currencies
|
||||||
|
(re.compile(r'\$(\d+(?:\.\d+)?)'), 'USD'),
|
||||||
|
(re.compile(r'€(\d+(?:\.\d+)?)'), 'EUR'),
|
||||||
|
(re.compile(r'£(\d+(?:\.\d+)?)'), 'GBP'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _format_currency_unit(n: int, singular: str, plural: str) -> str:
|
||||||
|
"""Format integer amount + currency noun with proper RO singular/plural
|
||||||
|
and 'de' particle. Uses 'un' (article) for n=1, not 'unu' (cardinal).
|
||||||
|
"""
|
||||||
|
if n == 1:
|
||||||
|
return f"un {singular}"
|
||||||
|
word = _int_to_ro(n)
|
||||||
|
if _needs_de(n):
|
||||||
|
return f"{word} de {plural}"
|
||||||
|
return f"{word} {plural}"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_currency(amount: str, code: str) -> str:
|
||||||
|
main_sg, main_pl = _CURRENCY_MAIN[code]
|
||||||
|
if '.' in amount:
|
||||||
|
whole_s, frac_s = amount.split('.', 1)
|
||||||
|
# Normalize fractional part to 2 digits so "12.5 RON" reads as
|
||||||
|
# 50 bani, not 5 bani.
|
||||||
|
if len(frac_s) == 1:
|
||||||
|
frac_s = frac_s + '0'
|
||||||
|
elif len(frac_s) > 2:
|
||||||
|
frac_s = frac_s[:2]
|
||||||
|
whole = int(whole_s)
|
||||||
|
frac = int(frac_s)
|
||||||
|
whole_part = _format_currency_unit(whole, main_sg, main_pl)
|
||||||
|
if frac == 0:
|
||||||
|
return whole_part
|
||||||
|
sub_sg, sub_pl = _CURRENCY_SUB[code]
|
||||||
|
frac_part = _format_currency_unit(frac, sub_sg, sub_pl)
|
||||||
|
return f"{whole_part} și {frac_part}"
|
||||||
|
return _format_currency_unit(int(amount), main_sg, main_pl)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_currency(text: str) -> str:
|
||||||
|
"""Expand currency amounts into natural Romanian.
|
||||||
|
|
||||||
|
Recognises ``<n> RON`` / ``<n> lei`` suffix and ``$``, ``€``, ``£`` prefix
|
||||||
|
forms with optional 2-decimal fractional part (treated as sub-unit:
|
||||||
|
bani / cenți / pence).
|
||||||
|
"""
|
||||||
|
for pattern, code in _CURRENCY_PATTERNS:
|
||||||
|
text = pattern.sub(lambda m, c=code: _format_currency(m.group(1), c), text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Symbols ----------
|
||||||
|
|
||||||
|
def expand_symbols(text: str) -> str:
|
||||||
|
"""Replace common symbols with their Romanian spoken form."""
|
||||||
|
text = text.replace('%', ' la sută')
|
||||||
|
text = text.replace('&', ' și ')
|
||||||
|
text = text.replace('@', ' la ')
|
||||||
|
text = text.replace('°', ' grade')
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
from tools.tts import sanitize_for_supertonic as sanitize_punctuation
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Abbreviations ----------
|
||||||
|
|
||||||
|
# Longer patterns first so 'ș.a.m.d.' wins over 'ș.a.'
|
||||||
|
_ABBREVIATIONS = [
|
||||||
|
(re.compile(r'(?<!\w)[șş]\.a\.m\.d\.', re.IGNORECASE), 'și așa mai departe'),
|
||||||
|
(re.compile(r'(?<!\w)[șş]\.a\.', re.IGNORECASE), 'și altele'),
|
||||||
|
(re.compile(r'(?<!\w)etc\.', re.IGNORECASE), 'etcetera'),
|
||||||
|
(re.compile(r'(?<!\w)dl\.', re.IGNORECASE), 'domnul'),
|
||||||
|
(re.compile(r'(?<!\w)dna\.', re.IGNORECASE), 'doamna'),
|
||||||
|
(re.compile(r'(?<!\w)nr\.', re.IGNORECASE), 'numărul'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def expand_abbreviations(text: str) -> str:
|
||||||
|
"""Expand Romanian abbreviations into their full forms."""
|
||||||
|
for pattern, replacement in _ABBREVIATIONS:
|
||||||
|
text = pattern.sub(replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Top-level pipeline ----------
|
||||||
|
|
||||||
|
_MAX_WORDS = 200
|
||||||
|
_TRUNCATE_SUFFIX = "Restul l-am scris în chat."
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_for_tts(text: str) -> str:
|
||||||
|
"""Apply the full normalization pipeline and truncate to 200 words.
|
||||||
|
|
||||||
|
If the text exceeds 200 words, the first 200 are kept and the suffix
|
||||||
|
"Restul l-am scris în chat." is appended so the listener knows the
|
||||||
|
response continues in the text channel mirror.
|
||||||
|
"""
|
||||||
|
text = strip_markdown(text)
|
||||||
|
text = sanitize_punctuation(text)
|
||||||
|
text = expand_abbreviations(text)
|
||||||
|
text = expand_time(text)
|
||||||
|
text = expand_currency(text)
|
||||||
|
text = expand_numbers_ro(text)
|
||||||
|
text = expand_symbols(text)
|
||||||
|
words = text.split()
|
||||||
|
if len(words) > _MAX_WORDS:
|
||||||
|
text = ' '.join(words[:_MAX_WORDS]) + f" {_TRUNCATE_SUFFIX}"
|
||||||
|
return text.strip()
|
||||||
580
src/voice/pipeline.py
Normal file
580
src/voice/pipeline.py
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
"""Central voice pipeline: VAD -> STT -> Claude -> TTS for Discord voice.
|
||||||
|
|
||||||
|
``VoiceSession`` binds per-call state — voice_client, TTS queue, transcript
|
||||||
|
JSONL buffer, whitelist, presence — and exposes a single idempotent
|
||||||
|
``cleanup()`` invoked from every exit path (user /voice leave, network
|
||||||
|
disconnect, crash via ``__exit__``, auto-leave timer, user leaves channel).
|
||||||
|
|
||||||
|
``EchoVoiceSink`` is the discord-ext-voice-recv ``AudioSink`` subclass that
|
||||||
|
runs in the voice_recv reader thread. It batches 20ms PCM packets into
|
||||||
|
100ms windows for silero-vad inference, marks per-user speech timestamps,
|
||||||
|
and on 800ms cumulative silence flushes the accumulated audio through
|
||||||
|
faster-whisper. Hallucinated segments (``no_speech_prob > 0.6``) are
|
||||||
|
dropped. Valid transcripts are scheduled onto the session's event loop
|
||||||
|
via ``asyncio.run_coroutine_threadsafe``.
|
||||||
|
|
||||||
|
The bot's own ``user.id`` is filtered FIRST inside ``write()`` — load-bearing
|
||||||
|
echo prevention so a future whitelist expansion (Bianca, etc.) never lets
|
||||||
|
the bot transcribe itself.
|
||||||
|
|
||||||
|
See plan: ``src/voice/pipeline.py`` (Pas 5), Engineering decisions #4
|
||||||
|
(VAD 100ms batched), #5 (cleanup centralizat), #7 (bot.user.id explicit
|
||||||
|
guard).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.voice._discord_voice_adapter import AudioSink, VoiceData
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Discord delivers 48kHz s16le stereo PCM, 20ms per packet (3840 bytes).
|
||||||
|
SAMPLE_RATE_DISCORD = 48000
|
||||||
|
SAMPLE_RATE_WHISPER = 16000
|
||||||
|
PACKET_MS = 20
|
||||||
|
PACKET_BYTES = 3840 # 48000 Hz * 0.020 s * 2 channels * 2 bytes
|
||||||
|
VAD_WINDOW_MS = 100 # batch 5 * 20ms packets per VAD inference (Decision #4)
|
||||||
|
VAD_WINDOW_BYTES = PACKET_BYTES * (VAD_WINDOW_MS // PACKET_MS)
|
||||||
|
VAD_THRESHOLD = 0.5
|
||||||
|
SILENCE_FLUSH_MS = 800
|
||||||
|
NO_SPEECH_DROP_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
|
LOGS_DIR = PROJECT_ROOT / "logs"
|
||||||
|
VOICE_METRICS_PATH = LOGS_DIR / "voice_metrics.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Lazy model singletons ----------
|
||||||
|
|
||||||
|
_whisper_model: Any = None
|
||||||
|
_whisper_lock = threading.Lock()
|
||||||
|
_silero_model: Any = None
|
||||||
|
_silero_get_timestamps: Any = None
|
||||||
|
_silero_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_whisper_model() -> Any:
|
||||||
|
"""Lazy-load faster-whisper ``small`` int8 with the spike-validated
|
||||||
|
``cpu_threads=4`` (see ``tasks/voice-bench-results.md``)."""
|
||||||
|
global _whisper_model
|
||||||
|
if _whisper_model is not None:
|
||||||
|
return _whisper_model
|
||||||
|
with _whisper_lock:
|
||||||
|
if _whisper_model is not None:
|
||||||
|
return _whisper_model
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
_whisper_model = WhisperModel(
|
||||||
|
"small", device="cpu", compute_type="int8", cpu_threads=4,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
return _whisper_model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_silero_vad():
|
||||||
|
"""Lazy-load silero-vad. Returns ``(model, get_speech_timestamps)``."""
|
||||||
|
global _silero_model, _silero_get_timestamps
|
||||||
|
if _silero_model is not None:
|
||||||
|
return _silero_model, _silero_get_timestamps
|
||||||
|
with _silero_lock:
|
||||||
|
if _silero_model is not None:
|
||||||
|
return _silero_model, _silero_get_timestamps
|
||||||
|
from silero_vad import get_speech_timestamps, load_silero_vad
|
||||||
|
_silero_model = load_silero_vad()
|
||||||
|
_silero_get_timestamps = get_speech_timestamps
|
||||||
|
return _silero_model, _silero_get_timestamps
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Audio helpers ----------
|
||||||
|
|
||||||
|
def _pcm48_stereo_to_16_mono(pcm: bytes) -> np.ndarray:
|
||||||
|
"""Discord 48kHz s16le stereo bytes -> 16kHz mono float32 in [-1, 1].
|
||||||
|
|
||||||
|
Cheap downsample: average the two channels, then average every 3
|
||||||
|
samples (48k / 3 = 16k). faster-whisper + silero-vad accept the
|
||||||
|
resulting ``np.float32`` array directly.
|
||||||
|
"""
|
||||||
|
if not pcm:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
samples = np.frombuffer(pcm, dtype=np.int16)
|
||||||
|
if samples.size % 2 != 0:
|
||||||
|
samples = samples[:-1]
|
||||||
|
stereo = samples.reshape(-1, 2)
|
||||||
|
mono = stereo.mean(axis=1).astype(np.float32) / 32768.0
|
||||||
|
if mono.size == 0:
|
||||||
|
return mono
|
||||||
|
trim = (mono.size // 3) * 3
|
||||||
|
if trim == 0:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
mono = mono[:trim].reshape(-1, 3).mean(axis=1)
|
||||||
|
return mono.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- VoiceSession ----------
|
||||||
|
|
||||||
|
class VoiceSession:
|
||||||
|
"""Per-voice-call state with a single idempotent ``cleanup()``."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: int,
|
||||||
|
guild_id: int,
|
||||||
|
text_channel: Any,
|
||||||
|
voice_client: Any,
|
||||||
|
bot: Any,
|
||||||
|
ttsq: Any,
|
||||||
|
whitelist: Optional[set] = None,
|
||||||
|
record_enabled: bool = False,
|
||||||
|
mirror_enabled: bool = True,
|
||||||
|
transcripts_jsonl_path: Optional[Path] = None,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
|
router_route_message: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
self.channel_id = int(channel_id)
|
||||||
|
self.guild_id = int(guild_id)
|
||||||
|
self.text_channel = text_channel
|
||||||
|
self.voice_client = voice_client
|
||||||
|
self.bot = bot
|
||||||
|
self.ttsq = ttsq
|
||||||
|
self.whitelist: set = set(whitelist or set())
|
||||||
|
self.record_enabled = bool(record_enabled)
|
||||||
|
self.mirror_enabled = bool(mirror_enabled)
|
||||||
|
self.transcripts_jsonl_path = transcripts_jsonl_path
|
||||||
|
self.loop = loop
|
||||||
|
# Injection seam so tests can replace router.route_message without
|
||||||
|
# mocking the whole module.
|
||||||
|
if router_route_message is None:
|
||||||
|
from src.router import route_message as _rm
|
||||||
|
self._route_message = _rm
|
||||||
|
else:
|
||||||
|
self._route_message = router_route_message
|
||||||
|
|
||||||
|
self.last_activity_ts = time.monotonic()
|
||||||
|
self._jsonl_fh = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._cleaned_up = False
|
||||||
|
self._lock_owner_thread: Optional[int] = None
|
||||||
|
|
||||||
|
# ----- context manager -----
|
||||||
|
|
||||||
|
def __enter__(self) -> "VoiceSession":
|
||||||
|
self._lock.acquire()
|
||||||
|
self._lock_owner_thread = threading.get_ident()
|
||||||
|
if self.record_enabled and self.transcripts_jsonl_path is not None:
|
||||||
|
try:
|
||||||
|
self.transcripts_jsonl_path.parent.mkdir(
|
||||||
|
parents=True, exist_ok=True,
|
||||||
|
)
|
||||||
|
self._jsonl_fh = open(
|
||||||
|
self.transcripts_jsonl_path, "a",
|
||||||
|
buffering=1, encoding="utf-8",
|
||||||
|
)
|
||||||
|
except OSError as e:
|
||||||
|
log.warning("voice transcript open failed: %s", e)
|
||||||
|
self._jsonl_fh = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||||
|
self.cleanup("exit")
|
||||||
|
return False # never suppress exceptions
|
||||||
|
|
||||||
|
# ----- cleanup (centralized, idempotent) -----
|
||||||
|
|
||||||
|
def cleanup(self, reason: str) -> None:
|
||||||
|
"""Single drain path for ALL 5 exit scenarios. Safe to call twice."""
|
||||||
|
if self._cleaned_up:
|
||||||
|
return
|
||||||
|
self._cleaned_up = True
|
||||||
|
|
||||||
|
# 1. Flush or discard JSONL transcript.
|
||||||
|
if self._jsonl_fh is not None:
|
||||||
|
try:
|
||||||
|
self._jsonl_fh.flush()
|
||||||
|
self._jsonl_fh.close()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice transcript flush failed: %s", e)
|
||||||
|
self._jsonl_fh = None
|
||||||
|
if (not self.record_enabled
|
||||||
|
and self.transcripts_jsonl_path is not None
|
||||||
|
and self.transcripts_jsonl_path.exists()):
|
||||||
|
try:
|
||||||
|
self.transcripts_jsonl_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. Restore bot presence (clear Listening activity).
|
||||||
|
if self.bot is not None:
|
||||||
|
try:
|
||||||
|
change = getattr(self.bot, "change_presence", None)
|
||||||
|
if callable(change):
|
||||||
|
coro = change(activity=None)
|
||||||
|
if asyncio.iscoroutine(coro):
|
||||||
|
if self.loop is not None and self.loop.is_running():
|
||||||
|
asyncio.run_coroutine_threadsafe(coro, self.loop)
|
||||||
|
else:
|
||||||
|
# Best-effort: close the coroutine so Python
|
||||||
|
# doesn't emit "coroutine was never awaited".
|
||||||
|
coro.close()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice presence restore failed: %s", e)
|
||||||
|
|
||||||
|
# 3. Tear down the voice client.
|
||||||
|
if self.voice_client is not None:
|
||||||
|
try:
|
||||||
|
self.voice_client.cleanup()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice_client.cleanup failed: %s", e)
|
||||||
|
|
||||||
|
# 4. Stop the TTS queue worker.
|
||||||
|
if self.ttsq is not None:
|
||||||
|
try:
|
||||||
|
self.ttsq.stop()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("ttsq.stop failed: %s", e)
|
||||||
|
|
||||||
|
# 5. Release the session lock (held since __enter__).
|
||||||
|
try:
|
||||||
|
if self._lock.locked():
|
||||||
|
self._lock.release()
|
||||||
|
except RuntimeError:
|
||||||
|
# Released from a different thread than acquired it — already
|
||||||
|
# free for the next caller; nothing to do.
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._log_metric({"event": "cleanup", "reason": reason})
|
||||||
|
|
||||||
|
# ----- segment completion (scheduled from sink) -----
|
||||||
|
|
||||||
|
async def on_segment_done(
|
||||||
|
self,
|
||||||
|
speaker_id: int,
|
||||||
|
text: str,
|
||||||
|
no_speech_prob: float,
|
||||||
|
) -> None:
|
||||||
|
"""Mirror, persist, route to Claude, drive TTS via streaming callback."""
|
||||||
|
if self._cleaned_up:
|
||||||
|
return
|
||||||
|
self.last_activity_ts = time.monotonic()
|
||||||
|
speaker_name = self._resolve_speaker_name(speaker_id)
|
||||||
|
|
||||||
|
# Drop any TTS frames from the previous turn so a new utterance cuts off
|
||||||
|
# stale Echo speech (barge-in) and never mixes with the new response.
|
||||||
|
try:
|
||||||
|
self.ttsq.clear()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("ttsq.clear failed: %s", e)
|
||||||
|
|
||||||
|
# 1. Mirror to text channel (one Unicode 🎤 — exception per plan).
|
||||||
|
if self.mirror_enabled and self.text_channel is not None:
|
||||||
|
try:
|
||||||
|
send = getattr(self.text_channel, "send", None)
|
||||||
|
if callable(send):
|
||||||
|
coro = send(f"\U0001f3a4 {speaker_name}: \"{text}\"")
|
||||||
|
if asyncio.iscoroutine(coro):
|
||||||
|
await coro
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice mirror send failed: %s", e)
|
||||||
|
|
||||||
|
# 2. Append to JSONL transcript buffer if recording.
|
||||||
|
if self._jsonl_fh is not None:
|
||||||
|
try:
|
||||||
|
self._jsonl_fh.write(
|
||||||
|
json.dumps({
|
||||||
|
"ts": time.time(),
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"speaker": speaker_name,
|
||||||
|
"text": text,
|
||||||
|
"no_speech_prob": no_speech_prob,
|
||||||
|
}, ensure_ascii=False) + "\n"
|
||||||
|
)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice transcript write failed: %s", e)
|
||||||
|
|
||||||
|
block_count = [0]
|
||||||
|
|
||||||
|
def voice_stream_callback(block: str) -> None:
|
||||||
|
"""Called once per Claude streamed text block — pushes to TTS."""
|
||||||
|
block_count[0] += 1
|
||||||
|
log.info("voice stream block #%d (%d chars): %r",
|
||||||
|
block_count[0], len(block or ""), (block or "")[:80])
|
||||||
|
try:
|
||||||
|
self.ttsq.push_text(block)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("ttsq.push_text failed: %s", e)
|
||||||
|
|
||||||
|
# Dispatch to Claude. send_message is sync subprocess, run on
|
||||||
|
# a worker thread so the loop stays responsive for mirror/TTS.
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._route_message,
|
||||||
|
str(self.channel_id),
|
||||||
|
str(speaker_id),
|
||||||
|
text,
|
||||||
|
None, # model
|
||||||
|
voice_stream_callback, # on_text
|
||||||
|
"discord-voice", # adapter_name
|
||||||
|
)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.error("route_message voice path failed: %s", e)
|
||||||
|
|
||||||
|
# ----- helpers -----
|
||||||
|
|
||||||
|
def _resolve_speaker_name(self, speaker_id: int) -> str:
|
||||||
|
"""Best-effort display name lookup via the bot user cache."""
|
||||||
|
try:
|
||||||
|
if self.bot is not None and hasattr(self.bot, "get_user"):
|
||||||
|
user = self.bot.get_user(speaker_id)
|
||||||
|
if user is not None:
|
||||||
|
name = getattr(user, "display_name", None) or getattr(
|
||||||
|
user, "name", None,
|
||||||
|
)
|
||||||
|
if name:
|
||||||
|
return str(name)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
return str(speaker_id)
|
||||||
|
|
||||||
|
def _log_metric(self, payload: dict) -> None:
|
||||||
|
"""Append a structured event to ``logs/voice_metrics.jsonl``."""
|
||||||
|
event = {"ts": time.time(), "channel_id": self.channel_id, **payload}
|
||||||
|
try:
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(VOICE_METRICS_PATH, "a", buffering=1, encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- EchoVoiceSink ----------
|
||||||
|
|
||||||
|
class EchoVoiceSink(AudioSink):
|
||||||
|
"""PCM-in sink: per-user 20ms buffer -> 100ms VAD windows -> 800ms
|
||||||
|
silence triggers Whisper STT -> schedules ``on_segment_done`` on the
|
||||||
|
session loop.
|
||||||
|
|
||||||
|
Lives in the voice_recv reader thread; uses ``threading`` primitives
|
||||||
|
only (no asyncio in the hot path).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: VoiceSession, bot_user_id: int):
|
||||||
|
super().__init__()
|
||||||
|
self.session = session
|
||||||
|
self.bot_user_id = int(bot_user_id) if bot_user_id is not None else 0
|
||||||
|
self.whitelist: set = set(session.whitelist or set())
|
||||||
|
self._user_buffers: dict[int, bytearray] = {}
|
||||||
|
self._packet_accum: dict[int, bytearray] = {}
|
||||||
|
self._last_speech_ts: dict[int, float] = {}
|
||||||
|
self._has_speech: dict[int, bool] = {}
|
||||||
|
self._sink_lock = threading.Lock()
|
||||||
|
# Diagnostics: log once-per-user when packets first arrive and when
|
||||||
|
# VAD first detects speech. Cheap, but tells us exactly where the
|
||||||
|
# chain breaks when "I spoke but Echo heard nothing" happens.
|
||||||
|
self._first_packet_logged: set[int] = set()
|
||||||
|
self._first_speech_logged: set[int] = set()
|
||||||
|
# Background poller that triggers the silence flush even when Discord
|
||||||
|
# DTX stops delivering RTP packets after the user stops speaking. Without
|
||||||
|
# this, sink.write would stop firing and STT would never run on the
|
||||||
|
# final utterance.
|
||||||
|
self._poller_stop = threading.Event()
|
||||||
|
self._poller_thread = threading.Thread(
|
||||||
|
target=self._silence_flush_poller,
|
||||||
|
name="echo-voice-flush-poller",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._poller_thread.start()
|
||||||
|
|
||||||
|
def wants_opus(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
self._poller_stop.set()
|
||||||
|
with self._sink_lock:
|
||||||
|
self._user_buffers.clear()
|
||||||
|
self._packet_accum.clear()
|
||||||
|
self._last_speech_ts.clear()
|
||||||
|
self._has_speech.clear()
|
||||||
|
|
||||||
|
def write(self, user, voice_data: VoiceData) -> None:
|
||||||
|
# ---- FIRST GUARD (LOAD-BEARING): bot's own voice ---------------
|
||||||
|
if user is None:
|
||||||
|
return
|
||||||
|
uid = int(getattr(user, "id", 0) or 0)
|
||||||
|
if uid == 0:
|
||||||
|
return
|
||||||
|
if uid == self.bot_user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---- SECOND GUARD: whitelist filter ----------------------------
|
||||||
|
if self.whitelist and uid not in self.whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
pcm = getattr(voice_data, "pcm", None)
|
||||||
|
if not pcm:
|
||||||
|
return
|
||||||
|
|
||||||
|
if uid not in self._first_packet_logged:
|
||||||
|
self._first_packet_logged.add(uid)
|
||||||
|
log.info("voice sink: first PCM packet from user %s (%d bytes)", uid, len(pcm))
|
||||||
|
|
||||||
|
window_pcm: Optional[bytes] = None
|
||||||
|
pcm_for_stt: Optional[bytes] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self._sink_lock:
|
||||||
|
buf = self._user_buffers.setdefault(uid, bytearray())
|
||||||
|
accum = self._packet_accum.setdefault(uid, bytearray())
|
||||||
|
buf.extend(pcm)
|
||||||
|
accum.extend(pcm)
|
||||||
|
if len(accum) >= VAD_WINDOW_BYTES:
|
||||||
|
window_pcm = bytes(accum[:VAD_WINDOW_BYTES])
|
||||||
|
del accum[:VAD_WINDOW_BYTES]
|
||||||
|
|
||||||
|
if window_pcm is not None:
|
||||||
|
if self._vad_detects_speech(window_pcm):
|
||||||
|
if uid not in self._first_speech_logged:
|
||||||
|
self._first_speech_logged.add(uid)
|
||||||
|
log.info("voice sink: VAD detected speech from user %s", uid)
|
||||||
|
with self._sink_lock:
|
||||||
|
self._last_speech_ts[uid] = time.monotonic()
|
||||||
|
self._has_speech[uid] = True
|
||||||
|
|
||||||
|
pcm_for_stt = self._take_flushable_pcm(uid)
|
||||||
|
if pcm_for_stt:
|
||||||
|
self._flush_to_stt(uid, pcm_for_stt)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("EchoVoiceSink.write failed: %s", e)
|
||||||
|
|
||||||
|
def _take_flushable_pcm(self, uid: int) -> Optional[bytes]:
|
||||||
|
"""If user `uid` has buffered speech that's been silent ≥SILENCE_FLUSH_MS,
|
||||||
|
consume the buffer and return it. Otherwise return None."""
|
||||||
|
with self._sink_lock:
|
||||||
|
if not self._has_speech.get(uid):
|
||||||
|
return None
|
||||||
|
last = self._last_speech_ts.get(uid, 0.0)
|
||||||
|
silence_ms = (time.monotonic() - last) * 1000.0
|
||||||
|
if silence_ms < SILENCE_FLUSH_MS:
|
||||||
|
return None
|
||||||
|
pcm = bytes(self._user_buffers.get(uid, b""))
|
||||||
|
self._user_buffers[uid] = bytearray()
|
||||||
|
self._packet_accum[uid] = bytearray()
|
||||||
|
self._has_speech[uid] = False
|
||||||
|
return pcm if pcm else None
|
||||||
|
|
||||||
|
def _silence_flush_poller(self) -> None:
|
||||||
|
"""Background tick: Discord DTX stops sending RTP packets when the user
|
||||||
|
goes silent, so the inline flush check in `write()` never fires for the
|
||||||
|
last utterance. Poll every 200ms so the trailing audio actually reaches
|
||||||
|
Whisper."""
|
||||||
|
while not self._poller_stop.wait(0.2):
|
||||||
|
try:
|
||||||
|
with self._sink_lock:
|
||||||
|
pending = [uid for uid, has in self._has_speech.items() if has]
|
||||||
|
for uid in pending:
|
||||||
|
pcm = self._take_flushable_pcm(uid)
|
||||||
|
if pcm:
|
||||||
|
self._flush_to_stt(uid, pcm)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("silence flush poller iter failed: %s", e)
|
||||||
|
|
||||||
|
# ----- VAD -----
|
||||||
|
|
||||||
|
def _vad_detects_speech(self, pcm48_stereo: bytes) -> bool:
|
||||||
|
"""Run silero-vad on a 100ms window. silero-vad v5+ requires exactly
|
||||||
|
512 samples per call at 16kHz, so we slice the window into 512-sample
|
||||||
|
chunks and return True if any chunk crosses the threshold."""
|
||||||
|
try:
|
||||||
|
mono16 = _pcm48_stereo_to_16_mono(pcm48_stereo)
|
||||||
|
if mono16.size == 0:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
rms = float(np.sqrt(np.mean(mono16.astype(np.float64) ** 2)))
|
||||||
|
return rms > 0.02
|
||||||
|
model, _ = _get_silero_vad()
|
||||||
|
chunk = 512 # silero-vad v5+ hard requirement at 16kHz
|
||||||
|
max_prob = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for start in range(0, mono16.size - chunk + 1, chunk):
|
||||||
|
seg = mono16[start:start + chunk]
|
||||||
|
p = float(model(torch.from_numpy(seg), SAMPLE_RATE_WHISPER).item())
|
||||||
|
if p > max_prob:
|
||||||
|
max_prob = p
|
||||||
|
if p >= VAD_THRESHOLD:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.debug("VAD inference failed: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ----- STT flush -----
|
||||||
|
|
||||||
|
def _flush_to_stt(self, user_id: int, pcm48_stereo: bytes) -> None:
|
||||||
|
"""Downsample, Whisper-transcribe RO, drop hallucinations, dispatch."""
|
||||||
|
try:
|
||||||
|
mono16 = _pcm48_stereo_to_16_mono(pcm48_stereo)
|
||||||
|
if mono16.size == 0:
|
||||||
|
return
|
||||||
|
model = _get_whisper_model()
|
||||||
|
segments, _info = model.transcribe(
|
||||||
|
mono16, language="ro", beam_size=5,
|
||||||
|
initial_prompt=(
|
||||||
|
"Echo Core, asistent personal AI românesc al lui Marius. "
|
||||||
|
"Conversație colocvială în română."
|
||||||
|
),
|
||||||
|
condition_on_previous_text=False,
|
||||||
|
)
|
||||||
|
text_parts: list[str] = []
|
||||||
|
worst_no_speech = 0.0
|
||||||
|
for seg in segments:
|
||||||
|
no_sp = float(getattr(seg, "no_speech_prob", 0.0) or 0.0)
|
||||||
|
if no_sp > worst_no_speech:
|
||||||
|
worst_no_speech = no_sp
|
||||||
|
if no_sp > NO_SPEECH_DROP_THRESHOLD:
|
||||||
|
continue
|
||||||
|
seg_text = (getattr(seg, "text", "") or "").strip()
|
||||||
|
if seg_text:
|
||||||
|
text_parts.append(seg_text)
|
||||||
|
if not text_parts:
|
||||||
|
return
|
||||||
|
text = " ".join(text_parts).strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
self._schedule_segment_done(user_id, text, worst_no_speech)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("Whisper transcribe failed: %s", e)
|
||||||
|
|
||||||
|
def _schedule_segment_done(
|
||||||
|
self, user_id: int, text: str, no_speech_prob: float,
|
||||||
|
) -> None:
|
||||||
|
loop = self.session.loop
|
||||||
|
if loop is None or not loop.is_running():
|
||||||
|
log.debug("voice session loop missing — dropping segment")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.session.on_segment_done(user_id, text, no_speech_prob),
|
||||||
|
loop,
|
||||||
|
)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log.warning("voice segment dispatch failed: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VoiceSession",
|
||||||
|
"EchoVoiceSink",
|
||||||
|
"SILENCE_FLUSH_MS",
|
||||||
|
"VAD_THRESHOLD",
|
||||||
|
"VAD_WINDOW_MS",
|
||||||
|
"NO_SPEECH_DROP_THRESHOLD",
|
||||||
|
]
|
||||||
333
src/voice/tts_stream.py
Normal file
333
src/voice/tts_stream.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""Streaming TTS with clause-level chunking for Discord voice mode.
|
||||||
|
|
||||||
|
A worker thread consumes text -> produces 20ms PCM frames on a queue.Queue.
|
||||||
|
``EchoStreamingAudioSource`` pulls frames into Discord's audio thread so a
|
||||||
|
single ``voice_client.play()`` call lasts the whole turn (eliminates the
|
||||||
|
RTP gap between successive ``play()`` calls and the race with barge-in
|
||||||
|
``stop()``). See plan: src/voice/tts_stream.py (Pas 6 / Lane TTS),
|
||||||
|
Engineering decisions #6, #8, #15.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, List, Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from src.voice.normalize import normalize_for_tts
|
||||||
|
from tools.tts import synthesize
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Discord wants 20ms of 16-bit 48kHz stereo PCM per frame.
|
||||||
|
# 48000 Hz * 0.020 s * 2 channels * 2 bytes = 3840 bytes.
|
||||||
|
FRAME_BYTES = 3840
|
||||||
|
TARGET_SAMPLE_RATE = 48000
|
||||||
|
TARGET_CHANNELS = 2
|
||||||
|
TARGET_SAMPLE_WIDTH = 2
|
||||||
|
|
||||||
|
# Sentinel pushed onto the text queue to ask the worker to exit cleanly.
|
||||||
|
_POISON = object()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Clause segmentation ----------
|
||||||
|
|
||||||
|
# Split at Romanian sentence punctuation followed by whitespace. The
|
||||||
|
# trailing whitespace requirement protects mid-number (1.000), mid-decimal
|
||||||
|
# (12.5), and mid-abbreviation (M.D.) tokens, since none of those have a
|
||||||
|
# space right after the inner punctuation.
|
||||||
|
_CLAUSE_SPLIT = re.compile(r'(?<=[,;:.!?])\s+')
|
||||||
|
|
||||||
|
|
||||||
|
def clause_segments(text: str, min_words: int = 8) -> Iterator[str]:
|
||||||
|
"""Yield text in clause-sized chunks for streaming TTS.
|
||||||
|
|
||||||
|
Splits at ``, ; : . ! ?`` boundaries (only when the punctuation is
|
||||||
|
followed by whitespace, so numbers / decimals / abbreviations stay
|
||||||
|
intact). Short clauses are buffered and merged with the next one
|
||||||
|
until the accumulated chunk has at least ``min_words`` words. The
|
||||||
|
final remainder is always yielded, even if it's shorter than
|
||||||
|
``min_words`` -- otherwise the tail of the response would never
|
||||||
|
reach the TTS.
|
||||||
|
"""
|
||||||
|
if text is None:
|
||||||
|
return
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
pieces = [p.strip() for p in _CLAUSE_SPLIT.split(text) if p and p.strip()]
|
||||||
|
if not pieces:
|
||||||
|
return
|
||||||
|
buffer = ''
|
||||||
|
for clause in pieces:
|
||||||
|
buffer = (buffer + ' ' + clause).strip() if buffer else clause
|
||||||
|
if len(buffer.split()) >= min_words:
|
||||||
|
yield buffer
|
||||||
|
buffer = ''
|
||||||
|
if buffer:
|
||||||
|
yield buffer
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- WAV -> PCM frame conversion ----------
|
||||||
|
|
||||||
|
def _ffmpeg_resample(wav_bytes: bytes) -> bytes:
|
||||||
|
"""Convert any WAV payload to raw 48kHz stereo s16le PCM via ffmpeg.
|
||||||
|
|
||||||
|
ffmpeg is already an Echo Core hard dependency (heartbeat, video
|
||||||
|
transcription). Using a stdin/stdout pipe keeps the synth tempfile
|
||||||
|
short-lived and avoids extra disk traffic.
|
||||||
|
"""
|
||||||
|
proc = subprocess.run(
|
||||||
|
[
|
||||||
|
'ffmpeg', '-hide_banner', '-loglevel', 'error',
|
||||||
|
'-i', 'pipe:0',
|
||||||
|
'-f', 's16le',
|
||||||
|
'-ar', str(TARGET_SAMPLE_RATE),
|
||||||
|
'-ac', str(TARGET_CHANNELS),
|
||||||
|
'-acodec', 'pcm_s16le',
|
||||||
|
'pipe:1',
|
||||||
|
],
|
||||||
|
input=wav_bytes,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
err = proc.stderr.decode('utf-8', errors='replace')[:200]
|
||||||
|
raise RuntimeError(f"ffmpeg resample failed (rc={proc.returncode}): {err}")
|
||||||
|
return proc.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def _is_target_format(wav_bytes: bytes) -> bool:
|
||||||
|
"""Quick check whether the WAV already matches Discord's PCM format."""
|
||||||
|
try:
|
||||||
|
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||||
|
return (
|
||||||
|
w.getframerate() == TARGET_SAMPLE_RATE
|
||||||
|
and w.getnchannels() == TARGET_CHANNELS
|
||||||
|
and w.getsampwidth() == TARGET_SAMPLE_WIDTH
|
||||||
|
and w.getcomptype() == 'NONE'
|
||||||
|
)
|
||||||
|
except (wave.Error, EOFError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pcm_native(wav_bytes: bytes) -> bytes:
|
||||||
|
"""Strip the WAV header and return raw PCM (target format assumed)."""
|
||||||
|
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||||
|
return w.readframes(w.getnframes())
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_pcm_20ms_frames(wav_bytes: bytes) -> List[bytes]:
|
||||||
|
"""Parse a WAV blob, normalize to 48kHz s16le stereo, slice into 20ms frames.
|
||||||
|
|
||||||
|
The final frame is zero-padded to a full 3840 bytes so Discord's audio
|
||||||
|
thread always reads whole frames. Empty input yields an empty list.
|
||||||
|
"""
|
||||||
|
if not wav_bytes:
|
||||||
|
return []
|
||||||
|
pcm = _extract_pcm_native(wav_bytes) if _is_target_format(wav_bytes) else _ffmpeg_resample(wav_bytes)
|
||||||
|
if not pcm:
|
||||||
|
return []
|
||||||
|
frames: List[bytes] = []
|
||||||
|
for offset in range(0, len(pcm), FRAME_BYTES):
|
||||||
|
chunk = pcm[offset:offset + FRAME_BYTES]
|
||||||
|
if len(chunk) < FRAME_BYTES:
|
||||||
|
chunk = chunk + b'\x00' * (FRAME_BYTES - len(chunk))
|
||||||
|
frames.append(chunk)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- TTS worker queue ----------
|
||||||
|
|
||||||
|
class TTSQueue:
|
||||||
|
"""Worker thread: text in -> 20ms PCM frames out.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttsq = TTSQueue(voice_id="M2", lang="ro")
|
||||||
|
ttsq.start()
|
||||||
|
ttsq.push_text("salut Marius, ce mai faci?")
|
||||||
|
voice_client.play(EchoStreamingAudioSource(ttsq))
|
||||||
|
# ... barge-in detected:
|
||||||
|
ttsq.clear()
|
||||||
|
# ... session over:
|
||||||
|
ttsq.stop()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, voice_id: str = "M2", lang: str = "ro"):
|
||||||
|
self.voice_id = voice_id
|
||||||
|
self.lang = lang
|
||||||
|
self._text_queue: queue.Queue = queue.Queue()
|
||||||
|
self._pcm_queue: queue.Queue = queue.Queue()
|
||||||
|
self._worker_thread: Optional[threading.Thread] = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
|
# --- lifecycle ---
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self._worker_thread is not None and self._worker_thread.is_alive():
|
||||||
|
return
|
||||||
|
self._stop_event.clear()
|
||||||
|
self._worker_thread = threading.Thread(
|
||||||
|
target=self._worker_loop,
|
||||||
|
name=f"tts-worker-{self.voice_id}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._worker_thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Signal the worker to exit, drain queues, join (timeout 5s)."""
|
||||||
|
self._stop_event.set()
|
||||||
|
# Wake the worker if it's blocked on get(timeout=...).
|
||||||
|
self._text_queue.put(_POISON)
|
||||||
|
thread = self._worker_thread
|
||||||
|
if thread is not None:
|
||||||
|
thread.join(timeout=5.0)
|
||||||
|
self._worker_thread = None
|
||||||
|
self._drain(self._text_queue)
|
||||||
|
self._drain(self._pcm_queue)
|
||||||
|
|
||||||
|
# --- producer side ---
|
||||||
|
|
||||||
|
def push_text(self, text: str) -> None:
|
||||||
|
"""Normalize, segment into clauses, enqueue each clause for synthesis."""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
cleaned = normalize_for_tts(text)
|
||||||
|
n = 0
|
||||||
|
for clause in clause_segments(cleaned):
|
||||||
|
clause = clause.strip()
|
||||||
|
if clause:
|
||||||
|
self._text_queue.put(clause)
|
||||||
|
n += 1
|
||||||
|
log.info("ttsq.push_text: input %d chars → %d clauses queued", len(text), n)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Drop everything pending (used for barge-in)."""
|
||||||
|
self._drain(self._text_queue)
|
||||||
|
self._drain(self._pcm_queue)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return self._text_queue.empty() and self._pcm_queue.empty()
|
||||||
|
|
||||||
|
# --- consumer side (called by EchoStreamingAudioSource) ---
|
||||||
|
|
||||||
|
def get_frame_nowait(self) -> Optional[bytes]:
|
||||||
|
"""Return the next PCM frame if available, else None — no blocking.
|
||||||
|
|
||||||
|
Blocking inside the player's read() loop wrecks Discord's 20ms cadence
|
||||||
|
and the client interprets the stream as stuttering / out-of-order.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._pcm_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- internals ---
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _drain(q: queue.Queue) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
q.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _worker_loop(self) -> None:
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
try:
|
||||||
|
item = self._text_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
if item is _POISON:
|
||||||
|
break
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
preview = item[:60]
|
||||||
|
try:
|
||||||
|
result = synthesize(item, voice=self.voice_id, lang=self.lang)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("TTS synth raised for %r: %s", preview, e)
|
||||||
|
continue
|
||||||
|
if not result.get('ok'):
|
||||||
|
log.warning("TTS synth not ok for %r: %s", preview, result.get('error'))
|
||||||
|
continue
|
||||||
|
path = result.get('path')
|
||||||
|
if not path:
|
||||||
|
log.warning("TTS synth ok but no path for %r", preview)
|
||||||
|
continue
|
||||||
|
wav_bytes = b''
|
||||||
|
try:
|
||||||
|
wav_bytes = Path(path).read_bytes()
|
||||||
|
except OSError as e:
|
||||||
|
log.warning("TTS WAV read failed for %r: %s", preview, e)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
Path(path).unlink(missing_ok=True)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
if not wav_bytes:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
frames = wav_to_pcm_20ms_frames(wav_bytes)
|
||||||
|
except RuntimeError as e:
|
||||||
|
log.warning("TTS WAV-to-PCM failed for %r: %s", preview, e)
|
||||||
|
continue
|
||||||
|
if not frames:
|
||||||
|
log.warning("TTS WAV-to-PCM produced 0 frames for %r", preview)
|
||||||
|
continue
|
||||||
|
for frame in frames:
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
return
|
||||||
|
self._pcm_queue.put(frame)
|
||||||
|
log.info("TTS pushed %d frames (%.1fs) for %r",
|
||||||
|
len(frames), len(frames) * 0.02, preview)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Discord audio source ----------
|
||||||
|
|
||||||
|
class EchoStreamingAudioSource(discord.AudioSource):
|
||||||
|
"""Pull PCM frames from a ``TTSQueue`` into Discord's audio thread.
|
||||||
|
|
||||||
|
A single ``voice_client.play(EchoStreamingAudioSource(ttsq))`` call
|
||||||
|
spans the whole session. When the TTS queue is empty, ``read()``
|
||||||
|
returns a 20ms silence frame to keep the player alive — otherwise
|
||||||
|
Discord would interpret an empty return as end-of-stream and stop
|
||||||
|
the player, so real TTS frames pushed later would be silently
|
||||||
|
discarded. The player is explicitly terminated only via
|
||||||
|
``cleanup()`` (called on voice session teardown).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 20ms of s16le stereo at 48kHz silence (960 samples × 2 channels × 2 bytes).
|
||||||
|
_SILENCE_FRAME = b'\x00' * (960 * 2 * 2)
|
||||||
|
|
||||||
|
def __init__(self, ttsq: TTSQueue):
|
||||||
|
self._ttsq = ttsq
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
def read(self) -> bytes:
|
||||||
|
if self._closed:
|
||||||
|
return b''
|
||||||
|
frame = self._ttsq.get_frame_nowait()
|
||||||
|
if frame is None:
|
||||||
|
return self._SILENCE_FRAME
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def is_opus(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
self._closed = True
|
||||||
|
try:
|
||||||
|
self._ttsq.clear()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
@@ -17,6 +17,13 @@ Lecții capturate din corectările lui Marius. Citește acest fișier la începu
|
|||||||
|
|
||||||
<!-- Lecțiile se adaugă mai jos, cele mai noi sus. -->
|
<!-- Lecțiile se adaugă mai jos, cele mai noi sus. -->
|
||||||
|
|
||||||
|
## Supertonic rejectează ghilimelele curly (Unicode) cu HTTP 500
|
||||||
|
**Data:** 2026-05-27
|
||||||
|
**Context:** Marius a dat o comandă audio pe Discord cu un URL, iar răspunsul lui Claude conținea `„foo"` (ghilimele românești curly). Supertonic a returnat `HTTP 500: synthesis failed: Found 1 unsupported character(s): ['„']` și răspunsul nu s-a mai auzit. Fără retry logic vizibil în UX — pur și simplu tace.
|
||||||
|
**Greșeala:** Am presupus că `normalize_for_tts` produce text deja "TTS-safe" pentru Supertonic. În realitate `strip_markdown` păstrează ghilimelele Unicode (`„` U+201E, `"` U+201D, `—` U+2014, `…` U+2026, etc.) pe care Supertonic le refuză.
|
||||||
|
**Regula:** Înainte de orice apel HTTP la Supertonic, **sanitizează punctuația Unicode** la echivalentele ASCII (`„` `"` `"` → `"`, `'` `'` `‚` → `'`, `–` `—` → `-`, `…` → `...`, `«` `»` → `"`). Funcția `sanitize_punctuation` în `src/voice/normalize.py` face asta și e apelată chiar după `strip_markdown` în pipeline. Dacă apar caractere noi care crapă Supertonic (ex: simboluri matematice, săgeți), adaugă-le în `_TTS_PUNCT_MAP`.
|
||||||
|
**Când se aplică:** Orice cod care trimite text la Supertonic (`tools/tts.py`, `src/voice/tts_stream.py`). Inclusiv testare manuală cu `curl` — folosește text românesc realistic (include `„foo"`, em-dash `—`, ellipsis `…`).
|
||||||
|
|
||||||
## Mai multe threads ≠ mai rapid — fitează `cpu_threads` pe physical cores, nu logical
|
## Mai multe threads ≠ mai rapid — fitează `cpu_threads` pe physical cores, nu logical
|
||||||
**Data:** 2026-05-27
|
**Data:** 2026-05-27
|
||||||
**Context:** Benchmark `tools/voice_bench.py` pentru faster-whisper `small` int8 pe i7-6700T (4 physical / 8 logical cores). Marius a urcat VM-ul de la 2 → 4 → 6 cores online, așteptând că mai multe = mai rapid.
|
**Context:** Benchmark `tools/voice_bench.py` pentru faster-whisper `small` int8 pe i7-6700T (4 physical / 8 logical cores). Marius a urcat VM-ul de la 2 → 4 → 6 cores online, așteptând că mai multe = mai rapid.
|
||||||
|
|||||||
307
tests/test_claude_session_mutex.py
Normal file
307
tests/test_claude_session_mutex.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
"""Regression-critical tests for per-channel mutex in src/claude_session.py.
|
||||||
|
|
||||||
|
Three scenarios from the eng-review test plan (2026-05-27):
|
||||||
|
|
||||||
|
1. Concurrent `send_message` calls on the SAME channel_id serialize —
|
||||||
|
the second waits for the first to finish before its subprocess runs.
|
||||||
|
2. Concurrent `send_message` calls on DIFFERENT channel_ids run in parallel
|
||||||
|
— independent channels never block each other.
|
||||||
|
3. Acquisition contract is documented and consistent: the lock is acquired
|
||||||
|
blocking (no acquire timeout), which means a hung subprocess on
|
||||||
|
channel X delays subsequent X messages but never X' (X != X'). This
|
||||||
|
test pins that behavior so future refactors must preserve it.
|
||||||
|
|
||||||
|
The mutex is `threading.Lock`, NOT `asyncio.Lock`, because `send_message`
|
||||||
|
is a sync function typically dispatched via `asyncio.to_thread` from
|
||||||
|
async adapters. asyncio.Lock would serialize coroutines only — not the
|
||||||
|
subprocess invocation. See plan section "Engineering decisions" #2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src import claude_session
|
||||||
|
from src.claude_session import (
|
||||||
|
_get_session_lock,
|
||||||
|
_session_locks,
|
||||||
|
send_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_session_locks():
|
||||||
|
"""Each test starts with a fresh lock map so we don't share state."""
|
||||||
|
_session_locks.clear()
|
||||||
|
yield
|
||||||
|
_session_locks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_sessions(tmp_path, monkeypatch):
|
||||||
|
"""Isolated active.json per test — keeps real session state untouched."""
|
||||||
|
sessions_dir = tmp_path / "sessions"
|
||||||
|
sessions_dir.mkdir()
|
||||||
|
sf = sessions_dir / "active.json"
|
||||||
|
sf.write_text("{}")
|
||||||
|
monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir)
|
||||||
|
monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf)
|
||||||
|
return sf
|
||||||
|
|
||||||
|
|
||||||
|
def _slow_run_claude(sleep_seconds: float, in_critical: threading.Event,
|
||||||
|
concurrent_seen: threading.Event):
|
||||||
|
"""Build a fake `_run_claude` that signals when inside the critical section.
|
||||||
|
|
||||||
|
The fake holds the simulated subprocess for `sleep_seconds`. Any other
|
||||||
|
invocation that overlaps will set `concurrent_seen` — the mutex test
|
||||||
|
asserts this NEVER happens for the same channel_id.
|
||||||
|
"""
|
||||||
|
state = {"active": 0, "lock": threading.Lock()}
|
||||||
|
|
||||||
|
def fake(cmd, timeout, on_text=None, cwd=None):
|
||||||
|
with state["lock"]:
|
||||||
|
state["active"] += 1
|
||||||
|
if state["active"] > 1:
|
||||||
|
concurrent_seen.set()
|
||||||
|
in_critical.set()
|
||||||
|
time.sleep(sleep_seconds)
|
||||||
|
with state["lock"]:
|
||||||
|
state["active"] -= 1
|
||||||
|
return {
|
||||||
|
"result": "Hello from Claude!",
|
||||||
|
"session_id": "sess-abc-123",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
"cost_usd": 0.001,
|
||||||
|
"duration_ms": int(sleep_seconds * 1000),
|
||||||
|
"num_turns": 1,
|
||||||
|
"intermediate_count": 0,
|
||||||
|
"subtype": "success",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fake
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 1 — same channel serializes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSameChannelSerializes:
|
||||||
|
def test_two_concurrent_calls_same_channel_run_one_at_a_time(
|
||||||
|
self, temp_sessions
|
||||||
|
):
|
||||||
|
"""Two parallel send_message on the SAME channel_id never overlap.
|
||||||
|
|
||||||
|
We instrument `_run_claude` to signal whenever more than one
|
||||||
|
invocation is concurrently inside it. The mutex MUST prevent that.
|
||||||
|
"""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.25, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(send_message, "ch-same", f"msg-{i}")
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
results = [f.result(timeout=10) for f in futures]
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
assert not concurrent_seen.is_set(), (
|
||||||
|
"Two send_message calls on the same channel ran concurrently — "
|
||||||
|
"mutex did not serialize them."
|
||||||
|
)
|
||||||
|
assert all(r == "Hello from Claude!" for r in results)
|
||||||
|
# Two serial 0.25s subprocesses must take at least ~0.5s total
|
||||||
|
# (we allow a generous floor — schedulers can be slow).
|
||||||
|
assert elapsed >= 0.45, f"Expected serialized ~0.5s, got {elapsed:.3f}s"
|
||||||
|
|
||||||
|
def test_lock_is_reentrant_per_channel_dict(self, temp_sessions):
|
||||||
|
"""`_get_session_lock` returns the SAME lock object for the same channel."""
|
||||||
|
lock_a1 = _get_session_lock("channel-A")
|
||||||
|
lock_a2 = _get_session_lock("channel-A")
|
||||||
|
lock_b = _get_session_lock("channel-B")
|
||||||
|
assert lock_a1 is lock_a2
|
||||||
|
assert lock_a1 is not lock_b
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 2 — different channels parallel
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDifferentChannelsParallel:
|
||||||
|
def test_two_concurrent_calls_different_channels_run_in_parallel(
|
||||||
|
self, temp_sessions
|
||||||
|
):
|
||||||
|
"""Different channels MUST NOT block each other.
|
||||||
|
|
||||||
|
We measure elapsed wall-clock: two 0.4s subprocesses on different
|
||||||
|
channels should finish in ~0.4s (parallel), NOT ~0.8s (serialized).
|
||||||
|
"""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
# `concurrent_seen` is OK to fire here — we WANT them to overlap.
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.4, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
||||||
|
f1 = pool.submit(send_message, "ch-A", "msg-A")
|
||||||
|
f2 = pool.submit(send_message, "ch-B", "msg-B")
|
||||||
|
results = [f1.result(timeout=10), f2.result(timeout=10)]
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
assert all(r == "Hello from Claude!" for r in results)
|
||||||
|
# Parallel execution: total time should be close to 0.4s, well under
|
||||||
|
# 0.7s (would mean serialization). 0.65s ceiling allows for GIL +
|
||||||
|
# scheduler jitter on a busy test box.
|
||||||
|
assert elapsed < 0.65, (
|
||||||
|
f"Different channels appear serialized: elapsed {elapsed:.3f}s "
|
||||||
|
f"(expected ~0.4s parallel, <0.65s ceiling)"
|
||||||
|
)
|
||||||
|
assert concurrent_seen.is_set(), (
|
||||||
|
"Different channels did not overlap — mutex is too coarse "
|
||||||
|
"(should be per-channel, not global)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_three_channels_all_overlap(self, temp_sessions):
|
||||||
|
"""Stress: three concurrent channels all run in parallel."""
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(send_message, f"ch-{i}", f"msg-{i}")
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
for f in as_completed(futures, timeout=10):
|
||||||
|
assert f.result() == "Hello from Claude!"
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
# 3 × 0.3s in parallel ≈ 0.3s; serial would be ~0.9s.
|
||||||
|
assert elapsed < 0.6, (
|
||||||
|
f"Three channels serialized: {elapsed:.3f}s (expected <0.6s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 3 — acquisition behavior documented and consistent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAcquisitionBehavior:
|
||||||
|
"""Pin the chosen acquisition policy: blocking, no timeout.
|
||||||
|
|
||||||
|
Project style is to bound subprocess execution via `timeout` (default
|
||||||
|
5 min) rather than fail-fast on lock acquire. Reasons:
|
||||||
|
|
||||||
|
- Adapter callers (Discord/Telegram/voice) already serialize work via
|
||||||
|
asyncio.to_thread; queue depth is naturally bounded.
|
||||||
|
- A non-blocking acquire would surface a timing error to the user
|
||||||
|
("busy, try again") for an entirely transient and self-resolving
|
||||||
|
condition. Blocking gives FIFO-ish ordering with simple semantics.
|
||||||
|
- If a subprocess truly hangs past `timeout`, _run_claude raises
|
||||||
|
TimeoutError → the held lock releases (via `with`) → queued
|
||||||
|
callers proceed.
|
||||||
|
|
||||||
|
This test pins that: a second caller waits and eventually proceeds; it
|
||||||
|
does not raise an exception on contention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_contested_acquire_blocks_then_proceeds(self, temp_sessions):
|
||||||
|
in_critical = threading.Event()
|
||||||
|
concurrent_seen = threading.Event()
|
||||||
|
slow = _slow_run_claude(0.3, in_critical, concurrent_seen)
|
||||||
|
|
||||||
|
results: list[str | BaseException] = []
|
||||||
|
|
||||||
|
def run(label: str):
|
||||||
|
try:
|
||||||
|
results.append(send_message("ch-contend", label))
|
||||||
|
except BaseException as e:
|
||||||
|
results.append(e)
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=slow):
|
||||||
|
t1 = threading.Thread(target=run, args=("first",))
|
||||||
|
t1.start()
|
||||||
|
# Wait until the first call is inside the critical section so
|
||||||
|
# the second is GUARANTEED to contend on the lock.
|
||||||
|
assert in_critical.wait(timeout=2.0), "first call never entered"
|
||||||
|
in_critical.clear()
|
||||||
|
t2 = threading.Thread(target=run, args=("second",))
|
||||||
|
t2.start()
|
||||||
|
t1.join(timeout=5.0)
|
||||||
|
t2.join(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Both must return the canned response — no exception, no error.
|
||||||
|
assert all(r == "Hello from Claude!" for r in results), (
|
||||||
|
f"Contended acquire surfaced an error instead of blocking: {results}"
|
||||||
|
)
|
||||||
|
# Critical-section overlap check: contended calls MUST serialize.
|
||||||
|
assert not concurrent_seen.is_set(), (
|
||||||
|
"Contended same-channel calls ran concurrently — mutex broken."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lock_released_on_subprocess_exception(self, temp_sessions):
|
||||||
|
"""If `_run_claude` raises, the lock MUST be released so the next
|
||||||
|
caller can proceed (otherwise a single error deadlocks the channel
|
||||||
|
forever)."""
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
def flaky(cmd, timeout, on_text=None, cwd=None):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
raise RuntimeError("simulated subprocess crash")
|
||||||
|
return {
|
||||||
|
"result": "Hello from Claude!",
|
||||||
|
"session_id": "sess-abc-123",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
"cost_usd": 0.001,
|
||||||
|
"duration_ms": 50,
|
||||||
|
"num_turns": 1,
|
||||||
|
"intermediate_count": 0,
|
||||||
|
"subtype": "success",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(claude_session, "_run_claude", side_effect=flaky):
|
||||||
|
with pytest.raises(RuntimeError, match="simulated subprocess crash"):
|
||||||
|
send_message("ch-recover", "first")
|
||||||
|
|
||||||
|
# Second call MUST acquire the lock (proves the first released it).
|
||||||
|
# We use a short timeout via a thread so a deadlock would fail loudly.
|
||||||
|
done = threading.Event()
|
||||||
|
result_box: list[str] = []
|
||||||
|
|
||||||
|
def second():
|
||||||
|
result_box.append(send_message("ch-recover", "second"))
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=second)
|
||||||
|
t.start()
|
||||||
|
assert done.wait(timeout=3.0), (
|
||||||
|
"Second call deadlocked — lock was not released on exception."
|
||||||
|
)
|
||||||
|
t.join(timeout=1.0)
|
||||||
|
assert result_box == ["Hello from Claude!"]
|
||||||
222
tests/test_voice_adapter_contract.py
Normal file
222
tests/test_voice_adapter_contract.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Contract test for `src/voice/_discord_voice_adapter.py`.
|
||||||
|
|
||||||
|
Purpose: catch drift when the vendored `discord-ext-voice-recv` is upgraded.
|
||||||
|
If upstream renames/removes a method we depend on, this test fails LOUDLY
|
||||||
|
before any downstream code breaks at runtime in a Discord voice call.
|
||||||
|
|
||||||
|
Per VENDOR_INFO.md: this test MUST PASS after every vendor upgrade.
|
||||||
|
|
||||||
|
Plain `import` + `hasattr` / `callable` checks — no mocks. We're verifying
|
||||||
|
the SHAPE of the API surface, not behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# --- Adapter re-exports import cleanly --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_adapter_exports_voice_receive_client():
|
||||||
|
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||||
|
|
||||||
|
assert VoiceReceiveClient is not None
|
||||||
|
assert inspect.isclass(VoiceReceiveClient)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adapter_exports_audio_sink():
|
||||||
|
from src.voice._discord_voice_adapter import AudioSink
|
||||||
|
|
||||||
|
assert AudioSink is not None
|
||||||
|
assert inspect.isclass(AudioSink)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adapter_exports_voice_data():
|
||||||
|
from src.voice._discord_voice_adapter import VoiceData
|
||||||
|
|
||||||
|
assert VoiceData is not None
|
||||||
|
assert inspect.isclass(VoiceData)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adapter_exports_connect_helper():
|
||||||
|
from src.voice._discord_voice_adapter import connect_voice
|
||||||
|
|
||||||
|
assert callable(connect_voice)
|
||||||
|
assert inspect.iscoroutinefunction(connect_voice)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Re-exports point at the real vendored classes (no accidental shadowing) -
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_receive_client_is_voice_recv_client():
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||||
|
|
||||||
|
assert VoiceReceiveClient is voice_recv.VoiceRecvClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_sink_is_voice_recv_audio_sink():
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
from src.voice._discord_voice_adapter import AudioSink
|
||||||
|
|
||||||
|
assert AudioSink is voice_recv.AudioSink
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_data_is_voice_recv_voice_data():
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
from src.voice._discord_voice_adapter import VoiceData
|
||||||
|
|
||||||
|
assert VoiceData is voice_recv.VoiceData
|
||||||
|
|
||||||
|
|
||||||
|
# --- VoiceReceiveClient API surface used by the pipeline --------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method_name",
|
||||||
|
[
|
||||||
|
"connect", # inherited from discord.VoiceClient
|
||||||
|
"disconnect", # inherited from discord.VoiceClient
|
||||||
|
"listen", # voice_recv extension
|
||||||
|
"stop_listening", # voice_recv extension
|
||||||
|
"is_listening", # voice_recv extension
|
||||||
|
"stop", # voice_recv extension (stops play+listen)
|
||||||
|
"cleanup", # voice_recv extension
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_voice_receive_client_has_method(method_name):
|
||||||
|
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||||
|
|
||||||
|
attr = getattr(VoiceReceiveClient, method_name, None)
|
||||||
|
assert attr is not None, f"VoiceReceiveClient is missing `.{method_name}()`"
|
||||||
|
assert callable(attr), f"VoiceReceiveClient.{method_name} is not callable"
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_receive_client_listen_accepts_sink_and_after():
|
||||||
|
"""`.listen(sink, *, after=None)` is the canonical call shape."""
|
||||||
|
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||||
|
|
||||||
|
sig = inspect.signature(VoiceReceiveClient.listen)
|
||||||
|
params = sig.parameters
|
||||||
|
assert "sink" in params, f"VoiceReceiveClient.listen missing `sink` param; got {list(params)}"
|
||||||
|
assert "after" in params, f"VoiceReceiveClient.listen missing `after` kwarg; got {list(params)}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_receive_client_has_sink_property():
|
||||||
|
"""`.sink` is read/write so we can swap sinks in place."""
|
||||||
|
from src.voice._discord_voice_adapter import VoiceReceiveClient
|
||||||
|
|
||||||
|
sink_attr = inspect.getattr_static(VoiceReceiveClient, "sink", None)
|
||||||
|
assert isinstance(sink_attr, property), "VoiceReceiveClient.sink must be a property"
|
||||||
|
assert sink_attr.fget is not None, "VoiceReceiveClient.sink property missing getter"
|
||||||
|
assert sink_attr.fset is not None, "VoiceReceiveClient.sink property missing setter"
|
||||||
|
|
||||||
|
|
||||||
|
# --- AudioSink API surface --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method_name",
|
||||||
|
[
|
||||||
|
"write", # write(user, voice_data) — the hot path
|
||||||
|
"cleanup",
|
||||||
|
"wants_opus", # bool: opus bytes vs decoded PCM
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_audio_sink_has_method(method_name):
|
||||||
|
from src.voice._discord_voice_adapter import AudioSink
|
||||||
|
|
||||||
|
attr = getattr(AudioSink, method_name, None)
|
||||||
|
assert attr is not None, f"AudioSink is missing `.{method_name}()`"
|
||||||
|
assert callable(attr), f"AudioSink.{method_name} is not callable"
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_sink_write_signature():
|
||||||
|
"""`.write(self, user, data)` — user is the speaker (Optional), data is VoiceData."""
|
||||||
|
from src.voice._discord_voice_adapter import AudioSink
|
||||||
|
|
||||||
|
sig = inspect.signature(AudioSink.write)
|
||||||
|
params = list(sig.parameters)
|
||||||
|
# self, user, data
|
||||||
|
assert len(params) >= 3, f"AudioSink.write expected (self, user, data), got {params}"
|
||||||
|
|
||||||
|
|
||||||
|
# --- VoiceData attributes ---------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_data_slots():
|
||||||
|
"""VoiceData uses __slots__ for per-packet allocation. Pipeline reads these."""
|
||||||
|
from src.voice._discord_voice_adapter import VoiceData
|
||||||
|
|
||||||
|
assert hasattr(VoiceData, "__slots__"), "VoiceData lost __slots__ — perf regression risk"
|
||||||
|
slots = set(VoiceData.__slots__)
|
||||||
|
# Documented attributes the pipeline depends on.
|
||||||
|
assert "packet" in slots, f"VoiceData missing `packet` slot; got {slots}"
|
||||||
|
assert "source" in slots, f"VoiceData missing `source` slot (speaker user); got {slots}"
|
||||||
|
assert "pcm" in slots, f"VoiceData missing `pcm` slot (decoded audio); got {slots}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_data_has_opus_property():
|
||||||
|
"""`.opus` exposes the raw opus bytes from the underlying RTP packet."""
|
||||||
|
from src.voice._discord_voice_adapter import VoiceData
|
||||||
|
|
||||||
|
opus_attr = inspect.getattr_static(VoiceData, "opus", None)
|
||||||
|
assert isinstance(opus_attr, property), "VoiceData.opus must be a property"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Echo-core DAVE-decrypt fork guards -------------------------------------
|
||||||
|
#
|
||||||
|
# Two contract tests pinned by the DAVE receive-side decrypt patch.
|
||||||
|
# See plan: /home/moltbot/.claude/plans/wiggly-exploring-glade.md
|
||||||
|
#
|
||||||
|
# These fail fast on either:
|
||||||
|
# 1. An upstream voice-recv re-install wiping the fork's version marker
|
||||||
|
# (i.e. our patch is gone), OR
|
||||||
|
# 2. A discord.py upgrade renaming the connection-level DAVE attrs the
|
||||||
|
# patch reads (`dave_session`, `dave_protocol_version`).
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_recv_fork_version():
|
||||||
|
"""Echo-core fork tag for the DAVE-decrypt patch.
|
||||||
|
|
||||||
|
Lane A bumps `voice_recv.__version__` to `'0.5.3a+echo.dave1'` (PEP 440
|
||||||
|
local segment). If this assertion fails after a vendor reinstall, the
|
||||||
|
fork patch has been lost — re-apply `_maybe_dave_decrypt` + the
|
||||||
|
`callback()` hook before deploying, or live voice will regress to the
|
||||||
|
`opus_decode: corrupted stream` error chain.
|
||||||
|
"""
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
assert voice_recv.__version__ == "0.5.3a+echo.dave1", (
|
||||||
|
f"voice_recv.__version__ is {voice_recv.__version__!r}; expected "
|
||||||
|
"'0.5.3a+echo.dave1'. The DAVE-decrypt fork patch has been "
|
||||||
|
"overwritten — re-apply before reinstalling the vendored package."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_connection_state_has_dave_attrs():
|
||||||
|
"""`_maybe_dave_decrypt` reads `dave_session` and `dave_protocol_version`
|
||||||
|
off the discord.py `VoiceConnectionState`. If a future discord.py upgrade
|
||||||
|
renames either attr, fail loudly here rather than in a live voice call
|
||||||
|
(where the symptom is silent packet drops).
|
||||||
|
"""
|
||||||
|
from discord import voice_state
|
||||||
|
|
||||||
|
src = inspect.getsource(voice_state.VoiceConnectionState)
|
||||||
|
assert "dave_session" in src, (
|
||||||
|
"discord.voice_state.VoiceConnectionState source no longer mentions "
|
||||||
|
"'dave_session' — discord.py may have renamed the attr. Update "
|
||||||
|
"vendor/discord-ext-voice-recv/.../reader.py::_maybe_dave_decrypt."
|
||||||
|
)
|
||||||
|
assert "dave_protocol_version" in src, (
|
||||||
|
"discord.voice_state.VoiceConnectionState source no longer mentions "
|
||||||
|
"'dave_protocol_version' — discord.py may have renamed the attr. "
|
||||||
|
"Update _maybe_dave_decrypt accordingly."
|
||||||
|
)
|
||||||
137
tests/test_voice_normalize.py
Normal file
137
tests/test_voice_normalize.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Tests for src/voice/normalize.py — 35 Romanian cases.
|
||||||
|
|
||||||
|
Categories:
|
||||||
|
markdown strip (5), numbers cardinals (6), decimals (4),
|
||||||
|
currency natural (8), symbols (4), abbreviations (4),
|
||||||
|
truncation boundary (2), edge cases empty / whitespace (2).
|
||||||
|
|
||||||
|
Total: 35.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.voice.normalize import (
|
||||||
|
expand_abbreviations,
|
||||||
|
expand_currency,
|
||||||
|
expand_numbers_ro,
|
||||||
|
expand_symbols,
|
||||||
|
normalize_for_tts,
|
||||||
|
strip_markdown,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Markdown stripping (5)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("**bold text**", "bold text"),
|
||||||
|
("*italic text*", "italic text"),
|
||||||
|
("`code snippet`", "code snippet"),
|
||||||
|
("[click here](https://example.com)", "click here"),
|
||||||
|
("# Heading text", "Heading text"),
|
||||||
|
])
|
||||||
|
def test_strip_markdown(text, expected):
|
||||||
|
assert strip_markdown(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Numbers cardinals (6)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("21", "douăzeci și unu"),
|
||||||
|
("81", "optzeci și unu"),
|
||||||
|
("100", "o sută"),
|
||||||
|
("3", "trei"),
|
||||||
|
("0", "zero"),
|
||||||
|
("200", "două sute"),
|
||||||
|
])
|
||||||
|
def test_expand_numbers_cardinals(text, expected):
|
||||||
|
assert expand_numbers_ro(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Decimals (4)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("3.14", "trei virgulă paisprezece"),
|
||||||
|
("12.5", "doisprezece virgulă cinci"),
|
||||||
|
("0.5", "zero virgulă cinci"),
|
||||||
|
("99.99", "nouăzeci și nouă virgulă nouăzeci și nouă"),
|
||||||
|
])
|
||||||
|
def test_expand_numbers_decimals(text, expected):
|
||||||
|
assert expand_numbers_ro(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Currency natural RO (8) — RON / USD / EUR / GBP mix
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("12.50 RON", "doisprezece lei și cincizeci de bani"),
|
||||||
|
("$25.99", "douăzeci și cinci de dolari și nouăzeci și nouă de cenți"),
|
||||||
|
("€100.50", "o sută de euro și cincizeci de cenți"),
|
||||||
|
("£200", "două sute de lire"),
|
||||||
|
("100 RON", "o sută de lei"),
|
||||||
|
("$1", "un dolar"),
|
||||||
|
("€50", "cincizeci de euro"),
|
||||||
|
("1 RON", "un leu"),
|
||||||
|
])
|
||||||
|
def test_expand_currency(text, expected):
|
||||||
|
assert expand_currency(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Symbols (4)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("25%", "25 la sută"),
|
||||||
|
("foo & bar", "foo și bar"),
|
||||||
|
("Marius @ home", "Marius la home"),
|
||||||
|
("30°", "30 grade"),
|
||||||
|
])
|
||||||
|
def test_expand_symbols(text, expected):
|
||||||
|
assert expand_symbols(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Abbreviations (4)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("etc.", "etcetera"),
|
||||||
|
("dl. Popescu", "domnul Popescu"),
|
||||||
|
("dna. Ionescu", "doamna Ionescu"),
|
||||||
|
("nr. 5", "numărul 5"),
|
||||||
|
])
|
||||||
|
def test_expand_abbreviations(text, expected):
|
||||||
|
assert expand_abbreviations(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Truncation boundary (2)
|
||||||
|
# ============================================================
|
||||||
|
def test_truncate_exactly_200_words_unchanged():
|
||||||
|
"""Exactly 200 simple word tokens — no truncation, no suffix."""
|
||||||
|
text = " ".join(["cuvant"] * 200)
|
||||||
|
out = normalize_for_tts(text)
|
||||||
|
assert "Restul l-am scris în chat." not in out
|
||||||
|
assert out.split() == ["cuvant"] * 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_over_200_words_appends_suffix():
|
||||||
|
"""250 word tokens — keep first 200 then append the chat-deferral phrase."""
|
||||||
|
text = " ".join(["cuvant"] * 250)
|
||||||
|
out = normalize_for_tts(text)
|
||||||
|
assert out.endswith("Restul l-am scris în chat.")
|
||||||
|
words = out.split()
|
||||||
|
# First 200 are 'cuvant', followed by the 5-word suffix.
|
||||||
|
assert words[:200] == ["cuvant"] * 200
|
||||||
|
assert words[200:] == ["Restul", "l-am", "scris", "în", "chat."]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Edge cases (2)
|
||||||
|
# ============================================================
|
||||||
|
@pytest.mark.parametrize("text,expected", [
|
||||||
|
("", ""),
|
||||||
|
(" ", ""),
|
||||||
|
])
|
||||||
|
def test_normalize_edge_cases(text, expected):
|
||||||
|
assert normalize_for_tts(text) == expected
|
||||||
302
tests/test_voice_recv_dave.py
Normal file
302
tests/test_voice_recv_dave.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""DAVE receive-side decrypt tests for the vendored voice-recv fork.
|
||||||
|
|
||||||
|
Exercises Lane A's patch on
|
||||||
|
`vendor/discord-ext-voice-recv/discord/ext/voice_recv/reader.py`:
|
||||||
|
|
||||||
|
* `_maybe_dave_decrypt(rtp_packet)` — DAVE E2E layer sandwiched between the
|
||||||
|
transport-layer decrypt and the routing into the opus decoder. No-op when
|
||||||
|
the room is non-DAVE, when davey isn't installed, or when the SSRC map
|
||||||
|
hasn't caught up to a new speaker yet.
|
||||||
|
* `callback()` hook — feeds the DAVE-unwrapped plaintext into
|
||||||
|
`packet_router.feed_rtp()` on success, drops the packet on failure WITHOUT
|
||||||
|
killing the reader thread.
|
||||||
|
|
||||||
|
The test fixtures mirror `tests/test_voice_session_cleanup.py:33-54`:
|
||||||
|
* Construct `AudioReader` via `AudioReader.__new__(AudioReader)` + manual
|
||||||
|
attr set so the reader thread is never started.
|
||||||
|
* `MagicMock` everything below the unit under test.
|
||||||
|
|
||||||
|
`_HAS_DAVE` / `_MEDIA_TYPE_AUDIO` on the reader module are monkey-patched per
|
||||||
|
test so the suite passes whether or not `davey` is importable in the venv.
|
||||||
|
The assertions only become meaningful once Lane A's patch has landed and the
|
||||||
|
package has been re-installed (`pip install -e vendor/discord-ext-voice-recv
|
||||||
|
--force-reinstall`); the FILE itself is valid Python regardless.
|
||||||
|
|
||||||
|
See plan: /home/moltbot/.claude/plans/wiggly-exploring-glade.md
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from discord.ext.voice_recv.reader import AudioReader
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel for `_MEDIA_TYPE_AUDIO`. Using a plain object() keeps the tests
|
||||||
|
# independent of whether davey is importable — we just assert the value
|
||||||
|
# flows through to `dave_session.decrypt()` unchanged.
|
||||||
|
_FAKE_MEDIA_TYPE_AUDIO = object()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_dave_session():
|
||||||
|
sess = MagicMock(name="dave_session")
|
||||||
|
sess.ready = True
|
||||||
|
# Default: this user is NOT in passthrough — DAVE decrypt must run.
|
||||||
|
# Individual tests can override to True to exercise the passthrough path.
|
||||||
|
sess.can_passthrough = MagicMock(return_value=False)
|
||||||
|
sess.decrypt = MagicMock(return_value=b"plaintext_opus")
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_connection(fake_dave_session):
|
||||||
|
conn = MagicMock(name="_connection")
|
||||||
|
conn.dave_protocol_version = 1
|
||||||
|
conn.dave_session = fake_dave_session
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_voice_client(fake_connection):
|
||||||
|
vc = MagicMock(name="voice_client")
|
||||||
|
vc._connection = fake_connection
|
||||||
|
vc._ssrc_to_id = {12345: 999_000}
|
||||||
|
return vc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_rtp_packet():
|
||||||
|
pkt = MagicMock(name="rtp_packet")
|
||||||
|
pkt.ssrc = 12345
|
||||||
|
pkt.decrypted_data = b"ciphertext_after_transport_decrypt"
|
||||||
|
pkt.is_silence = MagicMock(return_value=False)
|
||||||
|
return pkt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reader(fake_voice_client):
|
||||||
|
"""`AudioReader` instance with no reader thread spawned.
|
||||||
|
|
||||||
|
Same pattern used by `tests/test_voice_session_cleanup.py` for
|
||||||
|
`VoiceSession` — bypass `__init__` so we can drive the public surface
|
||||||
|
against pure mocks.
|
||||||
|
"""
|
||||||
|
r = AudioReader.__new__(AudioReader)
|
||||||
|
r.voice_client = fake_voice_client
|
||||||
|
r.error = None
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dave_enabled(monkeypatch):
|
||||||
|
"""Force the reader module's DAVE-availability flags ON.
|
||||||
|
|
||||||
|
Pins `_MEDIA_TYPE_AUDIO` to a known sentinel so the happy-path test can
|
||||||
|
assert exactly what gets passed to `dave_session.decrypt`. `raising=False`
|
||||||
|
keeps the monkeypatch valid even if Lane A's patch hasn't landed yet —
|
||||||
|
the tests will still fail (no `_maybe_dave_decrypt` attr), just for the
|
||||||
|
right reason.
|
||||||
|
"""
|
||||||
|
import discord.ext.voice_recv.reader as reader_mod
|
||||||
|
|
||||||
|
monkeypatch.setattr(reader_mod, "_HAS_DAVE", True, raising=False)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
reader_mod, "_MEDIA_TYPE_AUDIO", _FAKE_MEDIA_TYPE_AUDIO, raising=False
|
||||||
|
)
|
||||||
|
return reader_mod
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: `_maybe_dave_decrypt`
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMaybeDaveDecrypt:
|
||||||
|
"""Seven unit tests on the DAVE-decrypt gate.
|
||||||
|
|
||||||
|
The gate mirrors `voice_client.can_encrypt` in discord.py 2.7.1 exactly
|
||||||
|
(`voice_state.py:272-273`). Bypass semantics on every "DAVE inactive"
|
||||||
|
branch let non-DAVE rooms and davey-less environments keep working.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_protocol_version_zero_bypasses_decrypt(
|
||||||
|
self, dave_enabled, reader, fake_connection, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`dave_protocol_version == 0` → return the transport-decrypted
|
||||||
|
payload unchanged; never touch `dave_session.decrypt`."""
|
||||||
|
fake_connection.dave_protocol_version = 0
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is fake_rtp_packet.decrypted_data
|
||||||
|
fake_dave_session.decrypt.assert_not_called()
|
||||||
|
|
||||||
|
def test_dave_session_none_bypasses_decrypt(
|
||||||
|
self, dave_enabled, reader, fake_connection, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`dave_session is None` → bypass. Pre-MLS-handshake state."""
|
||||||
|
fake_connection.dave_session = None
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is fake_rtp_packet.decrypted_data
|
||||||
|
|
||||||
|
def test_dave_session_not_ready_bypasses_decrypt(
|
||||||
|
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`dave_session.ready is False` → bypass. Pre-MLS-epoch-1 packets
|
||||||
|
are transport-only on the wire."""
|
||||||
|
fake_dave_session.ready = False
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is fake_rtp_packet.decrypted_data
|
||||||
|
fake_dave_session.decrypt.assert_not_called()
|
||||||
|
|
||||||
|
def test_unknown_ssrc_returns_none(
|
||||||
|
self, dave_enabled, reader, fake_voice_client, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""SSRC not in `_ssrc_to_id` → drop (return None).
|
||||||
|
|
||||||
|
Accepted regression: davey requires per-user keys; when SPEAKING
|
||||||
|
events race behind the first audio packet, 1-5 packets per new
|
||||||
|
speaker per session are dropped. See plan §Edge cases.
|
||||||
|
"""
|
||||||
|
fake_voice_client._ssrc_to_id.clear()
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is None
|
||||||
|
fake_dave_session.decrypt.assert_not_called()
|
||||||
|
|
||||||
|
def test_happy_path_invokes_decrypt_and_returns_plaintext(
|
||||||
|
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""Full DAVE-active path: `decrypt(user_id, MediaType.audio, ciphertext)`
|
||||||
|
called exactly once with the expected args; method returns the
|
||||||
|
davey plaintext bytes verbatim."""
|
||||||
|
ciphertext = fake_rtp_packet.decrypted_data
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result == b"plaintext_opus"
|
||||||
|
fake_dave_session.decrypt.assert_called_once_with(
|
||||||
|
999_000, _FAKE_MEDIA_TYPE_AUDIO, ciphertext,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decrypt_raises_returns_none_no_crash(
|
||||||
|
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""davey.decrypt raising → drop the packet, don't propagate, and
|
||||||
|
leave `reader.error` untouched so the reader thread stays alive.
|
||||||
|
|
||||||
|
MLS epoch transitions can produce transient decrypt failures —
|
||||||
|
bumping `reader.error` would call `self.stop()` and kill the whole
|
||||||
|
receive pipeline."""
|
||||||
|
fake_dave_session.decrypt.side_effect = RuntimeError(
|
||||||
|
"simulated MLS epoch transition fail"
|
||||||
|
)
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is None
|
||||||
|
assert reader.error is None
|
||||||
|
|
||||||
|
def test_has_dave_false_bypasses_even_with_session_present(
|
||||||
|
self, monkeypatch, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`_HAS_DAVE = False` → bypass everything, even if a real session
|
||||||
|
somehow showed up on the connection. Defensive shim that keeps the
|
||||||
|
tests (and any davey-less deploys) green."""
|
||||||
|
import discord.ext.voice_recv.reader as reader_mod
|
||||||
|
|
||||||
|
monkeypatch.setattr(reader_mod, "_HAS_DAVE", False, raising=False)
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is fake_rtp_packet.decrypted_data
|
||||||
|
fake_dave_session.decrypt.assert_not_called()
|
||||||
|
|
||||||
|
def test_can_passthrough_true_returns_payload_without_decrypt(
|
||||||
|
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`can_passthrough(user_id) == True` → return the transport-decrypted
|
||||||
|
payload as-is; never call `decrypt`. Mirrors Discord's protocol where
|
||||||
|
a passthrough-mode peer sends non-DAVE-wrapped packets that the
|
||||||
|
receiver must accept verbatim."""
|
||||||
|
fake_dave_session.can_passthrough = MagicMock(return_value=True)
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result is fake_rtp_packet.decrypted_data
|
||||||
|
fake_dave_session.can_passthrough.assert_called_once_with(999_000)
|
||||||
|
fake_dave_session.decrypt.assert_not_called()
|
||||||
|
|
||||||
|
def test_can_passthrough_raises_falls_through_to_decrypt(
|
||||||
|
self, dave_enabled, reader, fake_dave_session, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""`can_passthrough` raising → swallow the error and try `decrypt`.
|
||||||
|
Defensive: an older davey build or transient internal state shouldn't
|
||||||
|
break the receive pipeline."""
|
||||||
|
fake_dave_session.can_passthrough = MagicMock(
|
||||||
|
side_effect=RuntimeError("simulated davey internal error")
|
||||||
|
)
|
||||||
|
result = reader._maybe_dave_decrypt(fake_rtp_packet)
|
||||||
|
assert result == b"plaintext_opus"
|
||||||
|
fake_dave_session.decrypt.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests: `callback()` exercises the DAVE hook
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackIntegration:
|
||||||
|
"""Two integration tests for the hook Lane A inserts between transport
|
||||||
|
decrypt (reader.py:141) and the post-decrypt routing (reader.py:159).
|
||||||
|
|
||||||
|
Strategy: stub the transport-decrypt and RTP parsing path so `callback()`
|
||||||
|
reaches the hook, then mock `_maybe_dave_decrypt` directly on the reader
|
||||||
|
instance. The assertion focuses on `feed_rtp` being called (test 8) vs.
|
||||||
|
not called (test 9). The transport path correctness is covered by
|
||||||
|
voice-recv's own upstream tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _wire_callback(reader, monkeypatch, fake_rtp_packet):
|
||||||
|
import discord.ext.voice_recv.reader as reader_mod
|
||||||
|
|
||||||
|
# Redirect rtp parsing — we want an RTP path (not RTCP) so the hook fires.
|
||||||
|
monkeypatch.setattr(reader_mod.rtp, "is_rtcp", lambda data: False)
|
||||||
|
monkeypatch.setattr(reader_mod.rtp, "decode_rtp", lambda data: fake_rtp_packet)
|
||||||
|
|
||||||
|
# Stub the instance attrs `callback()` touches besides the hook.
|
||||||
|
reader.decryptor = MagicMock(name="decryptor")
|
||||||
|
reader.decryptor.decrypt_rtp = MagicMock(return_value=b"ciphertext")
|
||||||
|
reader.packet_router = MagicMock(name="packet_router")
|
||||||
|
reader.packet_router.feed_rtp = MagicMock()
|
||||||
|
reader.speaking_timer = MagicMock(name="speaking_timer")
|
||||||
|
reader.sink = MagicMock(name="sink")
|
||||||
|
|
||||||
|
def test_callback_feeds_when_dave_returns_bytes(
|
||||||
|
self, monkeypatch, reader, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""Hook returns plaintext → `feed_rtp` called once with the
|
||||||
|
rtp_packet whose `decrypted_data` is now the post-DAVE plaintext."""
|
||||||
|
self._wire_callback(reader, monkeypatch, fake_rtp_packet)
|
||||||
|
plaintext = b"dave_unwrapped_opus_payload"
|
||||||
|
reader._maybe_dave_decrypt = MagicMock(return_value=plaintext)
|
||||||
|
|
||||||
|
reader.callback(b"raw_packet_bytes")
|
||||||
|
|
||||||
|
reader._maybe_dave_decrypt.assert_called_once_with(fake_rtp_packet)
|
||||||
|
assert reader.packet_router.feed_rtp.call_count == 1
|
||||||
|
called_with = reader.packet_router.feed_rtp.call_args[0][0]
|
||||||
|
assert called_with is fake_rtp_packet
|
||||||
|
assert fake_rtp_packet.decrypted_data == plaintext
|
||||||
|
assert reader.error is None
|
||||||
|
|
||||||
|
def test_callback_drops_when_dave_returns_none(
|
||||||
|
self, monkeypatch, reader, fake_rtp_packet,
|
||||||
|
):
|
||||||
|
"""Hook returns None → `feed_rtp` NOT called, no exception propagated,
|
||||||
|
`reader.error` stays None (reader thread survives the drop)."""
|
||||||
|
self._wire_callback(reader, monkeypatch, fake_rtp_packet)
|
||||||
|
reader._maybe_dave_decrypt = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
reader.callback(b"raw_packet_bytes")
|
||||||
|
|
||||||
|
reader._maybe_dave_decrypt.assert_called_once_with(fake_rtp_packet)
|
||||||
|
reader.packet_router.feed_rtp.assert_not_called()
|
||||||
|
assert reader.error is None
|
||||||
319
tests/test_voice_session_cleanup.py
Normal file
319
tests/test_voice_session_cleanup.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
"""Cleanup-path tests for ``src/voice/pipeline.py::VoiceSession``.
|
||||||
|
|
||||||
|
Pins the centralized ``cleanup()`` contract from the voice plan
|
||||||
|
(Engineering decision #5): every one of the FIVE exit paths must drain
|
||||||
|
state cleanly and idempotently — lock released, JSONL flushed or
|
||||||
|
discarded, presence cleared, ``voice_client.cleanup()`` invoked,
|
||||||
|
``ttsq.stop()`` invoked, and a second call to ``cleanup()`` MUST be a
|
||||||
|
no-op (side effects happen exactly once).
|
||||||
|
|
||||||
|
The 5 paths under test:
|
||||||
|
1. ``test_cleanup_on_voice_leave`` — explicit ``/voice leave``
|
||||||
|
2. ``test_cleanup_on_disconnect`` — Discord-level disconnect
|
||||||
|
3. ``test_cleanup_on_crash`` — exception via ``__exit__``
|
||||||
|
4. ``test_cleanup_on_auto_leave`` — 5-min inactivity timer
|
||||||
|
5. ``test_cleanup_on_user_leaves_channel`` — user leaves voice channel
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.voice.pipeline import VoiceSession
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bot():
|
||||||
|
bot = MagicMock(name="bot")
|
||||||
|
bot.user = MagicMock()
|
||||||
|
bot.user.id = 999_999
|
||||||
|
bot.change_presence = AsyncMock(name="change_presence")
|
||||||
|
bot.get_user = MagicMock(return_value=None)
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_voice_client():
|
||||||
|
vc = MagicMock(name="voice_client")
|
||||||
|
vc.cleanup = MagicMock(name="vc_cleanup")
|
||||||
|
return vc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ttsq():
|
||||||
|
ttsq = MagicMock(name="ttsq")
|
||||||
|
ttsq.stop = MagicMock(name="ttsq_stop")
|
||||||
|
return ttsq
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_text_channel():
|
||||||
|
tc = MagicMock(name="text_channel")
|
||||||
|
tc.send = AsyncMock(name="text_send")
|
||||||
|
return tc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_bot,
|
||||||
|
mock_voice_client,
|
||||||
|
mock_ttsq,
|
||||||
|
mock_text_channel,
|
||||||
|
*,
|
||||||
|
record_enabled: bool = True,
|
||||||
|
) -> VoiceSession:
|
||||||
|
jsonl = tmp_path / ("transcripts.jsonl" if record_enabled else "noop.jsonl")
|
||||||
|
return VoiceSession(
|
||||||
|
channel_id=1001,
|
||||||
|
guild_id=42,
|
||||||
|
text_channel=mock_text_channel,
|
||||||
|
voice_client=mock_voice_client,
|
||||||
|
bot=mock_bot,
|
||||||
|
ttsq=mock_ttsq,
|
||||||
|
whitelist={1234},
|
||||||
|
record_enabled=record_enabled,
|
||||||
|
mirror_enabled=True,
|
||||||
|
transcripts_jsonl_path=jsonl,
|
||||||
|
loop=None,
|
||||||
|
router_route_message=MagicMock(name="route_message"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_clean_post_cleanup(
|
||||||
|
session: VoiceSession,
|
||||||
|
voice_client,
|
||||||
|
ttsq,
|
||||||
|
bot,
|
||||||
|
jsonl_path: Path,
|
||||||
|
record_enabled: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Assertions shared across all five cleanup-path tests."""
|
||||||
|
# 1. Lock released — non-blocking acquire from this thread returns True.
|
||||||
|
acquired = session._lock.acquire(blocking=False)
|
||||||
|
assert acquired, "session._lock must be released after cleanup()"
|
||||||
|
session._lock.release()
|
||||||
|
|
||||||
|
# 2. voice_client.cleanup() called exactly once.
|
||||||
|
assert voice_client.cleanup.call_count == 1, (
|
||||||
|
f"voice_client.cleanup() called {voice_client.cleanup.call_count}x, "
|
||||||
|
f"expected 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. ttsq.stop() called exactly once.
|
||||||
|
assert ttsq.stop.call_count == 1, (
|
||||||
|
f"ttsq.stop() called {ttsq.stop.call_count}x, expected 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. bot.change_presence(activity=None) called at least once with that kwarg.
|
||||||
|
assert bot.change_presence.call_count >= 1, (
|
||||||
|
"bot.change_presence was never called — presence not restored"
|
||||||
|
)
|
||||||
|
bot.change_presence.assert_called_with(activity=None)
|
||||||
|
|
||||||
|
# 5. JSONL flushed (record=on) OR absent (record=off).
|
||||||
|
if record_enabled:
|
||||||
|
assert jsonl_path.exists(), (
|
||||||
|
"record=on: JSONL file must exist (was created by __enter__ and "
|
||||||
|
"left in place by cleanup so transcript can be persisted)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# record=off: cleanup unlinks the file if it ever existed.
|
||||||
|
assert not jsonl_path.exists() or jsonl_path.stat().st_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 1 — explicit /voice leave
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupOnVoiceLeave:
|
||||||
|
def test_cleanup_on_voice_leave(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=True,
|
||||||
|
)
|
||||||
|
jsonl_path = session.transcripts_jsonl_path
|
||||||
|
|
||||||
|
with session:
|
||||||
|
# Simulate one transcript line.
|
||||||
|
session._jsonl_fh.write(json.dumps({"text": "salut"}) + "\n")
|
||||||
|
session.cleanup("voice_leave")
|
||||||
|
assert session._cleaned_up is True
|
||||||
|
|
||||||
|
# __exit__ called cleanup("exit") — must be a no-op the second time.
|
||||||
|
_assert_clean_post_cleanup(
|
||||||
|
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||||
|
jsonl_path, record_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Idempotency: a third explicit call still doesn't bump counts.
|
||||||
|
session.cleanup("redundant")
|
||||||
|
assert mock_voice_client.cleanup.call_count == 1
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 2 — Discord-level voice disconnect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupOnDisconnect:
|
||||||
|
def test_cleanup_on_disconnect(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=False,
|
||||||
|
)
|
||||||
|
jsonl_path = session.transcripts_jsonl_path
|
||||||
|
|
||||||
|
session.__enter__()
|
||||||
|
# Network drop arrives outside the with-block.
|
||||||
|
session.cleanup("disconnect")
|
||||||
|
_assert_clean_post_cleanup(
|
||||||
|
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||||
|
jsonl_path, record_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Idempotency.
|
||||||
|
session.cleanup("disconnect-again")
|
||||||
|
assert mock_voice_client.cleanup.call_count == 1
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 3 — crash / exception via __exit__
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupOnCrash:
|
||||||
|
def test_cleanup_on_crash(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=True,
|
||||||
|
)
|
||||||
|
jsonl_path = session.transcripts_jsonl_path
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="simulated crash"):
|
||||||
|
with session:
|
||||||
|
# Pipeline raises mid-call.
|
||||||
|
raise RuntimeError("simulated crash")
|
||||||
|
|
||||||
|
# __exit__ must have driven cleanup — every side effect happened once.
|
||||||
|
_assert_clean_post_cleanup(
|
||||||
|
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||||
|
jsonl_path, record_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Idempotency: explicit follow-up call (e.g. an outer error handler
|
||||||
|
# also tries to cleanup) MUST be a no-op.
|
||||||
|
session.cleanup("post-crash")
|
||||||
|
assert mock_voice_client.cleanup.call_count == 1
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 4 — auto-leave timer fires after 5 min inactivity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupOnAutoLeave:
|
||||||
|
def test_cleanup_on_auto_leave(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=True,
|
||||||
|
)
|
||||||
|
jsonl_path = session.transcripts_jsonl_path
|
||||||
|
|
||||||
|
session.__enter__()
|
||||||
|
# The auto-leave timer trips outside the with-block.
|
||||||
|
session.cleanup("auto_leave")
|
||||||
|
|
||||||
|
_assert_clean_post_cleanup(
|
||||||
|
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||||
|
jsonl_path, record_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Idempotency.
|
||||||
|
session.cleanup("auto_leave_redundant")
|
||||||
|
assert mock_voice_client.cleanup.call_count == 1
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 5 — user leaves voice channel themselves
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupOnUserLeaves:
|
||||||
|
def test_cleanup_on_user_leaves_channel(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=False,
|
||||||
|
)
|
||||||
|
jsonl_path = session.transcripts_jsonl_path
|
||||||
|
|
||||||
|
session.__enter__()
|
||||||
|
# voice_state_update event handler invokes cleanup directly.
|
||||||
|
session.cleanup("user_left_channel")
|
||||||
|
|
||||||
|
_assert_clean_post_cleanup(
|
||||||
|
session, mock_voice_client, mock_ttsq, mock_bot,
|
||||||
|
jsonl_path, record_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Idempotency.
|
||||||
|
session.cleanup("user_left_again")
|
||||||
|
assert mock_voice_client.cleanup.call_count == 1
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Cross-cutting: failures inside cleanup don't propagate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupRobustness:
|
||||||
|
def test_cleanup_swallows_voice_client_errors(
|
||||||
|
self, tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
):
|
||||||
|
"""If voice_client.cleanup() raises, ttsq.stop() must still run and
|
||||||
|
the lock must still release — otherwise a broken Discord state would
|
||||||
|
deadlock the channel forever."""
|
||||||
|
mock_voice_client.cleanup.side_effect = RuntimeError("vc died")
|
||||||
|
|
||||||
|
session = _make_session(
|
||||||
|
tmp_path, mock_bot, mock_voice_client, mock_ttsq, mock_text_channel,
|
||||||
|
record_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session:
|
||||||
|
session.cleanup("voice_leave")
|
||||||
|
|
||||||
|
# ttsq.stop still ran exactly once.
|
||||||
|
assert mock_ttsq.stop.call_count == 1
|
||||||
|
# Lock released.
|
||||||
|
acquired = session._lock.acquire(blocking=False)
|
||||||
|
assert acquired, "lock must release even when voice_client.cleanup raises"
|
||||||
|
session._lock.release()
|
||||||
20
tools/tts.py
20
tools/tts.py
@@ -23,6 +23,24 @@ VOICES = {"M1", "M2", "M3", "M4", "M5", "F1", "F2", "F3", "F4", "F5"}
|
|||||||
DEFAULT_VOICE = "M2"
|
DEFAULT_VOICE = "M2"
|
||||||
DEFAULT_LANG = "ro"
|
DEFAULT_LANG = "ro"
|
||||||
|
|
||||||
|
# Punctuation Supertonic synthesis rejects with HTTP 500 (Romanian curly quotes,
|
||||||
|
# smart dashes, ellipsis, angle quotes). Mapped to ASCII so a stray „foo" in
|
||||||
|
# any caller's text doesn't kill the whole request.
|
||||||
|
_TTS_PUNCT_MAP = {
|
||||||
|
'„': '"', '“': '"', '”': '"',
|
||||||
|
'‘': "'", '’': "'", '‚': "'",
|
||||||
|
'«': '"', '»': '"',
|
||||||
|
'–': '-', '—': '-',
|
||||||
|
'…': '...',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_for_supertonic(text: str) -> str:
|
||||||
|
"""Replace Unicode punctuation Supertonic rejects with ASCII equivalents."""
|
||||||
|
for src, dst in _TTS_PUNCT_MAP.items():
|
||||||
|
text = text.replace(src, dst)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def synthesize(text: str, voice: str = DEFAULT_VOICE, lang: str = DEFAULT_LANG) -> dict:
|
def synthesize(text: str, voice: str = DEFAULT_VOICE, lang: str = DEFAULT_LANG) -> dict:
|
||||||
"""Call Supertonic server and save audio to a temp WAV file.
|
"""Call Supertonic server and save audio to a temp WAV file.
|
||||||
@@ -34,6 +52,8 @@ def synthesize(text: str, voice: str = DEFAULT_VOICE, lang: str = DEFAULT_LANG)
|
|||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
return {"ok": False, "error": "Text gol."}
|
return {"ok": False, "error": "Text gol."}
|
||||||
|
|
||||||
|
text = sanitize_for_supertonic(text)
|
||||||
|
|
||||||
voice = voice.upper()
|
voice = voice.upper()
|
||||||
if voice not in VOICES:
|
if voice not in VOICES:
|
||||||
voice = DEFAULT_VOICE
|
voice = DEFAULT_VOICE
|
||||||
|
|||||||
60
vendor/discord-ext-voice-recv/VENDOR_INFO.md
vendored
60
vendor/discord-ext-voice-recv/VENDOR_INFO.md
vendored
@@ -1,22 +1,76 @@
|
|||||||
# Vendored: discord-ext-voice-recv
|
# Vendored: discord-ext-voice-recv
|
||||||
|
|
||||||
**Upstream:** https://github.com/imayhaveborkedit/discord-ext-voice-recv
|
**Upstream:** https://github.com/imayhaveborkedit/discord-ext-voice-recv
|
||||||
**Pinned commit:** `ac04ea7b0941112e83767cf1c1469b408fa06748` (bump version 0.5.3a)
|
**Pinned commit:** `ac04ea7b0941112e83767cf1c1469b408fa06748` (bump version 0.5.3a, master HEAD Jun 2025)
|
||||||
**Vendored at:** 2026-05-27
|
**Vendored at:** 2026-05-27
|
||||||
|
**Echo Core fork version:** `0.5.3a+echo.dave1` (PEP 440 local segment)
|
||||||
**Reason:** Discord voice protocol is fragile, upstream is hobby fork. Adapter
|
**Reason:** Discord voice protocol is fragile, upstream is hobby fork. Adapter
|
||||||
layer in `src/voice/_discord_voice_adapter.py` isolates upstream churn — if this
|
layer in `src/voice/_discord_voice_adapter.py` isolates upstream churn — if this
|
||||||
package breaks, swap to py-cord by rewriting only that file.
|
package breaks, swap to py-cord by rewriting only that file.
|
||||||
|
|
||||||
## Update procedure
|
## Echo Core patch: `+echo.dave1` (DAVE E2E receive-side decrypt)
|
||||||
|
|
||||||
|
### Why
|
||||||
|
|
||||||
|
Discord enforces DAVE (E2E media encryption) on voice gateway `v=8` whenever the
|
||||||
|
bot advertises `max_dave_protocol_version > 0` in IDENTIFY. discord.py 2.7.1 (the
|
||||||
|
version Echo Core pins) does so unconditionally — Discord then closes the WS
|
||||||
|
with code **4017** if the bot opts out by sending `max_dave_protocol_version=0`.
|
||||||
|
DAVE is **mandatory**.
|
||||||
|
|
||||||
|
Audio received from a DAVE-active room is **dual-wrapped**: transport layer
|
||||||
|
(`aead_xchacha20_poly1305_rtpsize`) + DAVE E2E. Upstream voice-recv decrypts
|
||||||
|
only the transport layer, then hands DAVE ciphertext to libopus, which raises
|
||||||
|
`OpusError: corrupted stream` on every packet.
|
||||||
|
|
||||||
|
### Patch shape
|
||||||
|
|
||||||
|
~30 lines, all in `discord/ext/voice_recv/reader.py`:
|
||||||
|
|
||||||
|
1. Module-level optional `davey` import (no-op when missing).
|
||||||
|
2. `AudioReader._maybe_dave_decrypt(rtp_packet) -> Optional[bytes]` — gate logic
|
||||||
|
mirrors discord.py 2.7.1 send-side `can_encrypt` exactly. Returns the
|
||||||
|
DAVE-unwrapped payload, the original payload (DAVE inactive), or `None` to
|
||||||
|
drop the packet (unknown SSRC, decrypt failure).
|
||||||
|
3. 4-line hook in `callback()` between transport-decrypt and `feed_rtp`:
|
||||||
|
overwrites `rtp_packet.decrypted_data` in place, or returns early to drop.
|
||||||
|
|
||||||
|
The post-decrypt `is_silence()` check (formerly at reader.py:172) still works
|
||||||
|
because we overwrite `decrypted_data` in place — silence frames produced by
|
||||||
|
davey reach the existing check unchanged.
|
||||||
|
|
||||||
|
### Dependency
|
||||||
|
|
||||||
|
`davey==0.1.5` — matches discord.py 2.7.1 expectation. Pin in
|
||||||
|
`echo-core/requirements.txt`. The import is optional at module level so tests
|
||||||
|
and non-DAVE environments still run; the gate degrades to a bypass.
|
||||||
|
|
||||||
|
### Re-sync strategy
|
||||||
|
|
||||||
|
When upstream voice-recv adds DAVE support natively:
|
||||||
|
|
||||||
|
1. Drop the three patch hunks in `reader.py` (davey import block,
|
||||||
|
`_maybe_dave_decrypt` method, hook in `callback()`).
|
||||||
|
2. Revert `__version__` to upstream value in `__init__.py`.
|
||||||
|
3. Update `Pinned commit` below.
|
||||||
|
4. Run `pytest tests/test_voice_recv_dave.py tests/test_voice_adapter_contract.py`.
|
||||||
|
|
||||||
|
The contract test `test_voice_recv_fork_version` asserts `__version__ ==
|
||||||
|
'0.5.3a+echo.dave1'` and will fail fast on any accidental wipe during a careless
|
||||||
|
upstream sync — forcing a conscious decision to either re-port or drop the
|
||||||
|
patch.
|
||||||
|
|
||||||
|
## Update procedure (vanilla upstream sync)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd vendor/discord-ext-voice-recv
|
cd vendor/discord-ext-voice-recv
|
||||||
git fetch origin master
|
git fetch origin master
|
||||||
git log HEAD..origin/master --oneline # review what changed
|
git log HEAD..origin/master --oneline # review what changed
|
||||||
git checkout <new-commit>
|
git checkout <new-commit>
|
||||||
|
# RE-APPLY the +echo.dave1 patch if upstream still lacks DAVE
|
||||||
cd ../..
|
cd ../..
|
||||||
source .venv/bin/activate && pip install -e vendor/discord-ext-voice-recv --force-reinstall
|
source .venv/bin/activate && pip install -e vendor/discord-ext-voice-recv --force-reinstall
|
||||||
pytest tests/test_voice_adapter_contract.py -v # MUST PASS — contract guard
|
pytest tests/test_voice_adapter_contract.py tests/test_voice_recv_dave.py -v # MUST PASS — contract + DAVE guards
|
||||||
```
|
```
|
||||||
|
|
||||||
Update this file's `Pinned commit` after a successful upgrade.
|
Update this file's `Pinned commit` after a successful upgrade.
|
||||||
|
|||||||
@@ -17,4 +17,4 @@ __title__ = 'discord.ext.voice_recv'
|
|||||||
__author__ = 'Imayhaveborkedit'
|
__author__ = 'Imayhaveborkedit'
|
||||||
__license__ = 'MIT'
|
__license__ = 'MIT'
|
||||||
__copyright__ = 'Copyright 2021-present Imayhaveborkedit'
|
__copyright__ = 'Copyright 2021-present Imayhaveborkedit'
|
||||||
__version__ = '0.5.3a'
|
__version__ = '0.5.3a+echo.dave1'
|
||||||
|
|||||||
@@ -19,6 +19,15 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError("pynacl is required") from e
|
raise RuntimeError("pynacl is required") from e
|
||||||
|
|
||||||
|
# Echo Core +echo.dave1 patch: DAVE E2E receive-side decrypt. See VENDOR_INFO.md.
|
||||||
|
try:
|
||||||
|
import davey
|
||||||
|
_MEDIA_TYPE_AUDIO = davey.MediaType.audio
|
||||||
|
_HAS_DAVE = True
|
||||||
|
except ImportError:
|
||||||
|
_MEDIA_TYPE_AUDIO = None
|
||||||
|
_HAS_DAVE = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Optional, Callable, Any, Dict, Literal, Union
|
from typing import Optional, Callable, Any, Dict, Literal, Union
|
||||||
|
|
||||||
@@ -133,12 +142,63 @@ class AudioReader:
|
|||||||
def _is_ip_discovery_packet(self, data: bytes) -> bool:
|
def _is_ip_discovery_packet(self, data: bytes) -> bool:
|
||||||
return len(data) == 74 and data[1] == 0x02
|
return len(data) == 74 and data[1] == 0x02
|
||||||
|
|
||||||
|
def _maybe_dave_decrypt(self, rtp_packet) -> Optional[bytes]:
|
||||||
|
"""DAVE E2E layer applied after transport decrypt.
|
||||||
|
|
||||||
|
Returns the (possibly DAVE-unwrapped) opus payload, or None to drop the
|
||||||
|
packet. No-op when DAVE is inactive — non-DAVE rooms and environments
|
||||||
|
without `davey` installed pass through unchanged.
|
||||||
|
|
||||||
|
NOTE: `is_silence()` is NOT checked here. In a DAVE-active room the
|
||||||
|
transport-decrypted payload is ciphertext, so `is_silence()` (which
|
||||||
|
compares to plaintext OPUS_SILENCE ``b'\\xf8\\xff\\xfe'``) never matches.
|
||||||
|
Silence frames are handled either by davey.decrypt returning plaintext
|
||||||
|
silence (then caught at the existing post-decrypt silence check on
|
||||||
|
``decrypted_data``), or dropped via the decrypt-raises path. The
|
||||||
|
existing post-decrypt silence check continues to work because we
|
||||||
|
overwrite ``decrypted_data`` in place.
|
||||||
|
"""
|
||||||
|
if not _HAS_DAVE:
|
||||||
|
return rtp_packet.decrypted_data
|
||||||
|
conn = self.voice_client._connection
|
||||||
|
if getattr(conn, 'dave_protocol_version', 0) == 0:
|
||||||
|
return rtp_packet.decrypted_data
|
||||||
|
dave = getattr(conn, 'dave_session', None)
|
||||||
|
if dave is None or not dave.ready:
|
||||||
|
return rtp_packet.decrypted_data
|
||||||
|
user_id = self.voice_client._ssrc_to_id.get(rtp_packet.ssrc)
|
||||||
|
if user_id is None:
|
||||||
|
# ACCEPTED REGRESSION: davey requires per-user key. When SPEAKING
|
||||||
|
# event races behind the first audio packet, we drop 1-5 packets
|
||||||
|
# (~40-200ms) per new speaker per session.
|
||||||
|
return None
|
||||||
|
# can_passthrough(user_id) mirrors Discord's protocol: when this user's
|
||||||
|
# decryptor is in passthrough mode, packets are not DAVE-wrapped and
|
||||||
|
# must be returned as-is. Otherwise davey.decrypt unwraps DAVE E2E.
|
||||||
|
try:
|
||||||
|
if dave.can_passthrough(user_id):
|
||||||
|
return rtp_packet.decrypted_data
|
||||||
|
except Exception as e:
|
||||||
|
log.debug("can_passthrough check failed for ssrc=%s user=%s: %s: %s",
|
||||||
|
rtp_packet.ssrc, user_id, type(e).__name__, e)
|
||||||
|
try:
|
||||||
|
return dave.decrypt(user_id, _MEDIA_TYPE_AUDIO, rtp_packet.decrypted_data)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug("DAVE decrypt failed for ssrc=%s user=%s: %s: %s",
|
||||||
|
rtp_packet.ssrc, user_id, type(e).__name__, e)
|
||||||
|
return None
|
||||||
|
|
||||||
def callback(self, packet_data: bytes) -> None:
|
def callback(self, packet_data: bytes) -> None:
|
||||||
packet = rtp_packet = rtcp_packet = None
|
packet = rtp_packet = rtcp_packet = None
|
||||||
try:
|
try:
|
||||||
if not rtp.is_rtcp(packet_data):
|
if not rtp.is_rtcp(packet_data):
|
||||||
packet = rtp_packet = rtp.decode_rtp(packet_data)
|
packet = rtp_packet = rtp.decode_rtp(packet_data)
|
||||||
packet.decrypted_data = self.decryptor.decrypt_rtp(packet)
|
packet.decrypted_data = self.decryptor.decrypt_rtp(packet)
|
||||||
|
# Echo Core +echo.dave1: DAVE E2E layer (no-op when inactive).
|
||||||
|
dave_payload = self._maybe_dave_decrypt(rtp_packet)
|
||||||
|
if dave_payload is None:
|
||||||
|
return # drop packet, do not feed_rtp; reader thread stays alive
|
||||||
|
rtp_packet.decrypted_data = dave_payload
|
||||||
else:
|
else:
|
||||||
packet = rtcp_packet = rtp.decode_rtcp(self.decryptor.decrypt_rtcp(packet_data))
|
packet = rtcp_packet = rtp.decode_rtcp(self.decryptor.decrypt_rtcp(packet_data))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user