Files
nlp-master/retranscribe_tail.py
Marius Mutu 763999f3a9 feat: anti-hallucination params + retranscribe script for fixing broken transcripts
- transcribe.py: add --max-context 0, --entropy-thold 2.4, --max-len 60,
  --suppress-nst, --no-fallback to whisper.cpp to prevent hallucination loops
- transcribe.py: remove interactive quality gate (runs unattended now)
- run.bat: remove pause prompts for unattended operation
- retranscribe_tail.py: new script that detects hallucination bursts in SRT
  files, extracts and re-transcribes only the affected audio segments, then
  splices the result back together. Drops segments that re-hallucinate
  (silence/music). Backs up originals to transcripts/backup/.
- fix_hallucinations.bat: Windows wrapper for retranscribe_tail.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 21:17:14 +02:00

416 lines
14 KiB
Python

"""
Re-transcribe only the hallucinated portions of a transcript.
Detects hallucination bursts (repeated lines in SRT), classifies each as:
- "burst": short hallucination with good content after → extract just that segment
- "tail": hallucination runs to end of file → extract from burst start to end
Extracts audio for each bad segment, re-transcribes with anti-hallucination
parameters, and splices everything back together.
Usage:
python retranscribe_tail.py # auto-detect all broken transcripts
python retranscribe_tail.py "Master 25M1 Z2B" # fix a specific file
python retranscribe_tail.py --dry-run # show what would be fixed, don't run
"""
import os
import re
import shutil
import subprocess
import sys
import logging
from pathlib import Path
TRANSCRIPTS_DIR = Path("transcripts")
AUDIO_DIR = Path("audio")
WAV_CACHE_DIR = Path("audio_wav")
TEMP_DIR = Path("retranscribe_tmp")
WHISPER_BIN = os.getenv("WHISPER_BIN", r"whisper-cli.exe")
WHISPER_MODEL = os.getenv("WHISPER_MODEL", r"models\ggml-medium-q5_0.bin")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
log = logging.getLogger(__name__)
MIN_REPEATS = 4 # consecutive identical lines to count as hallucination
# If fewer than this many good entries remain after the last burst, treat as tail
TAIL_THRESHOLD = 50
# --- SRT parsing ---
def parse_srt(srt_path: Path) -> list[dict]:
"""Parse SRT file into list of {index, start, end, start_sec, end_sec, text}."""
content = srt_path.read_text(encoding="utf-8")
blocks = re.split(r"\n\n+", content.strip())
entries = []
for block in blocks:
lines = block.strip().split("\n")
if len(lines) < 3:
continue
try:
idx = int(lines[0])
except ValueError:
continue
ts_match = re.match(
r"(\d{2}):(\d{2}):(\d{2})[,.](\d{3})\s*-->\s*(\d{2}):(\d{2}):(\d{2})[,.](\d{3})",
lines[1],
)
if not ts_match:
continue
h, m, s, ms = (int(x) for x in ts_match.groups()[:4])
start_sec = h * 3600 + m * 60 + s + ms / 1000
h2, m2, s2, ms2 = (int(x) for x in ts_match.groups()[4:])
end_sec = h2 * 3600 + m2 * 60 + s2 + ms2 / 1000
text = "\n".join(lines[2:]).strip()
entries.append({
"index": idx,
"start": lines[1].split("-->")[0].strip(),
"end": lines[1].split("-->")[1].strip(),
"start_sec": start_sec,
"end_sec": end_sec,
"text": text,
})
return entries
def _fmt_ts(sec: float) -> str:
h = int(sec // 3600)
m = int((sec % 3600) // 60)
s = int(sec % 60)
ms = int((sec % 1) * 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
# --- Hallucination detection ---
def detect_bursts(entries: list[dict]) -> list[dict]:
"""
Find all hallucination bursts. Returns list of:
{start_idx, end_idx, start_sec, end_sec, text, count, type}
where type is "burst" (good content follows) or "tail" (nothing useful after).
"""
bursts = []
i = 0
while i < len(entries) - MIN_REPEATS:
text = entries[i]["text"].strip()
if not text:
i += 1
continue
run = 1
while i + run < len(entries) and entries[i + run]["text"].strip() == text:
run += 1
if run >= MIN_REPEATS:
bursts.append({
"start_idx": i,
"end_idx": i + run - 1,
"start_sec": entries[i]["start_sec"],
"end_sec": entries[i + run - 1]["end_sec"],
"text": text,
"count": run,
})
i += run
else:
i += 1
# Classify each burst
for burst in bursts:
remaining_after = len(entries) - burst["end_idx"] - 1
burst["type"] = "burst" if remaining_after >= TAIL_THRESHOLD else "tail"
return bursts
# --- Audio / transcription ---
def find_ffmpeg() -> str:
"""Find ffmpeg binary."""
for candidate in ["ffmpeg", "ffmpeg.exe", r"ffmpeg\ffmpeg.exe"]:
if Path(candidate).exists():
return candidate
try:
if subprocess.run([candidate, "-version"], capture_output=True).returncode == 0:
return candidate
except FileNotFoundError:
continue
return "ffmpeg"
def extract_audio_segment(wav_path: str, start_sec: float, end_sec: float | None,
output_path: str):
"""Extract audio segment. If end_sec is None, extract to end of file."""
ffmpeg = find_ffmpeg()
cmd = [
ffmpeg,
"-i", wav_path,
"-ss", f"{start_sec:.3f}",
]
if end_sec is not None:
duration = end_sec - start_sec
cmd += ["-t", f"{duration:.3f}"]
cmd += [
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
"-y",
output_path,
]
label = f"{start_sec:.1f}s-{'end' if end_sec is None else f'{end_sec:.1f}s'}"
log.info(f" Extracting audio [{label}]: {Path(output_path).name}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr[:300]}")
def transcribe_chunk(audio_path: str, output_base: str) -> bool:
"""Run whisper.cpp with anti-hallucination params on a chunk."""
cmd = [
WHISPER_BIN,
"--model", WHISPER_MODEL,
"--language", "ro",
"--no-gpu",
"--threads", str(os.cpu_count() or 4),
"--beam-size", "1",
"--best-of", "1",
"--max-context", "0", # don't carry context between segments
"--entropy-thold", "2.4",
"--max-len", "60",
"--suppress-nst", # suppress non-speech tokens (reduces hallucination on silence)
"--no-fallback", # don't retry with higher temperature (prevents hallucination amplification)
"--output-txt",
"--output-srt",
"--output-file", output_base,
"--file", audio_path,
]
log.info(f" CMD: {' '.join(cmd)}")
env = os.environ.copy()
whisper_dir = str(Path(WHISPER_BIN).resolve().parent)
env["PATH"] = whisper_dir + os.pathsep + env.get("PATH", "")
result = subprocess.run(
cmd,
stdout=sys.stdout,
stderr=sys.stderr,
timeout=7200,
env=env,
)
return result.returncode == 0
# --- Splice ---
def build_srt_block(idx: int, entry: dict) -> str:
"""Format a single SRT entry."""
return f"{idx}\n{entry['start']} --> {entry['end']}\n{entry['text']}\n"
def build_srt_block_offset(idx: int, entry: dict, offset_sec: float) -> str:
"""Format a single SRT entry with timestamp offset."""
new_start = entry["start_sec"] + offset_sec
new_end = entry["end_sec"] + offset_sec
return f"{idx}\n{_fmt_ts(new_start)} --> {_fmt_ts(new_end)}\n{entry['text']}\n"
def fix_transcript(stem: str, dry_run: bool = False) -> bool:
"""Fix one hallucinated transcript. Returns True on success."""
srt_path = TRANSCRIPTS_DIR / f"{stem}.srt"
txt_path = TRANSCRIPTS_DIR / f"{stem}.txt"
if not srt_path.exists():
log.error(f"SRT not found: {srt_path}")
return False
entries = parse_srt(srt_path)
bursts = detect_bursts(entries)
if not bursts:
log.info(f" {stem}: no hallucination detected, skipping")
return True
# Report findings
for b in bursts:
log.info(
f" {b['type'].upper()} (RE-TRANSCRIBE): entries {b['start_idx']}-{b['end_idx']} "
f"({_fmt_ts(b['start_sec'])} - {_fmt_ts(b['end_sec'])}), "
f"\"{b['text'][:50]}\" x{b['count']}"
)
if dry_run:
return True
# Find audio source (WAV)
audio_src = find_wav_for_stem(stem)
if not audio_src:
log.error(f" No audio found for {stem}")
return False
TEMP_DIR.mkdir(exist_ok=True)
# Backup originals
backup_dir = TRANSCRIPTS_DIR / "backup"
backup_dir.mkdir(exist_ok=True)
for p in [txt_path, srt_path]:
if p.exists():
backup = backup_dir / p.name
if backup.exists():
backup.unlink()
shutil.copy2(p, backup)
log.info(f" Backed up: {backup}")
# Build the fixed transcript by processing segments between bursts.
# Strategy: keep good entries as-is, replace each burst with re-transcription.
#
# Segments to process:
# [0 .. burst0.start) → keep (good)
# [burst0.start .. burst0.end] → re-transcribe
# (burst0.end .. burst1.start) → keep (good)
# [burst1.start .. burst1.end] → re-transcribe
# ...
# after last burst:
# if type=burst → keep remaining entries
# if type=tail → re-transcribe from burst start to end of audio
result_txt_parts = []
result_srt_entries = [] # list of (start_sec, end_sec, start_ts, end_ts, text)
chunk_idx = 0
prev_end_idx = -1 # last processed entry index
for burst in bursts:
# 1) Keep good entries before this burst
for i in range(prev_end_idx + 1, burst["start_idx"]):
e = entries[i]
result_srt_entries.append(e)
result_txt_parts.append(e["text"])
# 2) Re-transcribe the hallucinated segment
chunk_wav = str(TEMP_DIR / f"{stem}_chunk{chunk_idx}.wav")
if burst["type"] == "tail":
extract_audio_segment(audio_src, burst["start_sec"], None, chunk_wav)
else:
extract_audio_segment(
audio_src, burst["start_sec"], burst["end_sec"], chunk_wav
)
chunk_base = str(TEMP_DIR / f"{stem}_chunk{chunk_idx}")
success = transcribe_chunk(chunk_wav, chunk_base)
chunk_srt = Path(f"{chunk_base}.srt")
chunk_usable = False
if success and chunk_srt.exists():
chunk_entries = parse_srt(chunk_srt)
# Check if the retranscription itself hallucinated
if detect_bursts(chunk_entries):
log.warning(f" Chunk {chunk_idx} hallucinated again — "
f"likely silence/music, dropping segment")
elif not chunk_entries:
log.warning(f" Chunk {chunk_idx} produced empty output, dropping segment")
else:
chunk_usable = True
else:
log.warning(f" Whisper failed on chunk {chunk_idx}, dropping segment")
if chunk_usable:
# Read re-transcribed entries and offset timestamps
offset = burst["start_sec"]
for ce in chunk_entries:
ce["start_sec"] += offset
ce["end_sec"] += offset
ce["start"] = _fmt_ts(ce["start_sec"])
ce["end"] = _fmt_ts(ce["end_sec"])
result_srt_entries.append(ce)
result_txt_parts.append(ce["text"])
else:
log.info(f" Segment {_fmt_ts(burst['start_sec'])} - "
f"{_fmt_ts(burst['end_sec'])} removed (no usable speech)")
if burst["type"] == "tail":
prev_end_idx = len(entries)
else:
prev_end_idx = burst["end_idx"]
chunk_idx += 1
# 3) Keep any remaining good entries after last burst
if prev_end_idx < len(entries) - 1:
for i in range(prev_end_idx + 1, len(entries)):
e = entries[i]
result_srt_entries.append(e)
result_txt_parts.append(e["text"])
# Write final TXT
with open(txt_path, "w", encoding="utf-8") as f:
f.write("\n".join(result_txt_parts))
f.write("\n")
log.info(f" Written: {txt_path} ({txt_path.stat().st_size} bytes)")
# Write final SRT
with open(srt_path, "w", encoding="utf-8") as f:
for i, e in enumerate(result_srt_entries, 1):
f.write(f"{i}\n{e['start']} --> {e['end']}\n{e['text']}\n\n")
log.info(f" Written: {srt_path} ({srt_path.stat().st_size} bytes)")
log.info(f" {stem}: FIXED ({len(result_srt_entries)} entries, "
f"{chunk_idx} chunk(s) re-transcribed)")
return True
def find_wav_for_stem(stem: str) -> str | None:
"""Find the WAV file corresponding to a transcript stem."""
# Direct match
wav = WAV_CACHE_DIR / f"{stem}.wav"
if wav.exists():
return str(wav)
# Try with [Audio] suffix (original download names)
wav_audio = WAV_CACHE_DIR / f"{stem} [Audio].wav"
if wav_audio.exists():
return str(wav_audio)
# Glob for partial match
for w in WAV_CACHE_DIR.glob(f"{stem}*.wav"):
return str(w)
return None
def find_broken_transcripts() -> list[str]:
"""Scan all SRT files and return stems with hallucination."""
broken = []
for srt_file in sorted(TRANSCRIPTS_DIR.glob("*.srt")):
entries = parse_srt(srt_file)
if detect_bursts(entries):
broken.append(srt_file.stem)
return broken
def main():
args = sys.argv[1:]
dry_run = "--dry-run" in args
args = [a for a in args if a != "--dry-run"]
if args:
stems = args
else:
log.info("Scanning for hallucinated transcripts...")
stems = find_broken_transcripts()
if not stems:
log.info("All transcripts look clean!")
return
log.info(f"Found {len(stems)} broken transcript(s): {stems}")
for stem in stems:
log.info(f"\n{'='*60}")
log.info(f"Processing: {stem}")
log.info(f"{'='*60}")
fix_transcript(stem, dry_run=dry_run)
if not dry_run and TEMP_DIR.exists():
shutil.rmtree(TEMP_DIR)
log.info("Cleaned up temp directory")
if __name__ == "__main__":
main()