fix telegram

This commit is contained in:
Claude Agent
2026-02-23 15:12:33 +00:00
parent 6c78fec8a7
commit 8bc567a9c5
426 changed files with 112478 additions and 1 deletions

View File

@@ -0,0 +1,42 @@
"""
OCR Services Module
Provides persistent OCR worker pool with job queue for efficient processing.
Components:
- ocr_worker_pool: Manages ProcessPoolExecutor with persistent PaddleOCR
- job_queue: SQLite-based job queue for async processing
- job_worker: Background task that processes queued jobs
- tesseract_engine: Optimized Tesseract with multi-PSM and polarity fix
Architecture:
FastAPI → job_queue.create_job() → SQLite
job_worker loop → ocr_worker_pool.submit_task() → Worker Process
PaddleOCR/Tesseract
"""
from .ocr_worker_pool import ocr_worker_pool, OCRWorkerPool
from .job_queue import job_queue, OCRJobQueue, OCRJob, OCRJobStatus
from .job_worker import start_job_worker, stop_job_worker
from .tesseract_engine import TesseractEngine
from .validation import OCRValidationEngine
__all__ = [
# Worker pool
"ocr_worker_pool",
"OCRWorkerPool",
# Job queue
"job_queue",
"OCRJobQueue",
"OCRJob",
"OCRJobStatus",
# Job worker
"start_job_worker",
"stop_job_worker",
# Engines
"TesseractEngine",
# Validation
"OCRValidationEngine",
]

View File

