fix telegram

This commit is contained in:
Claude Agent
2026-02-23 15:12:33 +00:00
parent 6c78fec8a7
commit 8bc567a9c5
426 changed files with 112478 additions and 1 deletions

View File

@@ -0,0 +1,16 @@
# Business logic services
from .receipt_service import ReceiptService
from .nomenclature_service import NomenclatureService
from .expense_types import EXPENSE_TYPES, ExpenseType
from .receipt_auto_create import ReceiptAutoCreateService, ReceiptCreateResult
from . import sse_service
__all__ = [
"ReceiptService",
"NomenclatureService",
"EXPENSE_TYPES",
"ExpenseType",
"ReceiptAutoCreateService",
"ReceiptCreateResult",
"sse_service",
]

View File

@@ -0,0 +1,215 @@
"""
Cleanup service for auto-deleting expired failed receipts.
US-008: Backend - Auto-Cleanup Erori După 7 Zile
- Finds receipts with processing_status='failed' and processing_completed_at < now() - 7 days
- Deletes the receipts and their attached files from storage
- Runs at startup and then daily as a background task
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
from backend.modules.data_entry.config import settings
logger = logging.getLogger(__name__)
# Cleanup configuration
CLEANUP_RETENTION_DAYS = 7
CLEANUP_INTERVAL_HOURS = 24
# In-memory storage for last cleanup stats (optional - for login notification)
_last_cleanup_stats: dict = {
"count": 0,
"timestamp": None
}
def get_last_cleanup_stats() -> dict:
"""Get stats from the last cleanup run for notification purposes."""
return _last_cleanup_stats.copy()
async def cleanup_expired_failed_receipts(session: AsyncSession) -> int:
"""
Find and delete receipts with processing_status='failed' older than 7 days.
This function:
1. Queries for failed receipts where processing_completed_at < now() - 7 days
2. Deletes attachment files from disk
3. Deletes the receipt records (cascade deletes attachment records)
Args:
session: AsyncSession for database operations
Returns:
Number of receipts deleted
"""
global _last_cleanup_stats
cutoff_date = datetime.utcnow() - timedelta(days=CLEANUP_RETENTION_DAYS)
# Find expired failed receipts with their attachments
query = select(Receipt).options(
selectinload(Receipt.attachments)
).where(
and_(
Receipt.processing_status == "failed",
Receipt.processing_completed_at.isnot(None),
Receipt.processing_completed_at < cutoff_date
)
)
result = await session.execute(query)
expired_receipts = result.scalars().all()
if not expired_receipts:
logger.debug("[Cleanup] No expired failed receipts found")
return 0
deleted_count = 0
deleted_files = 0
upload_base_path = settings.upload_path_resolved
for receipt in expired_receipts:
try:
# Delete attachment files from disk
for attachment in receipt.attachments:
file_path = upload_base_path / attachment.file_path
if file_path.exists():
try:
file_path.unlink()
deleted_files += 1
logger.debug(f"[Cleanup] Deleted file: {file_path}")
except OSError as e:
logger.warning(f"[Cleanup] Failed to delete file {file_path}: {e}")
# Also try to clean up empty parent directories
parent_dir = file_path.parent
if parent_dir.exists() and parent_dir != upload_base_path:
try:
# Only remove if directory is empty
if not any(parent_dir.iterdir()):
parent_dir.rmdir()
logger.debug(f"[Cleanup] Removed empty directory: {parent_dir}")
except OSError:
pass # Directory not empty or permission issue, skip
# Delete receipt (cascade deletes attachment records in DB)
await session.delete(receipt)
deleted_count += 1
except Exception as e:
logger.error(f"[Cleanup] Error deleting receipt {receipt.id}: {e}")
continue
# Commit all deletions
if deleted_count > 0:
await session.commit()
# Update stats for notification
_last_cleanup_stats = {
"count": deleted_count,
"files_deleted": deleted_files,
"timestamp": datetime.utcnow().isoformat()
}
logger.info(f"[Cleanup] Cleaned up {deleted_count} expired failed receipts ({deleted_files} files)")
return deleted_count
async def run_cleanup_task(get_session_func) -> None:
"""
Background task that runs cleanup at startup and then every 24 hours.
Args:
get_session_func: Async generator function that yields database sessions
"""
logger.info("[Cleanup] Starting cleanup background task")
# Run immediately at startup
try:
async for session in get_session_func():
count = await cleanup_expired_failed_receipts(session)
if count > 0:
logger.info(f"[Cleanup] Initial cleanup: {count} receipts removed")
break
except Exception as e:
logger.error(f"[Cleanup] Initial cleanup failed: {e}")
# Then run every 24 hours
while True:
try:
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
async for session in get_session_func():
count = await cleanup_expired_failed_receipts(session)
if count > 0:
logger.info(f"[Cleanup] Daily cleanup: {count} receipts removed")
break
except asyncio.CancelledError:
logger.info("[Cleanup] Cleanup task cancelled")
raise
except Exception as e:
logger.error(f"[Cleanup] Daily cleanup failed: {e}")
# Continue running even if one cleanup fails
# Global reference to cleanup task for graceful shutdown
_cleanup_task: Optional[asyncio.Task] = None
async def start_cleanup_task(get_session_func) -> bool:
"""
Start the cleanup background task.
Args:
get_session_func: Async generator function that yields database sessions
Returns:
True if task started successfully, False otherwise
"""
global _cleanup_task
if _cleanup_task is not None and not _cleanup_task.done():
logger.warning("[Cleanup] Cleanup task already running")
return False
try:
_cleanup_task = asyncio.create_task(run_cleanup_task(get_session_func))
logger.info("[Cleanup] ✅ Cleanup background task started")
return True
except Exception as e:
logger.error(f"[Cleanup] Failed to start cleanup task: {e}")
return False
async def stop_cleanup_task() -> None:
"""Stop the cleanup background task gracefully."""
global _cleanup_task
if _cleanup_task is not None and not _cleanup_task.done():
_cleanup_task.cancel()
try:
await _cleanup_task
except asyncio.CancelledError:
pass
logger.info("[Cleanup] Cleanup task stopped")
_cleanup_task = None
def is_cleanup_task_running() -> bool:
"""Check if the cleanup task is currently running."""
return _cleanup_task is not None and not _cleanup_task.done()

View File

@@ -0,0 +1,101 @@
"""Predefined expense types for automatic accounting entry generation."""
from decimal import Decimal
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class ExpenseType:
"""Expense type definition with accounting configuration."""
code: str
name: str
account_code: str
account_name: str
has_vat: bool
vat_percent: Decimal = Decimal("19")
vat_account: str = "4426"
# Predefined expense types
EXPENSE_TYPES: Dict[str, ExpenseType] = {
"FUEL": ExpenseType(
code="FUEL",
name="Combustibil",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"MATERIALS": ExpenseType(
code="MATERIALS",
name="Materiale consumabile",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=True,
),
"OFFICE": ExpenseType(
code="OFFICE",
name="Rechizite birou",
account_code="6024",
account_name="Cheltuieli privind materialele pentru ambalat",
has_vat=True,
),
"PHONE": ExpenseType(
code="PHONE",
name="Telefonie / Internet",
account_code="626",
account_name="Cheltuieli postale si taxe de telecomunicatii",
has_vat=True,
),
"PARKING": ExpenseType(
code="PARKING",
name="Parcare",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"FOOD": ExpenseType(
code="FOOD",
name="Alimentatie",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=False, # No deductible VAT for food
),
"TRANSPORT": ExpenseType(
code="TRANSPORT",
name="Transport",
account_code="624",
account_name="Cheltuieli cu transportul de bunuri si personal",
has_vat=True,
),
"OTHER": ExpenseType(
code="OTHER",
name="Altele",
account_code="628",
account_name="Alte cheltuieli cu serviciile executate de terti",
has_vat=True,
),
}
def get_expense_type(code: str) -> Optional[ExpenseType]:
"""Get expense type by code."""
return EXPENSE_TYPES.get(code)
def get_all_expense_types() -> Dict[str, ExpenseType]:
"""Get all expense types."""
return EXPENSE_TYPES.copy()
# Default cash register accounts
CASH_REGISTER_ACCOUNTS = {
"CASA": {
"code": "5311",
"name": "Casa in lei",
},
"BANCA": {
"code": "5121",
"name": "Conturi la banci in lei",
},
}

View File

@@ -0,0 +1,366 @@
"""Image preprocessing for optimal OCR results."""
from pathlib import Path
from typing import List
import numpy as np
import cv2
try:
import pdf2image
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
class ImagePreprocessor:
"""Preprocess receipt images for OCR."""
def _add_safety_padding(self, image: np.ndarray, padding: int = 50) -> np.ndarray:
"""Add white padding around image to protect edge content during rotation.
This prevents left/right margin truncation in OCR by ensuring text near
edges isn't lost during deskew rotation.
"""
if len(image.shape) == 2:
# Grayscale
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=255
)
else:
# Color (BGR)
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
def load_image(self, path: Path) -> np.ndarray:
"""Load image from file."""
image = cv2.imread(str(path))
if image is None:
raise ValueError(f"Could not load image: {path}")
return image
def pdf_to_images(self, path: Path, dpi: int = 300) -> List[np.ndarray]:
"""
Convert PDF to images.
Args:
path: Path to PDF file
dpi: Resolution (300 = fast & good quality, 400 = better but slower)
"""
if not PDF_AVAILABLE:
raise RuntimeError("pdf2image not available. Install with: pip install pdf2image")
images = pdf2image.convert_from_path(str(path), dpi=dpi)
return [np.array(img) for img in images]
def preprocess(self, image: np.ndarray, high_quality: bool = True) -> np.ndarray:
"""
Apply LIGHT preprocessing - better for clear PDFs.
Heavy binarization can destroy text on clear images.
"""
return self.preprocess_light(image)
def preprocess_light(self, image: np.ndarray) -> np.ndarray:
"""
Light preprocessing for CLEAR images (PDFs, good scans).
Preserves original quality, only enhances contrast.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Light contrast enhancement only
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# NO binarization, NO morphological ops - preserve original quality
return enhanced
def preprocess_medium(self, image: np.ndarray) -> np.ndarray:
"""
Medium preprocessing for MIXED-QUALITY images.
Balance between Light (too gentle) and Heavy (too aggressive).
Use cases:
- Moderately faded receipts
- Photos with uneven lighting
- Scans with slight blur
Preprocessing steps:
- Moderate contrast enhancement (CLAHE clipLimit=2.0)
- Light denoising (fastNlMeansDenoising h=6)
- Gentle sharpening
- NO binarization (preserves text boundaries)
- NO morphological operations (avoids digit concatenation)
This method was created to replace preprocess_heavy() which caused
digit concatenation errors on high-quality PDFs (85.99 → 859,762.16).
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Moderate contrast enhancement (CLAHE clipLimit=2.0)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Light denoising (less aggressive than Heavy)
denoised = cv2.fastNlMeansDenoising(enhanced, h=6, templateWindowSize=7, searchWindowSize=15)
# 6. Gentle sharpening
gaussian = cv2.GaussianBlur(denoised, (0, 0), 1.0)
sharpened = cv2.addWeighted(denoised, 1.3, gaussian, -0.3, 0)
# NO binarization, NO morphological operations
# This preserves text boundaries and avoids digit concatenation
return sharpened
def preprocess_heavy(self, image: np.ndarray) -> np.ndarray:
"""
Heavy preprocessing for FADED thermal receipts.
Aggressive binarization to recover faded text.
⚠️ DEPRECATED: Use preprocess_medium() instead.
Heavy preprocessing causes digit concatenation on clear PDFs
(e.g., 85.99 → 859,762.16 due to binarization + morphological operations).
Kept for backward compatibility only.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small (larger = better OCR)
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Contrast enhancement with CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Denoise
denoised = cv2.fastNlMeansDenoising(enhanced, h=8, templateWindowSize=7, searchWindowSize=21)
# 6. Sharpening
gaussian = cv2.GaussianBlur(denoised, (0, 0), 2.0)
sharpened = cv2.addWeighted(denoised, 1.5, gaussian, -0.5, 0)
# 7. Adaptive thresholding (binarization)
binary = cv2.adaptiveThreshold(
sharpened, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
blockSize=11, C=5
)
# 8. Morphological operations
kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
result = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_close)
return result
def preprocess_for_tesseract(self, image: np.ndarray, binarize: bool = False,
padding: int = 0, clahe_clip: float = 1.5) -> np.ndarray:
"""
Tesseract-optimized preprocessing (based on comprehensive benchmark).
BENCHMARK FINDINGS:
- DPI 200 is optimal (not 300!)
- Padding 40px fixes left margin truncation issues
- CLAHE 1.5 for most receipts, 2.0 for difficult ones
- NO deskew, NO denoising for clear PDFs
Recommended usage:
- Simple receipts: padding=0, clahe_clip=1.5
- Complex receipts: padding=40, clahe_clip=1.5
- Difficult/faded: padding=40, clahe_clip=2.0, binarize=True
Args:
image: Input image (RGB from pdf2image or BGR from OpenCV)
binarize: Apply Otsu binarization (for faded receipts)
padding: White padding in pixels (40px recommended for edge protection)
clahe_clip: CLAHE clip limit (1.5 normal, 2.0 for difficult)
Returns:
Preprocessed grayscale image
"""
# 1. Grayscale (handle both RGB and BGR)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# 2. Add padding if specified (protects against left margin truncation)
if padding > 0:
gray = cv2.copyMakeBorder(
gray, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=255
)
# 3. CLAHE contrast enhancement
clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# NO deskew, NO denoising - these DEGRADE quality on clear PDFs!
if not binarize:
return enhanced
# Binarization only for faded receipts
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Ensure correct polarity
if np.mean(binary) < 127:
binary = 255 - binary
return binary
def preprocess_for_tesseract_padded(self, image: np.ndarray) -> np.ndarray:
"""
Tesseract preprocessing with optimal padding (40px).
Best for complex receipts where left margin gets truncated.
"""
return self.preprocess_for_tesseract(image, padding=40)
def preprocess_for_tesseract_faded(self, image: np.ndarray) -> np.ndarray:
"""
Tesseract preprocessing for FADED thermal receipts.
Uses binarization to recover faded text.
"""
return self.preprocess_for_tesseract(image, binarize=True)
def get_all_variants(self, image: np.ndarray) -> List[np.ndarray]:
"""
Generate 2 preprocessing variants for OCR (fast mode).
Returns: [light_processed, heavy_processed]
"""
return [
self.preprocess_light(image),
self.preprocess_heavy(image),
]
def _deskew(self, image: np.ndarray) -> np.ndarray:
"""Correct image rotation/skew using Hough lines.
Uses expanded canvas to preserve all content during rotation,
preventing left/right margin truncation.
"""
edges = cv2.Canny(image, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(
edges, 1, np.pi / 180,
threshold=100, minLineLength=100, maxLineGap=10
)
if lines is None:
return image
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi
if abs(angle) < 45:
angles.append(angle)
if not angles:
return image
median_angle = np.median(angles)
if abs(median_angle) < 0.5:
return image
h, w = image.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
# Calculate new canvas size to fit entire rotated image (prevents edge truncation)
cos_angle = abs(np.cos(np.radians(median_angle)))
sin_angle = abs(np.sin(np.radians(median_angle)))
new_w = int(h * sin_angle + w * cos_angle)
new_h = int(h * cos_angle + w * sin_angle)
# Adjust rotation matrix for new canvas center
M[0, 2] += (new_w - w) / 2
M[1, 2] += (new_h - h) / 2
return cv2.warpAffine(
image, M, (new_w, new_h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_CONSTANT,
borderValue=255 # White background (grayscale)
)

View File

@@ -0,0 +1,216 @@
"""Service for fetching nomenclatures from Oracle (read-only)."""
from typing import List, Optional
from decimal import Decimal
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.schemas.receipt import (
PartnerOption,
AccountOption,
CashRegisterOption,
ExpenseTypeOption,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
class NomenclatureService:
"""
Service for fetching nomenclatures.
In Phase 1 (MVP), some nomenclatures are hardcoded.
In Phase 2, these will be fetched from Oracle.
"""
@staticmethod
async def get_partners(
company_id: int,
search: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[PartnerOption]:
"""
Get partners (suppliers/customers) for a company.
Returns synced suppliers from Oracle + local suppliers created from OCR.
If no suppliers exist, returns empty list (frontend will trigger sync).
"""
partners = []
if not session:
return partners
# Get synced suppliers from Oracle
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.order_by(SyncedSupplier.name)
result = await session.execute(stmt)
suppliers = result.scalars().all()
for s in suppliers:
partners.append(PartnerOption(
name=s.name,
fiscal_code=s.fiscal_code,
address=s.address,
source="oracle"
))
# Always get local suppliers (not just when synced exist)
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
local_stmt = local_stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
local_stmt = local_stmt.order_by(LocalSupplier.name)
local_result = await session.execute(local_stmt)
local_suppliers = local_result.scalars().all()
for l in local_suppliers:
partners.append(PartnerOption(
name=l.name,
fiscal_code=l.fiscal_code,
address=l.address,
source="local"
))
return partners
@staticmethod
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Get chart of accounts for a company.
Phase 1: Returns common expense/income accounts.
Phase 2: Will fetch from Oracle PLAN_CONTURI.
"""
# Common accounts for expenses and receipts
accounts = [
# Expense accounts (Class 6)
AccountOption(code="6022", name="Cheltuieli cu combustibilii"),
AccountOption(code="6024", name="Cheltuieli materiale pentru ambalat"),
AccountOption(code="6028", name="Alte cheltuieli cu materiale consumabile"),
AccountOption(code="624", name="Cheltuieli cu transportul de bunuri si personal"),
AccountOption(code="626", name="Cheltuieli postale si taxe telecomunicatii"),
AccountOption(code="628", name="Alte cheltuieli cu serviciile executate de terti"),
# VAT
AccountOption(code="4426", name="TVA deductibila"),
AccountOption(code="4427", name="TVA colectata"),
# Cash and Bank (Class 5)
AccountOption(code="5311", name="Casa in lei"),
AccountOption(code="5121", name="Conturi la banci in lei"),
# Income accounts (Class 7)
AccountOption(code="7588", name="Alte venituri din exploatare"),
]
if prefix:
accounts = [a for a in accounts if a.code.startswith(prefix)]
return accounts
@staticmethod
async def get_cash_registers(
company_id: int,
session: Optional[AsyncSession] = None
) -> List[CashRegisterOption]:
"""
Get cash registers and bank accounts for a company.
Phase 1: Returns default options.
Phase 2: Returns synced data from SQLite (from Oracle sync).
Phase 3: Will fetch live from Oracle NOM_CASE / NOM_BANCI.
"""
# If session is provided, try to get from synced SQLite data
if session:
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
if registers:
return [
CashRegisterOption(id=r.id, name=r.name, account_code=r.account_code)
for r in registers
]
# Fallback to default cash registers for Phase 1
return [
CashRegisterOption(id=1, name="Casa principala", account_code="5311"),
CashRegisterOption(id=2, name="Cont BCR", account_code="5121"),
CashRegisterOption(id=3, name="Cont BRD", account_code="5121"),
]
@staticmethod
async def get_expense_types() -> List[ExpenseTypeOption]:
"""
Get predefined expense types with their accounting configuration.
"""
return [
ExpenseTypeOption(
code=et.code,
name=et.name,
account_code=et.account_code,
has_vat=et.has_vat,
vat_percent=et.vat_percent,
)
for et in EXPENSE_TYPES.values()
]
@staticmethod
async def get_companies(username: str) -> List[dict]:
"""
Get companies accessible by user.
Phase 1: Returns mock data.
Phase 2: Will fetch from shared auth based on user permissions.
"""
# TODO: Integrate with shared auth to get user's companies
return [
{"id": 1, "name": "SC Test SRL", "cui": "RO12345678"},
{"id": 2, "name": "SC Demo SA", "cui": "RO87654321"},
]
# ============ Phase 2 Oracle Integration Methods ============
@staticmethod
async def _fetch_partners_oracle(company_id: int, search: Optional[str] = None) -> List[PartnerOption]:
"""
Fetch partners from Oracle NOM_PARTENERI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
# Example query:
# SELECT ID_PART, DEN_PART, COD_FISCAL
# FROM {schema}.NOM_PARTENERI
# WHERE DEN_PART LIKE :search
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_accounts_oracle(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Fetch chart of accounts from Oracle PLAN_CONTURI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_cash_registers_oracle(company_id: int) -> List[CashRegisterOption]:
"""
Fetch cash registers from Oracle NOM_CASE / NOM_BANCI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")

View File

@@ -0,0 +1,42 @@
"""
OCR Services Module
Provides persistent OCR worker pool with job queue for efficient processing.
Components:
- ocr_worker_pool: Manages ProcessPoolExecutor with persistent PaddleOCR
- job_queue: SQLite-based job queue for async processing
- job_worker: Background task that processes queued jobs
- tesseract_engine: Optimized Tesseract with multi-PSM and polarity fix
Architecture:
FastAPI → job_queue.create_job() → SQLite
job_worker loop → ocr_worker_pool.submit_task() → Worker Process
PaddleOCR/Tesseract
"""
from .ocr_worker_pool import ocr_worker_pool, OCRWorkerPool
from .job_queue import job_queue, OCRJobQueue, OCRJob, OCRJobStatus
from .job_worker import start_job_worker, stop_job_worker
from .tesseract_engine import TesseractEngine
from .validation import OCRValidationEngine
__all__ = [
# Worker pool
"ocr_worker_pool",
"OCRWorkerPool",
# Job queue
"job_queue",
"OCRJobQueue",
"OCRJob",
"OCRJobStatus",
# Job worker
"start_job_worker",
"stop_job_worker",
# Engines
"TesseractEngine",
# Validation
"OCRValidationEngine",
]

View File

@@ -0,0 +1,653 @@
"""
SQLite Job Queue Manager for OCR Processing
Provides async job queue for OCR requests:
- Jobs are stored in SQLite for persistence
- Queue position and time estimation
- Automatic expiration after 24 hours
- Statistics for monitoring
Schema:
ocr_jobs (
id TEXT PRIMARY KEY, -- UUID
status TEXT NOT NULL, -- pending, processing, completed, failed
file_path TEXT NOT NULL, -- Path to uploaded file
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT, -- JSON extraction result
error_message TEXT,
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
ocr_time_ms INTEGER, -- Actual OCR engine processing time
created_by TEXT, -- Username
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER, -- Foreign key to batch_uploads (for bulk processing)
file_hash TEXT -- SHA-256 hash for duplicate detection (US-007)
)
"""
import asyncio
import json
from decimal import Decimal
class DecimalEncoder(json.JSONEncoder):
"""JSON encoder that handles Decimal types."""
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
import logging
import os
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiosqlite
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_QUEUE_DIR = Path(__file__).parent.parent.parent.parent.parent / "data" / "ocr_queue"
DEFAULT_DB_PATH = DEFAULT_QUEUE_DIR / "ocr_jobs.db"
DEFAULT_FILES_DIR = DEFAULT_QUEUE_DIR / "files"
# Job expiration
JOB_EXPIRY_HOURS = 24
# SQLite busy timeout (milliseconds) - prevents "database is locked" errors
SQLITE_BUSY_TIMEOUT_MS = 5000
class OCRJobStatus(str, Enum):
"""Job status enum."""
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
cancelled = "cancelled"
@dataclass
class OCRJob:
"""OCR Job data class."""
id: str
status: OCRJobStatus
file_path: str
mime_type: str
engine: str = "doctr_plus"
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result_json: Optional[str] = None
error_message: Optional[str] = None
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
created_by: Optional[str] = None
original_filename: Optional[str] = None
expires_at: Optional[datetime] = None
batch_id: Optional[int] = None # Links to batch_uploads table for bulk processing
file_hash: Optional[str] = None # SHA-256 hash for duplicate detection (US-007)
@property
def queue_wait_ms(self) -> Optional[int]:
"""Calculate queue wait time (created_at to started_at)."""
if self.created_at and self.started_at:
delta = self.started_at - self.created_at
return int(delta.total_seconds() * 1000)
return None
@property
def result(self) -> Optional[Dict]:
"""Parse result_json to dict."""
if self.result_json:
try:
return json.loads(self.result_json)
except json.JSONDecodeError:
return None
return None
class OCRJobQueue:
"""
SQLite-based job queue for OCR processing.
Provides async methods for job management with position
tracking and time estimation.
"""
def __init__(
self,
db_path: Optional[Path] = None,
files_dir: Optional[Path] = None
):
"""
Initialize job queue.
Args:
db_path: Path to SQLite database (default: data/ocr_queue/ocr_jobs.db)
files_dir: Path to files directory (default: data/ocr_queue/files/)
"""
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self.files_dir = Path(files_dir) if files_dir else DEFAULT_FILES_DIR
self._lock = asyncio.Lock()
self._initialized = False
async def initialize(self) -> None:
"""
Initialize database and directories.
Creates SQLite database and tables if they don't exist.
Creates files directory for uploaded files.
"""
if self._initialized:
return
# Create directories
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.files_dir.mkdir(parents=True, exist_ok=True)
# Create database and tables
async with aiosqlite.connect(str(self.db_path)) as db:
# Enable WAL mode for better concurrency and set busy timeout
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
CREATE TABLE IF NOT EXISTS ocr_jobs (
id TEXT PRIMARY KEY,
status TEXT NOT NULL DEFAULT 'pending',
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT,
error_message TEXT,
processing_time_ms INTEGER,
ocr_time_ms INTEGER,
created_by TEXT,
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER
)
''')
# Migration: add ocr_time_ms column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
except Exception:
pass # Column already exists
# Migration: add batch_id column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN batch_id INTEGER')
logger.info("[OCRJobQueue] Added batch_id column to existing table")
except Exception:
pass # Column already exists
# Migration: add file_hash column if it doesn't exist (US-007)
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN file_hash TEXT')
logger.info("[OCRJobQueue] Added file_hash column to existing table")
except Exception:
pass # Column already exists
# Index for efficient queue queries
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
ON ocr_jobs(status, created_at)
''')
# Index for expiration cleanup
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_expires
ON ocr_jobs(expires_at)
''')
await db.commit()
self._initialized = True
logger.info(f"[OCRJobQueue] Initialized: db={self.db_path}, files={self.files_dir}")
async def create_job(
self,
file_bytes: bytes,
mime_type: str,
engine: str = "doctr_plus",
username: Optional[str] = None,
original_filename: Optional[str] = None,
batch_id: Optional[int] = None,
file_hash: Optional[str] = None
) -> OCRJob:
"""
Create a new OCR job.
Saves file to disk and creates database record.
Args:
file_bytes: Raw file bytes
mime_type: MIME type of file
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
username: Username of requester
original_filename: Original filename from upload
batch_id: Optional batch ID for bulk upload processing
file_hash: Optional SHA-256 hash for duplicate detection (US-007)
Returns:
Created OCRJob instance
"""
await self.initialize()
# Generate job ID
job_id = str(uuid.uuid4())
# Determine file extension
ext_map = {
'image/jpeg': '.jpg',
'image/png': '.png',
'application/pdf': '.pdf',
}
ext = ext_map.get(mime_type, '.bin')
# Save file
file_path = self.files_dir / f"{job_id}{ext}"
with open(file_path, 'wb') as f:
f.write(file_bytes)
# Calculate expiration
now = datetime.utcnow()
expires_at = now + timedelta(hours=JOB_EXPIRY_HOURS)
# Insert job record
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
INSERT INTO ocr_jobs (
id, status, file_path, mime_type, engine,
created_at, created_by, original_filename, expires_at, batch_id, file_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
job_id, OCRJobStatus.pending.value, str(file_path), mime_type, engine,
now.isoformat(), username, original_filename, expires_at.isoformat(), batch_id, file_hash
))
await db.commit()
logger.info(f"[OCRJobQueue] Created job {job_id}: engine={engine}, file={file_path.name}, batch_id={batch_id}")
return OCRJob(
id=job_id,
status=OCRJobStatus.pending,
file_path=str(file_path),
mime_type=mime_type,
engine=engine,
created_at=now,
created_by=username,
original_filename=original_filename,
expires_at=expires_at,
batch_id=batch_id,
file_hash=file_hash
)
async def get_job(self, job_id: str) -> Optional[OCRJob]:
"""
Get job by ID.
Args:
job_id: Job UUID
Returns:
OCRJob or None if not found
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_job(row)
return None
async def get_queue_position(self, job_id: str) -> Optional[int]:
"""
Get position in queue for a pending job.
Args:
job_id: Job UUID
Returns:
Queue position (1 = next to process) or None if not pending
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
# Check if job is pending
async with db.execute(
'SELECT status, created_at FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if not row or row[0] != OCRJobStatus.pending.value:
return None
job_created_at = row[1]
# Count jobs ahead in queue (created before this job)
async with db.execute('''
SELECT COUNT(*) FROM ocr_jobs
WHERE status = 'pending' AND created_at < ?
''', (job_created_at,)) as cursor:
count = await cursor.fetchone()
return (count[0] + 1) if count else 1
async def get_next_pending(self) -> Optional[OCRJob]:
"""
Get the next pending job (oldest first) and atomically mark it as processing.
This prevents race conditions in parallel processing - only one worker
can claim each job.
Returns:
Next OCRJob to process or None if queue empty
"""
await self.initialize()
now = datetime.utcnow()
async with self._lock: # Serialize access to prevent race conditions
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get the next pending job
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if not row:
return None
job_id = row['id']
# Atomically mark as processing
await db.execute('''
UPDATE ocr_jobs
SET status = 'processing', started_at = ?
WHERE id = ? AND status = 'pending'
''', (now.isoformat(), job_id))
await db.commit()
# Fetch the updated job
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
updated_row = await cursor.fetchone()
if updated_row:
return self._row_to_job(updated_row)
return None
async def update_status(
self,
job_id: str,
status: OCRJobStatus,
result: Optional[Dict] = None,
error: Optional[str] = None,
processing_time_ms: Optional[int] = None,
ocr_time_ms: Optional[int] = None
) -> bool:
"""
Update job status.
Args:
job_id: Job UUID
status: New status
result: Extraction result dict (for completed)
error: Error message (for failed)
processing_time_ms: Total job processing time (started_at to completed_at)
ocr_time_ms: Actual OCR engine processing time
Returns:
True if update successful
"""
await self.initialize()
now = datetime.utcnow()
result_json = json.dumps(result, cls=DecimalEncoder) if result else None
# Build update query based on status
if status == OCRJobStatus.processing:
query = '''
UPDATE ocr_jobs
SET status = ?, started_at = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), job_id)
elif status == OCRJobStatus.completed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
elif status == OCRJobStatus.failed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
else:
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
params = (status.value, job_id)
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
cursor = await db.execute(query, params)
await db.commit()
return cursor.rowcount > 0
async def get_average_processing_time(self) -> float:
"""
Calculate average processing time from recent completed jobs.
Uses last 50 completed jobs for accuracy.
Returns:
Average time in seconds (default 7.0 if no data)
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT AVG(processing_time_ms)
FROM (
SELECT processing_time_ms FROM ocr_jobs
WHERE status = 'completed' AND processing_time_ms IS NOT NULL
ORDER BY completed_at DESC
LIMIT 50
)
''') as cursor:
row = await cursor.fetchone()
if row and row[0]:
return row[0] / 1000.0 # Convert ms to seconds
return 7.0 # Default estimate
async def count_pending(self) -> int:
"""Count pending jobs in queue."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.pending.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def count_processing(self) -> int:
"""Count currently processing jobs."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.processing.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def cleanup_expired(self) -> int:
"""
Delete expired jobs and their files.
Returns:
Number of jobs deleted
"""
await self.initialize()
now = datetime.utcnow()
deleted = 0
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get expired jobs
async with db.execute('''
SELECT id, file_path FROM ocr_jobs
WHERE expires_at < ?
''', (now.isoformat(),)) as cursor:
rows = await cursor.fetchall()
for row in rows:
# Delete file
file_path = Path(row['file_path'])
if file_path.exists():
try:
file_path.unlink()
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
# Delete job record
await db.execute('DELETE FROM ocr_jobs WHERE id = ?', (row['id'],))
deleted += 1
await db.commit()
if deleted > 0:
logger.info(f"[OCRJobQueue] Cleaned up {deleted} expired job(s)")
return deleted
async def cleanup_job_file(self, job_id: str) -> bool:
"""
Delete the file associated with a job.
Called after processing to free disk space.
Args:
job_id: Job UUID
Returns:
True if file deleted
"""
job = await self.get_job(job_id)
if job:
file_path = Path(job.file_path)
if file_path.exists():
try:
file_path.unlink()
return True
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
return False
async def get_queue_stats(self) -> Dict[str, Any]:
"""
Get queue statistics.
Returns:
Dict with pending, processing, completed, failed counts
"""
await self.initialize()
stats = {
"pending": 0,
"processing": 0,
"completed": 0,
"failed": 0,
"average_time_seconds": 0.0,
}
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT status, COUNT(*) as count
FROM ocr_jobs
GROUP BY status
''') as cursor:
rows = await cursor.fetchall()
for row in rows:
if row[0] in stats:
stats[row[0]] = row[1]
stats["average_time_seconds"] = await self.get_average_processing_time()
return stats
def _row_to_job(self, row: aiosqlite.Row) -> OCRJob:
"""Convert database row to OCRJob."""
def parse_datetime(val):
if val:
try:
return datetime.fromisoformat(val)
except (ValueError, TypeError):
return None
return None
return OCRJob(
id=row['id'],
status=OCRJobStatus(row['status']),
file_path=row['file_path'],
mime_type=row['mime_type'],
engine=row['engine'] or 'doctr_plus',
created_at=parse_datetime(row['created_at']),
started_at=parse_datetime(row['started_at']),
completed_at=parse_datetime(row['completed_at']),
result_json=row['result_json'],
error_message=row['error_message'],
processing_time_ms=row['processing_time_ms'],
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
created_by=row['created_by'],
original_filename=row['original_filename'],
expires_at=parse_datetime(row['expires_at']),
batch_id=row['batch_id'] if 'batch_id' in row.keys() else None,
file_hash=row['file_hash'] if 'file_hash' in row.keys() else None,
)
# Singleton instance
job_queue = OCRJobQueue()

View File

@@ -0,0 +1,665 @@
"""
OCR Job Worker - Background Task for Queue Processing
Runs as an asyncio background task in FastAPI.
Continuously polls the job queue and processes OCR requests IN PARALLEL.
Architecture:
FastAPI startup
start_job_worker()
asyncio.create_task(_job_worker_loop())
while True:
# Process up to OCR_WORKERS jobs concurrently
jobs = get_pending_jobs(limit=available_slots)
for job in jobs:
asyncio.create_task(_process_job(job))
await asyncio.sleep(0.1)
"""
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import Optional, Set
from .job_queue import job_queue, OCRJobStatus, OCRJob
from .ocr_worker_pool import ocr_worker_pool
from backend.modules.data_entry.schemas.ocr import ExtractionData
logger = logging.getLogger(__name__)
# Global task reference
_job_worker_task: Optional[asyncio.Task] = None
_cleanup_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
# Configuration
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
async def _job_worker_loop() -> None:
"""
Main worker loop - processes jobs from queue IN PARALLEL.
Runs continuously until shutdown. Uses semaphore to limit
concurrent jobs to OCR_WORKERS count. Launches jobs as
background tasks without waiting for completion.
"""
global _shutdown_event, _active_tasks, _concurrency_semaphore
# Get max concurrent jobs from env (matches worker pool size)
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
_active_tasks = set()
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
_shutdown_event = asyncio.Event()
consecutive_errors = 0
max_consecutive_errors = 10
while not _shutdown_event.is_set():
try:
# Clean up completed tasks
done_tasks = {t for t in _active_tasks if t.done()}
for task in done_tasks:
_active_tasks.discard(task)
# Check for exceptions
try:
task.result()
except Exception as e:
logger.error(f"[JobWorker] Task failed: {e}")
# Check if we have capacity for more jobs
active_count = len(_active_tasks)
available_slots = max_concurrent - active_count
if available_slots > 0:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0
# Launch job processing as background task
task = asyncio.create_task(_process_job_with_semaphore(job))
_active_tasks.add(task)
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
else:
# No pending jobs - wait briefly
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass
else:
# At capacity - wait for a slot to free up
await asyncio.sleep(POLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
logger.info("[JobWorker] Worker loop cancelled")
break
except Exception as e:
consecutive_errors += 1
logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}")
if consecutive_errors >= max_consecutive_errors:
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
break
await asyncio.sleep(min(consecutive_errors * 2, 30))
# Wait for active tasks to complete on shutdown
if _active_tasks:
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
await asyncio.gather(*_active_tasks, return_exceptions=True)
logger.info("[JobWorker] Worker loop stopped")
async def _process_job_with_semaphore(job: OCRJob) -> None:
"""
Process job with semaphore to limit concurrency.
Acquires semaphore before processing, releases after.
This ensures we don't exceed OCR_WORKERS concurrent jobs.
"""
global _concurrency_semaphore
async with _concurrency_semaphore:
await _process_job(job)
async def _process_job(job: OCRJob) -> None:
"""
Process a single OCR job.
Reads file, submits to worker pool, updates job status,
and saves metrics for analytics.
Args:
job: OCRJob to process
"""
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
start_time = time.time()
file_size = 0
file_type = "image/jpeg"
try:
# Note: Job already marked as 'processing' atomically in get_next_pending()
# Read file bytes
file_path = Path(job.file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_size = len(file_bytes)
# Determine file type from job or extension
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
# Submit to worker pool
result = await ocr_worker_pool.submit_task(
image_bytes=file_bytes,
engine=job.engine,
preprocessing="auto",
timeout=OCR_TIMEOUT_SECONDS
)
elapsed_ms = int((time.time() - start_time) * 1000)
if result.get("success"):
# Job completed successfully
extraction = result.get("extraction", {})
# Include raw_texts for analysis (from all OCR engine passes)
extraction['raw_texts'] = result.get("raw_texts", [])
# Extract actual OCR processing time from extraction result
ocr_time_ms = extraction.get('processing_time_ms', 0)
# Debug: log suggested_payment_mode
spm = extraction.get('suggested_payment_mode')
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.completed,
result=extraction,
processing_time_ms=elapsed_ms,
ocr_time_ms=ocr_time_ms
)
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
# Save metrics for successful job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=extraction.get('ocr_engine', job.engine),
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=True,
overall_confidence=extraction.get('overall_confidence', 0.0),
fields_extracted=_count_extracted_fields(extraction),
needs_manual_review=extraction.get('needs_manual_review'),
validation_warnings_count=len(extraction.get('validation_warnings', [])),
validation_errors_count=len(extraction.get('validation_errors', [])),
)
# Auto-save receipt for batch jobs
if job.batch_id:
auto_save_result = await _auto_save_batch_receipt(
job=job,
extraction=extraction,
file_path=str(file_path)
)
if not auto_save_result:
# Auto-save failed - mark job as failed
# Note: job_queue status already updated to 'completed' above
# We need to update it back to failed with the auto-save error
logger.warning(
f"[JobWorker] Job {job.id} OCR succeeded but auto-save failed"
)
else:
# Job failed
error_msg = result.get("error", "Unknown error")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=error_msg,
processing_time_ms=elapsed_ms
)
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
# Save metrics for failed job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=error_msg,
)
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=str(e),
processing_time_ms=elapsed_ms
)
# Save metrics for error job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=str(e),
)
finally:
# Cleanup file after processing
try:
await job_queue.cleanup_job_file(job.id)
except Exception as e:
logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}")
async def _cleanup_loop() -> None:
"""
Periodic cleanup of expired jobs.
Runs every hour to delete jobs older than 24 hours.
"""
global _shutdown_event
logger.info("[JobWorker] Starting cleanup loop...")
while not _shutdown_event.is_set():
try:
# Wait for interval or shutdown
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=CLEANUP_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass # Normal timeout, do cleanup
# Run cleanup
deleted = await job_queue.cleanup_expired()
if deleted > 0:
logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs")
except asyncio.CancelledError:
logger.info("[JobWorker] Cleanup loop cancelled")
break
except Exception as e:
logger.error(f"[JobWorker] Cleanup error: {e}")
await asyncio.sleep(60) # Retry after 1 minute
logger.info("[JobWorker] Cleanup loop stopped")
async def start_job_worker() -> bool:
"""
Start the job worker background task.
Called at FastAPI startup to begin processing queue.
Returns:
True if started successfully
"""
global _job_worker_task, _cleanup_task, _shutdown_event
if _job_worker_task is not None and not _job_worker_task.done():
logger.warning("[JobWorker] Already running")
return True
try:
# Initialize job queue
await job_queue.initialize()
# Initialize worker pool
if not ocr_worker_pool.initialize():
logger.error("[JobWorker] Failed to initialize worker pool")
return False
# Pre-warm worker pool in BACKGROUND (don't block startup)
# First OCR request may be slower if prewarm isn't done yet
async def _background_prewarm():
logger.info("[JobWorker] Pre-warming OCR worker pool (background)...")
warmup_success = await ocr_worker_pool.prewarm(timeout=90.0)
if warmup_success:
logger.info("[JobWorker] OCR worker pool pre-warmed successfully")
else:
logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower")
asyncio.create_task(_background_prewarm())
# Start worker loop
_shutdown_event = asyncio.Event()
_job_worker_task = asyncio.create_task(_job_worker_loop())
# Start cleanup loop
_cleanup_task = asyncio.create_task(_cleanup_loop())
logger.info("[JobWorker] Started successfully")
return True
except Exception as e:
logger.error(f"[JobWorker] Failed to start: {e}")
return False
async def stop_job_worker() -> None:
"""
Stop the job worker background task.
Called at FastAPI shutdown to gracefully stop processing.
"""
global _job_worker_task, _cleanup_task, _shutdown_event
logger.info("[JobWorker] Stopping...")
# Signal shutdown
if _shutdown_event:
_shutdown_event.set()
# Cancel worker task
if _job_worker_task and not _job_worker_task.done():
_job_worker_task.cancel()
try:
await _job_worker_task
except asyncio.CancelledError:
pass
# Cancel cleanup task
if _cleanup_task and not _cleanup_task.done():
_cleanup_task.cancel()
try:
await _cleanup_task
except asyncio.CancelledError:
pass
# Shutdown worker pool
ocr_worker_pool.shutdown(wait=True)
_job_worker_task = None
_cleanup_task = None
_shutdown_event = None
logger.info("[JobWorker] Stopped")
def is_running() -> bool:
"""Check if job worker is running."""
return _job_worker_task is not None and not _job_worker_task.done()
def estimate_wait_time(queue_position: int) -> int:
"""
Estimate wait time for a job in queue.
Args:
queue_position: Position in queue (1 = next)
Returns:
Estimated wait time in seconds
"""
if queue_position <= 0:
return 0
# Get average processing time (synchronous fallback)
# Default ~7 seconds per job if no data
avg_time = 7.0
try:
# Try to get from queue stats
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
# Can't use sync call in async context, use default
pass
else:
avg_time = loop.run_until_complete(job_queue.get_average_processing_time())
except Exception:
pass
# Estimate: position * average_time
return int(queue_position * avg_time)
# ============================================================================
# Metrics Helper Functions
# ============================================================================
async def _save_job_metrics(
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
) -> None:
"""
Save OCR job metrics to database for analytics.
Called after each job completes (success or failure).
Errors are logged but don't affect job processing.
"""
try:
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
async with await get_db_session() as session:
await OCRMetricsCRUD.create(
session=session,
job_id=job_id,
username=username,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
except Exception as e:
# Log but don't fail - metrics are nice-to-have
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
def _count_extracted_fields(extraction: dict) -> int:
"""
Count number of successfully extracted fields from OCR result.
Counts non-None values in key fields.
"""
key_fields = [
'receipt_number',
'receipt_date',
'amount',
'partner_name',
'cui',
'tva_total',
'address',
'items_count',
]
count = 0
for field in key_fields:
value = extraction.get(field)
if value is not None and value != '' and value != []:
count += 1
# Also count TVA entries if present
tva_entries = extraction.get('tva_entries', [])
if tva_entries and len(tva_entries) > 0:
count += 1
# Count payment methods if present
payment_methods = extraction.get('payment_methods', [])
if payment_methods and len(payment_methods) > 0:
count += 1
return count
# ============================================================================
# Auto-Save Batch Receipt Helper
# ============================================================================
async def _auto_save_batch_receipt(
job: OCRJob,
extraction: dict,
file_path: str
) -> bool:
"""
Automatically create a receipt from OCR result for batch jobs.
Called when a batch job completes successfully. Creates the receipt,
attachment, and accounting entries using ReceiptAutoCreateService.
Args:
job: Completed OCRJob with batch_id set
extraction: OCR extraction result dict
file_path: Path to the original uploaded file
Returns:
True if receipt created successfully, False otherwise
"""
if not job.batch_id:
return True # Not a batch job, nothing to do
logger.info(f"[JobWorker] Auto-saving receipt for batch job {job.id} (batch_id={job.batch_id})")
try:
# Import here to avoid circular imports
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.models import BatchUpload
from backend.modules.data_entry.services.receipt_auto_create import ReceiptAutoCreateService
from sqlalchemy import select
# Convert extraction dict to ExtractionData schema
ocr_result = ExtractionData(**extraction)
async with await get_db_session() as session:
# Get batch info to retrieve company_id and user_id
batch_result = await session.execute(
select(BatchUpload).where(BatchUpload.id == job.batch_id)
)
batch = batch_result.scalar_one_or_none()
if not batch:
error_msg = f"Batch {job.batch_id} not found"
logger.error(f"[JobWorker] Auto-save failed for job {job.id}: {error_msg}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
# Call ReceiptAutoCreateService
result = await ReceiptAutoCreateService.create_from_ocr_result(
session=session,
job_id=job.id,
ocr_result=ocr_result,
username=job.created_by or batch.user_id,
batch_id=job.batch_id,
company_id=batch.company_id,
file_path=file_path,
original_filename=job.original_filename,
file_hash=job.file_hash # Pass file_hash for duplicate detection (US-007)
)
if result.success:
logger.info(
f"[JobWorker] Auto-save successful for job {job.id}: "
f"receipt_id={result.receipt_id}"
)
return True
else:
error_msg = result.error_message or "Unknown error"
logger.warning(
f"[JobWorker] Auto-save validation failed for job {job.id}: {error_msg}"
)
# Update job status to failed with the auto-save error
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
except Exception as e:
error_msg = str(e)
logger.error(f"[JobWorker] Auto-save exception for job {job.id}: {error_msg}")
# Update job status to failed
try:
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
except Exception as update_err:
logger.error(f"[JobWorker] Failed to update job status after auto-save error: {update_err}")
return False

View File

@@ -0,0 +1,561 @@
"""
OCR Worker Pool Manager
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
Key features:
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
- mp_context='spawn' for Windows IIS compatibility
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- atexit + signal handlers for cleanup
- Health check with auto-respawn
- Orphan process cleanup on Windows
Architecture:
Main Process │ Worker Process (PERSISTENT)
──────────────────────│──────────────────────────────────
OCRWorkerPool │ Worker initialized once
↓ │ ↓
submit_task() ────────│────→ process_ocr()
↓ │ ↓
Future.result() ←─────│──── Return result
"""
import asyncio
import atexit
import gc
import logging
import multiprocessing as mp
import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
# Try to import psutil for orphan process cleanup
try:
import psutil
PSUTIL_AVAILABLE = True
except ImportError:
PSUTIL_AVAILABLE = False
logger.warning("[OCRWorkerPool] psutil not available - orphan cleanup disabled")
class OCRWorkerPool:
"""
Singleton manager for OCR ProcessPoolExecutor.
Ensures OCR engines are loaded once and reused for all requests.
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
"""
_instance: Optional["OCRWorkerPool"] = None
_initialized: bool = False
def __new__(cls) -> "OCRWorkerPool":
"""Singleton pattern - only one pool instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize worker pool (runs only once due to singleton)."""
if self._initialized:
return
self._executor: Optional[ProcessPoolExecutor] = None
self._worker_pid: Optional[int] = None
self._is_warming: bool = False
self._is_shutdown: bool = False
self._lock = asyncio.Lock() if asyncio.get_event_loop_policy() else None
self._sync_lock = mp.Lock()
# Register cleanup handlers
# NOTE: Only use atexit, NOT signal handlers!
# Signal handlers interfere with FastAPI's shutdown handling.
# FastAPI's shutdown event calls stop_job_worker() which calls shutdown().
atexit.register(self._cleanup_on_exit)
self._initialized = True
logger.info("[OCRWorkerPool] Singleton instance created")
def initialize(self) -> bool:
"""
Initialize the ProcessPoolExecutor.
Creates executor with spawn context for Windows compatibility.
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
Returns:
True if initialization successful
"""
if self._executor is not None:
logger.warning("[OCRWorkerPool] Already initialized")
return True
if self._is_shutdown:
logger.error("[OCRWorkerPool] Cannot initialize - pool is shutdown")
return False
try:
# Cleanup any orphan workers from previous runs
self._cleanup_orphan_workers()
# Read configuration from environment
max_workers = int(os.getenv('OCR_WORKERS', '2'))
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
# 0 means no restart (None in ProcessPoolExecutor)
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
# Create executor with spawn context (Windows compatible)
# Use mp_context='spawn' explicitly for cross-platform consistency
mp_context = mp.get_context('spawn')
# max_tasks_per_child only available in Python 3.11+
executor_kwargs = {
'max_workers': max_workers,
'mp_context': mp_context,
'initializer': _worker_initializer,
}
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
else:
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
self._executor = ProcessPoolExecutor(**executor_kwargs)
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
return True
except Exception as e:
logger.error(f"[OCRWorkerPool] Initialization failed: {e}")
return False
async def prewarm(self, timeout: float = 60.0) -> bool:
"""
Pre-warm the worker by loading PaddleOCR before first request.
This is called at FastAPI startup to avoid 30s delay on first request.
Submits a dummy task that triggers PaddleOCR initialization.
Args:
timeout: Maximum seconds to wait for warmup (default 60s)
Returns:
True if warmup successful, False if timeout or error
"""
if self._executor is None:
logger.error("[OCRWorkerPool] Cannot prewarm - not initialized")
return False
if self._is_warming:
logger.warning("[OCRWorkerPool] Already warming up")
return False
self._is_warming = True
logger.info("[OCRWorkerPool] Starting pre-warm (loading PaddleOCR in worker)...")
start_time = time.time()
try:
# Submit warmup task that initializes PaddleOCR
loop = asyncio.get_event_loop()
future = self._executor.submit(_warmup_task)
# Wait with timeout
result = await loop.run_in_executor(None, future.result, timeout)
elapsed = time.time() - start_time
if result.get("success"):
logger.info(f"[OCRWorkerPool] Pre-warm complete in {elapsed:.1f}s - PaddleOCR ready")
self._worker_pid = result.get("pid")
return True
else:
logger.error(f"[OCRWorkerPool] Pre-warm failed: {result.get('error')}")
return False
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"[OCRWorkerPool] Pre-warm failed after {elapsed:.1f}s: {e}")
return False
finally:
self._is_warming = False
async def submit_task(
self,
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto",
timeout: float = 120.0
) -> dict:
"""
Submit OCR task to worker process.
Args:
image_bytes: Raw image bytes
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
timeout: Maximum processing time in seconds
Returns:
Dict with extraction results
Raises:
RuntimeError: If pool not initialized or task fails
"""
if self._executor is None:
raise RuntimeError("OCR worker pool not initialized")
if self._is_shutdown:
raise RuntimeError("OCR worker pool is shutdown")
logger.info(f"[OCRWorkerPool] Submitting task: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes")
try:
loop = asyncio.get_event_loop()
future = self._executor.submit(
_process_ocr_task,
image_bytes,
engine,
preprocessing
)
# Wait for result with timeout
result = await loop.run_in_executor(None, future.result, timeout)
logger.info(f"[OCRWorkerPool] Task complete: success={result.get('success')}")
return result
except TimeoutError:
logger.error(f"[OCRWorkerPool] Task timed out after {timeout}s")
raise RuntimeError(f"OCR processing timed out after {timeout}s")
except Exception as e:
logger.error(f"[OCRWorkerPool] Task failed: {e}")
raise RuntimeError(f"OCR processing failed: {e}")
def is_healthy(self) -> bool:
"""
Check if worker pool is healthy.
Returns:
True if pool is ready to accept tasks
"""
if self._executor is None:
return False
if self._is_shutdown:
return False
# Check if worker process is still alive
if self._worker_pid and PSUTIL_AVAILABLE:
try:
proc = psutil.Process(self._worker_pid)
if not proc.is_running():
logger.warning("[OCRWorkerPool] Worker process died, needs respawn")
return False
except psutil.NoSuchProcess:
logger.warning("[OCRWorkerPool] Worker process not found")
return False
return True
def shutdown(self, wait: bool = True, timeout: float = 10.0) -> None:
"""
Shutdown the worker pool gracefully.
Args:
wait: Wait for pending tasks to complete
timeout: Maximum wait time in seconds
"""
if self._executor is None:
return
logger.info("[OCRWorkerPool] Shutting down...")
self._is_shutdown = True
try:
self._executor.shutdown(wait=wait, cancel_futures=True)
logger.info("[OCRWorkerPool] Executor shutdown complete")
except Exception as e:
logger.error(f"[OCRWorkerPool] Shutdown error: {e}")
self._executor = None
self._worker_pid = None
# Final orphan cleanup
self._cleanup_orphan_workers()
logger.info("[OCRWorkerPool] Shutdown complete")
def _cleanup_orphan_workers(self) -> int:
"""
Clean up orphan Python processes from previous runs.
On Windows with NSSM, orphan processes may remain after service restart.
This finds and kills any python.exe processes that were OCR workers.
Returns:
Number of processes killed
"""
if not PSUTIL_AVAILABLE:
return 0
killed = 0
current_pid = os.getpid()
try:
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
# Skip self
if proc.pid == current_pid:
continue
# Look for Python processes with OCR-related cmdline
if proc.name().lower() in ('python.exe', 'python3.exe', 'python', 'python3'):
cmdline = ' '.join(proc.cmdline() or [])
# Check if this is an OCR worker process
if 'ocr_worker_process' in cmdline.lower() or 'process_ocr_task' in cmdline.lower():
logger.warning(f"[OCRWorkerPool] Killing orphan worker: PID={proc.pid}")
proc.kill()
proc.wait(timeout=5)
killed += 1
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
continue
except Exception as e:
logger.error(f"[OCRWorkerPool] Orphan cleanup error: {e}")
if killed > 0:
logger.info(f"[OCRWorkerPool] Cleaned up {killed} orphan worker(s)")
return killed
def _cleanup_on_exit(self) -> None:
"""atexit handler for cleanup."""
logger.info("[OCRWorkerPool] atexit cleanup triggered")
self.shutdown(wait=False)
def _signal_handler(self, signum: int, frame: Any) -> None:
"""Signal handler for SIGTERM/SIGINT."""
logger.info(f"[OCRWorkerPool] Received signal {signum}, shutting down...")
self.shutdown(wait=False)
# ============================================================================
# WORKER PROCESS FUNCTIONS
# ============================================================================
# These functions run in the child process, not the main FastAPI process.
# Global engines - persist between tasks in worker process
_paddle_engine = None
_tesseract_engine = None
_doctr_engine = None # docTR engine (PyTorch backend)
_worker_initialized = False
def _worker_initializer() -> None:
"""
Called once when worker process spawns.
Initializes global OCR engines IN PARALLEL for faster startup.
Uses ThreadPoolExecutor to load enabled engines concurrently.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Total warmup time = max(engine_times) instead of sum(engine_times).
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
if _worker_initialized:
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
return
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
if paddle_enabled:
enabled_engines.append("paddle")
if tesseract_enabled:
enabled_engines.append("tesseract")
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
if not paddle_enabled:
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
if not tesseract_enabled:
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
start_time = time.time()
# Define loader functions - each runs in its own thread
def load_doctr():
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
engine = initialize_doctr_engine()
return ("doctr", engine, None)
except Exception as e:
return ("doctr", None, str(e))
def load_paddle():
if not paddle_enabled:
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
engine = initialize_paddle_engine()
return ("paddle", engine, None)
except Exception as e:
return ("paddle", None, str(e))
def load_tesseract():
if not tesseract_enabled:
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
engine = TesseractEngine()
return ("tesseract", engine, None)
except Exception as e:
return ("tesseract", None, str(e))
# Build list of futures for enabled engines only
futures_to_submit = [load_doctr] # docTR always loaded
if paddle_enabled:
futures_to_submit.append(load_paddle)
if tesseract_enabled:
futures_to_submit.append(load_tesseract)
# Load engines in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
futures = [executor.submit(fn) for fn in futures_to_submit]
for future in as_completed(futures):
name, engine, error = future.result()
if error and "disabled" not in error:
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
elif engine:
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
if name == "doctr":
_doctr_engine = engine
elif name == "paddle":
_paddle_engine = engine
elif name == "tesseract":
_tesseract_engine = engine
elapsed = time.time() - start_time
_worker_initialized = True
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
def _warmup_task() -> dict:
"""
Warmup task that ensures engines are loaded.
Called at FastAPI startup to pre-warm the worker.
Returns success status and worker PID.
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Quick test - create a small dummy image
import numpy as np
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Test docTR if available (fastest engine)
if _doctr_engine is not None:
try:
_doctr_engine([dummy_img])
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
# Test PaddleOCR if available
if _paddle_engine is not None:
try:
_paddle_engine.predict(dummy_img)
print(f"[Worker {os.getpid()}] PaddleOCR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] PaddleOCR warmup error: {e}", flush=True)
# Cleanup
gc.collect()
return {
"success": True,
"pid": os.getpid(),
"doctr_available": _doctr_engine is not None,
"paddle_available": _paddle_engine is not None,
"tesseract_available": _tesseract_engine is not None
}
except Exception as e:
return {
"success": False,
"pid": os.getpid(),
"error": str(e)
}
def _process_ocr_task(
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto"
) -> dict:
"""
Process OCR task in worker process.
This is the main work function called for each OCR request.
Uses persistent global engines loaded at worker init.
Args:
image_bytes: Raw image bytes
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode
Returns:
Dict with extraction results
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Import processing function
from backend.modules.data_entry.services.ocr.ocr_worker_process import process_ocr
# Run OCR
result = process_ocr(
image_bytes=image_bytes,
paddle_engine=_paddle_engine,
tesseract_engine=_tesseract_engine,
engine=engine,
preprocessing=preprocessing,
doctr_engine=_doctr_engine
)
# Cleanup after each task
gc.collect()
return result
except Exception as e:
print(f"[Worker {os.getpid()}] Task error: {e}", flush=True)
import traceback
traceback.print_exc()
return {
"success": False,
"error": str(e),
"pid": os.getpid()
}
# Singleton instance
ocr_worker_pool = OCRWorkerPool()

View File

@@ -0,0 +1,258 @@
# Store Profiles - OCR Extraction
Sistem de profile specifice pentru extracție OCR cu hot-reload.
---
## Quick Start: Adaugă un profil nou
```bash
# 1. Generează profil din PDF-uri (dry-run pentru preview)
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--dry-run
# 2. Generează și salvează
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--output backend/modules/data_entry/services/ocr/profiles/magazin_nou.py
# 3. Hot-reload (fără restart server)
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload
# 4. Verifică
curl http://localhost:8000/api/data-entry/ocr/profiles
```
---
## Structura directorului
```
profiles/
├── __init__.py # ProfileRegistry + hot-reload (~390 linii)
├── base.py # BaseStoreProfile + pattern-uri generice (~410 linii)
├── lidl.py # Multi-rate TVA (A/B)
├── omv.py # B2B, date YYYY.MM.DD
├── socar.py # B2B, date YYYY.MM.DD
├── brick.py # Standard TVA
├── dedeman.py # E-factura support
├── kineterra.py # Non-VAT payer
├── gama_ink.py # Standard TVA (toner/cartușe)
├── electrobering.py # Standard TVA (electronice)
├── pictus_velum.py # Standard TVA (rechizite)
├── unlimited_keys.py # Standard TVA, NUMERAR payment
├── best_print.py # Non-VAT payer (neplătitor TVA)
├── stepout_market.py # TVA 5% (cărți/librărie)
└── README.md # Acest fișier
```
---
## Profile existente (12 profile)
> **Note**: Pattern-urile TVA sunt **flexibile** și acceptă ORICE cotă (5%, 9%, 11%, 19%, 21%, etc.)
> pentru a gestiona atât datele istorice cât și schimbările viitoare ale legislației.
| Magazin | CUI | Fișier | Caracteristici |
|---------|-----|--------|----------------|
| LIDL DISCOUNT S.R.L. | 22891860 | `lidl.py` | Multi-rate TVA (coduri A, B, C, D) |
| OMV PETROM MARKETING S.R.L. | 11201891 | `omv.py` | B2B (client CUI), date YYYY.MM.DD |
| SOCAR PETROLEUM S.A. | 12546600 | `socar.py` | B2B (client CUI), date YYYY.MM.DD |
| FIVE-HOLDING S.A. (BRICK) | 10562600 | `brick.py` | Standard TVA |
| DEDEMAN SRL | 2816464 | `dedeman.py` | E-factura support |
| KINETERRA CONCEPT SRL | 31180432 | `kineterra.py` | Non-VAT payer (returnează `[]`) |
| GAMA INK SERVICE SRL | 17741882 | `gama_ink.py` | Standard TVA (toner, cartușe) |
| ELECTROBERING S.R.L. | 2744937 | `electrobering.py` | Standard TVA (electronice) |
| PICTUS VELUM SRL | 39634534 | `pictus_velum.py` | Standard TVA (rechizite) |
| UNLIMITED KEYS S.R.L. | 18993187 | `unlimited_keys.py` | Standard TVA, **NUMERAR** plată |
| BEST PRINT TRADE ACTIV SRL | 45417955 | `best_print.py` | **Non-VAT payer** (neplătitor TVA) |
| STEPOUT MARKET SRL | 35532655 | `stepout_market.py` | TVA 5% (cărți, librărie) |
---
## API Endpoints
| Endpoint | Metodă | Descriere |
|----------|--------|-----------|
| `/api/data-entry/ocr/profiles` | GET | Lista toate profilele |
| `/api/data-entry/ocr/profiles/{cui}` | GET | Detalii profil (acceptă RO prefix) |
| `/api/data-entry/ocr/profiles/reload` | POST | Hot-reload toate profilele |
### Exemple API
```bash
# Lista profile
curl http://localhost:8000/api/data-entry/ocr/profiles \
-H "Authorization: Bearer <token>"
# Detalii profil (cu sau fără RO prefix)
curl http://localhost:8000/api/data-entry/ocr/profiles/22891860
curl http://localhost:8000/api/data-entry/ocr/profiles/RO22891860
# Hot-reload după modificări
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload \
-H "Authorization: Bearer <token>"
# Response reload:
{
"success": true,
"reloaded_modules": 12,
"profiles_count": 12,
"registered_cuis": ["22891860", "11201891", "12546600", "10562600", ...],
"last_reload": "2026-01-06T22:37:05.000000"
}
```
---
## Cum funcționează sistemul
### Flow de extracție
```
ReceiptExtractor.extract()
├─► STEP 1: Extrage vendor + CUI
│ └─► _extract_vendor(), _extract_cui()
├─► ProfileRegistry.get_profile(cui)
│ └─► Returnează profil specific sau None
├─► STEP 2: Extracție cu profil (dacă există)
│ ├─► profile.extract_total()
│ ├─► profile.extract_date()
│ ├─► profile.extract_receipt_number()
│ ├─► profile.extract_tva_entries()
│ ├─► profile.extract_payment_methods()
│ └─► profile.extract_client_cui()
└─► STEP 3-4: Validare + post-procesare
```
### Fallback
Dacă nu există profil pentru CUI, se folosește logica generică din `ReceiptExtractor`.
---
## Structura unui profil
```python
from .base import BaseStoreProfile
from . import ProfileRegistry
@ProfileRegistry.register
class MagazinNouProfile(BaseStoreProfile):
"""Docstring cu descriere magazin."""
CUI_LIST = ["12345678"] # Poate avea mai multe CUI-uri
NAME_PATTERNS = ["MAGAZIN", "MAGAZIN NOU", "MAG4ZIN"] # OCR variants
STORE_NAME = "Magazin Nou SRL"
# Override doar ce e diferit de base class
def extract_tva_entries(self, text: str) -> List[dict]:
# Pattern-uri specifice magazinului
...
def get_validation_hints(self) -> Dict[str, Any]:
return {
"has_multi_rate_tva": False,
"card_equals_total": True,
"has_client_cui": False,
"has_efactura": False,
"is_non_vat_payer": False,
}
```
---
## Pattern-uri disponibile în base.py
BaseStoreProfile include pattern-uri generice OCR-tolerant:
| Pattern | Descriere |
|---------|-----------|
| `TOTAL_PATTERNS` | 8 variante pentru TOTAL (TOTAL:, TOTAL DE PLATA, etc.) |
| `DATE_PATTERNS` | 6 variante (DD.MM.YYYY, YYYY-MM-DD, DD/MM/YYYY) |
| `DATE_PATTERNS_OCR_SPACES` | 4 variante cu spații OCR ("2025. 08. 14") |
| `NUMBER_PATTERNS` | 11 variante pentru număr bon (NDS, BF, C3POS) |
| `PAYMENT_PATTERNS` | 8 variante pentru CARD/NUMERAR |
| `CLIENT_MARKERS` | 6 variante pentru secțiune CLIENT |
| `CLIENT_CUI_PATTERNS` | 7 variante pentru CUI client |
### Metode implementate în base class
- `extract_total(text)``Tuple[Decimal, float]`
- `extract_date(text)``Tuple[date, float]`
- `extract_receipt_number(text)``Tuple[str, float]`
- `extract_payment_methods(text)``List[dict]`
- `extract_client_cui(text)``Tuple[str, float]`
- `extract_client_name(text)``Tuple[str, float]`
---
## Când ai nevoie de profil custom?
| Situație | Exemplu | Ce trebuie override |
|----------|---------|---------------------|
| **Multi-rate TVA** | Lidl (TVA A, TVA B) | `extract_tva_entries()` |
| **Format dată special** | OMV/Socar (YYYY.MM.DD) | `DATE_PATTERNS_OCR_SPACES` |
| **B2B receipts** | Benzinării (au client CUI) | `extract_client_cui()` |
| **Non-VAT payer** | Kineterra | `extract_tva_entries()` returnează `[]` |
| **E-factura** | Dedeman | `extract_efactura_reference()` |
---
## Decizii de design
1. **Hot-reload manual** - endpoint `/profiles/reload` apelat când se modifică fișiere
2. **Persistență în Python** - profile în Git, version controlled
3. **Fallback graceful** - dacă nu există profil, folosește logica generică
4. **CUI normalization** - gestionează automat prefixul "RO" și whitespace
5. **Deduplicare TVA** - folosește `seen = set()` pentru a evita duplicate
---
## Comenzi utile
```bash
# Verifică syntax Python pentru toate profilele
for f in backend/modules/data_entry/services/ocr/profiles/*.py; do
python3 -m py_compile "$f" && echo "✓ $(basename $f)"
done
# Lista profile
ls -la backend/modules/data_entry/services/ocr/profiles/
# Pornește backend pentru testare
cd backend && source venv/bin/activate
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
# Test OCR pe un PDF
curl -X POST -F "file=@docs/data-entry/test.pdf" \
-H "Authorization: Bearer <token>" \
"http://localhost:8000/api/data-entry/ocr/extract?engine=doctr_plus"
```
---
## Script generare profile
`scripts/generate_store_profile.py` - generator automat de profile
```bash
# Vezi help
python scripts/generate_store_profile.py --help
# Funcționalități:
# - Analizează PDF-uri via OCR API
# - Detectează: TVA format, date format, payment patterns, B2B
# - Generează cod Python cu OCR error variants
# - Suportă glob patterns (*.pdf)
# - Verifică sintaxa după generare
```

View File

@@ -0,0 +1,398 @@
"""
Store Profiles Registry with Hot-Reload Support.
This module provides a registry for store-specific OCR extraction profiles.
Profiles can be reloaded at runtime without restarting the server.
Usage:
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
# Get profile for a CUI
profile = ProfileRegistry.get_profile("22891860")
if profile:
tva_entries = profile.extract_tva_entries(text)
# Reload all profiles (after file changes)
count = ProfileRegistry.reload_all()
Architecture:
- ProfileRegistry: Singleton registry with class methods
- BaseStoreProfile: Abstract base class for profiles
- @ProfileRegistry.register: Decorator for profile classes
Hot-Reload Mechanism:
1. Admin calls POST /profiles/reload endpoint
2. Registry clears instance cache
3. importlib.reload() re-executes each profile module
4. @register decorator re-registers classes with new code
"""
from __future__ import annotations
import importlib
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Type, TYPE_CHECKING
if TYPE_CHECKING:
from .base import BaseStoreProfile
logger = logging.getLogger(__name__)
# Directory containing profile modules
PROFILES_DIR = Path(__file__).parent
class ProfileRegistry:
"""
Registry for store-specific OCR extraction profiles.
Uses class methods for singleton-like behavior without explicit instantiation.
Supports hot-reload via importlib.reload() for runtime updates.
Attributes:
_profiles: Maps CUI -> profile class (not instance)
_instances: Maps CUI -> profile instance (lazy, cleared on reload)
_last_reload: Timestamp of last reload
_loaded: Whether initial load has been performed
"""
# Class-level storage (singleton pattern via class methods)
_profiles: Dict[str, Type["BaseStoreProfile"]] = {}
_instances: Dict[str, "BaseStoreProfile"] = {}
_last_reload: Optional[datetime] = None
_loaded: bool = False
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
@classmethod
def register(cls, profile_class: Type["BaseStoreProfile"]) -> Type["BaseStoreProfile"]:
"""
Decorator to register a store profile class.
Registers the profile for all CUIs in the class's CUI_LIST.
Safe for re-registration during hot-reload (overwrites existing).
Usage:
@ProfileRegistry.register
class LidlProfile(BaseStoreProfile):
CUI_LIST = ["22891860"]
...
Args:
profile_class: Profile class to register
Returns:
The same class (allows use as decorator)
Raises:
ValueError: If CUI_LIST is empty
"""
cui_list = getattr(profile_class, 'CUI_LIST', [])
store_name = getattr(profile_class, 'STORE_NAME', profile_class.__name__)
if not cui_list:
logger.warning(f"Profile {profile_class.__name__} has empty CUI_LIST, skipping")
return profile_class
# Register for each CUI
for cui in cui_list:
# Normalize CUI (remove RO prefix, strip whitespace)
normalized_cui = cls._normalize_cui(cui)
if normalized_cui in cls._profiles:
old_class = cls._profiles[normalized_cui]
logger.debug(
f"Re-registering CUI {normalized_cui}: "
f"{old_class.__name__} -> {profile_class.__name__}"
)
# Clear cached instance for this CUI
cls._instances.pop(normalized_cui, None)
cls._profiles[normalized_cui] = profile_class
logger.debug(f"Registered profile {profile_class.__name__} for CUI {normalized_cui}")
logger.info(f"Registered {store_name} for CUIs: {cui_list}")
return profile_class
# -------------------------------------------------------------------------
# Lookup
# -------------------------------------------------------------------------
@classmethod
def get_profile(cls, cui: Optional[str]) -> Optional["BaseStoreProfile"]:
"""
Get profile instance for a CUI.
Uses lazy instantiation - creates instance on first access.
Returns None if no profile is registered for this CUI.
Args:
cui: CUI to lookup (with or without RO prefix)
Returns:
Profile instance or None
"""
if not cui:
return None
# Ensure profiles are loaded
if not cls._loaded:
cls._load_all_profiles()
normalized_cui = cls._normalize_cui(cui)
# Check if profile exists
profile_class = cls._profiles.get(normalized_cui)
if not profile_class:
return None
# Lazy instantiation
if normalized_cui not in cls._instances:
try:
cls._instances[normalized_cui] = profile_class()
logger.debug(f"Instantiated {profile_class.__name__} for CUI {normalized_cui}")
except Exception as e:
logger.error(f"Failed to instantiate {profile_class.__name__}: {e}")
return None
return cls._instances[normalized_cui]
@classmethod
def has_profile(cls, cui: Optional[str]) -> bool:
"""Check if a profile exists for this CUI."""
if not cui:
return False
if not cls._loaded:
cls._load_all_profiles()
return cls._normalize_cui(cui) in cls._profiles
# -------------------------------------------------------------------------
# Listing
# -------------------------------------------------------------------------
@classmethod
def list_profiles(cls) -> List[Dict]:
"""
List all registered profiles.
Returns:
List of dicts with cui, class_name, store_name, name_patterns
"""
if not cls._loaded:
cls._load_all_profiles()
result = []
seen_classes = set()
for cui, profile_class in cls._profiles.items():
# Avoid duplicates for profiles with multiple CUIs
if profile_class.__name__ in seen_classes:
continue
seen_classes.add(profile_class.__name__)
result.append({
"cuis": list(getattr(profile_class, 'CUI_LIST', [])),
"class_name": profile_class.__name__,
"store_name": getattr(profile_class, 'STORE_NAME', profile_class.__name__),
"name_patterns": list(getattr(profile_class, 'NAME_PATTERNS', [])),
})
return result
@classmethod
def get_profile_info(cls, cui: str) -> Optional[Dict]:
"""
Get detailed info about a profile.
Args:
cui: CUI to lookup
Returns:
Dict with profile details or None
"""
profile = cls.get_profile(cui)
if not profile:
return None
return {
"cui": cui,
"cuis": list(profile.CUI_LIST),
"class_name": profile.__class__.__name__,
"store_name": profile.STORE_NAME,
"name_patterns": list(profile.NAME_PATTERNS),
"validation_hints": profile.get_validation_hints(),
}
# -------------------------------------------------------------------------
# Hot-Reload
# -------------------------------------------------------------------------
@classmethod
def reload_all(cls) -> int:
"""
Hot-reload all profile modules.
Clears instance cache and reloads all .py files in profiles directory.
Decorator re-registers classes with updated code.
Returns:
Number of modules reloaded
"""
logger.info("Starting profile hot-reload...")
# Clear instance cache (will be recreated on next get_profile)
cls._instances.clear()
# Get list of profile modules (exclude __init__, base)
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
base_package = cls.__module__
count = 0
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
if full_name in sys.modules:
# Reload existing module
importlib.reload(sys.modules[full_name])
logger.debug(f"Reloaded module: {module_name}")
else:
# Import new module
importlib.import_module(full_name)
logger.debug(f"Imported new module: {module_name}")
count += 1
except Exception as e:
logger.error(f"Failed to reload {module_name}: {e}")
cls._last_reload = datetime.utcnow()
cls._loaded = True
logger.info(f"Profile hot-reload complete: {count} modules, {len(cls._profiles)} profiles")
return count
@classmethod
def get_reload_status(cls) -> Dict:
"""Get status of the registry including last reload time."""
return {
"loaded": cls._loaded,
"last_reload": cls._last_reload.isoformat() if cls._last_reload else None,
"profiles_count": len(cls._profiles),
"instances_count": len(cls._instances),
"registered_cuis": list(cls._profiles.keys()),
}
# -------------------------------------------------------------------------
# Internal methods
# -------------------------------------------------------------------------
@classmethod
def _normalize_cui(cls, cui: str) -> str:
"""
Normalize CUI for consistent lookup.
- Removes RO prefix (with or without space)
- Strips whitespace
- Converts to uppercase
Args:
cui: Raw CUI string
Returns:
Normalized CUI (digits only)
"""
if not cui:
return ""
cui = str(cui).strip().upper()
# Remove RO prefix (handles "RO12345" and "RO 12345")
if cui.startswith("RO"):
cui = cui[2:].lstrip()
return cui.strip()
@classmethod
def _get_profile_module_names(cls) -> List[str]:
"""
Get list of profile module names from profiles directory.
Excludes __init__.py and base.py.
Returns:
List of module names (without .py extension)
"""
excluded = {"__init__", "base", "__pycache__"}
modules = []
for path in PROFILES_DIR.glob("*.py"):
name = path.stem
if name not in excluded:
modules.append(name)
return sorted(modules)
@classmethod
def _load_all_profiles(cls) -> None:
"""
Initial load of all profile modules.
Called automatically on first get_profile() if not already loaded.
"""
if cls._loaded:
return
logger.info("Loading store profiles...")
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
# This handles both:
# - Running from backend dir: "modules.data_entry.services.ocr.profiles"
# - Running from project root: "backend.modules.data_entry.services.ocr.profiles"
this_module = cls.__module__ # e.g. "backend.modules..." or "modules..."
base_package = this_module # Use the same prefix for child modules
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_name)
logger.debug(f"Loaded module: {module_name}")
except Exception as e:
logger.error(f"Failed to load {module_name}: {e}")
cls._loaded = True
cls._last_reload = datetime.utcnow()
logger.info(f"Loaded {len(cls._profiles)} store profiles")
@classmethod
def clear(cls) -> None:
"""
Clear all registered profiles.
Mainly useful for testing.
"""
cls._profiles.clear()
cls._instances.clear()
cls._loaded = False
cls._last_reload = None
# -------------------------------------------------------------------------
# Module exports
# -------------------------------------------------------------------------
__all__ = [
"ProfileRegistry",
"BaseStoreProfile",
]
# Re-export BaseStoreProfile for convenience
from .base import BaseStoreProfile

View File

@@ -0,0 +1,655 @@
"""
Optimized Tesseract Engine for OCR - SPEED + QUALITY OPTIMIZED
Performance optimizations (vs previous version):
- Single PSM mode (PSM 4) instead of multi-PSM (4 modes × 2 calls = 8x faster)
- Single Tesseract call per image (skip image_to_data for speed)
- Lighter preprocessing (no over-binarization)
- --dpi 300 flag for proper scaling
- OEM 3 (default LSTM+Legacy) for balanced speed/accuracy
Quality optimizations for Romanian receipts:
- PSM 4: Single column layout (optimal for receipts)
- Polarity correction: ensures black text on white background
- Language: Romanian only (-l ron) for faster recognition
- Fallback to PSM 6 if PSM 4 produces poor results
Previous issues fixed:
- Was 8x slower than PaddleOCR due to multi-PSM + dual calls
- Produced gibberish on clear PDFs due to over-binarization
"""
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import cv2
import numpy as np
# Check Tesseract availability
try:
import pytesseract
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
pytesseract = None
logger = logging.getLogger(__name__)
@dataclass
class OCRResult:
"""Raw OCR result from Tesseract."""
text: str
confidence: float
boxes: List[dict] = field(default_factory=list)
engine: str = "tesseract"
class TesseractEngine:
"""
Optimized Tesseract engine for receipt OCR.
TESTED OPTIMAL SETTINGS (from comprehensive benchmark):
- DPI 200 for PDF loading (not 300!)
- Padding 40px for edge protection
- PSM 6 for complex receipts, PSM 4 for simple ones
- Multi-pass strategy when quality is critical
SPEED vs QUALITY tradeoff:
- Fast mode (single pass): ~0.9s, ~6-7 keywords
- Quality mode (multi-pass): ~1.7s, ~8-9 keywords (+2 more keywords)
BENCHMARK RESULTS:
- padded_psm6_40: Best for complex receipts (igiena, five-holding)
- baseline_psm4: Best for simple receipts (rechizite, benzina)
- multi-pass: Best overall quality but slower
"""
# PSM modes for receipts
PSM_SINGLE_COLUMN = 4 # Best for simple vertical receipts
PSM_UNIFORM_BLOCK = 6 # Best for complex layouts
PSM_SPARSE_TEXT = 11 # Fallback for difficult receipts
# Optimal padding (from benchmark)
DEFAULT_PADDING = 40
def __init__(self):
"""Initialize Tesseract engine."""
if not TESSERACT_AVAILABLE:
raise RuntimeError("pytesseract not available. Install with: pip install pytesseract")
# Verify Tesseract installation
try:
self._version = pytesseract.get_tesseract_version()
except Exception as e:
raise RuntimeError(f"Tesseract not installed or not in PATH: {e}")
logger.info(f"[TesseractEngine] Initialized (v{self._version})")
def recognize(self, image: np.ndarray, fast_mode: bool = True) -> OCRResult:
"""
Perform OCR recognition on image (OPTIMIZED).
SPEED: Uses single PSM mode + single Tesseract call.
Previously used 4 PSM modes × 2 calls = 8 Tesseract invocations.
Now uses 1-2 calls maximum (with fallback).
Args:
image: Preprocessed grayscale image (DO NOT binarize for clear PDFs!)
fast_mode: If True, skip confidence calculation for maximum speed
Returns:
OCRResult with text and confidence
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity (black text on white background)
image = self._ensure_correct_polarity(image)
# Try PSM 4 first (single column - best for receipts)
result = self._recognize_fast(image, self.PSM_SINGLE_COLUMN, fast_mode)
# If poor result, try PSM 6 as fallback
if not result.text.strip() or result.confidence < 0.3:
logger.debug(f"[Tesseract] PSM {self.PSM_SINGLE_COLUMN} poor result, trying PSM {self.PSM_UNIFORM_BLOCK}")
fallback = self._recognize_fast(image, self.PSM_UNIFORM_BLOCK, fast_mode)
if len(fallback.text) > len(result.text):
result = fallback
if result.text.strip():
logger.info(f"[TesseractEngine] Result: {len(result.text)} chars, conf={result.confidence:.0%}")
return result
def _recognize_fast(self, image: np.ndarray, psm: int, fast_mode: bool = True) -> OCRResult:
"""
Fast single-call Tesseract recognition.
Optimizations:
- Single call (image_to_string only in fast mode)
- OEM 3 (LSTM+Legacy) - faster than OEM 1
- --dpi 300 for proper scaling
- Romanian only (-l ron)
Args:
image: Grayscale image
psm: Page segmentation mode
fast_mode: Skip confidence calculation for speed
Returns:
OCRResult
"""
# Build optimized config:
# OEM 3 = LSTM + Legacy (faster than pure LSTM)
# --dpi 300 = proper scaling hint
# -l ron = Romanian only (faster, avoids eng confusion)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
if fast_mode:
# Fast path: just get text, estimate confidence
text = pytesseract.image_to_string(image, config=config)
# Estimate confidence based on text quality
confidence = self._estimate_confidence(text)
else:
# Accurate path: get text + real confidence
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
confidence = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine="tesseract"
)
except Exception as e:
logger.warning(f"[Tesseract] PSM {psm} error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _estimate_confidence(self, text: str) -> float:
"""
Estimate OCR confidence based on text quality.
Heuristics:
- More alphanumeric chars = higher confidence
- Less garbage chars = higher confidence
- Romanian-specific patterns boost confidence
"""
if not text.strip():
return 0.0
# Count valid vs garbage chars
valid_chars = sum(1 for c in text if c.isalnum() or c in '.,;:-/\n ')
total_chars = len(text)
if total_chars == 0:
return 0.0
# Base confidence from char ratio
confidence = valid_chars / total_chars
# Boost for Romanian receipt patterns
text_lower = text.lower()
if any(word in text_lower for word in ['total', 'lei', 'ron', 'buc', 'tva', 'cif', 'bon']):
confidence = min(confidence + 0.1, 1.0)
return confidence
def recognize_multipass(self, image: np.ndarray) -> OCRResult:
"""
Multi-pass OCR for maximum quality (slower but more accurate).
Strategy (from benchmark testing):
- Pass 1: PSM 4 (single column) - no padding, fast baseline
- Pass 2: PSM 6 (uniform block) - with 40px padding, better for complex layouts
- Pass 3: PSM 11 (sparse text) - with 40px padding + stronger CLAHE, for difficult receipts
Merges results: picks the pass with highest keyword count.
On average finds +2.1 more keywords than single-pass (~8.7 vs 6.6).
Time: ~1.7s (vs ~0.9s for single pass)
Args:
image: Input image (RGB or grayscale)
Returns:
OCRResult from the best pass
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# Define passes with different settings
passes = [
# Pass 1: Fast baseline (no padding) - good for simple receipts
{"name": "pass1_psm4", "psm": 4, "padding": 0, "clahe_clip": 1.5},
# Pass 2: Padded PSM 6 - good for complex receipts
{"name": "pass2_psm6_padded", "psm": 6, "padding": 40, "clahe_clip": 1.5},
# Pass 3: Sparse text with stronger enhancement - for difficult cases
{"name": "pass3_psm11", "psm": 11, "padding": 40, "clahe_clip": 2.0},
]
best_result = None
best_score = -1
all_keywords = set()
for p in passes:
# Apply preprocessing for this pass
processed = gray.copy()
# Add padding if specified
if p["padding"] > 0:
processed = cv2.copyMakeBorder(
processed, p["padding"], p["padding"], p["padding"], p["padding"],
cv2.BORDER_CONSTANT, value=255
)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=p["clahe_clip"], tileGridSize=(8, 8))
processed = clahe.apply(processed)
# Ensure correct polarity
processed = self._ensure_correct_polarity(processed)
# Run OCR
config = f'--psm {p["psm"]} --oem 3 -l ron'
try:
text = pytesseract.image_to_string(processed, config=config)
confidence = self._estimate_confidence(text)
# Score based on Romanian receipt keywords
text_lower = text.lower()
keywords = ['cif', 'total', 'tva', 'lei', 'ron', 'buc', 'fiscal', 'bon',
'hartie', 'prosop', 'saci', 'creion', 'constanta', 'bucuresti']
found_keywords = [kw for kw in keywords if kw in text_lower]
all_keywords.update(found_keywords)
# Score: keywords + CIF bonus + TOTAL bonus
score = len(found_keywords) * 10
if self._has_cif_pattern(text):
score += 15
if self._has_total_pattern(text):
score += 10
logger.debug(f"[Tesseract] {p['name']}: {len(found_keywords)} keywords, score={score}")
if score > best_score:
best_score = score
best_result = OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine=f"tesseract-multipass-{p['name']}"
)
except Exception as e:
logger.warning(f"[Tesseract] {p['name']} failed: {e}")
continue
if best_result:
logger.info(f"[TesseractEngine] Multi-pass best: {best_result.engine}, "
f"{len(all_keywords)} total keywords found")
return best_result
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract-multipass")
def _has_cif_pattern(self, text: str) -> bool:
"""Check if text contains a valid CIF/CUI pattern."""
import re
text_upper = text.upper()
patterns = [
r'CIF[:\s]*RO?\d{6,10}',
r'CUI[:\s]*RO?\d{6,10}',
r'C\.?I\.?F\.?[:\s]*RO?\d{6,10}',
]
for pattern in patterns:
if re.search(pattern, text_upper):
return True
return bool(re.search(r'RO\d{7,10}', text_upper))
def _has_total_pattern(self, text: str) -> bool:
"""Check if TOTAL is properly recognized (not truncated to BTOTAL/OTAL)."""
import re
text_upper = text.upper()
return bool(re.search(r'(^|\s)TOTAL\s', text_upper, re.MULTILINE))
def recognize_with_boxes(self, image: np.ndarray, psm: int = 4) -> OCRResult:
"""
Recognition with bounding boxes (slower, for debugging/visualization).
Use this only when you need box coordinates.
For normal OCR, use recognize() which is faster.
Args:
image: Grayscale image
psm: Page segmentation mode (default: 4 for receipts)
Returns:
OCRResult with text, confidence, and boxes
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = self._ensure_correct_polarity(image)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
boxes = []
for i in range(len(data['text'])):
if data['text'][i].strip() and int(data['conf'][i]) > 0:
boxes.append({
'text': data['text'][i],
'confidence': int(data['conf'][i]) / 100,
'box': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
})
return OCRResult(text=text, confidence=avg_conf, boxes=boxes, engine="tesseract")
except Exception as e:
logger.warning(f"[Tesseract] recognize_with_boxes error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _ensure_correct_polarity(self, image: np.ndarray) -> np.ndarray:
"""
Ensure image has black text on white background.
Receipts should have dark text on light background.
If image is inverted (light text on dark), invert it.
Detection method:
- Calculate mean pixel value
- If mean < 127, image is mostly dark (inverted)
- Invert to correct polarity
Args:
image: Grayscale image
Returns:
Polarity-corrected image
"""
mean_value = np.mean(image)
if mean_value < 127:
# Image is mostly dark = inverted (white text on black)
logger.debug(f"[TesseractEngine] Detected inverted polarity (mean={mean_value:.1f}), correcting...")
return 255 - image
return image
def recognize_numbers_only(self, image: np.ndarray) -> OCRResult:
"""
OCR optimized for numeric content (amounts, totals).
Uses character whitelist to reduce errors on numbers.
Args:
image: Preprocessed grayscale image
Returns:
OCRResult with numeric text
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity
image = self._ensure_correct_polarity(image)
# Config for numbers only
# Whitelist: digits, comma, period, space, RON, LEI
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789.,- '
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image,
config=config,
output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text.strip(),
confidence=avg_conf,
boxes=[],
engine="tesseract-numeric"
)
except Exception as e:
logger.error(f"[TesseractEngine] Numeric OCR error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def recognize_cif_optimized(self, image: np.ndarray) -> Optional[str]:
"""
Optimized CIF extraction using multi-strategy approach.
BENCHMARK RESULTS (from test_critical_fields.py):
- digit_opt_dpi200: 33% accuracy (best)
- digit_whitelist: Works well on specific receipts
- basic_ron_eng: Good backup
Strategy:
1. Try digit-optimized preprocessing (2x scale + Otsu)
2. Try character whitelist (RO + digits only)
3. Try standard ron+eng config
4. Return best match based on CIF pattern validation
Args:
image: Input image (RGB from pdf2image or BGR from OpenCV)
Returns:
Extracted CIF string (e.g., "RO10562600") or None
"""
import re
if not TESSERACT_AVAILABLE:
return None
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# Extract top 35% of image (where CIF is typically found)
height = gray.shape[0]
top_region = gray[:int(height * 0.35), :]
candidates = []
# Strategy 1: Digit-optimized preprocessing (best performer: 33% accuracy)
try:
# Scale up 2x + Otsu binarization
scaled = cv2.resize(top_region, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(scaled)
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if np.mean(binary) < 127:
binary = 255 - binary
text = pytesseract.image_to_string(binary, config='--psm 6 --oem 3 -l ron')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('digit_opt', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] digit_opt strategy failed: {e}")
# Strategy 2: Character whitelist (RO + digits only)
try:
# Add padding
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
scaled = cv2.resize(padded, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789ROro'
text = pytesseract.image_to_string(scaled, config=config)
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('whitelist', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] whitelist strategy failed: {e}")
# Strategy 3: Standard ron+eng config (good backup)
try:
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(padded)
text = pytesseract.image_to_string(enhanced, config='--psm 6 --oem 3 -l ron+eng')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('ron_eng', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] ron_eng strategy failed: {e}")
if not candidates:
return None
# Log all candidates
for strategy, cif in candidates:
logger.debug(f"[TesseractEngine] CIF candidate from {strategy}: {cif}")
# Use majority voting if multiple strategies agree
from collections import Counter
cif_counts = Counter(cif for _, cif in candidates)
most_common_cif, count = cif_counts.most_common(1)[0]
if count > 1:
# Multiple strategies agree
logger.info(f"[TesseractEngine] CIF extracted (majority {count} strategies): {most_common_cif}")
return most_common_cif
# No agreement - prefer digit_opt strategy (33% accuracy in benchmarks)
for strategy, cif in candidates:
if strategy == 'digit_opt':
logger.info(f"[TesseractEngine] CIF extracted via digit_opt (preferred): {cif}")
return cif
# Fallback to first candidate
strategy, cif = candidates[0]
logger.info(f"[TesseractEngine] CIF extracted via {strategy}: {cif}")
return cif
def _extract_cif_from_text(self, text: str) -> Optional[str]:
"""Extract CIF/CUI from OCR text."""
import re
text_upper = text.upper().replace(' ', '')
patterns = [
r'CIF[:\s]*R?O?(\d{6,10})',
r'CUI[:\s]*R?O?(\d{6,10})',
r'C\.?I\.?F\.?[:\s]*R?O?(\d{6,10})',
r'RO(\d{7,10})',
r'R\.?O\.?[\s:]*(\d{6,10})',
]
for pattern in patterns:
match = re.search(pattern, text_upper)
if match:
digits = match.group(1).lstrip('0') or '0'
return f"RO{digits}"
return None
@staticmethod
def validate_romanian_cif(cif: str) -> bool:
"""
Validate Romanian CIF/CUI using checksum algorithm.
Romanian CIF format: RO + 2-10 digits
The last digit is a control digit calculated using modulo 11.
Algorithm:
1. Multiply each digit by corresponding weight (from right to left: 2,3,4,5,6,7,2,3,4,5)
2. Sum all products
3. Remainder of sum / 11 is the control digit
4. If remainder is 10, control digit is 0
Args:
cif: CIF string (e.g., "RO10562600", "10562600")
Returns:
True if CIF is valid, False otherwise
"""
# Remove RO prefix and spaces
cif = cif.upper().replace(' ', '').replace('RO', '')
# Must be 2-10 digits
if not cif.isdigit() or len(cif) < 2 or len(cif) > 10:
return False
# Weights for checksum calculation (right to left)
weights = [2, 3, 4, 5, 6, 7, 2, 3, 4, 5]
# Pad with zeros on the left to make it 10 digits
cif_padded = cif.zfill(10)
# Calculate checksum (excluding last digit which is control)
total = 0
for i in range(9):
total += int(cif_padded[i]) * weights[i]
# Control digit
control = total % 11
if control == 10:
control = 0
# Compare with last digit
return int(cif_padded[9]) == control
@staticmethod
def is_available() -> bool:
"""Check if Tesseract is available."""
if not TESSERACT_AVAILABLE:
return False
try:
pytesseract.get_tesseract_version()
return True
except Exception:
return False
@staticmethod
def get_version() -> Optional[str]:
"""Get Tesseract version string."""
if not TESSERACT_AVAILABLE:
return None
try:
return str(pytesseract.get_tesseract_version())
except Exception:
return None

View File

@@ -0,0 +1,476 @@
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
import os
import logging
import threading
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
# Setup logging (respects LOG_LEVEL env var set in main.py)
logger = logging.getLogger(__name__)
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Lazy imports - these will be imported on first use
PaddleOCR = None # Will be imported lazily
pytesseract = None # Will be imported lazily
doctr_ocr_predictor = None # Will be imported lazily
# Check availability without importing heavy libraries
def _check_paddle_available() -> bool:
"""Check if paddleocr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("paddleocr") is not None
except Exception:
return False
def _check_tesseract_available() -> bool:
"""Check if pytesseract is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("pytesseract") is not None
except Exception:
return False
def _check_doctr_available() -> bool:
"""Check if doctr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("doctr") is not None
except Exception:
return False
PADDLE_AVAILABLE = _check_paddle_available()
TESSERACT_AVAILABLE = _check_tesseract_available()
DOCTR_AVAILABLE = _check_doctr_available()
@dataclass
class OCRResult:
"""Raw OCR result."""
text: str
confidence: float
boxes: List[dict]
engine: str = "" # OCR engine used: paddleocr or tesseract
class OCREngine:
"""Unified OCR engine with fallback support."""
def __init__(self):
self._paddle = None
self._paddle_init_started = False
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
self._paddle_init_lock = threading.Lock()
self._doctr = None
self._doctr_init_started = False
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
self._doctr_init_lock = threading.Lock()
def _init_paddle_lazy(self):
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
global PaddleOCR
with self._paddle_init_lock:
if self._paddle_init_started:
return # Already initializing or done
self._paddle_init_started = True
if PADDLE_AVAILABLE:
try:
print("Importing PaddleOCR (first use, may take ~15-20 seconds)...", flush=True)
from paddleocr import PaddleOCR as _PaddleOCR
PaddleOCR = _PaddleOCR
print("Initializing PaddleOCR engine...", flush=True)
# PaddleOCR 3.x API - optimized for Romanian receipts
# Note: 'latin' not available in PaddleOCR 3.x, 'en' works well for receipts
self._paddle = PaddleOCR(
lang='en', # 'en' handles Latin alphabet well for receipts
# High quality settings for better accuracy
det_db_thresh=0.3, # Lower threshold = detect more text (default 0.3)
det_db_box_thresh=0.5, # Box confidence threshold (default 0.5)
det_db_unclip_ratio=1.8, # Expand detected boxes slightly (default 1.5)
rec_batch_num=6, # Batch size for recognition
use_angle_cls=True, # Enable text angle classification
)
print("PaddleOCR initialized successfully with high-quality settings", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize PaddleOCR: {e}", flush=True)
self._paddle = None
# Signal that initialization is complete (success or failure)
self._paddle_ready.set()
def _init_doctr_lazy(self):
"""Lazy initialize docTR on first use (avoids slow startup)."""
global doctr_ocr_predictor
with self._doctr_init_lock:
if self._doctr_init_started:
return # Already initializing or done
self._doctr_init_started = True
if DOCTR_AVAILABLE:
try:
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
print("Initializing docTR engine (PyTorch backend)...", flush=True)
# Initialize docTR predictor with pretrained models
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
self._doctr = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True,
assume_straight_pages=True,
straighten_pages=False,
preserve_aspect_ratio=True,
)
doctr_ocr_predictor = self._doctr
print("docTR initialized successfully with PyTorch backend", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
self._doctr = None
# Signal that initialization is complete (success or failure)
self._doctr_ready.set()
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
"""
Wait for docTR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if docTR is ready, False if timeout or unavailable
"""
if not DOCTR_AVAILABLE:
return False
if self._doctr is not None:
return True # Already ready
if not self._doctr_init_started:
# Start initialization if not already started
self._init_doctr_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._doctr_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._doctr is not None:
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_doctr_ready(self) -> bool:
"""Check if docTR is ready without waiting."""
return self._doctr is not None
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
"""
Wait for PaddleOCR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if PaddleOCR is ready, False if timeout or unavailable
"""
if not PADDLE_AVAILABLE:
return False
if self._paddle is not None:
return True # Already ready
if not self._paddle_init_started:
# Start initialization if not already started
self._init_paddle_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for PaddleOCR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._paddle_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._paddle is not None:
print(f"[OCR] PaddleOCR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] PaddleOCR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_paddle_ready(self) -> bool:
"""Check if PaddleOCR is ready without waiting."""
return self._paddle is not None
def recognize(self, image: np.ndarray) -> OCRResult:
"""Perform OCR on preprocessed image."""
logger.info(f"[OCR] Starting recognition, image shape: {image.shape}, dtype: {image.dtype}")
# Lazy init PaddleOCR on first call
self._init_paddle_lazy()
if PADDLE_AVAILABLE and self._paddle:
logger.info("[OCR] Using PaddleOCR engine")
return self._paddle_recognize(image)
elif TESSERACT_AVAILABLE:
logger.info("[OCR] Using Tesseract engine (PaddleOCR not available)")
return self._tesseract_recognize(image)
else:
logger.error("[OCR] No OCR engine available!")
raise RuntimeError(
"No OCR engine available. Install PaddleOCR or Tesseract."
)
def _paddle_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using PaddleOCR 3.x API."""
# Wait for PaddleOCR to be fully ready (handles background init)
if not self.wait_for_paddle(timeout=30.0):
logger.warning("[PaddleOCR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("PaddleOCR not ready and Tesseract not available")
try:
logger.info(f"[PaddleOCR] Processing image, shape: {image.shape}")
# PaddleOCR 3.x requires 3-channel images
if len(image.shape) == 2:
# Convert grayscale to 3-channel BGR
import cv2
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
logger.info(f"[PaddleOCR] Converted to BGR, new shape: {image.shape}")
# PaddleOCR 3.x uses predict() with new parameter names
logger.info("[PaddleOCR] Calling predict()...")
result = self._paddle.predict(image, use_textline_orientation=True)
logger.info(f"[PaddleOCR] predict() returned, result type: {type(result)}")
if not result or len(result) == 0:
logger.warning("[PaddleOCR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# PaddleOCR 3.x returns OCRResult objects with different structure
ocr_result = result[0]
# Extract texts and scores from the new format
rec_texts = ocr_result.get('rec_texts', [])
rec_scores = ocr_result.get('rec_scores', [])
dt_polys = ocr_result.get('dt_polys', [])
if not rec_texts:
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
boxes = []
for i, text in enumerate(rec_texts):
conf = rec_scores[i] if i < len(rec_scores) else 0.0
box = dt_polys[i].tolist() if i < len(dt_polys) else []
boxes.append({
'text': text,
'confidence': float(conf),
'box': box
})
avg_conf = sum(rec_scores) / len(rec_scores) if rec_scores else 0.0
text_result = '\n'.join(rec_texts)
logger.info(f"[PaddleOCR] SUCCESS - Found {len(rec_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[PaddleOCR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="paddleocr"
)
except Exception as e:
logger.error(f"[PaddleOCR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def _tesseract_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using Tesseract."""
global pytesseract
logger.info(f"[Tesseract] Processing image, shape: {image.shape}")
# Lazy import pytesseract
if pytesseract is None:
logger.info("[Tesseract] Importing pytesseract...")
import pytesseract as _pytesseract
pytesseract = _pytesseract
# PSM 4: Single column (best for receipts)
config = '--psm 4 -l ron+eng'
text = pytesseract.image_to_string(image, config=config)
# Quick confidence estimate
data = pytesseract.image_to_data(image, config=config, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using docTR."""
# Wait for docTR to be fully ready
if not self.wait_for_doctr(timeout=30.0):
logger.warning("[docTR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("docTR not ready and Tesseract not available")
try:
logger.info(f"[docTR] Processing image, shape: {image.shape}")
# docTR requires RGB images
import cv2
if len(image.shape) == 2:
# Convert grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
elif image.shape[2] == 4:
# Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
elif image.shape[2] == 3:
# Check if BGR (from OpenCV) and convert to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
# Process image with docTR
logger.info("[docTR] Running prediction...")
from doctr.io import DocumentFile
# docTR expects a document (list of pages as numpy arrays)
result = self._doctr([image])
if not result or not result.pages:
logger.warning("[docTR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
# Extract text from all pages
all_texts = []
all_confidences = []
boxes = []
for page in result.pages:
for block in page.blocks:
for line in block.lines:
line_text = ' '.join(word.value for word in line.words)
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
all_texts.append(line_text)
all_confidences.append(line_confidence)
# Store word-level boxes
for word in line.words:
boxes.append({
'text': word.value,
'confidence': float(word.confidence),
'box': word.geometry # (xmin, ymin), (xmax, ymax)
})
text_result = '\n'.join(all_texts)
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="doctr"
)
except Exception as e:
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
"""
Run both OCR engines and return both results.
Returns:
Tuple of (paddle_result, tesseract_result)
tesseract_result may be None if Tesseract is not available
"""
logger.info(f"[OCR Dual] Starting dual recognition, image shape: {image.shape}")
# Lazy init PaddleOCR
self._init_paddle_lazy()
paddle_result = None
tesseract_result = None
# Run PaddleOCR
if PADDLE_AVAILABLE and self._paddle:
try:
logger.info("[OCR Dual] Running PaddleOCR...")
paddle_result = self._paddle_recognize(image)
logger.info(f"[OCR Dual] PaddleOCR: {len(paddle_result.text)} chars, conf: {paddle_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] PaddleOCR failed: {e}")
paddle_result = OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# Run Tesseract
if TESSERACT_AVAILABLE:
try:
logger.info("[OCR Dual] Running Tesseract...")
tesseract_result = self._tesseract_recognize(image)
logger.info(f"[OCR Dual] Tesseract: {len(tesseract_result.text)} chars, conf: {tesseract_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] Tesseract failed: {e}")
tesseract_result = OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Fallback if PaddleOCR not available
if paddle_result is None:
if tesseract_result:
paddle_result = tesseract_result
else:
raise RuntimeError("No OCR engine available")
return paddle_result, tesseract_result
@staticmethod
def get_available_engines() -> List[str]:
"""
Return list of available OCR engines.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Engines that are disabled via .env are not returned even if installed.
Available engines: tesseract, doctr, doctr_plus, paddleocr
"""
# Check .env settings
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
engines = []
# Base engines (only if installed AND enabled)
if TESSERACT_AVAILABLE and tesseract_enabled:
engines.append('tesseract')
if DOCTR_AVAILABLE:
engines.append('doctr')
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
if PADDLE_AVAILABLE and paddle_enabled:
engines.append('paddleocr')
return engines

View File

@@ -0,0 +1,735 @@
"""Main OCR service coordinating preprocessing, recognition, and extraction."""
import os
import re
import gc
import logging
import threading
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x) - must be set before import
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
import time
import asyncio
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
from pathlib import Path
from typing import Optional, Tuple
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr_extractor import ReceiptExtractor, ExtractionResult
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
# Setup logging
logger = logging.getLogger(__name__)
def get_memory_usage_mb() -> float:
"""Get current process memory usage in MB."""
try:
import resource
# Get memory in KB, convert to MB
rusage = resource.getrusage(resource.RUSAGE_SELF)
return rusage.ru_maxrss / 1024 # Linux returns KB
except Exception:
return 0.0
class OCRService:
"""Service for OCR processing of receipt images."""
# Single worker to prevent memory accumulation from parallel OCR
_executor = ThreadPoolExecutor(max_workers=1)
# Semaphore to ensure only one OCR operation at a time (memory protection)
_ocr_semaphore = threading.Semaphore(1)
# Memory threshold in MB - if exceeded, force GC before processing
_memory_threshold_mb = 2500
def __init__(self):
self.preprocessor = ImagePreprocessor()
self.ocr_engine = OCREngine()
self.extractor = ReceiptExtractor()
async def process_image(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""
Process receipt image and extract structured data.
Args:
image_path: Path to the image file
mime_type: MIME type of the file
Returns:
Tuple of (success, message, extraction_result)
"""
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self._executor,
self._process_sync,
image_path,
mime_type
)
return result
except Exception as e:
return False, f"OCR processing failed: {str(e)}", None
def _cleanup_memory(self, *arrays):
"""Explicitly delete numpy arrays and force garbage collection."""
for arr in arrays:
if arr is not None:
try:
del arr
except:
pass
gc.collect()
def _process_sync(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""Synchronous processing with ADAPTIVE OCR pipeline."""
# Acquire semaphore to ensure only one OCR at a time
acquired = self._ocr_semaphore.acquire(timeout=120) # 2 min timeout
if not acquired:
return False, "OCR service busy - please try again", None
try:
return self._process_sync_internal(image_path, mime_type)
finally:
# Always release semaphore and cleanup
self._ocr_semaphore.release()
# Force garbage collection after EVERY OCR request
gc.collect()
mem_after = get_memory_usage_mb()
print(f"[OCR Service] Memory after cleanup: {mem_after:.0f}MB", flush=True)
def _process_sync_internal(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""Internal processing - called with semaphore held."""
start_time = time.time()
mem_before = get_memory_usage_mb()
print(f"[OCR Service] Starting processing: {image_path}, mime: {mime_type}", flush=True)
print(f"[OCR Service] Memory before: {mem_before:.0f}MB", flush=True)
# Check if memory is high - force GC before processing
if mem_before > self._memory_threshold_mb:
print(f"[OCR Service] WARNING: Memory high ({mem_before:.0f}MB > {self._memory_threshold_mb}MB), forcing GC...", flush=True)
gc.collect()
mem_after_gc = get_memory_usage_mb()
print(f"[OCR Service] Memory after pre-GC: {mem_after_gc:.0f}MB", flush=True)
# Load image
images = None # For cleanup
image = None
if mime_type == 'application/pdf':
try:
images = self.preprocessor.pdf_to_images(image_path)
if not images:
return False, "Failed to extract images from PDF", None
image = images[0]
# Delete other pages immediately to save memory
if len(images) > 1:
for i in range(1, len(images)):
del images[i]
images = [image]
except RuntimeError as e:
return False, str(e), None
else:
try:
image = self.preprocessor.load_image(image_path)
except ValueError as e:
return False, str(e), None
raw_texts = []
extraction = None
# ══════════════════════════════════════════════════════════════
# STEP 1: PaddleOCR + Light (fastest, best for clear PDFs)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 1: PaddleOCR + Light preprocessing", flush=True)
print("=" * 60, flush=True)
light_img = self.preprocessor.preprocess_light(image)
try:
paddle_light = self.ocr_engine._paddle_recognize(light_img)
# Cleanup light_img immediately after OCR
del light_img
light_img = None
if paddle_light and paddle_light.text:
extraction = self.extractor.extract(paddle_light.text)
extraction.ocr_engine = "paddle-light"
raw_texts.append(f"═══ PaddleOCR (light, conf: {paddle_light.confidence:.0%}) ═══\n{paddle_light.text}")
# Log extraction results
print(f"[OCR] Step 1 Results:", flush=True)
print(f" - OCR Confidence: {paddle_light.confidence:.0%}", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total} (entries: {len(extraction.tva_entries) if extraction.tva_entries else 0})", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
# Early exit if complete
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] *** EARLY EXIT at Step 1 - All fields found! ({elapsed_ms}ms) ***", flush=True)
# Cleanup before return
del image
if images:
del images
return True, "OCR complete (fast mode)", extraction
else:
print("[OCR] -> Step 1 incomplete, continuing to Step 2...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR light failed: {e}", flush=True)
extraction = ExtractionResult()
# Cleanup on error
if light_img is not None:
del light_img
# ══════════════════════════════════════════════════════════════
# STEP 2: PaddleOCR + Medium (balanced preprocessing)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 2: PaddleOCR + Medium preprocessing", flush=True)
print("=" * 60, flush=True)
medium_img = self.preprocessor.preprocess_medium(image)
try:
paddle_medium = self.ocr_engine._paddle_recognize(medium_img)
# Cleanup medium_img immediately after OCR
del medium_img
medium_img = None
if paddle_medium and paddle_medium.text:
extraction_medium = self.extractor.extract(paddle_medium.text)
extraction_medium.ocr_engine = "paddle-medium"
raw_texts.append(f"═══ PaddleOCR (medium, conf: {paddle_medium.confidence:.0%}) ═══\n{paddle_medium.text}")
print(f"[OCR] Step 2 (Medium) Results:", flush=True)
print(f" - OCR Confidence: {paddle_medium.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_medium.amount}", flush=True)
print(f" - Date: {extraction_medium.receipt_date}", flush=True)
print(f" - CUI: {extraction_medium.cui}", flush=True)
# Merge with previous
extraction = self._merge_extractions(extraction, extraction_medium)
print(f"[OCR] After merge:", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "paddle-adaptive"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] *** EARLY EXIT at Step 2 - All fields found after merge! ({elapsed_ms}ms) ***", flush=True)
# Cleanup before return
del image
if images:
del images
return True, "OCR complete (paddle dual)", extraction
else:
print("[OCR] -> Step 2 incomplete, continuing to Step 3 (Tesseract)...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR medium failed: {e}", flush=True)
# Cleanup on error
if medium_img is not None:
del medium_img
# ══════════════════════════════════════════════════════════════
# STEP 3: Tesseract - ONLY to complete missing fields
# Uses Tesseract-optimized preprocessing (binarized, high contrast)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 3: Tesseract (complement only, not override)", flush=True)
print("=" * 60, flush=True)
tesseract_img = None
try:
# Use Tesseract-specific preprocessing (Otsu binarization)
tesseract_img = self.preprocessor.preprocess_for_tesseract(image)
tesseract_result = self.ocr_engine._tesseract_recognize(tesseract_img)
# Cleanup tesseract_img immediately after OCR
del tesseract_img
tesseract_img = None
if tesseract_result and tesseract_result.text:
extraction_tess = self.extractor.extract(tesseract_result.text)
extraction_tess.ocr_engine = "tesseract"
raw_texts.append(f"═══ Tesseract (conf: {tesseract_result.confidence:.0%}) ═══\n{tesseract_result.text}")
print(f"[OCR] Step 3 (Tesseract) Results:", flush=True)
print(f" - OCR Confidence: {tesseract_result.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_tess.amount}", flush=True)
print(f" - Date: {extraction_tess.receipt_date}", flush=True)
print(f" - CUI: {extraction_tess.cui}", flush=True)
# IMPORTANT: Tesseract only COMPLETES missing fields, never overrides!
extraction = self._complement_extraction(extraction, extraction_tess)
except Exception as e:
print(f"[OCR] Tesseract failed: {e}", flush=True)
# Cleanup on error
if tesseract_img is not None:
del tesseract_img
# Cleanup original image - no longer needed
del image
if images:
del images
# ══════════════════════════════════════════════════════════════
# FINAL VALIDATION: Fix impossible values
# ══════════════════════════════════════════════════════════════
if extraction:
extraction = self._final_validation(extraction)
# Final result
if extraction is None:
return False, "No text detected", None
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "adaptive-full"
# Build result message
fields_found = []
if extraction.amount: fields_found.append("amount")
if extraction.receipt_date: fields_found.append("date")
if extraction.receipt_number: fields_found.append("number")
if extraction.cui: fields_found.append("CUI")
if extraction.tva_total or extraction.tva_entries: fields_found.append("TVA")
message = f"OCR complete (full pipeline). Found: {', '.join(fields_found) or 'no fields'}"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print("=" * 60, flush=True)
print(f"[OCR] FINAL RESULT (full pipeline) - {elapsed_ms}ms", flush=True)
print("=" * 60, flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
print(f" - Processing Time: {elapsed_ms}ms", flush=True)
print(f" - Message: {message}", flush=True)
# ══════════════════════════════════════════════════════════════
# VALIDATION: Apply validation rules to final extraction
# ══════════════════════════════════════════════════════════════
print("\n" + "=" * 60, flush=True)
print("[Validation] Applying validation rules...", flush=True)
print("=" * 60, flush=True)
validator = OCRValidationEngine()
# Prepare data for validation with safe type conversions
def safe_float(value) -> Optional[float]:
"""Safely convert Decimal or number to float."""
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def safe_payment_sum(methods: list, method_type: str) -> Optional[float]:
"""Safely sum payment amounts for a given method type."""
if not methods:
return None
try:
total = sum(
float(pm.get('amount', 0) or 0)
for pm in methods
if pm.get('method') == method_type
)
return total if total > 0 else None
except (TypeError, ValueError):
return None
validation_data = {
'amount': safe_float(extraction.amount),
'tva': safe_float(extraction.tva_total),
'cui': extraction.cui,
'card_amount': safe_payment_sum(extraction.payment_methods, 'CARD'),
'cash_amount': safe_payment_sum(extraction.payment_methods, 'NUMERAR'),
'tva_entries': {
entry.get('code', ''): safe_float(entry.get('amount'))
for entry in (extraction.tva_entries or [])
if entry.get('code') and safe_float(entry.get('amount')) is not None
}
}
# Run validation (no light/medium comparison for final result)
validated_result = validator.validate_extraction(validation_data)
# Apply validation results to extraction
extraction.needs_manual_review = validated_result.needs_manual_review
extraction.validation_warnings = validated_result.validation_warnings
extraction.validation_errors = validated_result.validation_errors
extraction.confidence_adjustments = validated_result.confidence_adjustments
extraction.inter_ocr_ratios = validated_result.inter_ocr_ratios
print(f"[Validation] Complete:", flush=True)
print(f" - Warnings: {len(extraction.validation_warnings)}", flush=True)
print(f" - Errors: {len(extraction.validation_errors)}", flush=True)
print(f" - Needs Manual Review: {extraction.needs_manual_review}", flush=True)
if extraction.validation_warnings:
for warning in extraction.validation_warnings:
print(f" [!] {warning}", flush=True)
return True, message, extraction
def _merge_extractions(
self,
paddle: Optional[ExtractionResult],
tesseract: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Merge two extractions, picking best fields from each engine.
Strategy:
- For each field, prefer the one with higher confidence
- Use validation rules (CUI format, date validity, company indicators)
- Combine TVA entries if different
"""
result = ExtractionResult()
# Handle case where one is None
if paddle is None and tesseract is None:
return result
if paddle is None:
return tesseract
if tesseract is None:
return paddle
print("[Merge] Comparing PaddleOCR vs Tesseract extractions...", flush=True)
# === AMOUNT ===
# Pick higher confidence, both must be positive
if paddle.amount and tesseract.amount:
if paddle.confidence_amount >= tesseract.confidence_amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
print(f"[Merge] Amount: PaddleOCR {paddle.amount} (conf: {paddle.confidence_amount:.0%})", flush=True)
else:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
print(f"[Merge] Amount: Tesseract {tesseract.amount} (conf: {tesseract.confidence_amount:.0%})", flush=True)
elif paddle.amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
elif tesseract.amount:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
# === DATE ===
# Pick higher confidence, validate date reasonableness
if paddle.receipt_date and tesseract.receipt_date:
if paddle.confidence_date >= tesseract.confidence_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
print(f"[Merge] Date: PaddleOCR {paddle.receipt_date}", flush=True)
else:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
print(f"[Merge] Date: Tesseract {tesseract.receipt_date}", flush=True)
elif paddle.receipt_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
elif tesseract.receipt_date:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
# === VENDOR NAME ===
# Prefer one with company indicators (S.R.L., S.A., etc.)
paddle_has_indicator = self._has_company_indicator(paddle.partner_name)
tesseract_has_indicator = self._has_company_indicator(tesseract.partner_name)
if paddle.partner_name and tesseract.partner_name:
if paddle_has_indicator and not tesseract_has_indicator:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (has company indicator)", flush=True)
elif tesseract_has_indicator and not paddle_has_indicator:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (has company indicator)", flush=True)
elif paddle.confidence_vendor >= tesseract.confidence_vendor:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (higher conf)", flush=True)
else:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (higher conf)", flush=True)
elif paddle.partner_name:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
elif tesseract.partner_name:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
# === CUI (Fiscal Code) ===
# Validate format: 6-10 digits, prefer valid one
paddle_cui_valid = self._is_valid_cui(paddle.cui)
tesseract_cui_valid = self._is_valid_cui(tesseract.cui)
if paddle.cui and tesseract.cui:
if paddle_cui_valid and not tesseract_cui_valid:
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui} (valid format)", flush=True)
elif tesseract_cui_valid and not paddle_cui_valid:
result.cui = tesseract.cui
print(f"[Merge] CUI: Tesseract {tesseract.cui} (valid format)", flush=True)
else:
# Both valid or both invalid - prefer PaddleOCR
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui}", flush=True)
elif paddle.cui and paddle_cui_valid:
result.cui = paddle.cui
elif tesseract.cui and tesseract_cui_valid:
result.cui = tesseract.cui
elif paddle.cui:
result.cui = paddle.cui
elif tesseract.cui:
result.cui = tesseract.cui
# === TVA ENTRIES ===
# Prefer non-empty, use the one with more entries or higher amounts
if paddle.tva_entries and tesseract.tva_entries:
# Compare: prefer the one with actual amounts (not just 0)
paddle_total = sum(e.get('amount', Decimal('0')) for e in paddle.tva_entries)
tesseract_total = sum(e.get('amount', Decimal('0')) for e in tesseract.tva_entries)
if paddle_total >= tesseract_total:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
print(f"[Merge] TVA: PaddleOCR (total: {paddle_total})", flush=True)
else:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
print(f"[Merge] TVA: Tesseract (total: {tesseract_total})", flush=True)
elif paddle.tva_entries:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
elif tesseract.tva_entries:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
# === OTHER FIELDS ===
# Simple preference: paddle > tesseract
result.receipt_number = paddle.receipt_number or tesseract.receipt_number
result.receipt_series = paddle.receipt_series or tesseract.receipt_series
result.receipt_type = paddle.receipt_type or tesseract.receipt_type
result.items_count = paddle.items_count or tesseract.items_count
result.address = paddle.address or tesseract.address
result.description = paddle.description or tesseract.description
return result
def _has_company_indicator(self, name: Optional[str]) -> bool:
"""Check if vendor name has company type indicator (S.R.L., S.A., etc.)"""
if not name:
return False
name_upper = name.upper()
indicators = [
r'\bS\.?\s*R\.?\s*L\.?\b',
r'\bS\.?\s*A\.?\b',
r'\bS\.?\s*N\.?\s*C\.?\b',
r'\bP\.?\s*F\.?\s*A\.?\b',
r'\bI\.?\s*I\.?\b',
r'\bHOLDING\b',
r'\bGROUP\b',
r'\bCOMPANY\b',
]
for indicator in indicators:
if re.search(indicator, name_upper):
return True
return False
def _is_valid_cui(self, cui: Optional[str]) -> bool:
"""Validate CUI format: 6-10 digits."""
if not cui:
return False
# Remove any RO prefix
cui_clean = re.sub(r'^RO', '', cui.upper())
# Must be 6-10 digits
return bool(re.match(r'^\d{6,10}$', cui_clean))
def _is_extraction_complete(self, ext: ExtractionResult, min_confidence: float = 0.85) -> bool:
"""
Check if extraction has ALL required fields to skip further processing.
Required for early exit (ALL must be true):
- Overall confidence >= 85%
- ALL 5 critical fields present: number, date, amount, TVA, CUI
"""
# Must have high confidence
if ext.overall_confidence < min_confidence:
print(f"[OCR] Confidence {ext.overall_confidence:.0%} < {min_confidence:.0%} - continuing", flush=True)
return False
# Check all required fields
has_number = bool(ext.receipt_number)
has_date = bool(ext.receipt_date)
has_amount = bool(ext.amount)
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
has_cui = bool(ext.cui)
missing = []
if not has_number: missing.append("number")
if not has_date: missing.append("date")
if not has_amount: missing.append("amount")
if not has_tva: missing.append("TVA")
if not has_cui: missing.append("CUI")
if missing:
print(f"[OCR] Missing: {', '.join(missing)} - continuing", flush=True)
return False
print(f"[OCR] OK: All 5 fields found with {ext.overall_confidence:.0%} confidence", flush=True)
return True
def _complement_extraction(
self,
primary: Optional[ExtractionResult],
secondary: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Complement primary extraction with missing fields from secondary.
NEVER overrides existing values - only fills in gaps.
This is different from _merge_extractions which can override values.
"""
if primary is None and secondary is None:
return ExtractionResult()
if primary is None:
return secondary
if secondary is None:
return primary
print("[Complement] Adding missing fields from Tesseract...", flush=True)
# Only fill missing amount
if not primary.amount and secondary.amount:
primary.amount = secondary.amount
primary.confidence_amount = secondary.confidence_amount
print(f"[Complement] Added amount: {secondary.amount}", flush=True)
# Only fill missing date
if not primary.receipt_date and secondary.receipt_date:
primary.receipt_date = secondary.receipt_date
primary.confidence_date = secondary.confidence_date
print(f"[Complement] Added date: {secondary.receipt_date}", flush=True)
# Only fill missing vendor
if not primary.partner_name and secondary.partner_name:
primary.partner_name = secondary.partner_name
primary.confidence_vendor = secondary.confidence_vendor
print(f"[Complement] Added vendor: {secondary.partner_name}", flush=True)
# Only fill missing CUI
if not primary.cui and secondary.cui and self._is_valid_cui(secondary.cui):
primary.cui = secondary.cui
print(f"[Complement] Added CUI: {secondary.cui}", flush=True)
# Only fill missing TVA
if not primary.tva_entries and secondary.tva_entries:
primary.tva_entries = secondary.tva_entries
primary.tva_total = secondary.tva_total
print(f"[Complement] Added TVA: {secondary.tva_total}", flush=True)
# Only fill missing receipt number
if not primary.receipt_number and secondary.receipt_number:
primary.receipt_number = secondary.receipt_number
print(f"[Complement] Added number: {secondary.receipt_number}", flush=True)
# Only fill missing address
if not primary.address and secondary.address:
primary.address = secondary.address
print(f"[Complement] Added address: {secondary.address}", flush=True)
return primary
def _final_validation(self, extraction: ExtractionResult) -> ExtractionResult:
"""
Final validation and correction of impossible values.
Key rules:
1. TVA cannot be greater than TOTAL (it's always a fraction)
2. If TVA > TOTAL, recalculate TOTAL from TVA using known rates
3. Validate TVA entries sum equals TVA total
"""
print("[Final Validation] Checking extracted values...", flush=True)
# Rule 1: TVA cannot be greater than TOTAL
if extraction.tva_total and extraction.amount:
if extraction.tva_total > extraction.amount:
print(f"[Final Validation] TVA ({extraction.tva_total}) > TOTAL ({extraction.amount}) - IMPOSSIBLE!", flush=True)
# Calculate TOTAL from TVA using reverse formula:
# total = base + tva = tva * (100/rate + 1) = tva * (100 + rate) / rate
# For 9% TVA: total = tva * 109 / 9 = tva * 12.11
# For 19% TVA: total = tva * 119 / 19 = tva * 6.26
# For 21% TVA: total = tva * 121 / 21 = tva * 5.76
rate = 19 # Default rate assumption
if extraction.tva_entries:
# Use the rate from the first entry
rate = extraction.tva_entries[0].get('percent', 19)
if rate > 0:
# Formula: total = tva * (100 + rate) / rate
calculated_total = extraction.tva_total * (Decimal('100') + Decimal(str(rate))) / Decimal(str(rate))
calculated_total = calculated_total.quantize(Decimal('0.01'))
print(f"[Final Validation] Calculated TOTAL from TVA: {calculated_total} (using {rate}% rate)", flush=True)
extraction.amount = calculated_total
extraction.confidence_amount = 0.70 # Lower confidence for calculated value
# Rule 2: TVA cannot be more than ~25% of total (max Romanian rate is 21%)
if extraction.tva_total and extraction.amount:
tva_percent = extraction.tva_total / extraction.amount * Decimal('100')
if tva_percent > Decimal('25'):
print(f"[Final Validation] Warning: TVA is {tva_percent:.1f}% of total - suspicious", flush=True)
# Rule 3: Validate TVA entries sum
if extraction.tva_entries and extraction.tva_total:
entries_sum = sum(e.get('amount', Decimal('0')) for e in extraction.tva_entries)
tolerance = Decimal('0.05')
if abs(entries_sum - extraction.tva_total) > tolerance:
print(f"[Final Validation] TVA entries sum ({entries_sum}) != tva_total ({extraction.tva_total})", flush=True)
# Use the sum as it's more reliable
extraction.tva_total = entries_sum
print(f"[Final Validation] Done. Amount={extraction.amount}, TVA={extraction.tva_total}", flush=True)
return extraction
# Singleton instance
ocr_service = OCRService()

View File

@@ -0,0 +1,385 @@
"""
Auto-create Receipt from OCR results for bulk upload flow.
This service handles automatic creation of Receipt records from OCR extraction
results, enabling end-to-end processing without manual UI intervention.
The service:
1. Maps OCR ExtractionData fields to Receipt fields
2. Creates attachment from the original uploaded file
3. Generates accounting entries
4. Links the receipt back to the batch job for tracking
"""
import logging
import shutil
import uuid
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.receipt import (
Receipt,
ReceiptAttachment,
ReceiptStatus,
ReceiptType,
ReceiptDirection,
)
from backend.modules.data_entry.db.models.batch import BatchJob
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, TvaEntrySchema, PaymentMethodSchema
from backend.modules.data_entry.schemas.ocr import ExtractionData
from backend.modules.data_entry.services.receipt_service import ReceiptService
from backend.modules.data_entry.services import sse_service
from backend.config import settings
logger = logging.getLogger(__name__)
@dataclass
class ReceiptCreateResult:
"""Result of auto-create operation."""
success: bool
receipt_id: Optional[int] = None
error_message: Optional[str] = None
class ReceiptAutoCreateService:
"""
Service for automatically creating receipts from OCR results.
Used by the bulk upload flow to create receipts without user intervention.
Created receipts are in DRAFT status and require review before approval.
"""
@staticmethod
def _validate_ocr_result(ocr_result: ExtractionData) -> tuple[bool, str]:
"""
Perform minimal validation on OCR result.
Validates:
- amount > 0 (required for receipt)
- date is valid and not in future
Args:
ocr_result: Extracted data from OCR
Returns:
Tuple of (is_valid, error_message)
"""
# Validate amount exists and is positive
if ocr_result.amount is None:
return False, "Amount not extracted from receipt"
if ocr_result.amount <= 0:
return False, f"Invalid amount: {ocr_result.amount} (must be > 0)"
# Validate date exists and is not in the future
if ocr_result.receipt_date is None:
return False, "Receipt date not extracted"
today = date.today()
if ocr_result.receipt_date > today:
return False, f"Receipt date {ocr_result.receipt_date} is in the future"
return True, ""
@staticmethod
def _map_ocr_to_receipt(
ocr_result: ExtractionData,
company_id: int,
) -> ReceiptCreate:
"""
Map OCR ExtractionData fields to ReceiptCreate schema.
Args:
ocr_result: Extracted data from OCR
company_id: Company ID for the receipt
Returns:
ReceiptCreate schema ready for database insertion
"""
# Map receipt type
receipt_type = ReceiptType.BON_FISCAL
if ocr_result.receipt_type == "chitanta":
receipt_type = ReceiptType.CHITANTA
# Map TVA breakdown from OCR TvaEntry to schema TvaEntrySchema
tva_breakdown: Optional[List[TvaEntrySchema]] = None
if ocr_result.tva_entries:
tva_breakdown = [
TvaEntrySchema(
code=entry.code,
percent=entry.percent,
amount=entry.amount
)
for entry in ocr_result.tva_entries
]
# Map payment methods
payment_methods: Optional[List[PaymentMethodSchema]] = None
if ocr_result.payment_methods:
payment_methods = [
PaymentMethodSchema(
method=pm.method,
amount=pm.amount
)
for pm in ocr_result.payment_methods
]
# Create receipt data
return ReceiptCreate(
receipt_type=receipt_type,
direction=ReceiptDirection.CHELTUIALA, # Default to expense
receipt_number=ocr_result.receipt_number,
receipt_series=ocr_result.receipt_series,
receipt_date=ocr_result.receipt_date,
amount=ocr_result.amount,
description=ocr_result.description,
tva_breakdown=tva_breakdown,
tva_total=ocr_result.tva_total,
items_count=ocr_result.items_count,
vendor_address=ocr_result.address,
company_id=company_id,
partner_name=ocr_result.partner_name,
cui=ocr_result.cui,
ocr_raw_text=ocr_result.raw_text[:5000] if ocr_result.raw_text else None, # Limit size
payment_methods=payment_methods,
payment_mode=ocr_result.suggested_payment_mode,
)
@staticmethod
async def _create_attachment_from_file(
session: AsyncSession,
receipt_id: int,
source_file_path: str,
original_filename: Optional[str] = None,
) -> Optional[ReceiptAttachment]:
"""
Create attachment by copying file from OCR job location.
Args:
session: Database session
receipt_id: Receipt ID to attach to
source_file_path: Path to the original file from OCR job
original_filename: Original filename from upload (optional)
Returns:
Created ReceiptAttachment or None if failed
"""
source_path = Path(source_file_path)
if not source_path.exists():
logger.warning(f"[ReceiptAutoCreate] Source file not found: {source_path}")
return None
# Generate stored filename
ext = source_path.suffix.lower()
stored_filename = f"{uuid.uuid4()}{ext}"
# Determine relative path (organized by year/month)
now = datetime.utcnow()
relative_path = Path(str(now.year)) / f"{now.month:02d}"
# Full destination path
dest_dir = settings.data_entry_upload_path_resolved / relative_path
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = dest_dir / stored_filename
# Copy file to attachments directory
try:
shutil.copy2(source_path, dest_path)
except Exception as e:
logger.error(f"[ReceiptAutoCreate] Failed to copy file: {e}")
return None
# Get file size
file_size = dest_path.stat().st_size
# Determine MIME type
mime_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.pdf': 'application/pdf',
}
mime_type = mime_map.get(ext, 'application/octet-stream')
# Use original filename if provided, otherwise use source filename
display_filename = original_filename or source_path.name
# Create attachment record
attachment = ReceiptAttachment(
receipt_id=receipt_id,
filename=display_filename,
stored_filename=stored_filename,
file_path=str(relative_path / stored_filename),
file_size=file_size,
mime_type=mime_type,
)
session.add(attachment)
await session.flush()
return attachment
@staticmethod
async def _update_batch_job_receipt_id(
session: AsyncSession,
job_id: str,
receipt_id: int,
) -> None:
"""
Update batch_jobs table with the created receipt_id.
Args:
session: Database session
job_id: OCR job UUID
receipt_id: Created receipt ID
"""
await session.execute(
update(BatchJob)
.where(BatchJob.job_id == job_id)
.values(receipt_id=receipt_id)
)
@staticmethod
async def create_from_ocr_result(
session: AsyncSession,
job_id: str,
ocr_result: ExtractionData,
username: str,
batch_id: int,
company_id: int,
file_path: Optional[str] = None,
original_filename: Optional[str] = None,
file_hash: Optional[str] = None,
) -> ReceiptCreateResult:
"""
Create a receipt from OCR extraction result.
This method:
1. Validates the OCR result (amount > 0, date valid)
2. Maps OCR fields to Receipt fields
3. Creates the Receipt in DRAFT status
4. Creates attachment from original file
5. Generates accounting entries
6. Updates batch_jobs with receipt_id
Args:
session: Database session
job_id: OCR job UUID for tracking
ocr_result: Extracted data from OCR processing
username: User who initiated the upload
batch_id: Batch ID for grouping
company_id: Company ID for the receipt
file_path: Path to the original uploaded file
original_filename: Original filename from upload
file_hash: SHA-256 hash of the file for duplicate detection (US-007)
Returns:
ReceiptCreateResult with success status and receipt_id or error
"""
try:
# Step 1: Validate OCR result
is_valid, error_msg = ReceiptAutoCreateService._validate_ocr_result(ocr_result)
if not is_valid:
logger.warning(f"[ReceiptAutoCreate] Validation failed for job {job_id}: {error_msg}")
return ReceiptCreateResult(
success=False,
error_message=error_msg
)
# Step 2: Map OCR to Receipt schema
receipt_data = ReceiptAutoCreateService._map_ocr_to_receipt(
ocr_result=ocr_result,
company_id=company_id,
)
# Step 3: Create receipt in DRAFT status
receipt = await ReceiptCRUD.create(session, receipt_data, created_by=username)
# Set batch tracking fields (US-007, US-011)
receipt.batch_id = str(batch_id)
receipt.file_hash = file_hash
receipt.processing_status = "completed"
session.add(receipt)
await session.flush()
logger.info(
f"[ReceiptAutoCreate] Created receipt {receipt.id} for job {job_id}: "
f"amount={receipt.amount}, vendor={receipt.partner_name}, file_hash={file_hash[:16] if file_hash else None}..."
)
# Step 4: Create attachment from original file (if path provided)
if file_path:
attachment = await ReceiptAutoCreateService._create_attachment_from_file(
session=session,
receipt_id=receipt.id,
source_file_path=file_path,
original_filename=original_filename,
)
if attachment:
logger.info(f"[ReceiptAutoCreate] Created attachment for receipt {receipt.id}")
else:
logger.warning(f"[ReceiptAutoCreate] Failed to create attachment for receipt {receipt.id}")
# Step 5: Generate accounting entries
# Note: For DRAFT status, entries are generated but not required for validation
try:
entries = ReceiptService.generate_accounting_entries(receipt)
if entries:
await AccountingEntryCRUD.create_bulk(
session, receipt.id, entries, is_auto_generated=True
)
logger.info(
f"[ReceiptAutoCreate] Generated {len(entries)} accounting entries "
f"for receipt {receipt.id}"
)
except Exception as e:
# Don't fail the receipt creation if entry generation fails
logger.warning(
f"[ReceiptAutoCreate] Failed to generate entries for receipt {receipt.id}: {e}"
)
# Step 6: Update batch_jobs with receipt_id
await ReceiptAutoCreateService._update_batch_job_receipt_id(
session=session,
job_id=job_id,
receipt_id=receipt.id,
)
# Commit all changes
await session.commit()
# Broadcast SSE event for real-time updates (US-030)
try:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
except Exception as e:
# Don't fail the receipt creation if SSE broadcast fails
logger.warning(f"[ReceiptAutoCreate] SSE broadcast failed for receipt {receipt.id}: {e}")
return ReceiptCreateResult(
success=True,
receipt_id=receipt.id
)
except Exception as e:
logger.error(f"[ReceiptAutoCreate] Failed to create receipt for job {job_id}: {e}")
await session.rollback()
return ReceiptCreateResult(
success=False,
error_message=str(e)
)

View File

@@ -0,0 +1,457 @@
"""Business logic service for receipts workflow."""
from decimal import Decimal, ROUND_HALF_UP
from typing import List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus, ReceiptDirection
from backend.modules.data_entry.db.models.accounting_entry import EntryType
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.schemas.receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptFilter,
ReceiptResponse,
ReceiptListResponse,
ProcessingStats,
AccountingEntryCreate,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES, get_expense_type
# Payment mode to accounting account mapping
PAYMENT_MODE_ACCOUNTS = {
'casa': ('5311', 'Casa in lei'),
'banca': ('5121', 'Conturi la banci in lei'),
'avans_decontare': ('542', 'Avansuri de trezorerie'),
}
class ReceiptService:
"""Service for receipt business logic and workflow."""
@staticmethod
async def create_receipt(
session: AsyncSession,
data: ReceiptCreate,
created_by: str,
) -> Receipt:
"""Create a new receipt in DRAFT status."""
return await ReceiptCRUD.create(session, data, created_by)
@staticmethod
async def get_receipt(
session: AsyncSession,
receipt_id: int,
) -> Optional[Receipt]:
"""Get receipt by ID with all relationships."""
return await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
@staticmethod
async def get_receipts(
session: AsyncSession,
filters: ReceiptFilter,
) -> ReceiptListResponse:
"""Get paginated list of receipts with processing_stats (US-012)."""
receipts, total = await ReceiptCRUD.get_list(session, filters)
pages = (total + filters.page_size - 1) // filters.page_size if total > 0 else 1
# Get processing stats for bulk uploaded receipts (US-012)
stats_dict = await ReceiptCRUD.get_processing_stats(
session,
company_id=filters.company_id,
batch_id=filters.batch_id,
)
processing_stats = ProcessingStats(**stats_dict)
return ReceiptListResponse(
items=[ReceiptResponse.model_validate(r) for r in receipts],
total=total,
page=filters.page,
page_size=filters.page_size,
pages=pages,
processing_stats=processing_stats,
)
@staticmethod
async def update_receipt(
session: AsyncSession,
receipt_id: int,
data: ReceiptUpdate,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Update receipt (only DRAFT status).
Returns (success, message, receipt).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_edit(receipt, username):
return False, "Cannot edit this receipt", None
updated = await ReceiptCRUD.update(session, receipt, data)
return True, "Receipt updated", updated
@staticmethod
async def delete_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str]:
"""
Delete receipt (only DRAFT status).
Returns (success, message).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found"
if not await ReceiptCRUD.can_delete(receipt, username):
return False, "Cannot delete this receipt"
await ReceiptCRUD.delete(session, receipt)
return True, "Receipt deleted"
@staticmethod
def generate_accounting_entries(receipt: Receipt) -> List[AccountingEntryCreate]:
"""
Generate accounting entries based on receipt data and expense type.
"""
entries: List[AccountingEntryCreate] = []
# Get expense type configuration
expense_type = get_expense_type(receipt.expense_type_code or "OTHER")
if not expense_type:
expense_type = EXPENSE_TYPES["OTHER"]
amount = Decimal(str(receipt.amount))
if receipt.direction == ReceiptDirection.CHELTUIALA:
# Expense: Debit expense account, Credit cash/bank
if expense_type.has_vat:
# Calculate net and VAT
vat_rate = expense_type.vat_percent / Decimal("100")
net_amount = (amount / (1 + vat_rate)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
vat_amount = amount - net_amount
# Debit: Expense account (net)
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=net_amount,
))
# Debit: VAT deductible
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.vat_account,
account_name="TVA deductibila",
amount=vat_amount,
))
else:
# No VAT - full amount to expense
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=amount,
))
# Credit entry - based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
credit_account, credit_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
# Backwards compatibility for existing receipts
credit_account = receipt.cash_register_account
credit_name = receipt.cash_register_name or "Casa/Banca"
else:
# Default fallback
credit_account = "5311"
credit_name = "Casa in lei"
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code=credit_account,
account_name=credit_name,
amount=amount,
))
else:
# Income: Debit cash/bank, Credit income account
# Based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
cash_account, cash_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
cash_account = receipt.cash_register_account
cash_name = receipt.cash_register_name or "Casa/Banca"
else:
cash_account = "5311"
cash_name = "Casa in lei"
# Debit: Cash/Bank
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=cash_account,
account_name=cash_name,
amount=amount,
))
# Credit: Income account (7xx - to be configured)
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code="7588",
account_name="Alte venituri din exploatare",
amount=amount,
))
return entries
@staticmethod
async def submit_for_review(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Submit receipt for review (DRAFT/REJECTED → PENDING_REVIEW).
Generates accounting entries automatically.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_submit(receipt, username):
return False, "Cannot submit this receipt", None
# Check if receipt has at least one attachment
if not receipt.attachments:
return False, "Receipt must have at least one attachment", None
# Check required fields
if not receipt.expense_type_code:
return False, "Expense type is required", None
# Validate payment_mode or cash_register (backwards compatibility)
if not receipt.payment_mode and not receipt.cash_register_account:
return False, "Modul de plata este obligatoriu", None
# Generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Delete existing entries and create new ones
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt submitted for review", updated
@staticmethod
async def approve_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Approve receipt (PENDING_REVIEW → APPROVED).
Requires valid CUI (fiscal code) for approval.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Validate CUI is present (required for Oracle import)
if not receipt.cui:
return False, "Trebuie completat codul fiscal (CUI) pentru aprobare", None
# Validate accounting entries
if not receipt.entries:
return False, "Receipt has no accounting entries", None
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.APPROVED, reviewed_by=username
)
return True, "Receipt approved", updated
@staticmethod
async def unapprove_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Unapprove receipt (APPROVED → PENDING_REVIEW).
Returns receipt to pending review for corrections.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.APPROVED:
return False, "Receipt is not approved", None
# Update status back to pending review
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
return True, "Receipt returned to pending review", updated
@staticmethod
async def reject_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
reason: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Reject receipt (PENDING_REVIEW → REJECTED).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Update status
updated = await ReceiptCRUD.update_status(
session,
receipt,
ReceiptStatus.REJECTED,
reviewed_by=username,
rejection_reason=reason,
)
return True, "Receipt rejected", updated
@staticmethod
async def resubmit_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.REJECTED:
return False, "Receipt is not rejected", None
if receipt.created_by != username:
return False, "Only the creator can resubmit", None
# Re-generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt resubmitted for review", updated
@staticmethod
async def regenerate_entries(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, List[AccountingEntryCreate]]:
"""
Regenerate accounting entries for a receipt.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.PENDING_REVIEW]:
return False, "Cannot regenerate entries for this receipt status", []
# Generate new entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Replace existing entries
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
return True, "Entries regenerated", entries
@staticmethod
async def update_entries(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
username: str,
) -> Tuple[bool, str, List]:
"""
Update accounting entries for a receipt (accountant action).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Can only modify entries for receipts pending review", []
# Validate entries
is_valid, error = await AccountingEntryCRUD.validate_entries(entries)
if not is_valid:
return False, error, []
# Replace entries
updated_entries = await AccountingEntryCRUD.replace_all_for_receipt(
session, receipt_id, entries, username
)
return True, "Entries updated", updated_entries
@staticmethod
async def get_pending_count(
session: AsyncSession,
company_id: Optional[int] = None,
) -> int:
"""Get count of receipts pending review."""
receipts = await ReceiptCRUD.get_pending_review(session, company_id)
return len(receipts)

View File

@@ -0,0 +1,197 @@
"""
Server-Sent Events (SSE) service for real-time status updates.
This module implements an event broadcaster pattern using asyncio.Queue per client.
When receipt status changes occur (CRUD operations), events are pushed to all
connected clients who are listening for that specific batch or all receipts.
Usage:
# In router endpoint (SSE stream):
async for event in sse_service.subscribe(batch_id=None):
yield event
# When status changes (from CRUD operations):
await sse_service.broadcast_status_change(receipt_id, status, processing_status, batch_id)
"""
import asyncio
import json
import logging
from dataclasses import dataclass, asdict
from typing import AsyncGenerator, Optional
from datetime import datetime
logger = logging.getLogger(__name__)
@dataclass
class StatusChangeEvent:
"""Event data for receipt status changes."""
receipt_id: int
status: str
processing_status: Optional[str] = None
batch_id: Optional[str] = None
timestamp: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow().isoformat()
def to_sse_data(self) -> str:
"""Format as SSE data line."""
data = asdict(self)
return f"data: {json.dumps(data)}\n\n"
class SSEEventBroadcaster:
"""
Manages SSE client connections and broadcasts events.
Each client gets its own asyncio.Queue. When an event occurs,
it's pushed to all relevant queues based on batch_id filtering.
"""
def __init__(self):
# Dict of {client_id: (queue, batch_id_filter)}
# batch_id_filter is None for clients that want all events
self._clients: dict[str, tuple[asyncio.Queue, Optional[str]]] = {}
self._client_counter = 0
self._lock = asyncio.Lock()
async def _generate_client_id(self) -> str:
"""Generate unique client ID."""
async with self._lock:
self._client_counter += 1
return f"client_{self._client_counter}_{datetime.utcnow().timestamp()}"
async def subscribe(
self,
batch_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""
Subscribe to SSE events.
Args:
batch_id: Optional filter - only receive events for this batch.
If None, receives all events.
Yields:
SSE-formatted event strings (ready to send to client).
"""
client_id = await self._generate_client_id()
queue: asyncio.Queue = asyncio.Queue()
# Register client
async with self._lock:
self._clients[client_id] = (queue, batch_id)
logger.info(
f"SSE client {client_id} connected (batch_id filter: {batch_id}). "
f"Total clients: {len(self._clients)}"
)
try:
# Send initial retry hint for reconnection
yield "retry: 3000\n\n"
# Keep connection alive and yield events
while True:
try:
# Wait for events with timeout for keep-alive
event = await asyncio.wait_for(queue.get(), timeout=30.0)
yield event
except asyncio.TimeoutError:
# Send keep-alive comment to prevent connection timeout
yield ": keep-alive\n\n"
except asyncio.CancelledError:
logger.info(f"SSE client {client_id} subscription cancelled")
raise
finally:
# Cleanup: remove client from registry
async with self._lock:
self._clients.pop(client_id, None)
logger.info(
f"SSE client {client_id} disconnected. "
f"Remaining clients: {len(self._clients)}"
)
async def broadcast_status_change(
self,
receipt_id: int,
status: str,
processing_status: Optional[str] = None,
batch_id: Optional[str] = None,
) -> int:
"""
Broadcast a status change event to all relevant clients.
Args:
receipt_id: The receipt ID that changed.
status: New workflow status (DRAFT, PENDING_REVIEW, etc.).
processing_status: New processing status (pending, processing, completed, failed).
batch_id: The batch ID this receipt belongs to (for filtering).
Returns:
Number of clients notified.
"""
event = StatusChangeEvent(
receipt_id=receipt_id,
status=status,
processing_status=processing_status,
batch_id=batch_id,
)
sse_data = event.to_sse_data()
notified = 0
async with self._lock:
for client_id, (queue, client_batch_filter) in self._clients.items():
# Send event if:
# 1. Client has no filter (wants all events), OR
# 2. Client's filter matches the event's batch_id
if client_batch_filter is None or client_batch_filter == batch_id:
try:
queue.put_nowait(sse_data)
notified += 1
except asyncio.QueueFull:
logger.warning(
f"SSE queue full for client {client_id}, dropping event"
)
if notified > 0:
logger.debug(
f"SSE broadcast: receipt_id={receipt_id}, status={status}, "
f"processing_status={processing_status}, notified={notified} clients"
)
return notified
@property
def client_count(self) -> int:
"""Get current number of connected clients."""
return len(self._clients)
# Singleton instance for the application
sse_broadcaster = SSEEventBroadcaster()
# Convenience functions for external use
async def subscribe(batch_id: Optional[str] = None) -> AsyncGenerator[str, None]:
"""Subscribe to SSE status change events."""
async for event in sse_broadcaster.subscribe(batch_id):
yield event
async def broadcast_status_change(
receipt_id: int,
status: str,
processing_status: Optional[str] = None,
batch_id: Optional[str] = None,
) -> int:
"""Broadcast a status change event."""
return await sse_broadcaster.broadcast_status_change(
receipt_id=receipt_id,
status=status,
processing_status=processing_status,
batch_id=batch_id,
)

View File

@@ -0,0 +1,451 @@
"""Service for syncing nomenclatures from Oracle to SQLite."""
import sys
from pathlib import Path
from typing import Optional, List, Tuple
from datetime import datetime
import logging
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
# Path setup handled by main.py - this is redundant
# project_root = Path(__file__).parent.parent.parent.parent.parent
# sys.path.insert(0, str(project_root / "shared"))
from shared.database.oracle_pool import oracle_pool
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
logger = logging.getLogger(__name__)
# Cache for schema lookups (populated dynamically from Oracle)
# Key format: (server_id, company_id) for multi-server support
_schema_cache: dict[tuple, str] = {}
class SyncService:
"""Service for syncing nomenclatures from Oracle."""
@staticmethod
async def get_schema_for_company(company_id: int, server_id: Optional[str] = None) -> Optional[str]:
"""
Get Oracle schema for company ID from V_NOM_FIRME view.
Results are cached in memory for performance.
Args:
company_id: The company ID to look up
server_id: Optional Oracle server ID for multi-server mode
"""
# Check cache first - use (server_id, company_id) as key for multi-server support
cache_key = (server_id, company_id)
if cache_key in _schema_cache:
return _schema_cache[cache_key]
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT SCHEMA
FROM CONTAFIN_ORACLE.V_NOM_FIRME
WHERE ID_FIRMA = :company_id
""", {'company_id': company_id})
result = cursor.fetchone()
if result:
schema = result[0]
_schema_cache[cache_key] = schema
logger.info(f"Resolved schema for company {company_id} on server {server_id}: {schema}")
return schema
else:
logger.warning(f"No schema found for company {company_id} on server {server_id}")
return None
except Exception as e:
logger.error(f"Error fetching schema for company {company_id} on server {server_id}: {e}")
return None
@staticmethod
async def sync_suppliers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
"""
Sync suppliers (furnizori, id_tip_part=17) from Oracle to SQLite.
Uses CORESP_TIP_PART joined with VNOM_PARTENERI view.
Returns (synced_count, error_count).
Args:
session: SQLAlchemy async session for SQLite
company_id: The company ID to sync suppliers for
server_id: Optional Oracle server ID for multi-server mode
"""
schema = await SyncService.get_schema_for_company(company_id, server_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
return 0, 0
synced = 0
errors = 0
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Fetch active suppliers from Oracle
# id_tip_part = 17 means "furnizori" (suppliers)
# Using CORESP_TIP_PART to filter by partner type
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, B.COD_FISCAL, B.ADRESA
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART = 17
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY B.DENUMIRE
""")
rows = cursor.fetchall()
for row in rows:
try:
oracle_id, name, fiscal_code, address = row
# Check if already exists
stmt = select(SyncedSupplier).where(
SyncedSupplier.oracle_id == oracle_id,
SyncedSupplier.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.fiscal_code = fiscal_code
existing.address = address
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated supplier {oracle_id}: {name}")
else:
# Create new record
supplier = SyncedSupplier(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
fiscal_code=fiscal_code,
address=address,
)
session.add(supplier)
logger.debug(f"Created supplier {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing supplier row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} suppliers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing suppliers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
async def sync_cash_registers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
"""
Sync cash registers and bank accounts from Oracle to SQLite.
Returns (synced_count, error_count).
Uses CORESP_TIP_PART with:
- id_tip_part = 22: CASA LEI
- id_tip_part = 23: CASA VALUTA
- id_tip_part = 24: BANCA LEI
- id_tip_part = 25: BANCA VALUTA
Args:
session: SQLAlchemy async session for SQLite
company_id: The company ID to sync cash registers for
server_id: Optional Oracle server ID for multi-server mode
"""
schema = await SyncService.get_schema_for_company(company_id, server_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
return 0, 0
synced = 0
errors = 0
# Partner types mapping
# 22=CASA LEI, 23=CASA VALUTA -> cash
# 24=BANCA LEI, 25=BANCA VALUTA -> bank
partner_types = [22, 23, 24, 25]
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Fetch cash/bank partners from CORESP_TIP_PART
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, A.ID_TIP_PART
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART IN (22, 23, 24, 25)
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY A.ID_TIP_PART, B.DENUMIRE
""")
rows = cursor.fetchall()
# Type mapping: 22=CASA LEI, 23=CASA VALUTA -> cash; 24=BANCA LEI, 25=BANCA VALUTA -> bank
type_mapping = {
22: ("cash", "CASA_LEI"),
23: ("cash", "CASA_VALUTA"),
24: ("bank", "BANCA_LEI"),
25: ("bank", "BANCA_VALUTA"),
}
for row in rows:
try:
oracle_id, name, tip_part_id = row
# Determine type based on partner type
register_type, account_code = type_mapping.get(tip_part_id, ("cash", "UNKNOWN"))
# Check if already exists
stmt = select(SyncedCashRegister).where(
SyncedCashRegister.oracle_id == oracle_id,
SyncedCashRegister.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.account_code = account_code
existing.register_type = register_type
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated cash register {oracle_id}: {name}")
else:
# Create new record
cash_register = SyncedCashRegister(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
account_code=account_code,
register_type=register_type,
)
session.add(cash_register)
logger.debug(f"Created cash register {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing cash register row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} cash registers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing cash registers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
def _get_fiscal_code_variants(fiscal_code: str) -> list:
"""
Generate all possible variants of a Romanian fiscal code (CUI).
Database may store: "22891860", "RO22891860", "RO 22891860"
OCR may extract: "RO22891860" or "22891860"
"""
import re
# Extract just the digits
digits = re.sub(r'[^0-9]', '', fiscal_code)
if not digits:
return [fiscal_code]
# Generate all variants
variants = [
digits, # Just digits: 22891860
f"RO{digits}", # With RO prefix: RO22891860
f"RO {digits}", # With RO prefix and space: RO 22891860
]
# Also add the original if different
if fiscal_code not in variants:
variants.append(fiscal_code)
return variants
@staticmethod
async def search_supplier(
session: AsyncSession,
company_id: int,
fiscal_code: Optional[str] = None,
name: Optional[str] = None
) -> Tuple[bool, Optional[dict], str]:
"""
Search for supplier in SQLite first, then Oracle if not found.
Returns (found, supplier_data, source).
Source can be: 'synced', 'local', 'not_found'
"""
# 1. Search in synced suppliers
if fiscal_code:
# Search all variants of the fiscal code (with/without RO, with/without space)
variants = SyncService._get_fiscal_code_variants(fiscal_code)
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.fiscal_code.in_(variants)
)
elif name:
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.name.ilike(f"%{name}%")
)
else:
return False, None, "no_query"
result = await session.execute(stmt)
supplier = result.scalar_one_or_none()
if supplier:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": supplier.name,
"fiscal_code": supplier.fiscal_code,
"address": supplier.address,
}, "synced"
# 2. Search in local suppliers
if fiscal_code:
# Search all variants of the fiscal code (with/without RO, with/without space)
variants = SyncService._get_fiscal_code_variants(fiscal_code)
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.fiscal_code.in_(variants)
)
elif name:
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.name.ilike(f"%{name}%")
)
result = await session.execute(stmt)
local = result.scalar_one_or_none()
if local:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": local.name,
"fiscal_code": local.fiscal_code,
"address": local.address,
}, "local"
# 3. Try live Oracle search (optional fallback for unsynced data)
# This is a fallback - ideally sync should be up to date
# TODO: Implement live Oracle search if needed
return False, None, "not_found"
@staticmethod
async def create_local_supplier(
session: AsyncSession,
company_id: int,
name: str,
fiscal_code: Optional[str],
address: Optional[str],
created_by: str
) -> LocalSupplier:
"""Create a local supplier entry from OCR data."""
supplier = LocalSupplier(
company_id=company_id,
name=name,
fiscal_code=fiscal_code,
address=address,
created_by=created_by,
)
session.add(supplier)
await session.commit()
await session.refresh(supplier)
logger.info(f"Created local supplier: {name} (CUI: {fiscal_code})")
return supplier
@staticmethod
async def get_all_suppliers(
session: AsyncSession,
company_id: int,
search: Optional[str] = None
) -> List[dict]:
"""
Get all suppliers (synced + local) for a company.
Used for dropdown/autocomplete in UI.
"""
suppliers = []
# Get synced suppliers
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50) # Limit results for performance
result = await session.execute(stmt)
synced = result.scalars().all()
for s in synced:
suppliers.append({
"id": s.id,
"oracle_id": s.oracle_id,
"name": s.name,
"fiscal_code": s.fiscal_code,
"source": "synced"
})
# Get local suppliers
stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50)
result = await session.execute(stmt)
local = result.scalars().all()
for l in local:
suppliers.append({
"id": l.id,
"name": l.name,
"fiscal_code": l.fiscal_code,
"source": "local"
})
return suppliers
@staticmethod
async def get_all_cash_registers(
session: AsyncSession,
company_id: int
) -> List[dict]:
"""
Get all cash registers for a company.
Used for dropdown in UI.
"""
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
return [
{
"id": r.id,
"oracle_id": r.oracle_id,
"name": r.name,
"account_code": r.account_code,
"register_type": r.register_type
}
for r in registers
]