Files
roa2web-service-auto/tests/ocr-validation/ocr-direct-validation.py
Marius Mutu 495790411f feat(ocr): Add docTR OCR engine with metrics infrastructure
Add docTR as primary OCR engine with 2-tier sequential processing,
OCR metrics tracking, and simplified engine selection.

Features:
- docTR OCR engine with light+medium preprocessing tiers
- doctr_plus mode with early exit optimization (~65% fast path)
- OCR metrics dashboard with per-engine statistics
- User OCR preference persistence
- Parallel worker pool for OCR processing
- Cross-validation for extraction quality

Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-02 05:37:16 +02:00

594 lines
24 KiB
Python

#!/usr/bin/env python3
"""
OCR Direct Validation Tests
This script validates the OCR extraction accuracy by:
1. Generating a test JWT token
2. Calling the OCR API endpoint with PDF receipts
3. Comparing extracted data with expected values from expected_receipts.json
Run:
python tests/ocr-validation/ocr-direct-validation.py
python tests/ocr-validation/ocr-direct-validation.py --engine doctr_plus
python tests/ocr-validation/ocr-direct-validation.py --receipt receipt_01
"""
import sys
import os
import json
import argparse
import time
import requests
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, List, Any
# Add backend and project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'backend'))
# Import JWT handler to create test tokens
from jose import jwt
def create_test_token(secret_key: str) -> str:
"""Create a test JWT token for API authentication."""
now = datetime.utcnow()
expire = now + timedelta(hours=1)
payload = {
"username": "ocr_test_user",
"user_id": 999,
"companies": ["TEST"],
"permissions": ["read", "write"],
"exp": expire,
"iat": now,
"type": "access"
}
return jwt.encode(payload, secret_key, algorithm="HS256")
def normalize_cui(cui: Optional[str]) -> Optional[str]:
"""Normalize CUI by removing RO prefix and spaces."""
if not cui:
return None
return cui.upper().replace('RO', '').replace(' ', '')
def normalize_date(date: Optional[str]) -> Optional[str]:
"""Normalize date to YYYY-MM-DD format."""
if not date:
return None
try:
# Try parsing ISO format
from datetime import datetime
parsed = datetime.fromisoformat(date.replace('Z', '+00:00'))
return parsed.strftime('%Y-%m-%d')
except:
return date
def compare_with_tolerance(expected: float, actual, tolerance: float) -> bool:
"""Compare numbers with tolerance."""
if actual is None:
return False
# Handle string values from API
try:
actual_float = float(actual)
except (ValueError, TypeError):
return False
diff = abs(expected - actual_float)
threshold = expected * tolerance
return diff <= threshold or diff <= 0.01 # Within tolerance or 1 cent
def submit_ocr_job(api_base: str, token: str, pdf_path: Path, engine: str) -> dict:
"""Submit a PDF file for OCR processing and wait for result.
Returns dict with detailed timing information:
- timing.submit_duration: time to submit job and get job_id
- timing.poll_count: number of poll requests made
- timing.poll_duration: total time spent polling
- timing.wall_time: total elapsed time (submit + poll)
- timing.api_reported_ms: processing time reported by API
"""
headers = {'Authorization': f'Bearer {token}'}
# Timing tracking
timing = {
'submit_duration': 0.0,
'poll_count': 0,
'poll_duration': 0.0,
'wall_time': 0.0,
'api_reported_ms': 0,
}
wall_start = time.time()
# Submit job
submit_start = time.time()
with open(pdf_path, 'rb') as f:
files = {'file': (pdf_path.name, f, 'application/pdf')}
response = requests.post(
f'{api_base}/api/data-entry/ocr/extract?engine={engine}',
headers=headers,
files=files,
timeout=60
)
timing['submit_duration'] = time.time() - submit_start
if not response.ok:
timing['wall_time'] = time.time() - wall_start
return {'status': 'failed', 'error': f'Submit failed: {response.status_code} - {response.text}', 'timing': timing}
job = response.json()
job_id = job.get('job_id')
if not job_id:
timing['wall_time'] = time.time() - wall_start
return {'status': 'failed', 'error': 'No job_id in response', 'timing': timing}
# Poll for result
poll_start = time.time()
max_wait = 120 # 2 minutes
while time.time() - wall_start < max_wait:
timing['poll_count'] += 1
poll_response = requests.get(
f'{api_base}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30',
headers=headers,
timeout=35
)
if not poll_response.ok:
time.sleep(1)
continue
job_status = poll_response.json()
if job_status.get('status') == 'completed':
timing['poll_duration'] = time.time() - poll_start
timing['wall_time'] = time.time() - wall_start
# Detailed timing from API
timing['queue_wait_ms'] = job_status.get('queue_wait_ms', 0) or 0
timing['ocr_time_ms'] = job_status.get('ocr_time_ms', 0) or 0
timing['processing_time_ms'] = job_status.get('processing_time_ms', 0) or 0
return {
'status': 'completed',
'result': job_status.get('result', {}),
'processing_time_ms': job_status.get('processing_time_ms', 0),
'timing': timing
}
if job_status.get('status') == 'failed':
timing['poll_duration'] = time.time() - poll_start
timing['wall_time'] = time.time() - wall_start
return {'status': 'failed', 'error': job_status.get('error', 'Unknown error'), 'timing': timing}
# Still pending - show status but don't spam
if timing['poll_count'] <= 3 or timing['poll_count'] % 5 == 0:
print(f" Status: {job_status.get('status')}, position: {job_status.get('queue_position')}, polls: {timing['poll_count']}")
timing['poll_duration'] = time.time() - poll_start
timing['wall_time'] = time.time() - wall_start
return {'status': 'failed', 'error': 'Timeout waiting for OCR result', 'timing': timing}
def main():
parser = argparse.ArgumentParser(description='OCR Direct Validation')
parser.add_argument('--engine', default='doctr_plus',
choices=['tesseract', 'doctr', 'doctr_plus', 'paddleocr'],
help='OCR engine to use (doctr_plus recommended)')
parser.add_argument('--receipt', help='Specific receipt ID to test (e.g., receipt_01)')
parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output')
parser.add_argument('--api-base', default='http://localhost:8000', help='API base URL')
parser.add_argument('--stop-on-issue', action='store_true',
help='Stop at first receipt with wall_time > 7.5s or extraction errors')
parser.add_argument('--include-multipage', action='store_true',
help='Include multi-page PDFs (normally skipped)')
args = parser.parse_args()
# Paths
script_dir = Path(__file__).parent
expected_path = script_dir / 'expected_receipts.json'
pdf_base_path = script_dir.parent.parent / 'docs' / 'data-entry'
# JWT secret from environment or default
jwt_secret = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
# Create test token
token = create_test_token(jwt_secret)
# Load expected data
print(f"\n{'='*60}")
print("OCR API VALIDATION")
print(f"{'='*60}")
print(f"Engine: {args.engine}")
print(f"API Base: {args.api_base}")
print(f"Expected data: {expected_path}")
print(f"PDF folder: {pdf_base_path}")
with open(expected_path) as f:
expected_data = json.load(f)
# Filter receipts
receipts_to_test = expected_data['receipts']
# Skip multi-page PDFs unless explicitly included
if not args.include_multipage:
original_count = len(receipts_to_test)
receipts_to_test = [r for r in receipts_to_test if r.get('page') is None]
skipped = original_count - len(receipts_to_test)
if skipped > 0:
print(f"Skipping {skipped} multi-page PDF entries (use --include-multipage to include)")
# Filter by specific receipt ID if requested
if args.receipt:
receipts_to_test = [r for r in receipts_to_test if r['id'] == args.receipt]
if not receipts_to_test:
print(f"\nError: Receipt ID '{args.receipt}' not found")
sys.exit(1)
print(f"Receipts to test: {len(receipts_to_test)}")
print(f"{'='*60}\n")
# Results storage
results: List[Dict[str, Any]] = []
# Test each receipt
for expected in receipts_to_test:
pdf_path = pdf_base_path / expected['filename']
if not pdf_path.exists():
print(f"[SKIP] File not found: {expected['filename']}")
continue
print(f"[TEST] Processing: {expected['filename']}")
try:
# Submit OCR job via API
start_time = datetime.now()
ocr_result = submit_ocr_job(args.api_base, token, pdf_path, args.engine)
# Handle processing_time which may be string or number
raw_time = ocr_result.get('processing_time_ms')
if raw_time is not None:
processing_time = float(raw_time)
else:
processing_time = (datetime.now() - start_time).total_seconds() * 1000
if ocr_result.get('status') == 'failed':
print(f" [ERROR] OCR failed: {ocr_result.get('error')}")
results.append({
'receipt_id': expected['id'],
'filename': expected['filename'],
'status': 'failed',
'error': ocr_result.get('error'),
})
continue
# Get extracted values
extracted = ocr_result.get('result', {})
# Safe float conversion helper
def safe_float(value, default=0.0):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
# Compare results
comparison = {
'receipt_id': expected['id'],
'filename': expected['filename'],
'status': 'completed',
'total_expected': expected['total'],
'total_extracted': safe_float(extracted.get('amount'), None),
'total_match': False,
'date_expected': expected['data_bon'],
'date_extracted': extracted.get('receipt_date'),
'date_match': False,
'cui_expected': expected['cui_furnizor'],
'cui_extracted': extracted.get('cui'),
'cui_match': False,
'tva_expected': expected['total_tva'],
'tva_extracted': safe_float(extracted.get('tva_total'), None),
'tva_match': False,
'confidence': safe_float(extracted.get('overall_confidence'), 0),
'processing_time_ms': processing_time,
'ocr_engine': extracted.get('ocr_engine', args.engine),
'errors': [],
# NEW: Save full extraction for analysis
'full_extraction': {
'amount': extracted.get('amount'),
'receipt_date': extracted.get('receipt_date'),
'cui': extracted.get('cui'),
'tva_total': extracted.get('tva_total'),
'tva_entries': extracted.get('tva_entries', []),
'supplier_name': extracted.get('supplier_name'),
'receipt_number': extracted.get('receipt_number'),
'payment_methods': extracted.get('payment_methods', []),
'items_count': extracted.get('items_count'),
'overall_confidence': extracted.get('overall_confidence'),
'confidence_amount': extracted.get('confidence_amount'),
'confidence_date': extracted.get('confidence_date'),
'confidence_cui': extracted.get('confidence_cui'),
},
# NEW: Save raw OCR texts from each engine pass
'raw_texts': extracted.get('raw_texts', []),
}
# Compare TOTAL
comparison['total_match'] = compare_with_tolerance(
expected['total'],
extracted.get('amount'),
0.02 # 2% tolerance
)
if not comparison['total_match']:
comparison['errors'].append(
f"TOTAL: expected {expected['total']}, got {extracted.get('amount')}"
)
# Compare DATE
normalized_expected_date = normalize_date(expected['data_bon'])
normalized_extracted_date = normalize_date(extracted.get('receipt_date'))
comparison['date_match'] = normalized_expected_date == normalized_extracted_date
if not comparison['date_match']:
comparison['errors'].append(
f"DATE: expected {normalized_expected_date}, got {normalized_extracted_date}"
)
# Compare CUI
normalized_expected_cui = normalize_cui(expected['cui_furnizor'])
normalized_extracted_cui = normalize_cui(extracted.get('cui'))
comparison['cui_match'] = normalized_expected_cui == normalized_extracted_cui
if not comparison['cui_match']:
comparison['errors'].append(
f"CUI: expected {normalized_expected_cui}, got {normalized_extracted_cui}"
)
# Compare TVA
if expected['total_tva'] > 0:
comparison['tva_match'] = compare_with_tolerance(
expected['total_tva'],
extracted.get('tva_total'),
0.05 # 5% tolerance
)
if not comparison['tva_match']:
comparison['errors'].append(
f"TVA: expected {expected['total_tva']}, got {extracted.get('tva_total')}"
)
else:
# No TVA expected (neplatitor TVA)
tva_extracted = safe_float(extracted.get('tva_total'), None)
comparison['tva_match'] = tva_extracted is None or tva_extracted == 0 or tva_extracted == 0.0
results.append(comparison)
# Get timing info from API (detailed breakdown)
t = ocr_result.get('timing', {})
wall_ms = t.get('wall_time', 0) * 1000
queue_wait_ms = t.get('queue_wait_ms', 0)
ocr_time_ms = t.get('ocr_time_ms', 0)
processing_time_ms = t.get('processing_time_ms', 0)
# Overhead = wall_time - processing_time (includes network, polling)
overhead_ms = wall_ms - processing_time_ms if processing_time_ms else 0
# Print result
status = 'PASS' if not comparison['errors'] else 'FAIL'
print(f" [{status}] Total: {expected['total']} vs {extracted.get('amount')} ({comparison['total_match']})")
print(f" Date: {expected['data_bon']} vs {extracted.get('receipt_date')} ({comparison['date_match']})")
print(f" CUI: {expected['cui_furnizor']} vs {extracted.get('cui')} ({comparison['cui_match']})")
print(f" TVA: {expected['total_tva']} vs {extracted.get('tva_total')} ({comparison['tva_match']})")
print(f" Confidence: {comparison['confidence']*100:.1f}%")
# Print detailed timing breakdown
print(f" TIMING: ocr={ocr_time_ms}ms, queue_wait={queue_wait_ms}ms, "
f"job_total={processing_time_ms}ms, wall={wall_ms:.0f}ms")
print(f" overhead={overhead_ms:.0f}ms (wall - job_total)")
if comparison['errors'] and args.verbose:
for err in comparison['errors']:
print(f" Error: {err}")
# Stop on issue if requested
if args.stop_on_issue:
has_errors = len(comparison['errors']) > 0
# Use OCR time for threshold (actual processing, not queue wait)
ocr_too_slow = ocr_time_ms > 10000 # 10s threshold for actual OCR
if has_errors or ocr_too_slow:
print(f"\n{'='*60}")
print(f"STOP: Issue detected on {expected['filename']}")
print(f"{'='*60}")
if ocr_too_slow:
print(f" SLOW: ocr_time={ocr_time_ms}ms > 10000ms threshold")
if has_errors:
print(f" ERRORS: {comparison['errors']}")
print(f"\n Full timing breakdown:")
print(f" ocr_time_ms: {ocr_time_ms}ms (actual OCR engine time)")
print(f" queue_wait_ms: {queue_wait_ms}ms (waiting in queue)")
print(f" processing_time_ms: {processing_time_ms}ms (job total)")
print(f" wall_time: {wall_ms:.0f}ms (client-side)")
print(f" overhead: {overhead_ms:.0f}ms (network + polling)")
print(f"\n Full extraction:")
print(json.dumps(extracted, indent=4, default=str))
sys.exit(1)
except Exception as e:
import traceback
print(f" [ERROR] {str(e)}")
if args.verbose:
traceback.print_exc()
results.append({
'receipt_id': expected['id'],
'filename': expected['filename'],
'status': 'error',
'error': str(e),
})
# Calculate statistics
completed_results = [r for r in results if r.get('status') == 'completed']
total_tests = len(completed_results)
if total_tests == 0:
print("\nNo tests completed successfully!")
sys.exit(1)
perfect_matches = len([r for r in completed_results
if r['total_match'] and r['date_match'] and r['cui_match'] and r['tva_match']])
total_match_rate = len([r for r in completed_results if r['total_match']]) / total_tests
date_match_rate = len([r for r in completed_results if r['date_match']]) / total_tests
cui_match_rate = len([r for r in completed_results if r['cui_match']]) / total_tests
tva_match_rate = len([r for r in completed_results if r['tva_match']]) / total_tests
avg_confidence = sum(r['confidence'] for r in completed_results) / total_tests
avg_processing_time = sum(r['processing_time_ms'] for r in completed_results) / total_tests
# Print summary
print(f"\n{'='*60}")
print("OCR VALIDATION SUMMARY")
print(f"{'='*60}")
print(f"Total Receipts Tested: {total_tests}")
print(f"Perfect Matches: {perfect_matches} ({perfect_matches/total_tests*100:.1f}%)")
print("---")
print(f"Total Amount Match Rate: {total_match_rate*100:.1f}%")
print(f"Date Match Rate: {date_match_rate*100:.1f}%")
print(f"CUI Match Rate: {cui_match_rate*100:.1f}%")
print(f"TVA Match Rate: {tva_match_rate*100:.1f}%")
print("---")
print(f"Average Confidence: {avg_confidence*100:.1f}%")
print(f"Average Processing Time: {avg_processing_time:.0f}ms")
print(f"{'='*60}")
# Failed receipts
failed_results = [r for r in completed_results if r.get('errors')]
if failed_results:
print(f"\nFAILED RECEIPTS ({len(failed_results)}):")
for r in failed_results:
print(f" - {r['filename']}: {'; '.join(r['errors'])}")
# Categorize problems for analysis
problems_analysis = {
'cui_issues': [],
'tva_issues': [],
'total_issues': [],
'date_issues': [],
'confidence_issues': [],
}
for r in completed_results:
# CUI issues
if not r.get('cui_match'):
cui_expected = normalize_cui(r.get('cui_expected'))
cui_got = normalize_cui(r.get('cui_extracted'))
issue_type = 'missing' if not cui_got else 'mismatch'
# Check if it's a digit substitution (same length, 1-2 chars different)
if cui_expected and cui_got and len(cui_expected) == len(cui_got):
diff_count = sum(1 for a, b in zip(cui_expected, cui_got) if a != b)
if diff_count <= 2:
issue_type = 'digit_substitution'
problems_analysis['cui_issues'].append({
'file': r['filename'],
'expected': r.get('cui_expected'),
'got': r.get('cui_extracted'),
'type': issue_type,
'confidence': r.get('confidence', 0),
})
# TVA issues
if not r.get('tva_match'):
tva_expected = r.get('tva_expected', 0)
tva_got = r.get('tva_extracted')
issue_type = 'missing' if tva_got is None else 'mismatch'
# Check for 5% rate (books)
if tva_expected and tva_expected > 0:
full_ext = r.get('full_extraction', {})
total = full_ext.get('amount')
if total and tva_expected:
try:
implied_rate = float(tva_expected) / float(total) * 100
if 4 <= implied_rate <= 6:
issue_type = 'low_rate_5pct'
except:
pass
problems_analysis['tva_issues'].append({
'file': r['filename'],
'expected': tva_expected,
'got': tva_got,
'type': issue_type,
'tva_entries': r.get('full_extraction', {}).get('tva_entries', []),
})
# TOTAL issues
if not r.get('total_match'):
problems_analysis['total_issues'].append({
'file': r['filename'],
'expected': r.get('total_expected'),
'got': r.get('total_extracted'),
'confidence': r.get('confidence', 0),
'payment_methods': r.get('full_extraction', {}).get('payment_methods', []),
})
# DATE issues
if not r.get('date_match'):
problems_analysis['date_issues'].append({
'file': r['filename'],
'expected': r.get('date_expected'),
'got': r.get('date_extracted'),
})
# Low confidence issues
if r.get('confidence', 0) < 0.7:
problems_analysis['confidence_issues'].append({
'file': r['filename'],
'confidence': r.get('confidence', 0),
'errors': r.get('errors', []),
})
# Save detailed report
report = {
'test_date': datetime.now().isoformat(),
'engine': args.engine,
'statistics': {
'total_tests': total_tests,
'perfect_matches': perfect_matches,
'perfect_match_rate': perfect_matches / total_tests,
'total_match_rate': total_match_rate,
'date_match_rate': date_match_rate,
'cui_match_rate': cui_match_rate,
'tva_match_rate': tva_match_rate,
'avg_confidence': avg_confidence,
'avg_processing_time_ms': avg_processing_time,
},
'problems_analysis': problems_analysis,
'failed_receipts': [
{'filename': r['filename'], 'errors': r['errors']}
for r in failed_results
],
'detailed_results': results,
}
# Save report with engine name in filename
report_path = script_dir / f'ocr_report_{args.engine.replace("-", "_")}_FULL.json'
with open(report_path, 'w') as f:
json.dump(report, f, indent=2, default=str)
print(f"\nReport saved to: {report_path}")
# Exit with error if match rates are below threshold
if total_match_rate < 0.8:
print(f"\n[FAIL] Total match rate {total_match_rate*100:.1f}% is below 80% threshold")
sys.exit(1)
print("\n[PASS] OCR validation completed successfully!")
if __name__ == '__main__':
main()