Files
roa2web-service-auto/backend/modules/data_entry/routers/ocr.py
Marius Mutu 74f7aefc26 feat(ocr): Implement persistent worker pool with SQLite job queue
Major OCR infrastructure improvements:
- Add persistent SQLite-based job queue for OCR tasks
- Implement worker pool with process isolation and auto-restart
- Add OCR engine selector dropdown (Tesseract/PaddleOCR) in upload zone
- Optimize Tesseract preprocessing based on benchmark results (8x faster)
- Add recognize_cif_optimized() with multi-strategy CIF extraction
- Add Romanian CIF checksum validation
- Increase Telegram long polling timeout from 10s to 30s

Squashed commits:
- feat(ocr): Implement persistent worker pool with SQLite job queue
- feat(ocr): Add OCR engine selector dropdown to upload zone
- perf(telegram): Increase long polling timeout from 10s to 30s
- perf(ocr): Optimize Tesseract preprocessing based on benchmark results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-31 12:32:12 +02:00

447 lines
15 KiB
Python

"""
OCR API endpoints with async job queue support.
Endpoints:
- POST /extract - Submit OCR job (returns job_id immediately)
- GET /jobs/{job_id} - Get job status and result
- GET /queue/status - Get queue statistics
- GET /status - Check OCR service availability
For backwards compatibility, we also support sync mode via query param:
- POST /extract?sync=true - Process synchronously (blocks until complete)
"""
import os
import tempfile
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
from backend.modules.data_entry.services.ocr_service import ocr_service
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
from backend.modules.data_entry.schemas.ocr import (
OCRResponse,
OCRStatusResponse,
ExtractionData,
TvaEntry,
PaymentMethod,
# New job queue schemas
OCREngineChoice,
OCRJobStatus,
OCRJobSubmitResponse,
OCRJobResponse,
OCRQueueStatusResponse,
)
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============================================================================
# OCR Job Queue Endpoints (NEW)
# ============================================================================
@router.post("/extract", response_model=OCRJobSubmitResponse)
async def submit_ocr_job(
file: UploadFile = File(...),
engine: OCREngineChoice = Query(default=OCREngineChoice.auto, description="OCR engine to use"),
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Submit an OCR job for processing.
By default, returns immediately with a job_id. Poll GET /jobs/{job_id} for result.
Use ?sync=true for synchronous processing (blocks until complete).
This is for backwards compatibility but not recommended for production.
Args:
file: Image or PDF file (max 10MB)
engine: OCR engine choice (auto, paddleocr, tesseract)
sync: If true, process synchronously (legacy mode)
Returns:
OCRJobSubmitResponse with job_id, queue_position, estimated_wait
"""
allowed_types = ['image/jpeg', 'image/png', 'application/pdf']
if file.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF"
)
# Read file content
content = await file.read()
# Check file size (10MB limit)
if len(content) > 10 * 1024 * 1024:
raise HTTPException(
status_code=400,
detail="File too large. Maximum size is 10MB."
)
# Sync mode - use legacy processing (blocks)
if sync:
return await _process_sync(content, file, engine, current_user)
# Async mode - create job and return immediately
try:
job = await job_queue.create_job(
file_bytes=content,
mime_type=file.content_type,
engine=engine.value,
username=current_user.username,
original_filename=file.filename
)
# Get queue position
queue_position = await job_queue.get_queue_position(job.id)
estimated_wait = estimate_wait_time(queue_position or 1)
return OCRJobSubmitResponse(
job_id=job.id,
status=OCRJobStatus.pending,
queue_position=queue_position or 1,
estimated_wait_seconds=estimated_wait,
created_at=job.created_at or datetime.utcnow()
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to create OCR job: {str(e)}"
)
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
async def get_job_status(
job_id: str,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR job status and result.
Poll this endpoint to check job progress.
Recommended polling interval: 2 seconds.
Args:
job_id: Job UUID from POST /extract response
Returns:
OCRJobResponse with status, queue_position, and result (if completed)
"""
job = await job_queue.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Get queue position for pending jobs
queue_position = None
estimated_wait = None
if job.status == JobStatus.pending:
queue_position = await job_queue.get_queue_position(job_id)
estimated_wait = estimate_wait_time(queue_position or 1)
elif job.status == JobStatus.processing:
queue_position = 0
# Estimate remaining time based on average
avg_time = await job_queue.get_average_processing_time()
estimated_wait = int(avg_time * 0.5) # Rough estimate: half remaining
# Convert result to ExtractionData if available
result_data = None
if job.status == JobStatus.completed and job.result:
result_data = _dict_to_extraction_data(job.result)
return OCRJobResponse(
job_id=job.id,
status=OCRJobStatus(job.status.value),
queue_position=queue_position,
estimated_wait_seconds=estimated_wait,
created_at=job.created_at or datetime.utcnow(),
started_at=job.started_at,
completed_at=job.completed_at,
processing_time_ms=job.processing_time_ms,
result=result_data,
error=job.error_message
)
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
async def get_queue_status(
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR queue statistics.
Returns:
Queue status with pending/processing counts and average time
"""
stats = await job_queue.get_queue_stats()
return OCRQueueStatusResponse(
pending_jobs=stats["pending"],
processing_jobs=stats["processing"],
average_time_seconds=stats["average_time_seconds"]
)
# ============================================================================
# Legacy Endpoints (backwards compatibility)
# ============================================================================
@router.get("/status", response_model=OCRStatusResponse)
async def get_ocr_status():
"""Check OCR service status and available engines."""
engines = OCREngine.get_available_engines()
available = len(engines) > 0
if available:
message = f"OCR service ready with engines: {', '.join(engines)}"
else:
message = "No OCR engines available. Install PaddleOCR or Tesseract."
return OCRStatusResponse(
available=available,
engines=engines,
message=message
)
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
async def extract_from_attachment(
attachment_id: int,
engine: OCREngineChoice = Query(default=OCREngineChoice.auto),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Extract receipt data from an existing attachment.
Re-processes an already uploaded file with OCR.
This endpoint always processes synchronously.
"""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
file_path = AttachmentCRUD.get_file_path(attachment)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
# Check if file type is supported
if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']:
raise HTTPException(
status_code=400,
detail=f"File type not supported for OCR: {attachment.mime_type}"
)
# TODO: Could use job queue here too, but keeping sync for now
success, message, result = await ocr_service.process_image(
file_path, attachment.mime_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
data = _result_to_extraction_data(result)
return OCRResponse(success=True, message=message, data=data)
# ============================================================================
# Helper Functions
# ============================================================================
async def _process_sync(
content: bytes,
file: UploadFile,
engine: OCREngineChoice,
current_user: CurrentUser
) -> OCRJobSubmitResponse:
"""
Process OCR synchronously (legacy mode).
Creates a job, processes it immediately, and returns the result
wrapped in a JobSubmitResponse for API consistency.
"""
# Get file extension
suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg'
if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']:
suffix = '.jpg'
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(content)
tmp_path = Path(tmp.name)
try:
success, message, result = await ocr_service.process_image(
tmp_path, file.content_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
# Create a fake job response with the result embedded
# This maintains API compatibility
now = datetime.utcnow()
# For sync mode, we return a special response that includes
# the result directly. Clients should check if result is present.
return OCRJobSubmitResponse(
job_id="sync-" + str(hash(content))[:16],
status=OCRJobStatus.completed,
queue_position=0,
estimated_wait_seconds=0,
created_at=now
)
finally:
# Clean up temp file
if tmp_path.exists():
os.unlink(tmp_path)
def _result_to_extraction_data(result) -> ExtractionData:
"""Convert ExtractionResult to ExtractionData schema."""
# Convert tva_entries from dict to TvaEntry objects
tva_entries_schema = [
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
for e in result.tva_entries
] if result.tva_entries else []
# Convert payment_methods from dict to PaymentMethod objects
payment_methods_list = [
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
for pm in result.payment_methods
] if result.payment_methods else []
# Auto-suggest payment_mode based on detected methods
suggested_payment_mode = None
if payment_methods_list:
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
if has_card:
suggested_payment_mode = 'banca'
return ExtractionData(
receipt_type=result.receipt_type,
receipt_number=result.receipt_number,
receipt_series=result.receipt_series,
receipt_date=result.receipt_date,
amount=result.amount,
partner_name=result.partner_name,
cui=result.cui,
description=result.description,
tva_entries=tva_entries_schema,
tva_total=result.tva_total,
address=result.address,
items_count=result.items_count,
payment_methods=payment_methods_list,
suggested_payment_mode=suggested_payment_mode,
client_name=result.client_name,
client_cui=result.client_cui,
client_address=result.client_address,
confidence_amount=result.confidence_amount,
confidence_date=result.confidence_date,
confidence_vendor=result.confidence_vendor,
confidence_client=getattr(result, 'confidence_client', 0.0),
overall_confidence=result.overall_confidence,
raw_text=result.raw_text,
ocr_engine=result.ocr_engine,
processing_time_ms=result.processing_time_ms,
needs_manual_review=result.needs_manual_review,
validation_warnings=result.validation_warnings,
validation_errors=result.validation_errors,
inter_ocr_ratios=result.inter_ocr_ratios,
)
def _dict_to_extraction_data(data: dict) -> ExtractionData:
"""Convert result dict (from job queue) to ExtractionData schema."""
from datetime import date
# Parse date if string
receipt_date = data.get('receipt_date')
if isinstance(receipt_date, str):
try:
receipt_date = date.fromisoformat(receipt_date)
except (ValueError, TypeError):
receipt_date = None
# Convert tva_entries
tva_entries = data.get('tva_entries', []) or []
tva_entries_schema = []
for e in tva_entries:
if isinstance(e, dict):
tva_entries_schema.append(TvaEntry(
code=e.get('code'),
percent=e.get('percent', 0),
amount=Decimal(str(e.get('amount', 0)))
))
# Convert payment_methods
payment_methods = data.get('payment_methods', []) or []
payment_methods_list = []
for pm in payment_methods:
if isinstance(pm, dict):
payment_methods_list.append(PaymentMethod(
method=pm.get('method', 'NUMERAR'),
amount=Decimal(str(pm.get('amount', 0)))
))
# Convert amount and tva_total to Decimal
amount = data.get('amount')
if amount is not None:
amount = Decimal(str(amount))
tva_total = data.get('tva_total')
if tva_total is not None:
tva_total = Decimal(str(tva_total))
return ExtractionData(
receipt_type=data.get('receipt_type', 'bon_fiscal'),
receipt_number=data.get('receipt_number'),
receipt_series=data.get('receipt_series'),
receipt_date=receipt_date,
amount=amount,
partner_name=data.get('partner_name'),
cui=data.get('cui'),
description=data.get('description'),
tva_entries=tva_entries_schema,
tva_total=tva_total,
address=data.get('address'),
items_count=data.get('items_count'),
payment_methods=payment_methods_list,
suggested_payment_mode=data.get('suggested_payment_mode'),
client_name=data.get('client_name'),
client_cui=data.get('client_cui'),
client_address=data.get('client_address'),
confidence_amount=data.get('confidence_amount', 0.0),
confidence_date=data.get('confidence_date', 0.0),
confidence_vendor=data.get('confidence_vendor', 0.0),
confidence_client=data.get('confidence_client', 0.0),
overall_confidence=data.get('overall_confidence', 0.0),
raw_text=data.get('raw_text', ''),
ocr_engine=data.get('ocr_engine', ''),
processing_time_ms=data.get('processing_time_ms', 0),
needs_manual_review=data.get('needs_manual_review'),
validation_warnings=data.get('validation_warnings', []),
validation_errors=data.get('validation_errors', []),
inter_ocr_ratios=data.get('inter_ocr_ratios', {}),
)