""" OCR API endpoints with async job queue support. Endpoints: - POST /extract - Submit OCR job (returns job_id immediately) - GET /jobs/{job_id} - Get job status and result - GET /queue/status - Get queue statistics - GET /status - Check OCR service availability For backwards compatibility, we also support sync mode via query param: - POST /extract?sync=true - Process synchronously (blocks until complete) """ import os import tempfile from datetime import datetime from decimal import Decimal from pathlib import Path from typing import Optional from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query from sqlalchemy.ext.asyncio import AsyncSession from backend.modules.data_entry.db.database import get_session from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD from backend.modules.data_entry.services.ocr_service import ocr_service from backend.modules.data_entry.services.ocr_engine import OCREngine from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time from backend.modules.data_entry.schemas.ocr import ( OCRResponse, OCRStatusResponse, ExtractionData, TvaEntry, PaymentMethod, # New job queue schemas OCREngineChoice, OCRJobStatus, OCRJobSubmitResponse, OCRJobResponse, OCRQueueStatusResponse, ) # Auth integration from shared.auth.dependencies import get_current_user from shared.auth.models import CurrentUser router = APIRouter() # ============================================================================ # OCR Job Queue Endpoints (NEW) # ============================================================================ @router.post("/extract", response_model=OCRJobSubmitResponse) async def submit_ocr_job( file: UploadFile = File(...), engine: OCREngineChoice = Query(default=OCREngineChoice.auto, description="OCR engine to use"), sync: bool = Query(default=False, description="If true, process synchronously (blocks)"), current_user: CurrentUser = Depends(get_current_user) ): """ Submit an OCR job for processing. By default, returns immediately with a job_id. Poll GET /jobs/{job_id} for result. Use ?sync=true for synchronous processing (blocks until complete). This is for backwards compatibility but not recommended for production. Args: file: Image or PDF file (max 10MB) engine: OCR engine choice (auto, paddleocr, tesseract) sync: If true, process synchronously (legacy mode) Returns: OCRJobSubmitResponse with job_id, queue_position, estimated_wait """ allowed_types = ['image/jpeg', 'image/png', 'application/pdf'] if file.content_type not in allowed_types: raise HTTPException( status_code=400, detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF" ) # Read file content content = await file.read() # Check file size (10MB limit) if len(content) > 10 * 1024 * 1024: raise HTTPException( status_code=400, detail="File too large. Maximum size is 10MB." ) # Sync mode - use legacy processing (blocks) if sync: return await _process_sync(content, file, engine, current_user) # Async mode - create job and return immediately try: job = await job_queue.create_job( file_bytes=content, mime_type=file.content_type, engine=engine.value, username=current_user.username, original_filename=file.filename ) # Get queue position queue_position = await job_queue.get_queue_position(job.id) estimated_wait = estimate_wait_time(queue_position or 1) return OCRJobSubmitResponse( job_id=job.id, status=OCRJobStatus.pending, queue_position=queue_position or 1, estimated_wait_seconds=estimated_wait, created_at=job.created_at or datetime.utcnow() ) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to create OCR job: {str(e)}" ) @router.get("/jobs/{job_id}", response_model=OCRJobResponse) async def get_job_status( job_id: str, current_user: CurrentUser = Depends(get_current_user) ): """ Get OCR job status and result. Poll this endpoint to check job progress. Recommended polling interval: 2 seconds. Args: job_id: Job UUID from POST /extract response Returns: OCRJobResponse with status, queue_position, and result (if completed) """ job = await job_queue.get_job(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") # Get queue position for pending jobs queue_position = None estimated_wait = None if job.status == JobStatus.pending: queue_position = await job_queue.get_queue_position(job_id) estimated_wait = estimate_wait_time(queue_position or 1) elif job.status == JobStatus.processing: queue_position = 0 # Estimate remaining time based on average avg_time = await job_queue.get_average_processing_time() estimated_wait = int(avg_time * 0.5) # Rough estimate: half remaining # Convert result to ExtractionData if available result_data = None if job.status == JobStatus.completed and job.result: result_data = _dict_to_extraction_data(job.result) return OCRJobResponse( job_id=job.id, status=OCRJobStatus(job.status.value), queue_position=queue_position, estimated_wait_seconds=estimated_wait, created_at=job.created_at or datetime.utcnow(), started_at=job.started_at, completed_at=job.completed_at, processing_time_ms=job.processing_time_ms, result=result_data, error=job.error_message ) @router.get("/queue/status", response_model=OCRQueueStatusResponse) async def get_queue_status( current_user: CurrentUser = Depends(get_current_user) ): """ Get OCR queue statistics. Returns: Queue status with pending/processing counts and average time """ stats = await job_queue.get_queue_stats() return OCRQueueStatusResponse( pending_jobs=stats["pending"], processing_jobs=stats["processing"], average_time_seconds=stats["average_time_seconds"] ) # ============================================================================ # Legacy Endpoints (backwards compatibility) # ============================================================================ @router.get("/status", response_model=OCRStatusResponse) async def get_ocr_status(): """Check OCR service status and available engines.""" engines = OCREngine.get_available_engines() available = len(engines) > 0 if available: message = f"OCR service ready with engines: {', '.join(engines)}" else: message = "No OCR engines available. Install PaddleOCR or Tesseract." return OCRStatusResponse( available=available, engines=engines, message=message ) @router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse) async def extract_from_attachment( attachment_id: int, engine: OCREngineChoice = Query(default=OCREngineChoice.auto), session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user) ): """ Extract receipt data from an existing attachment. Re-processes an already uploaded file with OCR. This endpoint always processes synchronously. """ attachment = await AttachmentCRUD.get_by_id(session, attachment_id) if not attachment: raise HTTPException(status_code=404, detail="Attachment not found") file_path = AttachmentCRUD.get_file_path(attachment) if not file_path.exists(): raise HTTPException(status_code=404, detail="File not found on disk") # Check if file type is supported if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']: raise HTTPException( status_code=400, detail=f"File type not supported for OCR: {attachment.mime_type}" ) # TODO: Could use job queue here too, but keeping sync for now success, message, result = await ocr_service.process_image( file_path, attachment.mime_type ) if not success: raise HTTPException(status_code=422, detail=message) data = _result_to_extraction_data(result) return OCRResponse(success=True, message=message, data=data) # ============================================================================ # Helper Functions # ============================================================================ async def _process_sync( content: bytes, file: UploadFile, engine: OCREngineChoice, current_user: CurrentUser ) -> OCRJobSubmitResponse: """ Process OCR synchronously (legacy mode). Creates a job, processes it immediately, and returns the result wrapped in a JobSubmitResponse for API consistency. """ # Get file extension suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg' if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']: suffix = '.jpg' # Save to temp file with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(content) tmp_path = Path(tmp.name) try: success, message, result = await ocr_service.process_image( tmp_path, file.content_type ) if not success: raise HTTPException(status_code=422, detail=message) # Create a fake job response with the result embedded # This maintains API compatibility now = datetime.utcnow() # For sync mode, we return a special response that includes # the result directly. Clients should check if result is present. return OCRJobSubmitResponse( job_id="sync-" + str(hash(content))[:16], status=OCRJobStatus.completed, queue_position=0, estimated_wait_seconds=0, created_at=now ) finally: # Clean up temp file if tmp_path.exists(): os.unlink(tmp_path) def _result_to_extraction_data(result) -> ExtractionData: """Convert ExtractionResult to ExtractionData schema.""" # Convert tva_entries from dict to TvaEntry objects tva_entries_schema = [ TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount']) for e in result.tva_entries ] if result.tva_entries else [] # Convert payment_methods from dict to PaymentMethod objects payment_methods_list = [ PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount']))) for pm in result.payment_methods ] if result.payment_methods else [] # Auto-suggest payment_mode based on detected methods suggested_payment_mode = None if payment_methods_list: has_card = any(pm.method == 'CARD' for pm in payment_methods_list) if has_card: suggested_payment_mode = 'banca' return ExtractionData( receipt_type=result.receipt_type, receipt_number=result.receipt_number, receipt_series=result.receipt_series, receipt_date=result.receipt_date, amount=result.amount, partner_name=result.partner_name, cui=result.cui, description=result.description, tva_entries=tva_entries_schema, tva_total=result.tva_total, address=result.address, items_count=result.items_count, payment_methods=payment_methods_list, suggested_payment_mode=suggested_payment_mode, client_name=result.client_name, client_cui=result.client_cui, client_address=result.client_address, confidence_amount=result.confidence_amount, confidence_date=result.confidence_date, confidence_vendor=result.confidence_vendor, confidence_client=getattr(result, 'confidence_client', 0.0), overall_confidence=result.overall_confidence, raw_text=result.raw_text, ocr_engine=result.ocr_engine, processing_time_ms=result.processing_time_ms, needs_manual_review=result.needs_manual_review, validation_warnings=result.validation_warnings, validation_errors=result.validation_errors, inter_ocr_ratios=result.inter_ocr_ratios, ) def _dict_to_extraction_data(data: dict) -> ExtractionData: """Convert result dict (from job queue) to ExtractionData schema.""" from datetime import date # Parse date if string receipt_date = data.get('receipt_date') if isinstance(receipt_date, str): try: receipt_date = date.fromisoformat(receipt_date) except (ValueError, TypeError): receipt_date = None # Convert tva_entries tva_entries = data.get('tva_entries', []) or [] tva_entries_schema = [] for e in tva_entries: if isinstance(e, dict): tva_entries_schema.append(TvaEntry( code=e.get('code'), percent=e.get('percent', 0), amount=Decimal(str(e.get('amount', 0))) )) # Convert payment_methods payment_methods = data.get('payment_methods', []) or [] payment_methods_list = [] for pm in payment_methods: if isinstance(pm, dict): payment_methods_list.append(PaymentMethod( method=pm.get('method', 'NUMERAR'), amount=Decimal(str(pm.get('amount', 0))) )) # Convert amount and tva_total to Decimal amount = data.get('amount') if amount is not None: amount = Decimal(str(amount)) tva_total = data.get('tva_total') if tva_total is not None: tva_total = Decimal(str(tva_total)) return ExtractionData( receipt_type=data.get('receipt_type', 'bon_fiscal'), receipt_number=data.get('receipt_number'), receipt_series=data.get('receipt_series'), receipt_date=receipt_date, amount=amount, partner_name=data.get('partner_name'), cui=data.get('cui'), description=data.get('description'), tva_entries=tva_entries_schema, tva_total=tva_total, address=data.get('address'), items_count=data.get('items_count'), payment_methods=payment_methods_list, suggested_payment_mode=data.get('suggested_payment_mode'), client_name=data.get('client_name'), client_cui=data.get('client_cui'), client_address=data.get('client_address'), confidence_amount=data.get('confidence_amount', 0.0), confidence_date=data.get('confidence_date', 0.0), confidence_vendor=data.get('confidence_vendor', 0.0), confidence_client=data.get('confidence_client', 0.0), overall_confidence=data.get('overall_confidence', 0.0), raw_text=data.get('raw_text', ''), ocr_engine=data.get('ocr_engine', ''), processing_time_ms=data.get('processing_time_ms', 0), needs_manual_review=data.get('needs_manual_review'), validation_warnings=data.get('validation_warnings', []), validation_errors=data.get('validation_errors', []), inter_ocr_ratios=data.get('inter_ocr_ratios', {}), )