@@ -0,0 +1,653 @@
"""
SQLite Job Queue Manager for OCR Processing
Provides async job queue for OCR requests:
- Jobs are stored in SQLite for persistence
- Queue position and time estimation
- Automatic expiration after 24 hours
- Statistics for monitoring
Schema:
ocr_jobs (
id TEXT PRIMARY KEY, -- UUID
status TEXT NOT NULL, -- pending, processing, completed, failed
file_path TEXT NOT NULL, -- Path to uploaded file
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT, -- JSON extraction result
error_message TEXT,
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
ocr_time_ms INTEGER, -- Actual OCR engine processing time
created_by TEXT, -- Username
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER, -- Foreign key to batch_uploads (for bulk processing)
file_hash TEXT -- SHA-256 hash for duplicate detection (US-007)
)
"""
import asyncio
import json
from decimal import Decimal
class DecimalEncoder(json.JSONEncoder):
"""JSON encoder that handles Decimal types."""
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
import logging
import os
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiosqlite
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_QUEUE_DIR = Path(__file__).parent.parent.parent.parent.parent / "data" / "ocr_queue"
DEFAULT_DB_PATH = DEFAULT_QUEUE_DIR / "ocr_jobs.db"
DEFAULT_FILES_DIR = DEFAULT_QUEUE_DIR / "files"
# Job expiration
JOB_EXPIRY_HOURS = 24
# SQLite busy timeout (milliseconds) - prevents "database is locked" errors
SQLITE_BUSY_TIMEOUT_MS = 5000
class OCRJobStatus(str, Enum):
"""Job status enum."""
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
cancelled = "cancelled"
@dataclass
class OCRJob:
"""OCR Job data class."""
id: str
status: OCRJobStatus
file_path: str
mime_type: str
engine: str = "doctr_plus"
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result_json: Optional[str] = None
error_message: Optional[str] = None
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
created_by: Optional[str] = None
original_filename: Optional[str] = None
expires_at: Optional[datetime] = None
batch_id: Optional[int] = None # Links to batch_uploads table for bulk processing
file_hash: Optional[str] = None # SHA-256 hash for duplicate detection (US-007)
@property
def queue_wait_ms(self) -> Optional[int]:
"""Calculate queue wait time (created_at to started_at)."""
if self.created_at and self.started_at:
delta = self.started_at - self.created_at
return int(delta.total_seconds() * 1000)
return None
@property
def result(self) -> Optional[Dict]:
"""Parse result_json to dict."""
if self.result_json:
try:
return json.loads(self.result_json)
except json.JSONDecodeError:
return None
return None
class OCRJobQueue:
"""
SQLite-based job queue for OCR processing.
Provides async methods for job management with position
tracking and time estimation.
"""
def __init__(
self,
db_path: Optional[Path] = None,
files_dir: Optional[Path] = None
):
"""
Initialize job queue.
Args:
db_path: Path to SQLite database (default: data/ocr_queue/ocr_jobs.db)
files_dir: Path to files directory (default: data/ocr_queue/files/)
"""
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self.files_dir = Path(files_dir) if files_dir else DEFAULT_FILES_DIR
self._lock = asyncio.Lock()
self._initialized = False
async def initialize(self) -> None:
"""
Initialize database and directories.
Creates SQLite database and tables if they don't exist.
Creates files directory for uploaded files.
"""
if self._initialized:
return
# Create directories
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.files_dir.mkdir(parents=True, exist_ok=True)
# Create database and tables
async with aiosqlite.connect(str(self.db_path)) as db:
# Enable WAL mode for better concurrency and set busy timeout
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
CREATE TABLE IF NOT EXISTS ocr_jobs (
id TEXT PRIMARY KEY,
status TEXT NOT NULL DEFAULT 'pending',
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT,
error_message TEXT,
processing_time_ms INTEGER,
ocr_time_ms INTEGER,
created_by TEXT,
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER
)
''')
# Migration: add ocr_time_ms column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
except Exception:
pass # Column already exists
# Migration: add batch_id column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN batch_id INTEGER')
logger.info("[OCRJobQueue] Added batch_id column to existing table")
except Exception:
pass # Column already exists
# Migration: add file_hash column if it doesn't exist (US-007)
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN file_hash TEXT')
logger.info("[OCRJobQueue] Added file_hash column to existing table")
except Exception:
pass # Column already exists
# Index for efficient queue queries
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
ON ocr_jobs(status, created_at)
''')
# Index for expiration cleanup
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_expires
ON ocr_jobs(expires_at)
''')
await db.commit()
self._initialized = True
logger.info(f"[OCRJobQueue] Initialized: db={self.db_path}, files={self.files_dir}")
async def create_job(
self,
file_bytes: bytes,
mime_type: str,
engine: str = "doctr_plus",
username: Optional[str] = None,
original_filename: Optional[str] = None,
batch_id: Optional[int] = None,
file_hash: Optional[str] = None
) -> OCRJob:
"""
Create a new OCR job.
Saves file to disk and creates database record.
Args:
file_bytes: Raw file bytes
mime_type: MIME type of file
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
username: Username of requester
original_filename: Original filename from upload
batch_id: Optional batch ID for bulk upload processing
file_hash: Optional SHA-256 hash for duplicate detection (US-007)
Returns:
Created OCRJob instance
"""
await self.initialize()
# Generate job ID
job_id = str(uuid.uuid4())
# Determine file extension
ext_map = {
'image/jpeg': '.jpg',
'image/png': '.png',
'application/pdf': '.pdf',
}
ext = ext_map.get(mime_type, '.bin')
# Save file
file_path = self.files_dir / f"{job_id}{ext}"
with open(file_path, 'wb') as f:
f.write(file_bytes)
# Calculate expiration
now = datetime.utcnow()
expires_at = now + timedelta(hours=JOB_EXPIRY_HOURS)
# Insert job record
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
INSERT INTO ocr_jobs (
id, status, file_path, mime_type, engine,
created_at, created_by, original_filename, expires_at, batch_id, file_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
job_id, OCRJobStatus.pending.value, str(file_path), mime_type, engine,
now.isoformat(), username, original_filename, expires_at.isoformat(), batch_id, file_hash
))
await db.commit()
logger.info(f"[OCRJobQueue] Created job {job_id}: engine={engine}, file={file_path.name}, batch_id={batch_id}")
return OCRJob(
id=job_id,
status=OCRJobStatus.pending,
file_path=str(file_path),
mime_type=mime_type,
engine=engine,
created_at=now,
created_by=username,
original_filename=original_filename,
expires_at=expires_at,
batch_id=batch_id,
file_hash=file_hash
)
async def get_job(self, job_id: str) -> Optional[OCRJob]:
"""
Get job by ID.
Args:
job_id: Job UUID
Returns:
OCRJob or None if not found
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_job(row)
return None
async def get_queue_position(self, job_id: str) -> Optional[int]:
"""
Get position in queue for a pending job.
Args:
job_id: Job UUID
Returns:
Queue position (1 = next to process) or None if not pending
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
# Check if job is pending
async with db.execute(
'SELECT status, created_at FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if not row or row[0] != OCRJobStatus.pending.value:
return None
job_created_at = row[1]
# Count jobs ahead in queue (created before this job)
async with db.execute('''
SELECT COUNT(*) FROM ocr_jobs
WHERE status = 'pending' AND created_at < ?
''', (job_created_at,)) as cursor:
count = await cursor.fetchone()
return (count[0] + 1) if count else 1
async def get_next_pending(self) -> Optional[OCRJob]:
"""
Get the next pending job (oldest first) and atomically mark it as processing.
This prevents race conditions in parallel processing - only one worker
can claim each job.
Returns:
Next OCRJob to process or None if queue empty
"""
await self.initialize()
now = datetime.utcnow()
async with self._lock: # Serialize access to prevent race conditions
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get the next pending job
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if not row:
return None
job_id = row['id']
# Atomically mark as processing
await db.execute('''
UPDATE ocr_jobs
SET status = 'processing', started_at = ?
WHERE id = ? AND status = 'pending'
''', (now.isoformat(), job_id))
await db.commit()
# Fetch the updated job
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
updated_row = await cursor.fetchone()
if updated_row:
return self._row_to_job(updated_row)
return None
async def update_status(
self,
job_id: str,
status: OCRJobStatus,
result: Optional[Dict] = None,
error: Optional[str] = None,
processing_time_ms: Optional[int] = None,
ocr_time_ms: Optional[int] = None
) -> bool:
"""
Update job status.
Args:
job_id: Job UUID
status: New status
result: Extraction result dict (for completed)
error: Error message (for failed)
processing_time_ms: Total job processing time (started_at to completed_at)
ocr_time_ms: Actual OCR engine processing time
Returns:
True if update successful
"""
await self.initialize()
now = datetime.utcnow()
result_json = json.dumps(result, cls=DecimalEncoder) if result else None
# Build update query based on status
if status == OCRJobStatus.processing:
query = '''
UPDATE ocr_jobs
SET status = ?, started_at = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), job_id)
elif status == OCRJobStatus.completed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
elif status == OCRJobStatus.failed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
else:
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
params = (status.value, job_id)
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
cursor = await db.execute(query, params)
await db.commit()
return cursor.rowcount > 0
async def get_average_processing_time(self) -> float:
"""
Calculate average processing time from recent completed jobs.
Uses last 50 completed jobs for accuracy.
Returns:
Average time in seconds (default 7.0 if no data)
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT AVG(processing_time_ms)
FROM (
SELECT processing_time_ms FROM ocr_jobs
WHERE status = 'completed' AND processing_time_ms IS NOT NULL
ORDER BY completed_at DESC
LIMIT 50
)
''') as cursor:
row = await cursor.fetchone()
if row and row[0]:
return row[0] / 1000.0 # Convert ms to seconds
return 7.0 # Default estimate
async def count_pending(self) -> int:
"""Count pending jobs in queue."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.pending.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def count_processing(self) -> int:
"""Count currently processing jobs."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.processing.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def cleanup_expired(self) -> int:
"""
Delete expired jobs and their files.
Returns:
Number of jobs deleted
"""
await self.initialize()
now = datetime.utcnow()
deleted = 0
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get expired jobs
async with db.execute('''
SELECT id, file_path FROM ocr_jobs
WHERE expires_at < ?
''', (now.isoformat(),)) as cursor:
rows = await cursor.fetchall()
for row in rows:
# Delete file
file_path = Path(row['file_path'])
if file_path.exists():
try:
file_path.unlink()
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
# Delete job record
await db.execute('DELETE FROM ocr_jobs WHERE id = ?', (row['id'],))
deleted += 1
await db.commit()
if deleted > 0:
logger.info(f"[OCRJobQueue] Cleaned up {deleted} expired job(s)")
return deleted
async def cleanup_job_file(self, job_id: str) -> bool:
"""
Delete the file associated with a job.
Called after processing to free disk space.
Args:
job_id: Job UUID
Returns:
True if file deleted
"""
job = await self.get_job(job_id)
if job:
file_path = Path(job.file_path)
if file_path.exists():
try:
file_path.unlink()
return True
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
return False
async def get_queue_stats(self) -> Dict[str, Any]:
"""
Get queue statistics.
Returns:
Dict with pending, processing, completed, failed counts
"""
await self.initialize()
stats = {
"pending": 0,
"processing": 0,
"completed": 0,
"failed": 0,
"average_time_seconds": 0.0,
}
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT status, COUNT(*) as count
FROM ocr_jobs
GROUP BY status
''') as cursor:
rows = await cursor.fetchall()
for row in rows:
if row[0] in stats:
stats[row[0]] = row[1]
stats["average_time_seconds"] = await self.get_average_processing_time()
return stats
def _row_to_job(self, row: aiosqlite.Row) -> OCRJob:
"""Convert database row to OCRJob."""
def parse_datetime(val):
if val:
try:
return datetime.fromisoformat(val)
except (ValueError, TypeError):
return None
return None
return OCRJob(
id=row['id'],
status=OCRJobStatus(row['status']),
file_path=row['file_path'],
mime_type=row['mime_type'],
engine=row['engine'] or 'doctr_plus',
created_at=parse_datetime(row['created_at']),
started_at=parse_datetime(row['started_at']),
completed_at=parse_datetime(row['completed_at']),
result_json=row['result_json'],
error_message=row['error_message'],
processing_time_ms=row['processing_time_ms'],
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
created_by=row['created_by'],
original_filename=row['original_filename'],
expires_at=parse_datetime(row['expires_at']),
batch_id=row['batch_id'] if 'batch_id' in row.keys() else None,
file_hash=row['file_hash'] if 'file_hash' in row.keys() else None,
)
# Singleton instance
job_queue = OCRJobQueue()

View File

@@ -0,0 +1,665 @@
"""
OCR Job Worker - Background Task for Queue Processing
Runs as an asyncio background task in FastAPI.
Continuously polls the job queue and processes OCR requests IN PARALLEL.
Architecture:
FastAPI startup
start_job_worker()
asyncio.create_task(_job_worker_loop())
while True:
# Process up to OCR_WORKERS jobs concurrently
jobs = get_pending_jobs(limit=available_slots)
for job in jobs:
asyncio.create_task(_process_job(job))
await asyncio.sleep(0.1)
"""
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import Optional, Set
from .job_queue import job_queue, OCRJobStatus, OCRJob
from .ocr_worker_pool import ocr_worker_pool
from backend.modules.data_entry.schemas.ocr import ExtractionData
logger = logging.getLogger(__name__)
# Global task reference
_job_worker_task: Optional[asyncio.Task] = None
_cleanup_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
# Configuration
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
async def _job_worker_loop() -> None:
"""
Main worker loop - processes jobs from queue IN PARALLEL.
Runs continuously until shutdown. Uses semaphore to limit
concurrent jobs to OCR_WORKERS count. Launches jobs as
background tasks without waiting for completion.
"""
global _shutdown_event, _active_tasks, _concurrency_semaphore
# Get max concurrent jobs from env (matches worker pool size)
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
_active_tasks = set()
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
_shutdown_event = asyncio.Event()
consecutive_errors = 0
max_consecutive_errors = 10
while not _shutdown_event.is_set():
try:
# Clean up completed tasks
done_tasks = {t for t in _active_tasks if t.done()}
for task in done_tasks:
_active_tasks.discard(task)
# Check for exceptions
try:
task.result()
except Exception as e:
logger.error(f"[JobWorker] Task failed: {e}")
# Check if we have capacity for more jobs
active_count = len(_active_tasks)
available_slots = max_concurrent - active_count
if available_slots > 0:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0
# Launch job processing as background task
task = asyncio.create_task(_process_job_with_semaphore(job))
_active_tasks.add(task)
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
else:
# No pending jobs - wait briefly
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass
else:
# At capacity - wait for a slot to free up
await asyncio.sleep(POLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
logger.info("[JobWorker] Worker loop cancelled")
break
except Exception as e:
consecutive_errors += 1
logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}")
if consecutive_errors >= max_consecutive_errors:
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
break
await asyncio.sleep(min(consecutive_errors * 2, 30))
# Wait for active tasks to complete on shutdown
if _active_tasks:
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
await asyncio.gather(*_active_tasks, return_exceptions=True)
logger.info("[JobWorker] Worker loop stopped")
async def _process_job_with_semaphore(job: OCRJob) -> None:
"""
Process job with semaphore to limit concurrency.
Acquires semaphore before processing, releases after.
This ensures we don't exceed OCR_WORKERS concurrent jobs.
"""
global _concurrency_semaphore
async with _concurrency_semaphore:
await _process_job(job)
async def _process_job(job: OCRJob) -> None:
"""
Process a single OCR job.
Reads file, submits to worker pool, updates job status,
and saves metrics for analytics.
Args:
job: OCRJob to process
"""
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
start_time = time.time()
file_size = 0
file_type = "image/jpeg"
try:
# Note: Job already marked as 'processing' atomically in get_next_pending()
# Read file bytes
file_path = Path(job.file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_size = len(file_bytes)
# Determine file type from job or extension
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
# Submit to worker pool
result = await ocr_worker_pool.submit_task(
image_bytes=file_bytes,
engine=job.engine,
preprocessing="auto",
timeout=OCR_TIMEOUT_SECONDS
)
elapsed_ms = int((time.time() - start_time) * 1000)
if result.get("success"):
# Job completed successfully
extraction = result.get("extraction", {})
# Include raw_texts for analysis (from all OCR engine passes)
extraction['raw_texts'] = result.get("raw_texts", [])
# Extract actual OCR processing time from extraction result
ocr_time_ms = extraction.get('processing_time_ms', 0)
# Debug: log suggested_payment_mode
spm = extraction.get('suggested_payment_mode')
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.completed,
result=extraction,
processing_time_ms=elapsed_ms,
ocr_time_ms=ocr_time_ms
)
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
# Save metrics for successful job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=extraction.get('ocr_engine', job.engine),
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=True,
overall_confidence=extraction.get('overall_confidence', 0.0),
fields_extracted=_count_extracted_fields(extraction),
needs_manual_review=extraction.get('needs_manual_review'),
validation_warnings_count=len(extraction.get('validation_warnings', [])),
validation_errors_count=len(extraction.get('validation_errors', [])),
)
# Auto-save receipt for batch jobs
if job.batch_id:
auto_save_result = await _auto_save_batch_receipt(
job=job,
extraction=extraction,
file_path=str(file_path)
)
if not auto_save_result:
# Auto-save failed - mark job as failed
# Note: job_queue status already updated to 'completed' above
# We need to update it back to failed with the auto-save error
logger.warning(
f"[JobWorker] Job {job.id} OCR succeeded but auto-save failed"
)
else:
# Job failed
error_msg = result.get("error", "Unknown error")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=error_msg,
processing_time_ms=elapsed_ms
)
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
# Save metrics for failed job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=error_msg,
)
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=str(e),
processing_time_ms=elapsed_ms
)
# Save metrics for error job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=str(e),
)
finally:
# Cleanup file after processing
try:
await job_queue.cleanup_job_file(job.id)
except Exception as e:
logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}")
async def _cleanup_loop() -> None:
"""
Periodic cleanup of expired jobs.
Runs every hour to delete jobs older than 24 hours.
"""
global _shutdown_event
logger.info("[JobWorker] Starting cleanup loop...")
while not _shutdown_event.is_set():
try:
# Wait for interval or shutdown
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=CLEANUP_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass # Normal timeout, do cleanup
# Run cleanup
deleted = await job_queue.cleanup_expired()
if deleted > 0:
logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs")
except asyncio.CancelledError:
logger.info("[JobWorker] Cleanup loop cancelled")
break
except Exception as e:
logger.error(f"[JobWorker] Cleanup error: {e}")
await asyncio.sleep(60) # Retry after 1 minute
logger.info("[JobWorker] Cleanup loop stopped")
async def start_job_worker() -> bool:
"""
Start the job worker background task.
Called at FastAPI startup to begin processing queue.
Returns:
True if started successfully
"""
global _job_worker_task, _cleanup_task, _shutdown_event
if _job_worker_task is not None and not _job_worker_task.done():
logger.warning("[JobWorker] Already running")
return True
try:
# Initialize job queue
await job_queue.initialize()
# Initialize worker pool
if not ocr_worker_pool.initialize():
logger.error("[JobWorker] Failed to initialize worker pool")
return False
# Pre-warm worker pool in BACKGROUND (don't block startup)
# First OCR request may be slower if prewarm isn't done yet
async def _background_prewarm():
logger.info("[JobWorker] Pre-warming OCR worker pool (background)...")
warmup_success = await ocr_worker_pool.prewarm(timeout=90.0)
if warmup_success:
logger.info("[JobWorker] OCR worker pool pre-warmed successfully")
else:
logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower")
asyncio.create_task(_background_prewarm())
# Start worker loop
_shutdown_event = asyncio.Event()
_job_worker_task = asyncio.create_task(_job_worker_loop())
# Start cleanup loop
_cleanup_task = asyncio.create_task(_cleanup_loop())
logger.info("[JobWorker] Started successfully")
return True
except Exception as e:
logger.error(f"[JobWorker] Failed to start: {e}")
return False
async def stop_job_worker() -> None:
"""
Stop the job worker background task.
Called at FastAPI shutdown to gracefully stop processing.
"""
global _job_worker_task, _cleanup_task, _shutdown_event
logger.info("[JobWorker] Stopping...")
# Signal shutdown
if _shutdown_event:
_shutdown_event.set()
# Cancel worker task
if _job_worker_task and not _job_worker_task.done():
_job_worker_task.cancel()
try:
await _job_worker_task
except asyncio.CancelledError:
pass
# Cancel cleanup task
if _cleanup_task and not _cleanup_task.done():
_cleanup_task.cancel()
try:
await _cleanup_task
except asyncio.CancelledError:
pass
# Shutdown worker pool
ocr_worker_pool.shutdown(wait=True)
_job_worker_task = None
_cleanup_task = None
_shutdown_event = None
logger.info("[JobWorker] Stopped")
def is_running() -> bool:
"""Check if job worker is running."""
return _job_worker_task is not None and not _job_worker_task.done()
def estimate_wait_time(queue_position: int) -> int:
"""
Estimate wait time for a job in queue.
Args:
queue_position: Position in queue (1 = next)
Returns:
Estimated wait time in seconds
"""
if queue_position <= 0:
return 0
# Get average processing time (synchronous fallback)
# Default ~7 seconds per job if no data
avg_time = 7.0
try:
# Try to get from queue stats
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
# Can't use sync call in async context, use default
pass
else:
avg_time = loop.run_until_complete(job_queue.get_average_processing_time())
except Exception:
pass
# Estimate: position * average_time
return int(queue_position * avg_time)
# ============================================================================
# Metrics Helper Functions
# ============================================================================
async def _save_job_metrics(
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
) -> None:
"""
Save OCR job metrics to database for analytics.
Called after each job completes (success or failure).
Errors are logged but don't affect job processing.
"""
try:
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
async with await get_db_session() as session:
await OCRMetricsCRUD.create(
session=session,
job_id=job_id,
username=username,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
except Exception as e:
# Log but don't fail - metrics are nice-to-have
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
def _count_extracted_fields(extraction: dict) -> int:
"""
Count number of successfully extracted fields from OCR result.
Counts non-None values in key fields.
"""
key_fields = [
'receipt_number',
'receipt_date',
'amount',
'partner_name',
'cui',
'tva_total',
'address',
'items_count',
]
count = 0
for field in key_fields:
value = extraction.get(field)
if value is not None and value != '' and value != []:
count += 1
# Also count TVA entries if present
tva_entries = extraction.get('tva_entries', [])
if tva_entries and len(tva_entries) > 0:
count += 1
# Count payment methods if present
payment_methods = extraction.get('payment_methods', [])
if payment_methods and len(payment_methods) > 0:
count += 1
return count
# ============================================================================
# Auto-Save Batch Receipt Helper
# ============================================================================
async def _auto_save_batch_receipt(
job: OCRJob,
extraction: dict,
file_path: str
) -> bool:
"""
Automatically create a receipt from OCR result for batch jobs.
Called when a batch job completes successfully. Creates the receipt,
attachment, and accounting entries using ReceiptAutoCreateService.
Args:
job: Completed OCRJob with batch_id set
extraction: OCR extraction result dict
file_path: Path to the original uploaded file
Returns:
True if receipt created successfully, False otherwise
"""
if not job.batch_id:
return True # Not a batch job, nothing to do
logger.info(f"[JobWorker] Auto-saving receipt for batch job {job.id} (batch_id={job.batch_id})")
try:
# Import here to avoid circular imports
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.models import BatchUpload
from backend.modules.data_entry.services.receipt_auto_create import ReceiptAutoCreateService
from sqlalchemy import select
# Convert extraction dict to ExtractionData schema
ocr_result = ExtractionData(**extraction)
async with await get_db_session() as session:
# Get batch info to retrieve company_id and user_id
batch_result = await session.execute(
select(BatchUpload).where(BatchUpload.id == job.batch_id)
)
batch = batch_result.scalar_one_or_none()
if not batch:
error_msg = f"Batch {job.batch_id} not found"
logger.error(f"[JobWorker] Auto-save failed for job {job.id}: {error_msg}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
# Call ReceiptAutoCreateService
result = await ReceiptAutoCreateService.create_from_ocr_result(
session=session,
job_id=job.id,
ocr_result=ocr_result,
username=job.created_by or batch.user_id,
batch_id=job.batch_id,
company_id=batch.company_id,
file_path=file_path,
original_filename=job.original_filename,
file_hash=job.file_hash # Pass file_hash for duplicate detection (US-007)
)
if result.success:
logger.info(
f"[JobWorker] Auto-save successful for job {job.id}: "
f"receipt_id={result.receipt_id}"
)
return True
else:
error_msg = result.error_message or "Unknown error"
logger.warning(
f"[JobWorker] Auto-save validation failed for job {job.id}: {error_msg}"
)
# Update job status to failed with the auto-save error
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
except Exception as e:
error_msg = str(e)
logger.error(f"[JobWorker] Auto-save exception for job {job.id}: {error_msg}")
# Update job status to failed
try:
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
except Exception as update_err:
logger.error(f"[JobWorker] Failed to update job status after auto-save error: {update_err}")
return False

View File

@@ -0,0 +1,561 @@
"""
OCR Worker Pool Manager
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
Key features:
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
- mp_context='spawn' for Windows IIS compatibility
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- atexit + signal handlers for cleanup
- Health check with auto-respawn
- Orphan process cleanup on Windows
Architecture:
Main Process │ Worker Process (PERSISTENT)
──────────────────────│──────────────────────────────────
OCRWorkerPool │ Worker initialized once
↓ │ ↓
submit_task() ────────│────→ process_ocr()
↓ │ ↓
Future.result() ←─────│──── Return result
"""
import asyncio
import atexit
import gc
import logging
import multiprocessing as mp
import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
# Try to import psutil for orphan process cleanup
try:
import psutil
PSUTIL_AVAILABLE = True
except ImportError:
PSUTIL_AVAILABLE = False
logger.warning("[OCRWorkerPool] psutil not available - orphan cleanup disabled")
class OCRWorkerPool:
"""
Singleton manager for OCR ProcessPoolExecutor.
Ensures OCR engines are loaded once and reused for all requests.
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
"""
_instance: Optional["OCRWorkerPool"] = None
_initialized: bool = False
def __new__(cls) -> "OCRWorkerPool":
"""Singleton pattern - only one pool instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize worker pool (runs only once due to singleton)."""
if self._initialized:
return
self._executor: Optional[ProcessPoolExecutor] = None
self._worker_pid: Optional[int] = None
self._is_warming: bool = False
self._is_shutdown: bool = False
self._lock = asyncio.Lock() if asyncio.get_event_loop_policy() else None
self._sync_lock = mp.Lock()
# Register cleanup handlers
# NOTE: Only use atexit, NOT signal handlers!
# Signal handlers interfere with FastAPI's shutdown handling.
# FastAPI's shutdown event calls stop_job_worker() which calls shutdown().
atexit.register(self._cleanup_on_exit)
self._initialized = True
logger.info("[OCRWorkerPool] Singleton instance created")
def initialize(self) -> bool:
"""
Initialize the ProcessPoolExecutor.
Creates executor with spawn context for Windows compatibility.
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
Returns:
True if initialization successful
"""
if self._executor is not None:
logger.warning("[OCRWorkerPool] Already initialized")
return True
if self._is_shutdown:
logger.error("[OCRWorkerPool] Cannot initialize - pool is shutdown")
return False
try:
# Cleanup any orphan workers from previous runs
self._cleanup_orphan_workers()
# Read configuration from environment
max_workers = int(os.getenv('OCR_WORKERS', '2'))
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
# 0 means no restart (None in ProcessPoolExecutor)
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
# Create executor with spawn context (Windows compatible)
# Use mp_context='spawn' explicitly for cross-platform consistency
mp_context = mp.get_context('spawn')
# max_tasks_per_child only available in Python 3.11+
executor_kwargs = {
'max_workers': max_workers,
'mp_context': mp_context,
'initializer': _worker_initializer,
}
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
else:
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
self._executor = ProcessPoolExecutor(**executor_kwargs)
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
return True
except Exception as e:
logger.error(f"[OCRWorkerPool] Initialization failed: {e}")
return False
async def prewarm(self, timeout: float = 60.0) -> bool:
"""
Pre-warm the worker by loading PaddleOCR before first request.
This is called at FastAPI startup to avoid 30s delay on first request.
Submits a dummy task that triggers PaddleOCR initialization.
Args:
timeout: Maximum seconds to wait for warmup (default 60s)
Returns:
True if warmup successful, False if timeout or error
"""
if self._executor is None:
logger.error("[OCRWorkerPool] Cannot prewarm - not initialized")
return False
if self._is_warming:
logger.warning("[OCRWorkerPool] Already warming up")
return False
self._is_warming = True
logger.info("[OCRWorkerPool] Starting pre-warm (loading PaddleOCR in worker)...")
start_time = time.time()
try:
# Submit warmup task that initializes PaddleOCR
loop = asyncio.get_event_loop()
future = self._executor.submit(_warmup_task)
# Wait with timeout
result = await loop.run_in_executor(None, future.result, timeout)
elapsed = time.time() - start_time
if result.get("success"):
logger.info(f"[OCRWorkerPool] Pre-warm complete in {elapsed:.1f}s - PaddleOCR ready")
self._worker_pid = result.get("pid")
return True
else:
logger.error(f"[OCRWorkerPool] Pre-warm failed: {result.get('error')}")
return False
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"[OCRWorkerPool] Pre-warm failed after {elapsed:.1f}s: {e}")
return False
finally:
self._is_warming = False
async def submit_task(
self,
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto",
timeout: float = 120.0
) -> dict:
"""
Submit OCR task to worker process.
Args:
image_bytes: Raw image bytes
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
timeout: Maximum processing time in seconds
Returns:
Dict with extraction results
Raises:
RuntimeError: If pool not initialized or task fails
"""
if self._executor is None:
raise RuntimeError("OCR worker pool not initialized")
if self._is_shutdown:
raise RuntimeError("OCR worker pool is shutdown")
logger.info(f"[OCRWorkerPool] Submitting task: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes")
try:
loop = asyncio.get_event_loop()
future = self._executor.submit(
_process_ocr_task,
image_bytes,
engine,
preprocessing
)
# Wait for result with timeout
result = await loop.run_in_executor(None, future.result, timeout)
logger.info(f"[OCRWorkerPool] Task complete: success={result.get('success')}")
return result
except TimeoutError:
logger.error(f"[OCRWorkerPool] Task timed out after {timeout}s")
raise RuntimeError(f"OCR processing timed out after {timeout}s")
except Exception as e:
logger.error(f"[OCRWorkerPool] Task failed: {e}")
raise RuntimeError(f"OCR processing failed: {e}")
def is_healthy(self) -> bool:
"""
Check if worker pool is healthy.
Returns:
True if pool is ready to accept tasks
"""
if self._executor is None:
return False
if self._is_shutdown:
return False
# Check if worker process is still alive
if self._worker_pid and PSUTIL_AVAILABLE:
try:
proc = psutil.Process(self._worker_pid)
if not proc.is_running():
logger.warning("[OCRWorkerPool] Worker process died, needs respawn")
return False
except psutil.NoSuchProcess:
logger.warning("[OCRWorkerPool] Worker process not found")
return False
return True
def shutdown(self, wait: bool = True, timeout: float = 10.0) -> None:
"""
Shutdown the worker pool gracefully.
Args:
wait: Wait for pending tasks to complete
timeout: Maximum wait time in seconds
"""
if self._executor is None:
return
logger.info("[OCRWorkerPool] Shutting down...")
self._is_shutdown = True
try:
self._executor.shutdown(wait=wait, cancel_futures=True)
logger.info("[OCRWorkerPool] Executor shutdown complete")
except Exception as e:
logger.error(f"[OCRWorkerPool] Shutdown error: {e}")
self._executor = None
self._worker_pid = None
# Final orphan cleanup
self._cleanup_orphan_workers()
logger.info("[OCRWorkerPool] Shutdown complete")
def _cleanup_orphan_workers(self) -> int:
"""
Clean up orphan Python processes from previous runs.
On Windows with NSSM, orphan processes may remain after service restart.
This finds and kills any python.exe processes that were OCR workers.
Returns:
Number of processes killed
"""
if not PSUTIL_AVAILABLE:
return 0
killed = 0
current_pid = os.getpid()
try:
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
# Skip self
if proc.pid == current_pid:
continue
# Look for Python processes with OCR-related cmdline
if proc.name().lower() in ('python.exe', 'python3.exe', 'python', 'python3'):
cmdline = ' '.join(proc.cmdline() or [])
# Check if this is an OCR worker process
if 'ocr_worker_process' in cmdline.lower() or 'process_ocr_task' in cmdline.lower():
logger.warning(f"[OCRWorkerPool] Killing orphan worker: PID={proc.pid}")
proc.kill()
proc.wait(timeout=5)
killed += 1
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
continue
except Exception as e:
logger.error(f"[OCRWorkerPool] Orphan cleanup error: {e}")
if killed > 0:
logger.info(f"[OCRWorkerPool] Cleaned up {killed} orphan worker(s)")
return killed
def _cleanup_on_exit(self) -> None:
"""atexit handler for cleanup."""
logger.info("[OCRWorkerPool] atexit cleanup triggered")
self.shutdown(wait=False)
def _signal_handler(self, signum: int, frame: Any) -> None:
"""Signal handler for SIGTERM/SIGINT."""
logger.info(f"[OCRWorkerPool] Received signal {signum}, shutting down...")
self.shutdown(wait=False)
# ============================================================================
# WORKER PROCESS FUNCTIONS
# ============================================================================
# These functions run in the child process, not the main FastAPI process.
# Global engines - persist between tasks in worker process
_paddle_engine = None
_tesseract_engine = None
_doctr_engine = None # docTR engine (PyTorch backend)
_worker_initialized = False
def _worker_initializer() -> None:
"""
Called once when worker process spawns.
Initializes global OCR engines IN PARALLEL for faster startup.
Uses ThreadPoolExecutor to load enabled engines concurrently.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Total warmup time = max(engine_times) instead of sum(engine_times).
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
if _worker_initialized:
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
return
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
if paddle_enabled:
enabled_engines.append("paddle")
if tesseract_enabled:
enabled_engines.append("tesseract")
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
if not paddle_enabled:
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
if not tesseract_enabled:
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
start_time = time.time()
# Define loader functions - each runs in its own thread
def load_doctr():
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
engine = initialize_doctr_engine()
return ("doctr", engine, None)
except Exception as e:
return ("doctr", None, str(e))
def load_paddle():
if not paddle_enabled:
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
engine = initialize_paddle_engine()
return ("paddle", engine, None)
except Exception as e:
return ("paddle", None, str(e))
def load_tesseract():
if not tesseract_enabled:
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
engine = TesseractEngine()
return ("tesseract", engine, None)
except Exception as e:
return ("tesseract", None, str(e))
# Build list of futures for enabled engines only
futures_to_submit = [load_doctr] # docTR always loaded
if paddle_enabled:
futures_to_submit.append(load_paddle)
if tesseract_enabled:
futures_to_submit.append(load_tesseract)
# Load engines in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
futures = [executor.submit(fn) for fn in futures_to_submit]
for future in as_completed(futures):
name, engine, error = future.result()
if error and "disabled" not in error:
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
elif engine:
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
if name == "doctr":
_doctr_engine = engine
elif name == "paddle":
_paddle_engine = engine
elif name == "tesseract":
_tesseract_engine = engine
elapsed = time.time() - start_time
_worker_initialized = True
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
def _warmup_task() -> dict:
"""
Warmup task that ensures engines are loaded.
Called at FastAPI startup to pre-warm the worker.
Returns success status and worker PID.
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Quick test - create a small dummy image
import numpy as np
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Test docTR if available (fastest engine)
if _doctr_engine is not None:
try:
_doctr_engine([dummy_img])
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
# Test PaddleOCR if available
if _paddle_engine is not None:
try:
_paddle_engine.predict(dummy_img)
print(f"[Worker {os.getpid()}] PaddleOCR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] PaddleOCR warmup error: {e}", flush=True)
# Cleanup
gc.collect()
return {
"success": True,
"pid": os.getpid(),
"doctr_available": _doctr_engine is not None,
"paddle_available": _paddle_engine is not None,
"tesseract_available": _tesseract_engine is not None
}
except Exception as e:
return {
"success": False,
"pid": os.getpid(),
"error": str(e)
}
def _process_ocr_task(
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto"
) -> dict:
"""
Process OCR task in worker process.
This is the main work function called for each OCR request.
Uses persistent global engines loaded at worker init.
Args:
image_bytes: Raw image bytes
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode
Returns:
Dict with extraction results
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Import processing function
from backend.modules.data_entry.services.ocr.ocr_worker_process import process_ocr
# Run OCR
result = process_ocr(
image_bytes=image_bytes,
paddle_engine=_paddle_engine,
tesseract_engine=_tesseract_engine,
engine=engine,
preprocessing=preprocessing,
doctr_engine=_doctr_engine
)
# Cleanup after each task
gc.collect()
return result
except Exception as e:
print(f"[Worker {os.getpid()}] Task error: {e}", flush=True)
import traceback
traceback.print_exc()
return {
"success": False,
"error": str(e),
"pid": os.getpid()
}
# Singleton instance
ocr_worker_pool = OCRWorkerPool()

