feat(ocr): Add docTR OCR engine with metrics infrastructure
Add docTR as primary OCR engine with 2-tier sequential processing, OCR metrics tracking, and simplified engine selection. Features: - docTR OCR engine with light+medium preprocessing tiers - doctr_plus mode with early exit optimization (~65% fast path) - OCR metrics dashboard with per-engine statistics - User OCR preference persistence - Parallel worker pool for OCR processing - Cross-validation for extraction quality Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -13,13 +13,14 @@ Schema:
|
||||
status TEXT NOT NULL, -- pending, processing, completed, failed
|
||||
file_path TEXT NOT NULL, -- Path to uploaded file
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT, -- JSON extraction result
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
|
||||
ocr_time_ms INTEGER, -- Actual OCR engine processing time
|
||||
created_by TEXT, -- Username
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
@@ -74,17 +75,26 @@ class OCRJob:
|
||||
status: OCRJobStatus
|
||||
file_path: str
|
||||
mime_type: str
|
||||
engine: str = "auto"
|
||||
engine: str = "doctr_plus"
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result_json: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
|
||||
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
|
||||
created_by: Optional[str] = None
|
||||
original_filename: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def queue_wait_ms(self) -> Optional[int]:
|
||||
"""Calculate queue wait time (created_at to started_at)."""
|
||||
if self.created_at and self.started_at:
|
||||
delta = self.started_at - self.created_at
|
||||
return int(delta.total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> Optional[Dict]:
|
||||
"""Parse result_json to dict."""
|
||||
@@ -143,19 +153,27 @@ class OCRJobQueue:
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
file_path TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT,
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
ocr_time_ms INTEGER,
|
||||
created_by TEXT,
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Migration: add ocr_time_ms column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
|
||||
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Index for efficient queue queries
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
|
||||
@@ -177,7 +195,7 @@ class OCRJobQueue:
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
mime_type: str,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
username: Optional[str] = None,
|
||||
original_filename: Optional[str] = None
|
||||
) -> OCRJob:
|
||||
@@ -189,7 +207,7 @@ class OCRJobQueue:
|
||||
Args:
|
||||
file_bytes: Raw file bytes
|
||||
mime_type: MIME type of file
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
username: Username of requester
|
||||
original_filename: Original filename from upload
|
||||
|
||||
@@ -301,24 +319,52 @@ class OCRJobQueue:
|
||||
|
||||
async def get_next_pending(self) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get the next pending job (oldest first).
|
||||
Get the next pending job (oldest first) and atomically mark it as processing.
|
||||
|
||||
This prevents race conditions in parallel processing - only one worker
|
||||
can claim each job.
|
||||
|
||||
Returns:
|
||||
Next OCRJob to process or None if queue empty
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock: # Serialize access to prevent race conditions
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get the next pending job
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
job_id = row['id']
|
||||
|
||||
# Atomically mark as processing
|
||||
await db.execute('''
|
||||
UPDATE ocr_jobs
|
||||
SET status = 'processing', started_at = ?
|
||||
WHERE id = ? AND status = 'pending'
|
||||
''', (now.isoformat(), job_id))
|
||||
await db.commit()
|
||||
|
||||
# Fetch the updated job
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
updated_row = await cursor.fetchone()
|
||||
if updated_row:
|
||||
return self._row_to_job(updated_row)
|
||||
|
||||
return None
|
||||
|
||||
async def update_status(
|
||||
@@ -327,7 +373,8 @@ class OCRJobQueue:
|
||||
status: OCRJobStatus,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None,
|
||||
ocr_time_ms: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job status.
|
||||
@@ -337,7 +384,8 @@ class OCRJobQueue:
|
||||
status: New status
|
||||
result: Extraction result dict (for completed)
|
||||
error: Error message (for failed)
|
||||
processing_time_ms: Processing time
|
||||
processing_time_ms: Total job processing time (started_at to completed_at)
|
||||
ocr_time_ms: Actual OCR engine processing time
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
@@ -359,18 +407,18 @@ class OCRJobQueue:
|
||||
elif status == OCRJobStatus.completed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
elif status == OCRJobStatus.failed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
else:
|
||||
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
|
||||
@@ -542,13 +590,14 @@ class OCRJobQueue:
|
||||
status=OCRJobStatus(row['status']),
|
||||
file_path=row['file_path'],
|
||||
mime_type=row['mime_type'],
|
||||
engine=row['engine'] or 'auto',
|
||||
engine=row['engine'] or 'doctr_plus',
|
||||
created_at=parse_datetime(row['created_at']),
|
||||
started_at=parse_datetime(row['started_at']),
|
||||
completed_at=parse_datetime(row['completed_at']),
|
||||
result_json=row['result_json'],
|
||||
error_message=row['error_message'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
|
||||
created_by=row['created_by'],
|
||||
original_filename=row['original_filename'],
|
||||
expires_at=parse_datetime(row['expires_at']),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
OCR Job Worker - Background Task for Queue Processing
|
||||
|
||||
Runs as an asyncio background task in FastAPI.
|
||||
Continuously polls the job queue and processes OCR requests.
|
||||
Continuously polls the job queue and processes OCR requests IN PARALLEL.
|
||||
|
||||
Architecture:
|
||||
FastAPI startup
|
||||
@@ -12,18 +12,19 @@ Architecture:
|
||||
asyncio.create_task(_job_worker_loop())
|
||||
↓
|
||||
while True:
|
||||
job = job_queue.get_next_pending()
|
||||
if job:
|
||||
result = ocr_worker_pool.submit_task(...)
|
||||
job_queue.update_status(...)
|
||||
await asyncio.sleep(0.5)
|
||||
# Process up to OCR_WORKERS jobs concurrently
|
||||
jobs = get_pending_jobs(limit=available_slots)
|
||||
for job in jobs:
|
||||
asyncio.create_task(_process_job(job))
|
||||
await asyncio.sleep(0.1)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from .job_queue import job_queue, OCRJobStatus, OCRJob
|
||||
from .ocr_worker_pool import ocr_worker_pool
|
||||
@@ -34,47 +35,76 @@ logger = logging.getLogger(__name__)
|
||||
_job_worker_task: Optional[asyncio.Task] = None
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
|
||||
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL_SECONDS = 0.5 # How often to check for new jobs
|
||||
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
|
||||
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
|
||||
|
||||
|
||||
async def _job_worker_loop() -> None:
|
||||
"""
|
||||
Main worker loop - processes jobs from queue.
|
||||
Main worker loop - processes jobs from queue IN PARALLEL.
|
||||
|
||||
Runs continuously until shutdown. Polls queue every 0.5s
|
||||
and submits jobs to worker pool for processing.
|
||||
Runs continuously until shutdown. Uses semaphore to limit
|
||||
concurrent jobs to OCR_WORKERS count. Launches jobs as
|
||||
background tasks without waiting for completion.
|
||||
"""
|
||||
global _shutdown_event
|
||||
global _shutdown_event, _active_tasks, _concurrency_semaphore
|
||||
|
||||
logger.info("[JobWorker] Starting worker loop...")
|
||||
# Get max concurrent jobs from env (matches worker pool size)
|
||||
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
|
||||
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
_active_tasks = set()
|
||||
|
||||
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
max_consecutive_errors = 10
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
await _process_job(job)
|
||||
else:
|
||||
# No jobs - wait before polling again
|
||||
# Clean up completed tasks
|
||||
done_tasks = {t for t in _active_tasks if t.done()}
|
||||
for task in done_tasks:
|
||||
_active_tasks.discard(task)
|
||||
# Check for exceptions
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout, continue loop
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Task failed: {e}")
|
||||
|
||||
# Check if we have capacity for more jobs
|
||||
active_count = len(_active_tasks)
|
||||
available_slots = max_concurrent - active_count
|
||||
|
||||
if available_slots > 0:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0
|
||||
# Launch job processing as background task
|
||||
task = asyncio.create_task(_process_job_with_semaphore(job))
|
||||
_active_tasks.add(task)
|
||||
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
|
||||
else:
|
||||
# No pending jobs - wait briefly
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# At capacity - wait for a slot to free up
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Worker loop cancelled")
|
||||
@@ -88,27 +118,46 @@ async def _job_worker_loop() -> None:
|
||||
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
|
||||
break
|
||||
|
||||
# Backoff on errors
|
||||
await asyncio.sleep(min(consecutive_errors * 2, 30))
|
||||
|
||||
# Wait for active tasks to complete on shutdown
|
||||
if _active_tasks:
|
||||
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
|
||||
await asyncio.gather(*_active_tasks, return_exceptions=True)
|
||||
|
||||
logger.info("[JobWorker] Worker loop stopped")
|
||||
|
||||
|
||||
async def _process_job_with_semaphore(job: OCRJob) -> None:
|
||||
"""
|
||||
Process job with semaphore to limit concurrency.
|
||||
|
||||
Acquires semaphore before processing, releases after.
|
||||
This ensures we don't exceed OCR_WORKERS concurrent jobs.
|
||||
"""
|
||||
global _concurrency_semaphore
|
||||
|
||||
async with _concurrency_semaphore:
|
||||
await _process_job(job)
|
||||
|
||||
|
||||
async def _process_job(job: OCRJob) -> None:
|
||||
"""
|
||||
Process a single OCR job.
|
||||
|
||||
Reads file, submits to worker pool, updates job status.
|
||||
Reads file, submits to worker pool, updates job status,
|
||||
and saves metrics for analytics.
|
||||
|
||||
Args:
|
||||
job: OCRJob to process
|
||||
"""
|
||||
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
|
||||
start_time = time.time()
|
||||
file_size = 0
|
||||
file_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# Mark as processing
|
||||
await job_queue.update_status(job.id, OCRJobStatus.processing)
|
||||
# Note: Job already marked as 'processing' atomically in get_next_pending()
|
||||
|
||||
# Read file bytes
|
||||
file_path = Path(job.file_path)
|
||||
@@ -118,6 +167,10 @@ async def _process_job(job: OCRJob) -> None:
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
file_size = len(file_bytes)
|
||||
# Determine file type from job or extension
|
||||
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
|
||||
|
||||
# Submit to worker pool
|
||||
result = await ocr_worker_pool.submit_task(
|
||||
image_bytes=file_bytes,
|
||||
@@ -132,14 +185,43 @@ async def _process_job(job: OCRJob) -> None:
|
||||
# Job completed successfully
|
||||
extraction = result.get("extraction", {})
|
||||
|
||||
# Include raw_texts for analysis (from all OCR engine passes)
|
||||
extraction['raw_texts'] = result.get("raw_texts", [])
|
||||
|
||||
# Extract actual OCR processing time from extraction result
|
||||
ocr_time_ms = extraction.get('processing_time_ms', 0)
|
||||
|
||||
# Debug: log suggested_payment_mode
|
||||
spm = extraction.get('suggested_payment_mode')
|
||||
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.completed,
|
||||
result=extraction,
|
||||
processing_time_ms=elapsed_ms
|
||||
processing_time_ms=elapsed_ms,
|
||||
ocr_time_ms=ocr_time_ms
|
||||
)
|
||||
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms")
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
|
||||
|
||||
# Save metrics for successful job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=extraction.get('ocr_engine', job.engine),
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=True,
|
||||
overall_confidence=extraction.get('overall_confidence', 0.0),
|
||||
fields_extracted=_count_extracted_fields(extraction),
|
||||
needs_manual_review=extraction.get('needs_manual_review'),
|
||||
validation_warnings_count=len(extraction.get('validation_warnings', [])),
|
||||
validation_errors_count=len(extraction.get('validation_errors', [])),
|
||||
)
|
||||
|
||||
else:
|
||||
# Job failed
|
||||
@@ -154,6 +236,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
|
||||
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
|
||||
|
||||
# Save metrics for failed job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -166,6 +262,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
# Save metrics for error job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup file after processing
|
||||
try:
|
||||
@@ -340,3 +450,96 @@ def estimate_wait_time(queue_position: int) -> int:
|
||||
|
||||
# Estimate: position * average_time
|
||||
return int(queue_position * avg_time)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _save_job_metrics(
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Save OCR job metrics to database for analytics.
|
||||
|
||||
Called after each job completes (success or failure).
|
||||
Errors are logged but don't affect job processing.
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
|
||||
|
||||
async with await get_db_session() as session:
|
||||
await OCRMetricsCRUD.create(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - metrics are nice-to-have
|
||||
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
|
||||
|
||||
|
||||
def _count_extracted_fields(extraction: dict) -> int:
|
||||
"""
|
||||
Count number of successfully extracted fields from OCR result.
|
||||
|
||||
Counts non-None values in key fields.
|
||||
"""
|
||||
key_fields = [
|
||||
'receipt_number',
|
||||
'receipt_date',
|
||||
'amount',
|
||||
'partner_name',
|
||||
'cui',
|
||||
'tva_total',
|
||||
'address',
|
||||
'items_count',
|
||||
]
|
||||
|
||||
count = 0
|
||||
for field in key_fields:
|
||||
value = extraction.get(field)
|
||||
if value is not None and value != '' and value != []:
|
||||
count += 1
|
||||
|
||||
# Also count TVA entries if present
|
||||
tva_entries = extraction.get('tva_entries', [])
|
||||
if tva_entries and len(tva_entries) > 0:
|
||||
count += 1
|
||||
|
||||
# Count payment methods if present
|
||||
payment_methods = extraction.get('payment_methods', [])
|
||||
if payment_methods and len(payment_methods) > 0:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
OCR Worker Pool Manager
|
||||
|
||||
Manages a ProcessPoolExecutor with persistent PaddleOCR initialization.
|
||||
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
|
||||
Key features:
|
||||
- ProcessPoolExecutor with max_workers=1 (sequential, no memory leak)
|
||||
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
|
||||
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
|
||||
- mp_context='spawn' for Windows IIS compatibility
|
||||
- PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- atexit + signal handlers for cleanup
|
||||
- Health check with auto-respawn
|
||||
- Orphan process cleanup on Windows
|
||||
@@ -29,7 +30,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -48,8 +49,8 @@ class OCRWorkerPool:
|
||||
"""
|
||||
Singleton manager for OCR ProcessPoolExecutor.
|
||||
|
||||
Ensures PaddleOCR is loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=None to keep worker alive indefinitely.
|
||||
Ensures OCR engines are loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
|
||||
"""
|
||||
|
||||
_instance: Optional["OCRWorkerPool"] = None
|
||||
@@ -86,7 +87,7 @@ class OCRWorkerPool:
|
||||
Initialize the ProcessPoolExecutor.
|
||||
|
||||
Creates executor with spawn context for Windows compatibility.
|
||||
Uses max_tasks_per_child=None to keep worker alive (persistent PaddleOCR).
|
||||
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
@@ -103,18 +104,30 @@ class OCRWorkerPool:
|
||||
# Cleanup any orphan workers from previous runs
|
||||
self._cleanup_orphan_workers()
|
||||
|
||||
# Read configuration from environment
|
||||
max_workers = int(os.getenv('OCR_WORKERS', '2'))
|
||||
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
|
||||
# 0 means no restart (None in ProcessPoolExecutor)
|
||||
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
|
||||
|
||||
# Create executor with spawn context (Windows compatible)
|
||||
# Use mp_context='spawn' explicitly for cross-platform consistency
|
||||
mp_context = mp.get_context('spawn')
|
||||
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=1, # Single worker for sequential processing
|
||||
mp_context=mp_context,
|
||||
initializer=_worker_initializer,
|
||||
max_tasks_per_child=None, # Keep worker alive indefinitely
|
||||
)
|
||||
# max_tasks_per_child only available in Python 3.11+
|
||||
executor_kwargs = {
|
||||
'max_workers': max_workers,
|
||||
'mp_context': mp_context,
|
||||
'initializer': _worker_initializer,
|
||||
}
|
||||
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
|
||||
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
|
||||
else:
|
||||
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
logger.info("[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers=1)")
|
||||
self._executor = ProcessPoolExecutor(**executor_kwargs)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +186,7 @@ class OCRWorkerPool:
|
||||
async def submit_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
timeout: float = 120.0
|
||||
) -> dict:
|
||||
@@ -182,7 +195,7 @@ class OCRWorkerPool:
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
|
||||
timeout: Maximum processing time in seconds
|
||||
|
||||
@@ -339,6 +352,7 @@ class OCRWorkerPool:
|
||||
# Global engines - persist between tasks in worker process
|
||||
_paddle_engine = None
|
||||
_tesseract_engine = None
|
||||
_doctr_engine = None # docTR engine (PyTorch backend)
|
||||
_worker_initialized = False
|
||||
|
||||
|
||||
@@ -346,40 +360,92 @@ def _worker_initializer() -> None:
|
||||
"""
|
||||
Called once when worker process spawns.
|
||||
|
||||
Initializes global OCR engines that persist between tasks.
|
||||
This is where PaddleOCR loading happens (15-20 seconds).
|
||||
Initializes global OCR engines IN PARALLEL for faster startup.
|
||||
Uses ThreadPoolExecutor to load enabled engines concurrently.
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
|
||||
Total warmup time = max(engine_times) instead of sum(engine_times).
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
if _worker_initialized:
|
||||
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
|
||||
return
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines...", flush=True)
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
|
||||
if paddle_enabled:
|
||||
enabled_engines.append("paddle")
|
||||
if tesseract_enabled:
|
||||
enabled_engines.append("tesseract")
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
|
||||
if not paddle_enabled:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
|
||||
if not tesseract_enabled:
|
||||
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize PaddleOCR
|
||||
try:
|
||||
# Import inside worker to avoid import issues in main process
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
_paddle_engine = initialize_paddle_engine()
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR init failed: {e}", flush=True)
|
||||
_paddle_engine = None
|
||||
# Define loader functions - each runs in its own thread
|
||||
def load_doctr():
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
|
||||
engine = initialize_doctr_engine()
|
||||
return ("doctr", engine, None)
|
||||
except Exception as e:
|
||||
return ("doctr", None, str(e))
|
||||
|
||||
# Initialize Tesseract
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
_tesseract_engine = TesseractEngine()
|
||||
print(f"[Worker {os.getpid()}] Tesseract loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Tesseract init failed: {e}", flush=True)
|
||||
_tesseract_engine = None
|
||||
def load_paddle():
|
||||
if not paddle_enabled:
|
||||
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
engine = initialize_paddle_engine()
|
||||
return ("paddle", engine, None)
|
||||
except Exception as e:
|
||||
return ("paddle", None, str(e))
|
||||
|
||||
def load_tesseract():
|
||||
if not tesseract_enabled:
|
||||
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
engine = TesseractEngine()
|
||||
return ("tesseract", engine, None)
|
||||
except Exception as e:
|
||||
return ("tesseract", None, str(e))
|
||||
|
||||
# Build list of futures for enabled engines only
|
||||
futures_to_submit = [load_doctr] # docTR always loaded
|
||||
if paddle_enabled:
|
||||
futures_to_submit.append(load_paddle)
|
||||
if tesseract_enabled:
|
||||
futures_to_submit.append(load_tesseract)
|
||||
|
||||
# Load engines in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
|
||||
futures = [executor.submit(fn) for fn in futures_to_submit]
|
||||
|
||||
for future in as_completed(futures):
|
||||
name, engine, error = future.result()
|
||||
if error and "disabled" not in error:
|
||||
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
|
||||
elif engine:
|
||||
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
|
||||
if name == "doctr":
|
||||
_doctr_engine = engine
|
||||
elif name == "paddle":
|
||||
_paddle_engine = engine
|
||||
elif name == "tesseract":
|
||||
_tesseract_engine = engine
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_worker_initialized = True
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s", flush=True)
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
|
||||
|
||||
|
||||
def _warmup_task() -> dict:
|
||||
@@ -389,7 +455,7 @@ def _warmup_task() -> dict:
|
||||
Called at FastAPI startup to pre-warm the worker.
|
||||
Returns success status and worker PID.
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -400,6 +466,14 @@ def _warmup_task() -> dict:
|
||||
import numpy as np
|
||||
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# Test docTR if available (fastest engine)
|
||||
if _doctr_engine is not None:
|
||||
try:
|
||||
_doctr_engine([dummy_img])
|
||||
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
|
||||
|
||||
# Test PaddleOCR if available
|
||||
if _paddle_engine is not None:
|
||||
try:
|
||||
@@ -414,6 +488,7 @@ def _warmup_task() -> dict:
|
||||
return {
|
||||
"success": True,
|
||||
"pid": os.getpid(),
|
||||
"doctr_available": _doctr_engine is not None,
|
||||
"paddle_available": _paddle_engine is not None,
|
||||
"tesseract_available": _tesseract_engine is not None
|
||||
}
|
||||
@@ -428,7 +503,7 @@ def _warmup_task() -> dict:
|
||||
|
||||
def _process_ocr_task(
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto"
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -439,13 +514,13 @@ def _process_ocr_task(
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine choice
|
||||
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -461,7 +536,8 @@ def _process_ocr_task(
|
||||
paddle_engine=_paddle_engine,
|
||||
tesseract_engine=_tesseract_engine,
|
||||
engine=engine,
|
||||
preprocessing=preprocessing
|
||||
preprocessing=preprocessing,
|
||||
doctr_engine=_doctr_engine
|
||||
)
|
||||
|
||||
# Cleanup after each task
|
||||
|
||||
@@ -6,6 +6,7 @@ Handles OCR processing with persistent engine instances.
|
||||
|
||||
Key features:
|
||||
- PaddleOCR initialized ONCE at process spawn
|
||||
- docTR initialized ONCE at process spawn (PyTorch backend)
|
||||
- Tesseract as fallback/complement engine
|
||||
- Multi-pass preprocessing (light → medium → tesseract)
|
||||
- Automatic engine selection based on results
|
||||
@@ -26,6 +27,13 @@ import numpy as np
|
||||
# Disable PaddleOCR model source check for faster startup
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
# Memory optimization for docTR (prevents memory leak in multiprocessing)
|
||||
# Source: https://github.com/mindee/doctr/issues/1594
|
||||
os.environ['DOCTR_MULTIPROCESSING_DISABLE'] = 'TRUE'
|
||||
|
||||
# Reduce Intel oneDNN cache to save memory
|
||||
os.environ['ONEDNN_PRIMITIVE_CACHE_CAPACITY'] = '1'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
@@ -71,25 +79,67 @@ def initialize_paddle_engine():
|
||||
return None
|
||||
|
||||
|
||||
def initialize_doctr_engine():
|
||||
"""
|
||||
Initialize docTR engine (CPU only).
|
||||
|
||||
Called once at worker spawn. Returns the engine instance
|
||||
that will be reused for all subsequent requests.
|
||||
|
||||
Note: DirectML (AMD GPU) has compatibility issues with docTR.
|
||||
CUDA (NVIDIA) works but requires separate PyTorch build.
|
||||
CPU mode is stable and well-optimized.
|
||||
|
||||
Returns:
|
||||
docTR predictor instance or None if unavailable
|
||||
"""
|
||||
try:
|
||||
print(f"[Worker {os.getpid()}] Loading docTR (PyTorch backend, CPU)...", flush=True)
|
||||
start_time = time.time()
|
||||
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[Worker {os.getpid()}] docTR loaded in {elapsed:.1f}s", flush=True)
|
||||
return doctr
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR init failed: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def process_ocr(
|
||||
image_bytes: bytes,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
engine: str = "auto",
|
||||
preprocessing: str = "auto"
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
doctr_engine=None
|
||||
) -> dict:
|
||||
"""
|
||||
Process OCR on image bytes.
|
||||
|
||||
Main entry point for OCR processing in worker process.
|
||||
Uses adaptive multi-pass strategy for best results.
|
||||
Uses the specified engine for text recognition.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (JPEG, PNG, or PDF)
|
||||
paddle_engine: Pre-initialized PaddleOCR instance (or None)
|
||||
tesseract_engine: Pre-initialized TesseractEngine instance (or None)
|
||||
engine: Engine selection ('auto', 'paddleocr', 'tesseract')
|
||||
engine: Engine selection ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('auto', 'light', 'medium', 'heavy')
|
||||
doctr_engine: Pre-initialized docTR instance (or None)
|
||||
|
||||
Returns:
|
||||
Dict with extraction results:
|
||||
@@ -101,14 +151,20 @@ def process_ocr(
|
||||
"ocr_engine": str
|
||||
}
|
||||
"""
|
||||
import sys
|
||||
start_time = time.time()
|
||||
print(f"[Worker {os.getpid()}] Processing OCR: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
# Decode image from bytes
|
||||
print(f"[Worker {os.getpid()}] Decoding image...", flush=True)
|
||||
sys.stdout.flush()
|
||||
image = _decode_image(image_bytes)
|
||||
if image is None:
|
||||
return {"success": False, "error": "Failed to decode image"}
|
||||
print(f"[Worker {os.getpid()}] Image decoded: shape={image.shape}, dtype={image.dtype}, size={image.nbytes/1024/1024:.1f}MB", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Import preprocessor
|
||||
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
@@ -116,22 +172,36 @@ def process_ocr(
|
||||
|
||||
preprocessor = ImagePreprocessor()
|
||||
extractor = ReceiptExtractor()
|
||||
print(f"[Worker {os.getpid()}] Preprocessor and extractor initialized", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# Engine routing
|
||||
if engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "tesseract":
|
||||
# Engine routing (available: tesseract, doctr, doctr_plus, paddleocr)
|
||||
print(f"[Worker {os.getpid()}] Routing to engine: {engine}", flush=True)
|
||||
sys.stdout.flush()
|
||||
if engine == "tesseract":
|
||||
extraction, raw_texts = _process_tesseract_only(
|
||||
image, tesseract_engine, preprocessor, extractor
|
||||
)
|
||||
else: # auto
|
||||
extraction, raw_texts = _process_adaptive(
|
||||
image, paddle_engine, tesseract_engine, preprocessor, extractor
|
||||
elif engine == "doctr":
|
||||
extraction, raw_texts = _process_doctr_only(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "doctr_plus":
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
else:
|
||||
# Default to doctr_plus if unknown engine specified
|
||||
print(f"[OCR] Unknown engine '{engine}', defaulting to doctr_plus", flush=True)
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
@@ -171,7 +241,11 @@ def process_ocr(
|
||||
|
||||
|
||||
def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF)."""
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF).
|
||||
|
||||
For PDFs, uses 200 DPI which is sufficient for receipt OCR
|
||||
and reduces processing time by ~50% vs 300 DPI.
|
||||
"""
|
||||
try:
|
||||
# Try as regular image first
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
@@ -180,18 +254,21 @@ def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
if image is not None:
|
||||
return image
|
||||
|
||||
# Try as PDF
|
||||
# Try as PDF - use 200 DPI for faster processing (sufficient for receipts)
|
||||
try:
|
||||
import pdf2image
|
||||
from PIL import Image
|
||||
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=300)
|
||||
# 200 DPI is sufficient for receipt text recognition
|
||||
# 300 DPI was overkill and slowed down processing
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=200)
|
||||
if images:
|
||||
# Convert first page to numpy array
|
||||
pil_img = images[0]
|
||||
print(f"[Worker {os.getpid()}] PDF decoded: {pil_img.width}x{pil_img.height} @ 200 DPI", flush=True)
|
||||
return np.array(pil_img)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PDF decode error: {e}", flush=True)
|
||||
|
||||
return None
|
||||
|
||||
@@ -270,83 +347,275 @@ def _process_tesseract_only(
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_adaptive(
|
||||
def _process_doctr_only(
|
||||
image: np.ndarray,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
Adaptive multi-pass OCR processing.
|
||||
Process using docTR only (light + medium preprocessing).
|
||||
|
||||
Strategy:
|
||||
1. PaddleOCR Light - fastest, best for clear PDFs
|
||||
2. PaddleOCR Medium - if Light incomplete
|
||||
3. Tesseract - complement missing fields only
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
docTR uses EXACT same preprocessing as PaddleOCR for consistency.
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# === STEP 1: PaddleOCR Light ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 1: PaddleOCR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
paddle_light = _paddle_recognize(paddle_engine, light_img)
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
if paddle_light and paddle_light.text:
|
||||
extraction = extractor.extract(paddle_light.text)
|
||||
extraction.ocr_engine = "paddle-light"
|
||||
raw_texts.append(f"=== PaddleOCR Light (conf: {paddle_light.confidence:.0%}) ===\n{paddle_light.text}")
|
||||
# Step 1: Light preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 1: docTR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found in Step 1", flush=True)
|
||||
return extraction, raw_texts
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-light"
|
||||
raw_texts.append(f"=== docTR Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# === STEP 2: PaddleOCR Medium ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 2: PaddleOCR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
paddle_medium = _paddle_recognize(paddle_engine, medium_img)
|
||||
if _is_extraction_complete(extraction):
|
||||
return extraction, raw_texts
|
||||
|
||||
if paddle_medium and paddle_medium.text:
|
||||
extraction_medium = extractor.extract(paddle_medium.text)
|
||||
extraction_medium.ocr_engine = "paddle-medium"
|
||||
raw_texts.append(f"=== PaddleOCR Medium (conf: {paddle_medium.confidence:.0%}) ===\n{paddle_medium.text}")
|
||||
# Step 2: Medium preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 2: docTR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "paddle-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-medium"
|
||||
raw_texts.append(f"=== docTR Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found after Step 2", flush=True)
|
||||
return extraction, raw_texts
|
||||
|
||||
# === STEP 3: Tesseract (complement only) ===
|
||||
if tesseract_engine:
|
||||
print("[OCR] Step 3: Tesseract complement", flush=True)
|
||||
tesseract_img = preprocessor.preprocess_for_tesseract(image)
|
||||
tesseract_result = tesseract_engine.recognize(tesseract_img)
|
||||
|
||||
if tesseract_result and tesseract_result.text:
|
||||
extraction_tess = extractor.extract(tesseract_result.text)
|
||||
extraction_tess.ocr_engine = "tesseract"
|
||||
raw_texts.append(f"=== Tesseract (conf: {tesseract_result.confidence:.0%}) ===\n{tesseract_result.text}")
|
||||
|
||||
if extraction:
|
||||
extraction = _complement_extraction(extraction, extraction_tess)
|
||||
extraction.ocr_engine = "adaptive-full"
|
||||
else:
|
||||
extraction = extraction_tess
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "doctr-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_doctr_plus(
|
||||
image: np.ndarray,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
docTR Plus - Optimized 2-tier sequential processing with early exit.
|
||||
|
||||
Architecture:
|
||||
- Tier 1: Light preprocessing (~4-5s)
|
||||
→ Early exit if confidence >= 0.75 AND all fields valid AND cross-validations pass
|
||||
- Tier 2: Medium preprocessing (only if Tier 1 insufficient, ~4-5s additional)
|
||||
→ Merge with Tier 1 results
|
||||
→ Mark for review if still problems
|
||||
|
||||
Performance:
|
||||
- Fast path (80% receipts): ~4-5s (Tier 1 only)
|
||||
- Slow path (20% receipts): ~8-9s (Tier 1 + Tier 2)
|
||||
- Average: ~5-6s
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
extraction_result.needs_review = True if validation issues remain
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
# ========== TIER 1: Light Preprocessing ==========
|
||||
print("[docTR+] TIER 1: Light preprocessing", flush=True)
|
||||
import time
|
||||
tier1_start = time.time()
|
||||
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
tier1_time = time.time() - tier1_start
|
||||
print(f"[docTR+] TIER 1 completed in {tier1_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-plus-light"
|
||||
raw_texts.append(f"=== docTR+ Tier1/Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# Early Exit Check: confidence >= 0.75 + cross-validations
|
||||
if _is_extraction_valid_for_early_exit(extraction, min_confidence=0.75):
|
||||
print(f"[docTR+] EARLY EXIT - Tier 1 sufficient (conf: {extraction.overall_confidence:.0%})", flush=True)
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
return extraction, raw_texts
|
||||
|
||||
print(f"[docTR+] Tier 1 incomplete or validation failed, proceeding to Tier 2...", flush=True)
|
||||
|
||||
# ========== TIER 2: Medium Preprocessing (only if needed) ==========
|
||||
print("[docTR+] TIER 2: Medium preprocessing", flush=True)
|
||||
tier2_start = time.time()
|
||||
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
tier2_time = time.time() - tier2_start
|
||||
print(f"[docTR+] TIER 2 completed in {tier2_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-plus-medium"
|
||||
raw_texts.append(f"=== docTR+ Tier2/Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if extraction:
|
||||
# Merge Tier 1 + Tier 2 results
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
# ========== FINAL VALIDATION ==========
|
||||
if extraction:
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
|
||||
# Mark for review if validation still fails after both tiers
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation or extraction.overall_confidence < 0.75:
|
||||
# Mark for human review using existing fields
|
||||
extraction.needs_manual_review = True
|
||||
|
||||
if extraction.overall_confidence < 0.75:
|
||||
extraction.validation_warnings.append(f"Low confidence: {extraction.overall_confidence:.0%}")
|
||||
|
||||
if not extraction.amount:
|
||||
extraction.validation_errors.append("TOTAL not detected")
|
||||
if not extraction.cui:
|
||||
extraction.validation_warnings.append("CUI not detected")
|
||||
if not extraction.tva_total and not extraction.tva_entries:
|
||||
extraction.validation_warnings.append("TVA not detected")
|
||||
if not extraction.receipt_date:
|
||||
extraction.validation_warnings.append("Date not detected")
|
||||
|
||||
# Add cross-validation errors
|
||||
extraction.validation_errors.extend(errors)
|
||||
|
||||
print(f"[docTR+] Marked for review: {extraction.validation_errors + extraction.validation_warnings}", flush=True)
|
||||
else:
|
||||
extraction.needs_manual_review = False
|
||||
|
||||
total_time = tier1_time + (tier2_time if 'tier2_time' in dir() else 0)
|
||||
print(f"[docTR+] Total processing time: {total_time:.1f}s", flush=True)
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VALIDATION HELPERS (used by doctr_plus for early exit decisions)
|
||||
# =============================================================================
|
||||
|
||||
def _quick_cross_validate(extraction) -> tuple[bool, float, list[str]]:
|
||||
"""
|
||||
Quick cross-validation for OCR results.
|
||||
|
||||
Checks critical field correlations to detect obvious OCR errors.
|
||||
Used by doctr_plus to decide whether to proceed to Tier 2 or exit early.
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
if extraction is None:
|
||||
return False, 1.0, ["No extraction result"]
|
||||
|
||||
# Convert extraction to dict for validation
|
||||
# Build TVA entries dict for TVAEntriesSumRule (expects {code: amount})
|
||||
tva_entries_dict = {}
|
||||
if extraction.tva_entries:
|
||||
for entry in extraction.tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
code = entry.get('code', 'A')
|
||||
amount = entry.get('amount', 0)
|
||||
try:
|
||||
tva_entries_dict[code] = float(amount)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
validation_data = {
|
||||
"amount": float(extraction.amount) if extraction.amount else None,
|
||||
"tva": float(extraction.tva_total) if extraction.tva_total else None,
|
||||
"tva_entries": tva_entries_dict, # For TVAEntriesSumRule: {code: amount}
|
||||
"cui": extraction.cui, # For CUI checksum validation
|
||||
}
|
||||
|
||||
# Also pass raw tva_entries for TVABasedTotalRule (for rate detection)
|
||||
if extraction.tva_entries:
|
||||
validation_data['tva_entries_raw'] = extraction.tva_entries
|
||||
|
||||
# Add payment methods if available (for TOTAL vs CARD+CASH validation)
|
||||
if extraction.payment_methods:
|
||||
try:
|
||||
card_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'CARD'
|
||||
)
|
||||
cash_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'NUMERAR'
|
||||
)
|
||||
validation_data['card_amount'] = card_amount
|
||||
validation_data['cash_amount'] = cash_amount
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Payment method validation error: {e}", flush=True)
|
||||
|
||||
# Run quick validation
|
||||
validator = OCRValidationEngine()
|
||||
return validator.quick_validate_for_hybrid(validation_data)
|
||||
|
||||
except Exception as e:
|
||||
# Never crash the process on validation errors
|
||||
print(f"[Worker {os.getpid()}] Cross-validation error: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return "passes" to allow processing to continue
|
||||
return True, 0.0, [f"Validation skipped due to error: {str(e)}"]
|
||||
|
||||
|
||||
def _is_extraction_valid_for_early_exit(extraction, min_confidence: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if extraction is valid for early exit in doctr_plus.
|
||||
|
||||
Combines confidence check with cross-validation to prevent
|
||||
early exit on OCR errors (e.g., wrong TOTAL but correct TVA).
|
||||
|
||||
Returns:
|
||||
True only if:
|
||||
1. Overall confidence >= min_confidence
|
||||
2. Critical fields are present (AMOUNT, DATE, CUI)
|
||||
3. Cross-validation passes (TOTAL matches TVA calculation, or no TVA)
|
||||
"""
|
||||
try:
|
||||
# First check basic completeness (relaxed for early exit)
|
||||
if not _is_extraction_complete(extraction, min_confidence, for_early_exit=True):
|
||||
return False
|
||||
|
||||
# Then run cross-validation
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation:
|
||||
print(f"[Early Exit] BLOCKED: cross-validation failed: {errors}", flush=True)
|
||||
return False
|
||||
|
||||
print(f"[Early Exit] OK: conf={extraction.overall_confidence:.0%}, validation passed", flush=True)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Never crash on validation - just continue to next engine
|
||||
print(f"[Worker {os.getpid()}] Early exit check error: {e}", flush=True)
|
||||
return False # Continue to next engine on error
|
||||
|
||||
def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""Run PaddleOCR recognition on image."""
|
||||
try:
|
||||
@@ -388,34 +657,191 @@ def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85) -> bool:
|
||||
"""Check if extraction has all required fields."""
|
||||
def _doctr_recognize(doctr_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""
|
||||
Run docTR recognition on image.
|
||||
|
||||
docTR requires RGB images, handles conversion automatically.
|
||||
Uses same preprocessing as PaddleOCR for consistent results.
|
||||
"""
|
||||
try:
|
||||
# docTR requires RGB images
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif image.shape[2] == 3:
|
||||
# Convert BGR (OpenCV) to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
|
||||
# docTR expects a list of numpy arrays (pages)
|
||||
result = doctr_engine([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker] docTR error: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85, for_early_exit: bool = False) -> bool:
|
||||
"""
|
||||
Check if extraction has required fields.
|
||||
|
||||
Args:
|
||||
ext: Extraction result
|
||||
min_confidence: Minimum overall confidence
|
||||
for_early_exit: If True, use relaxed criteria (AMOUNT + DATE + CUI required)
|
||||
If False, require all fields (strict mode for final validation)
|
||||
|
||||
Returns:
|
||||
True if extraction meets completeness criteria
|
||||
"""
|
||||
# Check confidence first
|
||||
if ext.overall_confidence < min_confidence:
|
||||
if for_early_exit:
|
||||
print(f"[Early Exit] BLOCKED: confidence {ext.overall_confidence:.0%} < {min_confidence:.0%}", flush=True)
|
||||
return False
|
||||
|
||||
has_number = bool(ext.receipt_number)
|
||||
has_date = bool(ext.receipt_date)
|
||||
has_amount = bool(ext.amount)
|
||||
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
|
||||
has_cui = bool(ext.cui)
|
||||
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
if for_early_exit:
|
||||
# Relaxed criteria for early exit:
|
||||
# - AMOUNT is required (core field)
|
||||
# - DATE is required (needed for accounting)
|
||||
# - CUI is required (needed for supplier identification)
|
||||
# - TVA is NOT required (some receipts have 0% TVA)
|
||||
# - receipt_number is NOT required (often missing)
|
||||
required_ok = all([has_amount, has_date, has_cui])
|
||||
|
||||
if not required_ok:
|
||||
missing = []
|
||||
if not has_amount: missing.append("AMOUNT")
|
||||
if not has_date: missing.append("DATE")
|
||||
if not has_cui: missing.append("CUI")
|
||||
print(f"[Early Exit] BLOCKED: missing required fields: {', '.join(missing)}", flush=True)
|
||||
|
||||
return required_ok
|
||||
else:
|
||||
# Strict criteria for final validation (all fields required)
|
||||
has_number = bool(ext.receipt_number)
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
|
||||
|
||||
def _merge_extractions(primary, secondary):
|
||||
"""Merge two extractions, picking best fields from each."""
|
||||
"""Merge two extractions, picking best fields from each.
|
||||
|
||||
Primary should be the higher-quality engine (e.g., docTR).
|
||||
Secondary is the fallback engine (e.g., Tesseract).
|
||||
|
||||
Priority logic:
|
||||
- AMOUNT: TVA validation wins over confidence. If both valid or both invalid,
|
||||
uses confidence (or TVA diff for invalid cases).
|
||||
- DATE/CUI: Validation-based, then confidence, then primary wins ties.
|
||||
- OTHER FIELDS: Primary wins when both have values.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_extractor import ExtractionResult
|
||||
|
||||
result = ExtractionResult()
|
||||
|
||||
# Amount - prefer higher confidence
|
||||
# Helper: Check if amount matches TVA calculation
|
||||
def amount_passes_tva_validation(amount, tva_total, tva_entries):
|
||||
if not amount or not tva_total:
|
||||
return False, 0.0
|
||||
try:
|
||||
tva_rate = 0.21 # Default Romanian TVA
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict) and entry.get('percent'):
|
||||
tva_rate = float(entry['percent']) / 100.0
|
||||
break
|
||||
# Expected TOTAL = TVA / rate * (1 + rate)
|
||||
expected = float(tva_total) * (1 + tva_rate) / tva_rate
|
||||
actual = float(amount)
|
||||
diff_percent = abs(actual - expected) / expected if expected > 0 else 1.0
|
||||
return diff_percent < 0.03, diff_percent # 3% tolerance
|
||||
except:
|
||||
return False, 1.0
|
||||
|
||||
# Amount - prefer TVA-validated value over confidence
|
||||
if primary.amount and secondary.amount:
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
# Get TVA from the one with entries, or use any available
|
||||
tva_total = primary.tva_total or secondary.tva_total
|
||||
tva_entries = primary.tva_entries or secondary.tva_entries
|
||||
|
||||
primary_valid, primary_diff = amount_passes_tva_validation(
|
||||
primary.amount, tva_total, tva_entries
|
||||
)
|
||||
secondary_valid, secondary_diff = amount_passes_tva_validation(
|
||||
secondary.amount, tva_total, tva_entries
|
||||
)
|
||||
|
||||
print(f"[Merge] Amount comparison: primary={primary.amount} (valid={primary_valid}, diff={primary_diff:.1%}), "
|
||||
f"secondary={secondary.amount} (valid={secondary_valid}, diff={secondary_diff:.1%})", flush=True)
|
||||
|
||||
if secondary_valid and not primary_valid:
|
||||
# Secondary passes validation, primary doesn't - use secondary!
|
||||
print(f"[Merge] Using secondary amount {secondary.amount} (passes TVA validation)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
elif primary_valid and not secondary_valid:
|
||||
# Primary passes validation
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary_valid and secondary_valid:
|
||||
# Both valid - use higher confidence
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
# Neither valid - use the one closer to TVA calculation
|
||||
if secondary_diff < primary_diff:
|
||||
print(f"[Merge] Neither valid, using secondary {secondary.amount} (closer to TVA)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary.amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
@@ -438,13 +864,15 @@ def _merge_extractions(primary, secondary):
|
||||
result.receipt_date = secondary.receipt_date
|
||||
result.confidence_date = secondary.confidence_date
|
||||
|
||||
# CUI - prefer valid format
|
||||
# CUI - prefer valid format and version with RO prefix
|
||||
# Use CUIChecksumRule static methods (single source of truth)
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
def is_valid_cui(cui):
|
||||
if not cui:
|
||||
return False
|
||||
import re
|
||||
cui_clean = re.sub(r'^RO', '', cui.upper())
|
||||
return bool(re.match(r'^\d{6,10}$', cui_clean))
|
||||
digits = CUIChecksumRule.extract_digits(cui)
|
||||
return len(digits) >= 6 and len(digits) <= 10
|
||||
|
||||
if primary.cui and secondary.cui:
|
||||
if is_valid_cui(primary.cui) and not is_valid_cui(secondary.cui):
|
||||
@@ -452,22 +880,27 @@ def _merge_extractions(primary, secondary):
|
||||
elif is_valid_cui(secondary.cui) and not is_valid_cui(primary.cui):
|
||||
result.cui = secondary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
# Both valid - prefer the one with RO prefix if digits match
|
||||
primary_digits = CUIChecksumRule.extract_digits(primary.cui)
|
||||
secondary_digits = CUIChecksumRule.extract_digits(secondary.cui)
|
||||
if primary_digits == secondary_digits:
|
||||
if CUIChecksumRule.has_ro_prefix(secondary.cui) and not CUIChecksumRule.has_ro_prefix(primary.cui):
|
||||
result.cui = secondary.cui # Prefer version with RO
|
||||
print(f"[CUI Complement] Preferring secondary with RO: {secondary.cui}", flush=True)
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
elif primary.cui:
|
||||
result.cui = primary.cui
|
||||
elif secondary.cui:
|
||||
result.cui = secondary.cui
|
||||
|
||||
# TVA entries
|
||||
# TVA entries - ALWAYS prefer primary (docTR) when both have entries
|
||||
if primary.tva_entries and secondary.tva_entries:
|
||||
primary_total = sum(e.get('amount', Decimal('0')) for e in primary.tva_entries)
|
||||
secondary_total = sum(e.get('amount', Decimal('0')) for e in secondary.tva_entries)
|
||||
if primary_total >= secondary_total:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
else:
|
||||
result.tva_entries = secondary.tva_entries
|
||||
result.tva_total = secondary.tva_total
|
||||
# Always use primary (docTR) - higher quality OCR
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
elif primary.tva_entries:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
@@ -483,12 +916,36 @@ def _merge_extractions(primary, secondary):
|
||||
result.address = primary.address or secondary.address
|
||||
result.items_count = primary.items_count or secondary.items_count
|
||||
result.payment_methods = primary.payment_methods or secondary.payment_methods
|
||||
result.suggested_payment_mode = getattr(primary, 'suggested_payment_mode', None) or getattr(secondary, 'suggested_payment_mode', None)
|
||||
|
||||
# Client fields
|
||||
result.client_name = primary.client_name or secondary.client_name
|
||||
result.client_cui = primary.client_cui or secondary.client_cui
|
||||
result.client_address = primary.client_address or secondary.client_address
|
||||
|
||||
# Confidence fields - preserve from primary or pick best
|
||||
if primary.confidence_vendor >= secondary.confidence_vendor:
|
||||
result.confidence_vendor = primary.confidence_vendor
|
||||
else:
|
||||
result.confidence_vendor = secondary.confidence_vendor
|
||||
|
||||
if hasattr(primary, 'confidence_client') and hasattr(secondary, 'confidence_client'):
|
||||
if primary.confidence_client >= secondary.confidence_client:
|
||||
result.confidence_client = primary.confidence_client
|
||||
else:
|
||||
result.confidence_client = secondary.confidence_client
|
||||
|
||||
# Raw text - combine both for debugging/display
|
||||
raw_texts = []
|
||||
if primary.raw_text:
|
||||
raw_texts.append(primary.raw_text)
|
||||
if secondary.raw_text and secondary.raw_text != primary.raw_text:
|
||||
raw_texts.append(secondary.raw_text)
|
||||
result.raw_text = '\n---\n'.join(raw_texts) if raw_texts else ''
|
||||
|
||||
# Note: overall_confidence is a computed @property on ExtractionResult
|
||||
# It automatically calculates from confidence_amount, confidence_date, confidence_vendor
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -557,6 +1014,7 @@ def _extraction_to_dict(extraction) -> dict:
|
||||
"address": extraction.address,
|
||||
"items_count": extraction.items_count,
|
||||
"payment_methods": extraction.payment_methods,
|
||||
"suggested_payment_mode": getattr(extraction, 'suggested_payment_mode', None),
|
||||
# Client data
|
||||
"client_name": extraction.client_name,
|
||||
"client_cui": extraction.client_cui,
|
||||
|
||||
@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
|
||||
|
||||
result = rule.validate({"cui": "R01879855"})
|
||||
# result.is_valid = False (checksum mismatch)
|
||||
|
||||
Static methods available for direct use:
|
||||
CUIChecksumRule.calculate_checksum("1056260") -> 0
|
||||
CUIChecksumRule.validate_checksum("10562600") -> True
|
||||
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
|
||||
"""
|
||||
|
||||
# Fixed multipliers for 9 positions (Romanian Mod 11)
|
||||
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(cui_base: str) -> int:
|
||||
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
|
||||
|
||||
Args:
|
||||
cui_base: CUI digits WITHOUT the checksum digit (last digit)
|
||||
|
||||
Returns:
|
||||
Expected checksum digit (0-9), or -1 if invalid input
|
||||
"""
|
||||
if not cui_base or not cui_base.isdigit():
|
||||
return -1
|
||||
|
||||
# Pad base to 9 digits from LEFT
|
||||
base_padded = cui_base.zfill(9)
|
||||
base_digits = [int(d) for d in base_padded]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
|
||||
|
||||
# Calculate checksum
|
||||
checksum = (weighted_sum * 10) % 11
|
||||
if checksum == 10:
|
||||
checksum = 0
|
||||
|
||||
return checksum
|
||||
|
||||
@staticmethod
|
||||
def validate_checksum(cui_digits: str) -> bool:
|
||||
"""Check if CUI checksum is valid.
|
||||
|
||||
Args:
|
||||
cui_digits: Full CUI digits (including checksum as last digit)
|
||||
|
||||
Returns:
|
||||
True if checksum is valid, False otherwise
|
||||
"""
|
||||
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return False
|
||||
|
||||
base = cui_digits[:-1]
|
||||
declared = int(cui_digits[-1])
|
||||
expected = CUIChecksumRule.calculate_checksum(base)
|
||||
|
||||
return expected == declared
|
||||
|
||||
@staticmethod
|
||||
def has_ro_prefix(cui: str) -> bool:
|
||||
"""Check if CUI has RO prefix (proper format for VAT payers)."""
|
||||
if not cui:
|
||||
return False
|
||||
return cui.upper().strip().startswith('RO')
|
||||
|
||||
@staticmethod
|
||||
def extract_digits(cui: str) -> str:
|
||||
"""Extract digits from CUI, removing RO/R0 prefix."""
|
||||
if not cui:
|
||||
return ""
|
||||
cui = cui.strip().upper()
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui = cui[2:]
|
||||
return ''.join(c for c in cui if c.isdigit())
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "CUI Checksum Check (Mod 11)"
|
||||
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="No CUI to validate"
|
||||
)
|
||||
|
||||
# Normalize: remove RO/R0 prefix
|
||||
cui_clean = cui.strip().upper()
|
||||
if cui_clean.startswith("RO"):
|
||||
cui_clean = cui_clean[2:]
|
||||
elif cui_clean.startswith("R0"):
|
||||
cui_clean = cui_clean[2:]
|
||||
# Use static method to extract digits
|
||||
cui_clean = CUIChecksumRule.extract_digits(cui)
|
||||
|
||||
# Check format first
|
||||
if not cui_clean.isdigit():
|
||||
if not cui_clean:
|
||||
return ValidationResult(
|
||||
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
|
||||
message="CUI format invalid, skipping checksum"
|
||||
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="CUI length invalid, skipping checksum"
|
||||
)
|
||||
|
||||
# Extract digits
|
||||
digits = [int(d) for d in cui_clean]
|
||||
checksum_declared = digits[-1]
|
||||
base_digits = digits[:-1]
|
||||
|
||||
# Multipliers (trim to match base_digits length)
|
||||
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
multipliers = multipliers[:len(base_digits)]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
|
||||
|
||||
# Calculate expected checksum
|
||||
checksum_calculated = (weighted_sum * 10) % 11
|
||||
if checksum_calculated == 10:
|
||||
checksum_calculated = 0
|
||||
|
||||
if checksum_calculated != checksum_declared:
|
||||
# Use static method to validate checksum
|
||||
if not CUIChecksumRule.validate_checksum(cui_clean):
|
||||
# Calculate expected for error message
|
||||
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
|
||||
declared = int(cui_clean[-1])
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.3,
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
|
||||
severity="warning"
|
||||
)
|
||||
|
||||
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
|
||||
)
|
||||
|
||||
|
||||
class TVABasedTotalRule(ValidationRule):
|
||||
"""Validate TOTAL using reverse calculation from TVA amount.
|
||||
|
||||
This is a CRITICAL validation that catches cases where OCR extracts
|
||||
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
|
||||
we can calculate expected TOTAL from TVA alone.
|
||||
|
||||
Formula:
|
||||
Expected TOTAL = TVA / rate * (1 + rate)
|
||||
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
|
||||
|
||||
For TVA rate 21%:
|
||||
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
|
||||
|
||||
Example (benzina 27 oct):
|
||||
TVA = 49.58, rate = 21%
|
||||
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
|
||||
Extracted TOTAL = 205.66 (WRONG!)
|
||||
Rule detects mismatch and flags for escalation
|
||||
|
||||
Usage in multi-tier processing (e.g., doctr_plus):
|
||||
If this rule fails, the engine should proceed to next tier
|
||||
instead of returning early with potentially wrong data.
|
||||
"""
|
||||
|
||||
def __init__(self, tolerance_percent: float = 0.02):
|
||||
"""
|
||||
Args:
|
||||
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
|
||||
"""
|
||||
self.tolerance_percent = tolerance_percent
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "TVA-Based Total Check"
|
||||
|
||||
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
||||
total = data.get("amount")
|
||||
tva = data.get("tva")
|
||||
tva_entries = data.get("tva_entries", [])
|
||||
|
||||
if not total or not tva:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Insufficient data for TVA-based total validation"
|
||||
)
|
||||
|
||||
# Type safety
|
||||
try:
|
||||
total = float(total)
|
||||
tva = float(tva)
|
||||
except (TypeError, ValueError):
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Non-numeric values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
if tva <= 0 or total <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Zero or negative values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
# Try to determine TVA rate from entries
|
||||
tva_rate = None
|
||||
|
||||
# Check tva_entries for rate information
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
percent = entry.get('percent')
|
||||
if percent:
|
||||
try:
|
||||
tva_rate = float(percent) / 100.0
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: try to calculate rate from TVA/TOTAL ratio
|
||||
if not tva_rate:
|
||||
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
|
||||
# TVA/TOTAL = rate / (1 + rate)
|
||||
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
|
||||
base = total - tva
|
||||
if base > 0:
|
||||
calculated_rate = tva / base
|
||||
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
|
||||
if 0.04 <= calculated_rate <= 0.25:
|
||||
tva_rate = calculated_rate
|
||||
|
||||
if not tva_rate:
|
||||
# Assume most common rate: 21%
|
||||
tva_rate = 0.21
|
||||
|
||||
# Calculate expected TOTAL from TVA
|
||||
# TVA = BASE * rate → BASE = TVA / rate
|
||||
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
|
||||
expected_total = tva * (1 + tva_rate) / tva_rate
|
||||
|
||||
# Calculate difference
|
||||
diff = abs(total - expected_total)
|
||||
diff_percent = diff / expected_total if expected_total > 0 else 1.0
|
||||
|
||||
if diff_percent > self.tolerance_percent:
|
||||
# Significant mismatch - OCR likely extracted TOTAL wrong
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.5, # High penalty - this is a critical error
|
||||
message=(
|
||||
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
|
||||
f"TVA-calculated {expected_total:.2f} RON "
|
||||
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
|
||||
f"Likely OCR error on TOTAL."
|
||||
),
|
||||
severity="error"
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
|
||||
)
|
||||
|
||||
|
||||
class InterOCRConsistencyRule(ValidationRule):
|
||||
"""Validate consistency between multiple OCR results.
|
||||
|
||||
@@ -562,6 +741,7 @@ class OCRValidationEngine:
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
|
||||
PaymentSumRule(tolerance=0.02),
|
||||
TVAEntriesSumRule(tolerance=0.02),
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
|
||||
]
|
||||
|
||||
# Inter-OCR consistency rules
|
||||
@@ -699,39 +879,508 @@ class OCRValidationEngine:
|
||||
inter_ocr_ratios=inter_ocr_ratios
|
||||
)
|
||||
|
||||
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
|
||||
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
|
||||
|
||||
Runs critical cross-validation rules to detect obvious OCR errors.
|
||||
Used to decide whether to proceed to next processing tier or exit early.
|
||||
|
||||
Args:
|
||||
extraction_result: Extraction data dict with fields:
|
||||
- amount: Extracted TOTAL
|
||||
- tva: Extracted TVA total
|
||||
- tva_entries: List of TVA entries with rates
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
- passes_validation: True if no critical errors detected
|
||||
- confidence_penalty: Cumulative penalty (0.0-1.0)
|
||||
- error_messages: List of validation error messages
|
||||
|
||||
Example usage:
|
||||
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
|
||||
if not passes:
|
||||
print(f"Validation failed: {errors}, proceeding to next tier")
|
||||
# Continue to next processing tier instead of early exit
|
||||
"""
|
||||
errors = []
|
||||
total_penalty = 0.0
|
||||
|
||||
# Critical rules for early-exit decision-making
|
||||
# These determine if we can trust the extraction or need to proceed to next tier
|
||||
critical_rules = [
|
||||
# Cross-field validations (most important for detecting OCR errors)
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
|
||||
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
|
||||
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
|
||||
|
||||
# Format & checksum validations
|
||||
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
|
||||
CUIFormatRule(), # CUI should be 6-10 digits
|
||||
|
||||
# Sanity checks
|
||||
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
|
||||
]
|
||||
|
||||
for rule in critical_rules:
|
||||
result = rule.validate(extraction_result)
|
||||
if not result.is_valid:
|
||||
errors.append(result.message)
|
||||
total_penalty += result.confidence_penalty
|
||||
|
||||
# Cap penalty at 1.0
|
||||
total_penalty = min(1.0, total_penalty)
|
||||
|
||||
passes = len(errors) == 0
|
||||
return passes, total_penalty, errors
|
||||
|
||||
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
|
||||
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
|
||||
|
||||
@staticmethod
|
||||
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
|
||||
"""Try to repair CUI by attempting 1-digit corrections.
|
||||
|
||||
OCR often misreads similar-looking digits:
|
||||
- 5 ↔ 8 (most common in receipts)
|
||||
- 6 ↔ 0
|
||||
- 1 ↔ 7
|
||||
- 3 ↔ 8
|
||||
|
||||
Algorithm:
|
||||
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
|
||||
2. Skip first digit (position 0) - usually reliable in CUI
|
||||
3. Check checksum digit (last position) last
|
||||
4. Prefer common OCR digit confusions (5↔8, 6↔0)
|
||||
|
||||
Args:
|
||||
cui_digits: Original CUI digits (without RO prefix)
|
||||
|
||||
Returns:
|
||||
Repaired CUI digits if 1-digit fix found, else None
|
||||
"""
|
||||
if len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return None
|
||||
|
||||
# If already valid, return as-is
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
return cui_digits
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then position 1, then 0, then checksum
|
||||
# Skip position 0 (first digit) - it's usually reliable
|
||||
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
|
||||
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
|
||||
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
if CUIChecksumRule.validate_checksum(candidate):
|
||||
print(f"[CUI Repair] Fixed {cui_digits} → {candidate} (position {pos}: {original_digit}→{replacement})", flush=True)
|
||||
return candidate
|
||||
|
||||
# No single-digit fix found
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def normalize_cui(cui: Optional[str]) -> Optional[str]:
|
||||
"""Normalize CUI to RO prefix + digits format.
|
||||
"""Normalize CUI - fix OCR errors but preserve original format.
|
||||
|
||||
Rules:
|
||||
- R0 → RO (fix OCR error where O is read as 0)
|
||||
- Keep RO prefix if original had it (platitor TVA)
|
||||
- Do NOT add RO if original didn't have it (neplatitor TVA)
|
||||
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
|
||||
|
||||
Examples:
|
||||
10562600 → RO10562600
|
||||
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
|
||||
R010562600 → RO10562600 (fix R0 OCR error)
|
||||
RO10562600 → RO10562600 (unchanged)
|
||||
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
|
||||
|
||||
Args:
|
||||
cui: Raw CUI string from OCR
|
||||
|
||||
Returns:
|
||||
Normalized CUI with RO prefix, or None if invalid
|
||||
Normalized CUI, or None if invalid
|
||||
"""
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Remove existing prefix if present
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"):
|
||||
cui = cui[2:]
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui if c.isdigit())
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
|
||||
return None
|
||||
|
||||
# Add RO prefix
|
||||
return f"RO{cui_digits}"
|
||||
# Try to repair checksum if invalid
|
||||
if not CUIChecksumRule.validate_checksum(cui_digits):
|
||||
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
|
||||
if repaired:
|
||||
cui_digits = repaired
|
||||
|
||||
# Return with RO prefix only if original had it
|
||||
if had_ro_prefix:
|
||||
return f"RO{cui_digits}"
|
||||
else:
|
||||
return cui_digits
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_cui_from_db(
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match CUI against database of known suppliers.
|
||||
|
||||
This function:
|
||||
1. Validates CUI checksum
|
||||
2. If valid, looks up in database (exact match)
|
||||
3. If invalid, tries 1-digit corrections and looks up each candidate
|
||||
4. Returns the first match found in database
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid)
|
||||
db_session: SQLAlchemy async session for database lookups
|
||||
|
||||
Returns:
|
||||
Tuple of (corrected_cui, supplier_name) if found, else None
|
||||
|
||||
Usage in OCR extraction:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
"""
|
||||
from sqlalchemy import select, or_
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
return None
|
||||
|
||||
# Helper to format CUI with optional RO prefix
|
||||
def format_cui(digits: str) -> str:
|
||||
if had_ro_prefix:
|
||||
return f"RO{digits}"
|
||||
return digits
|
||||
|
||||
# Helper to search database for CUI
|
||||
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
|
||||
"""Search both synced and local suppliers for CUI."""
|
||||
# Search patterns: with and without RO prefix
|
||||
search_patterns = [digits, f"RO{digits}"]
|
||||
|
||||
# Search synced_suppliers first (more data)
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
or_(
|
||||
SyncedSupplier.fiscal_code == digits,
|
||||
SyncedSupplier.fiscal_code == f"RO{digits}",
|
||||
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
or_(
|
||||
LocalSupplier.fiscal_code == digits,
|
||||
LocalSupplier.fiscal_code == f"RO{digits}",
|
||||
LocalSupplier.fiscal_code == digits.lstrip('0'),
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
return None
|
||||
|
||||
# 1. If checksum is valid, check if it exists in database (exact match)
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
match = await lookup_cui_in_db(cui_digits)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] Exact match found: {cui} → {match[0]} ({match[1]})", flush=True)
|
||||
return match
|
||||
# Valid checksum but not in DB - return as-is (it might be a new supplier)
|
||||
return None
|
||||
|
||||
# 2. Invalid checksum - try 1-digit corrections and verify against database
|
||||
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then ends
|
||||
middle_positions = list(range(2, last_pos))
|
||||
position_order = middle_positions + [1, last_pos, 0]
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
|
||||
# Only consider if checksum is valid
|
||||
if not CUIChecksumRule.validate_checksum(candidate):
|
||||
continue
|
||||
|
||||
# Check if this corrected CUI exists in database
|
||||
match = await lookup_cui_in_db(candidate)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] DB match: {cui} → {match[0]} ({match[1]}) [pos {pos}: {original_digit}→{replacement}]", flush=True)
|
||||
return match
|
||||
|
||||
# No match found in database
|
||||
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_by_name_and_cui(
|
||||
vendor_name: Optional[str],
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match supplier by NAME, then narrow down by CUI.
|
||||
|
||||
Algorithm:
|
||||
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
|
||||
2. Search suppliers by fuzzy name match (LIKE %name%)
|
||||
3. If multiple results, use fuzzy CUI matching to pick best one
|
||||
4. Return the best match
|
||||
|
||||
Args:
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
from sqlalchemy import select, or_, func
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
import re
|
||||
|
||||
if not vendor_name or len(vendor_name) < 3:
|
||||
return None
|
||||
|
||||
# Normalize vendor name for search
|
||||
def normalize_name(name: str) -> str:
|
||||
"""Normalize name for fuzzy matching."""
|
||||
name = name.upper()
|
||||
# Remove company type suffixes
|
||||
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
|
||||
name = name.replace(suffix, '')
|
||||
# Remove punctuation and extra spaces
|
||||
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
|
||||
name = ' '.join(name.split())
|
||||
return name.strip()
|
||||
|
||||
# Extract key words from vendor name (for fuzzy search)
|
||||
normalized_name = normalize_name(vendor_name)
|
||||
name_words = [w for w in normalized_name.split() if len(w) >= 3]
|
||||
|
||||
if not name_words:
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
|
||||
|
||||
# Build search pattern - use first significant word
|
||||
primary_word = name_words[0]
|
||||
search_pattern = f"%{primary_word}%"
|
||||
|
||||
candidates = []
|
||||
|
||||
# Search synced_suppliers
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
func.upper(SyncedSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
func.upper(LocalSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
if not candidates:
|
||||
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
|
||||
|
||||
# If only one candidate, return it
|
||||
if len(candidates) == 1:
|
||||
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
# Multiple candidates - try to narrow down by CUI
|
||||
if cui:
|
||||
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
|
||||
|
||||
if len(cui_digits) >= 6:
|
||||
# Score each candidate by how similar their CUI is to the extracted one
|
||||
def cui_similarity(candidate_cui: str) -> int:
|
||||
"""Calculate how many digits match in the same position."""
|
||||
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
|
||||
if len(cand_digits) != len(cui_digits):
|
||||
return 0
|
||||
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
|
||||
|
||||
# Sort candidates by CUI similarity (descending)
|
||||
scored = [(cui_similarity(c[0]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
best_score, best_match = scored[0]
|
||||
# Require at least 70% digit match for CUI similarity
|
||||
min_matching = int(len(cui_digits) * 0.7)
|
||||
|
||||
if best_score >= min_matching:
|
||||
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
|
||||
return best_match
|
||||
|
||||
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
|
||||
|
||||
# If still multiple and no CUI match, try name similarity
|
||||
def name_similarity(candidate_name: str) -> int:
|
||||
"""Count how many keywords match."""
|
||||
norm_cand = normalize_name(candidate_name)
|
||||
return sum(1 for w in name_words if w in norm_cand)
|
||||
|
||||
scored = [(name_similarity(c[1]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if scored[0][0] >= 2: # At least 2 keywords match
|
||||
best_match = scored[0][1]
|
||||
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
|
||||
return best_match
|
||||
|
||||
# Return first candidate if nothing else works
|
||||
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_supplier(
|
||||
cui: Optional[str],
|
||||
vendor_name: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
|
||||
|
||||
Strategy:
|
||||
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
|
||||
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
# Step 1: Try fuzzy CUI matching
|
||||
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
|
||||
if cui_match:
|
||||
return cui_match
|
||||
|
||||
# Step 2: Fallback to fuzzy NAME + CUI matching
|
||||
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
|
||||
vendor_name, cui, db_session
|
||||
)
|
||||
if name_match:
|
||||
return name_match
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user