#!/usr/bin/env python3 """ OCR Direct Validation Tests This script validates the OCR extraction accuracy by: 1. Generating a test JWT token 2. Calling the OCR API endpoint with PDF receipts 3. Comparing extracted data with expected values from expected_receipts.json Run: python tests/ocr-validation/ocr-direct-validation.py python tests/ocr-validation/ocr-direct-validation.py --engine doctr_plus python tests/ocr-validation/ocr-direct-validation.py --receipt receipt_01 """ import sys import os import json import argparse import time import requests from pathlib import Path from datetime import datetime, timedelta from typing import Optional, Dict, List, Any # Add backend and project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'backend')) # Import JWT handler to create test tokens from jose import jwt def create_test_token(secret_key: str) -> str: """Create a test JWT token for API authentication.""" now = datetime.utcnow() expire = now + timedelta(hours=1) payload = { "username": "ocr_test_user", "user_id": 999, "companies": ["TEST"], "permissions": ["read", "write"], "exp": expire, "iat": now, "type": "access" } return jwt.encode(payload, secret_key, algorithm="HS256") def normalize_cui(cui: Optional[str]) -> Optional[str]: """Normalize CUI by removing RO prefix and spaces.""" if not cui: return None return cui.upper().replace('RO', '').replace(' ', '') def normalize_date(date: Optional[str]) -> Optional[str]: """Normalize date to YYYY-MM-DD format.""" if not date: return None try: # Try parsing ISO format from datetime import datetime parsed = datetime.fromisoformat(date.replace('Z', '+00:00')) return parsed.strftime('%Y-%m-%d') except: return date def compare_with_tolerance(expected: float, actual, tolerance: float) -> bool: """Compare numbers with tolerance.""" if actual is None: return False # Handle string values from API try: actual_float = float(actual) except (ValueError, TypeError): return False diff = abs(expected - actual_float) threshold = expected * tolerance return diff <= threshold or diff <= 0.01 # Within tolerance or 1 cent def submit_ocr_job(api_base: str, token: str, pdf_path: Path, engine: str) -> dict: """Submit a PDF file for OCR processing and wait for result. Returns dict with detailed timing information: - timing.submit_duration: time to submit job and get job_id - timing.poll_count: number of poll requests made - timing.poll_duration: total time spent polling - timing.wall_time: total elapsed time (submit + poll) - timing.api_reported_ms: processing time reported by API """ headers = {'Authorization': f'Bearer {token}'} # Timing tracking timing = { 'submit_duration': 0.0, 'poll_count': 0, 'poll_duration': 0.0, 'wall_time': 0.0, 'api_reported_ms': 0, } wall_start = time.time() # Submit job submit_start = time.time() with open(pdf_path, 'rb') as f: files = {'file': (pdf_path.name, f, 'application/pdf')} response = requests.post( f'{api_base}/api/data-entry/ocr/extract?engine={engine}', headers=headers, files=files, timeout=60 ) timing['submit_duration'] = time.time() - submit_start if not response.ok: timing['wall_time'] = time.time() - wall_start return {'status': 'failed', 'error': f'Submit failed: {response.status_code} - {response.text}', 'timing': timing} job = response.json() job_id = job.get('job_id') if not job_id: timing['wall_time'] = time.time() - wall_start return {'status': 'failed', 'error': 'No job_id in response', 'timing': timing} # Poll for result poll_start = time.time() max_wait = 120 # 2 minutes while time.time() - wall_start < max_wait: timing['poll_count'] += 1 poll_response = requests.get( f'{api_base}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30', headers=headers, timeout=35 ) if not poll_response.ok: time.sleep(1) continue job_status = poll_response.json() if job_status.get('status') == 'completed': timing['poll_duration'] = time.time() - poll_start timing['wall_time'] = time.time() - wall_start # Detailed timing from API timing['queue_wait_ms'] = job_status.get('queue_wait_ms', 0) or 0 timing['ocr_time_ms'] = job_status.get('ocr_time_ms', 0) or 0 timing['processing_time_ms'] = job_status.get('processing_time_ms', 0) or 0 return { 'status': 'completed', 'result': job_status.get('result', {}), 'processing_time_ms': job_status.get('processing_time_ms', 0), 'timing': timing } if job_status.get('status') == 'failed': timing['poll_duration'] = time.time() - poll_start timing['wall_time'] = time.time() - wall_start return {'status': 'failed', 'error': job_status.get('error', 'Unknown error'), 'timing': timing} # Still pending - show status but don't spam if timing['poll_count'] <= 3 or timing['poll_count'] % 5 == 0: print(f" Status: {job_status.get('status')}, position: {job_status.get('queue_position')}, polls: {timing['poll_count']}") timing['poll_duration'] = time.time() - poll_start timing['wall_time'] = time.time() - wall_start return {'status': 'failed', 'error': 'Timeout waiting for OCR result', 'timing': timing} def main(): parser = argparse.ArgumentParser(description='OCR Direct Validation') parser.add_argument('--engine', default='doctr_plus', choices=['tesseract', 'doctr', 'doctr_plus', 'paddleocr'], help='OCR engine to use (doctr_plus recommended)') parser.add_argument('--receipt', help='Specific receipt ID to test (e.g., receipt_01)') parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output') parser.add_argument('--api-base', default='http://localhost:8000', help='API base URL') parser.add_argument('--stop-on-issue', action='store_true', help='Stop at first receipt with wall_time > 7.5s or extraction errors') parser.add_argument('--include-multipage', action='store_true', help='Include multi-page PDFs (normally skipped)') args = parser.parse_args() # Paths script_dir = Path(__file__).parent expected_path = script_dir / 'expected_receipts.json' pdf_base_path = script_dir.parent.parent / 'tests' / 'fixtures' / 'ocr-samples' # JWT secret from environment or default jwt_secret = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32') # Create test token token = create_test_token(jwt_secret) # Load expected data print(f"\n{'='*60}") print("OCR API VALIDATION") print(f"{'='*60}") print(f"Engine: {args.engine}") print(f"API Base: {args.api_base}") print(f"Expected data: {expected_path}") print(f"PDF folder: {pdf_base_path}") with open(expected_path) as f: expected_data = json.load(f) # Filter receipts receipts_to_test = expected_data['receipts'] # Skip multi-page PDFs unless explicitly included if not args.include_multipage: original_count = len(receipts_to_test) receipts_to_test = [r for r in receipts_to_test if r.get('page') is None] skipped = original_count - len(receipts_to_test) if skipped > 0: print(f"Skipping {skipped} multi-page PDF entries (use --include-multipage to include)") # Filter by specific receipt ID if requested if args.receipt: receipts_to_test = [r for r in receipts_to_test if r['id'] == args.receipt] if not receipts_to_test: print(f"\nError: Receipt ID '{args.receipt}' not found") sys.exit(1) print(f"Receipts to test: {len(receipts_to_test)}") print(f"{'='*60}\n") # Results storage results: List[Dict[str, Any]] = [] # Test each receipt for expected in receipts_to_test: pdf_path = pdf_base_path / expected['filename'] if not pdf_path.exists(): print(f"[SKIP] File not found: {expected['filename']}") continue print(f"[TEST] Processing: {expected['filename']}") try: # Submit OCR job via API start_time = datetime.now() ocr_result = submit_ocr_job(args.api_base, token, pdf_path, args.engine) # Handle processing_time which may be string or number raw_time = ocr_result.get('processing_time_ms') if raw_time is not None: processing_time = float(raw_time) else: processing_time = (datetime.now() - start_time).total_seconds() * 1000 if ocr_result.get('status') == 'failed': print(f" [ERROR] OCR failed: {ocr_result.get('error')}") results.append({ 'receipt_id': expected['id'], 'filename': expected['filename'], 'status': 'failed', 'error': ocr_result.get('error'), }) continue # Get extracted values extracted = ocr_result.get('result', {}) # Safe float conversion helper def safe_float(value, default=0.0): if value is None: return default try: return float(value) except (ValueError, TypeError): return default # Compare results comparison = { 'receipt_id': expected['id'], 'filename': expected['filename'], 'status': 'completed', 'total_expected': expected['total'], 'total_extracted': safe_float(extracted.get('amount'), None), 'total_match': False, 'date_expected': expected['data_bon'], 'date_extracted': extracted.get('receipt_date'), 'date_match': False, 'cui_expected': expected['cui_furnizor'], 'cui_extracted': extracted.get('cui'), 'cui_match': False, 'tva_expected': expected['total_tva'], 'tva_extracted': safe_float(extracted.get('tva_total'), None), 'tva_match': False, 'confidence': safe_float(extracted.get('overall_confidence'), 0), 'processing_time_ms': processing_time, 'ocr_engine': extracted.get('ocr_engine', args.engine), 'errors': [], # NEW: Save full extraction for analysis 'full_extraction': { 'amount': extracted.get('amount'), 'receipt_date': extracted.get('receipt_date'), 'cui': extracted.get('cui'), 'tva_total': extracted.get('tva_total'), 'tva_entries': extracted.get('tva_entries', []), 'supplier_name': extracted.get('supplier_name'), 'receipt_number': extracted.get('receipt_number'), 'payment_methods': extracted.get('payment_methods', []), 'items_count': extracted.get('items_count'), 'overall_confidence': extracted.get('overall_confidence'), 'confidence_amount': extracted.get('confidence_amount'), 'confidence_date': extracted.get('confidence_date'), 'confidence_cui': extracted.get('confidence_cui'), }, # NEW: Save raw OCR texts from each engine pass 'raw_texts': extracted.get('raw_texts', []), } # Compare TOTAL comparison['total_match'] = compare_with_tolerance( expected['total'], extracted.get('amount'), 0.02 # 2% tolerance ) if not comparison['total_match']: comparison['errors'].append( f"TOTAL: expected {expected['total']}, got {extracted.get('amount')}" ) # Compare DATE normalized_expected_date = normalize_date(expected['data_bon']) normalized_extracted_date = normalize_date(extracted.get('receipt_date')) comparison['date_match'] = normalized_expected_date == normalized_extracted_date if not comparison['date_match']: comparison['errors'].append( f"DATE: expected {normalized_expected_date}, got {normalized_extracted_date}" ) # Compare CUI normalized_expected_cui = normalize_cui(expected['cui_furnizor']) normalized_extracted_cui = normalize_cui(extracted.get('cui')) comparison['cui_match'] = normalized_expected_cui == normalized_extracted_cui if not comparison['cui_match']: comparison['errors'].append( f"CUI: expected {normalized_expected_cui}, got {normalized_extracted_cui}" ) # Compare TVA if expected['total_tva'] > 0: comparison['tva_match'] = compare_with_tolerance( expected['total_tva'], extracted.get('tva_total'), 0.05 # 5% tolerance ) if not comparison['tva_match']: comparison['errors'].append( f"TVA: expected {expected['total_tva']}, got {extracted.get('tva_total')}" ) else: # No TVA expected (neplatitor TVA) tva_extracted = safe_float(extracted.get('tva_total'), None) comparison['tva_match'] = tva_extracted is None or tva_extracted == 0 or tva_extracted == 0.0 results.append(comparison) # Get timing info from API (detailed breakdown) t = ocr_result.get('timing', {}) wall_ms = t.get('wall_time', 0) * 1000 queue_wait_ms = t.get('queue_wait_ms', 0) ocr_time_ms = t.get('ocr_time_ms', 0) processing_time_ms = t.get('processing_time_ms', 0) # Overhead = wall_time - processing_time (includes network, polling) overhead_ms = wall_ms - processing_time_ms if processing_time_ms else 0 # Print result status = 'PASS' if not comparison['errors'] else 'FAIL' print(f" [{status}] Total: {expected['total']} vs {extracted.get('amount')} ({comparison['total_match']})") print(f" Date: {expected['data_bon']} vs {extracted.get('receipt_date')} ({comparison['date_match']})") print(f" CUI: {expected['cui_furnizor']} vs {extracted.get('cui')} ({comparison['cui_match']})") print(f" TVA: {expected['total_tva']} vs {extracted.get('tva_total')} ({comparison['tva_match']})") print(f" Confidence: {comparison['confidence']*100:.1f}%") # Print detailed timing breakdown print(f" TIMING: ocr={ocr_time_ms}ms, queue_wait={queue_wait_ms}ms, " f"job_total={processing_time_ms}ms, wall={wall_ms:.0f}ms") print(f" overhead={overhead_ms:.0f}ms (wall - job_total)") if comparison['errors'] and args.verbose: for err in comparison['errors']: print(f" Error: {err}") # Stop on issue if requested if args.stop_on_issue: has_errors = len(comparison['errors']) > 0 # Use OCR time for threshold (actual processing, not queue wait) ocr_too_slow = ocr_time_ms > 10000 # 10s threshold for actual OCR if has_errors or ocr_too_slow: print(f"\n{'='*60}") print(f"STOP: Issue detected on {expected['filename']}") print(f"{'='*60}") if ocr_too_slow: print(f" SLOW: ocr_time={ocr_time_ms}ms > 10000ms threshold") if has_errors: print(f" ERRORS: {comparison['errors']}") print(f"\n Full timing breakdown:") print(f" ocr_time_ms: {ocr_time_ms}ms (actual OCR engine time)") print(f" queue_wait_ms: {queue_wait_ms}ms (waiting in queue)") print(f" processing_time_ms: {processing_time_ms}ms (job total)") print(f" wall_time: {wall_ms:.0f}ms (client-side)") print(f" overhead: {overhead_ms:.0f}ms (network + polling)") print(f"\n Full extraction:") print(json.dumps(extracted, indent=4, default=str)) sys.exit(1) except Exception as e: import traceback print(f" [ERROR] {str(e)}") if args.verbose: traceback.print_exc() results.append({ 'receipt_id': expected['id'], 'filename': expected['filename'], 'status': 'error', 'error': str(e), }) # Calculate statistics completed_results = [r for r in results if r.get('status') == 'completed'] total_tests = len(completed_results) if total_tests == 0: print("\nNo tests completed successfully!") sys.exit(1) perfect_matches = len([r for r in completed_results if r['total_match'] and r['date_match'] and r['cui_match'] and r['tva_match']]) total_match_rate = len([r for r in completed_results if r['total_match']]) / total_tests date_match_rate = len([r for r in completed_results if r['date_match']]) / total_tests cui_match_rate = len([r for r in completed_results if r['cui_match']]) / total_tests tva_match_rate = len([r for r in completed_results if r['tva_match']]) / total_tests avg_confidence = sum(r['confidence'] for r in completed_results) / total_tests avg_processing_time = sum(r['processing_time_ms'] for r in completed_results) / total_tests # Print summary print(f"\n{'='*60}") print("OCR VALIDATION SUMMARY") print(f"{'='*60}") print(f"Total Receipts Tested: {total_tests}") print(f"Perfect Matches: {perfect_matches} ({perfect_matches/total_tests*100:.1f}%)") print("---") print(f"Total Amount Match Rate: {total_match_rate*100:.1f}%") print(f"Date Match Rate: {date_match_rate*100:.1f}%") print(f"CUI Match Rate: {cui_match_rate*100:.1f}%") print(f"TVA Match Rate: {tva_match_rate*100:.1f}%") print("---") print(f"Average Confidence: {avg_confidence*100:.1f}%") print(f"Average Processing Time: {avg_processing_time:.0f}ms") print(f"{'='*60}") # Failed receipts failed_results = [r for r in completed_results if r.get('errors')] if failed_results: print(f"\nFAILED RECEIPTS ({len(failed_results)}):") for r in failed_results: print(f" - {r['filename']}: {'; '.join(r['errors'])}") # Categorize problems for analysis problems_analysis = { 'cui_issues': [], 'tva_issues': [], 'total_issues': [], 'date_issues': [], 'confidence_issues': [], } for r in completed_results: # CUI issues if not r.get('cui_match'): cui_expected = normalize_cui(r.get('cui_expected')) cui_got = normalize_cui(r.get('cui_extracted')) issue_type = 'missing' if not cui_got else 'mismatch' # Check if it's a digit substitution (same length, 1-2 chars different) if cui_expected and cui_got and len(cui_expected) == len(cui_got): diff_count = sum(1 for a, b in zip(cui_expected, cui_got) if a != b) if diff_count <= 2: issue_type = 'digit_substitution' problems_analysis['cui_issues'].append({ 'file': r['filename'], 'expected': r.get('cui_expected'), 'got': r.get('cui_extracted'), 'type': issue_type, 'confidence': r.get('confidence', 0), }) # TVA issues if not r.get('tva_match'): tva_expected = r.get('tva_expected', 0) tva_got = r.get('tva_extracted') issue_type = 'missing' if tva_got is None else 'mismatch' # Check for 5% rate (books) if tva_expected and tva_expected > 0: full_ext = r.get('full_extraction', {}) total = full_ext.get('amount') if total and tva_expected: try: implied_rate = float(tva_expected) / float(total) * 100 if 4 <= implied_rate <= 6: issue_type = 'low_rate_5pct' except: pass problems_analysis['tva_issues'].append({ 'file': r['filename'], 'expected': tva_expected, 'got': tva_got, 'type': issue_type, 'tva_entries': r.get('full_extraction', {}).get('tva_entries', []), }) # TOTAL issues if not r.get('total_match'): problems_analysis['total_issues'].append({ 'file': r['filename'], 'expected': r.get('total_expected'), 'got': r.get('total_extracted'), 'confidence': r.get('confidence', 0), 'payment_methods': r.get('full_extraction', {}).get('payment_methods', []), }) # DATE issues if not r.get('date_match'): problems_analysis['date_issues'].append({ 'file': r['filename'], 'expected': r.get('date_expected'), 'got': r.get('date_extracted'), }) # Low confidence issues if r.get('confidence', 0) < 0.7: problems_analysis['confidence_issues'].append({ 'file': r['filename'], 'confidence': r.get('confidence', 0), 'errors': r.get('errors', []), }) # Save detailed report report = { 'test_date': datetime.now().isoformat(), 'engine': args.engine, 'statistics': { 'total_tests': total_tests, 'perfect_matches': perfect_matches, 'perfect_match_rate': perfect_matches / total_tests, 'total_match_rate': total_match_rate, 'date_match_rate': date_match_rate, 'cui_match_rate': cui_match_rate, 'tva_match_rate': tva_match_rate, 'avg_confidence': avg_confidence, 'avg_processing_time_ms': avg_processing_time, }, 'problems_analysis': problems_analysis, 'failed_receipts': [ {'filename': r['filename'], 'errors': r['errors']} for r in failed_results ], 'detailed_results': results, } # Save report with engine name in filename report_path = script_dir / f'ocr_report_{args.engine.replace("-", "_")}_FULL.json' with open(report_path, 'w') as f: json.dump(report, f, indent=2, default=str) print(f"\nReport saved to: {report_path}") # Exit with error if match rates are below threshold if total_match_rate < 0.8: print(f"\n[FAIL] Total match rate {total_match_rate*100:.1f}% is below 80% threshold") sys.exit(1) print("\n[PASS] OCR validation completed successfully!") if __name__ == '__main__': main()