feat(voice): improve Romanian STT — hallucination gate + finetuned model
Gemma 4 cloud audio was infeasible (31b-cloud has no audio; E4B broken upstream, no deploy host), so improve faster-whisper instead. - Pin temperature=0.0 to disable the fallback ladder that re-decoded unclear audio up to 6x (source of the 16-24s latency outliers); reject hallucinated segments via avg_logprob/compression_ratio in the new pure _filter_segments. - Adopt mikr/whisper-small-ro-cv11 (CT2 int8) via configurable voice.stt_model: spike showed WER 24%->10%, numbers fixed at source, +0.33s p50 (in budget). - Add tools/voice_stt_mine.py (log mining) + tools/voice_stt_spike.py (model eval with diacritic scoring) + tests for the gate and miner. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
166
tools/voice_stt_mine.py
Normal file
166
tools/voice_stt_mine.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Mine logs/voice_stt_log.jsonl for STT correction candidates.
|
||||
|
||||
Read-only analysis tool. Surfaces what the always-on STT log has captured so
|
||||
Marius can decide hotwords, spot recurring mistranscriptions, and judge whether
|
||||
a model swap (e.g. a Romanian-finetuned Whisper) actually helps.
|
||||
|
||||
Pure helpers (tokenize / aggregate) are importable and tested; the CLI just
|
||||
prints reports. Tolerates rows written before the `text_corrected` field
|
||||
existed (falls back to `text`).
|
||||
|
||||
Usage:
|
||||
python3 tools/voice_stt_mine.py # full report
|
||||
python3 tools/voice_stt_mine.py --tokens # token frequency only
|
||||
python3 tools/voice_stt_mine.py --rare # one-off tokens (candidates)
|
||||
python3 tools/voice_stt_mine.py --suspect # likely hallucination rows
|
||||
python3 tools/voice_stt_mine.py --log PATH # custom log path
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Iterator
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
DEFAULT_LOG = PROJECT_ROOT / "logs" / "voice_stt_log.jsonl"
|
||||
|
||||
# Latency above this (s) almost always means the decoder thrashed on unclear
|
||||
# audio — a strong hallucination signal worth reviewing. Mirrors the >7s
|
||||
# conversational-abort budget from tasks/voice-bench-results.md.
|
||||
SUSPECT_LATENCY_S = 7.0
|
||||
SUSPECT_NO_SPEECH = 0.5
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-zĂÂÎȘȚăâîșț]+", re.UNICODE)
|
||||
# Romanian diacritic letters; a token with none of these is a diacritic-restore
|
||||
# candidate worth a human glance (not auto-corrected — see plan D2).
|
||||
_DIACRITICS = set("ĂÂÎȘȚăâîșț")
|
||||
|
||||
|
||||
def read_log(path: Path) -> list[dict]:
|
||||
"""Parse the JSONL log; skip malformed lines instead of crashing."""
|
||||
rows: list[dict] = []
|
||||
if not path.exists():
|
||||
return rows
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rows.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return rows
|
||||
|
||||
|
||||
def row_text(row: dict) -> str:
|
||||
"""Raw transcript for a row. New rows may add `text_corrected`; mining
|
||||
always wants the raw STT output, which lives in `text`."""
|
||||
return (row.get("text") or "").strip()
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Split into alphabetic word tokens, lowercased. Drops digits/punct."""
|
||||
return [t.lower() for t in _TOKEN_RE.findall(text or "")]
|
||||
|
||||
|
||||
def token_frequency(rows: Iterable[dict]) -> Counter:
|
||||
counter: Counter = Counter()
|
||||
for row in rows:
|
||||
counter.update(tokenize(row_text(row)))
|
||||
return counter
|
||||
|
||||
|
||||
def rare_tokens(freq: Counter, max_count: int = 1) -> list[str]:
|
||||
"""Tokens seen at most `max_count` times — candidate mistranscriptions,
|
||||
proper nouns to add as hotwords, or code-switch garbage."""
|
||||
return sorted(t for t, c in freq.items() if c <= max_count)
|
||||
|
||||
|
||||
def missing_diacritic_candidates(freq: Counter, min_len: int = 4) -> list[str]:
|
||||
"""All-ASCII tokens (no Romanian diacritics) of reasonable length, sorted by
|
||||
frequency. These are the words a diacritic-restore pass would target — kept
|
||||
as a review list only (v1 does not auto-restore, per plan D2)."""
|
||||
out = [
|
||||
(t, c) for t, c in freq.items()
|
||||
if len(t) >= min_len and not (set(t) & _DIACRITICS) and t.isalpha()
|
||||
]
|
||||
out.sort(key=lambda tc: (-tc[1], tc[0]))
|
||||
return [t for t, _ in out]
|
||||
|
||||
|
||||
def suspect_rows(rows: Iterable[dict]) -> list[dict]:
|
||||
"""Rows that look like hallucinations: very high latency or borderline
|
||||
no_speech_prob that still produced text."""
|
||||
out = []
|
||||
for row in rows:
|
||||
lat = float(row.get("stt_latency_s") or 0.0)
|
||||
nsp = float(row.get("no_speech_prob") or 0.0)
|
||||
if lat >= SUSPECT_LATENCY_S or nsp >= SUSPECT_NO_SPEECH:
|
||||
out.append(row)
|
||||
return out
|
||||
|
||||
|
||||
def _iter_report(rows: list[dict]) -> Iterator[str]:
|
||||
freq = token_frequency(rows)
|
||||
yield f"entries: {len(rows)}"
|
||||
if rows:
|
||||
lats = [float(r.get("stt_latency_s") or 0.0) for r in rows]
|
||||
yield f"latency: mean={sum(lats)/len(lats):.2f}s max={max(lats):.2f}s"
|
||||
yield ""
|
||||
yield "== top tokens =="
|
||||
for tok, cnt in freq.most_common(20):
|
||||
yield f" {cnt:>3} {tok}"
|
||||
yield ""
|
||||
yield "== rare tokens (<=1, candidate corrections / hotwords) =="
|
||||
rare = rare_tokens(freq)
|
||||
yield " " + (", ".join(rare) if rare else "(none)")
|
||||
yield ""
|
||||
yield "== missing-diacritic candidates (review only) =="
|
||||
cands = missing_diacritic_candidates(freq)[:30]
|
||||
yield " " + (", ".join(cands) if cands else "(none)")
|
||||
yield ""
|
||||
suspects = suspect_rows(rows)
|
||||
yield f"== likely-hallucination rows ({len(suspects)}) =="
|
||||
for r in suspects:
|
||||
yield (f" lat={float(r.get('stt_latency_s') or 0):.1f}s "
|
||||
f"nsp={float(r.get('no_speech_prob') or 0):.2f} "
|
||||
f"{row_text(r)!r}")
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
ap = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
ap.add_argument("--log", type=Path, default=DEFAULT_LOG, help="path to voice_stt_log.jsonl")
|
||||
ap.add_argument("--tokens", action="store_true", help="token frequency only")
|
||||
ap.add_argument("--rare", action="store_true", help="one-off tokens only")
|
||||
ap.add_argument("--suspect", action="store_true", help="likely-hallucination rows only")
|
||||
args = ap.parse_args(argv)
|
||||
|
||||
rows = read_log(args.log)
|
||||
if not rows:
|
||||
print(f"no entries in {args.log}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
freq = token_frequency(rows)
|
||||
if args.tokens:
|
||||
for tok, cnt in freq.most_common():
|
||||
print(f"{cnt:>4} {tok}")
|
||||
elif args.rare:
|
||||
print("\n".join(rare_tokens(freq)))
|
||||
elif args.suspect:
|
||||
for r in suspect_rows(rows):
|
||||
print(f"lat={float(r.get('stt_latency_s') or 0):.1f}s "
|
||||
f"nsp={float(r.get('no_speech_prob') or 0):.2f} {row_text(r)!r}")
|
||||
else:
|
||||
print("\n".join(_iter_report(rows)))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
180
tools/voice_stt_spike.py
Normal file
180
tools/voice_stt_spike.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
"""STT model spike — compare faster-whisper models on Romanian diacritic accuracy.
|
||||
|
||||
Answers the autoplan D1 question: does a Romanian-finetuned whisper-small beat the
|
||||
generic `small` on diacritics WITHOUT regressing latency? Synthesizes clean RO
|
||||
audio via Supertonic (ground-truth text known, with diacritics), runs each model,
|
||||
and scores diacritic preservation + word error rate + latency.
|
||||
|
||||
This is an evaluation harness, not production code. Synthetic TTS audio is a clean
|
||||
probe for diacritic behaviour specifically — it is NOT a proxy for real-mic acoustic
|
||||
robustness, so weight latency + diacritics, not absolute WER.
|
||||
|
||||
Usage:
|
||||
python3 tools/voice_stt_spike.py --models small,/home/.../whisper-small-ro-cv11-int8
|
||||
python3 tools/voice_stt_spike.py --models small,<path> --threads 4 --trials 2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
SUPERTONIC_URL = "http://127.0.0.1:7788"
|
||||
|
||||
# Ground truth with correct Romanian diacritics (mirrors tools/voice_bench.py).
|
||||
UTTERANCES_RO: list[tuple[str, str]] = [
|
||||
("short", "Salut, ce mai faci?"),
|
||||
("conversational", "Stai puțin să mă gândesc la asta."),
|
||||
("medium", "Am verificat în calendar și avem ședință cu echipa la trei după-amiază."),
|
||||
("numbers", "Costul total este o sută douăzeci și trei de lei și cincizeci de bani."),
|
||||
("question", "Marius, vrei să-ți pun pe agenda de mâine să suni la NOAA?"),
|
||||
("longer", "Vreau să-mi reamintești diseară să verific dacă scriptul de backup a rulat corect."),
|
||||
]
|
||||
|
||||
_DIACRITICS = set("ăâîșțĂÂÎȘȚ")
|
||||
|
||||
|
||||
def _strip_punct_lower(text: str) -> list[str]:
|
||||
out = []
|
||||
for raw in text.split():
|
||||
w = "".join(c for c in raw if c.isalnum() or c in _DIACRITICS)
|
||||
if w:
|
||||
out.append(w.lower())
|
||||
return out
|
||||
|
||||
|
||||
def _deaccent(s: str) -> str:
|
||||
# Map RO diacritics to base letters for "ignoring diacritics" comparison.
|
||||
table = str.maketrans("ăâîșțĂÂÎȘȚ", "aaistAAIST")
|
||||
s = s.translate(table)
|
||||
return "".join(c for c in unicodedata.normalize("NFD", s)
|
||||
if unicodedata.category(c) != "Mn")
|
||||
|
||||
|
||||
def wer(ref: list[str], hyp: list[str]) -> float:
|
||||
"""Word error rate via Levenshtein on token lists."""
|
||||
n, m = len(ref), len(hyp)
|
||||
if n == 0:
|
||||
return 0.0 if m == 0 else 1.0
|
||||
dp = list(range(m + 1))
|
||||
for i in range(1, n + 1):
|
||||
prev = dp[0]
|
||||
dp[0] = i
|
||||
for j in range(1, m + 1):
|
||||
cur = dp[j]
|
||||
cost = 0 if ref[i - 1] == hyp[j - 1] else 1
|
||||
dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
|
||||
prev = cur
|
||||
return dp[m] / n
|
||||
|
||||
|
||||
def diacritic_score(ref: str, hyp: str) -> tuple[int, int]:
|
||||
"""For each ground-truth word that carries a diacritic, did the hypothesis
|
||||
contain that exact diacritized word? Returns (correct, total)."""
|
||||
ref_words = _strip_punct_lower(ref)
|
||||
hyp_words = set(_strip_punct_lower(hyp))
|
||||
hyp_deaccent = {_deaccent(w) for w in hyp_words}
|
||||
correct = total = 0
|
||||
for w in ref_words:
|
||||
if set(w) & _DIACRITICS:
|
||||
total += 1
|
||||
if w in hyp_words:
|
||||
correct += 1
|
||||
return correct, total
|
||||
|
||||
|
||||
def synthesize(text: str, out_path: Path) -> float:
|
||||
import wave
|
||||
r = httpx.post(f"{SUPERTONIC_URL}/v1/audio/speech",
|
||||
json={"model": "supertonic-3", "input": text, "voice": "M2",
|
||||
"response_format": "wav", "lang": "ro"}, timeout=60.0)
|
||||
r.raise_for_status()
|
||||
out_path.write_bytes(r.content)
|
||||
with wave.open(str(out_path), "rb") as wf:
|
||||
return wf.getnframes() / float(wf.getframerate())
|
||||
|
||||
|
||||
def run_model(model_ref: str, wavs: list[tuple[str, str, str, float]],
|
||||
threads: int, trials: int) -> dict:
|
||||
from faster_whisper import WhisperModel
|
||||
t0 = time.perf_counter()
|
||||
model = WhisperModel(model_ref, device="cpu", compute_type="int8", cpu_threads=threads)
|
||||
load_s = time.perf_counter() - t0
|
||||
rows, lats = [], []
|
||||
dia_c = dia_t = 0
|
||||
wers = []
|
||||
for name, ref_text, wav_path, _dur in wavs:
|
||||
best_text = ""
|
||||
for trial in range(trials):
|
||||
t1 = time.perf_counter()
|
||||
segments, _ = model.transcribe(wav_path, language="ro", beam_size=5,
|
||||
temperature=0.0, condition_on_previous_text=False)
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
lats.append(time.perf_counter() - t1)
|
||||
if trial == 0:
|
||||
best_text = text
|
||||
c, tt = diacritic_score(ref_text, best_text)
|
||||
dia_c += c
|
||||
dia_t += tt
|
||||
w = wer(_strip_punct_lower(ref_text), _strip_punct_lower(best_text))
|
||||
wers.append(w)
|
||||
rows.append((name, ref_text, best_text, c, tt, w))
|
||||
return {
|
||||
"model": model_ref, "load_s": load_s, "rows": rows,
|
||||
"p50": statistics.median(lats), "p95": sorted(lats)[max(0, int(0.95*(len(lats)-1)))],
|
||||
"dia_correct": dia_c, "dia_total": dia_t,
|
||||
"wer": statistics.mean(wers) if wers else 1.0,
|
||||
}
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
ap = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
ap.add_argument("--models", required=True, help="CSV of model names or CT2 dirs")
|
||||
ap.add_argument("--threads", type=int, default=4)
|
||||
ap.add_argument("--trials", type=int, default=2)
|
||||
args = ap.parse_args(argv)
|
||||
|
||||
work = Path(tempfile.mkdtemp(prefix="stt_spike_"))
|
||||
print(f"[spike] synth dir {work}", flush=True)
|
||||
wavs = []
|
||||
for name, text in UTTERANCES_RO:
|
||||
p = work / f"{name}.wav"
|
||||
dur = synthesize(text, p)
|
||||
wavs.append((name, text, str(p), dur))
|
||||
print(f"[spike] TTS {name}: {dur:.2f}s", flush=True)
|
||||
|
||||
results = []
|
||||
for ref in args.models.split(","):
|
||||
ref = ref.strip()
|
||||
if not ref:
|
||||
continue
|
||||
print(f"[spike] running {ref} (threads={args.threads})…", flush=True)
|
||||
results.append(run_model(ref, wavs, args.threads, args.trials))
|
||||
|
||||
print("\n" + "=" * 72)
|
||||
print(f"{'model':<42} {'p50':>6} {'p95':>6} {'WER':>6} {'diacr':>8}")
|
||||
print("-" * 72)
|
||||
for r in results:
|
||||
dia = f"{r['dia_correct']}/{r['dia_total']}"
|
||||
label = Path(r["model"]).name if "/" in r["model"] else r["model"]
|
||||
print(f"{label:<42} {r['p50']:>5.2f}s {r['p95']:>5.2f}s {r['wer']*100:>5.1f}% {dia:>8}")
|
||||
print("=" * 72)
|
||||
for r in results:
|
||||
label = Path(r["model"]).name if "/" in r["model"] else r["model"]
|
||||
print(f"\n### {label}")
|
||||
for name, ref_text, hyp, c, tt, w in r["rows"]:
|
||||
print(f" [{name}] ref: {ref_text}")
|
||||
print(f" [{name}] hyp: {hyp} (diacr {c}/{tt}, wer {w*100:.0f}%)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user