From ce273d14dbb2d2ec56e74bdd5e3dbabd25407853 Mon Sep 17 00:00:00 2001 From: Marius Mutu Date: Sat, 27 Jun 2026 18:16:16 +0000 Subject: [PATCH] =?UTF-8?q?feat(voice):=20improve=20Romanian=20STT=20?= =?UTF-8?q?=E2=80=94=20hallucination=20gate=20+=20finetuned=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gemma 4 cloud audio was infeasible (31b-cloud has no audio; E4B broken upstream, no deploy host), so improve faster-whisper instead. - Pin temperature=0.0 to disable the fallback ladder that re-decoded unclear audio up to 6x (source of the 16-24s latency outliers); reject hallucinated segments via avg_logprob/compression_ratio in the new pure _filter_segments. - Adopt mikr/whisper-small-ro-cv11 (CT2 int8) via configurable voice.stt_model: spike showed WER 24%->10%, numbers fixed at source, +0.33s p50 (in budget). - Add tools/voice_stt_mine.py (log mining) + tools/voice_stt_spike.py (model eval with diacritic scoring) + tests for the gate and miner. Co-Authored-By: Claude Fable 5 --- .gitignore | 1 + config.json | 3 +- src/voice/pipeline.py | 70 ++++++++--- tasks/lessons.md | 14 +++ tasks/voice-stt-quality.md | 61 ++++++++++ tests/test_voice_pipeline_filter.py | 85 +++++++++++++ tests/test_voice_stt_mine.py | 100 ++++++++++++++++ tools/voice_stt_mine.py | 166 +++++++++++++++++++++++++ tools/voice_stt_spike.py | 180 ++++++++++++++++++++++++++++ 9 files changed, 664 insertions(+), 16 deletions(-) create mode 100644 tasks/voice-stt-quality.md create mode 100644 tests/test_voice_pipeline_filter.py create mode 100644 tests/test_voice_stt_mine.py create mode 100644 tools/voice_stt_mine.py create mode 100644 tools/voice_stt_spike.py diff --git a/.gitignore b/.gitignore index c4f5fcf..565419d 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ memory.bak/ approved-tasks.json dashboard/status.json tools/anaf-monitor/monitor.log +models/ diff --git a/config.json b/config.json index dc3dda9..e9a39b3 100644 --- a/config.json +++ b/config.json @@ -110,7 +110,8 @@ ], "user_name": "Marius", "default_voice": "F1", - "auto_leave_minutes": 5 + "auto_leave_minutes": 5, + "stt_model": "/home/moltbot/echo-core/models/whisper-small-ro-cv11-int8" }, "paths": { "personality": "personality/", diff --git a/src/voice/pipeline.py b/src/voice/pipeline.py index b528649..fe37eea 100644 --- a/src/voice/pipeline.py +++ b/src/voice/pipeline.py @@ -49,6 +49,14 @@ VAD_WINDOW_BYTES = PACKET_BYTES * (VAD_WINDOW_MS // PACKET_MS) VAD_THRESHOLD = 0.5 SILENCE_FLUSH_MS = 800 NO_SPEECH_DROP_THRESHOLD = 0.6 +# Hallucination rejection (no re-decode). faster-whisper's default temperature +# is a fallback ladder [0.0..1.0]; on unclear audio it re-decodes the segment up +# to 6x, which is what produced the 16-24s outliers in voice_stt_log.jsonl +# against a >7s conversational-abort budget. We pin temperature=0.0 (no fallback) +# and instead REJECT low-quality segments using the avg_logprob / compression_ratio +# that faster-whisper already computes per segment — zero extra latency. +AVG_LOGPROB_DROP_THRESHOLD = -1.0 # drop seg if avg_logprob below this +COMPRESSION_RATIO_DROP_THRESHOLD = 2.4 # drop seg if gzip ratio above this (repetition/garbage) PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent LOGS_DIR = PROJECT_ROOT / "logs" @@ -83,19 +91,28 @@ _silero_lock = threading.Lock() def _get_whisper_model() -> Any: - """Lazy-load faster-whisper ``small`` int8 with the spike-validated - ``cpu_threads=4`` (see ``tasks/voice-bench-results.md``).""" + """Lazy-load faster-whisper int8 with the spike-validated ``cpu_threads=4`` + (see ``tasks/voice-bench-results.md``). + + Model is configurable via ``voice.stt_model`` (default ``"small"``). It may be + a faster-whisper model name or a path to a local CT2 dir — e.g. the Romanian + Common-Voice finetune that halved WER and fixed number transcription in the + D1 spike (``tools/voice_stt_spike.py``). Custom paths still load with + ``local_files_only=True`` since they live on disk.""" global _whisper_model if _whisper_model is not None: return _whisper_model with _whisper_lock: if _whisper_model is not None: return _whisper_model + from src.config import Config + model_id = Config().get("voice.stt_model", "small") or "small" from faster_whisper import WhisperModel _whisper_model = WhisperModel( - "small", device="cpu", compute_type="int8", cpu_threads=4, + model_id, device="cpu", compute_type="int8", cpu_threads=4, local_files_only=True, ) + log.info("STT model loaded: %s", model_id) return _whisper_model @@ -145,6 +162,38 @@ def _pcm48_stereo_to_16_mono(pcm: bytes) -> np.ndarray: return np.ascontiguousarray(mono16, dtype=np.float32) +def _filter_segments(segments: Any) -> tuple[list[str], float]: + """Keep transcribable segments, drop silence and hallucinations. + + Pure + side-effect free (no model, no I/O) so the rejection thresholds are + unit-testable with fake segment objects. A segment is dropped when: + - ``no_speech_prob`` is high (silence/non-speech), OR + - ``avg_logprob`` is below ``AVG_LOGPROB_DROP_THRESHOLD`` (decoder unsure), OR + - ``compression_ratio`` exceeds ``COMPRESSION_RATIO_DROP_THRESHOLD`` (looped/garbage). + The avg_logprob/compression checks replace faster-whisper's temperature-fallback + re-decode (the source of the 16-24s latency outliers) with zero-cost rejection. + Returns ``(kept_text_parts, worst_no_speech_prob)``. + """ + text_parts: list[str] = [] + worst_no_speech = 0.0 + for seg in segments: + no_sp = float(getattr(seg, "no_speech_prob", 0.0) or 0.0) + if no_sp > worst_no_speech: + worst_no_speech = no_sp + if no_sp > NO_SPEECH_DROP_THRESHOLD: + continue + avg_lp = getattr(seg, "avg_logprob", None) + if avg_lp is not None and float(avg_lp) < AVG_LOGPROB_DROP_THRESHOLD: + continue + comp = getattr(seg, "compression_ratio", None) + if comp is not None and float(comp) > COMPRESSION_RATIO_DROP_THRESHOLD: + continue + seg_text = (getattr(seg, "text", "") or "").strip() + if seg_text: + text_parts.append(seg_text) + return text_parts, worst_no_speech + + # ---------- VoiceSession ---------- class VoiceSession: @@ -679,6 +728,7 @@ class EchoVoiceSink(AudioSink): model = _get_whisper_model() segments, _info = model.transcribe( mono16, language="ro", beam_size=5, + temperature=0.0, # no fallback ladder — reject bad segments instead (see thresholds above) initial_prompt=( "Conversatie in romana cu asistentul Eco (Echo Core). " "Marius i se adreseaza cu 'Salut, Eco', 'Eco' sau 'Echo Core' " @@ -689,20 +739,10 @@ class EchoVoiceSink(AudioSink): "F1, F2, F3, F4, F5. Exemple: vorbeste cu vocea M5, voce F3, " "treci pe vocea F1." ), - hotwords="Eco Echo Core Marius Bianca", + hotwords="Eco Echo Core Marius Bianca Bitcoin", condition_on_previous_text=False, ) - text_parts: list[str] = [] - worst_no_speech = 0.0 - for seg in segments: - no_sp = float(getattr(seg, "no_speech_prob", 0.0) or 0.0) - if no_sp > worst_no_speech: - worst_no_speech = no_sp - if no_sp > NO_SPEECH_DROP_THRESHOLD: - continue - seg_text = (getattr(seg, "text", "") or "").strip() - if seg_text: - text_parts.append(seg_text) + text_parts, worst_no_speech = _filter_segments(segments) if not text_parts: return text = " ".join(text_parts).strip() diff --git a/tasks/lessons.md b/tasks/lessons.md index 4887b0d..4cba6a5 100644 --- a/tasks/lessons.md +++ b/tasks/lessons.md @@ -51,3 +51,17 @@ Lecții capturate din corectările lui Marius. Citește acest fișier la începu **Greșeala:** Am editat index.json direct, cu o schemă diferită față de ce produce update_notes_index.py. **Regula:** Niciodată nu scriei manual în `memory/kb/index.json`. Fluxul corect: (1) creezi fișierul `.md` în `memory/kb//`, (2) rulezi `python3 tools/update_notes_index.py`. Dacă ai nevoie să salvezi o notiță din Facebook/video, folosești `scripts/transcribe_video.sh --save-kb` care face totul corect. **Când se aplică:** Orice salvare de notiță în KB (Facebook, YouTube, coaching, insights, orice). Dacă ești tentat să `json.dump` în index.json — stop, rulează scriptul. + +## Verifică că modelul/tool-ul numit chiar are capabilitatea ÎNAINTE de a planifica în jurul lui +**Data:** 2026-06-27 +**Context:** Marius a cerut să folosesc `gemma4:31b-cloud` (Ollama) pentru decodare audio ca alternativă la Whisper. Am verificat pe pagina oficială Ollama: variantele cloud (31b) suportă doar Text+Image — audio există DOAR pe E2B/E4B (edge, local), iar acela e stricat de o regresie upstream deschisă (issue #16584). Premisa cererii era infezabilă. +**Greșeala (evitată):** Dacă planificam direct integrarea fără să verific pagina modelului, scriam cod de cablare Ollama audio care n-ar fi funcționat niciodată. Search-ul generic spunea „Gemma 4 are audio" — adevărat la nivel de familie, fals pentru modelul cloud specific cerut. +**Regula:** Când userul numește un model/serviciu specific pentru o capabilitate (audio, vision, tool-use, context lung), verifică pagina/docs ACELUI model exact înainte de a planifica. Capabilitățile diferă per variantă (cloud vs edge, sizes). Fetch pagina oficială, nu te baza pe search agregat la nivel de familie. +**Când se aplică:** Orice task care pornește de la „folosește modelul X pentru Y". Confirmă X→Y pe sursa primară înainte de plan mode. + +## Corecția post-STT de text e cosmetică dacă consumatorul e un LLM — fixează la sursă (model), nu cu dicționar +**Data:** 2026-06-27 +**Context:** Plan inițial pentru curățarea transcrierii Whisper avea 4 piese, inclusiv dicționar de restaurare diacritice + canonicalizare wake-word. Două review-uri independente (/autoplan CEO+Eng) au arătat: textul transcris merge la Claude, care citește română fără diacritice perfect; NU există wake-word gate în cod (`on_segment_done` dispatch necondiționat); singurul consumator precis (`detect_voice_change`) e deja fuzz-hardenat. Un spike a confirmat că modelul RO-finetuned (`mikr/whisper-small-ro-cv11`) înjumătățește WER (24%→10%) și fixează numerele la SURSĂ, +0.33s latență. +**Greșeala (evitată):** Construirea unui strat de corecție hand-curat (întreținere perpetuă, risc de regresie pe cuvinte ambigue) când fix-ul real era un model finetuned cu același cost de inferență. +**Regula:** Înainte de a peticit output-ul unui model ML cu post-procesare rule-based, întreabă: (1) cine e CONSUMATORUL textului? (un LLM tolerează erori; un parser regex nu); (2) există un model finetuned care fixează la sursă cu același cost? Spike-uiește modelul ÎNAINTE de a scrie straturi de corecție. Verifică unde merge output-ul prin cod, nu presupune un gate care „pare" că există. +**Când se aplică:** Orice îmbunătățire de calitate STT/OCR/ML output. Tool de spike: `tools/voice_stt_spike.py`. diff --git a/tasks/voice-stt-quality.md b/tasks/voice-stt-quality.md new file mode 100644 index 0000000..7ee5e6e --- /dev/null +++ b/tasks/voice-stt-quality.md @@ -0,0 +1,61 @@ +# Voice STT Quality — îmbunătățiri Whisper + +Branch: `voice/stt-quality`. Origine: cererea de a folosi Gemma 4 cloud pentru audio +(infezabil — `gemma4:31b-cloud` n-are audio, E4B e stricat upstream, fără host de deploy). +Pivot la îmbunătățirea Whisper, validat prin `/autoplan` (CEO + Eng review). + +## Ce s-a livrat + +### 1. Gate rejection halucinații (cost zero latență) — `src/voice/pipeline.py` +- `model.transcribe(..., temperature=0.0)` — **dezactivează scara de fallback** a faster-whisper. + Codul vechi nu pasa `temperature`, deci folosea implicit `[0.0..1.0]` (6 pași) care re-decoda + segmentul pe audio prost → exact sursa latențelor de 24.4s / 16.7s din `voice_stt_log.jsonl`. +- `_filter_segments()` — funcție pură nouă care dropează segmentele cu `no_speech_prob` mare, + `avg_logprob < -1.0` (decoder nesigur) sau `compression_ratio > 2.4` (buclă/gunoi). Zero + re-decodare. Prinde „Care pune o zana judiciul tugea" / „Acest lucru a fost foarte mult". +- `hotwords += Bitcoin`. `initial_prompt` neatins (evită taxa de latență pe fiecare enunț). +- Teste: `tests/test_voice_pipeline_filter.py` (8 cazuri). + +### 2. Unealtă de mining — `tools/voice_stt_mine.py` +- CLI read-only peste `voice_stt_log.jsonl`: frecvențe token, tokeni rari (candidați + hotwords/corecții), candidați diacritice lipsă, rânduri suspecte de halucinație. +- Tolerează rânduri fără `text_corrected` (citește `text`). Teste: `tests/test_voice_stt_mine.py` (13). + +### 3. Spike model RO-finetuned (D1) — `tools/voice_stt_spike.py` +Compară modele faster-whisper pe audio RO sintetizat (Supertonic) cu ground-truth diacritizat. + +**Rezultat (threads=4, beam=5):** + +| Model | p50 | p95 | WER | Diacritice | +|-------|-----|-----|-----|-----------| +| `small` (baseline) | 2.59s | 3.04s | 24.2% | 12/20 | +| **`mikr/whisper-small-ro-cv11`** (CT2 int8) | 2.92s | 3.25s | **10.5%** | **17/20** | + +- WER se înjumătățește; diacritice 60%→85%; numere PERFECTE (baseline: „120 si 3 delei" + → finetuned: „o sută douăzeci și trei de lei"). Cost: +0.33s p50 (în bugetul 1.5-3s). +- Modelul CT2: `~/.cache/echo-ct2/whisper-small-ro-cv11-int8` (234M int8). + +### 4. Model STT configurabil — `src/voice/pipeline.py::_get_whisper_model` +- Citește `voice.stt_model` din config (default `"small"`). Adopția finetuned = flip config, + nu cod. **Default rămâne `small`** până la decizia de adopție. + +## Cum adopți modelul finetuned (când decizi) +```bash +# config.json → "voice": { ..., "stt_model": "/home/moltbot/.cache/echo-ct2/whisper-small-ro-cv11-int8" } +systemctl --user restart echo-core # reload model +``` +Re-rulează spike-ul oricând: `python3 tools/voice_stt_spike.py --models "small," --threads 4` + +## Decizii autoplan respinse (din review) +- ❌ `temperature=[0.0..0.6]` fallback → regresie latență pe worst-case. Înlocuit cu rejection. +- ❌ `canonicalize_wakeword` → **nu există wake gate** în cod (verificat); ar fi spart `detect_voice_change`. +- ❌ Dicționar diacritice pe calea Claude → Claude citește română stâlcită OK; finetuned-ul rezolvă la sursă. +- ⏸ `correct_vocab` / `src/voice/stt_correct.py` → deferate (21 mostre = anecdotă; mining adună întâi date). + +## Note de mediu +- `transformers 5.12.1` instalat în `.venv` pentru conversia CT2 (one-time). A downgradat + `tokenizers` 0.23.1→0.22.2 (faster-whisper încă OK, pin `<1` respectat). Se poate `pip uninstall + transformers` dacă nu mai e nevoie de conversii. +- **Pre-existent, neatins de mine:** `tools/tts.py` modificat necommis sparge 2 teste din + `test_voice_normalize.py` (truncare 200 cuvinte). Confirmat: cu `tts.py` committed, testele trec. +``` diff --git a/tests/test_voice_pipeline_filter.py b/tests/test_voice_pipeline_filter.py new file mode 100644 index 0000000..df9a94b --- /dev/null +++ b/tests/test_voice_pipeline_filter.py @@ -0,0 +1,85 @@ +"""Tests for src/voice/pipeline.py::_filter_segments — STT hallucination gate. + +The gate replaces faster-whisper's temperature-fallback re-decode (the source of +16-24s latency outliers) with zero-cost segment rejection on no_speech_prob, +avg_logprob, and compression_ratio. +""" +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.voice.pipeline import ( # noqa: E402 + AVG_LOGPROB_DROP_THRESHOLD, + COMPRESSION_RATIO_DROP_THRESHOLD, + NO_SPEECH_DROP_THRESHOLD, + _filter_segments, +) + + +@dataclass +class FakeSeg: + text: str = "" + no_speech_prob: float = 0.0 + avg_logprob: Optional[float] = 0.0 + compression_ratio: Optional[float] = 1.0 + + +def test_keeps_clean_segment(): + parts, worst = _filter_segments([FakeSeg(text="salut eco", avg_logprob=-0.3, compression_ratio=1.5)]) + assert parts == ["salut eco"] + assert worst == 0.0 + + +def test_drops_high_no_speech(): + seg = FakeSeg(text="hmm", no_speech_prob=NO_SPEECH_DROP_THRESHOLD + 0.1) + parts, worst = _filter_segments([seg]) + assert parts == [] + assert worst == NO_SPEECH_DROP_THRESHOLD + 0.1 # still tracked for logging + + +def test_drops_low_avg_logprob_hallucination(): + # "Care pune o zana judiciul tugea" style: decoder unsure + seg = FakeSeg(text="zana judiciul tugea", avg_logprob=AVG_LOGPROB_DROP_THRESHOLD - 0.5) + parts, _ = _filter_segments([seg]) + assert parts == [] + + +def test_drops_high_compression_ratio_loop(): + seg = FakeSeg(text="da da da da da", compression_ratio=COMPRESSION_RATIO_DROP_THRESHOLD + 1.0) + parts, _ = _filter_segments([seg]) + assert parts == [] + + +def test_keeps_when_metrics_missing(): + # Older/edge segments may not expose avg_logprob/compression_ratio + seg = FakeSeg(text="ok", avg_logprob=None, compression_ratio=None) + parts, _ = _filter_segments([seg]) + assert parts == ["ok"] + + +def test_drops_empty_text(): + parts, _ = _filter_segments([FakeSeg(text=" ", avg_logprob=-0.2)]) + assert parts == [] + + +def test_worst_no_speech_is_max_across_segments(): + segs = [ + FakeSeg(text="a", no_speech_prob=0.1, avg_logprob=-0.2), + FakeSeg(text="b", no_speech_prob=0.4, avg_logprob=-0.2), + ] + parts, worst = _filter_segments(segs) + assert parts == ["a", "b"] + assert worst == 0.4 + + +def test_mixed_keep_and_drop(): + segs = [ + FakeSeg(text="bun venit", avg_logprob=-0.3), + FakeSeg(text="garbage", avg_logprob=-3.0), # dropped: low logprob + FakeSeg(text="la revedere", avg_logprob=-0.5), + ] + parts, _ = _filter_segments(segs) + assert parts == ["bun venit", "la revedere"] diff --git a/tests/test_voice_stt_mine.py b/tests/test_voice_stt_mine.py new file mode 100644 index 0000000..46e3326 --- /dev/null +++ b/tests/test_voice_stt_mine.py @@ -0,0 +1,100 @@ +"""Tests for tools/voice_stt_mine.py — STT log mining helpers. + +Pure-function coverage: tokenize, token_frequency, rare_tokens, +missing_diacritic_candidates, suspect_rows, row_text (back-compat with rows +that predate the text_corrected field). +""" +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from tools.voice_stt_mine import ( # noqa: E402 + missing_diacritic_candidates, + rare_tokens, + row_text, + suspect_rows, + token_frequency, + tokenize, +) + + +def test_tokenize_lowercases_and_drops_punct(): + assert tokenize("Salut, Eco!") == ["salut", "eco"] + + +def test_tokenize_keeps_diacritics(): + assert tokenize("ședință și prețul") == ["ședință", "și", "prețul"] + + +def test_tokenize_drops_digits(): + # M3, numbers etc. are not alphabetic word tokens + assert tokenize("M3 are 120 lei") == ["m", "are", "lei"] + + +def test_tokenize_empty_and_none(): + assert tokenize("") == [] + assert tokenize(None) == [] + + +def test_row_text_prefers_raw_text_field(): + # Mining always wants raw STT output (the `text` field), even once + # newer rows add `text_corrected`. + assert row_text({"text": "cat", "text_corrected": "cât"}) == "cat" + + +def test_row_text_missing_field(): + assert row_text({}) == "" + + +def test_token_frequency_counts_across_rows(): + rows = [{"text": "eco eco"}, {"text": "Eco salut"}] + freq = token_frequency(rows) + assert freq["eco"] == 3 + assert freq["salut"] == 1 + + +def test_rare_tokens_returns_singletons_sorted(): + rows = [{"text": "eco eco salut bitcoin"}] + rare = rare_tokens(token_frequency(rows)) + assert rare == ["bitcoin", "salut"] # eco appears twice -> excluded + assert "eco" not in rare + + +def test_missing_diacritic_candidates_flags_ascii_words(): + rows = [{"text": "pretul este mare"}, {"text": "ședință corectă"}] + cands = missing_diacritic_candidates(token_frequency(rows), min_len=4) + assert "pretul" in cands + assert "mare" in cands + # words carrying diacritics are NOT restore candidates + assert "ședință" not in cands + assert "corectă" not in cands + + +def test_missing_diacritic_respects_min_len(): + rows = [{"text": "cat de bun"}] + cands = missing_diacritic_candidates(token_frequency(rows), min_len=4) + assert "cat" not in cands # len 3 < 4 + assert "bun" not in cands + + +def test_suspect_rows_flags_high_latency(): + rows = [ + {"text": "ok", "stt_latency_s": 2.0, "no_speech_prob": 0.0}, + {"text": "M3.", "stt_latency_s": 24.4, "no_speech_prob": 0.58}, + ] + suspects = suspect_rows(rows) + assert len(suspects) == 1 + assert suspects[0]["text"] == "M3." + + +def test_suspect_rows_flags_borderline_no_speech(): + rows = [{"text": "x", "stt_latency_s": 1.0, "no_speech_prob": 0.55}] + assert len(suspect_rows(rows)) == 1 + + +def test_suspect_rows_tolerates_missing_fields(): + # rows without latency/no_speech must not crash + assert suspect_rows([{"text": "x"}]) == [] diff --git a/tools/voice_stt_mine.py b/tools/voice_stt_mine.py new file mode 100644 index 0000000..6320dde --- /dev/null +++ b/tools/voice_stt_mine.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +"""Mine logs/voice_stt_log.jsonl for STT correction candidates. + +Read-only analysis tool. Surfaces what the always-on STT log has captured so +Marius can decide hotwords, spot recurring mistranscriptions, and judge whether +a model swap (e.g. a Romanian-finetuned Whisper) actually helps. + +Pure helpers (tokenize / aggregate) are importable and tested; the CLI just +prints reports. Tolerates rows written before the `text_corrected` field +existed (falls back to `text`). + +Usage: + python3 tools/voice_stt_mine.py # full report + python3 tools/voice_stt_mine.py --tokens # token frequency only + python3 tools/voice_stt_mine.py --rare # one-off tokens (candidates) + python3 tools/voice_stt_mine.py --suspect # likely hallucination rows + python3 tools/voice_stt_mine.py --log PATH # custom log path +""" +from __future__ import annotations + +import argparse +import json +import re +import sys +from collections import Counter +from pathlib import Path +from typing import Iterable, Iterator + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_LOG = PROJECT_ROOT / "logs" / "voice_stt_log.jsonl" + +# Latency above this (s) almost always means the decoder thrashed on unclear +# audio — a strong hallucination signal worth reviewing. Mirrors the >7s +# conversational-abort budget from tasks/voice-bench-results.md. +SUSPECT_LATENCY_S = 7.0 +SUSPECT_NO_SPEECH = 0.5 + +_TOKEN_RE = re.compile(r"[A-Za-zĂÂÎȘȚăâîșț]+", re.UNICODE) +# Romanian diacritic letters; a token with none of these is a diacritic-restore +# candidate worth a human glance (not auto-corrected — see plan D2). +_DIACRITICS = set("ĂÂÎȘȚăâîșț") + + +def read_log(path: Path) -> list[dict]: + """Parse the JSONL log; skip malformed lines instead of crashing.""" + rows: list[dict] = [] + if not path.exists(): + return rows + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + rows.append(json.loads(line)) + except json.JSONDecodeError: + continue + return rows + + +def row_text(row: dict) -> str: + """Raw transcript for a row. New rows may add `text_corrected`; mining + always wants the raw STT output, which lives in `text`.""" + return (row.get("text") or "").strip() + + +def tokenize(text: str) -> list[str]: + """Split into alphabetic word tokens, lowercased. Drops digits/punct.""" + return [t.lower() for t in _TOKEN_RE.findall(text or "")] + + +def token_frequency(rows: Iterable[dict]) -> Counter: + counter: Counter = Counter() + for row in rows: + counter.update(tokenize(row_text(row))) + return counter + + +def rare_tokens(freq: Counter, max_count: int = 1) -> list[str]: + """Tokens seen at most `max_count` times — candidate mistranscriptions, + proper nouns to add as hotwords, or code-switch garbage.""" + return sorted(t for t, c in freq.items() if c <= max_count) + + +def missing_diacritic_candidates(freq: Counter, min_len: int = 4) -> list[str]: + """All-ASCII tokens (no Romanian diacritics) of reasonable length, sorted by + frequency. These are the words a diacritic-restore pass would target — kept + as a review list only (v1 does not auto-restore, per plan D2).""" + out = [ + (t, c) for t, c in freq.items() + if len(t) >= min_len and not (set(t) & _DIACRITICS) and t.isalpha() + ] + out.sort(key=lambda tc: (-tc[1], tc[0])) + return [t for t, _ in out] + + +def suspect_rows(rows: Iterable[dict]) -> list[dict]: + """Rows that look like hallucinations: very high latency or borderline + no_speech_prob that still produced text.""" + out = [] + for row in rows: + lat = float(row.get("stt_latency_s") or 0.0) + nsp = float(row.get("no_speech_prob") or 0.0) + if lat >= SUSPECT_LATENCY_S or nsp >= SUSPECT_NO_SPEECH: + out.append(row) + return out + + +def _iter_report(rows: list[dict]) -> Iterator[str]: + freq = token_frequency(rows) + yield f"entries: {len(rows)}" + if rows: + lats = [float(r.get("stt_latency_s") or 0.0) for r in rows] + yield f"latency: mean={sum(lats)/len(lats):.2f}s max={max(lats):.2f}s" + yield "" + yield "== top tokens ==" + for tok, cnt in freq.most_common(20): + yield f" {cnt:>3} {tok}" + yield "" + yield "== rare tokens (<=1, candidate corrections / hotwords) ==" + rare = rare_tokens(freq) + yield " " + (", ".join(rare) if rare else "(none)") + yield "" + yield "== missing-diacritic candidates (review only) ==" + cands = missing_diacritic_candidates(freq)[:30] + yield " " + (", ".join(cands) if cands else "(none)") + yield "" + suspects = suspect_rows(rows) + yield f"== likely-hallucination rows ({len(suspects)}) ==" + for r in suspects: + yield (f" lat={float(r.get('stt_latency_s') or 0):.1f}s " + f"nsp={float(r.get('no_speech_prob') or 0):.2f} " + f"{row_text(r)!r}") + + +def main(argv: list[str] | None = None) -> int: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--log", type=Path, default=DEFAULT_LOG, help="path to voice_stt_log.jsonl") + ap.add_argument("--tokens", action="store_true", help="token frequency only") + ap.add_argument("--rare", action="store_true", help="one-off tokens only") + ap.add_argument("--suspect", action="store_true", help="likely-hallucination rows only") + args = ap.parse_args(argv) + + rows = read_log(args.log) + if not rows: + print(f"no entries in {args.log}", file=sys.stderr) + return 1 + + freq = token_frequency(rows) + if args.tokens: + for tok, cnt in freq.most_common(): + print(f"{cnt:>4} {tok}") + elif args.rare: + print("\n".join(rare_tokens(freq))) + elif args.suspect: + for r in suspect_rows(rows): + print(f"lat={float(r.get('stt_latency_s') or 0):.1f}s " + f"nsp={float(r.get('no_speech_prob') or 0):.2f} {row_text(r)!r}") + else: + print("\n".join(_iter_report(rows))) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/voice_stt_spike.py b/tools/voice_stt_spike.py new file mode 100644 index 0000000..8dd6dfe --- /dev/null +++ b/tools/voice_stt_spike.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +"""STT model spike — compare faster-whisper models on Romanian diacritic accuracy. + +Answers the autoplan D1 question: does a Romanian-finetuned whisper-small beat the +generic `small` on diacritics WITHOUT regressing latency? Synthesizes clean RO +audio via Supertonic (ground-truth text known, with diacritics), runs each model, +and scores diacritic preservation + word error rate + latency. + +This is an evaluation harness, not production code. Synthetic TTS audio is a clean +probe for diacritic behaviour specifically — it is NOT a proxy for real-mic acoustic +robustness, so weight latency + diacritics, not absolute WER. + +Usage: + python3 tools/voice_stt_spike.py --models small,/home/.../whisper-small-ro-cv11-int8 + python3 tools/voice_stt_spike.py --models small, --threads 4 --trials 2 +""" +from __future__ import annotations + +import argparse +import statistics +import sys +import tempfile +import time +import unicodedata +from pathlib import Path + +import httpx + +SUPERTONIC_URL = "http://127.0.0.1:7788" + +# Ground truth with correct Romanian diacritics (mirrors tools/voice_bench.py). +UTTERANCES_RO: list[tuple[str, str]] = [ + ("short", "Salut, ce mai faci?"), + ("conversational", "Stai puțin să mă gândesc la asta."), + ("medium", "Am verificat în calendar și avem ședință cu echipa la trei după-amiază."), + ("numbers", "Costul total este o sută douăzeci și trei de lei și cincizeci de bani."), + ("question", "Marius, vrei să-ți pun pe agenda de mâine să suni la NOAA?"), + ("longer", "Vreau să-mi reamintești diseară să verific dacă scriptul de backup a rulat corect."), +] + +_DIACRITICS = set("ăâîșțĂÂÎȘȚ") + + +def _strip_punct_lower(text: str) -> list[str]: + out = [] + for raw in text.split(): + w = "".join(c for c in raw if c.isalnum() or c in _DIACRITICS) + if w: + out.append(w.lower()) + return out + + +def _deaccent(s: str) -> str: + # Map RO diacritics to base letters for "ignoring diacritics" comparison. + table = str.maketrans("ăâîșțĂÂÎȘȚ", "aaistAAIST") + s = s.translate(table) + return "".join(c for c in unicodedata.normalize("NFD", s) + if unicodedata.category(c) != "Mn") + + +def wer(ref: list[str], hyp: list[str]) -> float: + """Word error rate via Levenshtein on token lists.""" + n, m = len(ref), len(hyp) + if n == 0: + return 0.0 if m == 0 else 1.0 + dp = list(range(m + 1)) + for i in range(1, n + 1): + prev = dp[0] + dp[0] = i + for j in range(1, m + 1): + cur = dp[j] + cost = 0 if ref[i - 1] == hyp[j - 1] else 1 + dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) + prev = cur + return dp[m] / n + + +def diacritic_score(ref: str, hyp: str) -> tuple[int, int]: + """For each ground-truth word that carries a diacritic, did the hypothesis + contain that exact diacritized word? Returns (correct, total).""" + ref_words = _strip_punct_lower(ref) + hyp_words = set(_strip_punct_lower(hyp)) + hyp_deaccent = {_deaccent(w) for w in hyp_words} + correct = total = 0 + for w in ref_words: + if set(w) & _DIACRITICS: + total += 1 + if w in hyp_words: + correct += 1 + return correct, total + + +def synthesize(text: str, out_path: Path) -> float: + import wave + r = httpx.post(f"{SUPERTONIC_URL}/v1/audio/speech", + json={"model": "supertonic-3", "input": text, "voice": "M2", + "response_format": "wav", "lang": "ro"}, timeout=60.0) + r.raise_for_status() + out_path.write_bytes(r.content) + with wave.open(str(out_path), "rb") as wf: + return wf.getnframes() / float(wf.getframerate()) + + +def run_model(model_ref: str, wavs: list[tuple[str, str, str, float]], + threads: int, trials: int) -> dict: + from faster_whisper import WhisperModel + t0 = time.perf_counter() + model = WhisperModel(model_ref, device="cpu", compute_type="int8", cpu_threads=threads) + load_s = time.perf_counter() - t0 + rows, lats = [], [] + dia_c = dia_t = 0 + wers = [] + for name, ref_text, wav_path, _dur in wavs: + best_text = "" + for trial in range(trials): + t1 = time.perf_counter() + segments, _ = model.transcribe(wav_path, language="ro", beam_size=5, + temperature=0.0, condition_on_previous_text=False) + text = " ".join(s.text.strip() for s in segments).strip() + lats.append(time.perf_counter() - t1) + if trial == 0: + best_text = text + c, tt = diacritic_score(ref_text, best_text) + dia_c += c + dia_t += tt + w = wer(_strip_punct_lower(ref_text), _strip_punct_lower(best_text)) + wers.append(w) + rows.append((name, ref_text, best_text, c, tt, w)) + return { + "model": model_ref, "load_s": load_s, "rows": rows, + "p50": statistics.median(lats), "p95": sorted(lats)[max(0, int(0.95*(len(lats)-1)))], + "dia_correct": dia_c, "dia_total": dia_t, + "wer": statistics.mean(wers) if wers else 1.0, + } + + +def main(argv=None) -> int: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--models", required=True, help="CSV of model names or CT2 dirs") + ap.add_argument("--threads", type=int, default=4) + ap.add_argument("--trials", type=int, default=2) + args = ap.parse_args(argv) + + work = Path(tempfile.mkdtemp(prefix="stt_spike_")) + print(f"[spike] synth dir {work}", flush=True) + wavs = [] + for name, text in UTTERANCES_RO: + p = work / f"{name}.wav" + dur = synthesize(text, p) + wavs.append((name, text, str(p), dur)) + print(f"[spike] TTS {name}: {dur:.2f}s", flush=True) + + results = [] + for ref in args.models.split(","): + ref = ref.strip() + if not ref: + continue + print(f"[spike] running {ref} (threads={args.threads})…", flush=True) + results.append(run_model(ref, wavs, args.threads, args.trials)) + + print("\n" + "=" * 72) + print(f"{'model':<42} {'p50':>6} {'p95':>6} {'WER':>6} {'diacr':>8}") + print("-" * 72) + for r in results: + dia = f"{r['dia_correct']}/{r['dia_total']}" + label = Path(r["model"]).name if "/" in r["model"] else r["model"] + print(f"{label:<42} {r['p50']:>5.2f}s {r['p95']:>5.2f}s {r['wer']*100:>5.1f}% {dia:>8}") + print("=" * 72) + for r in results: + label = Path(r["model"]).name if "/" in r["model"] else r["model"] + print(f"\n### {label}") + for name, ref_text, hyp, c, tt, w in r["rows"]: + print(f" [{name}] ref: {ref_text}") + print(f" [{name}] hyp: {hyp} (diacr {c}/{tt}, wer {w*100:.0f}%)") + return 0 + + +if __name__ == "__main__": + sys.exit(main())