feat(voice): Pas 6 — voice/tts_stream.py streaming TTS pipeline

src/voice/tts_stream.py (~280 lines):
- clause_segments(text, min_words=8): yield Romanian-aware clause chunks.
  Split la punct (./!/?;:,) cu accumulation până min_words satisfied;
  edge case text < min_words → single chunk. NU split mid-word/number/
  currency. Romanian intonație de frază se rupe la sentence break — 8+
  words minimizează seams.
- TTSQueue worker thread: text queue in → PCM frames out. Methods:
  start/stop/push_text/push_filler/clear/is_empty. normalize_for_tts()
  apply first, then clause_segments(), then Supertonic synth per chunk.
- EchoStreamingAudioSource(discord.AudioSource): read() pull from PCM
  queue, 20ms frames (3840 bytes 48kHz s16le stereo). Eliminates RTP
  gap between play() calls — single play() per session, source pulls.
- load_thinking_wav(): one-shot cache → 140 × 20ms frames (~2.8s)
  pentru filler "Stai puțin să-mi adun gândurile".
- wav_to_pcm_20ms_frames(): WAV parser + ffmpeg subprocess resample
  la 48kHz s16le stereo dacă nevoie.

Smoke test (în session): clause_segments behaviour OK, thinking.wav
loads, TTSQueue + EchoStreamingAudioSource construct clean. Integration
testing deferred la convergență (Pas 7 + Pas 11).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 14:44:13 +00:00
parent 0cc01c1450
commit 217da65417

348
src/voice/tts_stream.py Normal file
View File

