""" OCR Job Worker - Background Task for Queue Processing Runs as an asyncio background task in FastAPI. Continuously polls the job queue and processes OCR requests IN PARALLEL. Architecture: FastAPI startup ↓ start_job_worker() ↓ asyncio.create_task(_job_worker_loop()) ↓ while True: # Process up to OCR_WORKERS jobs concurrently jobs = get_pending_jobs(limit=available_slots) for job in jobs: asyncio.create_task(_process_job(job)) await asyncio.sleep(0.1) """ import asyncio import logging import os import time from pathlib import Path from typing import Optional, Set from .job_queue import job_queue, OCRJobStatus, OCRJob from .ocr_worker_pool import ocr_worker_pool logger = logging.getLogger(__name__) # Global task reference _job_worker_task: Optional[asyncio.Task] = None _cleanup_task: Optional[asyncio.Task] = None _shutdown_event: Optional[asyncio.Event] = None _active_tasks: Set[asyncio.Task] = set() # Track active job tasks _concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs # Configuration POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel) CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing async def _job_worker_loop() -> None: """ Main worker loop - processes jobs from queue IN PARALLEL. Runs continuously until shutdown. Uses semaphore to limit concurrent jobs to OCR_WORKERS count. Launches jobs as background tasks without waiting for completion. """ global _shutdown_event, _active_tasks, _concurrency_semaphore # Get max concurrent jobs from env (matches worker pool size) max_concurrent = int(os.getenv('OCR_WORKERS', '2')) _concurrency_semaphore = asyncio.Semaphore(max_concurrent) _active_tasks = set() logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...") _shutdown_event = asyncio.Event() consecutive_errors = 0 max_consecutive_errors = 10 while not _shutdown_event.is_set(): try: # Clean up completed tasks done_tasks = {t for t in _active_tasks if t.done()} for task in done_tasks: _active_tasks.discard(task) # Check for exceptions try: task.result() except Exception as e: logger.error(f"[JobWorker] Task failed: {e}") # Check if we have capacity for more jobs active_count = len(_active_tasks) available_slots = max_concurrent - active_count if available_slots > 0: # Get next pending job job = await job_queue.get_next_pending() if job: consecutive_errors = 0 # Launch job processing as background task task = asyncio.create_task(_process_job_with_semaphore(job)) _active_tasks.add(task) logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})") else: # No pending jobs - wait briefly try: await asyncio.wait_for( _shutdown_event.wait(), timeout=POLL_INTERVAL_SECONDS ) if _shutdown_event.is_set(): break except asyncio.TimeoutError: pass else: # At capacity - wait for a slot to free up await asyncio.sleep(POLL_INTERVAL_SECONDS) except asyncio.CancelledError: logger.info("[JobWorker] Worker loop cancelled") break except Exception as e: consecutive_errors += 1 logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}") if consecutive_errors >= max_consecutive_errors: logger.error("[JobWorker] Too many consecutive errors, stopping worker") break await asyncio.sleep(min(consecutive_errors * 2, 30)) # Wait for active tasks to complete on shutdown if _active_tasks: logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...") await asyncio.gather(*_active_tasks, return_exceptions=True) logger.info("[JobWorker] Worker loop stopped") async def _process_job_with_semaphore(job: OCRJob) -> None: """ Process job with semaphore to limit concurrency. Acquires semaphore before processing, releases after. This ensures we don't exceed OCR_WORKERS concurrent jobs. """ global _concurrency_semaphore async with _concurrency_semaphore: await _process_job(job) async def _process_job(job: OCRJob) -> None: """ Process a single OCR job. Reads file, submits to worker pool, updates job status, and saves metrics for analytics. Args: job: OCRJob to process """ logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}") start_time = time.time() file_size = 0 file_type = "image/jpeg" try: # Note: Job already marked as 'processing' atomically in get_next_pending() # Read file bytes file_path = Path(job.file_path) if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") with open(file_path, 'rb') as f: file_bytes = f.read() file_size = len(file_bytes) # Determine file type from job or extension file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg' # Submit to worker pool result = await ocr_worker_pool.submit_task( image_bytes=file_bytes, engine=job.engine, preprocessing="auto", timeout=OCR_TIMEOUT_SECONDS ) elapsed_ms = int((time.time() - start_time) * 1000) if result.get("success"): # Job completed successfully extraction = result.get("extraction", {}) # Include raw_texts for analysis (from all OCR engine passes) extraction['raw_texts'] = result.get("raw_texts", []) # Extract actual OCR processing time from extraction result ocr_time_ms = extraction.get('processing_time_ms', 0) # Debug: log suggested_payment_mode spm = extraction.get('suggested_payment_mode') logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}") await job_queue.update_status( job_id=job.id, status=OCRJobStatus.completed, result=extraction, processing_time_ms=elapsed_ms, ocr_time_ms=ocr_time_ms ) logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)") # Save metrics for successful job await _save_job_metrics( job_id=job.id, username=job.created_by or 'unknown', engine_requested=job.engine, engine_used=extraction.get('ocr_engine', job.engine), processing_time_ms=elapsed_ms, file_size_bytes=file_size, file_type=file_type, original_filename=job.original_filename, success=True, overall_confidence=extraction.get('overall_confidence', 0.0), fields_extracted=_count_extracted_fields(extraction), needs_manual_review=extraction.get('needs_manual_review'), validation_warnings_count=len(extraction.get('validation_warnings', [])), validation_errors_count=len(extraction.get('validation_errors', [])), ) else: # Job failed error_msg = result.get("error", "Unknown error") await job_queue.update_status( job_id=job.id, status=OCRJobStatus.failed, error=error_msg, processing_time_ms=elapsed_ms ) logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}") # Save metrics for failed job await _save_job_metrics( job_id=job.id, username=job.created_by or 'unknown', engine_requested=job.engine, engine_used=job.engine, processing_time_ms=elapsed_ms, file_size_bytes=file_size, file_type=file_type, original_filename=job.original_filename, success=False, error_message=error_msg, ) except Exception as e: elapsed_ms = int((time.time() - start_time) * 1000) logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}") await job_queue.update_status( job_id=job.id, status=OCRJobStatus.failed, error=str(e), processing_time_ms=elapsed_ms ) # Save metrics for error job await _save_job_metrics( job_id=job.id, username=job.created_by or 'unknown', engine_requested=job.engine, engine_used=job.engine, processing_time_ms=elapsed_ms, file_size_bytes=file_size, file_type=file_type, original_filename=job.original_filename, success=False, error_message=str(e), ) finally: # Cleanup file after processing try: await job_queue.cleanup_job_file(job.id) except Exception as e: logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}") async def _cleanup_loop() -> None: """ Periodic cleanup of expired jobs. Runs every hour to delete jobs older than 24 hours. """ global _shutdown_event logger.info("[JobWorker] Starting cleanup loop...") while not _shutdown_event.is_set(): try: # Wait for interval or shutdown try: await asyncio.wait_for( _shutdown_event.wait(), timeout=CLEANUP_INTERVAL_SECONDS ) if _shutdown_event.is_set(): break except asyncio.TimeoutError: pass # Normal timeout, do cleanup # Run cleanup deleted = await job_queue.cleanup_expired() if deleted > 0: logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs") except asyncio.CancelledError: logger.info("[JobWorker] Cleanup loop cancelled") break except Exception as e: logger.error(f"[JobWorker] Cleanup error: {e}") await asyncio.sleep(60) # Retry after 1 minute logger.info("[JobWorker] Cleanup loop stopped") async def start_job_worker() -> bool: """ Start the job worker background task. Called at FastAPI startup to begin processing queue. Returns: True if started successfully """ global _job_worker_task, _cleanup_task, _shutdown_event if _job_worker_task is not None and not _job_worker_task.done(): logger.warning("[JobWorker] Already running") return True try: # Initialize job queue await job_queue.initialize() # Initialize worker pool if not ocr_worker_pool.initialize(): logger.error("[JobWorker] Failed to initialize worker pool") return False # Pre-warm worker pool in BACKGROUND (don't block startup) # First OCR request may be slower if prewarm isn't done yet async def _background_prewarm(): logger.info("[JobWorker] Pre-warming OCR worker pool (background)...") warmup_success = await ocr_worker_pool.prewarm(timeout=90.0) if warmup_success: logger.info("[JobWorker] OCR worker pool pre-warmed successfully") else: logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower") asyncio.create_task(_background_prewarm()) # Start worker loop _shutdown_event = asyncio.Event() _job_worker_task = asyncio.create_task(_job_worker_loop()) # Start cleanup loop _cleanup_task = asyncio.create_task(_cleanup_loop()) logger.info("[JobWorker] Started successfully") return True except Exception as e: logger.error(f"[JobWorker] Failed to start: {e}") return False async def stop_job_worker() -> None: """ Stop the job worker background task. Called at FastAPI shutdown to gracefully stop processing. """ global _job_worker_task, _cleanup_task, _shutdown_event logger.info("[JobWorker] Stopping...") # Signal shutdown if _shutdown_event: _shutdown_event.set() # Cancel worker task if _job_worker_task and not _job_worker_task.done(): _job_worker_task.cancel() try: await _job_worker_task except asyncio.CancelledError: pass # Cancel cleanup task if _cleanup_task and not _cleanup_task.done(): _cleanup_task.cancel() try: await _cleanup_task except asyncio.CancelledError: pass # Shutdown worker pool ocr_worker_pool.shutdown(wait=True) _job_worker_task = None _cleanup_task = None _shutdown_event = None logger.info("[JobWorker] Stopped") def is_running() -> bool: """Check if job worker is running.""" return _job_worker_task is not None and not _job_worker_task.done() def estimate_wait_time(queue_position: int) -> int: """ Estimate wait time for a job in queue. Args: queue_position: Position in queue (1 = next) Returns: Estimated wait time in seconds """ if queue_position <= 0: return 0 # Get average processing time (synchronous fallback) # Default ~7 seconds per job if no data avg_time = 7.0 try: # Try to get from queue stats import asyncio loop = asyncio.get_event_loop() if loop.is_running(): # Can't use sync call in async context, use default pass else: avg_time = loop.run_until_complete(job_queue.get_average_processing_time()) except Exception: pass # Estimate: position * average_time return int(queue_position * avg_time) # ============================================================================ # Metrics Helper Functions # ============================================================================ async def _save_job_metrics( job_id: str, username: str, engine_requested: str, engine_used: str, processing_time_ms: int = 0, file_size_bytes: int = 0, file_type: str = "image/jpeg", original_filename: Optional[str] = None, success: bool = True, error_message: Optional[str] = None, overall_confidence: float = 0.0, fields_extracted: int = 0, needs_manual_review: Optional[bool] = None, validation_warnings_count: int = 0, validation_errors_count: int = 0, ) -> None: """ Save OCR job metrics to database for analytics. Called after each job completes (success or failure). Errors are logged but don't affect job processing. """ try: from backend.modules.data_entry.db.database import get_db_session from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD async with await get_db_session() as session: await OCRMetricsCRUD.create( session=session, job_id=job_id, username=username, engine_requested=engine_requested, engine_used=engine_used, processing_time_ms=processing_time_ms, file_size_bytes=file_size_bytes, file_type=file_type, original_filename=original_filename, success=success, error_message=error_message, overall_confidence=overall_confidence, fields_extracted=fields_extracted, needs_manual_review=needs_manual_review, validation_warnings_count=validation_warnings_count, validation_errors_count=validation_errors_count, ) logger.debug(f"[JobWorker] Saved metrics for job {job_id}") except Exception as e: # Log but don't fail - metrics are nice-to-have logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}") def _count_extracted_fields(extraction: dict) -> int: """ Count number of successfully extracted fields from OCR result. Counts non-None values in key fields. """ key_fields = [ 'receipt_number', 'receipt_date', 'amount', 'partner_name', 'cui', 'tva_total', 'address', 'items_count', ] count = 0 for field in key_fields: value = extraction.get(field) if value is not None and value != '' and value != []: count += 1 # Also count TVA entries if present tva_entries = extraction.get('tva_entries', []) if tva_entries and len(tva_entries) > 0: count += 1 # Count payment methods if present payment_methods = extraction.get('payment_methods', []) if payment_methods and len(payment_methods) > 0: count += 1 return count