View File

@@ -0,0 +1,258 @@
# Store Profiles - OCR Extraction
Sistem de profile specifice pentru extracție OCR cu hot-reload.
---
## Quick Start: Adaugă un profil nou
```bash
# 1. Generează profil din PDF-uri (dry-run pentru preview)
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--dry-run
# 2. Generează și salvează
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--output backend/modules/data_entry/services/ocr/profiles/magazin_nou.py
# 3. Hot-reload (fără restart server)
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload
# 4. Verifică
curl http://localhost:8000/api/data-entry/ocr/profiles
```
---
## Structura directorului
```
profiles/
├── __init__.py # ProfileRegistry + hot-reload (~390 linii)
├── base.py # BaseStoreProfile + pattern-uri generice (~410 linii)
├── lidl.py # Multi-rate TVA (A/B)
├── omv.py # B2B, date YYYY.MM.DD
├── socar.py # B2B, date YYYY.MM.DD
├── brick.py # Standard TVA
├── dedeman.py # E-factura support
├── kineterra.py # Non-VAT payer
├── gama_ink.py # Standard TVA (toner/cartușe)
├── electrobering.py # Standard TVA (electronice)
├── pictus_velum.py # Standard TVA (rechizite)
├── unlimited_keys.py # Standard TVA, NUMERAR payment
├── best_print.py # Non-VAT payer (neplătitor TVA)
├── stepout_market.py # TVA 5% (cărți/librărie)
└── README.md # Acest fișier
```
---
## Profile existente (12 profile)
> **Note**: Pattern-urile TVA sunt **flexibile** și acceptă ORICE cotă (5%, 9%, 11%, 19%, 21%, etc.)
> pentru a gestiona atât datele istorice cât și schimbările viitoare ale legislației.
| Magazin | CUI | Fișier | Caracteristici |
|---------|-----|--------|----------------|
| LIDL DISCOUNT S.R.L. | 22891860 | `lidl.py` | Multi-rate TVA (coduri A, B, C, D) |
| OMV PETROM MARKETING S.R.L. | 11201891 | `omv.py` | B2B (client CUI), date YYYY.MM.DD |
| SOCAR PETROLEUM S.A. | 12546600 | `socar.py` | B2B (client CUI), date YYYY.MM.DD |
| FIVE-HOLDING S.A. (BRICK) | 10562600 | `brick.py` | Standard TVA |
| DEDEMAN SRL | 2816464 | `dedeman.py` | E-factura support |
| KINETERRA CONCEPT SRL | 31180432 | `kineterra.py` | Non-VAT payer (returnează `[]`) |
| GAMA INK SERVICE SRL | 17741882 | `gama_ink.py` | Standard TVA (toner, cartușe) |
| ELECTROBERING S.R.L. | 2744937 | `electrobering.py` | Standard TVA (electronice) |
| PICTUS VELUM SRL | 39634534 | `pictus_velum.py` | Standard TVA (rechizite) |
| UNLIMITED KEYS S.R.L. | 18993187 | `unlimited_keys.py` | Standard TVA, **NUMERAR** plată |
| BEST PRINT TRADE ACTIV SRL | 45417955 | `best_print.py` | **Non-VAT payer** (neplătitor TVA) |
| STEPOUT MARKET SRL | 35532655 | `stepout_market.py` | TVA 5% (cărți, librărie) |
---
## API Endpoints
| Endpoint | Metodă | Descriere |
|----------|--------|-----------|
| `/api/data-entry/ocr/profiles` | GET | Lista toate profilele |
| `/api/data-entry/ocr/profiles/{cui}` | GET | Detalii profil (acceptă RO prefix) |
| `/api/data-entry/ocr/profiles/reload` | POST | Hot-reload toate profilele |
### Exemple API
```bash
# Lista profile
curl http://localhost:8000/api/data-entry/ocr/profiles \
-H "Authorization: Bearer <token>"
# Detalii profil (cu sau fără RO prefix)
curl http://localhost:8000/api/data-entry/ocr/profiles/22891860
curl http://localhost:8000/api/data-entry/ocr/profiles/RO22891860
# Hot-reload după modificări
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload \
-H "Authorization: Bearer <token>"
# Response reload:
{
"success": true,
"reloaded_modules": 12,
"profiles_count": 12,
"registered_cuis": ["22891860", "11201891", "12546600", "10562600", ...],
"last_reload": "2026-01-06T22:37:05.000000"
}
```
---
## Cum funcționează sistemul
### Flow de extracție
```
ReceiptExtractor.extract()
├─► STEP 1: Extrage vendor + CUI
│ └─► _extract_vendor(), _extract_cui()
├─► ProfileRegistry.get_profile(cui)
│ └─► Returnează profil specific sau None
├─► STEP 2: Extracție cu profil (dacă există)
│ ├─► profile.extract_total()
│ ├─► profile.extract_date()
│ ├─► profile.extract_receipt_number()
│ ├─► profile.extract_tva_entries()
│ ├─► profile.extract_payment_methods()
│ └─► profile.extract_client_cui()
└─► STEP 3-4: Validare + post-procesare
```
### Fallback
Dacă nu există profil pentru CUI, se folosește logica generică din `ReceiptExtractor`.
---
## Structura unui profil
```python
from .base import BaseStoreProfile
from . import ProfileRegistry
@ProfileRegistry.register
class MagazinNouProfile(BaseStoreProfile):
"""Docstring cu descriere magazin."""
CUI_LIST = ["12345678"] # Poate avea mai multe CUI-uri
NAME_PATTERNS = ["MAGAZIN", "MAGAZIN NOU", "MAG4ZIN"] # OCR variants
STORE_NAME = "Magazin Nou SRL"
# Override doar ce e diferit de base class
def extract_tva_entries(self, text: str) -> List[dict]:
# Pattern-uri specifice magazinului
...
def get_validation_hints(self) -> Dict[str, Any]:
return {
"has_multi_rate_tva": False,
"card_equals_total": True,
"has_client_cui": False,
"has_efactura": False,
"is_non_vat_payer": False,
}
```
---
## Pattern-uri disponibile în base.py
BaseStoreProfile include pattern-uri generice OCR-tolerant:
| Pattern | Descriere |
|---------|-----------|
| `TOTAL_PATTERNS` | 8 variante pentru TOTAL (TOTAL:, TOTAL DE PLATA, etc.) |
| `DATE_PATTERNS` | 6 variante (DD.MM.YYYY, YYYY-MM-DD, DD/MM/YYYY) |
| `DATE_PATTERNS_OCR_SPACES` | 4 variante cu spații OCR ("2025. 08. 14") |
| `NUMBER_PATTERNS` | 11 variante pentru număr bon (NDS, BF, C3POS) |
| `PAYMENT_PATTERNS` | 8 variante pentru CARD/NUMERAR |
| `CLIENT_MARKERS` | 6 variante pentru secțiune CLIENT |
| `CLIENT_CUI_PATTERNS` | 7 variante pentru CUI client |
### Metode implementate în base class
- `extract_total(text)``Tuple[Decimal, float]`
- `extract_date(text)``Tuple[date, float]`
- `extract_receipt_number(text)``Tuple[str, float]`
- `extract_payment_methods(text)``List[dict]`
- `extract_client_cui(text)``Tuple[str, float]`
- `extract_client_name(text)``Tuple[str, float]`
---
## Când ai nevoie de profil custom?
| Situație | Exemplu | Ce trebuie override |
|----------|---------|---------------------|
| **Multi-rate TVA** | Lidl (TVA A, TVA B) | `extract_tva_entries()` |
| **Format dată special** | OMV/Socar (YYYY.MM.DD) | `DATE_PATTERNS_OCR_SPACES` |
| **B2B receipts** | Benzinării (au client CUI) | `extract_client_cui()` |
| **Non-VAT payer** | Kineterra | `extract_tva_entries()` returnează `[]` |
| **E-factura** | Dedeman | `extract_efactura_reference()` |
---
## Decizii de design
1. **Hot-reload manual** - endpoint `/profiles/reload` apelat când se modifică fișiere
2. **Persistență în Python** - profile în Git, version controlled
3. **Fallback graceful** - dacă nu există profil, folosește logica generică
4. **CUI normalization** - gestionează automat prefixul "RO" și whitespace
5. **Deduplicare TVA** - folosește `seen = set()` pentru a evita duplicate
---
## Comenzi utile
```bash
# Verifică syntax Python pentru toate profilele
for f in backend/modules/data_entry/services/ocr/profiles/*.py; do
python3 -m py_compile "$f" && echo "✓ $(basename $f)"
done
# Lista profile
ls -la backend/modules/data_entry/services/ocr/profiles/
# Pornește backend pentru testare
cd backend && source venv/bin/activate
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
# Test OCR pe un PDF
curl -X POST -F "file=@docs/data-entry/test.pdf" \
-H "Authorization: Bearer <token>" \
"http://localhost:8000/api/data-entry/ocr/extract?engine=doctr_plus"
```
---
## Script generare profile
`scripts/generate_store_profile.py` - generator automat de profile
```bash
# Vezi help
python scripts/generate_store_profile.py --help
# Funcționalități:
# - Analizează PDF-uri via OCR API
# - Detectează: TVA format, date format, payment patterns, B2B
# - Generează cod Python cu OCR error variants
# - Suportă glob patterns (*.pdf)
# - Verifică sintaxa după generare
```

