#!/usr/bin/env python3 """STT model spike — compare faster-whisper models on Romanian diacritic accuracy. Answers the autoplan D1 question: does a Romanian-finetuned whisper-small beat the generic `small` on diacritics WITHOUT regressing latency? Synthesizes clean RO audio via Supertonic (ground-truth text known, with diacritics), runs each model, and scores diacritic preservation + word error rate + latency. This is an evaluation harness, not production code. Synthetic TTS audio is a clean probe for diacritic behaviour specifically — it is NOT a proxy for real-mic acoustic robustness, so weight latency + diacritics, not absolute WER. Usage: python3 tools/voice_stt_spike.py --models small,/home/.../whisper-small-ro-cv11-int8 python3 tools/voice_stt_spike.py --models small, --threads 4 --trials 2 """ from __future__ import annotations import argparse import statistics import sys import tempfile import time import unicodedata from pathlib import Path import httpx SUPERTONIC_URL = "http://127.0.0.1:7788" # Ground truth with correct Romanian diacritics (mirrors tools/voice_bench.py). UTTERANCES_RO: list[tuple[str, str]] = [ ("short", "Salut, ce mai faci?"), ("conversational", "Stai puțin să mă gândesc la asta."), ("medium", "Am verificat în calendar și avem ședință cu echipa la trei după-amiază."), ("numbers", "Costul total este o sută douăzeci și trei de lei și cincizeci de bani."), ("question", "Marius, vrei să-ți pun pe agenda de mâine să suni la NOAA?"), ("longer", "Vreau să-mi reamintești diseară să verific dacă scriptul de backup a rulat corect."), ] _DIACRITICS = set("ăâîșțĂÂÎȘȚ") def _strip_punct_lower(text: str) -> list[str]: out = [] for raw in text.split(): w = "".join(c for c in raw if c.isalnum() or c in _DIACRITICS) if w: out.append(w.lower()) return out def _deaccent(s: str) -> str: # Map RO diacritics to base letters for "ignoring diacritics" comparison. table = str.maketrans("ăâîșțĂÂÎȘȚ", "aaistAAIST") s = s.translate(table) return "".join(c for c in unicodedata.normalize("NFD", s) if unicodedata.category(c) != "Mn") def wer(ref: list[str], hyp: list[str]) -> float: """Word error rate via Levenshtein on token lists.""" n, m = len(ref), len(hyp) if n == 0: return 0.0 if m == 0 else 1.0 dp = list(range(m + 1)) for i in range(1, n + 1): prev = dp[0] dp[0] = i for j in range(1, m + 1): cur = dp[j] cost = 0 if ref[i - 1] == hyp[j - 1] else 1 dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) prev = cur return dp[m] / n def diacritic_score(ref: str, hyp: str) -> tuple[int, int]: """For each ground-truth word that carries a diacritic, did the hypothesis contain that exact diacritized word? Returns (correct, total).""" ref_words = _strip_punct_lower(ref) hyp_words = set(_strip_punct_lower(hyp)) hyp_deaccent = {_deaccent(w) for w in hyp_words} correct = total = 0 for w in ref_words: if set(w) & _DIACRITICS: total += 1 if w in hyp_words: correct += 1 return correct, total def synthesize(text: str, out_path: Path) -> float: import wave r = httpx.post(f"{SUPERTONIC_URL}/v1/audio/speech", json={"model": "supertonic-3", "input": text, "voice": "M2", "response_format": "wav", "lang": "ro"}, timeout=60.0) r.raise_for_status() out_path.write_bytes(r.content) with wave.open(str(out_path), "rb") as wf: return wf.getnframes() / float(wf.getframerate()) def run_model(model_ref: str, wavs: list[tuple[str, str, str, float]], threads: int, trials: int) -> dict: from faster_whisper import WhisperModel t0 = time.perf_counter() model = WhisperModel(model_ref, device="cpu", compute_type="int8", cpu_threads=threads) load_s = time.perf_counter() - t0 rows, lats = [], [] dia_c = dia_t = 0 wers = [] for name, ref_text, wav_path, _dur in wavs: best_text = "" for trial in range(trials): t1 = time.perf_counter() segments, _ = model.transcribe(wav_path, language="ro", beam_size=5, temperature=0.0, condition_on_previous_text=False) text = " ".join(s.text.strip() for s in segments).strip() lats.append(time.perf_counter() - t1) if trial == 0: best_text = text c, tt = diacritic_score(ref_text, best_text) dia_c += c dia_t += tt w = wer(_strip_punct_lower(ref_text), _strip_punct_lower(best_text)) wers.append(w) rows.append((name, ref_text, best_text, c, tt, w)) return { "model": model_ref, "load_s": load_s, "rows": rows, "p50": statistics.median(lats), "p95": sorted(lats)[max(0, int(0.95*(len(lats)-1)))], "dia_correct": dia_c, "dia_total": dia_t, "wer": statistics.mean(wers) if wers else 1.0, } def main(argv=None) -> int: ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--models", required=True, help="CSV of model names or CT2 dirs") ap.add_argument("--threads", type=int, default=4) ap.add_argument("--trials", type=int, default=2) args = ap.parse_args(argv) work = Path(tempfile.mkdtemp(prefix="stt_spike_")) print(f"[spike] synth dir {work}", flush=True) wavs = [] for name, text in UTTERANCES_RO: p = work / f"{name}.wav" dur = synthesize(text, p) wavs.append((name, text, str(p), dur)) print(f"[spike] TTS {name}: {dur:.2f}s", flush=True) results = [] for ref in args.models.split(","): ref = ref.strip() if not ref: continue print(f"[spike] running {ref} (threads={args.threads})…", flush=True) results.append(run_model(ref, wavs, args.threads, args.trials)) print("\n" + "=" * 72) print(f"{'model':<42} {'p50':>6} {'p95':>6} {'WER':>6} {'diacr':>8}") print("-" * 72) for r in results: dia = f"{r['dia_correct']}/{r['dia_total']}" label = Path(r["model"]).name if "/" in r["model"] else r["model"] print(f"{label:<42} {r['p50']:>5.2f}s {r['p95']:>5.2f}s {r['wer']*100:>5.1f}% {dia:>8}") print("=" * 72) for r in results: label = Path(r["model"]).name if "/" in r["model"] else r["model"] print(f"\n### {label}") for name, ref_text, hyp, c, tt, w in r["rows"]: print(f" [{name}] ref: {ref_text}") print(f" [{name}] hyp: {hyp} (diacr {c}/{tt}, wer {w*100:.0f}%)") return 0 if __name__ == "__main__": sys.exit(main())