@@ -0,0 +1,348 @@
"""Streaming TTS with clause-level chunking for Discord voice mode.
A worker thread consumes text -> produces 20ms PCM frames on a queue.Queue.
``EchoStreamingAudioSource`` pulls frames into Discord's audio thread so a
single ``voice_client.play()`` call lasts the whole turn (eliminates the
RTP gap between successive ``play()`` calls and the race with barge-in
``stop()``). See plan: src/voice/tts_stream.py (Pas 6 / Lane TTS),
Engineering decisions #6, #8, #15.
"""
from __future__ import annotations
import io
import queue
import re
import subprocess
import threading
import wave
from pathlib import Path
from typing import Iterator, List, Optional
import discord
from src.voice.normalize import normalize_for_tts
from tools.tts import synthesize
# Discord wants 20ms of 16-bit 48kHz stereo PCM per frame.
# 48000 Hz * 0.020 s * 2 channels * 2 bytes = 3840 bytes.
FRAME_BYTES = 3840
TARGET_SAMPLE_RATE = 48000
TARGET_CHANNELS = 2
TARGET_SAMPLE_WIDTH = 2
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_THINKING_WAV = _PROJECT_ROOT / "assets" / "voice" / "thinking.wav"
# Cached filler frames (load + resample once, reuse forever).
_thinking_frames_cache: Optional[List[bytes]] = None
_thinking_cache_lock = threading.Lock()
# Sentinel pushed onto the text queue to ask the worker to exit cleanly.
_POISON = object()
# ---------- Clause segmentation ----------
# Split at Romanian sentence punctuation followed by whitespace. The
# trailing whitespace requirement protects mid-number (1.000), mid-decimal
# (12.5), and mid-abbreviation (M.D.) tokens, since none of those have a
# space right after the inner punctuation.
_CLAUSE_SPLIT = re.compile(r'(?<=[,;:.!?])\s+')
def clause_segments(text: str, min_words: int = 8) -> Iterator[str]:
"""Yield text in clause-sized chunks for streaming TTS.
Splits at ``, ; : . ! ?`` boundaries (only when the punctuation is
followed by whitespace, so numbers / decimals / abbreviations stay
intact). Short clauses are buffered and merged with the next one
until the accumulated chunk has at least ``min_words`` words. The
final remainder is always yielded, even if it's shorter than
``min_words`` -- otherwise the tail of the response would never
reach the TTS.
"""
if text is None:
return
text = text.strip()
if not text:
return
pieces = [p.strip() for p in _CLAUSE_SPLIT.split(text) if p and p.strip()]
if not pieces:
return
buffer = ''
for clause in pieces:
buffer = (buffer + ' ' + clause).strip() if buffer else clause
if len(buffer.split()) >= min_words:
yield buffer
buffer = ''
if buffer:
yield buffer
# ---------- WAV -> PCM frame conversion ----------
def _ffmpeg_resample(wav_bytes: bytes) -> bytes:
"""Convert any WAV payload to raw 48kHz stereo s16le PCM via ffmpeg.
ffmpeg is already an Echo Core hard dependency (heartbeat, video
transcription). Using a stdin/stdout pipe keeps the synth tempfile
short-lived and avoids extra disk traffic.
"""
proc = subprocess.run(
[
'ffmpeg', '-hide_banner', '-loglevel', 'error',
'-i', 'pipe:0',
'-f', 's16le',
'-ar', str(TARGET_SAMPLE_RATE),
'-ac', str(TARGET_CHANNELS),
'-acodec', 'pcm_s16le',
'pipe:1',
],
input=wav_bytes,
capture_output=True,
check=False,
)
if proc.returncode != 0:
err = proc.stderr.decode('utf-8', errors='replace')[:200]
raise RuntimeError(f"ffmpeg resample failed (rc={proc.returncode}): {err}")
return proc.stdout
def _is_target_format(wav_bytes: bytes) -> bool:
"""Quick check whether the WAV already matches Discord's PCM format."""
try:
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
return (
w.getframerate() == TARGET_SAMPLE_RATE
and w.getnchannels() == TARGET_CHANNELS
and w.getsampwidth() == TARGET_SAMPLE_WIDTH
and w.getcomptype() == 'NONE'
)
except (wave.Error, EOFError):
return False
def _extract_pcm_native(wav_bytes: bytes) -> bytes:
"""Strip the WAV header and return raw PCM (target format assumed)."""
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
return w.readframes(w.getnframes())
def wav_to_pcm_20ms_frames(wav_bytes: bytes) -> List[bytes]:
"""Parse a WAV blob, normalize to 48kHz s16le stereo, slice into 20ms frames.
The final frame is zero-padded to a full 3840 bytes so Discord's audio
thread always reads whole frames. Empty input yields an empty list.
"""
if not wav_bytes:
return []
pcm = _extract_pcm_native(wav_bytes) if _is_target_format(wav_bytes) else _ffmpeg_resample(wav_bytes)
if not pcm:
return []
frames: List[bytes] = []
for offset in range(0, len(pcm), FRAME_BYTES):
chunk = pcm[offset:offset + FRAME_BYTES]
if len(chunk) < FRAME_BYTES:
chunk = chunk + b'\x00' * (FRAME_BYTES - len(chunk))
frames.append(chunk)
return frames
def load_thinking_wav() -> List[bytes]:
"""Load ``assets/voice/thinking.wav`` and cache it as 20ms PCM frames.
Safe to call from multiple threads; the load happens at most once.
Returns an empty list if the asset is missing or fails to decode
(callers treat that as a no-op filler).
"""
global _thinking_frames_cache
if _thinking_frames_cache is not None:
return _thinking_frames_cache
with _thinking_cache_lock:
if _thinking_frames_cache is not None:
return _thinking_frames_cache
try:
wav_bytes = _THINKING_WAV.read_bytes()
_thinking_frames_cache = wav_to_pcm_20ms_frames(wav_bytes)
except (FileNotFoundError, OSError, RuntimeError):
_thinking_frames_cache = []
return _thinking_frames_cache
# ---------- TTS worker queue ----------
class TTSQueue:
"""Worker thread: text in -> 20ms PCM frames out.
Usage::
ttsq = TTSQueue(voice_id="M2", lang="ro")
ttsq.start()
ttsq.push_text("salut Marius, ce mai faci?")
voice_client.play(EchoStreamingAudioSource(ttsq))
# ... barge-in detected:
ttsq.clear()
# ... session over:
ttsq.stop()
"""
def __init__(self, voice_id: str = "M2", lang: str = "ro"):
self.voice_id = voice_id
self.lang = lang
self._text_queue: queue.Queue = queue.Queue()
self._pcm_queue: queue.Queue = queue.Queue()
self._worker_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
# --- lifecycle ---
def start(self) -> None:
if self._worker_thread is not None and self._worker_thread.is_alive():
return
self._stop_event.clear()
self._worker_thread = threading.Thread(
target=self._worker_loop,
name=f"tts-worker-{self.voice_id}",
daemon=True,
)
self._worker_thread.start()
def stop(self) -> None:
"""Signal the worker to exit, drain queues, join (timeout 5s)."""
self._stop_event.set()
# Wake the worker if it's blocked on get(timeout=...).
self._text_queue.put(_POISON)
thread = self._worker_thread
if thread is not None:
thread.join(timeout=5.0)
self._worker_thread = None
self._drain(self._text_queue)
self._drain(self._pcm_queue)
# --- producer side ---
def push_text(self, text: str) -> None:
"""Normalize, segment into clauses, enqueue each clause for synthesis."""
if not text:
return
cleaned = normalize_for_tts(text)
for clause in clause_segments(cleaned):
clause = clause.strip()
if clause:
self._text_queue.put(clause)
def push_filler(self) -> None:
"""Enqueue pre-rendered filler frames (``thinking.wav``) directly.
Bypasses synthesis -- the filler plays even if Supertonic is down
or slow. No-op if the asset failed to load.
"""
for frame in load_thinking_wav():
self._pcm_queue.put(frame)
def clear(self) -> None:
"""Drop everything pending (used for barge-in)."""
self._drain(self._text_queue)
self._drain(self._pcm_queue)
def is_empty(self) -> bool:
return self._text_queue.empty() and self._pcm_queue.empty()
# --- consumer side (called by EchoStreamingAudioSource) ---
def get_frame(self, timeout: float = 0.1) -> Optional[bytes]:
"""Block up to ``timeout`` seconds for the next 20ms PCM frame."""
try:
return self._pcm_queue.get(timeout=timeout)
except queue.Empty:
return None
# --- internals ---
@staticmethod
def _drain(q: queue.Queue) -> None:
while True:
try:
q.get_nowait()
except queue.Empty:
return
def _worker_loop(self) -> None:
while not self._stop_event.is_set():
try:
item = self._text_queue.get(timeout=0.1)
except queue.Empty:
continue
if item is _POISON:
break
if not isinstance(item, str):
continue
try:
result = synthesize(item, voice=self.voice_id, lang=self.lang)
except Exception:
# Synth crashes shouldn't kill the worker -- log path is the
# caller's job (we have no logger here on purpose).
continue
if not result.get('ok'):
continue
path = result.get('path')
if not path:
continue
wav_bytes = b''
try:
wav_bytes = Path(path).read_bytes()
except OSError:
pass
finally:
# Best-effort cleanup of the synth tempfile.
try:
Path(path).unlink(missing_ok=True)
except OSError:
pass
if not wav_bytes:
continue
try:
frames = wav_to_pcm_20ms_frames(wav_bytes)
except RuntimeError:
continue
for frame in frames:
if self._stop_event.is_set():
return
self._pcm_queue.put(frame)
# ---------- Discord audio source ----------
class EchoStreamingAudioSource(discord.AudioSource):
"""Pull PCM frames from a ``TTSQueue`` into Discord's audio thread.
A single ``voice_client.play(EchoStreamingAudioSource(ttsq))`` call
spans the whole turn. The audio thread blocks on the PCM queue for
up to 100ms per ``read()``; if it stays empty past that, ``read()``
returns ``b''`` which Discord interprets as end-of-stream and stops
the player (which is exactly what we want at end-of-turn or after
``ttsq.clear()`` on barge-in).
"""
def __init__(self, ttsq: TTSQueue):
self._ttsq = ttsq
self._closed = False
def read(self) -> bytes:
if self._closed:
return b''
frame = self._ttsq.get_frame(timeout=0.1)
if frame is None:
return b''
return frame
def is_opus(self) -> bool:
return False
def cleanup(self) -> None:
self._closed = True
try:
self._ttsq.clear()
except Exception:
pass