feat(ocr): Add docTR OCR engine with metrics infrastructure

Add docTR as primary OCR engine with 2-tier sequential processing,
OCR metrics tracking, and simplified engine selection.

Features:
- docTR OCR engine with light+medium preprocessing tiers
- doctr_plus mode with early exit optimization (~65% fast path)
- OCR metrics dashboard with per-engine statistics
- User OCR preference persistence
- Parallel worker pool for OCR processing
- Cross-validation for extraction quality

Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-02 05:37:16 +02:00
parent 74f7aefc26
commit 495790411f
75 changed files with 23349 additions and 1311 deletions

View File

@@ -13,13 +13,14 @@ Schema:
status TEXT NOT NULL, -- pending, processing, completed, failed
file_path TEXT NOT NULL, -- Path to uploaded file
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'auto',
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT, -- JSON extraction result
error_message TEXT,
processing_time_ms INTEGER,
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
ocr_time_ms INTEGER, -- Actual OCR engine processing time
created_by TEXT, -- Username
original_filename TEXT,
expires_at TIMESTAMP
@@ -74,17 +75,26 @@ class OCRJob:
status: OCRJobStatus
file_path: str
mime_type: str
engine: str = "auto"
engine: str = "doctr_plus"
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result_json: Optional[str] = None
error_message: Optional[str] = None
processing_time_ms: Optional[int] = None
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
created_by: Optional[str] = None
original_filename: Optional[str] = None
expires_at: Optional[datetime] = None
@property
def queue_wait_ms(self) -> Optional[int]:
"""Calculate queue wait time (created_at to started_at)."""
if self.created_at and self.started_at:
delta = self.started_at - self.created_at
return int(delta.total_seconds() * 1000)
return None
@property
def result(self) -> Optional[Dict]:
"""Parse result_json to dict."""
@@ -143,19 +153,27 @@ class OCRJobQueue:
status TEXT NOT NULL DEFAULT 'pending',
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'auto',
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT,
error_message TEXT,
processing_time_ms INTEGER,
ocr_time_ms INTEGER,
created_by TEXT,
original_filename TEXT,
expires_at TIMESTAMP
)
''')
# Migration: add ocr_time_ms column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
except Exception:
pass # Column already exists
# Index for efficient queue queries
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
@@ -177,7 +195,7 @@ class OCRJobQueue:
self,
file_bytes: bytes,
mime_type: str,
engine: str = "auto",
engine: str = "doctr_plus",
username: Optional[str] = None,
original_filename: Optional[str] = None
) -> OCRJob:
@@ -189,7 +207,7 @@ class OCRJobQueue:
Args:
file_bytes: Raw file bytes
mime_type: MIME type of file
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
username: Username of requester
original_filename: Original filename from upload
@@ -301,24 +319,52 @@ class OCRJobQueue:
async def get_next_pending(self) -> Optional[OCRJob]:
"""
Get the next pending job (oldest first).
Get the next pending job (oldest first) and atomically mark it as processing.
This prevents race conditions in parallel processing - only one worker
can claim each job.
Returns:
Next OCRJob to process or None if queue empty
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_job(row)
now = datetime.utcnow()
async with self._lock: # Serialize access to prevent race conditions
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
# Get the next pending job
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if not row:
return None
job_id = row['id']
# Atomically mark as processing
await db.execute('''
UPDATE ocr_jobs
SET status = 'processing', started_at = ?
WHERE id = ? AND status = 'pending'
''', (now.isoformat(), job_id))
await db.commit()
# Fetch the updated job
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
updated_row = await cursor.fetchone()
if updated_row:
return self._row_to_job(updated_row)
return None
async def update_status(
@@ -327,7 +373,8 @@ class OCRJobQueue:
status: OCRJobStatus,
result: Optional[Dict] = None,
error: Optional[str] = None,
processing_time_ms: Optional[int] = None
processing_time_ms: Optional[int] = None,
ocr_time_ms: Optional[int] = None
) -> bool:
"""
Update job status.
@@ -337,7 +384,8 @@ class OCRJobQueue:
status: New status
result: Extraction result dict (for completed)
error: Error message (for failed)
processing_time_ms: Processing time
processing_time_ms: Total job processing time (started_at to completed_at)
ocr_time_ms: Actual OCR engine processing time
Returns:
True if update successful
@@ -359,18 +407,18 @@ class OCRJobQueue:
elif status == OCRJobStatus.completed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), result_json, processing_time_ms, job_id)
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
elif status == OCRJobStatus.failed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), error, processing_time_ms, job_id)
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
else:
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
@@ -542,13 +590,14 @@ class OCRJobQueue:
status=OCRJobStatus(row['status']),
file_path=row['file_path'],
mime_type=row['mime_type'],
engine=row['engine'] or 'auto',
engine=row['engine'] or 'doctr_plus',
created_at=parse_datetime(row['created_at']),
started_at=parse_datetime(row['started_at']),
completed_at=parse_datetime(row['completed_at']),
result_json=row['result_json'],
error_message=row['error_message'],
processing_time_ms=row['processing_time_ms'],
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
created_by=row['created_by'],
original_filename=row['original_filename'],
expires_at=parse_datetime(row['expires_at']),

View File

@@ -2,7 +2,7 @@
OCR Job Worker - Background Task for Queue Processing
Runs as an asyncio background task in FastAPI.
Continuously polls the job queue and processes OCR requests.
Continuously polls the job queue and processes OCR requests IN PARALLEL.
Architecture:
FastAPI startup
@@ -12,18 +12,19 @@ Architecture:
asyncio.create_task(_job_worker_loop())
while True:
job = job_queue.get_next_pending()
if job:
result = ocr_worker_pool.submit_task(...)
job_queue.update_status(...)
await asyncio.sleep(0.5)
# Process up to OCR_WORKERS jobs concurrently
jobs = get_pending_jobs(limit=available_slots)
for job in jobs:
asyncio.create_task(_process_job(job))
await asyncio.sleep(0.1)
"""
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import Optional
from typing import Optional, Set
from .job_queue import job_queue, OCRJobStatus, OCRJob
from .ocr_worker_pool import ocr_worker_pool
@@ -34,47 +35,76 @@ logger = logging.getLogger(__name__)
_job_worker_task: Optional[asyncio.Task] = None
_cleanup_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
# Configuration
POLL_INTERVAL_SECONDS = 0.5 # How often to check for new jobs
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
async def _job_worker_loop() -> None:
"""
Main worker loop - processes jobs from queue.
Main worker loop - processes jobs from queue IN PARALLEL.
Runs continuously until shutdown. Polls queue every 0.5s
and submits jobs to worker pool for processing.
Runs continuously until shutdown. Uses semaphore to limit
concurrent jobs to OCR_WORKERS count. Launches jobs as
background tasks without waiting for completion.
"""
global _shutdown_event
global _shutdown_event, _active_tasks, _concurrency_semaphore
logger.info("[JobWorker] Starting worker loop...")
# Get max concurrent jobs from env (matches worker pool size)
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
_active_tasks = set()
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
_shutdown_event = asyncio.Event()
consecutive_errors = 0
max_consecutive_errors = 5
max_consecutive_errors = 10
while not _shutdown_event.is_set():
try:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0 # Reset error counter on success
await _process_job(job)
else:
# No jobs - wait before polling again
# Clean up completed tasks
done_tasks = {t for t in _active_tasks if t.done()}
for task in done_tasks:
_active_tasks.discard(task)
# Check for exceptions
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass # Normal timeout, continue loop
task.result()
except Exception as e:
logger.error(f"[JobWorker] Task failed: {e}")
# Check if we have capacity for more jobs
active_count = len(_active_tasks)
available_slots = max_concurrent - active_count
if available_slots > 0:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0
# Launch job processing as background task
task = asyncio.create_task(_process_job_with_semaphore(job))
_active_tasks.add(task)
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
else:
# No pending jobs - wait briefly
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass
else:
# At capacity - wait for a slot to free up
await asyncio.sleep(POLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
logger.info("[JobWorker] Worker loop cancelled")
@@ -88,27 +118,46 @@ async def _job_worker_loop() -> None:
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
break
# Backoff on errors
await asyncio.sleep(min(consecutive_errors * 2, 30))
# Wait for active tasks to complete on shutdown
if _active_tasks:
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
await asyncio.gather(*_active_tasks, return_exceptions=True)
logger.info("[JobWorker] Worker loop stopped")
async def _process_job_with_semaphore(job: OCRJob) -> None:
"""
Process job with semaphore to limit concurrency.
Acquires semaphore before processing, releases after.
This ensures we don't exceed OCR_WORKERS concurrent jobs.
"""
global _concurrency_semaphore
async with _concurrency_semaphore:
await _process_job(job)
async def _process_job(job: OCRJob) -> None:
"""
Process a single OCR job.
Reads file, submits to worker pool, updates job status.
Reads file, submits to worker pool, updates job status,
and saves metrics for analytics.
Args:
job: OCRJob to process
"""
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
start_time = time.time()
file_size = 0
file_type = "image/jpeg"
try:
# Mark as processing
await job_queue.update_status(job.id, OCRJobStatus.processing)
# Note: Job already marked as 'processing' atomically in get_next_pending()
# Read file bytes
file_path = Path(job.file_path)
@@ -118,6 +167,10 @@ async def _process_job(job: OCRJob) -> None:
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_size = len(file_bytes)
# Determine file type from job or extension
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
# Submit to worker pool
result = await ocr_worker_pool.submit_task(
image_bytes=file_bytes,
@@ -132,14 +185,43 @@ async def _process_job(job: OCRJob) -> None:
# Job completed successfully
extraction = result.get("extraction", {})
# Include raw_texts for analysis (from all OCR engine passes)
extraction['raw_texts'] = result.get("raw_texts", [])
# Extract actual OCR processing time from extraction result
ocr_time_ms = extraction.get('processing_time_ms', 0)
# Debug: log suggested_payment_mode
spm = extraction.get('suggested_payment_mode')
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.completed,
result=extraction,
processing_time_ms=elapsed_ms
processing_time_ms=elapsed_ms,
ocr_time_ms=ocr_time_ms
)
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms")
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
# Save metrics for successful job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=extraction.get('ocr_engine', job.engine),
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=True,
overall_confidence=extraction.get('overall_confidence', 0.0),
fields_extracted=_count_extracted_fields(extraction),
needs_manual_review=extraction.get('needs_manual_review'),
validation_warnings_count=len(extraction.get('validation_warnings', [])),
validation_errors_count=len(extraction.get('validation_errors', [])),
)
else:
# Job failed
@@ -154,6 +236,20 @@ async def _process_job(job: OCRJob) -> None:
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
# Save metrics for failed job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=error_msg,
)
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
@@ -166,6 +262,20 @@ async def _process_job(job: OCRJob) -> None:
processing_time_ms=elapsed_ms
)
# Save metrics for error job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=str(e),
)
finally:
# Cleanup file after processing
try:
@@ -340,3 +450,96 @@ def estimate_wait_time(queue_position: int) -> int:
# Estimate: position * average_time
return int(queue_position * avg_time)
# ============================================================================
# Metrics Helper Functions
# ============================================================================
async def _save_job_metrics(
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
) -> None:
"""
Save OCR job metrics to database for analytics.
Called after each job completes (success or failure).
Errors are logged but don't affect job processing.
"""
try:
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
async with await get_db_session() as session:
await OCRMetricsCRUD.create(
session=session,
job_id=job_id,
username=username,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
except Exception as e:
# Log but don't fail - metrics are nice-to-have
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
def _count_extracted_fields(extraction: dict) -> int:
"""
Count number of successfully extracted fields from OCR result.
Counts non-None values in key fields.
"""
key_fields = [
'receipt_number',
'receipt_date',
'amount',
'partner_name',
'cui',
'tva_total',
'address',
'items_count',
]
count = 0
for field in key_fields:
value = extraction.get(field)
if value is not None and value != '' and value != []:
count += 1
# Also count TVA entries if present
tva_entries = extraction.get('tva_entries', [])
if tva_entries and len(tva_entries) > 0:
count += 1
# Count payment methods if present
payment_methods = extraction.get('payment_methods', [])
if payment_methods and len(payment_methods) > 0:
count += 1
return count

View File

@@ -1,11 +1,12 @@
"""
OCR Worker Pool Manager
Manages a ProcessPoolExecutor with persistent PaddleOCR initialization.
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
Key features:
- ProcessPoolExecutor with max_workers=1 (sequential, no memory leak)
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
- mp_context='spawn' for Windows IIS compatibility
- PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- atexit + signal handlers for cleanup
- Health check with auto-respawn
- Orphan process cleanup on Windows
@@ -29,7 +30,7 @@ import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, Future
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Optional
@@ -48,8 +49,8 @@ class OCRWorkerPool:
"""
Singleton manager for OCR ProcessPoolExecutor.
Ensures PaddleOCR is loaded once and reused for all requests.
Uses max_tasks_per_child=None to keep worker alive indefinitely.
Ensures OCR engines are loaded once and reused for all requests.
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
"""
_instance: Optional["OCRWorkerPool"] = None
@@ -86,7 +87,7 @@ class OCRWorkerPool:
Initialize the ProcessPoolExecutor.
Creates executor with spawn context for Windows compatibility.
Uses max_tasks_per_child=None to keep worker alive (persistent PaddleOCR).
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
Returns:
True if initialization successful
@@ -103,18 +104,30 @@ class OCRWorkerPool:
# Cleanup any orphan workers from previous runs
self._cleanup_orphan_workers()
# Read configuration from environment
max_workers = int(os.getenv('OCR_WORKERS', '2'))
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
# 0 means no restart (None in ProcessPoolExecutor)
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
# Create executor with spawn context (Windows compatible)
# Use mp_context='spawn' explicitly for cross-platform consistency
mp_context = mp.get_context('spawn')
self._executor = ProcessPoolExecutor(
max_workers=1, # Single worker for sequential processing
mp_context=mp_context,
initializer=_worker_initializer,
max_tasks_per_child=None, # Keep worker alive indefinitely
)
# max_tasks_per_child only available in Python 3.11+
executor_kwargs = {
'max_workers': max_workers,
'mp_context': mp_context,
'initializer': _worker_initializer,
}
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
else:
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
logger.info("[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers=1)")
self._executor = ProcessPoolExecutor(**executor_kwargs)
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
return True
except Exception as e:
@@ -173,7 +186,7 @@ class OCRWorkerPool:
async def submit_task(
self,
image_bytes: bytes,
engine: str = "auto",
engine: str = "doctr_plus",
preprocessing: str = "auto",
timeout: float = 120.0
) -> dict:
@@ -182,7 +195,7 @@ class OCRWorkerPool:
Args:
image_bytes: Raw image bytes
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
timeout: Maximum processing time in seconds
@@ -339,6 +352,7 @@ class OCRWorkerPool:
# Global engines - persist between tasks in worker process
_paddle_engine = None
_tesseract_engine = None
_doctr_engine = None # docTR engine (PyTorch backend)
_worker_initialized = False
@@ -346,40 +360,92 @@ def _worker_initializer() -> None:
"""
Called once when worker process spawns.
Initializes global OCR engines that persist between tasks.
This is where PaddleOCR loading happens (15-20 seconds).
Initializes global OCR engines IN PARALLEL for faster startup.
Uses ThreadPoolExecutor to load enabled engines concurrently.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Total warmup time = max(engine_times) instead of sum(engine_times).
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
if _worker_initialized:
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
return
print(f"[Worker {os.getpid()}] Initializing OCR engines...", flush=True)
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
if paddle_enabled:
enabled_engines.append("paddle")
if tesseract_enabled:
enabled_engines.append("tesseract")
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
if not paddle_enabled:
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
if not tesseract_enabled:
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
start_time = time.time()
# Initialize PaddleOCR
try:
# Import inside worker to avoid import issues in main process
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
_paddle_engine = initialize_paddle_engine()
print(f"[Worker {os.getpid()}] PaddleOCR loaded", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] PaddleOCR init failed: {e}", flush=True)
_paddle_engine = None
# Define loader functions - each runs in its own thread
def load_doctr():
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
engine = initialize_doctr_engine()
return ("doctr", engine, None)
except Exception as e:
return ("doctr", None, str(e))
# Initialize Tesseract
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
_tesseract_engine = TesseractEngine()
print(f"[Worker {os.getpid()}] Tesseract loaded", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] Tesseract init failed: {e}", flush=True)
_tesseract_engine = None
def load_paddle():
if not paddle_enabled:
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
engine = initialize_paddle_engine()
return ("paddle", engine, None)
except Exception as e:
return ("paddle", None, str(e))
def load_tesseract():
if not tesseract_enabled:
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
engine = TesseractEngine()
return ("tesseract", engine, None)
except Exception as e:
return ("tesseract", None, str(e))
# Build list of futures for enabled engines only
futures_to_submit = [load_doctr] # docTR always loaded
if paddle_enabled:
futures_to_submit.append(load_paddle)
if tesseract_enabled:
futures_to_submit.append(load_tesseract)
# Load engines in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
futures = [executor.submit(fn) for fn in futures_to_submit]
for future in as_completed(futures):
name, engine, error = future.result()
if error and "disabled" not in error:
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
elif engine:
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
if name == "doctr":
_doctr_engine = engine
elif name == "paddle":
_paddle_engine = engine
elif name == "tesseract":
_tesseract_engine = engine
elapsed = time.time() - start_time
_worker_initialized = True
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s", flush=True)
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
def _warmup_task() -> dict:
@@ -389,7 +455,7 @@ def _warmup_task() -> dict:
Called at FastAPI startup to pre-warm the worker.
Returns success status and worker PID.
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
@@ -400,6 +466,14 @@ def _warmup_task() -> dict:
import numpy as np
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Test docTR if available (fastest engine)
if _doctr_engine is not None:
try:
_doctr_engine([dummy_img])
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
# Test PaddleOCR if available
if _paddle_engine is not None:
try:
@@ -414,6 +488,7 @@ def _warmup_task() -> dict:
return {
"success": True,
"pid": os.getpid(),
"doctr_available": _doctr_engine is not None,
"paddle_available": _paddle_engine is not None,
"tesseract_available": _tesseract_engine is not None
}
@@ -428,7 +503,7 @@ def _warmup_task() -> dict:
def _process_ocr_task(
image_bytes: bytes,
engine: str = "auto",
engine: str = "doctr_plus",
preprocessing: str = "auto"
) -> dict:
"""
@@ -439,13 +514,13 @@ def _process_ocr_task(
Args:
image_bytes: Raw image bytes
engine: OCR engine choice
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode
Returns:
Dict with extraction results
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
@@ -461,7 +536,8 @@ def _process_ocr_task(
paddle_engine=_paddle_engine,
tesseract_engine=_tesseract_engine,
engine=engine,
preprocessing=preprocessing
preprocessing=preprocessing,
doctr_engine=_doctr_engine
)
# Cleanup after each task

View File

@@ -6,6 +6,7 @@ Handles OCR processing with persistent engine instances.
Key features:
- PaddleOCR initialized ONCE at process spawn
- docTR initialized ONCE at process spawn (PyTorch backend)
- Tesseract as fallback/complement engine
- Multi-pass preprocessing (light → medium → tesseract)
- Automatic engine selection based on results
@@ -26,6 +27,13 @@ import numpy as np
# Disable PaddleOCR model source check for faster startup
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Memory optimization for docTR (prevents memory leak in multiprocessing)
# Source: https://github.com/mindee/doctr/issues/1594
os.environ['DOCTR_MULTIPROCESSING_DISABLE'] = 'TRUE'
# Reduce Intel oneDNN cache to save memory
os.environ['ONEDNN_PRIMITIVE_CACHE_CAPACITY'] = '1'
@dataclass
class OCRResult:
@@ -71,25 +79,67 @@ def initialize_paddle_engine():
return None
def initialize_doctr_engine():
"""
Initialize docTR engine (CPU only).
Called once at worker spawn. Returns the engine instance
that will be reused for all subsequent requests.
Note: DirectML (AMD GPU) has compatibility issues with docTR.
CUDA (NVIDIA) works but requires separate PyTorch build.
CPU mode is stable and well-optimized.
Returns:
docTR predictor instance or None if unavailable
"""
try:
print(f"[Worker {os.getpid()}] Loading docTR (PyTorch backend, CPU)...", flush=True)
start_time = time.time()
from doctr.models import ocr_predictor
# Initialize docTR predictor with pretrained models
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
doctr = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True,
assume_straight_pages=True,
straighten_pages=False,
preserve_aspect_ratio=True,
)
elapsed = time.time() - start_time
print(f"[Worker {os.getpid()}] docTR loaded in {elapsed:.1f}s", flush=True)
return doctr
except Exception as e:
print(f"[Worker {os.getpid()}] docTR init failed: {e}", flush=True)
return None
def process_ocr(
image_bytes: bytes,
paddle_engine,
tesseract_engine,
engine: str = "auto",
preprocessing: str = "auto"
engine: str = "doctr_plus",
preprocessing: str = "auto",
doctr_engine=None
) -> dict:
"""
Process OCR on image bytes.
Main entry point for OCR processing in worker process.
Uses adaptive multi-pass strategy for best results.
Uses the specified engine for text recognition.
Args:
image_bytes: Raw image bytes (JPEG, PNG, or PDF)
paddle_engine: Pre-initialized PaddleOCR instance (or None)
tesseract_engine: Pre-initialized TesseractEngine instance (or None)
engine: Engine selection ('auto', 'paddleocr', 'tesseract')
engine: Engine selection ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('auto', 'light', 'medium', 'heavy')
doctr_engine: Pre-initialized docTR instance (or None)
Returns:
Dict with extraction results:
@@ -101,14 +151,20 @@ def process_ocr(
"ocr_engine": str
}
"""
import sys
start_time = time.time()
print(f"[Worker {os.getpid()}] Processing OCR: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes", flush=True)
sys.stdout.flush()
try:
# Decode image from bytes
print(f"[Worker {os.getpid()}] Decoding image...", flush=True)
sys.stdout.flush()
image = _decode_image(image_bytes)
if image is None:
return {"success": False, "error": "Failed to decode image"}
print(f"[Worker {os.getpid()}] Image decoded: shape={image.shape}, dtype={image.dtype}, size={image.nbytes/1024/1024:.1f}MB", flush=True)
sys.stdout.flush()
# Import preprocessor
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
@@ -116,22 +172,36 @@ def process_ocr(
preprocessor = ImagePreprocessor()
extractor = ReceiptExtractor()
print(f"[Worker {os.getpid()}] Preprocessor and extractor initialized", flush=True)
sys.stdout.flush()
raw_texts = []
extraction = None
# Engine routing
if engine == "paddleocr":
extraction, raw_texts = _process_paddleocr_only(
image, paddle_engine, preprocessor, extractor
)
elif engine == "tesseract":
# Engine routing (available: tesseract, doctr, doctr_plus, paddleocr)
print(f"[Worker {os.getpid()}] Routing to engine: {engine}", flush=True)
sys.stdout.flush()
if engine == "tesseract":
extraction, raw_texts = _process_tesseract_only(
image, tesseract_engine, preprocessor, extractor
)
else: # auto
extraction, raw_texts = _process_adaptive(
image, paddle_engine, tesseract_engine, preprocessor, extractor
elif engine == "doctr":
extraction, raw_texts = _process_doctr_only(
image, doctr_engine, preprocessor, extractor
)
elif engine == "doctr_plus":
extraction, raw_texts = _process_doctr_plus(
image, doctr_engine, preprocessor, extractor
)
elif engine == "paddleocr":
extraction, raw_texts = _process_paddleocr_only(
image, paddle_engine, preprocessor, extractor
)
else:
# Default to doctr_plus if unknown engine specified
print(f"[OCR] Unknown engine '{engine}', defaulting to doctr_plus", flush=True)
extraction, raw_texts = _process_doctr_plus(
image, doctr_engine, preprocessor, extractor
)
# Calculate processing time
@@ -171,7 +241,11 @@ def process_ocr(
def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
"""Decode image from bytes (JPEG, PNG, or first page of PDF)."""
"""Decode image from bytes (JPEG, PNG, or first page of PDF).
For PDFs, uses 200 DPI which is sufficient for receipt OCR
and reduces processing time by ~50% vs 300 DPI.
"""
try:
# Try as regular image first
nparr = np.frombuffer(image_bytes, np.uint8)
@@ -180,18 +254,21 @@ def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
if image is not None:
return image
# Try as PDF
# Try as PDF - use 200 DPI for faster processing (sufficient for receipts)
try:
import pdf2image
from PIL import Image
images = pdf2image.convert_from_bytes(image_bytes, dpi=300)
# 200 DPI is sufficient for receipt text recognition
# 300 DPI was overkill and slowed down processing
images = pdf2image.convert_from_bytes(image_bytes, dpi=200)
if images:
# Convert first page to numpy array
pil_img = images[0]
print(f"[Worker {os.getpid()}] PDF decoded: {pil_img.width}x{pil_img.height} @ 200 DPI", flush=True)
return np.array(pil_img)
except Exception:
pass
except Exception as e:
print(f"[Worker {os.getpid()}] PDF decode error: {e}", flush=True)
return None
@@ -270,83 +347,275 @@ def _process_tesseract_only(
return extraction, raw_texts
def _process_adaptive(
def _process_doctr_only(
image: np.ndarray,
paddle_engine,
tesseract_engine,
doctr_engine,
preprocessor,
extractor
) -> Tuple[Any, List[str]]:
"""
Adaptive multi-pass OCR processing.
Process using docTR only (light + medium preprocessing).
Strategy:
1. PaddleOCR Light - fastest, best for clear PDFs
2. PaddleOCR Medium - if Light incomplete
3. Tesseract - complement missing fields only
Returns:
Tuple of (extraction_result, raw_texts_list)
docTR uses EXACT same preprocessing as PaddleOCR for consistency.
"""
raw_texts = []
extraction = None
# === STEP 1: PaddleOCR Light ===
if paddle_engine:
print("[OCR] Step 1: PaddleOCR + Light", flush=True)
light_img = preprocessor.preprocess_light(image)
paddle_light = _paddle_recognize(paddle_engine, light_img)
if doctr_engine is None:
return None, ["docTR not available"]
if paddle_light and paddle_light.text:
extraction = extractor.extract(paddle_light.text)
extraction.ocr_engine = "paddle-light"
raw_texts.append(f"=== PaddleOCR Light (conf: {paddle_light.confidence:.0%}) ===\n{paddle_light.text}")
# Step 1: Light preprocessing (same as PaddleOCR)
print("[OCR] Step 1: docTR + Light", flush=True)
light_img = preprocessor.preprocess_light(image)
doctr_light = _doctr_recognize(doctr_engine, light_img)
if _is_extraction_complete(extraction):
print("[OCR] Early exit - all fields found in Step 1", flush=True)
return extraction, raw_texts
if doctr_light and doctr_light.text:
extraction = extractor.extract(doctr_light.text)
extraction.ocr_engine = "doctr-light"
raw_texts.append(f"=== docTR Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
# === STEP 2: PaddleOCR Medium ===
if paddle_engine:
print("[OCR] Step 2: PaddleOCR + Medium", flush=True)
medium_img = preprocessor.preprocess_medium(image)
paddle_medium = _paddle_recognize(paddle_engine, medium_img)
if _is_extraction_complete(extraction):
return extraction, raw_texts
if paddle_medium and paddle_medium.text:
extraction_medium = extractor.extract(paddle_medium.text)
extraction_medium.ocr_engine = "paddle-medium"
raw_texts.append(f"=== PaddleOCR Medium (conf: {paddle_medium.confidence:.0%}) ===\n{paddle_medium.text}")
# Step 2: Medium preprocessing (same as PaddleOCR)
print("[OCR] Step 2: docTR + Medium", flush=True)
medium_img = preprocessor.preprocess_medium(image)
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
if extraction:
extraction = _merge_extractions(extraction, extraction_medium)
extraction.ocr_engine = "paddle-adaptive"
else:
extraction = extraction_medium
if doctr_medium and doctr_medium.text:
extraction_medium = extractor.extract(doctr_medium.text)
extraction_medium.ocr_engine = "doctr-medium"
raw_texts.append(f"=== docTR Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
if _is_extraction_complete(extraction):
print("[OCR] Early exit - all fields found after Step 2", flush=True)
return extraction, raw_texts
# === STEP 3: Tesseract (complement only) ===
if tesseract_engine:
print("[OCR] Step 3: Tesseract complement", flush=True)
tesseract_img = preprocessor.preprocess_for_tesseract(image)
tesseract_result = tesseract_engine.recognize(tesseract_img)
if tesseract_result and tesseract_result.text:
extraction_tess = extractor.extract(tesseract_result.text)
extraction_tess.ocr_engine = "tesseract"
raw_texts.append(f"=== Tesseract (conf: {tesseract_result.confidence:.0%}) ===\n{tesseract_result.text}")
if extraction:
extraction = _complement_extraction(extraction, extraction_tess)
extraction.ocr_engine = "adaptive-full"
else:
extraction = extraction_tess
if extraction:
extraction = _merge_extractions(extraction, extraction_medium)
extraction.ocr_engine = "doctr-adaptive"
else:
extraction = extraction_medium
return extraction, raw_texts
def _process_doctr_plus(
image: np.ndarray,
doctr_engine,
preprocessor,
extractor
) -> Tuple[Any, List[str]]:
"""
docTR Plus - Optimized 2-tier sequential processing with early exit.
Architecture:
- Tier 1: Light preprocessing (~4-5s)
→ Early exit if confidence >= 0.75 AND all fields valid AND cross-validations pass
- Tier 2: Medium preprocessing (only if Tier 1 insufficient, ~4-5s additional)
→ Merge with Tier 1 results
→ Mark for review if still problems
Performance:
- Fast path (80% receipts): ~4-5s (Tier 1 only)
- Slow path (20% receipts): ~8-9s (Tier 1 + Tier 2)
- Average: ~5-6s
Returns:
Tuple of (extraction_result, raw_texts_list)
extraction_result.needs_review = True if validation issues remain
"""
raw_texts = []
extraction = None
if doctr_engine is None:
return None, ["docTR not available"]
# ========== TIER 1: Light Preprocessing ==========
print("[docTR+] TIER 1: Light preprocessing", flush=True)
import time
tier1_start = time.time()
light_img = preprocessor.preprocess_light(image)
doctr_light = _doctr_recognize(doctr_engine, light_img)
tier1_time = time.time() - tier1_start
print(f"[docTR+] TIER 1 completed in {tier1_time:.1f}s", flush=True)
if doctr_light and doctr_light.text:
extraction = extractor.extract(doctr_light.text)
extraction.ocr_engine = "doctr-plus-light"
raw_texts.append(f"=== docTR+ Tier1/Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
# Early Exit Check: confidence >= 0.75 + cross-validations
if _is_extraction_valid_for_early_exit(extraction, min_confidence=0.75):
print(f"[docTR+] EARLY EXIT - Tier 1 sufficient (conf: {extraction.overall_confidence:.0%})", flush=True)
extraction.ocr_engine = "doctr-plus"
return extraction, raw_texts
print(f"[docTR+] Tier 1 incomplete or validation failed, proceeding to Tier 2...", flush=True)
# ========== TIER 2: Medium Preprocessing (only if needed) ==========
print("[docTR+] TIER 2: Medium preprocessing", flush=True)
tier2_start = time.time()
medium_img = preprocessor.preprocess_medium(image)
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
tier2_time = time.time() - tier2_start
print(f"[docTR+] TIER 2 completed in {tier2_time:.1f}s", flush=True)
if doctr_medium and doctr_medium.text:
extraction_medium = extractor.extract(doctr_medium.text)
extraction_medium.ocr_engine = "doctr-plus-medium"
raw_texts.append(f"=== docTR+ Tier2/Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
if extraction:
# Merge Tier 1 + Tier 2 results
extraction = _merge_extractions(extraction, extraction_medium)
else:
extraction = extraction_medium
# ========== FINAL VALIDATION ==========
if extraction:
extraction.ocr_engine = "doctr-plus"
# Mark for review if validation still fails after both tiers
passes_validation, penalty, errors = _quick_cross_validate(extraction)
if not passes_validation or extraction.overall_confidence < 0.75:
# Mark for human review using existing fields
extraction.needs_manual_review = True
if extraction.overall_confidence < 0.75:
extraction.validation_warnings.append(f"Low confidence: {extraction.overall_confidence:.0%}")
if not extraction.amount:
extraction.validation_errors.append("TOTAL not detected")
if not extraction.cui:
extraction.validation_warnings.append("CUI not detected")
if not extraction.tva_total and not extraction.tva_entries:
extraction.validation_warnings.append("TVA not detected")
if not extraction.receipt_date:
extraction.validation_warnings.append("Date not detected")
# Add cross-validation errors
extraction.validation_errors.extend(errors)
print(f"[docTR+] Marked for review: {extraction.validation_errors + extraction.validation_warnings}", flush=True)
else:
extraction.needs_manual_review = False
total_time = tier1_time + (tier2_time if 'tier2_time' in dir() else 0)
print(f"[docTR+] Total processing time: {total_time:.1f}s", flush=True)
return extraction, raw_texts
# =============================================================================
# VALIDATION HELPERS (used by doctr_plus for early exit decisions)
# =============================================================================
def _quick_cross_validate(extraction) -> tuple[bool, float, list[str]]:
"""
Quick cross-validation for OCR results.
Checks critical field correlations to detect obvious OCR errors.
Used by doctr_plus to decide whether to proceed to Tier 2 or exit early.
Returns:
Tuple of (passes_validation, confidence_penalty, error_messages)
"""
try:
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
if extraction is None:
return False, 1.0, ["No extraction result"]
# Convert extraction to dict for validation
# Build TVA entries dict for TVAEntriesSumRule (expects {code: amount})
tva_entries_dict = {}
if extraction.tva_entries:
for entry in extraction.tva_entries:
if isinstance(entry, dict):
code = entry.get('code', 'A')
amount = entry.get('amount', 0)
try:
tva_entries_dict[code] = float(amount)
except (TypeError, ValueError):
pass
validation_data = {
"amount": float(extraction.amount) if extraction.amount else None,
"tva": float(extraction.tva_total) if extraction.tva_total else None,
"tva_entries": tva_entries_dict, # For TVAEntriesSumRule: {code: amount}
"cui": extraction.cui, # For CUI checksum validation
}
# Also pass raw tva_entries for TVABasedTotalRule (for rate detection)
if extraction.tva_entries:
validation_data['tva_entries_raw'] = extraction.tva_entries
# Add payment methods if available (for TOTAL vs CARD+CASH validation)
if extraction.payment_methods:
try:
card_amount = sum(
float(p.get('amount', 0) if isinstance(p, dict) else 0)
for p in extraction.payment_methods
if isinstance(p, dict) and p.get('method') == 'CARD'
)
cash_amount = sum(
float(p.get('amount', 0) if isinstance(p, dict) else 0)
for p in extraction.payment_methods
if isinstance(p, dict) and p.get('method') == 'NUMERAR'
)
validation_data['card_amount'] = card_amount
validation_data['cash_amount'] = cash_amount
except Exception as e:
print(f"[Worker {os.getpid()}] Payment method validation error: {e}", flush=True)
# Run quick validation
validator = OCRValidationEngine()
return validator.quick_validate_for_hybrid(validation_data)
except Exception as e:
# Never crash the process on validation errors
print(f"[Worker {os.getpid()}] Cross-validation error: {e}", flush=True)
import traceback
traceback.print_exc()
# Return "passes" to allow processing to continue
return True, 0.0, [f"Validation skipped due to error: {str(e)}"]
def _is_extraction_valid_for_early_exit(extraction, min_confidence: float = 0.85) -> bool:
"""
Check if extraction is valid for early exit in doctr_plus.
Combines confidence check with cross-validation to prevent
early exit on OCR errors (e.g., wrong TOTAL but correct TVA).
Returns:
True only if:
1. Overall confidence >= min_confidence
2. Critical fields are present (AMOUNT, DATE, CUI)
3. Cross-validation passes (TOTAL matches TVA calculation, or no TVA)
"""
try:
# First check basic completeness (relaxed for early exit)
if not _is_extraction_complete(extraction, min_confidence, for_early_exit=True):
return False
# Then run cross-validation
passes_validation, penalty, errors = _quick_cross_validate(extraction)
if not passes_validation:
print(f"[Early Exit] BLOCKED: cross-validation failed: {errors}", flush=True)
return False
print(f"[Early Exit] OK: conf={extraction.overall_confidence:.0%}, validation passed", flush=True)
return True
except Exception as e:
# Never crash on validation - just continue to next engine
print(f"[Worker {os.getpid()}] Early exit check error: {e}", flush=True)
return False # Continue to next engine on error
def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
"""Run PaddleOCR recognition on image."""
try:
@@ -388,34 +657,191 @@ def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
return None
def _is_extraction_complete(ext, min_confidence: float = 0.85) -> bool:
"""Check if extraction has all required fields."""
def _doctr_recognize(doctr_engine, image: np.ndarray) -> Optional[OCRResult]:
"""
Run docTR recognition on image.
docTR requires RGB images, handles conversion automatically.
Uses same preprocessing as PaddleOCR for consistent results.
"""
try:
# docTR requires RGB images
if len(image.shape) == 2:
# Convert grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 3:
# Convert BGR (OpenCV) to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif image.shape[2] == 4:
# Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# docTR expects a list of numpy arrays (pages)
result = doctr_engine([image])
if not result or not result.pages:
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
# Extract text from all pages
all_texts = []
all_confidences = []
boxes = []
for page in result.pages:
for block in page.blocks:
for line in block.lines:
line_text = ' '.join(word.value for word in line.words)
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
all_texts.append(line_text)
all_confidences.append(line_confidence)
# Store word-level boxes
for word in line.words:
boxes.append({
'text': word.value,
'confidence': float(word.confidence),
'box': word.geometry # (xmin, ymin), (xmax, ymax)
})
text_result = '\n'.join(all_texts)
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="doctr"
)
except Exception as e:
print(f"[Worker] docTR error: {e}", flush=True)
return None
def _is_extraction_complete(ext, min_confidence: float = 0.85, for_early_exit: bool = False) -> bool:
"""
Check if extraction has required fields.
Args:
ext: Extraction result
min_confidence: Minimum overall confidence
for_early_exit: If True, use relaxed criteria (AMOUNT + DATE + CUI required)
If False, require all fields (strict mode for final validation)
Returns:
True if extraction meets completeness criteria
"""
# Check confidence first
if ext.overall_confidence < min_confidence:
if for_early_exit:
print(f"[Early Exit] BLOCKED: confidence {ext.overall_confidence:.0%} < {min_confidence:.0%}", flush=True)
return False
has_number = bool(ext.receipt_number)
has_date = bool(ext.receipt_date)
has_amount = bool(ext.amount)
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
has_cui = bool(ext.cui)
return all([has_number, has_date, has_amount, has_tva, has_cui])
if for_early_exit:
# Relaxed criteria for early exit:
# - AMOUNT is required (core field)
# - DATE is required (needed for accounting)
# - CUI is required (needed for supplier identification)
# - TVA is NOT required (some receipts have 0% TVA)
# - receipt_number is NOT required (often missing)
required_ok = all([has_amount, has_date, has_cui])
if not required_ok:
missing = []
if not has_amount: missing.append("AMOUNT")
if not has_date: missing.append("DATE")
if not has_cui: missing.append("CUI")
print(f"[Early Exit] BLOCKED: missing required fields: {', '.join(missing)}", flush=True)
return required_ok
else:
# Strict criteria for final validation (all fields required)
has_number = bool(ext.receipt_number)
return all([has_number, has_date, has_amount, has_tva, has_cui])
def _merge_extractions(primary, secondary):
"""Merge two extractions, picking best fields from each."""
"""Merge two extractions, picking best fields from each.
Primary should be the higher-quality engine (e.g., docTR).
Secondary is the fallback engine (e.g., Tesseract).
Priority logic:
- AMOUNT: TVA validation wins over confidence. If both valid or both invalid,
uses confidence (or TVA diff for invalid cases).
- DATE/CUI: Validation-based, then confidence, then primary wins ties.
- OTHER FIELDS: Primary wins when both have values.
"""
from backend.modules.data_entry.services.ocr_extractor import ExtractionResult
result = ExtractionResult()
# Amount - prefer higher confidence
# Helper: Check if amount matches TVA calculation
def amount_passes_tva_validation(amount, tva_total, tva_entries):
if not amount or not tva_total:
return False, 0.0
try:
tva_rate = 0.21 # Default Romanian TVA
if tva_entries:
for entry in tva_entries:
if isinstance(entry, dict) and entry.get('percent'):
tva_rate = float(entry['percent']) / 100.0
break
# Expected TOTAL = TVA / rate * (1 + rate)
expected = float(tva_total) * (1 + tva_rate) / tva_rate
actual = float(amount)
diff_percent = abs(actual - expected) / expected if expected > 0 else 1.0
return diff_percent < 0.03, diff_percent # 3% tolerance
except:
return False, 1.0
# Amount - prefer TVA-validated value over confidence
if primary.amount and secondary.amount:
if primary.confidence_amount >= secondary.confidence_amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
else:
# Get TVA from the one with entries, or use any available
tva_total = primary.tva_total or secondary.tva_total
tva_entries = primary.tva_entries or secondary.tva_entries
primary_valid, primary_diff = amount_passes_tva_validation(
primary.amount, tva_total, tva_entries
)
secondary_valid, secondary_diff = amount_passes_tva_validation(
secondary.amount, tva_total, tva_entries
)
print(f"[Merge] Amount comparison: primary={primary.amount} (valid={primary_valid}, diff={primary_diff:.1%}), "
f"secondary={secondary.amount} (valid={secondary_valid}, diff={secondary_diff:.1%})", flush=True)
if secondary_valid and not primary_valid:
# Secondary passes validation, primary doesn't - use secondary!
print(f"[Merge] Using secondary amount {secondary.amount} (passes TVA validation)", flush=True)
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
elif primary_valid and not secondary_valid:
# Primary passes validation
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
elif primary_valid and secondary_valid:
# Both valid - use higher confidence
if primary.confidence_amount >= secondary.confidence_amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
else:
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
else:
# Neither valid - use the one closer to TVA calculation
if secondary_diff < primary_diff:
print(f"[Merge] Neither valid, using secondary {secondary.amount} (closer to TVA)", flush=True)
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
else:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
elif primary.amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
@@ -438,13 +864,15 @@ def _merge_extractions(primary, secondary):
result.receipt_date = secondary.receipt_date
result.confidence_date = secondary.confidence_date
# CUI - prefer valid format
# CUI - prefer valid format and version with RO prefix
# Use CUIChecksumRule static methods (single source of truth)
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
def is_valid_cui(cui):
if not cui:
return False
import re
cui_clean = re.sub(r'^RO', '', cui.upper())
return bool(re.match(r'^\d{6,10}$', cui_clean))
digits = CUIChecksumRule.extract_digits(cui)
return len(digits) >= 6 and len(digits) <= 10
if primary.cui and secondary.cui:
if is_valid_cui(primary.cui) and not is_valid_cui(secondary.cui):
@@ -452,22 +880,27 @@ def _merge_extractions(primary, secondary):
elif is_valid_cui(secondary.cui) and not is_valid_cui(primary.cui):
result.cui = secondary.cui
else:
result.cui = primary.cui
# Both valid - prefer the one with RO prefix if digits match
primary_digits = CUIChecksumRule.extract_digits(primary.cui)
secondary_digits = CUIChecksumRule.extract_digits(secondary.cui)
if primary_digits == secondary_digits:
if CUIChecksumRule.has_ro_prefix(secondary.cui) and not CUIChecksumRule.has_ro_prefix(primary.cui):
result.cui = secondary.cui # Prefer version with RO
print(f"[CUI Complement] Preferring secondary with RO: {secondary.cui}", flush=True)
else:
result.cui = primary.cui
else:
result.cui = primary.cui
elif primary.cui:
result.cui = primary.cui
elif secondary.cui:
result.cui = secondary.cui
# TVA entries
# TVA entries - ALWAYS prefer primary (docTR) when both have entries
if primary.tva_entries and secondary.tva_entries:
primary_total = sum(e.get('amount', Decimal('0')) for e in primary.tva_entries)
secondary_total = sum(e.get('amount', Decimal('0')) for e in secondary.tva_entries)
if primary_total >= secondary_total:
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
else:
result.tva_entries = secondary.tva_entries
result.tva_total = secondary.tva_total
# Always use primary (docTR) - higher quality OCR
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
elif primary.tva_entries:
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
@@ -483,12 +916,36 @@ def _merge_extractions(primary, secondary):
result.address = primary.address or secondary.address
result.items_count = primary.items_count or secondary.items_count
result.payment_methods = primary.payment_methods or secondary.payment_methods
result.suggested_payment_mode = getattr(primary, 'suggested_payment_mode', None) or getattr(secondary, 'suggested_payment_mode', None)
# Client fields
result.client_name = primary.client_name or secondary.client_name
result.client_cui = primary.client_cui or secondary.client_cui
result.client_address = primary.client_address or secondary.client_address
# Confidence fields - preserve from primary or pick best
if primary.confidence_vendor >= secondary.confidence_vendor:
result.confidence_vendor = primary.confidence_vendor
else:
result.confidence_vendor = secondary.confidence_vendor
if hasattr(primary, 'confidence_client') and hasattr(secondary, 'confidence_client'):
if primary.confidence_client >= secondary.confidence_client:
result.confidence_client = primary.confidence_client
else:
result.confidence_client = secondary.confidence_client
# Raw text - combine both for debugging/display
raw_texts = []
if primary.raw_text:
raw_texts.append(primary.raw_text)
if secondary.raw_text and secondary.raw_text != primary.raw_text:
raw_texts.append(secondary.raw_text)
result.raw_text = '\n---\n'.join(raw_texts) if raw_texts else ''
# Note: overall_confidence is a computed @property on ExtractionResult
# It automatically calculates from confidence_amount, confidence_date, confidence_vendor
return result
@@ -557,6 +1014,7 @@ def _extraction_to_dict(extraction) -> dict:
"address": extraction.address,
"items_count": extraction.items_count,
"payment_methods": extraction.payment_methods,
"suggested_payment_mode": getattr(extraction, 'suggested_payment_mode', None),
# Client data
"client_name": extraction.client_name,
"client_cui": extraction.client_cui,

View File

@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
result = rule.validate({"cui": "R01879855"})
# result.is_valid = False (checksum mismatch)
Static methods available for direct use:
CUIChecksumRule.calculate_checksum("1056260") -> 0
CUIChecksumRule.validate_checksum("10562600") -> True
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
"""
# Fixed multipliers for 9 positions (Romanian Mod 11)
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
@staticmethod
def calculate_checksum(cui_base: str) -> int:
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
Args:
cui_base: CUI digits WITHOUT the checksum digit (last digit)
Returns:
Expected checksum digit (0-9), or -1 if invalid input
"""
if not cui_base or not cui_base.isdigit():
return -1
# Pad base to 9 digits from LEFT
base_padded = cui_base.zfill(9)
base_digits = [int(d) for d in base_padded]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
# Calculate checksum
checksum = (weighted_sum * 10) % 11
if checksum == 10:
checksum = 0
return checksum
@staticmethod
def validate_checksum(cui_digits: str) -> bool:
"""Check if CUI checksum is valid.
Args:
cui_digits: Full CUI digits (including checksum as last digit)
Returns:
True if checksum is valid, False otherwise
"""
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
return False
base = cui_digits[:-1]
declared = int(cui_digits[-1])
expected = CUIChecksumRule.calculate_checksum(base)
return expected == declared
@staticmethod
def has_ro_prefix(cui: str) -> bool:
"""Check if CUI has RO prefix (proper format for VAT payers)."""
if not cui:
return False
return cui.upper().strip().startswith('RO')
@staticmethod
def extract_digits(cui: str) -> str:
"""Extract digits from CUI, removing RO/R0 prefix."""
if not cui:
return ""
cui = cui.strip().upper()
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui = cui[2:]
return ''.join(c for c in cui if c.isdigit())
@property
def rule_name(self) -> str:
return "CUI Checksum Check (Mod 11)"
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
message="No CUI to validate"
)
# Normalize: remove RO/R0 prefix
cui_clean = cui.strip().upper()
if cui_clean.startswith("RO"):
cui_clean = cui_clean[2:]
elif cui_clean.startswith("R0"):
cui_clean = cui_clean[2:]
# Use static method to extract digits
cui_clean = CUIChecksumRule.extract_digits(cui)
# Check format first
if not cui_clean.isdigit():
if not cui_clean:
return ValidationResult(
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
message="CUI format invalid, skipping checksum"
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
message="CUI length invalid, skipping checksum"
)
# Extract digits
digits = [int(d) for d in cui_clean]
checksum_declared = digits[-1]
base_digits = digits[:-1]
# Multipliers (trim to match base_digits length)
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
multipliers = multipliers[:len(base_digits)]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
# Calculate expected checksum
checksum_calculated = (weighted_sum * 10) % 11
if checksum_calculated == 10:
checksum_calculated = 0
if checksum_calculated != checksum_declared:
# Use static method to validate checksum
if not CUIChecksumRule.validate_checksum(cui_clean):
# Calculate expected for error message
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
declared = int(cui_clean[-1])
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
severity="warning"
)
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
)
class TVABasedTotalRule(ValidationRule):
"""Validate TOTAL using reverse calculation from TVA amount.
This is a CRITICAL validation that catches cases where OCR extracts
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
we can calculate expected TOTAL from TVA alone.
Formula:
Expected TOTAL = TVA / rate * (1 + rate)
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
For TVA rate 21%:
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
Example (benzina 27 oct):
TVA = 49.58, rate = 21%
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
Extracted TOTAL = 205.66 (WRONG!)
Rule detects mismatch and flags for escalation
Usage in multi-tier processing (e.g., doctr_plus):
If this rule fails, the engine should proceed to next tier
instead of returning early with potentially wrong data.
"""
def __init__(self, tolerance_percent: float = 0.02):
"""
Args:
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
"""
self.tolerance_percent = tolerance_percent
@property
def rule_name(self) -> str:
return "TVA-Based Total Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
total = data.get("amount")
tva = data.get("tva")
tva_entries = data.get("tva_entries", [])
if not total or not tva:
return ValidationResult(
is_valid=True,
message="Insufficient data for TVA-based total validation"
)
# Type safety
try:
total = float(total)
tva = float(tva)
except (TypeError, ValueError):
return ValidationResult(
is_valid=True,
message="Non-numeric values, skipping TVA-based total validation"
)
if tva <= 0 or total <= 0:
return ValidationResult(
is_valid=True,
message="Zero or negative values, skipping TVA-based total validation"
)
# Try to determine TVA rate from entries
tva_rate = None
# Check tva_entries for rate information
if tva_entries:
for entry in tva_entries:
if isinstance(entry, dict):
percent = entry.get('percent')
if percent:
try:
tva_rate = float(percent) / 100.0
break
except (TypeError, ValueError):
pass
# Fallback: try to calculate rate from TVA/TOTAL ratio
if not tva_rate:
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
# TVA/TOTAL = rate / (1 + rate)
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
base = total - tva
if base > 0:
calculated_rate = tva / base
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
if 0.04 <= calculated_rate <= 0.25:
tva_rate = calculated_rate
if not tva_rate:
# Assume most common rate: 21%
tva_rate = 0.21
# Calculate expected TOTAL from TVA
# TVA = BASE * rate → BASE = TVA / rate
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
expected_total = tva * (1 + tva_rate) / tva_rate
# Calculate difference
diff = abs(total - expected_total)
diff_percent = diff / expected_total if expected_total > 0 else 1.0
if diff_percent > self.tolerance_percent:
# Significant mismatch - OCR likely extracted TOTAL wrong
return ValidationResult(
is_valid=False,
confidence_penalty=0.5, # High penalty - this is a critical error
message=(
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
f"TVA-calculated {expected_total:.2f} RON "
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
f"Likely OCR error on TOTAL."
),
severity="error"
)
return ValidationResult(
is_valid=True,
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
)
class InterOCRConsistencyRule(ValidationRule):
"""Validate consistency between multiple OCR results.
@@ -562,6 +741,7 @@ class OCRValidationEngine:
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
PaymentSumRule(tolerance=0.02),
TVAEntriesSumRule(tolerance=0.02),
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
]
# Inter-OCR consistency rules
@@ -699,39 +879,508 @@ class OCRValidationEngine:
inter_ocr_ratios=inter_ocr_ratios
)
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
Runs critical cross-validation rules to detect obvious OCR errors.
Used to decide whether to proceed to next processing tier or exit early.
Args:
extraction_result: Extraction data dict with fields:
- amount: Extracted TOTAL
- tva: Extracted TVA total
- tva_entries: List of TVA entries with rates
Returns:
Tuple of (passes_validation, confidence_penalty, error_messages)
- passes_validation: True if no critical errors detected
- confidence_penalty: Cumulative penalty (0.0-1.0)
- error_messages: List of validation error messages
Example usage:
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
if not passes:
print(f"Validation failed: {errors}, proceeding to next tier")
# Continue to next processing tier instead of early exit
"""
errors = []
total_penalty = 0.0
# Critical rules for early-exit decision-making
# These determine if we can trust the extraction or need to proceed to next tier
critical_rules = [
# Cross-field validations (most important for detecting OCR errors)
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
# Format & checksum validations
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
CUIFormatRule(), # CUI should be 6-10 digits
# Sanity checks
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
]
for rule in critical_rules:
result = rule.validate(extraction_result)
if not result.is_valid:
errors.append(result.message)
total_penalty += result.confidence_penalty
# Cap penalty at 1.0
total_penalty = min(1.0, total_penalty)
passes = len(errors) == 0
return passes, total_penalty, errors
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
@staticmethod
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
"""Try to repair CUI by attempting 1-digit corrections.
OCR often misreads similar-looking digits:
- 5 ↔ 8 (most common in receipts)
- 6 ↔ 0
- 1 ↔ 7
- 3 ↔ 8
Algorithm:
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
2. Skip first digit (position 0) - usually reliable in CUI
3. Check checksum digit (last position) last
4. Prefer common OCR digit confusions (5↔8, 6↔0)
Args:
cui_digits: Original CUI digits (without RO prefix)
Returns:
Repaired CUI digits if 1-digit fix found, else None
"""
if len(cui_digits) < 6 or not cui_digits.isdigit():
return None
# If already valid, return as-is
if CUIChecksumRule.validate_checksum(cui_digits):
return cui_digits
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then position 1, then 0, then checksum
# Skip position 0 (first digit) - it's usually reliable
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
if CUIChecksumRule.validate_checksum(candidate):
print(f"[CUI Repair] Fixed {cui_digits}{candidate} (position {pos}: {original_digit}{replacement})", flush=True)
return candidate
# No single-digit fix found
return None
@staticmethod
def normalize_cui(cui: Optional[str]) -> Optional[str]:
"""Normalize CUI to RO prefix + digits format.
"""Normalize CUI - fix OCR errors but preserve original format.
Rules:
- R0 → RO (fix OCR error where O is read as 0)
- Keep RO prefix if original had it (platitor TVA)
- Do NOT add RO if original didn't have it (neplatitor TVA)
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
Examples:
10562600 → RO10562600
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
R010562600 → RO10562600 (fix R0 OCR error)
RO10562600 → RO10562600 (unchanged)
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
Args:
cui: Raw CUI string from OCR
Returns:
Normalized CUI with RO prefix, or None if invalid
Normalized CUI, or None if invalid
"""
if not cui:
return None
cui = cui.strip().upper()
# Remove existing prefix if present
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"):
cui = cui[2:]
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui if c.isdigit())
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
return None
# Add RO prefix
return f"RO{cui_digits}"
# Try to repair checksum if invalid
if not CUIChecksumRule.validate_checksum(cui_digits):
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
if repaired:
cui_digits = repaired
# Return with RO prefix only if original had it
if had_ro_prefix:
return f"RO{cui_digits}"
else:
return cui_digits
@staticmethod
async def fuzzy_match_cui_from_db(
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match CUI against database of known suppliers.
This function:
1. Validates CUI checksum
2. If valid, looks up in database (exact match)
3. If invalid, tries 1-digit corrections and looks up each candidate
4. Returns the first match found in database
Args:
cui: Extracted CUI from OCR (may be invalid)
db_session: SQLAlchemy async session for database lookups
Returns:
Tuple of (corrected_cui, supplier_name) if found, else None
Usage in OCR extraction:
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
if match:
corrected_cui, supplier_name = match
"""
from sqlalchemy import select, or_
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
if not cui:
return None
cui = cui.strip().upper()
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
return None
# Helper to format CUI with optional RO prefix
def format_cui(digits: str) -> str:
if had_ro_prefix:
return f"RO{digits}"
return digits
# Helper to search database for CUI
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
"""Search both synced and local suppliers for CUI."""
# Search patterns: with and without RO prefix
search_patterns = [digits, f"RO{digits}"]
# Search synced_suppliers first (more data)
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
or_(
SyncedSupplier.fiscal_code == digits,
SyncedSupplier.fiscal_code == f"RO{digits}",
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
or_(
LocalSupplier.fiscal_code == digits,
LocalSupplier.fiscal_code == f"RO{digits}",
LocalSupplier.fiscal_code == digits.lstrip('0'),
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
return None
# 1. If checksum is valid, check if it exists in database (exact match)
if CUIChecksumRule.validate_checksum(cui_digits):
match = await lookup_cui_in_db(cui_digits)
if match:
print(f"[Fuzzy CUI] Exact match found: {cui}{match[0]} ({match[1]})", flush=True)
return match
# Valid checksum but not in DB - return as-is (it might be a new supplier)
return None
# 2. Invalid checksum - try 1-digit corrections and verify against database
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then ends
middle_positions = list(range(2, last_pos))
position_order = middle_positions + [1, last_pos, 0]
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
# Only consider if checksum is valid
if not CUIChecksumRule.validate_checksum(candidate):
continue
# Check if this corrected CUI exists in database
match = await lookup_cui_in_db(candidate)
if match:
print(f"[Fuzzy CUI] DB match: {cui}{match[0]} ({match[1]}) [pos {pos}: {original_digit}{replacement}]", flush=True)
return match
# No match found in database
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
return None
@staticmethod
async def fuzzy_match_by_name_and_cui(
vendor_name: Optional[str],
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match supplier by NAME, then narrow down by CUI.
Algorithm:
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
2. Search suppliers by fuzzy name match (LIKE %name%)
3. If multiple results, use fuzzy CUI matching to pick best one
4. Return the best match
Args:
vendor_name: Extracted vendor name from OCR
cui: Extracted CUI from OCR (may be invalid/incomplete)
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
from sqlalchemy import select, or_, func
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
import re
if not vendor_name or len(vendor_name) < 3:
return None
# Normalize vendor name for search
def normalize_name(name: str) -> str:
"""Normalize name for fuzzy matching."""
name = name.upper()
# Remove company type suffixes
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
name = name.replace(suffix, '')
# Remove punctuation and extra spaces
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
name = ' '.join(name.split())
return name.strip()
# Extract key words from vendor name (for fuzzy search)
normalized_name = normalize_name(vendor_name)
name_words = [w for w in normalized_name.split() if len(w) >= 3]
if not name_words:
return None
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
# Build search pattern - use first significant word
primary_word = name_words[0]
search_pattern = f"%{primary_word}%"
candidates = []
# Search synced_suppliers
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
func.upper(SyncedSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
func.upper(LocalSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
if not candidates:
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
return None
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
# If only one candidate, return it
if len(candidates) == 1:
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
# Multiple candidates - try to narrow down by CUI
if cui:
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
if len(cui_digits) >= 6:
# Score each candidate by how similar their CUI is to the extracted one
def cui_similarity(candidate_cui: str) -> int:
"""Calculate how many digits match in the same position."""
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
if len(cand_digits) != len(cui_digits):
return 0
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
# Sort candidates by CUI similarity (descending)
scored = [(cui_similarity(c[0]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
best_score, best_match = scored[0]
# Require at least 70% digit match for CUI similarity
min_matching = int(len(cui_digits) * 0.7)
if best_score >= min_matching:
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
return best_match
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
# If still multiple and no CUI match, try name similarity
def name_similarity(candidate_name: str) -> int:
"""Count how many keywords match."""
norm_cand = normalize_name(candidate_name)
return sum(1 for w in name_words if w in norm_cand)
scored = [(name_similarity(c[1]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
if scored[0][0] >= 2: # At least 2 keywords match
best_match = scored[0][1]
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
return best_match
# Return first candidate if nothing else works
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
@staticmethod
async def fuzzy_match_supplier(
cui: Optional[str],
vendor_name: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
Strategy:
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
Args:
cui: Extracted CUI from OCR (may be invalid/incomplete)
vendor_name: Extracted vendor name from OCR
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
# Step 1: Try fuzzy CUI matching
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
if cui_match:
return cui_match
# Step 2: Fallback to fuzzy NAME + CUI matching
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
vendor_name, cui, db_session
)
if name_match:
return name_match
return None