View File

@@ -0,0 +1,398 @@
"""
Store Profiles Registry with Hot-Reload Support.
This module provides a registry for store-specific OCR extraction profiles.
Profiles can be reloaded at runtime without restarting the server.
Usage:
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
# Get profile for a CUI
profile = ProfileRegistry.get_profile("22891860")
if profile:
tva_entries = profile.extract_tva_entries(text)
# Reload all profiles (after file changes)
count = ProfileRegistry.reload_all()
Architecture:
- ProfileRegistry: Singleton registry with class methods
- BaseStoreProfile: Abstract base class for profiles
- @ProfileRegistry.register: Decorator for profile classes
Hot-Reload Mechanism:
1. Admin calls POST /profiles/reload endpoint
2. Registry clears instance cache
3. importlib.reload() re-executes each profile module
4. @register decorator re-registers classes with new code
"""
from __future__ import annotations
import importlib
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Type, TYPE_CHECKING
if TYPE_CHECKING:
from .base import BaseStoreProfile
logger = logging.getLogger(__name__)
# Directory containing profile modules
PROFILES_DIR = Path(__file__).parent
class ProfileRegistry:
"""
Registry for store-specific OCR extraction profiles.
Uses class methods for singleton-like behavior without explicit instantiation.
Supports hot-reload via importlib.reload() for runtime updates.
Attributes:
_profiles: Maps CUI -> profile class (not instance)
_instances: Maps CUI -> profile instance (lazy, cleared on reload)
_last_reload: Timestamp of last reload
_loaded: Whether initial load has been performed
"""
# Class-level storage (singleton pattern via class methods)
_profiles: Dict[str, Type["BaseStoreProfile"]] = {}
_instances: Dict[str, "BaseStoreProfile"] = {}
_last_reload: Optional[datetime] = None
_loaded: bool = False
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
@classmethod
def register(cls, profile_class: Type["BaseStoreProfile"]) -> Type["BaseStoreProfile"]:
"""
Decorator to register a store profile class.
Registers the profile for all CUIs in the class's CUI_LIST.
Safe for re-registration during hot-reload (overwrites existing).
Usage:
@ProfileRegistry.register
class LidlProfile(BaseStoreProfile):
CUI_LIST = ["22891860"]
...
Args:
profile_class: Profile class to register
Returns:
The same class (allows use as decorator)
Raises:
ValueError: If CUI_LIST is empty
"""
cui_list = getattr(profile_class, 'CUI_LIST', [])
store_name = getattr(profile_class, 'STORE_NAME', profile_class.__name__)
if not cui_list:
logger.warning(f"Profile {profile_class.__name__} has empty CUI_LIST, skipping")
return profile_class
# Register for each CUI
for cui in cui_list:
# Normalize CUI (remove RO prefix, strip whitespace)
normalized_cui = cls._normalize_cui(cui)
if normalized_cui in cls._profiles:
old_class = cls._profiles[normalized_cui]
logger.debug(
f"Re-registering CUI {normalized_cui}: "
f"{old_class.__name__} -> {profile_class.__name__}"
)
# Clear cached instance for this CUI
cls._instances.pop(normalized_cui, None)
cls._profiles[normalized_cui] = profile_class
logger.debug(f"Registered profile {profile_class.__name__} for CUI {normalized_cui}")
logger.info(f"Registered {store_name} for CUIs: {cui_list}")
return profile_class
# -------------------------------------------------------------------------
# Lookup
# -------------------------------------------------------------------------
@classmethod
def get_profile(cls, cui: Optional[str]) -> Optional["BaseStoreProfile"]:
"""
Get profile instance for a CUI.
Uses lazy instantiation - creates instance on first access.
Returns None if no profile is registered for this CUI.
Args:
cui: CUI to lookup (with or without RO prefix)
Returns:
Profile instance or None
"""
if not cui:
return None
# Ensure profiles are loaded
if not cls._loaded:
cls._load_all_profiles()
normalized_cui = cls._normalize_cui(cui)
# Check if profile exists
profile_class = cls._profiles.get(normalized_cui)
if not profile_class:
return None
# Lazy instantiation
if normalized_cui not in cls._instances:
try:
cls._instances[normalized_cui] = profile_class()
logger.debug(f"Instantiated {profile_class.__name__} for CUI {normalized_cui}")
except Exception as e:
logger.error(f"Failed to instantiate {profile_class.__name__}: {e}")
return None
return cls._instances[normalized_cui]
@classmethod
def has_profile(cls, cui: Optional[str]) -> bool:
"""Check if a profile exists for this CUI."""
if not cui:
return False
if not cls._loaded:
cls._load_all_profiles()
return cls._normalize_cui(cui) in cls._profiles
# -------------------------------------------------------------------------
# Listing
# -------------------------------------------------------------------------
@classmethod
def list_profiles(cls) -> List[Dict]:
"""
List all registered profiles.
Returns:
List of dicts with cui, class_name, store_name, name_patterns
"""
if not cls._loaded:
cls._load_all_profiles()
result = []
seen_classes = set()
for cui, profile_class in cls._profiles.items():
# Avoid duplicates for profiles with multiple CUIs
if profile_class.__name__ in seen_classes:
continue
seen_classes.add(profile_class.__name__)
result.append({
"cuis": list(getattr(profile_class, 'CUI_LIST', [])),
"class_name": profile_class.__name__,
"store_name": getattr(profile_class, 'STORE_NAME', profile_class.__name__),
"name_patterns": list(getattr(profile_class, 'NAME_PATTERNS', [])),
})
return result
@classmethod
def get_profile_info(cls, cui: str) -> Optional[Dict]:
"""
Get detailed info about a profile.
Args:
cui: CUI to lookup
Returns:
Dict with profile details or None
"""
profile = cls.get_profile(cui)
if not profile:
return None
return {
"cui": cui,
"cuis": list(profile.CUI_LIST),
"class_name": profile.__class__.__name__,
"store_name": profile.STORE_NAME,
"name_patterns": list(profile.NAME_PATTERNS),
"validation_hints": profile.get_validation_hints(),
}
# -------------------------------------------------------------------------
# Hot-Reload
# -------------------------------------------------------------------------
@classmethod
def reload_all(cls) -> int:
"""
Hot-reload all profile modules.
Clears instance cache and reloads all .py files in profiles directory.
Decorator re-registers classes with updated code.
Returns:
Number of modules reloaded
"""
logger.info("Starting profile hot-reload...")
# Clear instance cache (will be recreated on next get_profile)
cls._instances.clear()
# Get list of profile modules (exclude __init__, base)
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
base_package = cls.__module__
count = 0
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
if full_name in sys.modules:
# Reload existing module
importlib.reload(sys.modules[full_name])
logger.debug(f"Reloaded module: {module_name}")
else:
# Import new module
importlib.import_module(full_name)
logger.debug(f"Imported new module: {module_name}")
count += 1
except Exception as e:
logger.error(f"Failed to reload {module_name}: {e}")
cls._last_reload = datetime.utcnow()
cls._loaded = True
logger.info(f"Profile hot-reload complete: {count} modules, {len(cls._profiles)} profiles")
return count
@classmethod
def get_reload_status(cls) -> Dict:
"""Get status of the registry including last reload time."""
return {
"loaded": cls._loaded,
"last_reload": cls._last_reload.isoformat() if cls._last_reload else None,
"profiles_count": len(cls._profiles),
"instances_count": len(cls._instances),
"registered_cuis": list(cls._profiles.keys()),
}
# -------------------------------------------------------------------------
# Internal methods
# -------------------------------------------------------------------------
@classmethod
def _normalize_cui(cls, cui: str) -> str:
"""
Normalize CUI for consistent lookup.
- Removes RO prefix (with or without space)
- Strips whitespace
- Converts to uppercase
Args:
cui: Raw CUI string
Returns:
Normalized CUI (digits only)
"""
if not cui:
return ""
cui = str(cui).strip().upper()
# Remove RO prefix (handles "RO12345" and "RO 12345")
if cui.startswith("RO"):
cui = cui[2:].lstrip()
return cui.strip()
@classmethod
def _get_profile_module_names(cls) -> List[str]:
"""
Get list of profile module names from profiles directory.
Excludes __init__.py and base.py.
Returns:
List of module names (without .py extension)
"""
excluded = {"__init__", "base", "__pycache__"}
modules = []
for path in PROFILES_DIR.glob("*.py"):
name = path.stem
if name not in excluded:
modules.append(name)
return sorted(modules)
@classmethod
def _load_all_profiles(cls) -> None:
"""
Initial load of all profile modules.
Called automatically on first get_profile() if not already loaded.
"""
if cls._loaded:
return
logger.info("Loading store profiles...")
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
# This handles both:
# - Running from backend dir: "modules.data_entry.services.ocr.profiles"
# - Running from project root: "backend.modules.data_entry.services.ocr.profiles"
this_module = cls.__module__ # e.g. "backend.modules..." or "modules..."
base_package = this_module # Use the same prefix for child modules
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_name)
logger.debug(f"Loaded module: {module_name}")
except Exception as e:
logger.error(f"Failed to load {module_name}: {e}")
cls._loaded = True
cls._last_reload = datetime.utcnow()
logger.info(f"Loaded {len(cls._profiles)} store profiles")
@classmethod
def clear(cls) -> None:
"""
Clear all registered profiles.
Mainly useful for testing.
"""
cls._profiles.clear()
cls._instances.clear()
cls._loaded = False
cls._last_reload = None
# -------------------------------------------------------------------------
# Module exports
# -------------------------------------------------------------------------
__all__ = [
"ProfileRegistry",
"BaseStoreProfile",
]
# Re-export BaseStoreProfile for convenience
from .base import BaseStoreProfile

