feat(voice): improve Romanian STT — hallucination gate + finetuned model

Gemma 4 cloud audio was infeasible (31b-cloud has no audio; E4B broken
upstream, no deploy host), so improve faster-whisper instead.

- Pin temperature=0.0 to disable the fallback ladder that re-decoded unclear
  audio up to 6x (source of the 16-24s latency outliers); reject hallucinated
  segments via avg_logprob/compression_ratio in the new pure _filter_segments.
- Adopt mikr/whisper-small-ro-cv11 (CT2 int8) via configurable voice.stt_model:
  spike showed WER 24%->10%, numbers fixed at source, +0.33s p50 (in budget).
- Add tools/voice_stt_mine.py (log mining) + tools/voice_stt_spike.py (model
  eval with diacritic scoring) + tests for the gate and miner.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-27 18:16:16 +00:00
parent ec23d188ec
commit ce273d14db
9 changed files with 664 additions and 16 deletions

180
tools/voice_stt_spike.py Normal file
View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python3
"""STT model spike — compare faster-whisper models on Romanian diacritic accuracy.
Answers the autoplan D1 question: does a Romanian-finetuned whisper-small beat the
generic `small` on diacritics WITHOUT regressing latency? Synthesizes clean RO
audio via Supertonic (ground-truth text known, with diacritics), runs each model,
and scores diacritic preservation + word error rate + latency.
This is an evaluation harness, not production code. Synthetic TTS audio is a clean
probe for diacritic behaviour specifically — it is NOT a proxy for real-mic acoustic
robustness, so weight latency + diacritics, not absolute WER.
Usage:
python3 tools/voice_stt_spike.py --models small,/home/.../whisper-small-ro-cv11-int8
python3 tools/voice_stt_spike.py --models small,<path> --threads 4 --trials 2
"""
from __future__ import annotations
import argparse
import statistics
import sys
import tempfile
import time
import unicodedata
from pathlib import Path
import httpx
SUPERTONIC_URL = "http://127.0.0.1:7788"
# Ground truth with correct Romanian diacritics (mirrors tools/voice_bench.py).
UTTERANCES_RO: list[tuple[str, str]] = [
("short", "Salut, ce mai faci?"),
("conversational", "Stai puțin să mă gândesc la asta."),
("medium", "Am verificat în calendar și avem ședință cu echipa la trei după-amiază."),
("numbers", "Costul total este o sută douăzeci și trei de lei și cincizeci de bani."),
("question", "Marius, vrei să-ți pun pe agenda de mâine să suni la NOAA?"),
("longer", "Vreau să-mi reamintești diseară să verific dacă scriptul de backup a rulat corect."),
]
_DIACRITICS = set("ăâîșțĂÂÎȘȚ")
def _strip_punct_lower(text: str) -> list[str]:
out = []
for raw in text.split():
w = "".join(c for c in raw if c.isalnum() or c in _DIACRITICS)
if w:
out.append(w.lower())
return out
def _deaccent(s: str) -> str:
# Map RO diacritics to base letters for "ignoring diacritics" comparison.
table = str.maketrans("ăâîșțĂÂÎȘȚ", "aaistAAIST")
s = s.translate(table)
return "".join(c for c in unicodedata.normalize("NFD", s)
if unicodedata.category(c) != "Mn")
def wer(ref: list[str], hyp: list[str]) -> float:
"""Word error rate via Levenshtein on token lists."""
n, m = len(ref), len(hyp)
if n == 0:
return 0.0 if m == 0 else 1.0
dp = list(range(m + 1))
for i in range(1, n + 1):
prev = dp[0]
dp[0] = i
for j in range(1, m + 1):
cur = dp[j]
cost = 0 if ref[i - 1] == hyp[j - 1] else 1
dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
prev = cur
return dp[m] / n
def diacritic_score(ref: str, hyp: str) -> tuple[int, int]:
"""For each ground-truth word that carries a diacritic, did the hypothesis
contain that exact diacritized word? Returns (correct, total)."""
ref_words = _strip_punct_lower(ref)
hyp_words = set(_strip_punct_lower(hyp))
hyp_deaccent = {_deaccent(w) for w in hyp_words}
correct = total = 0
for w in ref_words:
if set(w) & _DIACRITICS:
total += 1
if w in hyp_words:
correct += 1
return correct, total
def synthesize(text: str, out_path: Path) -> float:
import wave
r = httpx.post(f"{SUPERTONIC_URL}/v1/audio/speech",
json={"model": "supertonic-3", "input": text, "voice": "M2",
"response_format": "wav", "lang": "ro"}, timeout=60.0)
r.raise_for_status()
out_path.write_bytes(r.content)
with wave.open(str(out_path), "rb") as wf:
return wf.getnframes() / float(wf.getframerate())
def run_model(model_ref: str, wavs: list[tuple[str, str, str, float]],
threads: int, trials: int) -> dict:
from faster_whisper import WhisperModel
t0 = time.perf_counter()
model = WhisperModel(model_ref, device="cpu", compute_type="int8", cpu_threads=threads)
load_s = time.perf_counter() - t0
rows, lats = [], []
dia_c = dia_t = 0
wers = []
for name, ref_text, wav_path, _dur in wavs:
best_text = ""
for trial in range(trials):
t1 = time.perf_counter()
segments, _ = model.transcribe(wav_path, language="ro", beam_size=5,
temperature=0.0, condition_on_previous_text=False)
text = " ".join(s.text.strip() for s in segments).strip()
lats.append(time.perf_counter() - t1)
if trial == 0:
best_text = text
c, tt = diacritic_score(ref_text, best_text)
dia_c += c
dia_t += tt
w = wer(_strip_punct_lower(ref_text), _strip_punct_lower(best_text))
wers.append(w)
rows.append((name, ref_text, best_text, c, tt, w))
return {
"model": model_ref, "load_s": load_s, "rows": rows,
"p50": statistics.median(lats), "p95": sorted(lats)[max(0, int(0.95*(len(lats)-1)))],
"dia_correct": dia_c, "dia_total": dia_t,
"wer": statistics.mean(wers) if wers else 1.0,
}
def main(argv=None) -> int:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--models", required=True, help="CSV of model names or CT2 dirs")
ap.add_argument("--threads", type=int, default=4)
ap.add_argument("--trials", type=int, default=2)
args = ap.parse_args(argv)
work = Path(tempfile.mkdtemp(prefix="stt_spike_"))
print(f"[spike] synth dir {work}", flush=True)
wavs = []
for name, text in UTTERANCES_RO:
p = work / f"{name}.wav"
dur = synthesize(text, p)
wavs.append((name, text, str(p), dur))
print(f"[spike] TTS {name}: {dur:.2f}s", flush=True)
results = []
for ref in args.models.split(","):
ref = ref.strip()
if not ref:
continue
print(f"[spike] running {ref} (threads={args.threads})…", flush=True)
results.append(run_model(ref, wavs, args.threads, args.trials))
print("\n" + "=" * 72)
print(f"{'model':<42} {'p50':>6} {'p95':>6} {'WER':>6} {'diacr':>8}")
print("-" * 72)
for r in results:
dia = f"{r['dia_correct']}/{r['dia_total']}"
label = Path(r["model"]).name if "/" in r["model"] else r["model"]
print(f"{label:<42} {r['p50']:>5.2f}s {r['p95']:>5.2f}s {r['wer']*100:>5.1f}% {dia:>8}")
print("=" * 72)
for r in results:
label = Path(r["model"]).name if "/" in r["model"] else r["model"]
print(f"\n### {label}")
for name, ref_text, hyp, c, tt, w in r["rows"]:
print(f" [{name}] ref: {ref_text}")
print(f" [{name}] hyp: {hyp} (diacr {c}/{tt}, wer {w*100:.0f}%)")
return 0
if __name__ == "__main__":
sys.exit(main())