fix telegram
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
OCR Services Module
|
||||
|
||||
Provides persistent OCR worker pool with job queue for efficient processing.
|
||||
|
||||
Components:
|
||||
- ocr_worker_pool: Manages ProcessPoolExecutor with persistent PaddleOCR
|
||||
- job_queue: SQLite-based job queue for async processing
|
||||
- job_worker: Background task that processes queued jobs
|
||||
- tesseract_engine: Optimized Tesseract with multi-PSM and polarity fix
|
||||
|
||||
Architecture:
|
||||
FastAPI → job_queue.create_job() → SQLite
|
||||
↓
|
||||
job_worker loop → ocr_worker_pool.submit_task() → Worker Process
|
||||
↓
|
||||
PaddleOCR/Tesseract
|
||||
"""
|
||||
|
||||
from .ocr_worker_pool import ocr_worker_pool, OCRWorkerPool
|
||||
from .job_queue import job_queue, OCRJobQueue, OCRJob, OCRJobStatus
|
||||
from .job_worker import start_job_worker, stop_job_worker
|
||||
from .tesseract_engine import TesseractEngine
|
||||
from .validation import OCRValidationEngine
|
||||
|
||||
__all__ = [
|
||||
# Worker pool
|
||||
"ocr_worker_pool",
|
||||
"OCRWorkerPool",
|
||||
# Job queue
|
||||
"job_queue",
|
||||
"OCRJobQueue",
|
||||
"OCRJob",
|
||||
"OCRJobStatus",
|
||||
# Job worker
|
||||
"start_job_worker",
|
||||
"stop_job_worker",
|
||||
# Engines
|
||||
"TesseractEngine",
|
||||
# Validation
|
||||
"OCRValidationEngine",
|
||||
]
|
||||
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
SQLite Job Queue Manager for OCR Processing
|
||||
|
||||
Provides async job queue for OCR requests:
|
||||
- Jobs are stored in SQLite for persistence
|
||||
- Queue position and time estimation
|
||||
- Automatic expiration after 24 hours
|
||||
- Statistics for monitoring
|
||||
|
||||
Schema:
|
||||
ocr_jobs (
|
||||
id TEXT PRIMARY KEY, -- UUID
|
||||
status TEXT NOT NULL, -- pending, processing, completed, failed
|
||||
file_path TEXT NOT NULL, -- Path to uploaded file
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT, -- JSON extraction result
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
|
||||
ocr_time_ms INTEGER, -- Actual OCR engine processing time
|
||||
created_by TEXT, -- Username
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP,
|
||||
batch_id INTEGER, -- Foreign key to batch_uploads (for bulk processing)
|
||||
file_hash TEXT -- SHA-256 hash for duplicate detection (US-007)
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class DecimalEncoder(json.JSONEncoder):
|
||||
"""JSON encoder that handles Decimal types."""
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
return super().default(obj)
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
DEFAULT_QUEUE_DIR = Path(__file__).parent.parent.parent.parent.parent / "data" / "ocr_queue"
|
||||
DEFAULT_DB_PATH = DEFAULT_QUEUE_DIR / "ocr_jobs.db"
|
||||
DEFAULT_FILES_DIR = DEFAULT_QUEUE_DIR / "files"
|
||||
|
||||
# Job expiration
|
||||
JOB_EXPIRY_HOURS = 24
|
||||
|
||||
# SQLite busy timeout (milliseconds) - prevents "database is locked" errors
|
||||
SQLITE_BUSY_TIMEOUT_MS = 5000
|
||||
|
||||
|
||||
class OCRJobStatus(str, Enum):
|
||||
"""Job status enum."""
|
||||
pending = "pending"
|
||||
processing = "processing"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
cancelled = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRJob:
|
||||
"""OCR Job data class."""
|
||||
id: str
|
||||
status: OCRJobStatus
|
||||
file_path: str
|
||||
mime_type: str
|
||||
engine: str = "doctr_plus"
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result_json: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
|
||||
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
|
||||
created_by: Optional[str] = None
|
||||
original_filename: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
batch_id: Optional[int] = None # Links to batch_uploads table for bulk processing
|
||||
file_hash: Optional[str] = None # SHA-256 hash for duplicate detection (US-007)
|
||||
|
||||
@property
|
||||
def queue_wait_ms(self) -> Optional[int]:
|
||||
"""Calculate queue wait time (created_at to started_at)."""
|
||||
if self.created_at and self.started_at:
|
||||
delta = self.started_at - self.created_at
|
||||
return int(delta.total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> Optional[Dict]:
|
||||
"""Parse result_json to dict."""
|
||||
if self.result_json:
|
||||
try:
|
||||
return json.loads(self.result_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class OCRJobQueue:
|
||||
"""
|
||||
SQLite-based job queue for OCR processing.
|
||||
|
||||
Provides async methods for job management with position
|
||||
tracking and time estimation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path] = None,
|
||||
files_dir: Optional[Path] = None
|
||||
):
|
||||
"""
|
||||
Initialize job queue.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database (default: data/ocr_queue/ocr_jobs.db)
|
||||
files_dir: Path to files directory (default: data/ocr_queue/files/)
|
||||
"""
|
||||
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
||||
self.files_dir = Path(files_dir) if files_dir else DEFAULT_FILES_DIR
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize database and directories.
|
||||
|
||||
Creates SQLite database and tables if they don't exist.
|
||||
Creates files directory for uploaded files.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Create directories
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create database and tables
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
# Enable WAL mode for better concurrency and set busy timeout
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS ocr_jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
file_path TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT,
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
ocr_time_ms INTEGER,
|
||||
created_by TEXT,
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP,
|
||||
batch_id INTEGER
|
||||
)
|
||||
''')
|
||||
|
||||
# Migration: add ocr_time_ms column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
|
||||
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Migration: add batch_id column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN batch_id INTEGER')
|
||||
logger.info("[OCRJobQueue] Added batch_id column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Migration: add file_hash column if it doesn't exist (US-007)
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN file_hash TEXT')
|
||||
logger.info("[OCRJobQueue] Added file_hash column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Index for efficient queue queries
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
|
||||
ON ocr_jobs(status, created_at)
|
||||
''')
|
||||
|
||||
# Index for expiration cleanup
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_expires
|
||||
ON ocr_jobs(expires_at)
|
||||
''')
|
||||
|
||||
await db.commit()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"[OCRJobQueue] Initialized: db={self.db_path}, files={self.files_dir}")
|
||||
|
||||
async def create_job(
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
mime_type: str,
|
||||
engine: str = "doctr_plus",
|
||||
username: Optional[str] = None,
|
||||
original_filename: Optional[str] = None,
|
||||
batch_id: Optional[int] = None,
|
||||
file_hash: Optional[str] = None
|
||||
) -> OCRJob:
|
||||
"""
|
||||
Create a new OCR job.
|
||||
|
||||
Saves file to disk and creates database record.
|
||||
|
||||
Args:
|
||||
file_bytes: Raw file bytes
|
||||
mime_type: MIME type of file
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
username: Username of requester
|
||||
original_filename: Original filename from upload
|
||||
batch_id: Optional batch ID for bulk upload processing
|
||||
file_hash: Optional SHA-256 hash for duplicate detection (US-007)
|
||||
|
||||
Returns:
|
||||
Created OCRJob instance
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
# Generate job ID
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
# Determine file extension
|
||||
ext_map = {
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
'application/pdf': '.pdf',
|
||||
}
|
||||
ext = ext_map.get(mime_type, '.bin')
|
||||
|
||||
# Save file
|
||||
file_path = self.files_dir / f"{job_id}{ext}"
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(file_bytes)
|
||||
|
||||
# Calculate expiration
|
||||
now = datetime.utcnow()
|
||||
expires_at = now + timedelta(hours=JOB_EXPIRY_HOURS)
|
||||
|
||||
# Insert job record
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
await db.execute('''
|
||||
INSERT INTO ocr_jobs (
|
||||
id, status, file_path, mime_type, engine,
|
||||
created_at, created_by, original_filename, expires_at, batch_id, file_hash
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
job_id, OCRJobStatus.pending.value, str(file_path), mime_type, engine,
|
||||
now.isoformat(), username, original_filename, expires_at.isoformat(), batch_id, file_hash
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[OCRJobQueue] Created job {job_id}: engine={engine}, file={file_path.name}, batch_id={batch_id}")
|
||||
|
||||
return OCRJob(
|
||||
id=job_id,
|
||||
status=OCRJobStatus.pending,
|
||||
file_path=str(file_path),
|
||||
mime_type=mime_type,
|
||||
engine=engine,
|
||||
created_at=now,
|
||||
created_by=username,
|
||||
original_filename=original_filename,
|
||||
expires_at=expires_at,
|
||||
batch_id=batch_id,
|
||||
file_hash=file_hash
|
||||
)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get job by ID.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
OCRJob or None if not found
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
return None
|
||||
|
||||
async def get_queue_position(self, job_id: str) -> Optional[int]:
|
||||
"""
|
||||
Get position in queue for a pending job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Queue position (1 = next to process) or None if not pending
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
# Check if job is pending
|
||||
async with db.execute(
|
||||
'SELECT status, created_at FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row or row[0] != OCRJobStatus.pending.value:
|
||||
return None
|
||||
job_created_at = row[1]
|
||||
|
||||
# Count jobs ahead in queue (created before this job)
|
||||
async with db.execute('''
|
||||
SELECT COUNT(*) FROM ocr_jobs
|
||||
WHERE status = 'pending' AND created_at < ?
|
||||
''', (job_created_at,)) as cursor:
|
||||
count = await cursor.fetchone()
|
||||
return (count[0] + 1) if count else 1
|
||||
|
||||
async def get_next_pending(self) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get the next pending job (oldest first) and atomically mark it as processing.
|
||||
|
||||
This prevents race conditions in parallel processing - only one worker
|
||||
can claim each job.
|
||||
|
||||
Returns:
|
||||
Next OCRJob to process or None if queue empty
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock: # Serialize access to prevent race conditions
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get the next pending job
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
job_id = row['id']
|
||||
|
||||
# Atomically mark as processing
|
||||
await db.execute('''
|
||||
UPDATE ocr_jobs
|
||||
SET status = 'processing', started_at = ?
|
||||
WHERE id = ? AND status = 'pending'
|
||||
''', (now.isoformat(), job_id))
|
||||
await db.commit()
|
||||
|
||||
# Fetch the updated job
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
updated_row = await cursor.fetchone()
|
||||
if updated_row:
|
||||
return self._row_to_job(updated_row)
|
||||
|
||||
return None
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: OCRJobStatus,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None,
|
||||
ocr_time_ms: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job status.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
status: New status
|
||||
result: Extraction result dict (for completed)
|
||||
error: Error message (for failed)
|
||||
processing_time_ms: Total job processing time (started_at to completed_at)
|
||||
ocr_time_ms: Actual OCR engine processing time
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
result_json = json.dumps(result, cls=DecimalEncoder) if result else None
|
||||
|
||||
# Build update query based on status
|
||||
if status == OCRJobStatus.processing:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, started_at = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), job_id)
|
||||
|
||||
elif status == OCRJobStatus.completed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
elif status == OCRJobStatus.failed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
else:
|
||||
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
|
||||
params = (status.value, job_id)
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
cursor = await db.execute(query, params)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def get_average_processing_time(self) -> float:
|
||||
"""
|
||||
Calculate average processing time from recent completed jobs.
|
||||
|
||||
Uses last 50 completed jobs for accuracy.
|
||||
|
||||
Returns:
|
||||
Average time in seconds (default 7.0 if no data)
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute('''
|
||||
SELECT AVG(processing_time_ms)
|
||||
FROM (
|
||||
SELECT processing_time_ms FROM ocr_jobs
|
||||
WHERE status = 'completed' AND processing_time_ms IS NOT NULL
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT 50
|
||||
)
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0] / 1000.0 # Convert ms to seconds
|
||||
return 7.0 # Default estimate
|
||||
|
||||
async def count_pending(self) -> int:
|
||||
"""Count pending jobs in queue."""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute(
|
||||
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
|
||||
(OCRJobStatus.pending.value,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
async def count_processing(self) -> int:
|
||||
"""Count currently processing jobs."""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute(
|
||||
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
|
||||
(OCRJobStatus.processing.value,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Delete expired jobs and their files.
|
||||
|
||||
Returns:
|
||||
Number of jobs deleted
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
deleted = 0
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get expired jobs
|
||||
async with db.execute('''
|
||||
SELECT id, file_path FROM ocr_jobs
|
||||
WHERE expires_at < ?
|
||||
''', (now.isoformat(),)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
# Delete file
|
||||
file_path = Path(row['file_path'])
|
||||
if file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
|
||||
|
||||
# Delete job record
|
||||
await db.execute('DELETE FROM ocr_jobs WHERE id = ?', (row['id'],))
|
||||
deleted += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"[OCRJobQueue] Cleaned up {deleted} expired job(s)")
|
||||
|
||||
return deleted
|
||||
|
||||
async def cleanup_job_file(self, job_id: str) -> bool:
|
||||
"""
|
||||
Delete the file associated with a job.
|
||||
|
||||
Called after processing to free disk space.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
True if file deleted
|
||||
"""
|
||||
job = await self.get_job(job_id)
|
||||
if job:
|
||||
file_path = Path(job.file_path)
|
||||
if file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
async def get_queue_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get queue statistics.
|
||||
|
||||
Returns:
|
||||
Dict with pending, processing, completed, failed counts
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
stats = {
|
||||
"pending": 0,
|
||||
"processing": 0,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"average_time_seconds": 0.0,
|
||||
}
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute('''
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM ocr_jobs
|
||||
GROUP BY status
|
||||
''') as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
if row[0] in stats:
|
||||
stats[row[0]] = row[1]
|
||||
|
||||
stats["average_time_seconds"] = await self.get_average_processing_time()
|
||||
return stats
|
||||
|
||||
def _row_to_job(self, row: aiosqlite.Row) -> OCRJob:
|
||||
"""Convert database row to OCRJob."""
|
||||
def parse_datetime(val):
|
||||
if val:
|
||||
try:
|
||||
return datetime.fromisoformat(val)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return None
|
||||
|
||||
return OCRJob(
|
||||
id=row['id'],
|
||||
status=OCRJobStatus(row['status']),
|
||||
file_path=row['file_path'],
|
||||
mime_type=row['mime_type'],
|
||||
engine=row['engine'] or 'doctr_plus',
|
||||
created_at=parse_datetime(row['created_at']),
|
||||
started_at=parse_datetime(row['started_at']),
|
||||
completed_at=parse_datetime(row['completed_at']),
|
||||
result_json=row['result_json'],
|
||||
error_message=row['error_message'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
|
||||
created_by=row['created_by'],
|
||||
original_filename=row['original_filename'],
|
||||
expires_at=parse_datetime(row['expires_at']),
|
||||
batch_id=row['batch_id'] if 'batch_id' in row.keys() else None,
|
||||
file_hash=row['file_hash'] if 'file_hash' in row.keys() else None,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
job_queue = OCRJobQueue()
|
||||
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
OCR Job Worker - Background Task for Queue Processing
|
||||
|
||||
Runs as an asyncio background task in FastAPI.
|
||||
Continuously polls the job queue and processes OCR requests IN PARALLEL.
|
||||
|
||||
Architecture:
|
||||
FastAPI startup
|
||||
↓
|
||||
start_job_worker()
|
||||
↓
|
||||
asyncio.create_task(_job_worker_loop())
|
||||
↓
|
||||
while True:
|
||||
# Process up to OCR_WORKERS jobs concurrently
|
||||
jobs = get_pending_jobs(limit=available_slots)
|
||||
for job in jobs:
|
||||
asyncio.create_task(_process_job(job))
|
||||
await asyncio.sleep(0.1)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
from .job_queue import job_queue, OCRJobStatus, OCRJob
|
||||
from .ocr_worker_pool import ocr_worker_pool
|
||||
from backend.modules.data_entry.schemas.ocr import ExtractionData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global task reference
|
||||
_job_worker_task: Optional[asyncio.Task] = None
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
|
||||
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
|
||||
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
|
||||
|
||||
|
||||
async def _job_worker_loop() -> None:
|
||||
"""
|
||||
Main worker loop - processes jobs from queue IN PARALLEL.
|
||||
|
||||
Runs continuously until shutdown. Uses semaphore to limit
|
||||
concurrent jobs to OCR_WORKERS count. Launches jobs as
|
||||
background tasks without waiting for completion.
|
||||
"""
|
||||
global _shutdown_event, _active_tasks, _concurrency_semaphore
|
||||
|
||||
# Get max concurrent jobs from env (matches worker pool size)
|
||||
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
|
||||
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
_active_tasks = set()
|
||||
|
||||
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 10
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Clean up completed tasks
|
||||
done_tasks = {t for t in _active_tasks if t.done()}
|
||||
for task in done_tasks:
|
||||
_active_tasks.discard(task)
|
||||
# Check for exceptions
|
||||
try:
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Task failed: {e}")
|
||||
|
||||
# Check if we have capacity for more jobs
|
||||
active_count = len(_active_tasks)
|
||||
available_slots = max_concurrent - active_count
|
||||
|
||||
if available_slots > 0:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0
|
||||
# Launch job processing as background task
|
||||
task = asyncio.create_task(_process_job_with_semaphore(job))
|
||||
_active_tasks.add(task)
|
||||
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
|
||||
else:
|
||||
# No pending jobs - wait briefly
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# At capacity - wait for a slot to free up
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Worker loop cancelled")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}")
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
|
||||
break
|
||||
|
||||
await asyncio.sleep(min(consecutive_errors * 2, 30))
|
||||
|
||||
# Wait for active tasks to complete on shutdown
|
||||
if _active_tasks:
|
||||
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
|
||||
await asyncio.gather(*_active_tasks, return_exceptions=True)
|
||||
|
||||
logger.info("[JobWorker] Worker loop stopped")
|
||||
|
||||
|
||||
async def _process_job_with_semaphore(job: OCRJob) -> None:
|
||||
"""
|
||||
Process job with semaphore to limit concurrency.
|
||||
|
||||
Acquires semaphore before processing, releases after.
|
||||
This ensures we don't exceed OCR_WORKERS concurrent jobs.
|
||||
"""
|
||||
global _concurrency_semaphore
|
||||
|
||||
async with _concurrency_semaphore:
|
||||
await _process_job(job)
|
||||
|
||||
|
||||
async def _process_job(job: OCRJob) -> None:
|
||||
"""
|
||||
Process a single OCR job.
|
||||
|
||||
Reads file, submits to worker pool, updates job status,
|
||||
and saves metrics for analytics.
|
||||
|
||||
Args:
|
||||
job: OCRJob to process
|
||||
"""
|
||||
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
|
||||
start_time = time.time()
|
||||
file_size = 0
|
||||
file_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# Note: Job already marked as 'processing' atomically in get_next_pending()
|
||||
|
||||
# Read file bytes
|
||||
file_path = Path(job.file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
file_size = len(file_bytes)
|
||||
# Determine file type from job or extension
|
||||
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
|
||||
|
||||
# Submit to worker pool
|
||||
result = await ocr_worker_pool.submit_task(
|
||||
image_bytes=file_bytes,
|
||||
engine=job.engine,
|
||||
preprocessing="auto",
|
||||
timeout=OCR_TIMEOUT_SECONDS
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if result.get("success"):
|
||||
# Job completed successfully
|
||||
extraction = result.get("extraction", {})
|
||||
|
||||
# Include raw_texts for analysis (from all OCR engine passes)
|
||||
extraction['raw_texts'] = result.get("raw_texts", [])
|
||||
|
||||
# Extract actual OCR processing time from extraction result
|
||||
ocr_time_ms = extraction.get('processing_time_ms', 0)
|
||||
|
||||
# Debug: log suggested_payment_mode
|
||||
spm = extraction.get('suggested_payment_mode')
|
||||
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.completed,
|
||||
result=extraction,
|
||||
processing_time_ms=elapsed_ms,
|
||||
ocr_time_ms=ocr_time_ms
|
||||
)
|
||||
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
|
||||
|
||||
# Save metrics for successful job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=extraction.get('ocr_engine', job.engine),
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=True,
|
||||
overall_confidence=extraction.get('overall_confidence', 0.0),
|
||||
fields_extracted=_count_extracted_fields(extraction),
|
||||
needs_manual_review=extraction.get('needs_manual_review'),
|
||||
validation_warnings_count=len(extraction.get('validation_warnings', [])),
|
||||
validation_errors_count=len(extraction.get('validation_errors', [])),
|
||||
)
|
||||
|
||||
# Auto-save receipt for batch jobs
|
||||
if job.batch_id:
|
||||
auto_save_result = await _auto_save_batch_receipt(
|
||||
job=job,
|
||||
extraction=extraction,
|
||||
file_path=str(file_path)
|
||||
)
|
||||
if not auto_save_result:
|
||||
# Auto-save failed - mark job as failed
|
||||
# Note: job_queue status already updated to 'completed' above
|
||||
# We need to update it back to failed with the auto-save error
|
||||
logger.warning(
|
||||
f"[JobWorker] Job {job.id} OCR succeeded but auto-save failed"
|
||||
)
|
||||
|
||||
else:
|
||||
# Job failed
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=error_msg,
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
|
||||
|
||||
# Save metrics for failed job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=str(e),
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
# Save metrics for error job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup file after processing
|
||||
try:
|
||||
await job_queue.cleanup_job_file(job.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}")
|
||||
|
||||
|
||||
async def _cleanup_loop() -> None:
|
||||
"""
|
||||
Periodic cleanup of expired jobs.
|
||||
|
||||
Runs every hour to delete jobs older than 24 hours.
|
||||
"""
|
||||
global _shutdown_event
|
||||
|
||||
logger.info("[JobWorker] Starting cleanup loop...")
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Wait for interval or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=CLEANUP_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout, do cleanup
|
||||
|
||||
# Run cleanup
|
||||
deleted = await job_queue.cleanup_expired()
|
||||
if deleted > 0:
|
||||
logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Cleanup loop cancelled")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Cleanup error: {e}")
|
||||
await asyncio.sleep(60) # Retry after 1 minute
|
||||
|
||||
logger.info("[JobWorker] Cleanup loop stopped")
|
||||
|
||||
|
||||
async def start_job_worker() -> bool:
|
||||
"""
|
||||
Start the job worker background task.
|
||||
|
||||
Called at FastAPI startup to begin processing queue.
|
||||
|
||||
Returns:
|
||||
True if started successfully
|
||||
"""
|
||||
global _job_worker_task, _cleanup_task, _shutdown_event
|
||||
|
||||
if _job_worker_task is not None and not _job_worker_task.done():
|
||||
logger.warning("[JobWorker] Already running")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Initialize job queue
|
||||
await job_queue.initialize()
|
||||
|
||||
# Initialize worker pool
|
||||
if not ocr_worker_pool.initialize():
|
||||
logger.error("[JobWorker] Failed to initialize worker pool")
|
||||
return False
|
||||
|
||||
# Pre-warm worker pool in BACKGROUND (don't block startup)
|
||||
# First OCR request may be slower if prewarm isn't done yet
|
||||
async def _background_prewarm():
|
||||
logger.info("[JobWorker] Pre-warming OCR worker pool (background)...")
|
||||
warmup_success = await ocr_worker_pool.prewarm(timeout=90.0)
|
||||
if warmup_success:
|
||||
logger.info("[JobWorker] OCR worker pool pre-warmed successfully")
|
||||
else:
|
||||
logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower")
|
||||
|
||||
asyncio.create_task(_background_prewarm())
|
||||
|
||||
# Start worker loop
|
||||
_shutdown_event = asyncio.Event()
|
||||
_job_worker_task = asyncio.create_task(_job_worker_loop())
|
||||
|
||||
# Start cleanup loop
|
||||
_cleanup_task = asyncio.create_task(_cleanup_loop())
|
||||
|
||||
logger.info("[JobWorker] Started successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Failed to start: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_job_worker() -> None:
|
||||
"""
|
||||
Stop the job worker background task.
|
||||
|
||||
Called at FastAPI shutdown to gracefully stop processing.
|
||||
"""
|
||||
global _job_worker_task, _cleanup_task, _shutdown_event
|
||||
|
||||
logger.info("[JobWorker] Stopping...")
|
||||
|
||||
# Signal shutdown
|
||||
if _shutdown_event:
|
||||
_shutdown_event.set()
|
||||
|
||||
# Cancel worker task
|
||||
if _job_worker_task and not _job_worker_task.done():
|
||||
_job_worker_task.cancel()
|
||||
try:
|
||||
await _job_worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel cleanup task
|
||||
if _cleanup_task and not _cleanup_task.done():
|
||||
_cleanup_task.cancel()
|
||||
try:
|
||||
await _cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shutdown worker pool
|
||||
ocr_worker_pool.shutdown(wait=True)
|
||||
|
||||
_job_worker_task = None
|
||||
_cleanup_task = None
|
||||
_shutdown_event = None
|
||||
|
||||
logger.info("[JobWorker] Stopped")
|
||||
|
||||
|
||||
def is_running() -> bool:
|
||||
"""Check if job worker is running."""
|
||||
return _job_worker_task is not None and not _job_worker_task.done()
|
||||
|
||||
|
||||
def estimate_wait_time(queue_position: int) -> int:
|
||||
"""
|
||||
Estimate wait time for a job in queue.
|
||||
|
||||
Args:
|
||||
queue_position: Position in queue (1 = next)
|
||||
|
||||
Returns:
|
||||
Estimated wait time in seconds
|
||||
"""
|
||||
if queue_position <= 0:
|
||||
return 0
|
||||
|
||||
# Get average processing time (synchronous fallback)
|
||||
# Default ~7 seconds per job if no data
|
||||
avg_time = 7.0
|
||||
|
||||
try:
|
||||
# Try to get from queue stats
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Can't use sync call in async context, use default
|
||||
pass
|
||||
else:
|
||||
avg_time = loop.run_until_complete(job_queue.get_average_processing_time())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Estimate: position * average_time
|
||||
return int(queue_position * avg_time)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _save_job_metrics(
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Save OCR job metrics to database for analytics.
|
||||
|
||||
Called after each job completes (success or failure).
|
||||
Errors are logged but don't affect job processing.
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
|
||||
|
||||
async with await get_db_session() as session:
|
||||
await OCRMetricsCRUD.create(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - metrics are nice-to-have
|
||||
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
|
||||
|
||||
|
||||
def _count_extracted_fields(extraction: dict) -> int:
|
||||
"""
|
||||
Count number of successfully extracted fields from OCR result.
|
||||
|
||||
Counts non-None values in key fields.
|
||||
"""
|
||||
key_fields = [
|
||||
'receipt_number',
|
||||
'receipt_date',
|
||||
'amount',
|
||||
'partner_name',
|
||||
'cui',
|
||||
'tva_total',
|
||||
'address',
|
||||
'items_count',
|
||||
]
|
||||
|
||||
count = 0
|
||||
for field in key_fields:
|
||||
value = extraction.get(field)
|
||||
if value is not None and value != '' and value != []:
|
||||
count += 1
|
||||
|
||||
# Also count TVA entries if present
|
||||
tva_entries = extraction.get('tva_entries', [])
|
||||
if tva_entries and len(tva_entries) > 0:
|
||||
count += 1
|
||||
|
||||
# Count payment methods if present
|
||||
payment_methods = extraction.get('payment_methods', [])
|
||||
if payment_methods and len(payment_methods) > 0:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Auto-Save Batch Receipt Helper
|
||||
# ============================================================================
|
||||
|
||||
async def _auto_save_batch_receipt(
|
||||
job: OCRJob,
|
||||
extraction: dict,
|
||||
file_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
Automatically create a receipt from OCR result for batch jobs.
|
||||
|
||||
Called when a batch job completes successfully. Creates the receipt,
|
||||
attachment, and accounting entries using ReceiptAutoCreateService.
|
||||
|
||||
Args:
|
||||
job: Completed OCRJob with batch_id set
|
||||
extraction: OCR extraction result dict
|
||||
file_path: Path to the original uploaded file
|
||||
|
||||
Returns:
|
||||
True if receipt created successfully, False otherwise
|
||||
"""
|
||||
if not job.batch_id:
|
||||
return True # Not a batch job, nothing to do
|
||||
|
||||
logger.info(f"[JobWorker] Auto-saving receipt for batch job {job.id} (batch_id={job.batch_id})")
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.models import BatchUpload
|
||||
from backend.modules.data_entry.services.receipt_auto_create import ReceiptAutoCreateService
|
||||
from sqlalchemy import select
|
||||
|
||||
# Convert extraction dict to ExtractionData schema
|
||||
ocr_result = ExtractionData(**extraction)
|
||||
|
||||
async with await get_db_session() as session:
|
||||
# Get batch info to retrieve company_id and user_id
|
||||
batch_result = await session.execute(
|
||||
select(BatchUpload).where(BatchUpload.id == job.batch_id)
|
||||
)
|
||||
batch = batch_result.scalar_one_or_none()
|
||||
|
||||
if not batch:
|
||||
error_msg = f"Batch {job.batch_id} not found"
|
||||
logger.error(f"[JobWorker] Auto-save failed for job {job.id}: {error_msg}")
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Call ReceiptAutoCreateService
|
||||
result = await ReceiptAutoCreateService.create_from_ocr_result(
|
||||
session=session,
|
||||
job_id=job.id,
|
||||
ocr_result=ocr_result,
|
||||
username=job.created_by or batch.user_id,
|
||||
batch_id=job.batch_id,
|
||||
company_id=batch.company_id,
|
||||
file_path=file_path,
|
||||
original_filename=job.original_filename,
|
||||
file_hash=job.file_hash # Pass file_hash for duplicate detection (US-007)
|
||||
)
|
||||
|
||||
if result.success:
|
||||
logger.info(
|
||||
f"[JobWorker] Auto-save successful for job {job.id}: "
|
||||
f"receipt_id={result.receipt_id}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
error_msg = result.error_message or "Unknown error"
|
||||
logger.warning(
|
||||
f"[JobWorker] Auto-save validation failed for job {job.id}: {error_msg}"
|
||||
)
|
||||
# Update job status to failed with the auto-save error
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[JobWorker] Auto-save exception for job {job.id}: {error_msg}")
|
||||
|
||||
# Update job status to failed
|
||||
try:
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
except Exception as update_err:
|
||||
logger.error(f"[JobWorker] Failed to update job status after auto-save error: {update_err}")
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
OCR Worker Pool Manager
|
||||
|
||||
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
|
||||
Key features:
|
||||
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
|
||||
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
|
||||
- mp_context='spawn' for Windows IIS compatibility
|
||||
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- atexit + signal handlers for cleanup
|
||||
- Health check with auto-respawn
|
||||
- Orphan process cleanup on Windows
|
||||
|
||||
Architecture:
|
||||
Main Process │ Worker Process (PERSISTENT)
|
||||
──────────────────────│──────────────────────────────────
|
||||
OCRWorkerPool │ Worker initialized once
|
||||
↓ │ ↓
|
||||
submit_task() ────────│────→ process_ocr()
|
||||
↓ │ ↓
|
||||
Future.result() ←─────│──── Return result
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import psutil for orphan process cleanup
|
||||
try:
|
||||
import psutil
|
||||
PSUTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PSUTIL_AVAILABLE = False
|
||||
logger.warning("[OCRWorkerPool] psutil not available - orphan cleanup disabled")
|
||||
|
||||
|
||||
class OCRWorkerPool:
|
||||
"""
|
||||
Singleton manager for OCR ProcessPoolExecutor.
|
||||
|
||||
Ensures OCR engines are loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
|
||||
"""
|
||||
|
||||
_instance: Optional["OCRWorkerPool"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> "OCRWorkerPool":
|
||||
"""Singleton pattern - only one pool instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize worker pool (runs only once due to singleton)."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
self._worker_pid: Optional[int] = None
|
||||
self._is_warming: bool = False
|
||||
self._is_shutdown: bool = False
|
||||
self._lock = asyncio.Lock() if asyncio.get_event_loop_policy() else None
|
||||
self._sync_lock = mp.Lock()
|
||||
|
||||
# Register cleanup handlers
|
||||
# NOTE: Only use atexit, NOT signal handlers!
|
||||
# Signal handlers interfere with FastAPI's shutdown handling.
|
||||
# FastAPI's shutdown event calls stop_job_worker() which calls shutdown().
|
||||
atexit.register(self._cleanup_on_exit)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("[OCRWorkerPool] Singleton instance created")
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the ProcessPoolExecutor.
|
||||
|
||||
Creates executor with spawn context for Windows compatibility.
|
||||
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if self._executor is not None:
|
||||
logger.warning("[OCRWorkerPool] Already initialized")
|
||||
return True
|
||||
|
||||
if self._is_shutdown:
|
||||
logger.error("[OCRWorkerPool] Cannot initialize - pool is shutdown")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Cleanup any orphan workers from previous runs
|
||||
self._cleanup_orphan_workers()
|
||||
|
||||
# Read configuration from environment
|
||||
max_workers = int(os.getenv('OCR_WORKERS', '2'))
|
||||
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
|
||||
# 0 means no restart (None in ProcessPoolExecutor)
|
||||
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
|
||||
|
||||
# Create executor with spawn context (Windows compatible)
|
||||
# Use mp_context='spawn' explicitly for cross-platform consistency
|
||||
mp_context = mp.get_context('spawn')
|
||||
|
||||
# max_tasks_per_child only available in Python 3.11+
|
||||
executor_kwargs = {
|
||||
'max_workers': max_workers,
|
||||
'mp_context': mp_context,
|
||||
'initializer': _worker_initializer,
|
||||
}
|
||||
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
|
||||
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
|
||||
else:
|
||||
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
self._executor = ProcessPoolExecutor(**executor_kwargs)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
async def prewarm(self, timeout: float = 60.0) -> bool:
|
||||
"""
|
||||
Pre-warm the worker by loading PaddleOCR before first request.
|
||||
|
||||
This is called at FastAPI startup to avoid 30s delay on first request.
|
||||
Submits a dummy task that triggers PaddleOCR initialization.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for warmup (default 60s)
|
||||
|
||||
Returns:
|
||||
True if warmup successful, False if timeout or error
|
||||
"""
|
||||
if self._executor is None:
|
||||
logger.error("[OCRWorkerPool] Cannot prewarm - not initialized")
|
||||
return False
|
||||
|
||||
if self._is_warming:
|
||||
logger.warning("[OCRWorkerPool] Already warming up")
|
||||
return False
|
||||
|
||||
self._is_warming = True
|
||||
logger.info("[OCRWorkerPool] Starting pre-warm (loading PaddleOCR in worker)...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Submit warmup task that initializes PaddleOCR
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self._executor.submit(_warmup_task)
|
||||
|
||||
# Wait with timeout
|
||||
result = await loop.run_in_executor(None, future.result, timeout)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if result.get("success"):
|
||||
logger.info(f"[OCRWorkerPool] Pre-warm complete in {elapsed:.1f}s - PaddleOCR ready")
|
||||
self._worker_pid = result.get("pid")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[OCRWorkerPool] Pre-warm failed: {result.get('error')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"[OCRWorkerPool] Pre-warm failed after {elapsed:.1f}s: {e}")
|
||||
return False
|
||||
finally:
|
||||
self._is_warming = False
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
timeout: float = 120.0
|
||||
) -> dict:
|
||||
"""
|
||||
Submit OCR task to worker process.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
|
||||
timeout: Maximum processing time in seconds
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
|
||||
Raises:
|
||||
RuntimeError: If pool not initialized or task fails
|
||||
"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError("OCR worker pool not initialized")
|
||||
|
||||
if self._is_shutdown:
|
||||
raise RuntimeError("OCR worker pool is shutdown")
|
||||
|
||||
logger.info(f"[OCRWorkerPool] Submitting task: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes")
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self._executor.submit(
|
||||
_process_ocr_task,
|
||||
image_bytes,
|
||||
engine,
|
||||
preprocessing
|
||||
)
|
||||
|
||||
# Wait for result with timeout
|
||||
result = await loop.run_in_executor(None, future.result, timeout)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] Task complete: success={result.get('success')}")
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"[OCRWorkerPool] Task timed out after {timeout}s")
|
||||
raise RuntimeError(f"OCR processing timed out after {timeout}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Task failed: {e}")
|
||||
raise RuntimeError(f"OCR processing failed: {e}")
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""
|
||||
Check if worker pool is healthy.
|
||||
|
||||
Returns:
|
||||
True if pool is ready to accept tasks
|
||||
"""
|
||||
if self._executor is None:
|
||||
return False
|
||||
if self._is_shutdown:
|
||||
return False
|
||||
|
||||
# Check if worker process is still alive
|
||||
if self._worker_pid and PSUTIL_AVAILABLE:
|
||||
try:
|
||||
proc = psutil.Process(self._worker_pid)
|
||||
if not proc.is_running():
|
||||
logger.warning("[OCRWorkerPool] Worker process died, needs respawn")
|
||||
return False
|
||||
except psutil.NoSuchProcess:
|
||||
logger.warning("[OCRWorkerPool] Worker process not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def shutdown(self, wait: bool = True, timeout: float = 10.0) -> None:
|
||||
"""
|
||||
Shutdown the worker pool gracefully.
|
||||
|
||||
Args:
|
||||
wait: Wait for pending tasks to complete
|
||||
timeout: Maximum wait time in seconds
|
||||
"""
|
||||
if self._executor is None:
|
||||
return
|
||||
|
||||
logger.info("[OCRWorkerPool] Shutting down...")
|
||||
self._is_shutdown = True
|
||||
|
||||
try:
|
||||
self._executor.shutdown(wait=wait, cancel_futures=True)
|
||||
logger.info("[OCRWorkerPool] Executor shutdown complete")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Shutdown error: {e}")
|
||||
|
||||
self._executor = None
|
||||
self._worker_pid = None
|
||||
|
||||
# Final orphan cleanup
|
||||
self._cleanup_orphan_workers()
|
||||
logger.info("[OCRWorkerPool] Shutdown complete")
|
||||
|
||||
def _cleanup_orphan_workers(self) -> int:
|
||||
"""
|
||||
Clean up orphan Python processes from previous runs.
|
||||
|
||||
On Windows with NSSM, orphan processes may remain after service restart.
|
||||
This finds and kills any python.exe processes that were OCR workers.
|
||||
|
||||
Returns:
|
||||
Number of processes killed
|
||||
"""
|
||||
if not PSUTIL_AVAILABLE:
|
||||
return 0
|
||||
|
||||
killed = 0
|
||||
current_pid = os.getpid()
|
||||
|
||||
try:
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
# Skip self
|
||||
if proc.pid == current_pid:
|
||||
continue
|
||||
|
||||
# Look for Python processes with OCR-related cmdline
|
||||
if proc.name().lower() in ('python.exe', 'python3.exe', 'python', 'python3'):
|
||||
cmdline = ' '.join(proc.cmdline() or [])
|
||||
|
||||
# Check if this is an OCR worker process
|
||||
if 'ocr_worker_process' in cmdline.lower() or 'process_ocr_task' in cmdline.lower():
|
||||
logger.warning(f"[OCRWorkerPool] Killing orphan worker: PID={proc.pid}")
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
killed += 1
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Orphan cleanup error: {e}")
|
||||
|
||||
if killed > 0:
|
||||
logger.info(f"[OCRWorkerPool] Cleaned up {killed} orphan worker(s)")
|
||||
|
||||
return killed
|
||||
|
||||
def _cleanup_on_exit(self) -> None:
|
||||
"""atexit handler for cleanup."""
|
||||
logger.info("[OCRWorkerPool] atexit cleanup triggered")
|
||||
self.shutdown(wait=False)
|
||||
|
||||
def _signal_handler(self, signum: int, frame: Any) -> None:
|
||||
"""Signal handler for SIGTERM/SIGINT."""
|
||||
logger.info(f"[OCRWorkerPool] Received signal {signum}, shutting down...")
|
||||
self.shutdown(wait=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WORKER PROCESS FUNCTIONS
|
||||
# ============================================================================
|
||||
# These functions run in the child process, not the main FastAPI process.
|
||||
|
||||
# Global engines - persist between tasks in worker process
|
||||
_paddle_engine = None
|
||||
_tesseract_engine = None
|
||||
_doctr_engine = None # docTR engine (PyTorch backend)
|
||||
_worker_initialized = False
|
||||
|
||||
|
||||
def _worker_initializer() -> None:
|
||||
"""
|
||||
Called once when worker process spawns.
|
||||
|
||||
Initializes global OCR engines IN PARALLEL for faster startup.
|
||||
Uses ThreadPoolExecutor to load enabled engines concurrently.
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
|
||||
Total warmup time = max(engine_times) instead of sum(engine_times).
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
if _worker_initialized:
|
||||
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
|
||||
return
|
||||
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
|
||||
if paddle_enabled:
|
||||
enabled_engines.append("paddle")
|
||||
if tesseract_enabled:
|
||||
enabled_engines.append("tesseract")
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
|
||||
if not paddle_enabled:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
|
||||
if not tesseract_enabled:
|
||||
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Define loader functions - each runs in its own thread
|
||||
def load_doctr():
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
|
||||
engine = initialize_doctr_engine()
|
||||
return ("doctr", engine, None)
|
||||
except Exception as e:
|
||||
return ("doctr", None, str(e))
|
||||
|
||||
def load_paddle():
|
||||
if not paddle_enabled:
|
||||
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
engine = initialize_paddle_engine()
|
||||
return ("paddle", engine, None)
|
||||
except Exception as e:
|
||||
return ("paddle", None, str(e))
|
||||
|
||||
def load_tesseract():
|
||||
if not tesseract_enabled:
|
||||
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
engine = TesseractEngine()
|
||||
return ("tesseract", engine, None)
|
||||
except Exception as e:
|
||||
return ("tesseract", None, str(e))
|
||||
|
||||
# Build list of futures for enabled engines only
|
||||
futures_to_submit = [load_doctr] # docTR always loaded
|
||||
if paddle_enabled:
|
||||
futures_to_submit.append(load_paddle)
|
||||
if tesseract_enabled:
|
||||
futures_to_submit.append(load_tesseract)
|
||||
|
||||
# Load engines in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
|
||||
futures = [executor.submit(fn) for fn in futures_to_submit]
|
||||
|
||||
for future in as_completed(futures):
|
||||
name, engine, error = future.result()
|
||||
if error and "disabled" not in error:
|
||||
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
|
||||
elif engine:
|
||||
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
|
||||
if name == "doctr":
|
||||
_doctr_engine = engine
|
||||
elif name == "paddle":
|
||||
_paddle_engine = engine
|
||||
elif name == "tesseract":
|
||||
_tesseract_engine = engine
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_worker_initialized = True
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
|
||||
|
||||
|
||||
def _warmup_task() -> dict:
|
||||
"""
|
||||
Warmup task that ensures engines are loaded.
|
||||
|
||||
Called at FastAPI startup to pre-warm the worker.
|
||||
Returns success status and worker PID.
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
if not _worker_initialized:
|
||||
_worker_initializer()
|
||||
|
||||
# Quick test - create a small dummy image
|
||||
import numpy as np
|
||||
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# Test docTR if available (fastest engine)
|
||||
if _doctr_engine is not None:
|
||||
try:
|
||||
_doctr_engine([dummy_img])
|
||||
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
|
||||
|
||||
# Test PaddleOCR if available
|
||||
if _paddle_engine is not None:
|
||||
try:
|
||||
_paddle_engine.predict(dummy_img)
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR warmup error: {e}", flush=True)
|
||||
|
||||
# Cleanup
|
||||
gc.collect()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"pid": os.getpid(),
|
||||
"doctr_available": _doctr_engine is not None,
|
||||
"paddle_available": _paddle_engine is not None,
|
||||
"tesseract_available": _tesseract_engine is not None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"pid": os.getpid(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def _process_ocr_task(
|
||||
image_bytes: bytes,
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto"
|
||||
) -> dict:
|
||||
"""
|
||||
Process OCR task in worker process.
|
||||
|
||||
This is the main work function called for each OCR request.
|
||||
Uses persistent global engines loaded at worker init.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
if not _worker_initialized:
|
||||
_worker_initializer()
|
||||
|
||||
# Import processing function
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import process_ocr
|
||||
|
||||
# Run OCR
|
||||
result = process_ocr(
|
||||
image_bytes=image_bytes,
|
||||
paddle_engine=_paddle_engine,
|
||||
tesseract_engine=_tesseract_engine,
|
||||
engine=engine,
|
||||
preprocessing=preprocessing,
|
||||
doctr_engine=_doctr_engine
|
||||
)
|
||||
|
||||
# Cleanup after each task
|
||||
gc.collect()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Task error: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"pid": os.getpid()
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
ocr_worker_pool = OCRWorkerPool()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,258 @@
|
||||
# Store Profiles - OCR Extraction
|
||||
|
||||
Sistem de profile specifice pentru extracție OCR cu hot-reload.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start: Adaugă un profil nou
|
||||
|
||||
```bash
|
||||
# 1. Generează profil din PDF-uri (dry-run pentru preview)
|
||||
python scripts/generate_store_profile.py \
|
||||
--name "Magazin Nou SRL" \
|
||||
--cui "12345678" \
|
||||
--receipts "docs/data-entry/MagazinNou*.pdf" \
|
||||
--dry-run
|
||||
|
||||
# 2. Generează și salvează
|
||||
python scripts/generate_store_profile.py \
|
||||
--name "Magazin Nou SRL" \
|
||||
--cui "12345678" \
|
||||
--receipts "docs/data-entry/MagazinNou*.pdf" \
|
||||
--output backend/modules/data_entry/services/ocr/profiles/magazin_nou.py
|
||||
|
||||
# 3. Hot-reload (fără restart server)
|
||||
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload
|
||||
|
||||
# 4. Verifică
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Structura directorului
|
||||
|
||||
```
|
||||
profiles/
|
||||
├── __init__.py # ProfileRegistry + hot-reload (~390 linii)
|
||||
├── base.py # BaseStoreProfile + pattern-uri generice (~410 linii)
|
||||
├── lidl.py # Multi-rate TVA (A/B)
|
||||
├── omv.py # B2B, date YYYY.MM.DD
|
||||
├── socar.py # B2B, date YYYY.MM.DD
|
||||
├── brick.py # Standard TVA
|
||||
├── dedeman.py # E-factura support
|
||||
├── kineterra.py # Non-VAT payer
|
||||
├── gama_ink.py # Standard TVA (toner/cartușe)
|
||||
├── electrobering.py # Standard TVA (electronice)
|
||||
├── pictus_velum.py # Standard TVA (rechizite)
|
||||
├── unlimited_keys.py # Standard TVA, NUMERAR payment
|
||||
├── best_print.py # Non-VAT payer (neplătitor TVA)
|
||||
├── stepout_market.py # TVA 5% (cărți/librărie)
|
||||
└── README.md # Acest fișier
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Profile existente (12 profile)
|
||||
|
||||
> **Note**: Pattern-urile TVA sunt **flexibile** și acceptă ORICE cotă (5%, 9%, 11%, 19%, 21%, etc.)
|
||||
> pentru a gestiona atât datele istorice cât și schimbările viitoare ale legislației.
|
||||
|
||||
| Magazin | CUI | Fișier | Caracteristici |
|
||||
|---------|-----|--------|----------------|
|
||||
| LIDL DISCOUNT S.R.L. | 22891860 | `lidl.py` | Multi-rate TVA (coduri A, B, C, D) |
|
||||
| OMV PETROM MARKETING S.R.L. | 11201891 | `omv.py` | B2B (client CUI), date YYYY.MM.DD |
|
||||
| SOCAR PETROLEUM S.A. | 12546600 | `socar.py` | B2B (client CUI), date YYYY.MM.DD |
|
||||
| FIVE-HOLDING S.A. (BRICK) | 10562600 | `brick.py` | Standard TVA |
|
||||
| DEDEMAN SRL | 2816464 | `dedeman.py` | E-factura support |
|
||||
| KINETERRA CONCEPT SRL | 31180432 | `kineterra.py` | Non-VAT payer (returnează `[]`) |
|
||||
| GAMA INK SERVICE SRL | 17741882 | `gama_ink.py` | Standard TVA (toner, cartușe) |
|
||||
| ELECTROBERING S.R.L. | 2744937 | `electrobering.py` | Standard TVA (electronice) |
|
||||
| PICTUS VELUM SRL | 39634534 | `pictus_velum.py` | Standard TVA (rechizite) |
|
||||
| UNLIMITED KEYS S.R.L. | 18993187 | `unlimited_keys.py` | Standard TVA, **NUMERAR** plată |
|
||||
| BEST PRINT TRADE ACTIV SRL | 45417955 | `best_print.py` | **Non-VAT payer** (neplătitor TVA) |
|
||||
| STEPOUT MARKET SRL | 35532655 | `stepout_market.py` | TVA 5% (cărți, librărie) |
|
||||
|
||||
---
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Metodă | Descriere |
|
||||
|----------|--------|-----------|
|
||||
| `/api/data-entry/ocr/profiles` | GET | Lista toate profilele |
|
||||
| `/api/data-entry/ocr/profiles/{cui}` | GET | Detalii profil (acceptă RO prefix) |
|
||||
| `/api/data-entry/ocr/profiles/reload` | POST | Hot-reload toate profilele |
|
||||
|
||||
### Exemple API
|
||||
|
||||
```bash
|
||||
# Lista profile
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles \
|
||||
-H "Authorization: Bearer <token>"
|
||||
|
||||
# Detalii profil (cu sau fără RO prefix)
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles/22891860
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles/RO22891860
|
||||
|
||||
# Hot-reload după modificări
|
||||
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload \
|
||||
-H "Authorization: Bearer <token>"
|
||||
|
||||
# Response reload:
|
||||
{
|
||||
"success": true,
|
||||
"reloaded_modules": 12,
|
||||
"profiles_count": 12,
|
||||
"registered_cuis": ["22891860", "11201891", "12546600", "10562600", ...],
|
||||
"last_reload": "2026-01-06T22:37:05.000000"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cum funcționează sistemul
|
||||
|
||||
### Flow de extracție
|
||||
|
||||
```
|
||||
ReceiptExtractor.extract()
|
||||
│
|
||||
├─► STEP 1: Extrage vendor + CUI
|
||||
│ └─► _extract_vendor(), _extract_cui()
|
||||
│
|
||||
├─► ProfileRegistry.get_profile(cui)
|
||||
│ └─► Returnează profil specific sau None
|
||||
│
|
||||
├─► STEP 2: Extracție cu profil (dacă există)
|
||||
│ ├─► profile.extract_total()
|
||||
│ ├─► profile.extract_date()
|
||||
│ ├─► profile.extract_receipt_number()
|
||||
│ ├─► profile.extract_tva_entries()
|
||||
│ ├─► profile.extract_payment_methods()
|
||||
│ └─► profile.extract_client_cui()
|
||||
│
|
||||
└─► STEP 3-4: Validare + post-procesare
|
||||
```
|
||||
|
||||
### Fallback
|
||||
|
||||
Dacă nu există profil pentru CUI, se folosește logica generică din `ReceiptExtractor`.
|
||||
|
||||
---
|
||||
|
||||
## Structura unui profil
|
||||
|
||||
```python
|
||||
from .base import BaseStoreProfile
|
||||
from . import ProfileRegistry
|
||||
|
||||
@ProfileRegistry.register
|
||||
class MagazinNouProfile(BaseStoreProfile):
|
||||
"""Docstring cu descriere magazin."""
|
||||
|
||||
CUI_LIST = ["12345678"] # Poate avea mai multe CUI-uri
|
||||
NAME_PATTERNS = ["MAGAZIN", "MAGAZIN NOU", "MAG4ZIN"] # OCR variants
|
||||
STORE_NAME = "Magazin Nou SRL"
|
||||
|
||||
# Override doar ce e diferit de base class
|
||||
def extract_tva_entries(self, text: str) -> List[dict]:
|
||||
# Pattern-uri specifice magazinului
|
||||
...
|
||||
|
||||
def get_validation_hints(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"has_multi_rate_tva": False,
|
||||
"card_equals_total": True,
|
||||
"has_client_cui": False,
|
||||
"has_efactura": False,
|
||||
"is_non_vat_payer": False,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pattern-uri disponibile în base.py
|
||||
|
||||
BaseStoreProfile include pattern-uri generice OCR-tolerant:
|
||||
|
||||
| Pattern | Descriere |
|
||||
|---------|-----------|
|
||||
| `TOTAL_PATTERNS` | 8 variante pentru TOTAL (TOTAL:, TOTAL DE PLATA, etc.) |
|
||||
| `DATE_PATTERNS` | 6 variante (DD.MM.YYYY, YYYY-MM-DD, DD/MM/YYYY) |
|
||||
| `DATE_PATTERNS_OCR_SPACES` | 4 variante cu spații OCR ("2025. 08. 14") |
|
||||
| `NUMBER_PATTERNS` | 11 variante pentru număr bon (NDS, BF, C3POS) |
|
||||
| `PAYMENT_PATTERNS` | 8 variante pentru CARD/NUMERAR |
|
||||
| `CLIENT_MARKERS` | 6 variante pentru secțiune CLIENT |
|
||||
| `CLIENT_CUI_PATTERNS` | 7 variante pentru CUI client |
|
||||
|
||||
### Metode implementate în base class
|
||||
|
||||
- `extract_total(text)` → `Tuple[Decimal, float]`
|
||||
- `extract_date(text)` → `Tuple[date, float]`
|
||||
- `extract_receipt_number(text)` → `Tuple[str, float]`
|
||||
- `extract_payment_methods(text)` → `List[dict]`
|
||||
- `extract_client_cui(text)` → `Tuple[str, float]`
|
||||
- `extract_client_name(text)` → `Tuple[str, float]`
|
||||
|
||||
---
|
||||
|
||||
## Când ai nevoie de profil custom?
|
||||
|
||||
| Situație | Exemplu | Ce trebuie override |
|
||||
|----------|---------|---------------------|
|
||||
| **Multi-rate TVA** | Lidl (TVA A, TVA B) | `extract_tva_entries()` |
|
||||
| **Format dată special** | OMV/Socar (YYYY.MM.DD) | `DATE_PATTERNS_OCR_SPACES` |
|
||||
| **B2B receipts** | Benzinării (au client CUI) | `extract_client_cui()` |
|
||||
| **Non-VAT payer** | Kineterra | `extract_tva_entries()` returnează `[]` |
|
||||
| **E-factura** | Dedeman | `extract_efactura_reference()` |
|
||||
|
||||
---
|
||||
|
||||
## Decizii de design
|
||||
|
||||
1. **Hot-reload manual** - endpoint `/profiles/reload` apelat când se modifică fișiere
|
||||
2. **Persistență în Python** - profile în Git, version controlled
|
||||
3. **Fallback graceful** - dacă nu există profil, folosește logica generică
|
||||
4. **CUI normalization** - gestionează automat prefixul "RO" și whitespace
|
||||
5. **Deduplicare TVA** - folosește `seen = set()` pentru a evita duplicate
|
||||
|
||||
---
|
||||
|
||||
## Comenzi utile
|
||||
|
||||
```bash
|
||||
# Verifică syntax Python pentru toate profilele
|
||||
for f in backend/modules/data_entry/services/ocr/profiles/*.py; do
|
||||
python3 -m py_compile "$f" && echo "✓ $(basename $f)"
|
||||
done
|
||||
|
||||
# Lista profile
|
||||
ls -la backend/modules/data_entry/services/ocr/profiles/
|
||||
|
||||
# Pornește backend pentru testare
|
||||
cd backend && source venv/bin/activate
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
|
||||
|
||||
# Test OCR pe un PDF
|
||||
curl -X POST -F "file=@docs/data-entry/test.pdf" \
|
||||
-H "Authorization: Bearer <token>" \
|
||||
"http://localhost:8000/api/data-entry/ocr/extract?engine=doctr_plus"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Script generare profile
|
||||
|
||||
`scripts/generate_store_profile.py` - generator automat de profile
|
||||
|
||||
```bash
|
||||
# Vezi help
|
||||
python scripts/generate_store_profile.py --help
|
||||
|
||||
# Funcționalități:
|
||||
# - Analizează PDF-uri via OCR API
|
||||
# - Detectează: TVA format, date format, payment patterns, B2B
|
||||
# - Generează cod Python cu OCR error variants
|
||||
# - Suportă glob patterns (*.pdf)
|
||||
# - Verifică sintaxa după generare
|
||||
```
|
||||
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Store Profiles Registry with Hot-Reload Support.
|
||||
|
||||
This module provides a registry for store-specific OCR extraction profiles.
|
||||
Profiles can be reloaded at runtime without restarting the server.
|
||||
|
||||
Usage:
|
||||
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
|
||||
|
||||
# Get profile for a CUI
|
||||
profile = ProfileRegistry.get_profile("22891860")
|
||||
if profile:
|
||||
tva_entries = profile.extract_tva_entries(text)
|
||||
|
||||
# Reload all profiles (after file changes)
|
||||
count = ProfileRegistry.reload_all()
|
||||
|
||||
Architecture:
|
||||
- ProfileRegistry: Singleton registry with class methods
|
||||
- BaseStoreProfile: Abstract base class for profiles
|
||||
- @ProfileRegistry.register: Decorator for profile classes
|
||||
|
||||
Hot-Reload Mechanism:
|
||||
1. Admin calls POST /profiles/reload endpoint
|
||||
2. Registry clears instance cache
|
||||
3. importlib.reload() re-executes each profile module
|
||||
4. @register decorator re-registers classes with new code
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import BaseStoreProfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directory containing profile modules
|
||||
PROFILES_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class ProfileRegistry:
|
||||
"""
|
||||
Registry for store-specific OCR extraction profiles.
|
||||
|
||||
Uses class methods for singleton-like behavior without explicit instantiation.
|
||||
Supports hot-reload via importlib.reload() for runtime updates.
|
||||
|
||||
Attributes:
|
||||
_profiles: Maps CUI -> profile class (not instance)
|
||||
_instances: Maps CUI -> profile instance (lazy, cleared on reload)
|
||||
_last_reload: Timestamp of last reload
|
||||
_loaded: Whether initial load has been performed
|
||||
"""
|
||||
|
||||
# Class-level storage (singleton pattern via class methods)
|
||||
_profiles: Dict[str, Type["BaseStoreProfile"]] = {}
|
||||
_instances: Dict[str, "BaseStoreProfile"] = {}
|
||||
_last_reload: Optional[datetime] = None
|
||||
_loaded: bool = False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Registration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def register(cls, profile_class: Type["BaseStoreProfile"]) -> Type["BaseStoreProfile"]:
|
||||
"""
|
||||
Decorator to register a store profile class.
|
||||
|
||||
Registers the profile for all CUIs in the class's CUI_LIST.
|
||||
Safe for re-registration during hot-reload (overwrites existing).
|
||||
|
||||
Usage:
|
||||
@ProfileRegistry.register
|
||||
class LidlProfile(BaseStoreProfile):
|
||||
CUI_LIST = ["22891860"]
|
||||
...
|
||||
|
||||
Args:
|
||||
profile_class: Profile class to register
|
||||
|
||||
Returns:
|
||||
The same class (allows use as decorator)
|
||||
|
||||
Raises:
|
||||
ValueError: If CUI_LIST is empty
|
||||
"""
|
||||
cui_list = getattr(profile_class, 'CUI_LIST', [])
|
||||
store_name = getattr(profile_class, 'STORE_NAME', profile_class.__name__)
|
||||
|
||||
if not cui_list:
|
||||
logger.warning(f"Profile {profile_class.__name__} has empty CUI_LIST, skipping")
|
||||
return profile_class
|
||||
|
||||
# Register for each CUI
|
||||
for cui in cui_list:
|
||||
# Normalize CUI (remove RO prefix, strip whitespace)
|
||||
normalized_cui = cls._normalize_cui(cui)
|
||||
|
||||
if normalized_cui in cls._profiles:
|
||||
old_class = cls._profiles[normalized_cui]
|
||||
logger.debug(
|
||||
f"Re-registering CUI {normalized_cui}: "
|
||||
f"{old_class.__name__} -> {profile_class.__name__}"
|
||||
)
|
||||
# Clear cached instance for this CUI
|
||||
cls._instances.pop(normalized_cui, None)
|
||||
|
||||
cls._profiles[normalized_cui] = profile_class
|
||||
logger.debug(f"Registered profile {profile_class.__name__} for CUI {normalized_cui}")
|
||||
|
||||
logger.info(f"Registered {store_name} for CUIs: {cui_list}")
|
||||
return profile_class
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lookup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, cui: Optional[str]) -> Optional["BaseStoreProfile"]:
|
||||
"""
|
||||
Get profile instance for a CUI.
|
||||
|
||||
Uses lazy instantiation - creates instance on first access.
|
||||
Returns None if no profile is registered for this CUI.
|
||||
|
||||
Args:
|
||||
cui: CUI to lookup (with or without RO prefix)
|
||||
|
||||
Returns:
|
||||
Profile instance or None
|
||||
"""
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
# Ensure profiles are loaded
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
|
||||
normalized_cui = cls._normalize_cui(cui)
|
||||
|
||||
# Check if profile exists
|
||||
profile_class = cls._profiles.get(normalized_cui)
|
||||
if not profile_class:
|
||||
return None
|
||||
|
||||
# Lazy instantiation
|
||||
if normalized_cui not in cls._instances:
|
||||
try:
|
||||
cls._instances[normalized_cui] = profile_class()
|
||||
logger.debug(f"Instantiated {profile_class.__name__} for CUI {normalized_cui}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate {profile_class.__name__}: {e}")
|
||||
return None
|
||||
|
||||
return cls._instances[normalized_cui]
|
||||
|
||||
@classmethod
|
||||
def has_profile(cls, cui: Optional[str]) -> bool:
|
||||
"""Check if a profile exists for this CUI."""
|
||||
if not cui:
|
||||
return False
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
return cls._normalize_cui(cui) in cls._profiles
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Listing
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def list_profiles(cls) -> List[Dict]:
|
||||
"""
|
||||
List all registered profiles.
|
||||
|
||||
Returns:
|
||||
List of dicts with cui, class_name, store_name, name_patterns
|
||||
"""
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
|
||||
result = []
|
||||
seen_classes = set()
|
||||
|
||||
for cui, profile_class in cls._profiles.items():
|
||||
# Avoid duplicates for profiles with multiple CUIs
|
||||
if profile_class.__name__ in seen_classes:
|
||||
continue
|
||||
seen_classes.add(profile_class.__name__)
|
||||
|
||||
result.append({
|
||||
"cuis": list(getattr(profile_class, 'CUI_LIST', [])),
|
||||
"class_name": profile_class.__name__,
|
||||
"store_name": getattr(profile_class, 'STORE_NAME', profile_class.__name__),
|
||||
"name_patterns": list(getattr(profile_class, 'NAME_PATTERNS', [])),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_profile_info(cls, cui: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get detailed info about a profile.
|
||||
|
||||
Args:
|
||||
cui: CUI to lookup
|
||||
|
||||
Returns:
|
||||
Dict with profile details or None
|
||||
"""
|
||||
profile = cls.get_profile(cui)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
return {
|
||||
"cui": cui,
|
||||
"cuis": list(profile.CUI_LIST),
|
||||
"class_name": profile.__class__.__name__,
|
||||
"store_name": profile.STORE_NAME,
|
||||
"name_patterns": list(profile.NAME_PATTERNS),
|
||||
"validation_hints": profile.get_validation_hints(),
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Hot-Reload
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def reload_all(cls) -> int:
|
||||
"""
|
||||
Hot-reload all profile modules.
|
||||
|
||||
Clears instance cache and reloads all .py files in profiles directory.
|
||||
Decorator re-registers classes with updated code.
|
||||
|
||||
Returns:
|
||||
Number of modules reloaded
|
||||
"""
|
||||
logger.info("Starting profile hot-reload...")
|
||||
|
||||
# Clear instance cache (will be recreated on next get_profile)
|
||||
cls._instances.clear()
|
||||
|
||||
# Get list of profile modules (exclude __init__, base)
|
||||
module_names = cls._get_profile_module_names()
|
||||
|
||||
# Determine the module prefix based on how THIS module was imported
|
||||
base_package = cls.__module__
|
||||
|
||||
count = 0
|
||||
for module_name in module_names:
|
||||
full_name = f"{base_package}.{module_name}"
|
||||
|
||||
try:
|
||||
if full_name in sys.modules:
|
||||
# Reload existing module
|
||||
importlib.reload(sys.modules[full_name])
|
||||
logger.debug(f"Reloaded module: {module_name}")
|
||||
else:
|
||||
# Import new module
|
||||
importlib.import_module(full_name)
|
||||
logger.debug(f"Imported new module: {module_name}")
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload {module_name}: {e}")
|
||||
|
||||
cls._last_reload = datetime.utcnow()
|
||||
cls._loaded = True
|
||||
|
||||
logger.info(f"Profile hot-reload complete: {count} modules, {len(cls._profiles)} profiles")
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def get_reload_status(cls) -> Dict:
|
||||
"""Get status of the registry including last reload time."""
|
||||
return {
|
||||
"loaded": cls._loaded,
|
||||
"last_reload": cls._last_reload.isoformat() if cls._last_reload else None,
|
||||
"profiles_count": len(cls._profiles),
|
||||
"instances_count": len(cls._instances),
|
||||
"registered_cuis": list(cls._profiles.keys()),
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def _normalize_cui(cls, cui: str) -> str:
|
||||
"""
|
||||
Normalize CUI for consistent lookup.
|
||||
|
||||
- Removes RO prefix (with or without space)
|
||||
- Strips whitespace
|
||||
- Converts to uppercase
|
||||
|
||||
Args:
|
||||
cui: Raw CUI string
|
||||
|
||||
Returns:
|
||||
Normalized CUI (digits only)
|
||||
"""
|
||||
if not cui:
|
||||
return ""
|
||||
|
||||
cui = str(cui).strip().upper()
|
||||
|
||||
# Remove RO prefix (handles "RO12345" and "RO 12345")
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:].lstrip()
|
||||
|
||||
return cui.strip()
|
||||
|
||||
@classmethod
|
||||
def _get_profile_module_names(cls) -> List[str]:
|
||||
"""
|
||||
Get list of profile module names from profiles directory.
|
||||
|
||||
Excludes __init__.py and base.py.
|
||||
|
||||
Returns:
|
||||
List of module names (without .py extension)
|
||||
"""
|
||||
excluded = {"__init__", "base", "__pycache__"}
|
||||
modules = []
|
||||
|
||||
for path in PROFILES_DIR.glob("*.py"):
|
||||
name = path.stem
|
||||
if name not in excluded:
|
||||
modules.append(name)
|
||||
|
||||
return sorted(modules)
|
||||
|
||||
@classmethod
|
||||
def _load_all_profiles(cls) -> None:
|
||||
"""
|
||||
Initial load of all profile modules.
|
||||
|
||||
Called automatically on first get_profile() if not already loaded.
|
||||
"""
|
||||
if cls._loaded:
|
||||
return
|
||||
|
||||
logger.info("Loading store profiles...")
|
||||
|
||||
module_names = cls._get_profile_module_names()
|
||||
|
||||
# Determine the module prefix based on how THIS module was imported
|
||||
# This handles both:
|
||||
# - Running from backend dir: "modules.data_entry.services.ocr.profiles"
|
||||
# - Running from project root: "backend.modules.data_entry.services.ocr.profiles"
|
||||
this_module = cls.__module__ # e.g. "backend.modules..." or "modules..."
|
||||
base_package = this_module # Use the same prefix for child modules
|
||||
|
||||
for module_name in module_names:
|
||||
full_name = f"{base_package}.{module_name}"
|
||||
try:
|
||||
importlib.import_module(full_name)
|
||||
logger.debug(f"Loaded module: {module_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {module_name}: {e}")
|
||||
|
||||
cls._loaded = True
|
||||
cls._last_reload = datetime.utcnow()
|
||||
|
||||
logger.info(f"Loaded {len(cls._profiles)} store profiles")
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""
|
||||
Clear all registered profiles.
|
||||
|
||||
Mainly useful for testing.
|
||||
"""
|
||||
cls._profiles.clear()
|
||||
cls._instances.clear()
|
||||
cls._loaded = False
|
||||
cls._last_reload = None
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Module exports
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
__all__ = [
|
||||
"ProfileRegistry",
|
||||
"BaseStoreProfile",
|
||||
]
|
||||
|
||||
# Re-export BaseStoreProfile for convenience
|
||||
from .base import BaseStoreProfile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,655 @@
|
||||
"""
|
||||
Optimized Tesseract Engine for OCR - SPEED + QUALITY OPTIMIZED
|
||||
|
||||
Performance optimizations (vs previous version):
|
||||
- Single PSM mode (PSM 4) instead of multi-PSM (4 modes × 2 calls = 8x faster)
|
||||
- Single Tesseract call per image (skip image_to_data for speed)
|
||||
- Lighter preprocessing (no over-binarization)
|
||||
- --dpi 300 flag for proper scaling
|
||||
- OEM 3 (default LSTM+Legacy) for balanced speed/accuracy
|
||||
|
||||
Quality optimizations for Romanian receipts:
|
||||
- PSM 4: Single column layout (optimal for receipts)
|
||||
- Polarity correction: ensures black text on white background
|
||||
- Language: Romanian only (-l ron) for faster recognition
|
||||
- Fallback to PSM 6 if PSM 4 produces poor results
|
||||
|
||||
Previous issues fixed:
|
||||
- Was 8x slower than PaddleOCR due to multi-PSM + dual calls
|
||||
- Produced gibberish on clear PDFs due to over-binarization
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Check Tesseract availability
|
||||
try:
|
||||
import pytesseract
|
||||
TESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TESSERACT_AVAILABLE = False
|
||||
pytesseract = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Raw OCR result from Tesseract."""
|
||||
text: str
|
||||
confidence: float
|
||||
boxes: List[dict] = field(default_factory=list)
|
||||
engine: str = "tesseract"
|
||||
|
||||
|
||||
class TesseractEngine:
|
||||
"""
|
||||
Optimized Tesseract engine for receipt OCR.
|
||||
|
||||
TESTED OPTIMAL SETTINGS (from comprehensive benchmark):
|
||||
- DPI 200 for PDF loading (not 300!)
|
||||
- Padding 40px for edge protection
|
||||
- PSM 6 for complex receipts, PSM 4 for simple ones
|
||||
- Multi-pass strategy when quality is critical
|
||||
|
||||
SPEED vs QUALITY tradeoff:
|
||||
- Fast mode (single pass): ~0.9s, ~6-7 keywords
|
||||
- Quality mode (multi-pass): ~1.7s, ~8-9 keywords (+2 more keywords)
|
||||
|
||||
BENCHMARK RESULTS:
|
||||
- padded_psm6_40: Best for complex receipts (igiena, five-holding)
|
||||
- baseline_psm4: Best for simple receipts (rechizite, benzina)
|
||||
- multi-pass: Best overall quality but slower
|
||||
"""
|
||||
|
||||
# PSM modes for receipts
|
||||
PSM_SINGLE_COLUMN = 4 # Best for simple vertical receipts
|
||||
PSM_UNIFORM_BLOCK = 6 # Best for complex layouts
|
||||
PSM_SPARSE_TEXT = 11 # Fallback for difficult receipts
|
||||
|
||||
# Optimal padding (from benchmark)
|
||||
DEFAULT_PADDING = 40
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Tesseract engine."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
raise RuntimeError("pytesseract not available. Install with: pip install pytesseract")
|
||||
|
||||
# Verify Tesseract installation
|
||||
try:
|
||||
self._version = pytesseract.get_tesseract_version()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Tesseract not installed or not in PATH: {e}")
|
||||
|
||||
logger.info(f"[TesseractEngine] Initialized (v{self._version})")
|
||||
|
||||
def recognize(self, image: np.ndarray, fast_mode: bool = True) -> OCRResult:
|
||||
"""
|
||||
Perform OCR recognition on image (OPTIMIZED).
|
||||
|
||||
SPEED: Uses single PSM mode + single Tesseract call.
|
||||
Previously used 4 PSM modes × 2 calls = 8 Tesseract invocations.
|
||||
Now uses 1-2 calls maximum (with fallback).
|
||||
|
||||
Args:
|
||||
image: Preprocessed grayscale image (DO NOT binarize for clear PDFs!)
|
||||
fast_mode: If True, skip confidence calculation for maximum speed
|
||||
|
||||
Returns:
|
||||
OCRResult with text and confidence
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Fix polarity (black text on white background)
|
||||
image = self._ensure_correct_polarity(image)
|
||||
|
||||
# Try PSM 4 first (single column - best for receipts)
|
||||
result = self._recognize_fast(image, self.PSM_SINGLE_COLUMN, fast_mode)
|
||||
|
||||
# If poor result, try PSM 6 as fallback
|
||||
if not result.text.strip() or result.confidence < 0.3:
|
||||
logger.debug(f"[Tesseract] PSM {self.PSM_SINGLE_COLUMN} poor result, trying PSM {self.PSM_UNIFORM_BLOCK}")
|
||||
fallback = self._recognize_fast(image, self.PSM_UNIFORM_BLOCK, fast_mode)
|
||||
if len(fallback.text) > len(result.text):
|
||||
result = fallback
|
||||
|
||||
if result.text.strip():
|
||||
logger.info(f"[TesseractEngine] Result: {len(result.text)} chars, conf={result.confidence:.0%}")
|
||||
|
||||
return result
|
||||
|
||||
def _recognize_fast(self, image: np.ndarray, psm: int, fast_mode: bool = True) -> OCRResult:
|
||||
"""
|
||||
Fast single-call Tesseract recognition.
|
||||
|
||||
Optimizations:
|
||||
- Single call (image_to_string only in fast mode)
|
||||
- OEM 3 (LSTM+Legacy) - faster than OEM 1
|
||||
- --dpi 300 for proper scaling
|
||||
- Romanian only (-l ron)
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
psm: Page segmentation mode
|
||||
fast_mode: Skip confidence calculation for speed
|
||||
|
||||
Returns:
|
||||
OCRResult
|
||||
"""
|
||||
# Build optimized config:
|
||||
# OEM 3 = LSTM + Legacy (faster than pure LSTM)
|
||||
# --dpi 300 = proper scaling hint
|
||||
# -l ron = Romanian only (faster, avoids eng confusion)
|
||||
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
|
||||
|
||||
try:
|
||||
if fast_mode:
|
||||
# Fast path: just get text, estimate confidence
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
# Estimate confidence based on text quality
|
||||
confidence = self._estimate_confidence(text)
|
||||
else:
|
||||
# Accurate path: get text + real confidence
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
data = pytesseract.image_to_data(
|
||||
image, config=config, output_type=pytesseract.Output.DICT
|
||||
)
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
confidence = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
boxes=[],
|
||||
engine="tesseract"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] PSM {psm} error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def _estimate_confidence(self, text: str) -> float:
|
||||
"""
|
||||
Estimate OCR confidence based on text quality.
|
||||
|
||||
Heuristics:
|
||||
- More alphanumeric chars = higher confidence
|
||||
- Less garbage chars = higher confidence
|
||||
- Romanian-specific patterns boost confidence
|
||||
"""
|
||||
if not text.strip():
|
||||
return 0.0
|
||||
|
||||
# Count valid vs garbage chars
|
||||
valid_chars = sum(1 for c in text if c.isalnum() or c in '.,;:-/\n ')
|
||||
total_chars = len(text)
|
||||
|
||||
if total_chars == 0:
|
||||
return 0.0
|
||||
|
||||
# Base confidence from char ratio
|
||||
confidence = valid_chars / total_chars
|
||||
|
||||
# Boost for Romanian receipt patterns
|
||||
text_lower = text.lower()
|
||||
if any(word in text_lower for word in ['total', 'lei', 'ron', 'buc', 'tva', 'cif', 'bon']):
|
||||
confidence = min(confidence + 0.1, 1.0)
|
||||
|
||||
return confidence
|
||||
|
||||
def recognize_multipass(self, image: np.ndarray) -> OCRResult:
|
||||
"""
|
||||
Multi-pass OCR for maximum quality (slower but more accurate).
|
||||
|
||||
Strategy (from benchmark testing):
|
||||
- Pass 1: PSM 4 (single column) - no padding, fast baseline
|
||||
- Pass 2: PSM 6 (uniform block) - with 40px padding, better for complex layouts
|
||||
- Pass 3: PSM 11 (sparse text) - with 40px padding + stronger CLAHE, for difficult receipts
|
||||
|
||||
Merges results: picks the pass with highest keyword count.
|
||||
On average finds +2.1 more keywords than single-pass (~8.7 vs 6.6).
|
||||
|
||||
Time: ~1.7s (vs ~0.9s for single pass)
|
||||
|
||||
Args:
|
||||
image: Input image (RGB or grayscale)
|
||||
|
||||
Returns:
|
||||
OCRResult from the best pass
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# Define passes with different settings
|
||||
passes = [
|
||||
# Pass 1: Fast baseline (no padding) - good for simple receipts
|
||||
{"name": "pass1_psm4", "psm": 4, "padding": 0, "clahe_clip": 1.5},
|
||||
# Pass 2: Padded PSM 6 - good for complex receipts
|
||||
{"name": "pass2_psm6_padded", "psm": 6, "padding": 40, "clahe_clip": 1.5},
|
||||
# Pass 3: Sparse text with stronger enhancement - for difficult cases
|
||||
{"name": "pass3_psm11", "psm": 11, "padding": 40, "clahe_clip": 2.0},
|
||||
]
|
||||
|
||||
best_result = None
|
||||
best_score = -1
|
||||
all_keywords = set()
|
||||
|
||||
for p in passes:
|
||||
# Apply preprocessing for this pass
|
||||
processed = gray.copy()
|
||||
|
||||
# Add padding if specified
|
||||
if p["padding"] > 0:
|
||||
processed = cv2.copyMakeBorder(
|
||||
processed, p["padding"], p["padding"], p["padding"], p["padding"],
|
||||
cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
|
||||
# Apply CLAHE
|
||||
clahe = cv2.createCLAHE(clipLimit=p["clahe_clip"], tileGridSize=(8, 8))
|
||||
processed = clahe.apply(processed)
|
||||
|
||||
# Ensure correct polarity
|
||||
processed = self._ensure_correct_polarity(processed)
|
||||
|
||||
# Run OCR
|
||||
config = f'--psm {p["psm"]} --oem 3 -l ron'
|
||||
try:
|
||||
text = pytesseract.image_to_string(processed, config=config)
|
||||
confidence = self._estimate_confidence(text)
|
||||
|
||||
# Score based on Romanian receipt keywords
|
||||
text_lower = text.lower()
|
||||
keywords = ['cif', 'total', 'tva', 'lei', 'ron', 'buc', 'fiscal', 'bon',
|
||||
'hartie', 'prosop', 'saci', 'creion', 'constanta', 'bucuresti']
|
||||
found_keywords = [kw for kw in keywords if kw in text_lower]
|
||||
all_keywords.update(found_keywords)
|
||||
|
||||
# Score: keywords + CIF bonus + TOTAL bonus
|
||||
score = len(found_keywords) * 10
|
||||
if self._has_cif_pattern(text):
|
||||
score += 15
|
||||
if self._has_total_pattern(text):
|
||||
score += 10
|
||||
|
||||
logger.debug(f"[Tesseract] {p['name']}: {len(found_keywords)} keywords, score={score}")
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_result = OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
boxes=[],
|
||||
engine=f"tesseract-multipass-{p['name']}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] {p['name']} failed: {e}")
|
||||
continue
|
||||
|
||||
if best_result:
|
||||
logger.info(f"[TesseractEngine] Multi-pass best: {best_result.engine}, "
|
||||
f"{len(all_keywords)} total keywords found")
|
||||
return best_result
|
||||
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract-multipass")
|
||||
|
||||
def _has_cif_pattern(self, text: str) -> bool:
|
||||
"""Check if text contains a valid CIF/CUI pattern."""
|
||||
import re
|
||||
text_upper = text.upper()
|
||||
patterns = [
|
||||
r'CIF[:\s]*RO?\d{6,10}',
|
||||
r'CUI[:\s]*RO?\d{6,10}',
|
||||
r'C\.?I\.?F\.?[:\s]*RO?\d{6,10}',
|
||||
]
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text_upper):
|
||||
return True
|
||||
return bool(re.search(r'RO\d{7,10}', text_upper))
|
||||
|
||||
def _has_total_pattern(self, text: str) -> bool:
|
||||
"""Check if TOTAL is properly recognized (not truncated to BTOTAL/OTAL)."""
|
||||
import re
|
||||
text_upper = text.upper()
|
||||
return bool(re.search(r'(^|\s)TOTAL\s', text_upper, re.MULTILINE))
|
||||
|
||||
def recognize_with_boxes(self, image: np.ndarray, psm: int = 4) -> OCRResult:
|
||||
"""
|
||||
Recognition with bounding boxes (slower, for debugging/visualization).
|
||||
|
||||
Use this only when you need box coordinates.
|
||||
For normal OCR, use recognize() which is faster.
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
psm: Page segmentation mode (default: 4 for receipts)
|
||||
|
||||
Returns:
|
||||
OCRResult with text, confidence, and boxes
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
image = self._ensure_correct_polarity(image)
|
||||
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
data = pytesseract.image_to_data(
|
||||
image, config=config, output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
boxes = []
|
||||
for i in range(len(data['text'])):
|
||||
if data['text'][i].strip() and int(data['conf'][i]) > 0:
|
||||
boxes.append({
|
||||
'text': data['text'][i],
|
||||
'confidence': int(data['conf'][i]) / 100,
|
||||
'box': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
|
||||
})
|
||||
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=boxes, engine="tesseract")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] recognize_with_boxes error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def _ensure_correct_polarity(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Ensure image has black text on white background.
|
||||
|
||||
Receipts should have dark text on light background.
|
||||
If image is inverted (light text on dark), invert it.
|
||||
|
||||
Detection method:
|
||||
- Calculate mean pixel value
|
||||
- If mean < 127, image is mostly dark (inverted)
|
||||
- Invert to correct polarity
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
|
||||
Returns:
|
||||
Polarity-corrected image
|
||||
"""
|
||||
mean_value = np.mean(image)
|
||||
|
||||
if mean_value < 127:
|
||||
# Image is mostly dark = inverted (white text on black)
|
||||
logger.debug(f"[TesseractEngine] Detected inverted polarity (mean={mean_value:.1f}), correcting...")
|
||||
return 255 - image
|
||||
|
||||
return image
|
||||
|
||||
def recognize_numbers_only(self, image: np.ndarray) -> OCRResult:
|
||||
"""
|
||||
OCR optimized for numeric content (amounts, totals).
|
||||
|
||||
Uses character whitelist to reduce errors on numbers.
|
||||
|
||||
Args:
|
||||
image: Preprocessed grayscale image
|
||||
|
||||
Returns:
|
||||
OCRResult with numeric text
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Fix polarity
|
||||
image = self._ensure_correct_polarity(image)
|
||||
|
||||
# Config for numbers only
|
||||
# Whitelist: digits, comma, period, space, RON, LEI
|
||||
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789.,- '
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
|
||||
data = pytesseract.image_to_data(
|
||||
image,
|
||||
config=config,
|
||||
output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text.strip(),
|
||||
confidence=avg_conf,
|
||||
boxes=[],
|
||||
engine="tesseract-numeric"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[TesseractEngine] Numeric OCR error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def recognize_cif_optimized(self, image: np.ndarray) -> Optional[str]:
|
||||
"""
|
||||
Optimized CIF extraction using multi-strategy approach.
|
||||
|
||||
BENCHMARK RESULTS (from test_critical_fields.py):
|
||||
- digit_opt_dpi200: 33% accuracy (best)
|
||||
- digit_whitelist: Works well on specific receipts
|
||||
- basic_ron_eng: Good backup
|
||||
|
||||
Strategy:
|
||||
1. Try digit-optimized preprocessing (2x scale + Otsu)
|
||||
2. Try character whitelist (RO + digits only)
|
||||
3. Try standard ron+eng config
|
||||
4. Return best match based on CIF pattern validation
|
||||
|
||||
Args:
|
||||
image: Input image (RGB from pdf2image or BGR from OpenCV)
|
||||
|
||||
Returns:
|
||||
Extracted CIF string (e.g., "RO10562600") or None
|
||||
"""
|
||||
import re
|
||||
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# Extract top 35% of image (where CIF is typically found)
|
||||
height = gray.shape[0]
|
||||
top_region = gray[:int(height * 0.35), :]
|
||||
|
||||
candidates = []
|
||||
|
||||
# Strategy 1: Digit-optimized preprocessing (best performer: 33% accuracy)
|
||||
try:
|
||||
# Scale up 2x + Otsu binarization
|
||||
scaled = cv2.resize(top_region, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(scaled)
|
||||
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
if np.mean(binary) < 127:
|
||||
binary = 255 - binary
|
||||
|
||||
text = pytesseract.image_to_string(binary, config='--psm 6 --oem 3 -l ron')
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('digit_opt', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] digit_opt strategy failed: {e}")
|
||||
|
||||
# Strategy 2: Character whitelist (RO + digits only)
|
||||
try:
|
||||
# Add padding
|
||||
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
|
||||
scaled = cv2.resize(padded, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789ROro'
|
||||
text = pytesseract.image_to_string(scaled, config=config)
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('whitelist', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] whitelist strategy failed: {e}")
|
||||
|
||||
# Strategy 3: Standard ron+eng config (good backup)
|
||||
try:
|
||||
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
|
||||
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(padded)
|
||||
|
||||
text = pytesseract.image_to_string(enhanced, config='--psm 6 --oem 3 -l ron+eng')
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('ron_eng', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] ron_eng strategy failed: {e}")
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Log all candidates
|
||||
for strategy, cif in candidates:
|
||||
logger.debug(f"[TesseractEngine] CIF candidate from {strategy}: {cif}")
|
||||
|
||||
# Use majority voting if multiple strategies agree
|
||||
from collections import Counter
|
||||
cif_counts = Counter(cif for _, cif in candidates)
|
||||
most_common_cif, count = cif_counts.most_common(1)[0]
|
||||
|
||||
if count > 1:
|
||||
# Multiple strategies agree
|
||||
logger.info(f"[TesseractEngine] CIF extracted (majority {count} strategies): {most_common_cif}")
|
||||
return most_common_cif
|
||||
|
||||
# No agreement - prefer digit_opt strategy (33% accuracy in benchmarks)
|
||||
for strategy, cif in candidates:
|
||||
if strategy == 'digit_opt':
|
||||
logger.info(f"[TesseractEngine] CIF extracted via digit_opt (preferred): {cif}")
|
||||
return cif
|
||||
|
||||
# Fallback to first candidate
|
||||
strategy, cif = candidates[0]
|
||||
logger.info(f"[TesseractEngine] CIF extracted via {strategy}: {cif}")
|
||||
return cif
|
||||
|
||||
def _extract_cif_from_text(self, text: str) -> Optional[str]:
|
||||
"""Extract CIF/CUI from OCR text."""
|
||||
import re
|
||||
text_upper = text.upper().replace(' ', '')
|
||||
|
||||
patterns = [
|
||||
r'CIF[:\s]*R?O?(\d{6,10})',
|
||||
r'CUI[:\s]*R?O?(\d{6,10})',
|
||||
r'C\.?I\.?F\.?[:\s]*R?O?(\d{6,10})',
|
||||
r'RO(\d{7,10})',
|
||||
r'R\.?O\.?[\s:]*(\d{6,10})',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text_upper)
|
||||
if match:
|
||||
digits = match.group(1).lstrip('0') or '0'
|
||||
return f"RO{digits}"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_romanian_cif(cif: str) -> bool:
|
||||
"""
|
||||
Validate Romanian CIF/CUI using checksum algorithm.
|
||||
|
||||
Romanian CIF format: RO + 2-10 digits
|
||||
The last digit is a control digit calculated using modulo 11.
|
||||
|
||||
Algorithm:
|
||||
1. Multiply each digit by corresponding weight (from right to left: 2,3,4,5,6,7,2,3,4,5)
|
||||
2. Sum all products
|
||||
3. Remainder of sum / 11 is the control digit
|
||||
4. If remainder is 10, control digit is 0
|
||||
|
||||
Args:
|
||||
cif: CIF string (e.g., "RO10562600", "10562600")
|
||||
|
||||
Returns:
|
||||
True if CIF is valid, False otherwise
|
||||
"""
|
||||
# Remove RO prefix and spaces
|
||||
cif = cif.upper().replace(' ', '').replace('RO', '')
|
||||
|
||||
# Must be 2-10 digits
|
||||
if not cif.isdigit() or len(cif) < 2 or len(cif) > 10:
|
||||
return False
|
||||
|
||||
# Weights for checksum calculation (right to left)
|
||||
weights = [2, 3, 4, 5, 6, 7, 2, 3, 4, 5]
|
||||
|
||||
# Pad with zeros on the left to make it 10 digits
|
||||
cif_padded = cif.zfill(10)
|
||||
|
||||
# Calculate checksum (excluding last digit which is control)
|
||||
total = 0
|
||||
for i in range(9):
|
||||
total += int(cif_padded[i]) * weights[i]
|
||||
|
||||
# Control digit
|
||||
control = total % 11
|
||||
if control == 10:
|
||||
control = 0
|
||||
|
||||
# Compare with last digit
|
||||
return int(cif_padded[9]) == control
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if Tesseract is available."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return False
|
||||
|
||||
try:
|
||||
pytesseract.get_tesseract_version()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_version() -> Optional[str]:
|
||||
"""Get Tesseract version string."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
return str(pytesseract.get_tesseract_version())
|
||||
except Exception:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user