View File

@@ -0,0 +1,655 @@
"""
Optimized Tesseract Engine for OCR - SPEED + QUALITY OPTIMIZED
Performance optimizations (vs previous version):
- Single PSM mode (PSM 4) instead of multi-PSM (4 modes × 2 calls = 8x faster)
- Single Tesseract call per image (skip image_to_data for speed)
- Lighter preprocessing (no over-binarization)
- --dpi 300 flag for proper scaling
- OEM 3 (default LSTM+Legacy) for balanced speed/accuracy
Quality optimizations for Romanian receipts:
- PSM 4: Single column layout (optimal for receipts)
- Polarity correction: ensures black text on white background
- Language: Romanian only (-l ron) for faster recognition
- Fallback to PSM 6 if PSM 4 produces poor results
Previous issues fixed:
- Was 8x slower than PaddleOCR due to multi-PSM + dual calls
- Produced gibberish on clear PDFs due to over-binarization
"""
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import cv2
import numpy as np
# Check Tesseract availability
try:
import pytesseract
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
pytesseract = None
logger = logging.getLogger(__name__)
@dataclass
class OCRResult:
"""Raw OCR result from Tesseract."""
text: str
confidence: float
boxes: List[dict] = field(default_factory=list)
engine: str = "tesseract"
class TesseractEngine:
"""
Optimized Tesseract engine for receipt OCR.
TESTED OPTIMAL SETTINGS (from comprehensive benchmark):
- DPI 200 for PDF loading (not 300!)
- Padding 40px for edge protection
- PSM 6 for complex receipts, PSM 4 for simple ones
- Multi-pass strategy when quality is critical
SPEED vs QUALITY tradeoff:
- Fast mode (single pass): ~0.9s, ~6-7 keywords
- Quality mode (multi-pass): ~1.7s, ~8-9 keywords (+2 more keywords)
BENCHMARK RESULTS:
- padded_psm6_40: Best for complex receipts (igiena, five-holding)
- baseline_psm4: Best for simple receipts (rechizite, benzina)
- multi-pass: Best overall quality but slower
"""
# PSM modes for receipts
PSM_SINGLE_COLUMN = 4 # Best for simple vertical receipts
PSM_UNIFORM_BLOCK = 6 # Best for complex layouts
PSM_SPARSE_TEXT = 11 # Fallback for difficult receipts
# Optimal padding (from benchmark)
DEFAULT_PADDING = 40
def __init__(self):
"""Initialize Tesseract engine."""
if not TESSERACT_AVAILABLE:
raise RuntimeError("pytesseract not available. Install with: pip install pytesseract")
# Verify Tesseract installation
try:
self._version = pytesseract.get_tesseract_version()
except Exception as e:
raise RuntimeError(f"Tesseract not installed or not in PATH: {e}")
logger.info(f"[TesseractEngine] Initialized (v{self._version})")
def recognize(self, image: np.ndarray, fast_mode: bool = True) -> OCRResult:
"""
Perform OCR recognition on image (OPTIMIZED).
SPEED: Uses single PSM mode + single Tesseract call.
Previously used 4 PSM modes × 2 calls = 8 Tesseract invocations.
Now uses 1-2 calls maximum (with fallback).
Args:
image: Preprocessed grayscale image (DO NOT binarize for clear PDFs!)
fast_mode: If True, skip confidence calculation for maximum speed
Returns:
OCRResult with text and confidence
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity (black text on white background)
image = self._ensure_correct_polarity(image)
# Try PSM 4 first (single column - best for receipts)
result = self._recognize_fast(image, self.PSM_SINGLE_COLUMN, fast_mode)
# If poor result, try PSM 6 as fallback
if not result.text.strip() or result.confidence < 0.3:
logger.debug(f"[Tesseract] PSM {self.PSM_SINGLE_COLUMN} poor result, trying PSM {self.PSM_UNIFORM_BLOCK}")
fallback = self._recognize_fast(image, self.PSM_UNIFORM_BLOCK, fast_mode)
if len(fallback.text) > len(result.text):
result = fallback
if result.text.strip():
logger.info(f"[TesseractEngine] Result: {len(result.text)} chars, conf={result.confidence:.0%}")
return result
def _recognize_fast(self, image: np.ndarray, psm: int, fast_mode: bool = True) -> OCRResult:
"""
Fast single-call Tesseract recognition.
Optimizations:
- Single call (image_to_string only in fast mode)
- OEM 3 (LSTM+Legacy) - faster than OEM 1
- --dpi 300 for proper scaling
- Romanian only (-l ron)
Args:
image: Grayscale image
psm: Page segmentation mode
fast_mode: Skip confidence calculation for speed
Returns:
OCRResult
"""
# Build optimized config:
# OEM 3 = LSTM + Legacy (faster than pure LSTM)
# --dpi 300 = proper scaling hint
# -l ron = Romanian only (faster, avoids eng confusion)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
if fast_mode:
# Fast path: just get text, estimate confidence
text = pytesseract.image_to_string(image, config=config)
# Estimate confidence based on text quality
confidence = self._estimate_confidence(text)
else:
# Accurate path: get text + real confidence
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
confidence = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine="tesseract"
)
except Exception as e:
logger.warning(f"[Tesseract] PSM {psm} error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _estimate_confidence(self, text: str) -> float:
"""
Estimate OCR confidence based on text quality.
Heuristics:
- More alphanumeric chars = higher confidence
- Less garbage chars = higher confidence
- Romanian-specific patterns boost confidence
"""
if not text.strip():
return 0.0
# Count valid vs garbage chars
valid_chars = sum(1 for c in text if c.isalnum() or c in '.,;:-/\n ')
total_chars = len(text)
if total_chars == 0:
return 0.0
# Base confidence from char ratio
confidence = valid_chars / total_chars
# Boost for Romanian receipt patterns
text_lower = text.lower()
if any(word in text_lower for word in ['total', 'lei', 'ron', 'buc', 'tva', 'cif', 'bon']):
confidence = min(confidence + 0.1, 1.0)
return confidence
def recognize_multipass(self, image: np.ndarray) -> OCRResult:
"""
Multi-pass OCR for maximum quality (slower but more accurate).
Strategy (from benchmark testing):
- Pass 1: PSM 4 (single column) - no padding, fast baseline
- Pass 2: PSM 6 (uniform block) - with 40px padding, better for complex layouts
- Pass 3: PSM 11 (sparse text) - with 40px padding + stronger CLAHE, for difficult receipts
Merges results: picks the pass with highest keyword count.
On average finds +2.1 more keywords than single-pass (~8.7 vs 6.6).
Time: ~1.7s (vs ~0.9s for single pass)
Args:
image: Input image (RGB or grayscale)
Returns:
OCRResult from the best pass
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# Define passes with different settings
passes = [
# Pass 1: Fast baseline (no padding) - good for simple receipts
{"name": "pass1_psm4", "psm": 4, "padding": 0, "clahe_clip": 1.5},
# Pass 2: Padded PSM 6 - good for complex receipts
{"name": "pass2_psm6_padded", "psm": 6, "padding": 40, "clahe_clip": 1.5},
# Pass 3: Sparse text with stronger enhancement - for difficult cases
{"name": "pass3_psm11", "psm": 11, "padding": 40, "clahe_clip": 2.0},
]
best_result = None
best_score = -1
all_keywords = set()
for p in passes:
# Apply preprocessing for this pass
processed = gray.copy()
# Add padding if specified
if p["padding"] > 0:
processed = cv2.copyMakeBorder(
processed, p["padding"], p["padding"], p["padding"], p["padding"],
cv2.BORDER_CONSTANT, value=255
)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=p["clahe_clip"], tileGridSize=(8, 8))
processed = clahe.apply(processed)
# Ensure correct polarity
processed = self._ensure_correct_polarity(processed)
# Run OCR
config = f'--psm {p["psm"]} --oem 3 -l ron'
try:
text = pytesseract.image_to_string(processed, config=config)
confidence = self._estimate_confidence(text)
# Score based on Romanian receipt keywords
text_lower = text.lower()
keywords = ['cif', 'total', 'tva', 'lei', 'ron', 'buc', 'fiscal', 'bon',
'hartie', 'prosop', 'saci', 'creion', 'constanta', 'bucuresti']
found_keywords = [kw for kw in keywords if kw in text_lower]
all_keywords.update(found_keywords)
# Score: keywords + CIF bonus + TOTAL bonus
score = len(found_keywords) * 10
if self._has_cif_pattern(text):
score += 15
if self._has_total_pattern(text):
score += 10
logger.debug(f"[Tesseract] {p['name']}: {len(found_keywords)} keywords, score={score}")
if score > best_score:
best_score = score
best_result = OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine=f"tesseract-multipass-{p['name']}"
)
except Exception as e:
logger.warning(f"[Tesseract] {p['name']} failed: {e}")
continue
if best_result:
logger.info(f"[TesseractEngine] Multi-pass best: {best_result.engine}, "
f"{len(all_keywords)} total keywords found")
return best_result
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract-multipass")
def _has_cif_pattern(self, text: str) -> bool:
"""Check if text contains a valid CIF/CUI pattern."""
import re
text_upper = text.upper()
patterns = [
r'CIF[:\s]*RO?\d{6,10}',
r'CUI[:\s]*RO?\d{6,10}',
r'C\.?I\.?F\.?[:\s]*RO?\d{6,10}',
]
for pattern in patterns:
if re.search(pattern, text_upper):
return True
return bool(re.search(r'RO\d{7,10}', text_upper))
def _has_total_pattern(self, text: str) -> bool:
"""Check if TOTAL is properly recognized (not truncated to BTOTAL/OTAL)."""
import re
text_upper = text.upper()
return bool(re.search(r'(^|\s)TOTAL\s', text_upper, re.MULTILINE))
def recognize_with_boxes(self, image: np.ndarray, psm: int = 4) -> OCRResult:
"""
Recognition with bounding boxes (slower, for debugging/visualization).
Use this only when you need box coordinates.
For normal OCR, use recognize() which is faster.
Args:
image: Grayscale image
psm: Page segmentation mode (default: 4 for receipts)
Returns:
OCRResult with text, confidence, and boxes
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = self._ensure_correct_polarity(image)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
boxes = []
for i in range(len(data['text'])):
if data['text'][i].strip() and int(data['conf'][i]) > 0:
boxes.append({
'text': data['text'][i],
'confidence': int(data['conf'][i]) / 100,
'box': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
})
return OCRResult(text=text, confidence=avg_conf, boxes=boxes, engine="tesseract")
except Exception as e:
logger.warning(f"[Tesseract] recognize_with_boxes error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _ensure_correct_polarity(self, image: np.ndarray) -> np.ndarray:
"""
Ensure image has black text on white background.
Receipts should have dark text on light background.
If image is inverted (light text on dark), invert it.
Detection method:
- Calculate mean pixel value
- If mean < 127, image is mostly dark (inverted)
- Invert to correct polarity
Args:
image: Grayscale image
Returns:
Polarity-corrected image
"""
mean_value = np.mean(image)
if mean_value < 127:
# Image is mostly dark = inverted (white text on black)
logger.debug(f"[TesseractEngine] Detected inverted polarity (mean={mean_value:.1f}), correcting...")
return 255 - image
return image
def recognize_numbers_only(self, image: np.ndarray) -> OCRResult:
"""
OCR optimized for numeric content (amounts, totals).
Uses character whitelist to reduce errors on numbers.
Args:
image: Preprocessed grayscale image
Returns:
OCRResult with numeric text
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity
image = self._ensure_correct_polarity(image)
# Config for numbers only
# Whitelist: digits, comma, period, space, RON, LEI
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789.,- '
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image,
config=config,
output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text.strip(),
confidence=avg_conf,
boxes=[],
engine="tesseract-numeric"
)
except Exception as e:
logger.error(f"[TesseractEngine] Numeric OCR error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def recognize_cif_optimized(self, image: np.ndarray) -> Optional[str]:
"""
Optimized CIF extraction using multi-strategy approach.
BENCHMARK RESULTS (from test_critical_fields.py):
- digit_opt_dpi200: 33% accuracy (best)
- digit_whitelist: Works well on specific receipts
- basic_ron_eng: Good backup
Strategy:
1. Try digit-optimized preprocessing (2x scale + Otsu)
2. Try character whitelist (RO + digits only)
3. Try standard ron+eng config
4. Return best match based on CIF pattern validation
Args:
image: Input image (RGB from pdf2image or BGR from OpenCV)
Returns:
Extracted CIF string (e.g., "RO10562600") or None
"""
import re
if not TESSERACT_AVAILABLE:
return None
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# Extract top 35% of image (where CIF is typically found)
height = gray.shape[0]
top_region = gray[:int(height * 0.35), :]
candidates = []
# Strategy 1: Digit-optimized preprocessing (best performer: 33% accuracy)
try:
# Scale up 2x + Otsu binarization
scaled = cv2.resize(top_region, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(scaled)
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if np.mean(binary) < 127:
binary = 255 - binary
text = pytesseract.image_to_string(binary, config='--psm 6 --oem 3 -l ron')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('digit_opt', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] digit_opt strategy failed: {e}")
# Strategy 2: Character whitelist (RO + digits only)
try:
# Add padding
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
scaled = cv2.resize(padded, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789ROro'
text = pytesseract.image_to_string(scaled, config=config)
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('whitelist', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] whitelist strategy failed: {e}")
# Strategy 3: Standard ron+eng config (good backup)
try:
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(padded)
text = pytesseract.image_to_string(enhanced, config='--psm 6 --oem 3 -l ron+eng')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('ron_eng', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] ron_eng strategy failed: {e}")
if not candidates:
return None
# Log all candidates
for strategy, cif in candidates:
logger.debug(f"[TesseractEngine] CIF candidate from {strategy}: {cif}")
# Use majority voting if multiple strategies agree
from collections import Counter
cif_counts = Counter(cif for _, cif in candidates)
most_common_cif, count = cif_counts.most_common(1)[0]
if count > 1:
# Multiple strategies agree
logger.info(f"[TesseractEngine] CIF extracted (majority {count} strategies): {most_common_cif}")
return most_common_cif
# No agreement - prefer digit_opt strategy (33% accuracy in benchmarks)
for strategy, cif in candidates:
if strategy == 'digit_opt':
logger.info(f"[TesseractEngine] CIF extracted via digit_opt (preferred): {cif}")
return cif
# Fallback to first candidate
strategy, cif = candidates[0]
logger.info(f"[TesseractEngine] CIF extracted via {strategy}: {cif}")
return cif
def _extract_cif_from_text(self, text: str) -> Optional[str]:
"""Extract CIF/CUI from OCR text."""
import re
text_upper = text.upper().replace(' ', '')
patterns = [
r'CIF[:\s]*R?O?(\d{6,10})',
r'CUI[:\s]*R?O?(\d{6,10})',
r'C\.?I\.?F\.?[:\s]*R?O?(\d{6,10})',
r'RO(\d{7,10})',
r'R\.?O\.?[\s:]*(\d{6,10})',
]
for pattern in patterns:
match = re.search(pattern, text_upper)
if match:
digits = match.group(1).lstrip('0') or '0'
return f"RO{digits}"
return None
@staticmethod
def validate_romanian_cif(cif: str) -> bool:
"""
Validate Romanian CIF/CUI using checksum algorithm.
Romanian CIF format: RO + 2-10 digits
The last digit is a control digit calculated using modulo 11.
Algorithm:
1. Multiply each digit by corresponding weight (from right to left: 2,3,4,5,6,7,2,3,4,5)
2. Sum all products
3. Remainder of sum / 11 is the control digit
4. If remainder is 10, control digit is 0
Args:
cif: CIF string (e.g., "RO10562600", "10562600")
Returns:
True if CIF is valid, False otherwise
"""
# Remove RO prefix and spaces
cif = cif.upper().replace(' ', '').replace('RO', '')
# Must be 2-10 digits
if not cif.isdigit() or len(cif) < 2 or len(cif) > 10:
return False
# Weights for checksum calculation (right to left)
weights = [2, 3, 4, 5, 6, 7, 2, 3, 4, 5]
# Pad with zeros on the left to make it 10 digits
cif_padded = cif.zfill(10)
# Calculate checksum (excluding last digit which is control)
total = 0
for i in range(9):
total += int(cif_padded[i]) * weights[i]
# Control digit
control = total % 11
if control == 10:
control = 0
# Compare with last digit
return int(cif_padded[9]) == control
@staticmethod
def is_available() -> bool:
"""Check if Tesseract is available."""
if not TESSERACT_AVAILABLE:
return False
try:
pytesseract.get_tesseract_version()
return True
except Exception:
return False
@staticmethod
def get_version() -> Optional[str]:
"""Get Tesseract version string."""
if not TESSERACT_AVAILABLE:
return None
try:
return str(pytesseract.get_tesseract_version())
except Exception:
return None