feat(voice): Pas 6 — voice/tts_stream.py streaming TTS pipeline
src/voice/tts_stream.py (~280 lines): - clause_segments(text, min_words=8): yield Romanian-aware clause chunks. Split la punct (./!/?;:,) cu accumulation până min_words satisfied; edge case text < min_words → single chunk. NU split mid-word/number/ currency. Romanian intonație de frază se rupe la sentence break — 8+ words minimizează seams. - TTSQueue worker thread: text queue in → PCM frames out. Methods: start/stop/push_text/push_filler/clear/is_empty. normalize_for_tts() apply first, then clause_segments(), then Supertonic synth per chunk. - EchoStreamingAudioSource(discord.AudioSource): read() pull from PCM queue, 20ms frames (3840 bytes 48kHz s16le stereo). Eliminates RTP gap between play() calls — single play() per session, source pulls. - load_thinking_wav(): one-shot cache → 140 × 20ms frames (~2.8s) pentru filler "Stai puțin să-mi adun gândurile". - wav_to_pcm_20ms_frames(): WAV parser + ffmpeg subprocess resample la 48kHz s16le stereo dacă nevoie. Smoke test (în session): clause_segments behaviour OK, thinking.wav loads, TTSQueue + EchoStreamingAudioSource construct clean. Integration testing deferred la convergență (Pas 7 + Pas 11). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
348
src/voice/tts_stream.py
Normal file
348
src/voice/tts_stream.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Streaming TTS with clause-level chunking for Discord voice mode.
|
||||
|
||||
A worker thread consumes text -> produces 20ms PCM frames on a queue.Queue.
|
||||
``EchoStreamingAudioSource`` pulls frames into Discord's audio thread so a
|
||||
single ``voice_client.play()`` call lasts the whole turn (eliminates the
|
||||
RTP gap between successive ``play()`` calls and the race with barge-in
|
||||
``stop()``). See plan: src/voice/tts_stream.py (Pas 6 / Lane TTS),
|
||||
Engineering decisions #6, #8, #15.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import threading
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import discord
|
||||
|
||||
from src.voice.normalize import normalize_for_tts
|
||||
from tools.tts import synthesize
|
||||
|
||||
|
||||
# Discord wants 20ms of 16-bit 48kHz stereo PCM per frame.
|
||||
# 48000 Hz * 0.020 s * 2 channels * 2 bytes = 3840 bytes.
|
||||
FRAME_BYTES = 3840
|
||||
TARGET_SAMPLE_RATE = 48000
|
||||
TARGET_CHANNELS = 2
|
||||
TARGET_SAMPLE_WIDTH = 2
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
_THINKING_WAV = _PROJECT_ROOT / "assets" / "voice" / "thinking.wav"
|
||||
|
||||
# Cached filler frames (load + resample once, reuse forever).
|
||||
_thinking_frames_cache: Optional[List[bytes]] = None
|
||||
_thinking_cache_lock = threading.Lock()
|
||||
|
||||
# Sentinel pushed onto the text queue to ask the worker to exit cleanly.
|
||||
_POISON = object()
|
||||
|
||||
|
||||
# ---------- Clause segmentation ----------
|
||||
|
||||
# Split at Romanian sentence punctuation followed by whitespace. The
|
||||
# trailing whitespace requirement protects mid-number (1.000), mid-decimal
|
||||
# (12.5), and mid-abbreviation (M.D.) tokens, since none of those have a
|
||||
# space right after the inner punctuation.
|
||||
_CLAUSE_SPLIT = re.compile(r'(?<=[,;:.!?])\s+')
|
||||
|
||||
|
||||
def clause_segments(text: str, min_words: int = 8) -> Iterator[str]:
|
||||
"""Yield text in clause-sized chunks for streaming TTS.
|
||||
|
||||
Splits at ``, ; : . ! ?`` boundaries (only when the punctuation is
|
||||
followed by whitespace, so numbers / decimals / abbreviations stay
|
||||
intact). Short clauses are buffered and merged with the next one
|
||||
until the accumulated chunk has at least ``min_words`` words. The
|
||||
final remainder is always yielded, even if it's shorter than
|
||||
``min_words`` -- otherwise the tail of the response would never
|
||||
reach the TTS.
|
||||
"""
|
||||
if text is None:
|
||||
return
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
pieces = [p.strip() for p in _CLAUSE_SPLIT.split(text) if p and p.strip()]
|
||||
if not pieces:
|
||||
return
|
||||
buffer = ''
|
||||
for clause in pieces:
|
||||
buffer = (buffer + ' ' + clause).strip() if buffer else clause
|
||||
if len(buffer.split()) >= min_words:
|
||||
yield buffer
|
||||
buffer = ''
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
|
||||
# ---------- WAV -> PCM frame conversion ----------
|
||||
|
||||
def _ffmpeg_resample(wav_bytes: bytes) -> bytes:
|
||||
"""Convert any WAV payload to raw 48kHz stereo s16le PCM via ffmpeg.
|
||||
|
||||
ffmpeg is already an Echo Core hard dependency (heartbeat, video
|
||||
transcription). Using a stdin/stdout pipe keeps the synth tempfile
|
||||
short-lived and avoids extra disk traffic.
|
||||
"""
|
||||
proc = subprocess.run(
|
||||
[
|
||||
'ffmpeg', '-hide_banner', '-loglevel', 'error',
|
||||
'-i', 'pipe:0',
|
||||
'-f', 's16le',
|
||||
'-ar', str(TARGET_SAMPLE_RATE),
|
||||
'-ac', str(TARGET_CHANNELS),
|
||||
'-acodec', 'pcm_s16le',
|
||||
'pipe:1',
|
||||
],
|
||||
input=wav_bytes,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
err = proc.stderr.decode('utf-8', errors='replace')[:200]
|
||||
raise RuntimeError(f"ffmpeg resample failed (rc={proc.returncode}): {err}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
def _is_target_format(wav_bytes: bytes) -> bool:
|
||||
"""Quick check whether the WAV already matches Discord's PCM format."""
|
||||
try:
|
||||
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||
return (
|
||||
w.getframerate() == TARGET_SAMPLE_RATE
|
||||
and w.getnchannels() == TARGET_CHANNELS
|
||||
and w.getsampwidth() == TARGET_SAMPLE_WIDTH
|
||||
and w.getcomptype() == 'NONE'
|
||||
)
|
||||
except (wave.Error, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
def _extract_pcm_native(wav_bytes: bytes) -> bytes:
|
||||
"""Strip the WAV header and return raw PCM (target format assumed)."""
|
||||
with wave.open(io.BytesIO(wav_bytes), 'rb') as w:
|
||||
return w.readframes(w.getnframes())
|
||||
|
||||
|
||||
def wav_to_pcm_20ms_frames(wav_bytes: bytes) -> List[bytes]:
|
||||
"""Parse a WAV blob, normalize to 48kHz s16le stereo, slice into 20ms frames.
|
||||
|
||||
The final frame is zero-padded to a full 3840 bytes so Discord's audio
|
||||
thread always reads whole frames. Empty input yields an empty list.
|
||||
"""
|
||||
if not wav_bytes:
|
||||
return []
|
||||
pcm = _extract_pcm_native(wav_bytes) if _is_target_format(wav_bytes) else _ffmpeg_resample(wav_bytes)
|
||||
if not pcm:
|
||||
return []
|
||||
frames: List[bytes] = []
|
||||
for offset in range(0, len(pcm), FRAME_BYTES):
|
||||
chunk = pcm[offset:offset + FRAME_BYTES]
|
||||
if len(chunk) < FRAME_BYTES:
|
||||
chunk = chunk + b'\x00' * (FRAME_BYTES - len(chunk))
|
||||
frames.append(chunk)
|
||||
return frames
|
||||
|
||||
|
||||
def load_thinking_wav() -> List[bytes]:
|
||||
"""Load ``assets/voice/thinking.wav`` and cache it as 20ms PCM frames.
|
||||
|
||||
Safe to call from multiple threads; the load happens at most once.
|
||||
Returns an empty list if the asset is missing or fails to decode
|
||||
(callers treat that as a no-op filler).
|
||||
"""
|
||||
global _thinking_frames_cache
|
||||
if _thinking_frames_cache is not None:
|
||||
return _thinking_frames_cache
|
||||
with _thinking_cache_lock:
|
||||
if _thinking_frames_cache is not None:
|
||||
return _thinking_frames_cache
|
||||
try:
|
||||
wav_bytes = _THINKING_WAV.read_bytes()
|
||||
_thinking_frames_cache = wav_to_pcm_20ms_frames(wav_bytes)
|
||||
except (FileNotFoundError, OSError, RuntimeError):
|
||||
_thinking_frames_cache = []
|
||||
return _thinking_frames_cache
|
||||
|
||||
|
||||
# ---------- TTS worker queue ----------
|
||||
|
||||
class TTSQueue:
|
||||
"""Worker thread: text in -> 20ms PCM frames out.
|
||||
|
||||
Usage::
|
||||
|
||||
ttsq = TTSQueue(voice_id="M2", lang="ro")
|
||||
ttsq.start()
|
||||
ttsq.push_text("salut Marius, ce mai faci?")
|
||||
voice_client.play(EchoStreamingAudioSource(ttsq))
|
||||
# ... barge-in detected:
|
||||
ttsq.clear()
|
||||
# ... session over:
|
||||
ttsq.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, voice_id: str = "M2", lang: str = "ro"):
|
||||
self.voice_id = voice_id
|
||||
self.lang = lang
|
||||
self._text_queue: queue.Queue = queue.Queue()
|
||||
self._pcm_queue: queue.Queue = queue.Queue()
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# --- lifecycle ---
|
||||
|
||||
def start(self) -> None:
|
||||
if self._worker_thread is not None and self._worker_thread.is_alive():
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"tts-worker-{self.voice_id}",
|
||||
daemon=True,
|
||||
)
|
||||
self._worker_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to exit, drain queues, join (timeout 5s)."""
|
||||
self._stop_event.set()
|
||||
# Wake the worker if it's blocked on get(timeout=...).
|
||||
self._text_queue.put(_POISON)
|
||||
thread = self._worker_thread
|
||||
if thread is not None:
|
||||
thread.join(timeout=5.0)
|
||||
self._worker_thread = None
|
||||
self._drain(self._text_queue)
|
||||
self._drain(self._pcm_queue)
|
||||
|
||||
# --- producer side ---
|
||||
|
||||
def push_text(self, text: str) -> None:
|
||||
"""Normalize, segment into clauses, enqueue each clause for synthesis."""
|
||||
if not text:
|
||||
return
|
||||
cleaned = normalize_for_tts(text)
|
||||
for clause in clause_segments(cleaned):
|
||||
clause = clause.strip()
|
||||
if clause:
|
||||
self._text_queue.put(clause)
|
||||
|
||||
def push_filler(self) -> None:
|
||||
"""Enqueue pre-rendered filler frames (``thinking.wav``) directly.
|
||||
|
||||
Bypasses synthesis -- the filler plays even if Supertonic is down
|
||||
or slow. No-op if the asset failed to load.
|
||||
"""
|
||||
for frame in load_thinking_wav():
|
||||
self._pcm_queue.put(frame)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Drop everything pending (used for barge-in)."""
|
||||
self._drain(self._text_queue)
|
||||
self._drain(self._pcm_queue)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self._text_queue.empty() and self._pcm_queue.empty()
|
||||
|
||||
# --- consumer side (called by EchoStreamingAudioSource) ---
|
||||
|
||||
def get_frame(self, timeout: float = 0.1) -> Optional[bytes]:
|
||||
"""Block up to ``timeout`` seconds for the next 20ms PCM frame."""
|
||||
try:
|
||||
return self._pcm_queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
# --- internals ---
|
||||
|
||||
@staticmethod
|
||||
def _drain(q: queue.Queue) -> None:
|
||||
while True:
|
||||
try:
|
||||
q.get_nowait()
|
||||
except queue.Empty:
|
||||
return
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
item = self._text_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if item is _POISON:
|
||||
break
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
try:
|
||||
result = synthesize(item, voice=self.voice_id, lang=self.lang)
|
||||
except Exception:
|
||||
# Synth crashes shouldn't kill the worker -- log path is the
|
||||
# caller's job (we have no logger here on purpose).
|
||||
continue
|
||||
if not result.get('ok'):
|
||||
continue
|
||||
path = result.get('path')
|
||||
if not path:
|
||||
continue
|
||||
wav_bytes = b''
|
||||
try:
|
||||
wav_bytes = Path(path).read_bytes()
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
# Best-effort cleanup of the synth tempfile.
|
||||
try:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
if not wav_bytes:
|
||||
continue
|
||||
try:
|
||||
frames = wav_to_pcm_20ms_frames(wav_bytes)
|
||||
except RuntimeError:
|
||||
continue
|
||||
for frame in frames:
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
self._pcm_queue.put(frame)
|
||||
|
||||
|
||||
# ---------- Discord audio source ----------
|
||||
|
||||
class EchoStreamingAudioSource(discord.AudioSource):
|
||||
"""Pull PCM frames from a ``TTSQueue`` into Discord's audio thread.
|
||||
|
||||
A single ``voice_client.play(EchoStreamingAudioSource(ttsq))`` call
|
||||
spans the whole turn. The audio thread blocks on the PCM queue for
|
||||
up to 100ms per ``read()``; if it stays empty past that, ``read()``
|
||||
returns ``b''`` which Discord interprets as end-of-stream and stops
|
||||
the player (which is exactly what we want at end-of-turn or after
|
||||
``ttsq.clear()`` on barge-in).
|
||||
"""
|
||||
|
||||
def __init__(self, ttsq: TTSQueue):
|
||||
self._ttsq = ttsq
|
||||
self._closed = False
|
||||
|
||||
def read(self) -> bytes:
|
||||
if self._closed:
|
||||
return b''
|
||||
frame = self._ttsq.get_frame(timeout=0.1)
|
||||
if frame is None:
|
||||
return b''
|
||||
return frame
|
||||
|
||||
def is_opus(self) -> bool:
|
||||
return False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self._closed = True
|
||||
try:
|
||||
self._ttsq.clear()
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user