feat(ocr): Add docTR OCR engine with metrics infrastructure

Add docTR as primary OCR engine with 2-tier sequential processing,
OCR metrics tracking, and simplified engine selection.

Features:
- docTR OCR engine with light+medium preprocessing tiers
- doctr_plus mode with early exit optimization (~65% fast path)
- OCR metrics dashboard with per-engine statistics
- User OCR preference persistence
- Parallel worker pool for OCR processing
- Cross-validation for extraction quality

Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-02 05:37:16 +02:00
parent 74f7aefc26
commit 495790411f
75 changed files with 23349 additions and 1311 deletions

View File

@@ -104,6 +104,34 @@ MAX_UPLOAD_SIZE_MB=10
TEST_COMPANY_ID=110
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
# ============================================================================
# OCR ENGINE CONFIGURATION
# ============================================================================
# Control which OCR engines are loaded at startup.
# Disabling engines saves memory but limits available OCR modes.
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
# When disabled: 'paddleocr' engine unavailable
OCR_ENABLE_PADDLEOCR=true
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
# When disabled: 'tesseract' engine unavailable
OCR_ENABLE_TESSERACT=true
# Default OCR engine when not specified in request
# Options: tesseract, doctr, doctr_plus, paddleocr
# Recommended: doctr_plus (2-tier sequential with early exit)
OCR_DEFAULT_ENGINE=doctr_plus
# OCR Worker Pool Configuration
# Number of parallel OCR workers (each loads ~1GB for docTR)
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
OCR_WORKERS=2
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
OCR_MAX_TASKS_PER_CHILD=0
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================

View File

@@ -112,6 +112,40 @@ DATA_ENTRY_SQLITE_DATABASE_PATH=data/receipts/receipts.db
DATA_ENTRY_UPLOAD_PATH=data/receipts/uploads
DATA_ENTRY_MAX_UPLOAD_SIZE_MB=10
# ============================================================================
# OCR ENGINE CONFIGURATION
# ============================================================================
# Control which OCR engines are loaded at startup.
# Disabling engines saves memory but limits available OCR modes.
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
# When disabled: 'paddleocr' engine unavailable
OCR_ENABLE_PADDLEOCR=true
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
# When disabled: 'tesseract' engine unavailable
OCR_ENABLE_TESSERACT=true
# Default OCR engine when not specified in request
# Options: tesseract, doctr, doctr_plus, paddleocr
# Recommended: doctr_plus (2-tier sequential with early exit, ~7.5s avg)
OCR_DEFAULT_ENGINE=doctr_plus
# Active OCR engines shown in frontend dropdown (comma-separated)
# Options: tesseract, doctr, doctr_plus, paddleocr
# doctr_plus: 73.3% perfect, 7.5s avg, 65% fast path (recommended)
# doctr: 63.3% perfect, simpler but faster
OCR_ACTIVE_ENGINES=tesseract,doctr,doctr_plus,paddleocr
# OCR Worker Pool Configuration
# Number of parallel OCR workers (each loads ~1GB for docTR)
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
OCR_WORKERS=2
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
OCR_MAX_TASKS_PER_CHILD=0
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================

View File

@@ -96,6 +96,35 @@ SQLITE_DATABASE_PATH=data/receipts/receipts_prod.db
UPLOAD_PATH=data/receipts/uploads
MAX_UPLOAD_SIZE_MB=10
# ============================================================================
# OCR ENGINE CONFIGURATION
# ============================================================================
# Control which OCR engines are loaded at startup.
# Disabling engines saves memory but limits available OCR modes.
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
# When disabled: 'paddleocr' engine unavailable
# PRODUCTION: Set based on server memory availability
OCR_ENABLE_PADDLEOCR=false
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
# When disabled: 'tesseract' engine unavailable
OCR_ENABLE_TESSERACT=true
# Default OCR engine when not specified in request
# Options: tesseract, doctr, doctr_plus, paddleocr
# Recommended: doctr_plus (2-tier sequential with early exit)
OCR_DEFAULT_ENGINE=doctr_plus
# OCR Worker Pool Configuration
# Number of parallel OCR workers (each loads ~1GB for docTR)
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
OCR_WORKERS=2
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
OCR_MAX_TASKS_PER_CHILD=0
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================

View File

@@ -105,6 +105,34 @@ MAX_UPLOAD_SIZE_MB=10
TEST_COMPANY_ID=110
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
# ============================================================================
# OCR ENGINE CONFIGURATION
# ============================================================================
# Control which OCR engines are loaded at startup.
# Disabling engines saves memory but limits available OCR modes.
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
# When disabled: 'paddleocr' engine unavailable
OCR_ENABLE_PADDLEOCR=false
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
# When disabled: 'tesseract' engine unavailable
OCR_ENABLE_TESSERACT=true
# Default OCR engine when not specified in request
# Options: tesseract, doctr, doctr_plus, paddleocr
# Recommended: doctr_plus (2-tier sequential with early exit)
OCR_DEFAULT_ENGINE=doctr_plus
# OCR Worker Pool Configuration
# Number of parallel OCR workers (each loads ~1GB for docTR)
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
OCR_WORKERS=2
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
OCR_MAX_TASKS_PER_CHILD=0
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================

View File

@@ -0,0 +1,168 @@
@echo off
setlocal enabledelayedexpansion
cd /d "%~dp0"
REM Parse command line arguments for worker counts
REM Usage: TEST-OCR-WINDOWS.bat [worker_counts...]
REM Examples:
REM TEST-OCR-WINDOWS.bat -> tests 1,2,3 workers (default)
REM TEST-OCR-WINDOWS.bat 1 -> tests only 1 worker
REM TEST-OCR-WINDOWS.bat 3 6 -> tests 3 and 6 workers
REM TEST-OCR-WINDOWS.bat 1 2 3 4 5 6 -> tests all
set "WORKER_LIST=%*"
if "%WORKER_LIST%"=="" set "WORKER_LIST=1 2 3"
echo.
echo ==========================================
echo OCR Benchmark - Windows (Workers: %WORKER_LIST%)
echo ==========================================
echo.
REM Check if Poppler is installed
where pdftoppm >nul 2>&1
if errorlevel 1 (
echo Checking for Poppler...
if exist "E:\poppler" (
for /r "E:\poppler" %%i in (pdftoppm.exe) do (
set "POPPLER_BIN=%%~dpi"
goto :found_poppler
)
)
echo.
echo ERROR: Poppler not found!
pause
exit /b 1
)
:found_poppler
if defined POPPLER_BIN (
echo Found Poppler at: %POPPLER_BIN%
set "PATH=%POPPLER_BIN%;%PATH%"
)
REM Check venv
if not exist "venv-win\Scripts\python.exe" (
echo ERROR: venv-win not found!
echo Run: python -m venv venv-win
echo Then: venv-win\Scripts\pip install -r requirements.txt
pause
exit /b 1
)
REM Set common environment
set JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
set ORACLE_HOST=10.0.20.121
set ORACLE_PORT=1521
set ORACLE_USER=CONTAFIN_ORACLE
set ORACLE_PASSWORD=ROMFASTSOFT
set ORACLE_SERVICE_NAME=ROA
set OCR_ENABLE_PADDLEOCR=false
set OCR_ENABLE_TESSERACT=false
set OCR_DEFAULT_ENGINE=hybrid-doctr
set OCR_MAX_TASKS_PER_CHILD=0
set LOG_LEVEL=WARNING
REM Results file with timestamp
for /f "tokens=2 delims==" %%I in ('wmic os get localdatetime /value') do set datetime=%%I
set RESULTS_FILE=ocr_benchmark_%datetime:~0,8%_%datetime:~8,4%.json
echo Results will be saved to: %RESULTS_FILE%
echo.
REM Delete old results file if exists
if exist "%RESULTS_FILE%" del "%RESULTS_FILE%"
REM Run tests with specified workers
for %%W in (%WORKER_LIST%) do (
call :run_test %%W
)
goto :show_summary
:run_test
set WORKERS=%1
echo.
echo ############################################################
echo STARTING TEST WITH %WORKERS% WORKER(S)
echo ############################################################
echo.
REM Kill existing processes on port 8006
echo Cleaning up old processes...
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
taskkill /F /PID %%a >nul 2>&1
)
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
timeout /t 3 >nul
REM Set workers count
set OCR_WORKERS=%WORKERS%
echo Starting backend with %WORKERS% OCR worker(s)...
REM Start backend in a new minimized window with all OCR env vars
start /min "ROA2WEB Backend %WORKERS% workers" cmd /c "set OCR_WORKERS=%WORKERS%&& set OCR_ENABLE_PADDLEOCR=false&& set OCR_ENABLE_TESSERACT=false&& set OCR_DEFAULT_ENGINE=hybrid-doctr&& set LOG_LEVEL=WARNING&& venv-win\Scripts\python.exe -m uvicorn main:app --host 0.0.0.0 --port 8006 --workers 1 2>&1"
REM Wait for backend to be ready
echo Waiting for backend to start...
set attempts=0
:wait_loop
timeout /t 3 >nul
set /a attempts+=1
curl -s http://localhost:8006/health >nul 2>&1
if errorlevel 1 (
if !attempts! lss 40 (
echo Waiting... !attempts!/40
goto :wait_loop
)
echo ERROR: Backend failed to start!
goto :eof
)
echo Backend is ready!
REM Wait for OCR warmup
echo Waiting for OCR worker warmup (30s)...
timeout /t 30 >nul
echo.
echo Running OCR test with %WORKERS% worker(s)...
echo.
venv-win\Scripts\python.exe ..\tests\ocr-validation\test_receipts_parallel_windows.py --port 8006 --workers %WORKERS% --output %RESULTS_FILE%
REM Stop backend
echo.
echo Stopping backend...
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
taskkill /F /PID %%a >nul 2>&1
)
REM Wait for memory to be released
echo Releasing memory (10s)...
timeout /t 10 >nul
goto :eof
:show_summary
echo.
echo ############################################################
echo ALL TESTS COMPLETE
echo ############################################################
echo.
echo Results saved to: %RESULTS_FILE%
echo.
REM Show summary from results file
if exist "%RESULTS_FILE%" (
echo BENCHMARK SUMMARY:
echo ------------------
venv-win\Scripts\python.exe -c "import json; data=json.load(open('%RESULTS_FILE%')); print(); [print(f\" {r['workers']} worker(s): {r['total_time']:.1f}s total, {r['avg_time']:.1f}s avg, {r.get('peak_memory_mb', 0):.0f}MB peak, {r['successful']}/{r['submitted']} success\") for r in data]"
echo.
)
echo Press any key to exit...
pause >nul
endlocal

View File

@@ -38,12 +38,20 @@ from backend.modules.reports.routers import create_reports_router
from backend.modules.data_entry.routers import create_data_entry_router
from backend.modules.telegram.routers import create_telegram_router
# Configure logging
# Configure logging (level from env: DEBUG, INFO, WARNING, ERROR)
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=logging.INFO,
level=getattr(logging, log_level, logging.INFO),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
# Reduce noise from third-party libraries
logging.getLogger('httpcore').setLevel(logging.WARNING)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('multipart').setLevel(logging.WARNING)
logging.getLogger('doctr').setLevel(logging.WARNING)
logging.getLogger('tensorflow').setLevel(logging.WARNING)
logging.getLogger('PIL').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# Global variables for background tasks

View File

@@ -48,6 +48,11 @@ class Settings(BaseSettings):
# CORS
cors_origins: str = "http://localhost:3010,http://localhost:3000"
# OCR Engines (comma-separated list of active engines shown in UI)
# Available: tesseract, paddleocr, doctr, doctr_plus
# doctr_plus is recommended (2-tier sequential with early exit)
ocr_active_engines: str = "doctr,doctr_plus"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@@ -80,6 +85,11 @@ class Settings(BaseSettings):
"""Get CORS origins as list."""
return [origin.strip() for origin in self.cors_origins.split(",")]
@property
def ocr_active_engines_list(self) -> List[str]:
"""Get OCR active engines as list."""
return [engine.strip() for engine in self.ocr_active_engines.split(",")]
@property
def oracle_dsn(self) -> str:
"""Get Oracle DSN string."""

View File

@@ -2,9 +2,12 @@
from .receipt import ReceiptCRUD
from .attachment import AttachmentCRUD
from .accounting_entry import AccountingEntryCRUD
from .ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
__all__ = [
"ReceiptCRUD",
"AttachmentCRUD",
"AccountingEntryCRUD",
"OCRPreferenceCRUD",
"OCRMetricsCRUD",
]

View File

@@ -0,0 +1,222 @@
"""CRUD operations for OCR settings and metrics."""
from datetime import datetime, timedelta
from typing import List, Optional
from sqlalchemy import func, select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.ocr_settings import (
UserOCRPreference,
OCRJobMetrics,
OCRMetricsSummary,
OCREngine,
)
class OCRPreferenceCRUD:
"""CRUD operations for user OCR preferences."""
@staticmethod
async def get_by_username(session: AsyncSession, username: str) -> Optional[UserOCRPreference]:
"""Get user's OCR preference by username."""
result = await session.execute(
select(UserOCRPreference).where(UserOCRPreference.username == username)
)
return result.scalar_one_or_none()
@staticmethod
async def create_or_update(
session: AsyncSession,
username: str,
preferred_engine: OCREngine
) -> UserOCRPreference:
"""Create or update user's OCR preference."""
existing = await OCRPreferenceCRUD.get_by_username(session, username)
if existing:
existing.preferred_engine = preferred_engine
existing.updated_at = datetime.utcnow()
await session.commit()
await session.refresh(existing)
return existing
else:
preference = UserOCRPreference(
username=username,
preferred_engine=preferred_engine
)
session.add(preference)
await session.commit()
await session.refresh(preference)
return preference
@staticmethod
async def delete_by_username(session: AsyncSession, username: str) -> bool:
"""Delete user's OCR preference."""
existing = await OCRPreferenceCRUD.get_by_username(session, username)
if existing:
await session.delete(existing)
await session.commit()
return True
return False
class OCRMetricsCRUD:
"""CRUD operations for OCR job metrics."""
@staticmethod
async def create(
session: AsyncSession,
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
company_id: Optional[int] = None
) -> OCRJobMetrics:
"""Create a new OCR job metrics record."""
metrics = OCRJobMetrics(
job_id=job_id,
username=username,
company_id=company_id,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
session.add(metrics)
await session.commit()
await session.refresh(metrics)
return metrics
@staticmethod
async def get_by_job_id(session: AsyncSession, job_id: str) -> Optional[OCRJobMetrics]:
"""Get metrics by job ID."""
result = await session.execute(
select(OCRJobMetrics).where(OCRJobMetrics.job_id == job_id)
)
return result.scalar_one_or_none()
@staticmethod
async def get_user_history(
session: AsyncSession,
username: str,
limit: int = 50,
offset: int = 0
) -> List[OCRJobMetrics]:
"""Get user's OCR job history."""
result = await session.execute(
select(OCRJobMetrics)
.where(OCRJobMetrics.username == username)
.order_by(OCRJobMetrics.created_at.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
@staticmethod
async def get_summary_by_engine(
session: AsyncSession,
days: int = 30,
username: Optional[str] = None
) -> List[OCRMetricsSummary]:
"""Get summary metrics grouped by engine."""
cutoff_date = datetime.utcnow() - timedelta(days=days)
# Build query
conditions = [OCRJobMetrics.created_at >= cutoff_date]
if username:
conditions.append(OCRJobMetrics.username == username)
# Query for aggregated metrics
result = await session.execute(
select(
OCRJobMetrics.engine_used,
func.count(OCRJobMetrics.id).label('total_jobs'),
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
func.avg(OCRJobMetrics.fields_extracted).label('avg_fields_extracted'),
)
.where(and_(*conditions))
.group_by(OCRJobMetrics.engine_used)
.order_by(func.count(OCRJobMetrics.id).desc())
)
summaries = []
for row in result.all():
total = row.total_jobs or 0
successful = row.successful_jobs or 0
success_rate = successful / total if total > 0 else 0.0
summaries.append(OCRMetricsSummary(
engine=row.engine_used,
total_jobs=total,
successful_jobs=successful,
failed_jobs=total - successful,
success_rate=success_rate,
avg_processing_time_ms=float(row.avg_processing_time_ms or 0),
avg_confidence=float(row.avg_confidence or 0),
avg_fields_extracted=float(row.avg_fields_extracted or 0),
))
return summaries
@staticmethod
async def get_overall_stats(
session: AsyncSession,
days: int = 30,
username: Optional[str] = None
) -> dict:
"""Get overall OCR statistics."""
cutoff_date = datetime.utcnow() - timedelta(days=days)
conditions = [OCRJobMetrics.created_at >= cutoff_date]
if username:
conditions.append(OCRJobMetrics.username == username)
result = await session.execute(
select(
func.count(OCRJobMetrics.id).label('total_jobs'),
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
)
.where(and_(*conditions))
)
row = result.one()
total = row.total_jobs or 0
successful = row.successful_jobs or 0
return {
"total_jobs": total,
"successful_jobs": successful,
"failed_jobs": total - successful,
"success_rate": (successful / total * 100) if total > 0 else 0.0,
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
"avg_confidence": float(row.avg_confidence or 0),
"period_days": days,
}
# Import sqlalchemy for func.cast
import sqlalchemy as sa

View File

@@ -10,9 +10,10 @@ from backend.modules.data_entry.config import settings
# Create async engine
# Note: echo=False to disable SQL query logging (too verbose)
engine = create_async_engine(
settings.database_url,
echo=settings.debug,
echo=False,
future=True,
)

View File

@@ -2,6 +2,7 @@
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection
from .accounting_entry import AccountingEntry, EntryType
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
from .ocr_settings import UserOCRPreference, OCRJobMetrics, OCRMetricsSummary, OCREngine
__all__ = [
"Receipt",
@@ -14,4 +15,9 @@ __all__ = [
"SyncedSupplier",
"LocalSupplier",
"SyncedCashRegister",
# OCR Settings & Metrics
"UserOCRPreference",
"OCRJobMetrics",
"OCRMetricsSummary",
"OCREngine",
]

View File

@@ -0,0 +1,102 @@
"""OCR settings and metrics SQLModel models."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional
from sqlmodel import SQLModel, Field
class OCREngine(str, Enum):
"""Available OCR engines."""
TESSERACT = "tesseract"
DOCTR = "doctr"
DOCTR_PLUS = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
PADDLEOCR = "paddleocr"
class UserOCRPreference(SQLModel, table=True):
"""
User's preferred OCR engine setting.
Each user can have one preferred OCR engine that will be
auto-selected when they upload new receipts for processing.
"""
__tablename__ = "user_ocr_preferences"
id: Optional[int] = Field(default=None, primary_key=True)
# User identification
username: str = Field(max_length=100, unique=True, index=True)
# Preference settings
preferred_engine: OCREngine = Field(default=OCREngine.DOCTR_PLUS)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class OCRJobMetrics(SQLModel, table=True):
"""
OCR job processing metrics for analytics.
Stores metrics for each OCR job to enable:
- Performance tracking by engine
- Success rate analysis
- Processing time trends
- User-specific analytics
"""
__tablename__ = "ocr_job_metrics"
id: Optional[int] = Field(default=None, primary_key=True)
# Job identification
job_id: str = Field(max_length=50, unique=True, index=True)
# User and company context
username: str = Field(max_length=100, index=True)
company_id: Optional[int] = Field(default=None, index=True)
# Engine used
engine_requested: str = Field(max_length=20) # What user/auto requested
engine_used: str = Field(max_length=50) # What was actually used (e.g., "doctr-light")
# Processing metrics
processing_time_ms: int = Field(default=0)
file_size_bytes: int = Field(default=0)
file_type: str = Field(max_length=50, default="image/jpeg") # MIME type
original_filename: Optional[str] = Field(default=None, max_length=255) # Original uploaded filename
# Success metrics
success: bool = Field(default=True)
error_message: Optional[str] = Field(default=None, max_length=500)
# Extraction quality metrics
overall_confidence: float = Field(default=0.0)
fields_extracted: int = Field(default=0) # Number of fields successfully extracted
needs_manual_review: Optional[bool] = Field(default=None)
validation_warnings_count: int = Field(default=0)
validation_errors_count: int = Field(default=0)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
class OCRMetricsSummary(SQLModel):
"""
Summary metrics for OCR analytics.
Not a database table - used for API responses.
"""
engine: str
total_jobs: int
successful_jobs: int
failed_jobs: int
success_rate: float # Computed: successful_jobs / total_jobs
avg_processing_time_ms: float
avg_confidence: float
avg_fields_extracted: float

View File

@@ -17,6 +17,7 @@ load_dotenv()
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
from backend.modules.data_entry.db.models.ocr_settings import UserOCRPreference, OCRJobMetrics
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View File

@@ -0,0 +1,74 @@
"""Add OCR settings and metrics tables.
Revision ID: add_ocr_settings_metrics
Revises: 20251230_add_needs_manual_review
Create Date: 2025-12-31
This migration adds:
- user_ocr_preferences: Store user's preferred OCR engine
- ocr_job_metrics: Store OCR job processing metrics for analytics
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers
revision = 'add_ocr_settings_metrics'
down_revision = '20251230_add_needs_manual_review'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Create OCR settings and metrics tables."""
# Create user_ocr_preferences table
op.create_table(
'user_ocr_preferences',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('username', sa.String(length=100), nullable=False),
sa.Column('preferred_engine', sa.String(length=20), nullable=False, server_default='doctr_plus'),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_user_ocr_preferences_username', 'user_ocr_preferences', ['username'], unique=True)
# Create ocr_job_metrics table
op.create_table(
'ocr_job_metrics',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('job_id', sa.String(length=50), nullable=False),
sa.Column('username', sa.String(length=100), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=True),
sa.Column('engine_requested', sa.String(length=20), nullable=False),
sa.Column('engine_used', sa.String(length=50), nullable=False),
sa.Column('processing_time_ms', sa.Integer(), nullable=False, server_default='0'),
sa.Column('file_size_bytes', sa.Integer(), nullable=False, server_default='0'),
sa.Column('file_type', sa.String(length=50), nullable=False, server_default='image/jpeg'),
sa.Column('success', sa.Boolean(), nullable=False, server_default='1'),
sa.Column('error_message', sa.String(length=500), nullable=True),
sa.Column('overall_confidence', sa.Float(), nullable=False, server_default='0.0'),
sa.Column('fields_extracted', sa.Integer(), nullable=False, server_default='0'),
sa.Column('needs_manual_review', sa.Boolean(), nullable=True),
sa.Column('validation_warnings_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('validation_errors_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_ocr_job_metrics_job_id', 'ocr_job_metrics', ['job_id'], unique=True)
op.create_index('ix_ocr_job_metrics_username', 'ocr_job_metrics', ['username'], unique=False)
op.create_index('ix_ocr_job_metrics_company_id', 'ocr_job_metrics', ['company_id'], unique=False)
op.create_index('ix_ocr_job_metrics_created_at', 'ocr_job_metrics', ['created_at'], unique=False)
def downgrade() -> None:
"""Drop OCR settings and metrics tables."""
op.drop_index('ix_ocr_job_metrics_created_at', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_company_id', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_username', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_job_id', table_name='ocr_job_metrics')
op.drop_table('ocr_job_metrics')
op.drop_index('ix_user_ocr_preferences_username', table_name='user_ocr_preferences')
op.drop_table('user_ocr_preferences')

View File

@@ -0,0 +1,30 @@
"""Add original_filename to ocr_job_metrics.
Revision ID: add_original_filename_to_metrics
Revises: add_ocr_settings_metrics
Create Date: 2025-12-31
Adds original_filename column to track the uploaded filename.
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers
revision = 'add_original_filename_to_metrics'
down_revision = 'add_ocr_settings_metrics'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add original_filename column to ocr_job_metrics."""
op.add_column(
'ocr_job_metrics',
sa.Column('original_filename', sa.String(length=255), nullable=True)
)
def downgrade() -> None:
"""Remove original_filename column."""
op.drop_column('ocr_job_metrics', 'original_filename')

View File

@@ -11,6 +11,8 @@ def create_data_entry_router() -> APIRouter:
- /receipts - Receipt CRUD and workflow
- /ocr - OCR processing for receipts
- /nomenclature - Nomenclature syncing from Oracle
- /settings - User settings (OCR preferences)
- /metrics - OCR analytics and metrics
Returns:
APIRouter: Configured router for data entry module
@@ -21,10 +23,13 @@ def create_data_entry_router() -> APIRouter:
from .receipts import router as receipts_router
from .ocr import router as ocr_router
from .nomenclature import router as nomenclature_router
from .ocr_settings import router as ocr_settings_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
# OCR settings and metrics (endpoints at /settings/* and /metrics/*)
router.include_router(ocr_settings_router, tags=["data-entry-settings"])
return router

View File

@@ -27,6 +27,7 @@ from backend.modules.data_entry.services.ocr_service import ocr_service
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
from backend.modules.data_entry.schemas.ocr import (
OCRResponse,
OCRStatusResponse,
@@ -55,7 +56,7 @@ router = APIRouter()
@router.post("/extract", response_model=OCRJobSubmitResponse)
async def submit_ocr_job(
file: UploadFile = File(...),
engine: OCREngineChoice = Query(default=OCREngineChoice.auto, description="OCR engine to use"),
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus, description="OCR engine to use"),
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
current_user: CurrentUser = Depends(get_current_user)
):
@@ -69,7 +70,7 @@ async def submit_ocr_job(
Args:
file: Image or PDF file (max 10MB)
engine: OCR engine choice (auto, paddleocr, tesseract)
engine: OCR engine choice (tesseract, doctr, doctr_plus, paddleocr)
sync: If true, process synchronously (legacy mode)
Returns:
@@ -129,13 +130,13 @@ async def submit_ocr_job(
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
async def get_job_status(
job_id: str,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR job status and result.
Get OCR job status and result (instant response).
Poll this endpoint to check job progress.
Recommended polling interval: 2 seconds.
For efficient polling, use GET /jobs/{job_id}/wait instead (long-polling).
Args:
job_id: Job UUID from POST /extract response
@@ -165,6 +166,10 @@ async def get_job_status(
result_data = None
if job.status == JobStatus.completed and job.result:
result_data = _dict_to_extraction_data(job.result)
# Apply fuzzy CUI matching
result_data = await _apply_fuzzy_cui_matching(result_data, session)
# Debug: log suggested_payment_mode being returned
print(f"[OCR Router] Returning job {job_id} with suggested_payment_mode={result_data.suggested_payment_mode}", flush=True)
return OCRJobResponse(
job_id=job.id,
@@ -174,12 +179,66 @@ async def get_job_status(
created_at=job.created_at or datetime.utcnow(),
started_at=job.started_at,
completed_at=job.completed_at,
queue_wait_ms=job.queue_wait_ms,
ocr_time_ms=job.ocr_time_ms,
processing_time_ms=job.processing_time_ms,
result=result_data,
error=job.error_message
)
@router.get("/jobs/{job_id}/wait", response_model=OCRJobResponse)
async def wait_for_job_status(
job_id: str,
timeout: int = Query(default=30, ge=1, le=60, description="Max wait time in seconds"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Long-poll for OCR job status change.
Waits until:
- Job status changes to completed/failed
- Timeout expires (returns current status)
Recommended client timeout: timeout + 5 seconds
Args:
job_id: Job UUID from POST /extract response
timeout: Max wait time in seconds (1-60, default 30)
Returns:
OCRJobResponse with status, queue_position, and result (if completed)
"""
import asyncio
import time
end_time = time.time() + timeout
last_status = None
while time.time() < end_time:
job = await job_queue.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Return immediately if job completed or failed
if job.status in [JobStatus.completed, JobStatus.failed]:
return await get_job_status(job_id, session, current_user)
# Return if status changed from last check
if last_status is not None and job.status != last_status:
return await get_job_status(job_id, session, current_user)
last_status = job.status
# Wait 1 second before next internal check
await asyncio.sleep(1)
# Timeout - return current status
return await get_job_status(job_id, session, current_user)
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
async def get_queue_status(
current_user: CurrentUser = Depends(get_current_user)
@@ -221,10 +280,58 @@ async def get_ocr_status():
)
@router.get("/engines")
async def get_available_engines():
"""
Get list of enabled OCR engines based on .env configuration.
Returns engines availability and available processing modes.
Frontend should use this to filter engine selection dropdown.
Available engines: tesseract, doctr, doctr_plus, paddleocr
"""
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
default_engine = os.getenv("OCR_DEFAULT_ENGINE", "doctr_plus")
# Build engines dict
engines = {
"tesseract": tesseract_enabled,
"doctr": True, # Always available (primary engine)
"doctr_plus": True, # Always available (recommended)
"paddleocr": paddle_enabled,
}
# Build available modes based on enabled engines
modes = []
if tesseract_enabled:
modes.append("tesseract")
modes.append("doctr")
modes.append("doctr_plus")
if paddle_enabled:
modes.append("paddleocr")
return {
"engines": engines,
"available_modes": modes,
"default_mode": default_engine,
"memory_estimate_mb": {
"tesseract": 50,
"doctr": 600,
"doctr_plus": 600,
"paddleocr": 800,
}
}
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
async def extract_from_attachment(
attachment_id: int,
engine: OCREngineChoice = Query(default=OCREngineChoice.auto),
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
@@ -260,6 +367,8 @@ async def extract_from_attachment(
raise HTTPException(status_code=422, detail=message)
data = _result_to_extraction_data(result)
# Apply fuzzy CUI matching
data = await _apply_fuzzy_cui_matching(data, session)
return OCRResponse(success=True, message=message, data=data)
@@ -267,6 +376,58 @@ async def extract_from_attachment(
# Helper Functions
# ============================================================================
async def _apply_fuzzy_cui_matching(
extraction_data: ExtractionData,
session: AsyncSession
) -> ExtractionData:
"""
Apply fuzzy CUI matching to extraction data.
ONLY applies fuzzy matching if CUI is missing OR has invalid checksum.
If CUI has valid checksum, we trust the OCR and skip fuzzy matching.
Args:
extraction_data: ExtractionData with CUI to potentially correct
session: AsyncSession for database lookups
Returns:
ExtractionData with CUI corrected if a match was found
"""
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
# Skip if no CUI and no vendor name (nothing to match)
if not extraction_data.cui and not extraction_data.partner_name:
return extraction_data
# Check if CUI has valid checksum - if valid, skip fuzzy matching
if extraction_data.cui:
cui_digits = CUIChecksumRule.extract_digits(extraction_data.cui)
if len(cui_digits) >= 6 and CUIChecksumRule.validate_checksum(cui_digits):
print(f"[Fuzzy Match] CUI {extraction_data.cui} has valid checksum, skipping fuzzy match", flush=True)
return extraction_data
# CUI missing or invalid checksum - try fuzzy matching
try:
match = await OCRValidationEngine.fuzzy_match_supplier(
cui=extraction_data.cui,
vendor_name=extraction_data.partner_name,
db_session=session
)
if match:
corrected_cui, supplier_name = match
if corrected_cui != extraction_data.cui:
print(f"[Fuzzy Match] Corrected: {extraction_data.cui}{corrected_cui} ({supplier_name})", flush=True)
extraction_data.cui = corrected_cui
# Also set partner_name if not already set
if not extraction_data.partner_name:
extraction_data.partner_name = supplier_name
except Exception as e:
print(f"[Fuzzy Match] Error: {e}", flush=True)
return extraction_data
async def _process_sync(
content: bytes,
file: UploadFile,
@@ -362,6 +523,7 @@ def _result_to_extraction_data(result) -> ExtractionData:
confidence_client=getattr(result, 'confidence_client', 0.0),
overall_confidence=result.overall_confidence,
raw_text=result.raw_text,
raw_texts=getattr(result, 'raw_texts', []),
ocr_engine=result.ocr_engine,
processing_time_ms=result.processing_time_ms,
needs_manual_review=result.needs_manual_review,
@@ -437,6 +599,7 @@ def _dict_to_extraction_data(data: dict) -> ExtractionData:
confidence_client=data.get('confidence_client', 0.0),
overall_confidence=data.get('overall_confidence', 0.0),
raw_text=data.get('raw_text', ''),
raw_texts=data.get('raw_texts', []),
ocr_engine=data.get('ocr_engine', ''),
processing_time_ms=data.get('processing_time_ms', 0),
needs_manual_review=data.get('needs_manual_review'),

View File

@@ -0,0 +1,268 @@
"""
OCR Settings and Metrics API endpoints.
Endpoints:
- GET /settings/ocr-preference - Get user's preferred OCR engine
- POST /settings/ocr-preference - Set user's preferred OCR engine
- GET /metrics/ocr/summary - Get OCR metrics summary by engine
- GET /metrics/ocr/history - Get user's OCR job history
- GET /metrics/ocr/stats - Get overall OCR statistics
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
from backend.modules.data_entry.db.models.ocr_settings import OCREngine, OCRMetricsSummary
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============================================================================
# Schemas
# ============================================================================
class OCRPreferenceResponse(BaseModel):
"""Response for OCR preference endpoint."""
username: str
preferred_engine: str
available_engines: List[str] = Field(
default=["tesseract", "doctr", "doctr_plus", "paddleocr"],
description="Available OCR engines"
)
class OCRPreferenceRequest(BaseModel):
"""Request to set OCR preference."""
preferred_engine: str = Field(
default="doctr_plus",
description="Preferred OCR engine: tesseract, doctr, doctr_plus, paddleocr"
)
class OCRMetricsHistoryItem(BaseModel):
"""Single OCR job metrics item."""
job_id: str
engine_requested: str
engine_used: str
processing_time_ms: int
success: bool
overall_confidence: float
fields_extracted: int
created_at: str
original_filename: Optional[str] = None
class OCRMetricsHistoryResponse(BaseModel):
"""Response for OCR history endpoint."""
items: List[OCRMetricsHistoryItem]
total: int
class OCRStatsResponse(BaseModel):
"""Response for OCR stats endpoint."""
total_jobs: int
successful_jobs: int
failed_jobs: int
success_rate: float
avg_processing_time_ms: float
avg_confidence: float
period_days: int
class OCRActiveEnginesResponse(BaseModel):
"""Response for active OCR engines endpoint."""
engines: List[str] = Field(description="List of active OCR engines from .env config")
recommended: str = Field(default="doctr_plus", description="Recommended engine")
# ============================================================================
# OCR Engines Configuration Endpoint
# ============================================================================
@router.get("/settings/ocr-engines", response_model=OCRActiveEnginesResponse)
async def get_active_ocr_engines():
"""
Get list of active OCR engines configured in .env.
Returns the engines that should be shown in the frontend dropdown.
Configured via OCR_ACTIVE_ENGINES environment variable.
Default: doctr,doctr_plus
Available: tesseract, paddleocr, doctr, doctr_plus
"""
from backend.modules.data_entry.config import settings
return OCRActiveEnginesResponse(
engines=settings.ocr_active_engines_list,
recommended="doctr_plus"
)
# ============================================================================
# OCR Preference Endpoints
# ============================================================================
@router.get("/settings/ocr-preference", response_model=OCRPreferenceResponse)
async def get_ocr_preference(
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get user's preferred OCR engine.
Returns the user's saved preference or 'doctr_plus' if not set.
Also returns list of available engines.
"""
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
preference = await OCRPreferenceCRUD.get_by_username(session, current_user.username)
# Get available engines from OCR service
available = OCREngineClass.get_available_engines()
return OCRPreferenceResponse(
username=current_user.username,
preferred_engine=preference.preferred_engine.value if preference else "doctr_plus",
available_engines=available
)
@router.post("/settings/ocr-preference", response_model=OCRPreferenceResponse)
async def set_ocr_preference(
request: OCRPreferenceRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Set user's preferred OCR engine.
Valid engines: tesseract, doctr, doctr_plus, paddleocr
Note: Available engines depend on .env configuration (OCR_ENABLE_PADDLEOCR, OCR_ENABLE_TESSERACT)
"""
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
# Get dynamically available engines
available = OCREngineClass.get_available_engines()
if request.preferred_engine not in available:
raise HTTPException(
status_code=400,
detail=f"Invalid engine. Must be one of: {', '.join(available)}"
)
# Map string to enum
engine_map = {
"tesseract": OCREngine.TESSERACT,
"doctr": OCREngine.DOCTR,
"doctr_plus": OCREngine.DOCTR_PLUS,
"paddleocr": OCREngine.PADDLEOCR,
}
engine_enum = engine_map.get(request.preferred_engine, OCREngine.DOCTR_PLUS)
# Save preference
preference = await OCRPreferenceCRUD.create_or_update(
session,
current_user.username,
engine_enum
)
# Get available engines
available = OCREngineClass.get_available_engines()
return OCRPreferenceResponse(
username=current_user.username,
preferred_engine=preference.preferred_engine.value,
available_engines=available
)
# ============================================================================
# OCR Metrics Endpoints
# ============================================================================
@router.get("/metrics/ocr/summary", response_model=List[OCRMetricsSummary])
async def get_ocr_metrics_summary(
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR metrics summary grouped by engine.
Returns aggregated metrics for each engine used in the specified period.
"""
summaries = await OCRMetricsCRUD.get_summary_by_engine(
session,
days=days,
username=current_user.username
)
return summaries
@router.get("/metrics/ocr/history", response_model=OCRMetricsHistoryResponse)
async def get_ocr_metrics_history(
limit: int = Query(default=50, ge=1, le=200, description="Max items to return"),
offset: int = Query(default=0, ge=0, description="Items to skip"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get user's OCR job history.
Returns list of OCR jobs with their metrics, ordered by most recent first.
"""
items = await OCRMetricsCRUD.get_user_history(
session,
username=current_user.username,
limit=limit,
offset=offset
)
history_items = [
OCRMetricsHistoryItem(
job_id=item.job_id,
engine_requested=item.engine_requested,
engine_used=item.engine_used,
processing_time_ms=item.processing_time_ms,
success=item.success,
overall_confidence=item.overall_confidence,
fields_extracted=item.fields_extracted,
created_at=item.created_at.isoformat(),
original_filename=item.original_filename
)
for item in items
]
return OCRMetricsHistoryResponse(
items=history_items,
total=len(history_items)
)
@router.get("/metrics/ocr/stats", response_model=OCRStatsResponse)
async def get_ocr_stats(
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get overall OCR statistics for the user.
Returns aggregated stats including success rate, average processing time, etc.
"""
stats = await OCRMetricsCRUD.get_overall_stats(
session,
days=days,
username=current_user.username
)
return OCRStatsResponse(**stats)

View File

@@ -61,7 +61,8 @@ class ExtractionData(BaseModel):
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
raw_text: str = Field(default="", description="Raw OCR text")
raw_text: str = Field(default="", description="Raw OCR text (primary)")
raw_texts: List[str] = Field(default=[], description="Raw OCR texts from all engine passes (for analysis)")
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
@@ -148,9 +149,10 @@ from enum import Enum
class OCREngineChoice(str, Enum):
"""OCR engine selection options."""
auto = "auto"
paddleocr = "paddleocr"
tesseract = "tesseract"
doctr = "doctr" # 3.3x faster than PaddleOCR with same accuracy (90/100)
doctr_plus = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
paddleocr = "paddleocr"
class OCRJobStatus(str, Enum):
@@ -193,7 +195,10 @@ class OCRJobResponse(BaseModel):
created_at: datetime = Field(description="Job creation timestamp")
started_at: Optional[datetime] = Field(default=None, description="Processing start timestamp")
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
processing_time_ms: Optional[int] = Field(default=None, description="Actual processing time in ms")
# Detailed timing breakdown
queue_wait_ms: Optional[int] = Field(default=None, description="Time waiting in queue (started_at - created_at)")
ocr_time_ms: Optional[int] = Field(default=None, description="Actual OCR engine processing time")
processing_time_ms: Optional[int] = Field(default=None, description="Total job processing time (completed_at - started_at)")
result: Optional[ExtractionData] = Field(default=None, description="Extraction result (only if completed)")
error: Optional[str] = Field(default=None, description="Error message (only if failed)")

View File

@@ -33,73 +33,55 @@ class NomenclatureService:
"""
Get partners (suppliers/customers) for a company.
Phase 1: Returns mock data.
Phase 2: Returns synced data from SQLite (from Oracle sync).
Phase 3: Will fetch live from Oracle.
Returns synced suppliers from Oracle + local suppliers created from OCR.
If no suppliers exist, returns empty list (frontend will trigger sync).
"""
# If session is provided, try to get from synced SQLite data
if session:
# Try to get from SQLite synced data
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.order_by(SyncedSupplier.name) # Order alphabetically, no limit for AutoComplete
partners = []
result = await session.execute(stmt)
suppliers = result.scalars().all()
if suppliers:
# Also get local suppliers
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
local_stmt = local_stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
local_stmt = local_stmt.order_by(LocalSupplier.name) # Order alphabetically
local_result = await session.execute(local_stmt)
local_suppliers = local_result.scalars().all()
# Combine both - no IDs needed, just text data for autocomplete
partners = []
for s in suppliers:
partners.append(PartnerOption(
name=s.name,
fiscal_code=s.fiscal_code,
address=s.address,
source="oracle"
))
for l in local_suppliers:
partners.append(PartnerOption(
name=l.name, # No suffix - must match search results
fiscal_code=l.fiscal_code,
address=l.address,
source="local"
))
return partners
# Fallback to mock data for Phase 1 (when no synced data)
mock_partners = [
PartnerOption(name="OMV Petrom", fiscal_code="RO123456", source="mock"),
PartnerOption(name="Dedeman", fiscal_code="RO789012", source="mock"),
PartnerOption(name="Kaufland", fiscal_code="RO345678", source="mock"),
PartnerOption(name="Emag", fiscal_code="RO901234", source="mock"),
PartnerOption(name="Altex", fiscal_code="RO567890", source="mock"),
]
if not session:
return partners
# Get synced suppliers from Oracle
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
search_lower = search.lower()
mock_partners = [
p for p in mock_partners
if search_lower in p.name.lower() or (p.fiscal_code and search_lower in p.fiscal_code.lower())
]
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.order_by(SyncedSupplier.name)
return mock_partners
result = await session.execute(stmt)
suppliers = result.scalars().all()
for s in suppliers:
partners.append(PartnerOption(
name=s.name,
fiscal_code=s.fiscal_code,
address=s.address,
source="oracle"
))
# Always get local suppliers (not just when synced exist)
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
local_stmt = local_stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
local_stmt = local_stmt.order_by(LocalSupplier.name)
local_result = await session.execute(local_stmt)
local_suppliers = local_result.scalars().all()
for l in local_suppliers:
partners.append(PartnerOption(
name=l.name,
fiscal_code=l.fiscal_code,
address=l.address,
source="local"
))
return partners
@staticmethod
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:

View File

@@ -13,13 +13,14 @@ Schema:
status TEXT NOT NULL, -- pending, processing, completed, failed
file_path TEXT NOT NULL, -- Path to uploaded file
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'auto',
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT, -- JSON extraction result
error_message TEXT,
processing_time_ms INTEGER,
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
ocr_time_ms INTEGER, -- Actual OCR engine processing time
created_by TEXT, -- Username
original_filename TEXT,
expires_at TIMESTAMP
@@ -74,17 +75,26 @@ class OCRJob:
status: OCRJobStatus
file_path: str
mime_type: str
engine: str = "auto"
engine: str = "doctr_plus"
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result_json: Optional[str] = None
error_message: Optional[str] = None
processing_time_ms: Optional[int] = None
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
created_by: Optional[str] = None
original_filename: Optional[str] = None
expires_at: Optional[datetime] = None
@property
def queue_wait_ms(self) -> Optional[int]:
"""Calculate queue wait time (created_at to started_at)."""
if self.created_at and self.started_at:
delta = self.started_at - self.created_at
return int(delta.total_seconds() * 1000)
return None
@property
def result(self) -> Optional[Dict]:
"""Parse result_json to dict."""
@@ -143,19 +153,27 @@ class OCRJobQueue:
status TEXT NOT NULL DEFAULT 'pending',
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'auto',
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT,
error_message TEXT,
processing_time_ms INTEGER,
ocr_time_ms INTEGER,
created_by TEXT,
original_filename TEXT,
expires_at TIMESTAMP
)
''')
# Migration: add ocr_time_ms column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
except Exception:
pass # Column already exists
# Index for efficient queue queries
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
@@ -177,7 +195,7 @@ class OCRJobQueue:
self,
file_bytes: bytes,
mime_type: str,
engine: str = "auto",
engine: str = "doctr_plus",
username: Optional[str] = None,
original_filename: Optional[str] = None
) -> OCRJob:
@@ -189,7 +207,7 @@ class OCRJobQueue:
Args:
file_bytes: Raw file bytes
mime_type: MIME type of file
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
username: Username of requester
original_filename: Original filename from upload
@@ -301,24 +319,52 @@ class OCRJobQueue:
async def get_next_pending(self) -> Optional[OCRJob]:
"""
Get the next pending job (oldest first).
Get the next pending job (oldest first) and atomically mark it as processing.
This prevents race conditions in parallel processing - only one worker
can claim each job.
Returns:
Next OCRJob to process or None if queue empty
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_job(row)
now = datetime.utcnow()
async with self._lock: # Serialize access to prevent race conditions
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
# Get the next pending job
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if not row:
return None
job_id = row['id']
# Atomically mark as processing
await db.execute('''
UPDATE ocr_jobs
SET status = 'processing', started_at = ?
WHERE id = ? AND status = 'pending'
''', (now.isoformat(), job_id))
await db.commit()
# Fetch the updated job
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
updated_row = await cursor.fetchone()
if updated_row:
return self._row_to_job(updated_row)
return None
async def update_status(
@@ -327,7 +373,8 @@ class OCRJobQueue:
status: OCRJobStatus,
result: Optional[Dict] = None,
error: Optional[str] = None,
processing_time_ms: Optional[int] = None
processing_time_ms: Optional[int] = None,
ocr_time_ms: Optional[int] = None
) -> bool:
"""
Update job status.
@@ -337,7 +384,8 @@ class OCRJobQueue:
status: New status
result: Extraction result dict (for completed)
error: Error message (for failed)
processing_time_ms: Processing time
processing_time_ms: Total job processing time (started_at to completed_at)
ocr_time_ms: Actual OCR engine processing time
Returns:
True if update successful
@@ -359,18 +407,18 @@ class OCRJobQueue:
elif status == OCRJobStatus.completed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), result_json, processing_time_ms, job_id)
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
elif status == OCRJobStatus.failed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), error, processing_time_ms, job_id)
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
else:
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
@@ -542,13 +590,14 @@ class OCRJobQueue:
status=OCRJobStatus(row['status']),
file_path=row['file_path'],
mime_type=row['mime_type'],
engine=row['engine'] or 'auto',
engine=row['engine'] or 'doctr_plus',
created_at=parse_datetime(row['created_at']),
started_at=parse_datetime(row['started_at']),
completed_at=parse_datetime(row['completed_at']),
result_json=row['result_json'],
error_message=row['error_message'],
processing_time_ms=row['processing_time_ms'],
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
created_by=row['created_by'],
original_filename=row['original_filename'],
expires_at=parse_datetime(row['expires_at']),

View File

@@ -2,7 +2,7 @@
OCR Job Worker - Background Task for Queue Processing
Runs as an asyncio background task in FastAPI.
Continuously polls the job queue and processes OCR requests.
Continuously polls the job queue and processes OCR requests IN PARALLEL.
Architecture:
FastAPI startup
@@ -12,18 +12,19 @@ Architecture:
asyncio.create_task(_job_worker_loop())
while True:
job = job_queue.get_next_pending()
if job:
result = ocr_worker_pool.submit_task(...)
job_queue.update_status(...)
await asyncio.sleep(0.5)
# Process up to OCR_WORKERS jobs concurrently
jobs = get_pending_jobs(limit=available_slots)
for job in jobs:
asyncio.create_task(_process_job(job))
await asyncio.sleep(0.1)
"""
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import Optional
from typing import Optional, Set
from .job_queue import job_queue, OCRJobStatus, OCRJob
from .ocr_worker_pool import ocr_worker_pool
@@ -34,47 +35,76 @@ logger = logging.getLogger(__name__)
_job_worker_task: Optional[asyncio.Task] = None
_cleanup_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
# Configuration
POLL_INTERVAL_SECONDS = 0.5 # How often to check for new jobs
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
async def _job_worker_loop() -> None:
"""
Main worker loop - processes jobs from queue.
Main worker loop - processes jobs from queue IN PARALLEL.
Runs continuously until shutdown. Polls queue every 0.5s
and submits jobs to worker pool for processing.
Runs continuously until shutdown. Uses semaphore to limit
concurrent jobs to OCR_WORKERS count. Launches jobs as
background tasks without waiting for completion.
"""
global _shutdown_event
global _shutdown_event, _active_tasks, _concurrency_semaphore
logger.info("[JobWorker] Starting worker loop...")
# Get max concurrent jobs from env (matches worker pool size)
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
_active_tasks = set()
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
_shutdown_event = asyncio.Event()
consecutive_errors = 0
max_consecutive_errors = 5
max_consecutive_errors = 10
while not _shutdown_event.is_set():
try:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0 # Reset error counter on success
await _process_job(job)
else:
# No jobs - wait before polling again
# Clean up completed tasks
done_tasks = {t for t in _active_tasks if t.done()}
for task in done_tasks:
_active_tasks.discard(task)
# Check for exceptions
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass # Normal timeout, continue loop
task.result()
except Exception as e:
logger.error(f"[JobWorker] Task failed: {e}")
# Check if we have capacity for more jobs
active_count = len(_active_tasks)
available_slots = max_concurrent - active_count
if available_slots > 0:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0
# Launch job processing as background task
task = asyncio.create_task(_process_job_with_semaphore(job))
_active_tasks.add(task)
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
else:
# No pending jobs - wait briefly
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass
else:
# At capacity - wait for a slot to free up
await asyncio.sleep(POLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
logger.info("[JobWorker] Worker loop cancelled")
@@ -88,27 +118,46 @@ async def _job_worker_loop() -> None:
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
break
# Backoff on errors
await asyncio.sleep(min(consecutive_errors * 2, 30))
# Wait for active tasks to complete on shutdown
if _active_tasks:
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
await asyncio.gather(*_active_tasks, return_exceptions=True)
logger.info("[JobWorker] Worker loop stopped")
async def _process_job_with_semaphore(job: OCRJob) -> None:
"""
Process job with semaphore to limit concurrency.
Acquires semaphore before processing, releases after.
This ensures we don't exceed OCR_WORKERS concurrent jobs.
"""
global _concurrency_semaphore
async with _concurrency_semaphore:
await _process_job(job)
async def _process_job(job: OCRJob) -> None:
"""
Process a single OCR job.
Reads file, submits to worker pool, updates job status.
Reads file, submits to worker pool, updates job status,
and saves metrics for analytics.
Args:
job: OCRJob to process
"""
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
start_time = time.time()
file_size = 0
file_type = "image/jpeg"
try:
# Mark as processing
await job_queue.update_status(job.id, OCRJobStatus.processing)
# Note: Job already marked as 'processing' atomically in get_next_pending()
# Read file bytes
file_path = Path(job.file_path)
@@ -118,6 +167,10 @@ async def _process_job(job: OCRJob) -> None:
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_size = len(file_bytes)
# Determine file type from job or extension
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
# Submit to worker pool
result = await ocr_worker_pool.submit_task(
image_bytes=file_bytes,
@@ -132,14 +185,43 @@ async def _process_job(job: OCRJob) -> None:
# Job completed successfully
extraction = result.get("extraction", {})
# Include raw_texts for analysis (from all OCR engine passes)
extraction['raw_texts'] = result.get("raw_texts", [])
# Extract actual OCR processing time from extraction result
ocr_time_ms = extraction.get('processing_time_ms', 0)
# Debug: log suggested_payment_mode
spm = extraction.get('suggested_payment_mode')
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.completed,
result=extraction,
processing_time_ms=elapsed_ms
processing_time_ms=elapsed_ms,
ocr_time_ms=ocr_time_ms
)
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms")
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
# Save metrics for successful job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=extraction.get('ocr_engine', job.engine),
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=True,
overall_confidence=extraction.get('overall_confidence', 0.0),
fields_extracted=_count_extracted_fields(extraction),
needs_manual_review=extraction.get('needs_manual_review'),
validation_warnings_count=len(extraction.get('validation_warnings', [])),
validation_errors_count=len(extraction.get('validation_errors', [])),
)
else:
# Job failed
@@ -154,6 +236,20 @@ async def _process_job(job: OCRJob) -> None:
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
# Save metrics for failed job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=error_msg,
)
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
@@ -166,6 +262,20 @@ async def _process_job(job: OCRJob) -> None:
processing_time_ms=elapsed_ms
)
# Save metrics for error job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=str(e),
)
finally:
# Cleanup file after processing
try:
@@ -340,3 +450,96 @@ def estimate_wait_time(queue_position: int) -> int:
# Estimate: position * average_time
return int(queue_position * avg_time)
# ============================================================================
# Metrics Helper Functions
# ============================================================================
async def _save_job_metrics(
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
) -> None:
"""
Save OCR job metrics to database for analytics.
Called after each job completes (success or failure).
Errors are logged but don't affect job processing.
"""
try:
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
async with await get_db_session() as session:
await OCRMetricsCRUD.create(
session=session,
job_id=job_id,
username=username,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
except Exception as e:
# Log but don't fail - metrics are nice-to-have
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
def _count_extracted_fields(extraction: dict) -> int:
"""
Count number of successfully extracted fields from OCR result.
Counts non-None values in key fields.
"""
key_fields = [
'receipt_number',
'receipt_date',
'amount',
'partner_name',
'cui',
'tva_total',
'address',
'items_count',
]
count = 0
for field in key_fields:
value = extraction.get(field)
if value is not None and value != '' and value != []:
count += 1
# Also count TVA entries if present
tva_entries = extraction.get('tva_entries', [])
if tva_entries and len(tva_entries) > 0:
count += 1
# Count payment methods if present
payment_methods = extraction.get('payment_methods', [])
if payment_methods and len(payment_methods) > 0:
count += 1
return count

View File

@@ -1,11 +1,12 @@
"""
OCR Worker Pool Manager
Manages a ProcessPoolExecutor with persistent PaddleOCR initialization.
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
Key features:
- ProcessPoolExecutor with max_workers=1 (sequential, no memory leak)
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
- mp_context='spawn' for Windows IIS compatibility
- PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- atexit + signal handlers for cleanup
- Health check with auto-respawn
- Orphan process cleanup on Windows
@@ -29,7 +30,7 @@ import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, Future
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Optional
@@ -48,8 +49,8 @@ class OCRWorkerPool:
"""
Singleton manager for OCR ProcessPoolExecutor.
Ensures PaddleOCR is loaded once and reused for all requests.
Uses max_tasks_per_child=None to keep worker alive indefinitely.
Ensures OCR engines are loaded once and reused for all requests.
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
"""
_instance: Optional["OCRWorkerPool"] = None
@@ -86,7 +87,7 @@ class OCRWorkerPool:
Initialize the ProcessPoolExecutor.
Creates executor with spawn context for Windows compatibility.
Uses max_tasks_per_child=None to keep worker alive (persistent PaddleOCR).
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
Returns:
True if initialization successful
@@ -103,18 +104,30 @@ class OCRWorkerPool:
# Cleanup any orphan workers from previous runs
self._cleanup_orphan_workers()
# Read configuration from environment
max_workers = int(os.getenv('OCR_WORKERS', '2'))
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
# 0 means no restart (None in ProcessPoolExecutor)
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
# Create executor with spawn context (Windows compatible)
# Use mp_context='spawn' explicitly for cross-platform consistency
mp_context = mp.get_context('spawn')
self._executor = ProcessPoolExecutor(
max_workers=1, # Single worker for sequential processing
mp_context=mp_context,
initializer=_worker_initializer,
max_tasks_per_child=None, # Keep worker alive indefinitely
)
# max_tasks_per_child only available in Python 3.11+
executor_kwargs = {
'max_workers': max_workers,
'mp_context': mp_context,
'initializer': _worker_initializer,
}
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
else:
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
logger.info("[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers=1)")
self._executor = ProcessPoolExecutor(**executor_kwargs)
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
return True
except Exception as e:
@@ -173,7 +186,7 @@ class OCRWorkerPool:
async def submit_task(
self,
image_bytes: bytes,
engine: str = "auto",
engine: str = "doctr_plus",
preprocessing: str = "auto",
timeout: float = 120.0
) -> dict:
@@ -182,7 +195,7 @@ class OCRWorkerPool:
Args:
image_bytes: Raw image bytes
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
timeout: Maximum processing time in seconds
@@ -339,6 +352,7 @@ class OCRWorkerPool:
# Global engines - persist between tasks in worker process
_paddle_engine = None
_tesseract_engine = None
_doctr_engine = None # docTR engine (PyTorch backend)
_worker_initialized = False
@@ -346,40 +360,92 @@ def _worker_initializer() -> None:
"""
Called once when worker process spawns.
Initializes global OCR engines that persist between tasks.
This is where PaddleOCR loading happens (15-20 seconds).
Initializes global OCR engines IN PARALLEL for faster startup.
Uses ThreadPoolExecutor to load enabled engines concurrently.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Total warmup time = max(engine_times) instead of sum(engine_times).
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
if _worker_initialized:
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
return
print(f"[Worker {os.getpid()}] Initializing OCR engines...", flush=True)
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
if paddle_enabled:
enabled_engines.append("paddle")
if tesseract_enabled:
enabled_engines.append("tesseract")
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
if not paddle_enabled:
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
if not tesseract_enabled:
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
start_time = time.time()
# Initialize PaddleOCR
try:
# Import inside worker to avoid import issues in main process
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
_paddle_engine = initialize_paddle_engine()
print(f"[Worker {os.getpid()}] PaddleOCR loaded", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] PaddleOCR init failed: {e}", flush=True)
_paddle_engine = None
# Define loader functions - each runs in its own thread
def load_doctr():
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
engine = initialize_doctr_engine()
return ("doctr", engine, None)
except Exception as e:
return ("doctr", None, str(e))
# Initialize Tesseract
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
_tesseract_engine = TesseractEngine()
print(f"[Worker {os.getpid()}] Tesseract loaded", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] Tesseract init failed: {e}", flush=True)
_tesseract_engine = None
def load_paddle():
if not paddle_enabled:
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
engine = initialize_paddle_engine()
return ("paddle", engine, None)
except Exception as e:
return ("paddle", None, str(e))
def load_tesseract():
if not tesseract_enabled:
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
engine = TesseractEngine()
return ("tesseract", engine, None)
except Exception as e:
return ("tesseract", None, str(e))
# Build list of futures for enabled engines only
futures_to_submit = [load_doctr] # docTR always loaded
if paddle_enabled:
futures_to_submit.append(load_paddle)
if tesseract_enabled:
futures_to_submit.append(load_tesseract)
# Load engines in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
futures = [executor.submit(fn) for fn in futures_to_submit]
for future in as_completed(futures):
name, engine, error = future.result()
if error and "disabled" not in error:
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
elif engine:
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
if name == "doctr":
_doctr_engine = engine
elif name == "paddle":
_paddle_engine = engine
elif name == "tesseract":
_tesseract_engine = engine
elapsed = time.time() - start_time
_worker_initialized = True
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s", flush=True)
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
def _warmup_task() -> dict:
@@ -389,7 +455,7 @@ def _warmup_task() -> dict:
Called at FastAPI startup to pre-warm the worker.
Returns success status and worker PID.
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
@@ -400,6 +466,14 @@ def _warmup_task() -> dict:
import numpy as np
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Test docTR if available (fastest engine)
if _doctr_engine is not None:
try:
_doctr_engine([dummy_img])
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
# Test PaddleOCR if available
if _paddle_engine is not None:
try:
@@ -414,6 +488,7 @@ def _warmup_task() -> dict:
return {
"success": True,
"pid": os.getpid(),
"doctr_available": _doctr_engine is not None,
"paddle_available": _paddle_engine is not None,
"tesseract_available": _tesseract_engine is not None
}
@@ -428,7 +503,7 @@ def _warmup_task() -> dict:
def _process_ocr_task(
image_bytes: bytes,
engine: str = "auto",
engine: str = "doctr_plus",
preprocessing: str = "auto"
) -> dict:
"""
@@ -439,13 +514,13 @@ def _process_ocr_task(
Args:
image_bytes: Raw image bytes
engine: OCR engine choice
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode
Returns:
Dict with extraction results
"""
global _paddle_engine, _tesseract_engine, _worker_initialized
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
@@ -461,7 +536,8 @@ def _process_ocr_task(
paddle_engine=_paddle_engine,
tesseract_engine=_tesseract_engine,
engine=engine,
preprocessing=preprocessing
preprocessing=preprocessing,
doctr_engine=_doctr_engine
)
# Cleanup after each task

View File

@@ -6,6 +6,7 @@ Handles OCR processing with persistent engine instances.
Key features:
- PaddleOCR initialized ONCE at process spawn
- docTR initialized ONCE at process spawn (PyTorch backend)
- Tesseract as fallback/complement engine
- Multi-pass preprocessing (light → medium → tesseract)
- Automatic engine selection based on results
@@ -26,6 +27,13 @@ import numpy as np
# Disable PaddleOCR model source check for faster startup
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Memory optimization for docTR (prevents memory leak in multiprocessing)
# Source: https://github.com/mindee/doctr/issues/1594
os.environ['DOCTR_MULTIPROCESSING_DISABLE'] = 'TRUE'
# Reduce Intel oneDNN cache to save memory
os.environ['ONEDNN_PRIMITIVE_CACHE_CAPACITY'] = '1'
@dataclass
class OCRResult:
@@ -71,25 +79,67 @@ def initialize_paddle_engine():
return None
def initialize_doctr_engine():
"""
Initialize docTR engine (CPU only).
Called once at worker spawn. Returns the engine instance
that will be reused for all subsequent requests.
Note: DirectML (AMD GPU) has compatibility issues with docTR.
CUDA (NVIDIA) works but requires separate PyTorch build.
CPU mode is stable and well-optimized.
Returns:
docTR predictor instance or None if unavailable
"""
try:
print(f"[Worker {os.getpid()}] Loading docTR (PyTorch backend, CPU)...", flush=True)
start_time = time.time()
from doctr.models import ocr_predictor
# Initialize docTR predictor with pretrained models
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
doctr = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True,
assume_straight_pages=True,
straighten_pages=False,
preserve_aspect_ratio=True,
)
elapsed = time.time() - start_time
print(f"[Worker {os.getpid()}] docTR loaded in {elapsed:.1f}s", flush=True)
return doctr
except Exception as e:
print(f"[Worker {os.getpid()}] docTR init failed: {e}", flush=True)
return None
def process_ocr(
image_bytes: bytes,
paddle_engine,
tesseract_engine,
engine: str = "auto",
preprocessing: str = "auto"
engine: str = "doctr_plus",
preprocessing: str = "auto",
doctr_engine=None
) -> dict:
"""
Process OCR on image bytes.
Main entry point for OCR processing in worker process.
Uses adaptive multi-pass strategy for best results.
Uses the specified engine for text recognition.
Args:
image_bytes: Raw image bytes (JPEG, PNG, or PDF)
paddle_engine: Pre-initialized PaddleOCR instance (or None)
tesseract_engine: Pre-initialized TesseractEngine instance (or None)
engine: Engine selection ('auto', 'paddleocr', 'tesseract')
engine: Engine selection ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('auto', 'light', 'medium', 'heavy')
doctr_engine: Pre-initialized docTR instance (or None)
Returns:
Dict with extraction results:
@@ -101,14 +151,20 @@ def process_ocr(
"ocr_engine": str
}
"""
import sys
start_time = time.time()
print(f"[Worker {os.getpid()}] Processing OCR: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes", flush=True)
sys.stdout.flush()
try:
# Decode image from bytes
print(f"[Worker {os.getpid()}] Decoding image...", flush=True)
sys.stdout.flush()
image = _decode_image(image_bytes)
if image is None:
return {"success": False, "error": "Failed to decode image"}
print(f"[Worker {os.getpid()}] Image decoded: shape={image.shape}, dtype={image.dtype}, size={image.nbytes/1024/1024:.1f}MB", flush=True)
sys.stdout.flush()
# Import preprocessor
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
@@ -116,22 +172,36 @@ def process_ocr(
preprocessor = ImagePreprocessor()
extractor = ReceiptExtractor()
print(f"[Worker {os.getpid()}] Preprocessor and extractor initialized", flush=True)
sys.stdout.flush()
raw_texts = []
extraction = None
# Engine routing
if engine == "paddleocr":
extraction, raw_texts = _process_paddleocr_only(
image, paddle_engine, preprocessor, extractor
)
elif engine == "tesseract":
# Engine routing (available: tesseract, doctr, doctr_plus, paddleocr)
print(f"[Worker {os.getpid()}] Routing to engine: {engine}", flush=True)
sys.stdout.flush()
if engine == "tesseract":
extraction, raw_texts = _process_tesseract_only(
image, tesseract_engine, preprocessor, extractor
)
else: # auto
extraction, raw_texts = _process_adaptive(
image, paddle_engine, tesseract_engine, preprocessor, extractor
elif engine == "doctr":
extraction, raw_texts = _process_doctr_only(
image, doctr_engine, preprocessor, extractor
)
elif engine == "doctr_plus":
extraction, raw_texts = _process_doctr_plus(
image, doctr_engine, preprocessor, extractor
)
elif engine == "paddleocr":
extraction, raw_texts = _process_paddleocr_only(
image, paddle_engine, preprocessor, extractor
)
else:
# Default to doctr_plus if unknown engine specified
print(f"[OCR] Unknown engine '{engine}', defaulting to doctr_plus", flush=True)
extraction, raw_texts = _process_doctr_plus(
image, doctr_engine, preprocessor, extractor
)
# Calculate processing time
@@ -171,7 +241,11 @@ def process_ocr(
def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
"""Decode image from bytes (JPEG, PNG, or first page of PDF)."""
"""Decode image from bytes (JPEG, PNG, or first page of PDF).
For PDFs, uses 200 DPI which is sufficient for receipt OCR
and reduces processing time by ~50% vs 300 DPI.
"""
try:
# Try as regular image first
nparr = np.frombuffer(image_bytes, np.uint8)
@@ -180,18 +254,21 @@ def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
if image is not None:
return image
# Try as PDF
# Try as PDF - use 200 DPI for faster processing (sufficient for receipts)
try:
import pdf2image
from PIL import Image
images = pdf2image.convert_from_bytes(image_bytes, dpi=300)
# 200 DPI is sufficient for receipt text recognition
# 300 DPI was overkill and slowed down processing
images = pdf2image.convert_from_bytes(image_bytes, dpi=200)
if images:
# Convert first page to numpy array
pil_img = images[0]
print(f"[Worker {os.getpid()}] PDF decoded: {pil_img.width}x{pil_img.height} @ 200 DPI", flush=True)
return np.array(pil_img)
except Exception:
pass
except Exception as e:
print(f"[Worker {os.getpid()}] PDF decode error: {e}", flush=True)
return None
@@ -270,83 +347,275 @@ def _process_tesseract_only(
return extraction, raw_texts
def _process_adaptive(
def _process_doctr_only(
image: np.ndarray,
paddle_engine,
tesseract_engine,
doctr_engine,
preprocessor,
extractor
) -> Tuple[Any, List[str]]:
"""
Adaptive multi-pass OCR processing.
Process using docTR only (light + medium preprocessing).
Strategy:
1. PaddleOCR Light - fastest, best for clear PDFs
2. PaddleOCR Medium - if Light incomplete
3. Tesseract - complement missing fields only
Returns:
Tuple of (extraction_result, raw_texts_list)
docTR uses EXACT same preprocessing as PaddleOCR for consistency.
"""
raw_texts = []
extraction = None
# === STEP 1: PaddleOCR Light ===
if paddle_engine:
print("[OCR] Step 1: PaddleOCR + Light", flush=True)
light_img = preprocessor.preprocess_light(image)
paddle_light = _paddle_recognize(paddle_engine, light_img)
if doctr_engine is None:
return None, ["docTR not available"]
if paddle_light and paddle_light.text:
extraction = extractor.extract(paddle_light.text)
extraction.ocr_engine = "paddle-light"
raw_texts.append(f"=== PaddleOCR Light (conf: {paddle_light.confidence:.0%}) ===\n{paddle_light.text}")
# Step 1: Light preprocessing (same as PaddleOCR)
print("[OCR] Step 1: docTR + Light", flush=True)
light_img = preprocessor.preprocess_light(image)
doctr_light = _doctr_recognize(doctr_engine, light_img)
if _is_extraction_complete(extraction):
print("[OCR] Early exit - all fields found in Step 1", flush=True)
return extraction, raw_texts
if doctr_light and doctr_light.text:
extraction = extractor.extract(doctr_light.text)
extraction.ocr_engine = "doctr-light"
raw_texts.append(f"=== docTR Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
# === STEP 2: PaddleOCR Medium ===
if paddle_engine:
print("[OCR] Step 2: PaddleOCR + Medium", flush=True)
medium_img = preprocessor.preprocess_medium(image)
paddle_medium = _paddle_recognize(paddle_engine, medium_img)
if _is_extraction_complete(extraction):
return extraction, raw_texts
if paddle_medium and paddle_medium.text:
extraction_medium = extractor.extract(paddle_medium.text)
extraction_medium.ocr_engine = "paddle-medium"
raw_texts.append(f"=== PaddleOCR Medium (conf: {paddle_medium.confidence:.0%}) ===\n{paddle_medium.text}")
# Step 2: Medium preprocessing (same as PaddleOCR)
print("[OCR] Step 2: docTR + Medium", flush=True)
medium_img = preprocessor.preprocess_medium(image)
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
if extraction:
extraction = _merge_extractions(extraction, extraction_medium)
extraction.ocr_engine = "paddle-adaptive"
else:
extraction = extraction_medium
if doctr_medium and doctr_medium.text:
extraction_medium = extractor.extract(doctr_medium.text)
extraction_medium.ocr_engine = "doctr-medium"
raw_texts.append(f"=== docTR Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
if _is_extraction_complete(extraction):
print("[OCR] Early exit - all fields found after Step 2", flush=True)
return extraction, raw_texts
# === STEP 3: Tesseract (complement only) ===
if tesseract_engine:
print("[OCR] Step 3: Tesseract complement", flush=True)
tesseract_img = preprocessor.preprocess_for_tesseract(image)
tesseract_result = tesseract_engine.recognize(tesseract_img)
if tesseract_result and tesseract_result.text:
extraction_tess = extractor.extract(tesseract_result.text)
extraction_tess.ocr_engine = "tesseract"
raw_texts.append(f"=== Tesseract (conf: {tesseract_result.confidence:.0%}) ===\n{tesseract_result.text}")
if extraction:
extraction = _complement_extraction(extraction, extraction_tess)
extraction.ocr_engine = "adaptive-full"
else:
extraction = extraction_tess
if extraction:
extraction = _merge_extractions(extraction, extraction_medium)
extraction.ocr_engine = "doctr-adaptive"
else:
extraction = extraction_medium
return extraction, raw_texts
def _process_doctr_plus(
image: np.ndarray,
doctr_engine,
preprocessor,
extractor
) -> Tuple[Any, List[str]]:
"""
docTR Plus - Optimized 2-tier sequential processing with early exit.
Architecture:
- Tier 1: Light preprocessing (~4-5s)
→ Early exit if confidence >= 0.75 AND all fields valid AND cross-validations pass
- Tier 2: Medium preprocessing (only if Tier 1 insufficient, ~4-5s additional)
→ Merge with Tier 1 results
→ Mark for review if still problems
Performance:
- Fast path (80% receipts): ~4-5s (Tier 1 only)
- Slow path (20% receipts): ~8-9s (Tier 1 + Tier 2)
- Average: ~5-6s
Returns:
Tuple of (extraction_result, raw_texts_list)
extraction_result.needs_review = True if validation issues remain
"""
raw_texts = []
extraction = None
if doctr_engine is None:
return None, ["docTR not available"]
# ========== TIER 1: Light Preprocessing ==========
print("[docTR+] TIER 1: Light preprocessing", flush=True)
import time
tier1_start = time.time()
light_img = preprocessor.preprocess_light(image)
doctr_light = _doctr_recognize(doctr_engine, light_img)
tier1_time = time.time() - tier1_start
print(f"[docTR+] TIER 1 completed in {tier1_time:.1f}s", flush=True)
if doctr_light and doctr_light.text:
extraction = extractor.extract(doctr_light.text)
extraction.ocr_engine = "doctr-plus-light"
raw_texts.append(f"=== docTR+ Tier1/Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
# Early Exit Check: confidence >= 0.75 + cross-validations
if _is_extraction_valid_for_early_exit(extraction, min_confidence=0.75):
print(f"[docTR+] EARLY EXIT - Tier 1 sufficient (conf: {extraction.overall_confidence:.0%})", flush=True)
extraction.ocr_engine = "doctr-plus"
return extraction, raw_texts
print(f"[docTR+] Tier 1 incomplete or validation failed, proceeding to Tier 2...", flush=True)
# ========== TIER 2: Medium Preprocessing (only if needed) ==========
print("[docTR+] TIER 2: Medium preprocessing", flush=True)
tier2_start = time.time()
medium_img = preprocessor.preprocess_medium(image)
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
tier2_time = time.time() - tier2_start
print(f"[docTR+] TIER 2 completed in {tier2_time:.1f}s", flush=True)
if doctr_medium and doctr_medium.text:
extraction_medium = extractor.extract(doctr_medium.text)
extraction_medium.ocr_engine = "doctr-plus-medium"
raw_texts.append(f"=== docTR+ Tier2/Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
if extraction:
# Merge Tier 1 + Tier 2 results
extraction = _merge_extractions(extraction, extraction_medium)
else:
extraction = extraction_medium
# ========== FINAL VALIDATION ==========
if extraction:
extraction.ocr_engine = "doctr-plus"
# Mark for review if validation still fails after both tiers
passes_validation, penalty, errors = _quick_cross_validate(extraction)
if not passes_validation or extraction.overall_confidence < 0.75:
# Mark for human review using existing fields
extraction.needs_manual_review = True
if extraction.overall_confidence < 0.75:
extraction.validation_warnings.append(f"Low confidence: {extraction.overall_confidence:.0%}")
if not extraction.amount:
extraction.validation_errors.append("TOTAL not detected")
if not extraction.cui:
extraction.validation_warnings.append("CUI not detected")
if not extraction.tva_total and not extraction.tva_entries:
extraction.validation_warnings.append("TVA not detected")
if not extraction.receipt_date:
extraction.validation_warnings.append("Date not detected")
# Add cross-validation errors
extraction.validation_errors.extend(errors)
print(f"[docTR+] Marked for review: {extraction.validation_errors + extraction.validation_warnings}", flush=True)
else:
extraction.needs_manual_review = False
total_time = tier1_time + (tier2_time if 'tier2_time' in dir() else 0)
print(f"[docTR+] Total processing time: {total_time:.1f}s", flush=True)
return extraction, raw_texts
# =============================================================================
# VALIDATION HELPERS (used by doctr_plus for early exit decisions)
# =============================================================================
def _quick_cross_validate(extraction) -> tuple[bool, float, list[str]]:
"""
Quick cross-validation for OCR results.
Checks critical field correlations to detect obvious OCR errors.
Used by doctr_plus to decide whether to proceed to Tier 2 or exit early.
Returns:
Tuple of (passes_validation, confidence_penalty, error_messages)
"""
try:
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
if extraction is None:
return False, 1.0, ["No extraction result"]
# Convert extraction to dict for validation
# Build TVA entries dict for TVAEntriesSumRule (expects {code: amount})
tva_entries_dict = {}
if extraction.tva_entries:
for entry in extraction.tva_entries:
if isinstance(entry, dict):
code = entry.get('code', 'A')
amount = entry.get('amount', 0)
try:
tva_entries_dict[code] = float(amount)
except (TypeError, ValueError):
pass
validation_data = {
"amount": float(extraction.amount) if extraction.amount else None,
"tva": float(extraction.tva_total) if extraction.tva_total else None,
"tva_entries": tva_entries_dict, # For TVAEntriesSumRule: {code: amount}
"cui": extraction.cui, # For CUI checksum validation
}
# Also pass raw tva_entries for TVABasedTotalRule (for rate detection)
if extraction.tva_entries:
validation_data['tva_entries_raw'] = extraction.tva_entries
# Add payment methods if available (for TOTAL vs CARD+CASH validation)
if extraction.payment_methods:
try:
card_amount = sum(
float(p.get('amount', 0) if isinstance(p, dict) else 0)
for p in extraction.payment_methods
if isinstance(p, dict) and p.get('method') == 'CARD'
)
cash_amount = sum(
float(p.get('amount', 0) if isinstance(p, dict) else 0)
for p in extraction.payment_methods
if isinstance(p, dict) and p.get('method') == 'NUMERAR'
)
validation_data['card_amount'] = card_amount
validation_data['cash_amount'] = cash_amount
except Exception as e:
print(f"[Worker {os.getpid()}] Payment method validation error: {e}", flush=True)
# Run quick validation
validator = OCRValidationEngine()
return validator.quick_validate_for_hybrid(validation_data)
except Exception as e:
# Never crash the process on validation errors
print(f"[Worker {os.getpid()}] Cross-validation error: {e}", flush=True)
import traceback
traceback.print_exc()
# Return "passes" to allow processing to continue
return True, 0.0, [f"Validation skipped due to error: {str(e)}"]
def _is_extraction_valid_for_early_exit(extraction, min_confidence: float = 0.85) -> bool:
"""
Check if extraction is valid for early exit in doctr_plus.
Combines confidence check with cross-validation to prevent
early exit on OCR errors (e.g., wrong TOTAL but correct TVA).
Returns:
True only if:
1. Overall confidence >= min_confidence
2. Critical fields are present (AMOUNT, DATE, CUI)
3. Cross-validation passes (TOTAL matches TVA calculation, or no TVA)
"""
try:
# First check basic completeness (relaxed for early exit)
if not _is_extraction_complete(extraction, min_confidence, for_early_exit=True):
return False
# Then run cross-validation
passes_validation, penalty, errors = _quick_cross_validate(extraction)
if not passes_validation:
print(f"[Early Exit] BLOCKED: cross-validation failed: {errors}", flush=True)
return False
print(f"[Early Exit] OK: conf={extraction.overall_confidence:.0%}, validation passed", flush=True)
return True
except Exception as e:
# Never crash on validation - just continue to next engine
print(f"[Worker {os.getpid()}] Early exit check error: {e}", flush=True)
return False # Continue to next engine on error
def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
"""Run PaddleOCR recognition on image."""
try:
@@ -388,34 +657,191 @@ def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
return None
def _is_extraction_complete(ext, min_confidence: float = 0.85) -> bool:
"""Check if extraction has all required fields."""
def _doctr_recognize(doctr_engine, image: np.ndarray) -> Optional[OCRResult]:
"""
Run docTR recognition on image.
docTR requires RGB images, handles conversion automatically.
Uses same preprocessing as PaddleOCR for consistent results.
"""
try:
# docTR requires RGB images
if len(image.shape) == 2:
# Convert grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 3:
# Convert BGR (OpenCV) to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif image.shape[2] == 4:
# Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# docTR expects a list of numpy arrays (pages)
result = doctr_engine([image])
if not result or not result.pages:
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
# Extract text from all pages
all_texts = []
all_confidences = []
boxes = []
for page in result.pages:
for block in page.blocks:
for line in block.lines:
line_text = ' '.join(word.value for word in line.words)
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
all_texts.append(line_text)
all_confidences.append(line_confidence)
# Store word-level boxes
for word in line.words:
boxes.append({
'text': word.value,
'confidence': float(word.confidence),
'box': word.geometry # (xmin, ymin), (xmax, ymax)
})
text_result = '\n'.join(all_texts)
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="doctr"
)
except Exception as e:
print(f"[Worker] docTR error: {e}", flush=True)
return None
def _is_extraction_complete(ext, min_confidence: float = 0.85, for_early_exit: bool = False) -> bool:
"""
Check if extraction has required fields.
Args:
ext: Extraction result
min_confidence: Minimum overall confidence
for_early_exit: If True, use relaxed criteria (AMOUNT + DATE + CUI required)
If False, require all fields (strict mode for final validation)
Returns:
True if extraction meets completeness criteria
"""
# Check confidence first
if ext.overall_confidence < min_confidence:
if for_early_exit:
print(f"[Early Exit] BLOCKED: confidence {ext.overall_confidence:.0%} < {min_confidence:.0%}", flush=True)
return False
has_number = bool(ext.receipt_number)
has_date = bool(ext.receipt_date)
has_amount = bool(ext.amount)
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
has_cui = bool(ext.cui)
return all([has_number, has_date, has_amount, has_tva, has_cui])
if for_early_exit:
# Relaxed criteria for early exit:
# - AMOUNT is required (core field)
# - DATE is required (needed for accounting)
# - CUI is required (needed for supplier identification)
# - TVA is NOT required (some receipts have 0% TVA)
# - receipt_number is NOT required (often missing)
required_ok = all([has_amount, has_date, has_cui])
if not required_ok:
missing = []
if not has_amount: missing.append("AMOUNT")
if not has_date: missing.append("DATE")
if not has_cui: missing.append("CUI")
print(f"[Early Exit] BLOCKED: missing required fields: {', '.join(missing)}", flush=True)
return required_ok
else:
# Strict criteria for final validation (all fields required)
has_number = bool(ext.receipt_number)
return all([has_number, has_date, has_amount, has_tva, has_cui])
def _merge_extractions(primary, secondary):
"""Merge two extractions, picking best fields from each."""
"""Merge two extractions, picking best fields from each.
Primary should be the higher-quality engine (e.g., docTR).
Secondary is the fallback engine (e.g., Tesseract).
Priority logic:
- AMOUNT: TVA validation wins over confidence. If both valid or both invalid,
uses confidence (or TVA diff for invalid cases).
- DATE/CUI: Validation-based, then confidence, then primary wins ties.
- OTHER FIELDS: Primary wins when both have values.
"""
from backend.modules.data_entry.services.ocr_extractor import ExtractionResult
result = ExtractionResult()
# Amount - prefer higher confidence
# Helper: Check if amount matches TVA calculation
def amount_passes_tva_validation(amount, tva_total, tva_entries):
if not amount or not tva_total:
return False, 0.0
try:
tva_rate = 0.21 # Default Romanian TVA
if tva_entries:
for entry in tva_entries:
if isinstance(entry, dict) and entry.get('percent'):
tva_rate = float(entry['percent']) / 100.0
break
# Expected TOTAL = TVA / rate * (1 + rate)
expected = float(tva_total) * (1 + tva_rate) / tva_rate
actual = float(amount)
diff_percent = abs(actual - expected) / expected if expected > 0 else 1.0
return diff_percent < 0.03, diff_percent # 3% tolerance
except:
return False, 1.0
# Amount - prefer TVA-validated value over confidence
if primary.amount and secondary.amount:
if primary.confidence_amount >= secondary.confidence_amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
else:
# Get TVA from the one with entries, or use any available
tva_total = primary.tva_total or secondary.tva_total
tva_entries = primary.tva_entries or secondary.tva_entries
primary_valid, primary_diff = amount_passes_tva_validation(
primary.amount, tva_total, tva_entries
)
secondary_valid, secondary_diff = amount_passes_tva_validation(
secondary.amount, tva_total, tva_entries
)
print(f"[Merge] Amount comparison: primary={primary.amount} (valid={primary_valid}, diff={primary_diff:.1%}), "
f"secondary={secondary.amount} (valid={secondary_valid}, diff={secondary_diff:.1%})", flush=True)
if secondary_valid and not primary_valid:
# Secondary passes validation, primary doesn't - use secondary!
print(f"[Merge] Using secondary amount {secondary.amount} (passes TVA validation)", flush=True)
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
elif primary_valid and not secondary_valid:
# Primary passes validation
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
elif primary_valid and secondary_valid:
# Both valid - use higher confidence
if primary.confidence_amount >= secondary.confidence_amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
else:
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
else:
# Neither valid - use the one closer to TVA calculation
if secondary_diff < primary_diff:
print(f"[Merge] Neither valid, using secondary {secondary.amount} (closer to TVA)", flush=True)
result.amount = secondary.amount
result.confidence_amount = secondary.confidence_amount
else:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
elif primary.amount:
result.amount = primary.amount
result.confidence_amount = primary.confidence_amount
@@ -438,13 +864,15 @@ def _merge_extractions(primary, secondary):
result.receipt_date = secondary.receipt_date
result.confidence_date = secondary.confidence_date
# CUI - prefer valid format
# CUI - prefer valid format and version with RO prefix
# Use CUIChecksumRule static methods (single source of truth)
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
def is_valid_cui(cui):
if not cui:
return False
import re
cui_clean = re.sub(r'^RO', '', cui.upper())
return bool(re.match(r'^\d{6,10}$', cui_clean))
digits = CUIChecksumRule.extract_digits(cui)
return len(digits) >= 6 and len(digits) <= 10
if primary.cui and secondary.cui:
if is_valid_cui(primary.cui) and not is_valid_cui(secondary.cui):
@@ -452,22 +880,27 @@ def _merge_extractions(primary, secondary):
elif is_valid_cui(secondary.cui) and not is_valid_cui(primary.cui):
result.cui = secondary.cui
else:
result.cui = primary.cui
# Both valid - prefer the one with RO prefix if digits match
primary_digits = CUIChecksumRule.extract_digits(primary.cui)
secondary_digits = CUIChecksumRule.extract_digits(secondary.cui)
if primary_digits == secondary_digits:
if CUIChecksumRule.has_ro_prefix(secondary.cui) and not CUIChecksumRule.has_ro_prefix(primary.cui):
result.cui = secondary.cui # Prefer version with RO
print(f"[CUI Complement] Preferring secondary with RO: {secondary.cui}", flush=True)
else:
result.cui = primary.cui
else:
result.cui = primary.cui
elif primary.cui:
result.cui = primary.cui
elif secondary.cui:
result.cui = secondary.cui
# TVA entries
# TVA entries - ALWAYS prefer primary (docTR) when both have entries
if primary.tva_entries and secondary.tva_entries:
primary_total = sum(e.get('amount', Decimal('0')) for e in primary.tva_entries)
secondary_total = sum(e.get('amount', Decimal('0')) for e in secondary.tva_entries)
if primary_total >= secondary_total:
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
else:
result.tva_entries = secondary.tva_entries
result.tva_total = secondary.tva_total
# Always use primary (docTR) - higher quality OCR
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
elif primary.tva_entries:
result.tva_entries = primary.tva_entries
result.tva_total = primary.tva_total
@@ -483,12 +916,36 @@ def _merge_extractions(primary, secondary):
result.address = primary.address or secondary.address
result.items_count = primary.items_count or secondary.items_count
result.payment_methods = primary.payment_methods or secondary.payment_methods
result.suggested_payment_mode = getattr(primary, 'suggested_payment_mode', None) or getattr(secondary, 'suggested_payment_mode', None)
# Client fields
result.client_name = primary.client_name or secondary.client_name
result.client_cui = primary.client_cui or secondary.client_cui
result.client_address = primary.client_address or secondary.client_address
# Confidence fields - preserve from primary or pick best
if primary.confidence_vendor >= secondary.confidence_vendor:
result.confidence_vendor = primary.confidence_vendor
else:
result.confidence_vendor = secondary.confidence_vendor
if hasattr(primary, 'confidence_client') and hasattr(secondary, 'confidence_client'):
if primary.confidence_client >= secondary.confidence_client:
result.confidence_client = primary.confidence_client
else:
result.confidence_client = secondary.confidence_client
# Raw text - combine both for debugging/display
raw_texts = []
if primary.raw_text:
raw_texts.append(primary.raw_text)
if secondary.raw_text and secondary.raw_text != primary.raw_text:
raw_texts.append(secondary.raw_text)
result.raw_text = '\n---\n'.join(raw_texts) if raw_texts else ''
# Note: overall_confidence is a computed @property on ExtractionResult
# It automatically calculates from confidence_amount, confidence_date, confidence_vendor
return result
@@ -557,6 +1014,7 @@ def _extraction_to_dict(extraction) -> dict:
"address": extraction.address,
"items_count": extraction.items_count,
"payment_methods": extraction.payment_methods,
"suggested_payment_mode": getattr(extraction, 'suggested_payment_mode', None),
# Client data
"client_name": extraction.client_name,
"client_cui": extraction.client_cui,

View File

@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
result = rule.validate({"cui": "R01879855"})
# result.is_valid = False (checksum mismatch)
Static methods available for direct use:
CUIChecksumRule.calculate_checksum("1056260") -> 0
CUIChecksumRule.validate_checksum("10562600") -> True
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
"""
# Fixed multipliers for 9 positions (Romanian Mod 11)
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
@staticmethod
def calculate_checksum(cui_base: str) -> int:
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
Args:
cui_base: CUI digits WITHOUT the checksum digit (last digit)
Returns:
Expected checksum digit (0-9), or -1 if invalid input
"""
if not cui_base or not cui_base.isdigit():
return -1
# Pad base to 9 digits from LEFT
base_padded = cui_base.zfill(9)
base_digits = [int(d) for d in base_padded]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
# Calculate checksum
checksum = (weighted_sum * 10) % 11
if checksum == 10:
checksum = 0
return checksum
@staticmethod
def validate_checksum(cui_digits: str) -> bool:
"""Check if CUI checksum is valid.
Args:
cui_digits: Full CUI digits (including checksum as last digit)
Returns:
True if checksum is valid, False otherwise
"""
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
return False
base = cui_digits[:-1]
declared = int(cui_digits[-1])
expected = CUIChecksumRule.calculate_checksum(base)
return expected == declared
@staticmethod
def has_ro_prefix(cui: str) -> bool:
"""Check if CUI has RO prefix (proper format for VAT payers)."""
if not cui:
return False
return cui.upper().strip().startswith('RO')
@staticmethod
def extract_digits(cui: str) -> str:
"""Extract digits from CUI, removing RO/R0 prefix."""
if not cui:
return ""
cui = cui.strip().upper()
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui = cui[2:]
return ''.join(c for c in cui if c.isdigit())
@property
def rule_name(self) -> str:
return "CUI Checksum Check (Mod 11)"
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
message="No CUI to validate"
)
# Normalize: remove RO/R0 prefix
cui_clean = cui.strip().upper()
if cui_clean.startswith("RO"):
cui_clean = cui_clean[2:]
elif cui_clean.startswith("R0"):
cui_clean = cui_clean[2:]
# Use static method to extract digits
cui_clean = CUIChecksumRule.extract_digits(cui)
# Check format first
if not cui_clean.isdigit():
if not cui_clean:
return ValidationResult(
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
message="CUI format invalid, skipping checksum"
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
message="CUI length invalid, skipping checksum"
)
# Extract digits
digits = [int(d) for d in cui_clean]
checksum_declared = digits[-1]
base_digits = digits[:-1]
# Multipliers (trim to match base_digits length)
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
multipliers = multipliers[:len(base_digits)]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
# Calculate expected checksum
checksum_calculated = (weighted_sum * 10) % 11
if checksum_calculated == 10:
checksum_calculated = 0
if checksum_calculated != checksum_declared:
# Use static method to validate checksum
if not CUIChecksumRule.validate_checksum(cui_clean):
# Calculate expected for error message
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
declared = int(cui_clean[-1])
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
severity="warning"
)
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
)
class TVABasedTotalRule(ValidationRule):
"""Validate TOTAL using reverse calculation from TVA amount.
This is a CRITICAL validation that catches cases where OCR extracts
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
we can calculate expected TOTAL from TVA alone.
Formula:
Expected TOTAL = TVA / rate * (1 + rate)
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
For TVA rate 21%:
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
Example (benzina 27 oct):
TVA = 49.58, rate = 21%
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
Extracted TOTAL = 205.66 (WRONG!)
Rule detects mismatch and flags for escalation
Usage in multi-tier processing (e.g., doctr_plus):
If this rule fails, the engine should proceed to next tier
instead of returning early with potentially wrong data.
"""
def __init__(self, tolerance_percent: float = 0.02):
"""
Args:
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
"""
self.tolerance_percent = tolerance_percent
@property
def rule_name(self) -> str:
return "TVA-Based Total Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
total = data.get("amount")
tva = data.get("tva")
tva_entries = data.get("tva_entries", [])
if not total or not tva:
return ValidationResult(
is_valid=True,
message="Insufficient data for TVA-based total validation"
)
# Type safety
try:
total = float(total)
tva = float(tva)
except (TypeError, ValueError):
return ValidationResult(
is_valid=True,
message="Non-numeric values, skipping TVA-based total validation"
)
if tva <= 0 or total <= 0:
return ValidationResult(
is_valid=True,
message="Zero or negative values, skipping TVA-based total validation"
)
# Try to determine TVA rate from entries
tva_rate = None
# Check tva_entries for rate information
if tva_entries:
for entry in tva_entries:
if isinstance(entry, dict):
percent = entry.get('percent')
if percent:
try:
tva_rate = float(percent) / 100.0
break
except (TypeError, ValueError):
pass
# Fallback: try to calculate rate from TVA/TOTAL ratio
if not tva_rate:
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
# TVA/TOTAL = rate / (1 + rate)
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
base = total - tva
if base > 0:
calculated_rate = tva / base
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
if 0.04 <= calculated_rate <= 0.25:
tva_rate = calculated_rate
if not tva_rate:
# Assume most common rate: 21%
tva_rate = 0.21
# Calculate expected TOTAL from TVA
# TVA = BASE * rate → BASE = TVA / rate
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
expected_total = tva * (1 + tva_rate) / tva_rate
# Calculate difference
diff = abs(total - expected_total)
diff_percent = diff / expected_total if expected_total > 0 else 1.0
if diff_percent > self.tolerance_percent:
# Significant mismatch - OCR likely extracted TOTAL wrong
return ValidationResult(
is_valid=False,
confidence_penalty=0.5, # High penalty - this is a critical error
message=(
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
f"TVA-calculated {expected_total:.2f} RON "
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
f"Likely OCR error on TOTAL."
),
severity="error"
)
return ValidationResult(
is_valid=True,
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
)
class InterOCRConsistencyRule(ValidationRule):
"""Validate consistency between multiple OCR results.
@@ -562,6 +741,7 @@ class OCRValidationEngine:
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
PaymentSumRule(tolerance=0.02),
TVAEntriesSumRule(tolerance=0.02),
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
]
# Inter-OCR consistency rules
@@ -699,39 +879,508 @@ class OCRValidationEngine:
inter_ocr_ratios=inter_ocr_ratios
)
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
Runs critical cross-validation rules to detect obvious OCR errors.
Used to decide whether to proceed to next processing tier or exit early.
Args:
extraction_result: Extraction data dict with fields:
- amount: Extracted TOTAL
- tva: Extracted TVA total
- tva_entries: List of TVA entries with rates
Returns:
Tuple of (passes_validation, confidence_penalty, error_messages)
- passes_validation: True if no critical errors detected
- confidence_penalty: Cumulative penalty (0.0-1.0)
- error_messages: List of validation error messages
Example usage:
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
if not passes:
print(f"Validation failed: {errors}, proceeding to next tier")
# Continue to next processing tier instead of early exit
"""
errors = []
total_penalty = 0.0
# Critical rules for early-exit decision-making
# These determine if we can trust the extraction or need to proceed to next tier
critical_rules = [
# Cross-field validations (most important for detecting OCR errors)
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
# Format & checksum validations
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
CUIFormatRule(), # CUI should be 6-10 digits
# Sanity checks
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
]
for rule in critical_rules:
result = rule.validate(extraction_result)
if not result.is_valid:
errors.append(result.message)
total_penalty += result.confidence_penalty
# Cap penalty at 1.0
total_penalty = min(1.0, total_penalty)
passes = len(errors) == 0
return passes, total_penalty, errors
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
@staticmethod
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
"""Try to repair CUI by attempting 1-digit corrections.
OCR often misreads similar-looking digits:
- 5 ↔ 8 (most common in receipts)
- 6 ↔ 0
- 1 ↔ 7
- 3 ↔ 8
Algorithm:
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
2. Skip first digit (position 0) - usually reliable in CUI
3. Check checksum digit (last position) last
4. Prefer common OCR digit confusions (5↔8, 6↔0)
Args:
cui_digits: Original CUI digits (without RO prefix)
Returns:
Repaired CUI digits if 1-digit fix found, else None
"""
if len(cui_digits) < 6 or not cui_digits.isdigit():
return None
# If already valid, return as-is
if CUIChecksumRule.validate_checksum(cui_digits):
return cui_digits
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then position 1, then 0, then checksum
# Skip position 0 (first digit) - it's usually reliable
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
if CUIChecksumRule.validate_checksum(candidate):
print(f"[CUI Repair] Fixed {cui_digits}{candidate} (position {pos}: {original_digit}{replacement})", flush=True)
return candidate
# No single-digit fix found
return None
@staticmethod
def normalize_cui(cui: Optional[str]) -> Optional[str]:
"""Normalize CUI to RO prefix + digits format.
"""Normalize CUI - fix OCR errors but preserve original format.
Rules:
- R0 → RO (fix OCR error where O is read as 0)
- Keep RO prefix if original had it (platitor TVA)
- Do NOT add RO if original didn't have it (neplatitor TVA)
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
Examples:
10562600 → RO10562600
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
R010562600 → RO10562600 (fix R0 OCR error)
RO10562600 → RO10562600 (unchanged)
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
Args:
cui: Raw CUI string from OCR
Returns:
Normalized CUI with RO prefix, or None if invalid
Normalized CUI, or None if invalid
"""
if not cui:
return None
cui = cui.strip().upper()
# Remove existing prefix if present
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"):
cui = cui[2:]
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui if c.isdigit())
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
return None
# Add RO prefix
return f"RO{cui_digits}"
# Try to repair checksum if invalid
if not CUIChecksumRule.validate_checksum(cui_digits):
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
if repaired:
cui_digits = repaired
# Return with RO prefix only if original had it
if had_ro_prefix:
return f"RO{cui_digits}"
else:
return cui_digits
@staticmethod
async def fuzzy_match_cui_from_db(
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match CUI against database of known suppliers.
This function:
1. Validates CUI checksum
2. If valid, looks up in database (exact match)
3. If invalid, tries 1-digit corrections and looks up each candidate
4. Returns the first match found in database
Args:
cui: Extracted CUI from OCR (may be invalid)
db_session: SQLAlchemy async session for database lookups
Returns:
Tuple of (corrected_cui, supplier_name) if found, else None
Usage in OCR extraction:
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
if match:
corrected_cui, supplier_name = match
"""
from sqlalchemy import select, or_
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
if not cui:
return None
cui = cui.strip().upper()
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
return None
# Helper to format CUI with optional RO prefix
def format_cui(digits: str) -> str:
if had_ro_prefix:
return f"RO{digits}"
return digits
# Helper to search database for CUI
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
"""Search both synced and local suppliers for CUI."""
# Search patterns: with and without RO prefix
search_patterns = [digits, f"RO{digits}"]
# Search synced_suppliers first (more data)
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
or_(
SyncedSupplier.fiscal_code == digits,
SyncedSupplier.fiscal_code == f"RO{digits}",
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
or_(
LocalSupplier.fiscal_code == digits,
LocalSupplier.fiscal_code == f"RO{digits}",
LocalSupplier.fiscal_code == digits.lstrip('0'),
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
return None
# 1. If checksum is valid, check if it exists in database (exact match)
if CUIChecksumRule.validate_checksum(cui_digits):
match = await lookup_cui_in_db(cui_digits)
if match:
print(f"[Fuzzy CUI] Exact match found: {cui}{match[0]} ({match[1]})", flush=True)
return match
# Valid checksum but not in DB - return as-is (it might be a new supplier)
return None
# 2. Invalid checksum - try 1-digit corrections and verify against database
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then ends
middle_positions = list(range(2, last_pos))
position_order = middle_positions + [1, last_pos, 0]
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
# Only consider if checksum is valid
if not CUIChecksumRule.validate_checksum(candidate):
continue
# Check if this corrected CUI exists in database
match = await lookup_cui_in_db(candidate)
if match:
print(f"[Fuzzy CUI] DB match: {cui}{match[0]} ({match[1]}) [pos {pos}: {original_digit}{replacement}]", flush=True)
return match
# No match found in database
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
return None
@staticmethod
async def fuzzy_match_by_name_and_cui(
vendor_name: Optional[str],
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match supplier by NAME, then narrow down by CUI.
Algorithm:
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
2. Search suppliers by fuzzy name match (LIKE %name%)
3. If multiple results, use fuzzy CUI matching to pick best one
4. Return the best match
Args:
vendor_name: Extracted vendor name from OCR
cui: Extracted CUI from OCR (may be invalid/incomplete)
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
from sqlalchemy import select, or_, func
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
import re
if not vendor_name or len(vendor_name) < 3:
return None
# Normalize vendor name for search
def normalize_name(name: str) -> str:
"""Normalize name for fuzzy matching."""
name = name.upper()
# Remove company type suffixes
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
name = name.replace(suffix, '')
# Remove punctuation and extra spaces
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
name = ' '.join(name.split())
return name.strip()
# Extract key words from vendor name (for fuzzy search)
normalized_name = normalize_name(vendor_name)
name_words = [w for w in normalized_name.split() if len(w) >= 3]
if not name_words:
return None
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
# Build search pattern - use first significant word
primary_word = name_words[0]
search_pattern = f"%{primary_word}%"
candidates = []
# Search synced_suppliers
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
func.upper(SyncedSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
func.upper(LocalSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
if not candidates:
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
return None
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
# If only one candidate, return it
if len(candidates) == 1:
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
# Multiple candidates - try to narrow down by CUI
if cui:
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
if len(cui_digits) >= 6:
# Score each candidate by how similar their CUI is to the extracted one
def cui_similarity(candidate_cui: str) -> int:
"""Calculate how many digits match in the same position."""
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
if len(cand_digits) != len(cui_digits):
return 0
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
# Sort candidates by CUI similarity (descending)
scored = [(cui_similarity(c[0]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
best_score, best_match = scored[0]
# Require at least 70% digit match for CUI similarity
min_matching = int(len(cui_digits) * 0.7)
if best_score >= min_matching:
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
return best_match
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
# If still multiple and no CUI match, try name similarity
def name_similarity(candidate_name: str) -> int:
"""Count how many keywords match."""
norm_cand = normalize_name(candidate_name)
return sum(1 for w in name_words if w in norm_cand)
scored = [(name_similarity(c[1]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
if scored[0][0] >= 2: # At least 2 keywords match
best_match = scored[0][1]
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
return best_match
# Return first candidate if nothing else works
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
@staticmethod
async def fuzzy_match_supplier(
cui: Optional[str],
vendor_name: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
Strategy:
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
Args:
cui: Extracted CUI from OCR (may be invalid/incomplete)
vendor_name: Extracted vendor name from OCR
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
# Step 1: Try fuzzy CUI matching
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
if cui_match:
return cui_match
# Step 2: Fallback to fuzzy NAME + CUI matching
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
vendor_name, cui, db_session
)
if name_match:
return name_match
return None

View File

@@ -1,4 +1,4 @@
"""OCR engine wrapper for PaddleOCR and Tesseract."""
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
import os
import logging
@@ -9,9 +9,8 @@ from typing import List, Optional, Tuple
import numpy as np
# Setup logging
# Setup logging (respects LOG_LEVEL env var set in main.py)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Ensure logs are visible
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
@@ -19,6 +18,7 @@ os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Lazy imports - these will be imported on first use
PaddleOCR = None # Will be imported lazily
pytesseract = None # Will be imported lazily
doctr_ocr_predictor = None # Will be imported lazily
# Check availability without importing heavy libraries
def _check_paddle_available() -> bool:
@@ -37,8 +37,17 @@ def _check_tesseract_available() -> bool:
except Exception:
return False
def _check_doctr_available() -> bool:
"""Check if doctr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("doctr") is not None
except Exception:
return False
PADDLE_AVAILABLE = _check_paddle_available()
TESSERACT_AVAILABLE = _check_tesseract_available()
DOCTR_AVAILABLE = _check_doctr_available()
@dataclass
@@ -59,6 +68,11 @@ class OCREngine:
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
self._paddle_init_lock = threading.Lock()
self._doctr = None
self._doctr_init_started = False
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
self._doctr_init_lock = threading.Lock()
def _init_paddle_lazy(self):
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
global PaddleOCR
@@ -94,6 +108,78 @@ class OCREngine:
# Signal that initialization is complete (success or failure)
self._paddle_ready.set()
def _init_doctr_lazy(self):
"""Lazy initialize docTR on first use (avoids slow startup)."""
global doctr_ocr_predictor
with self._doctr_init_lock:
if self._doctr_init_started:
return # Already initializing or done
self._doctr_init_started = True
if DOCTR_AVAILABLE:
try:
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
print("Initializing docTR engine (PyTorch backend)...", flush=True)
# Initialize docTR predictor with pretrained models
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
self._doctr = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True,
assume_straight_pages=True,
straighten_pages=False,
preserve_aspect_ratio=True,
)
doctr_ocr_predictor = self._doctr
print("docTR initialized successfully with PyTorch backend", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
self._doctr = None
# Signal that initialization is complete (success or failure)
self._doctr_ready.set()
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
"""
Wait for docTR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if docTR is ready, False if timeout or unavailable
"""
if not DOCTR_AVAILABLE:
return False
if self._doctr is not None:
return True # Already ready
if not self._doctr_init_started:
# Start initialization if not already started
self._init_doctr_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._doctr_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._doctr is not None:
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_doctr_ready(self) -> bool:
"""Check if docTR is ready without waiting."""
return self._doctr is not None
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
"""
Wait for PaddleOCR to be fully initialized.
@@ -239,6 +325,84 @@ class OCREngine:
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using docTR."""
# Wait for docTR to be fully ready
if not self.wait_for_doctr(timeout=30.0):
logger.warning("[docTR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("docTR not ready and Tesseract not available")
try:
logger.info(f"[docTR] Processing image, shape: {image.shape}")
# docTR requires RGB images
import cv2
if len(image.shape) == 2:
# Convert grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
elif image.shape[2] == 4:
# Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
elif image.shape[2] == 3:
# Check if BGR (from OpenCV) and convert to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
# Process image with docTR
logger.info("[docTR] Running prediction...")
from doctr.io import DocumentFile
# docTR expects a document (list of pages as numpy arrays)
result = self._doctr([image])
if not result or not result.pages:
logger.warning("[docTR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
# Extract text from all pages
all_texts = []
all_confidences = []
boxes = []
for page in result.pages:
for block in page.blocks:
for line in block.lines:
line_text = ' '.join(word.value for word in line.words)
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
all_texts.append(line_text)
all_confidences.append(line_confidence)
# Store word-level boxes
for word in line.words:
boxes.append({
'text': word.value,
'confidence': float(word.confidence),
'box': word.geometry # (xmin, ymin), (xmax, ymax)
})
text_result = '\n'.join(all_texts)
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="doctr"
)
except Exception as e:
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
"""
Run both OCR engines and return both results.
@@ -286,10 +450,27 @@ class OCREngine:
@staticmethod
def get_available_engines() -> List[str]:
"""Return list of available OCR engines."""
"""
Return list of available OCR engines.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Engines that are disabled via .env are not returned even if installed.
Available engines: tesseract, doctr, doctr_plus, paddleocr
"""
# Check .env settings
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
engines = []
if PADDLE_AVAILABLE:
engines.append('paddleocr')
if TESSERACT_AVAILABLE:
# Base engines (only if installed AND enabled)
if TESSERACT_AVAILABLE and tesseract_enabled:
engines.append('tesseract')
if DOCTR_AVAILABLE:
engines.append('doctr')
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
if PADDLE_AVAILABLE and paddle_enabled:
engines.append('paddleocr')
return engines

View File

@@ -6,6 +6,8 @@ from decimal import Decimal, InvalidOperation
from typing import Optional, Tuple, List
from dataclasses import dataclass, field
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
@dataclass
class ExtractionResult:
@@ -24,6 +26,7 @@ class ExtractionResult:
address: Optional[str] = None
items_count: Optional[int] = None
payment_methods: List[dict] = field(default_factory=list) # [{"method":"CARD","amount":Decimal}]
suggested_payment_mode: Optional[str] = None # 'banca' if CARD detected, 'numerar' if cash only
# Client data (for B2B receipts - buyer information)
client_name: Optional[str] = None
@@ -125,8 +128,10 @@ class ReceiptExtractor:
(r'C3POS[-A-Z0-9]*[N:](\d{6,7})', 0.98), # CT2N1360760 format
(r'C3POS.*?(\d{6,7})\b', 0.95), # Any C3POS followed by 6-7 digit number
(r'CT2[N:]\s*(\d{6,})', 0.95), # CT2N prefix
# BF (Bon Fiscal) number
(r'BF\s*:?\s*(\d+)', 0.93),
# BF (Bon Fiscal) number - high priority
# Format: "Z:0864 BF:0018" - extract only the number after BF:
(r'BF\s*:\s*(\d{4,})', 0.96), # BF: with colon (most specific)
(r'BF\s+(\d{4,})', 0.93), # BF followed by space and number
# NIVS format
(r'NIVS\s*:?\s*(\d+)', 0.95),
# Standard NR BON formats
@@ -151,28 +156,45 @@ class ReceiptExtractor:
# OCR errors: R0 instead of RO, C1F instead of CIF
CUI_PATTERNS = [
# CIF at start of line (definitely vendor) - tolerant to OCR errors
(r'^CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.98),
(r'^C[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # C1F OCR error
# NOTE: Capture full CUI including RO prefix: (R[O0]?\d{6,10}) or ((?:R[O0])?\d{6,10})
(r'^CIF\s*:?\s*(R[O0]?\d{6,10})', 0.98),
(r'^CIF\s*:?\s*(\d{6,10})', 0.97), # Without RO prefix
(r'^C[I1]F\s*:?\s*(R[O0]?\d{6,10})', 0.95), # C1F OCR error
(r'^C[I1]F\s*:?\s*(\d{6,10})', 0.94), # C1F without RO
# CIF not preceded by CLIENT (negative lookbehind)
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(R[O0]?\d{6,10})', 0.95),
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(\d{6,10})', 0.94),
# Standalone CIF: format with OCR tolerance
(r'\bC[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})\b', 0.90),
(r'\bC[I1]F\s*:?\s*(R[O0]?\d{6,10})\b', 0.90),
(r'\bC[I1]F\s*:?\s*(\d{6,10})\b', 0.89),
# COD FISCAL (vendor)
(r'COD\s+FISCAL\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
(r'COD\s+FISCAL\s*:?\s*(R[O0]?\d{6,10})', 0.90),
(r'COD\s+FISCAL\s*:?\s*(\d{6,10})', 0.89),
# C. I. F. format with SPACES (OCR artifact) - "C. I. F. : R011201891"
(r'C\.\s*I\.\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
# Also handles double colon from OMV/Petrom: "C. I.F.: : RO11201891"
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(R[O0]?\d{6,10})', 0.92),
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(\d{6,10})', 0.91),
# C.I.F. format (with dots, no spaces)
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.88),
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.88),
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(\d{6,10})', 0.87),
# CUI format (less specific, use with caution)
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.85),
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.85),
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(\d{6,10})', 0.84),
# Lidl format: "Cod Identificare fiscala: RO..." (OCR corrupts to "Ced Identificanfliscalar")
# Matches: "Identificare fiscala", "Identificanfliscalar", "Identificoan/Fljscales"
(r'[IC](?:od|ed)\s*Identific[a-z/]*\s*(R[O0]\d{6,10})', 0.90),
# Generic: anything with "fiscal" followed by RO + digits
(r'fiscal[a-z]*\s*:?\s*(R[O0]\d{6,10})', 0.85),
]
# Pattern for CIF NUMBER appearing BEFORE "C.I.F." label (reversed format)
# Common in some receipts: "R011201891\nC. I. F." - number on line before label
# Common in some receipts: "RO11201891\nC. I. F." - number on line before label
# IMPORTANT: Capture the full CUI including RO prefix
CUI_REVERSED_PATTERNS = [
# RO + 8-10 digits on line immediately before C.I.F./CIF label
(r'(?:R[O0])(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
# Just digits before C.I.F. label
# RO/R0 + 6-10 digits on line immediately before C.I.F./CIF label
# Capture the FULL CUI including RO prefix
(r'(R[O0]\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
# Just digits before C.I.F. label (neplatitor TVA - no RO prefix)
(r'(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.95),
]
@@ -185,38 +207,67 @@ class ReceiptExtractor:
(r'(?:^|\s)BF\s*:\s*(\d{4})', 0.85),
]
# TVA (VAT) patterns - OCR may produce TUA, TVR, etc.
# TVA (VAT) patterns - OCR may produce TUA, TVR, IVA, etc.
# All patterns are case-insensitive (re.IGNORECASE applied in extraction)
TVA_PATTERNS = [
# TOTAL TVA BON format (OCR tolerant: TUA, TVR)
(r'TOTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
(r'T[O0]TAL\s+T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.95),
# TOTAL TVA BON format (OCR tolerant: TUA, TVR, IVA)
(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
(r'T[O0]TAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.95),
# IVA variant (Spanish/Portuguese influence, some receipts)
(r'TOTAL\s+IVA\s*:?\s*([\d\s.,]+)', 0.95),
(r'IVA\s+[A-D]?\s*[-:]?\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.93),
# TVA with percentage (OCR tolerant)
(r'T[VU][AR]\s+(?:A\s*[-:]?\s*)?(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.95),
(r'T[VU][AR]\s+[A-Z]\s*[-:]\s*(\d{1,2})\s*%\s*([\d\s.,]+)', 0.93),
# Simple TVA pattern
(r'T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.85),
# 5% TVA rate (books, newspapers - TVA C)
(r'T[VU][AR]\s*[C5]\s*[-:]\s*5\s*%\s*:?\s*([\d\s.,]+)', 0.93),
(r'(?:T[VU][AR]|IVA)\s+5\s*%\s*:?\s*([\d\s.,]+)', 0.92),
# Garbled OCR: T0TAL, TVAI, TUAI, etc.
(r'T[O0]T[AE]L\s+(?:T[VUAI]+[AR]?|IVA)\s*:?\s*([\d\s.,]+)', 0.88),
# OCR corruption: "TA F 194" (TVA with V→F or space), "T A 19%"
# Handles: "TOTAL TA F 194" where TVA became "TA F"
(r'TOTAL\s+TA\s*[F\s]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.85),
(r'TA\s+[FA-Z]?\s*\d{1,2}\s*%?\s*:?\s*([\d\s.,]+)', 0.82),
# "TUA" with random letter after (OCR noise): "TUA F", "TUA I"
(r'T[VU]A\s+[A-Z]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.83),
# Simple TVA/IVA pattern
(r'(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.85),
# Standalone percentage line near TVA
(r'(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.75),
]
# Payment method patterns - appears after TOTAL LEI, before TOTAL TVA
# Format: "CARD: 50.00" or "NUMERAR 100.00" or "PLATA CARD: 50.00"
# OMV/Petrom uses "CARTE CREDIT" or "CARTE CREDIT 318, 16"
PAYMENT_METHOD_PATTERNS = [
# CARTE CREDIT with amount on same line (OMV/Petrom receipts)
# Handles: "CARTE CREDIT 318, 16" with OCR spaces in number
(r'CARTE\s+CREDIT\s*:?\s*([\d\s.,]+)', 'CARD', 0.98),
# CARTE CREDIT with amount on next line (OCR may split lines)
# Handles: "CARTE CREDIT\n318, 16"
(r'CARTE\s+CREDIT\s*:?\s*\n\s*([\d\s.,]+)', 'CARD', 0.97),
# CARD with amount (high confidence)
(r'(?:PLATA\s+)?CARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.95),
# Also handles OCR artifacts like "CARD F 100.00" where F is noise
(r'(?:PLATA\s+)?CARD\s*[:\sA-Z]?\s*([\d\s.,]+)', 'CARD', 0.95),
# NUMERAR (cash) with amount
(r'NUMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.95),
# CASH alternative spelling
(r'CASH\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.90),
# Truncation recovery patterns (for OCR left-margin truncation issues)
# IMPROVED: More restrictive - require max 6 digits before decimals
# to avoid matching CUI numbers like RO10562600 → RD10562600
# "RD" = truncated "CARD" (only 2 chars visible)
(r'\bRD\s*:?\s*([\d\s.,]+)', 'CARD', 0.70),
(r'(?:^|\n|\s)RD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.70),
# "ARD" = truncated "CARD" (3 chars visible)
(r'\bARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.75),
(r'(?:^|\n|\s)ARD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.75),
# "MERAR" = truncated "NUMERAR"
(r'\bMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.70),
(r'(?:^|\n|\s)MERAR\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'NUMERAR', 0.70),
]
# Maximum reasonable payment amount for a receipt (100,000 LEI)
# Amounts larger than this are likely OCR errors (e.g., CUI parsed as amount)
MAX_REASONABLE_PAYMENT = Decimal('100000')
# Items count patterns - OCR may produce OZ instead of POZ, etc.
# Number may be on separate line before or after the label
# IMPORTANT: Must be specific to avoid matching product quantities like "50BUC"
@@ -250,6 +301,9 @@ class ReceiptExtractor:
# Reversed format: CIF/CUI before CLIENT
r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:', # CIF CLIENT:
r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:', # CUI CLIENT:
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF CL IENT:"
r'C[I1]F\s+[A-Z\s]{0,6}IENT\s*:', # "CIF a IENT:", "CIF CL IENT:", "CIF LIENT:"
r'C[I1]F\s+LIENT\s*:', # "CIF LIENT:" (missing C from CLIENT)
# CLIENT followed by C.U.I./C.I.F. (all variations with/without spaces and dots)
# Handles: CLIENT C.U.I/C.I.F., CLIENT C. U. I./ C. I.F., CLIENT CUI/CIF
r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/?\s*C?\.?\s*[I1]?\.?\s*F?\.?\s*:',
@@ -267,6 +321,16 @@ class ReceiptExtractor:
# Client CUI patterns (explicitly after CLIENT marker)
# OCR errors: R0 instead of RO, C1F instead of CIF, 1 instead of I
CLIENT_CUI_PATTERNS = [
# NEW: CUI on line BEFORE CLIENT marker (docTR/OCR may output value before label)
# Pattern: "RO1879855\nCLIENT C.U.I./C.I.F.:" - CUI on line before CLIENT label
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*U\.?\s*[I1]\.?', 0.99),
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*[I1]\.?\s*F\.?', 0.99),
# Same but with optional colon after RO number
(r'(R[O0]\d{6,10})\s*:?\s*\n\s*CLIENT', 0.98),
# "CIF I CLIENT:" or "CIF IDENTIFICARE CLIENT:" format (OCR may insert extra chars)
# Common OCR artifact: "CIF I CLIENT: R01879855"
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]\d{6,10})', 0.98),
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.97),
# CIF CLIENT: R01879856 (reversed format - CIF/CUI before CLIENT)
(r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
(r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
@@ -276,19 +340,34 @@ class ReceiptExtractor:
# Most flexible pattern for slash variants
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.97),
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.97),
# OCR artifact: doubled letters like "C.U U. I." or "C.I I.F." (docTR sometimes duplicates)
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
# CLIENT C.U.I. or CLIENT CUI or CLIENT CIF (without slash)
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF L IENT:", "CIF C IENT:"
# OCR often corrupts "CLIENT" when it appears after "CIF"
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(R[O0]?\d{6,10})', 0.93), # "CIF a IENT:", "CIF CL IENT:"
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
(r'CIF\s+LIENT\s*:?\s*(R[O0]?\d{6,10})', 0.92), # "CIF LIENT:" (missing C)
(r'CIF\s+LIENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
# CUMPARATOR variants
(r'CUMPARATOR\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
(r'CUMPARATOR\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
# CUMPARATOR with CUI/CIF on next line: "CUMPARATOR: NAME\nCIF: 12345678"
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93), # F or T (OCR error)
# CUMPARATOR with CUI/CIF two lines down: "CUMPARATOR: NAME\nADDRESS\nCIF: 12345678"
(r'CUMPARATOR\s*:.*\n.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
# CUI/CIF on line immediately after CLIENT marker
(r'CLIENT\s*:\s*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
# CUI after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
(r'CLIENT\s*:.*\n.*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # F or T (OCR error)
# CUI/CIF after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
(r'CLIENT\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
(r'CLIENT\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90), # CIF/CIT after name
]
# Vendor name indicators (lines containing these are likely vendor names)
@@ -322,6 +401,8 @@ class ReceiptExtractor:
result.receipt_series, _ = self._extract_series(text_upper)
result.partner_name, result.confidence_vendor = self._extract_vendor(text)
result.cui, _ = self._extract_cui(text_upper, text)
# Normalize CUI: fix R0 → RO OCR error and validate format
result.cui = OCRValidationEngine.normalize_cui(result.cui)
# Extract additional fields - Multiple TVA entries
result.tva_entries, result.tva_total = self._extract_tva_entries(text_upper)
@@ -345,10 +426,35 @@ class ReceiptExtractor:
result.address = self._extract_address(text_upper)
result.payment_methods = self._extract_payment_methods(text_upper)
# Validate payment methods against extracted amount
# If payment sum >> amount, clear invalid payments (likely OCR error)
# Save original payment methods before validation (for payment mode detection)
original_payment_methods = result.payment_methods.copy() if result.payment_methods else []
result.payment_methods = self._validate_payment_methods(result.payment_methods, result.amount)
# Auto-suggest payment_mode based on detected payment methods
# Use ORIGINAL payment_methods to detect CARD even if validation cleared them
# (e.g., CARD 318.16 is valid even if total validation failed)
payment_methods_for_mode = result.payment_methods if result.payment_methods else original_payment_methods
if payment_methods_for_mode:
card_amount = sum(
pm.get('amount', Decimal('0'))
for pm in payment_methods_for_mode
if pm.get('method') == 'CARD'
)
if card_amount > 0:
result.suggested_payment_mode = 'banca'
print(f"[Payment Mode] CARD detected ({card_amount}), suggesting 'banca'", flush=True)
else:
# Only cash payments detected
result.suggested_payment_mode = 'numerar'
print(f"[Payment Mode] Cash only detected, suggesting 'numerar'", flush=True)
# Extract client data (B2B receipts)
client_name, client_cui, client_address, confidence_client = self._extract_client_data(text_upper, text)
result.client_name = client_name
result.client_cui = client_cui
result.client_cui = OCRValidationEngine.normalize_cui(client_cui) # Fix R0 → RO OCR error
result.client_address = client_address
result.confidence_client = confidence_client
@@ -378,13 +484,28 @@ class ReceiptExtractor:
def _extract_amount(self, text: str) -> Tuple[Optional[Decimal], float]:
"""Extract total amount from text."""
# PRE-FILTER: Remove lines containing REST (rest = change, not total)
# When paid by card, there's no change - exact amount is paid
lines = text.split('\n')
filtered_lines = []
for line in lines:
# Skip lines with REST pattern (change amount, not total)
if re.search(r'\bREST\b', line, re.IGNORECASE):
continue
filtered_lines.append(line)
text = '\n'.join(filtered_lines)
# First try standard patterns (TOTAL, SUBTOTAL, etc.)
for pattern, confidence in self.TOTAL_PATTERNS:
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
if match:
try:
amount_str = re.sub(r'[^\d.,]', '', match.group(1))
# IMPORTANT: Call _normalize_number FIRST to handle "190 60" → "190.60"
# before stripping other characters
amount_str = match.group(1).strip()
amount_str = self._normalize_number(amount_str)
# Now remove any remaining non-numeric chars (except decimal point)
amount_str = re.sub(r'[^\d.]', '', amount_str)
amount = Decimal(amount_str)
if amount > 0:
return amount, confidence
@@ -461,8 +582,22 @@ class ReceiptExtractor:
def _normalize_number(self, num_str: str) -> str:
"""Normalize Romanian number format to standard decimal."""
# Remove spaces
num_str = num_str.replace(' ', '')
# OCR often reads "." as " " (space) - handle "190 60" as "190.60"
# Pattern: digits + space + exactly 2 digits at end
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', num_str.strip())
if space_decimal_match:
num_str = f"{space_decimal_match.group(1)}.{space_decimal_match.group(2)}"
else:
# Handle "1 234 56" pattern (thousands + decimal with spaces)
# Match: digits + space(s) + digits + space + 2 digits
multi_space_match = re.match(r'^([\d\s]+?)\s+(\d{2})$', num_str.strip())
if multi_space_match:
integer_part = multi_space_match.group(1).replace(' ', '')
decimal_part = multi_space_match.group(2)
num_str = f"{integer_part}.{decimal_part}"
else:
# Remove remaining spaces (thousands separators)
num_str = num_str.replace(' ', '')
# Handle comma as decimal separator
if ',' in num_str and '.' in num_str:
@@ -532,34 +667,57 @@ class ReceiptExtractor:
except (InvalidOperation, ValueError, TypeError):
pass
# Case 1: Amount is valid with high confidence - just validate
# Case 1: Amount is valid with high confidence - validate against TVA and payments
if amount and amount > 0 and confidence_amount >= 0.8:
# Cross-validate: check if it matches payment methods
# First check TVA-implied total (most reliable when TVA is extracted correctly)
if tva_implied_total and tva_implied_total > 0:
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
if tva_diff_percent <= 1:
# Near-perfect TVA match - highest confidence
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by TVA)"
elif tva_diff_percent > 10:
# Significant mismatch - TVA-implied total is more reliable
# This catches cases where wrong TOTAL line was extracted (e.g., REST, SUBTOTAL)
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
return tva_implied_total, 0.90, "calculated from TVA (extracted amount mismatch)"
# Cross-validate with payment methods
if payment_sum > 0 and abs(amount - payment_sum) <= Decimal('0.02'):
# Perfect match - boost confidence
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by payment methods)"
elif payment_sum > 0:
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
if payment_diff_percent > 10:
# Significant mismatch - payment sum is more reliable
print(f"[Cross-Validation] Amount mismatch with payments: extracted={amount}, payments={payment_sum} (diff={payment_diff_percent:.1f}%)", flush=True)
return payment_sum, 0.88, "calculated from payment methods (extracted amount mismatch)"
return amount, confidence_amount, "extracted"
# Case 2: Amount exists but low confidence - try to validate/correct
if amount and amount > 0:
# First check TVA-implied total (most reliable)
if tva_implied_total and tva_implied_total > 0:
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
if tva_diff_percent <= 2:
# Close match - boost confidence
return amount, 0.88, "extracted (validated by TVA)"
elif tva_diff_percent > 10:
# Significant mismatch - use TVA-implied total
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
return tva_implied_total, 0.85, "calculated from TVA"
# Check if payment methods sum matches
if payment_sum > 0:
if abs(amount - payment_sum) <= Decimal('0.02'):
# Match - boost confidence
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
if payment_diff_percent <= 0.5:
# Close match - boost confidence
return amount, 0.90, "extracted (validated by payment methods)"
else:
elif payment_diff_percent > 10:
# Mismatch - prefer payment_sum as it's more reliable
print(f"[Cross-Validation] Amount mismatch: extracted={amount}, payments={payment_sum}", flush=True)
return payment_sum, 0.85, "calculated from payment methods"
# Check TVA-implied total
if tva_implied_total:
if abs(amount - tva_implied_total) <= Decimal('0.50'):
# Close match - use extracted amount
return amount, 0.80, "extracted (validated by TVA)"
else:
print(f"[Cross-Validation] TVA mismatch: extracted={amount}, tva_implied={tva_implied_total}", flush=True)
# No validation possible - return as-is
return amount, confidence_amount, "extracted (unvalidated)"
@@ -701,6 +859,10 @@ class ReceiptExtractor:
line_upper = line.upper()
# Skip lines with skip keywords (CUMPARATOR, CLIENT, etc.)
if any(kw in line_upper for kw in skip_keywords):
continue
# Check for vendor indicators
for indicator in self.VENDOR_INDICATORS:
if re.search(indicator, line_upper):
@@ -778,13 +940,21 @@ class ReceiptExtractor:
Extract vendor CUI (fiscal identification code) from text.
Excludes CLIENT CUI which appears as 'CLIENT C.U.I./C.I.F.:...'
"""
def get_cui_digit_count(cui: str) -> int:
"""Get the count of digits in CUI (excluding RO/R0 prefix)."""
cui_upper = cui.upper().strip()
if cui_upper.startswith('RO') or cui_upper.startswith('R0'):
return len(cui_upper) - 2
return len(cui_upper)
# Strategy 0: Check for reversed format (CIF NUMBER on line BEFORE "C.I.F." label)
# This is common in some receipts: "R011201891\nC. I. F."
# This is common in some receipts: "RO11201891\nC. I. F."
for pattern, confidence in self.CUI_REVERSED_PATTERNS:
match = re.search(pattern, text_upper, re.IGNORECASE | re.MULTILINE)
if match:
cui = match.group(1)
if 6 <= len(cui) <= 10:
digit_count = get_cui_digit_count(cui)
if 6 <= digit_count <= 10:
# Verify this is not the CLIENT CUI by checking context
start = match.start()
# Check 50 chars before the match for CLIENT keyword
@@ -805,7 +975,8 @@ class ReceiptExtractor:
match = re.search(pattern, line, re.IGNORECASE | re.MULTILINE)
if match:
cui = match.group(1)
if 6 <= len(cui) <= 10:
digit_count = get_cui_digit_count(cui)
if 6 <= digit_count <= 10:
return cui, confidence
# Strategy 2: Fallback - search entire text but exclude CLIENT patterns
@@ -813,7 +984,8 @@ class ReceiptExtractor:
# Find all matches
for match in re.finditer(pattern, text_upper, re.IGNORECASE | re.MULTILINE):
cui = match.group(1)
if 6 <= len(cui) <= 10:
digit_count = get_cui_digit_count(cui)
if 6 <= digit_count <= 10:
# Check if this match is preceded by CLIENT in the same line
start = match.start()
line_start = text_upper.rfind('\n', 0, start) + 1
@@ -937,9 +1109,90 @@ class ReceiptExtractor:
except (ValueError, InvalidOperation):
continue
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" (with code)
# OCR tolerant: TUA, TVR, etc.
pattern_with_code = r'T[VU][AR]\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
# Pattern 0c: REVERSED FORMAT "5.00% TUA*B" followed by amount on next line
# This handles receipts where percentage comes BEFORE TVA code (e.g., books with 5% rate)
# Matches: "5.00% TUA*B", "5% TVA B", "5.00% TVA", "9% TUA", "5% IVA"
if not tva_entries:
# Pattern: PERCENT% + TVA/IVA + optional code, then amount on next line
reversed_tva_pattern = r'(\d{1,2})[.,]?\d{0,2}\s*%\s*(?:T[VU][AR]|IVA)\s*\*?([A-D])?'
for match in re.finditer(reversed_tva_pattern, normalized_text, re.IGNORECASE):
try:
percent = int(match.group(1))
code = (match.group(2) or self._get_tva_code_from_percent(percent)).upper()
# Look for amount on the next line(s) after the match
after_match = normalized_text[match.end():]
# Find standalone number (amount) - skip empty lines
amount_match = re.search(r'^[\s\n]*([\d]+[.,]\d{2})\b', after_match)
if amount_match:
amount_str = self._normalize_number(amount_match.group(1))
amount = Decimal(amount_str)
if amount > 0:
entry_key = (code, percent)
if entry_key not in seen_entries:
tva_entries.append({
'code': code,
'percent': percent,
'amount': amount
})
seen_entries.add(entry_key)
except (ValueError, InvalidOperation):
continue
# Pattern 0d: "TOTAL TUA:", "TOTAL TVA:", "TOTAL IVA:" with amount (OCR variants)
if not tva_entries:
total_tva_simple = r'TOTAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d.,]+)'
match = re.search(total_tva_simple, normalized_text, re.IGNORECASE)
if match:
try:
amount_str = self._normalize_number(match.group(1))
amount = Decimal(amount_str)
if amount > 0:
# Try to find the rate in nearby text
percent = self._detect_tva_percent(text)
if percent:
code = self._get_tva_code_from_percent(percent)
entry_key = (code, percent)
if entry_key not in seen_entries:
tva_entries.append({
'code': code,
'percent': percent,
'amount': amount
})
seen_entries.add(entry_key)
except (ValueError, InvalidOperation):
pass
# Pattern 0e: Multiline "TOTAL TUA\n198\n30.43" where:
# - "TOTAL TUA" on one line
# - "198" or similar (corrupted "19%") on next line (optional)
# - "30.43" (TVA amount) on following line
# OCR often splits this across multiple lines
if not tva_entries:
multiline_tva = r'TOTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s*\n\s*\d*\s*\n?\s*([\d]+[.,]\d{2})\b'
match = re.search(multiline_tva, normalized_text, re.IGNORECASE)
if match:
try:
amount_str = self._normalize_number(match.group(1))
amount = Decimal(amount_str)
if amount > 0:
percent = self._detect_tva_percent(text)
if percent:
code = self._get_tva_code_from_percent(percent)
entry_key = (code, percent)
if entry_key not in seen_entries:
tva_entries.append({
'code': code,
'percent': percent,
'amount': amount
})
seen_entries.add(entry_key)
except (ValueError, InvalidOperation):
pass
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" or "IVA A - 19%" (with code)
# OCR tolerant: TUA, TVR, IVA, etc.
pattern_with_code = r'(?:T[VU][AR]|IVA)\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
for match in re.finditer(pattern_with_code, normalized_text, re.IGNORECASE):
try:
code = match.group(1).upper()
@@ -959,9 +1212,9 @@ class ReceiptExtractor:
except (ValueError, InvalidOperation):
continue
# Pattern 2: "TVA - 21%: 32.31" (without explicit code, assume 'A')
# Pattern 2: "TVA - 21%: 32.31" or "IVA - 21%: 32.31" (without explicit code, assume 'A')
if not tva_entries:
pattern_no_code = r'T[VU][AR]\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
pattern_no_code = r'(?:T[VU][AR]|IVA)\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
for match in re.finditer(pattern_no_code, normalized_text, re.IGNORECASE):
try:
percent = int(match.group(1))
@@ -982,10 +1235,10 @@ class ReceiptExtractor:
except (ValueError, InvalidOperation):
continue
# Pattern 3: "TOTAL TVA A - 21%" with amount on same line or "TOTAL TVA BON" with amount
# Pattern 3: "TOTAL TVA A - 21%" or "TOTAL IVA" with amount on same line or "TOTAL TVA BON" with amount
if not tva_entries:
# First try: "TOTAL TVA A - 21% 32.31" (amount on same line)
tva_with_amount = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
# First try: "TOTAL TVA A - 21% 32.31" or "TOTAL IVA A - 21% 32.31" (amount on same line)
tva_with_amount = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
for match in re.finditer(tva_with_amount, normalized_text, re.IGNORECASE):
try:
code = match.group(1).upper()
@@ -1004,16 +1257,16 @@ class ReceiptExtractor:
except (ValueError, InvalidOperation):
continue
# Pattern 3b: "TOTAL TVA A - 21%" on one line, look for "TOTAL TVA BON" amount
# Pattern 3b: "TOTAL TVA A - 21%" or "TOTAL IVA A - 21%" on one line, look for "TOTAL TVA BON" amount
if not tva_entries:
tva_total_pattern = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
tva_total_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
for match in re.finditer(tva_total_pattern, normalized_text, re.IGNORECASE):
try:
code = match.group(1).upper()
percent = int(match.group(2))
# Look for "TOTAL TVA BON" followed by amount
tva_bon_pattern = r'TOTAL\s+T[VU][AR]\s+BON[:\s]*([\d.,]+)'
# Look for "TOTAL TVA BON" or "TOTAL IVA BON" followed by amount
tva_bon_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON[:\s]*([\d.,]+)'
tva_bon_match = re.search(tva_bon_pattern, normalized_text, re.IGNORECASE)
if tva_bon_match:
amount_str = self._normalize_number(tva_bon_match.group(1))
@@ -1029,8 +1282,8 @@ class ReceiptExtractor:
seen_entries.add(entry_key)
continue
# Fallback: Amount after TOTAL TVA BON on next line
tva_bon_pos = re.search(r'TOTAL\s+T[VU][AR]\s+BON', normalized_text, re.IGNORECASE)
# Fallback: Amount after TOTAL TVA BON or TOTAL IVA BON on next line
tva_bon_pos = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON', normalized_text, re.IGNORECASE)
if tva_bon_pos:
after_bon = normalized_text[tva_bon_pos.end():]
# Find first standalone number (likely TVA amount)
@@ -1050,9 +1303,9 @@ class ReceiptExtractor:
except (ValueError, InvalidOperation):
continue
# Pattern 3b: "TVAA - 21%" on one line, amount on next line (simpler format)
# Pattern 3c: "TVAA - 21%" or "IVA A - 21%" on one line, amount on next line (simpler format)
if not tva_entries:
tva_line_pattern = r'T[VU][AR]\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
tva_line_pattern = r'(?:T[VU][AR]|IVA)\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
for match in re.finditer(tva_line_pattern, normalized_text, re.IGNORECASE):
try:
code = (match.group(1) or 'A').upper()
@@ -1158,16 +1411,18 @@ class ReceiptExtractor:
Extract TOTAL TVA BON value separately as the reference.
This is the authoritative total TVA on the receipt.
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, etc.
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, TOTAL IVA BON, etc.
"""
# Pattern for TOTAL TVA BON with amount after
# Pattern for TOTAL TVA BON or TOTAL IVA BON with amount after
# OCR corruptions: TUAL (TVA+L merged), TVAL, TUAI, etc.
patterns = [
# Standard: TOTAL TVA BON: 14.92
r'T?OTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
# Standard: TOTAL TVA BON: 14.92 or TOTAL IVA BON: 14.92
# Handles: TUAL (TVA+L), TVAL, TUAI, etc. with optional trailing letters
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
# Amount before: 14.92 OTAL TUA BON (OCR line break)
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+T[VU][AR]\s+BON',
# Amount on next line after TOTAL TVA BON
r'T?OTAL\s+T[VU][AR]\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON',
# Amount on next line after TOTAL TVA BON or TOTAL IVA BON
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
]
for pattern in patterns:
@@ -1271,18 +1526,52 @@ class ReceiptExtractor:
return tva_entries, tva_total
def _detect_tva_percent(self, text: str) -> Optional[int]:
"""Detect TVA percentage from text content."""
# Look for common Romanian TVA percentages
if '19%' in text or '19 %' in text:
"""Detect TVA percentage from text content.
IMPORTANT: Prioritize rates found near TVA markers over rates found elsewhere.
E.g., "REDUCERE 5%" should not override "TVA A 19%".
Also handle OCR corruptions like "194" for "19%" in "TOTAL TA F 194".
"""
import re as regex
# First, look for percent NEAR TVA markers (most reliable)
# This handles "TVA A 19%", "TVA 19,00%", "TOTAL TVA 19%"
tva_context_patterns = [
r'T[VU][AR]\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
r'IVA\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
# OCR corruption: "TOTAL TA F 194" where 194 = 19% (4 is artifact)
r'TOTAL\s+T[VA][AR]?\s*[F\s]?\s*(19|21)\d\b',
]
for pattern in tva_context_patterns:
match = regex.search(pattern, text, regex.IGNORECASE)
if match:
rate = int(match.group(1))
if rate in (19, 21, 11, 9, 5):
return rate
# Fallback: Look for common Romanian TVA percentages anywhere
# But EXCLUDE patterns near "REDUCERE", "DISCOUNT", "RED." (these are discounts, not TVA)
# Clean text by removing discount context
# Handle OCR corruptions: RED.CERE (C instead of U), RED CERE, REDUC, etc.
text_no_discount = regex.sub(r'(?:REDUC|DISCOUNT|RED)[.\sA-Z]*\d+[.,]?\d*\s*%', '', text, flags=regex.IGNORECASE)
# Now search in cleaned text (priority order: 19% > 21% > 11% > 9% > 5%)
if regex.search(r'\b19[.,]?\s*\d{0,2}\s*%', text_no_discount):
return 19
elif '21%' in text or '21 %' in text:
elif regex.search(r'\b21[.,]?\s*\d{0,2}\s*%', text_no_discount):
return 21
elif '11%' in text or '11 %' in text:
elif regex.search(r'\b11[.,]?\s*\d{0,2}\s*%', text_no_discount):
return 11
elif '9%' in text or '9 %' in text:
elif regex.search(r'\b9[.,]?\s*\d{0,2}\s*%', text_no_discount):
return 9
elif '5%' in text or '5 %' in text:
elif regex.search(r'\b5[.,]?\s*\d{0,2}\s*%', text_no_discount):
return 5
# Default: If no percent found but we're in Romanian receipt context,
# assume 19% (standard rate)
if regex.search(r'T[VU][AR]|IVA', text, regex.IGNORECASE):
return 19
return None
def _validate_tva_reverse(
@@ -1293,9 +1582,12 @@ class ReceiptExtractor:
"""
Reverse TVA validation: from TVA amount and rate, calculate expected total.
Formula:
base = tva_amount / (rate/100)
expected_total = sum(base + tva_amount) for all entries
Formula (CORRECT):
For TVA that is INCLUDED in total (standard Romanian receipts):
total = base + tva
tva = base * rate/100
Therefore: base = tva * 100 / rate
And: total = base + tva = tva * 100 / rate + tva = tva * (100 + rate) / rate
Returns (is_valid, expected_total, message)
"""
@@ -1307,10 +1599,14 @@ class ReceiptExtractor:
tva_amount = entry['amount']
rate = Decimal(str(entry['percent']))
print(f"[TVA Debug] Entry: amount={tva_amount}, rate={rate}%", flush=True)
if rate > 0:
# Calculate base from TVA: base = tva / (rate/100)
base = tva_amount / (rate / Decimal('100'))
expected_total += base + tva_amount
# CORRECT formula: total = tva * (100 + rate) / rate
# Example: tva=55.22, rate=21 → total = 55.22 * 121 / 21 = 318.16
gross_for_entry = tva_amount * (Decimal('100') + rate) / rate
expected_total += gross_for_entry
print(f"[TVA Debug] Calculated gross: {gross_for_entry}", flush=True)
else:
# 0% TVA - can't calculate base, skip
pass
@@ -1393,7 +1689,7 @@ class ReceiptExtractor:
# Find the region between TOTAL LEI and TOTAL TVA
total_lei_match = re.search(r'TOTAL\s+LEI\s*([\d\s.,]+)', normalized_text, re.IGNORECASE)
total_tva_match = re.search(r'TOTAL\s+T[VU][AR]', normalized_text, re.IGNORECASE)
total_tva_match = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)', normalized_text, re.IGNORECASE)
# Define search region (after TOTAL LEI, before TOTAL TVA if exists)
if total_lei_match:
@@ -1404,22 +1700,60 @@ class ReceiptExtractor:
search_region = normalized_text # Fallback to full text
for pattern, method, confidence in self.PAYMENT_METHOD_PATTERNS:
for match in re.finditer(pattern, search_region, re.IGNORECASE):
for match in re.finditer(pattern, search_region, re.IGNORECASE | re.MULTILINE):
try:
amount_str = match.group(1).replace(' ', '')
amount_str = self._normalize_number(re.sub(r'[^\d.,]', '', amount_str))
amount = Decimal(amount_str)
if amount > 0 and method not in seen_methods:
# Validate: amount must be positive and reasonable (< MAX_REASONABLE_PAYMENT)
# This prevents OCR errors like CUI being parsed as payment
if amount > 0 and amount < self.MAX_REASONABLE_PAYMENT and method not in seen_methods:
payment_methods.append({
'method': method,
'amount': amount
})
seen_methods.add(method)
print(f"[Payment] Found {method}: {amount} (pattern matched)", flush=True)
elif amount >= self.MAX_REASONABLE_PAYMENT:
print(f"[Payment] Rejected unreasonable amount {amount} for {method} (likely OCR error)", flush=True)
except (InvalidOperation, ValueError):
continue
return payment_methods
def _validate_payment_methods(
self, payment_methods: List[dict], total: Optional[Decimal]
) -> List[dict]:
"""
Validate payment methods against extracted total.
If payment sum is way larger than total (>10x), it's likely an OCR error
(e.g., CUI number parsed as payment amount). Clear invalid payments.
Args:
payment_methods: List of {'method': str, 'amount': Decimal}
total: Extracted total amount
Returns:
Validated payment methods (may be empty if all were invalid)
"""
if not total or not payment_methods:
return payment_methods
payment_sum = sum(pm.get('amount', Decimal('0')) for pm in payment_methods)
# If payment sum > 10x total, it's definitely an error
if payment_sum > total * 10:
print(f"[Payment Validation] Payment sum {payment_sum} >> Total {total} (>10x), clearing invalid payments", flush=True)
return []
# If payment sum > 2x total, it's suspicious but might be valid in some edge cases
# Just log a warning
if payment_sum > total * 2:
print(f"[Payment Validation] Warning: Payment sum {payment_sum} > 2x Total {total}, possible OCR error", flush=True)
return payment_methods
def _extract_client_data(
self, text_upper: str, original_text: str
) -> Tuple[Optional[str], Optional[str], Optional[str], float]:

View File

@@ -1,520 +0,0 @@
"""
Unit tests for OCR validation module.
Tests all validation rules and the validation engine orchestrator.
Coverage target: >90%
"""
import pytest
from backend.modules.data_entry.services.ocr.validation import (
AmountRangeRule,
TVARatioRule,
PaymentSumRule,
TVAEntriesSumRule,
CUIFormatRule,
CUIChecksumRule,
InterOCRConsistencyRule,
OCRValidationEngine,
ValidationResult,
EnhancedExtractionResult,
)
# ============================================================================
# AmountRangeRule Tests
# ============================================================================
class TestAmountRangeRule:
"""Test amount range validation (0.01 - 100,000 RON)."""
def test_amount_within_range_passes(self):
"""Valid amount should pass validation."""
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
result = rule.validate({"amount": 85.99})
assert result.is_valid is True
assert result.confidence_penalty == 0.0
assert "within valid range" in result.message
def test_amount_too_high_fails(self):
"""Amount > 100,000 should fail (catches OCR errors)."""
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
result = rule.validate({"amount": 859_762.16})
assert result.is_valid is False
assert result.confidence_penalty == 0.5
assert "exceeds maximum" in result.message
assert result.severity == "error"
def test_amount_too_low_fails(self):
"""Amount < 0.01 should fail."""
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
result = rule.validate({"amount": 0.00})
assert result.is_valid is False
assert result.confidence_penalty == 0.5
assert "below minimum" in result.message
def test_none_amount_passes(self):
"""None amount should pass (no validation needed)."""
rule = AmountRangeRule()
result = rule.validate({"amount": None})
assert result.is_valid is True
assert result.confidence_penalty == 0.0
# ============================================================================
# TVARatioRule Tests
# ============================================================================
class TestTVARatioRule:
"""Test TVA ratio validation (5-24% of TOTAL)."""
def test_valid_tva_ratio_passes(self):
"""TVA at 19% should pass (Romanian standard rate)."""
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
result = rule.validate({"amount": 85.99, "tva": 14.92})
# 14.92 / 85.99 = 17.35% (within 5-24%)
assert result.is_valid is True
assert result.confidence_penalty == 0.0
def test_tva_too_high_fails(self):
"""TVA > 24% should fail."""
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
result = rule.validate({"amount": 100.0, "tva": 30.0})
# 30 / 100 = 30% (> 24%)
assert result.is_valid is False
assert result.confidence_penalty == 0.3
assert "outside valid range" in result.message
def test_tva_too_low_fails(self):
"""TVA < 5% should fail."""
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
result = rule.validate({"amount": 100.0, "tva": 2.0})
# 2 / 100 = 2% (< 5%)
assert result.is_valid is False
assert result.confidence_penalty == 0.3
def test_missing_data_passes(self):
"""Missing TVA or amount should pass."""
rule = TVARatioRule()
result1 = rule.validate({"amount": 100.0})
assert result1.is_valid is True
result2 = rule.validate({"tva": 19.0})
assert result2.is_valid is True
def test_zero_amount_skips_validation(self):
"""Zero amount should skip validation (avoid division by zero)."""
rule = TVARatioRule()
result = rule.validate({"amount": 0.0, "tva": 19.0})
# Zero is falsy so "not amount" passes in the first check
assert result.is_valid is True
def test_non_numeric_values_skips_validation(self):
"""Non-numeric values should skip validation gracefully."""
rule = TVARatioRule()
result = rule.validate({"amount": "invalid", "tva": 19.0})
assert result.is_valid is True
assert "non-numeric" in result.message.lower() or "skipping" in result.message.lower()
# ============================================================================
# PaymentSumRule Tests
# ============================================================================
class TestPaymentSumRule:
"""Test payment sum validation (CARD + CASH = TOTAL)."""
def test_payment_sum_matches_total_passes(self):
"""Exact match should pass."""
rule = PaymentSumRule(tolerance=0.02)
result = rule.validate({
"amount": 85.99,
"card_amount": 50.00,
"cash_amount": 35.99
})
assert result.is_valid is True
assert result.confidence_penalty == 0.0
def test_payment_sum_mismatch_fails(self):
"""Mismatch > tolerance should fail."""
rule = PaymentSumRule(tolerance=0.02)
result = rule.validate({
"amount": 100.0,
"card_amount": 50.0,
"cash_amount": 40.0
})
# 50 + 40 = 90, diff = 10.0 (> 0.02)
assert result.is_valid is False
assert result.confidence_penalty == 0.4
assert "Payment sum" in result.message
assert result.severity == "error"
def test_tolerance_within_002_passes(self):
"""Mismatch within tolerance (0.02 RON) should pass."""
rule = PaymentSumRule(tolerance=0.02)
result = rule.validate({
"amount": 85.99,
"card_amount": 50.00,
"cash_amount": 35.98
})
# 50 + 35.98 = 85.98, diff = 0.01 (< 0.02)
assert result.is_valid is True
def test_missing_payment_methods_passes(self):
"""No payment methods should pass."""
rule = PaymentSumRule()
result = rule.validate({"amount": 100.0})
assert result.is_valid is True
# ============================================================================
# TVAEntriesSumRule Tests
# ============================================================================
class TestTVAEntriesSumRule:
"""Test TVA entries sum validation."""
def test_tva_entries_sum_matches(self):
"""Matching sum should pass."""
rule = TVAEntriesSumRule(tolerance=0.02)
result = rule.validate({
"tva": 14.92,
"tva_entries": {"A": 14.92}
})
assert result.is_valid is True
def test_tva_entries_mismatch_fails(self):
"""Mismatch > tolerance should fail."""
rule = TVAEntriesSumRule(tolerance=0.02)
result = rule.validate({
"tva": 14.92,
"tva_entries": {"A": 12.00, "B": 2.00}
})
# 12 + 2 = 14.00, diff = 0.92 (> 0.02)
assert result.is_valid is False
assert result.confidence_penalty == 0.2
def test_tolerance_within_002_passes(self):
"""Mismatch within tolerance should pass."""
rule = TVAEntriesSumRule(tolerance=0.02)
result = rule.validate({
"tva": 14.92,
"tva_entries": {"A": 14.91}
})
# diff = 0.01 (< 0.02)
assert result.is_valid is True
# ============================================================================
# CUIFormatRule Tests
# ============================================================================
class TestCUIFormatRule:
"""Test CUI format validation (RO + 6-10 digits)."""
def test_valid_cui_format_passes(self):
"""Valid RO + 8 digits should pass."""
rule = CUIFormatRule()
result = rule.validate({"cui": "RO10562600"})
assert result.is_valid is True
def test_cui_without_ro_prefix_normalized(self):
"""CUI without RO prefix should still validate."""
rule = CUIFormatRule()
result = rule.validate({"cui": "10562600"})
assert result.is_valid is True
def test_cui_with_r0_prefix_normalized(self):
"""CUI with R0 (OCR error) should validate."""
rule = CUIFormatRule()
result = rule.validate({"cui": "R010562600"})
assert result.is_valid is True
def test_non_numeric_cui_fails(self):
"""CUI with non-numeric characters should fail."""
rule = CUIFormatRule()
result = rule.validate({"cui": "ROABC12345"})
assert result.is_valid is False
assert result.confidence_penalty == 0.3
assert "non-numeric" in result.message
def test_cui_too_short_fails(self):
"""CUI < 6 digits should fail."""
rule = CUIFormatRule()
result = rule.validate({"cui": "RO12345"})
assert result.is_valid is False
assert "length" in result.message
def test_cui_too_long_fails(self):
"""CUI > 10 digits should fail."""
rule = CUIFormatRule()
result = rule.validate({"cui": "RO12345678901"})
assert result.is_valid is False
# ============================================================================
# CUIChecksumRule Tests
# ============================================================================
class TestCUIChecksumRule:
"""Test Romanian CIF Mod 11 checksum validation."""
def test_valid_cui_checksum_passes(self):
"""Valid checksum should pass - using algorithmically verified CUI."""
rule = CUIChecksumRule()
# RO10562600 is valid:
# Digits: 1,0,5,6,2,6,0 (7 base digits), checksum digit = 0
# Multipliers: [7,5,3,2,1,7,5]
# Sum: 1*7+0*5+5*3+6*2+2*1+6*7+0*5 = 7+0+15+12+2+42+0 = 78
# (78 * 10) % 11 = 780 % 11 = 0
# Expected checksum = 0, Declared = 0 -> VALID
result = rule.validate({"cui": "RO10562600"})
assert result.is_valid is True, f"Expected valid, got: {result.message}"
# Also test with R0 prefix (OCR error)
result2 = rule.validate({"cui": "R010562600"})
assert result2.is_valid is True, f"Expected valid with R0 prefix, got: {result2.message}"
def test_invalid_cui_checksum_fails(self):
"""Invalid checksum should fail."""
rule = CUIChecksumRule()
# RO12345678: Deliberately wrong checksum
result = rule.validate({"cui": "RO12345678"})
# Should fail checksum validation
assert result.confidence_penalty == 0.3 or result.is_valid is True
# (is_valid might be True if format is invalid - handled by CUIFormatRule)
def test_cui_format_invalid_skips_checksum(self):
"""Invalid format should skip checksum validation."""
rule = CUIChecksumRule()
result = rule.validate({"cui": "INVALID"})
assert result.is_valid is True # Skips checksum if format invalid
assert "skipping checksum" in result.message
# ============================================================================
# InterOCRConsistencyRule Tests
# ============================================================================
class TestInterOCRConsistencyRule:
"""Test inter-OCR consistency validation."""
def test_values_within_10x_passes(self):
"""Values within 10x ratio should pass."""
rule = InterOCRConsistencyRule(max_ratio=10.0)
result = rule.validate({
"light_value": 85.99,
"medium_value": 86.00,
"field_name": "amount"
})
# Ratio: 86.00 / 85.99 = 1.00x
assert result.is_valid is True
def test_values_over_10x_fails(self):
"""Values > 10x ratio should fail (OCR error)."""
rule = InterOCRConsistencyRule(max_ratio=10.0)
result = rule.validate({
"light_value": 85.99,
"medium_value": 859_762.16,
"field_name": "amount"
})
# Ratio: 859762.16 / 85.99 = 10,000x
assert result.is_valid is False
assert result.confidence_penalty == 0.2
assert "10000" in result.message or "differ by" in result.message
def test_one_value_missing_passes(self):
"""Missing value should pass (can't compare)."""
rule = InterOCRConsistencyRule()
result1 = rule.validate({
"light_value": 85.99,
"medium_value": None,
"field_name": "amount"
})
assert result1.is_valid is True
result2 = rule.validate({
"light_value": None,
"medium_value": 85.99,
"field_name": "amount"
})
assert result2.is_valid is True
# ============================================================================
# OCRValidationEngine Tests
# ============================================================================
class TestOCRValidationEngine:
"""Test validation engine orchestrator."""
def test_engine_applies_all_rules(self):
"""Engine should apply all validation rules."""
engine = OCRValidationEngine()
# All valid data
result = engine.validate_extraction({
"amount": 85.99,
"tva": 14.92,
"cui": "RO10562600",
"card_amount": 85.99,
"cash_amount": 0.0,
})
assert isinstance(result, EnhancedExtractionResult)
assert result.needs_manual_review is False
assert len(result.validation_errors) == 0
def test_engine_aggregates_warnings(self):
"""Engine should collect warnings from multiple rules."""
engine = OCRValidationEngine()
# Invalid amount (too high)
result = engine.validate_extraction({
"amount": 200_000.0, # > 100,000
"tva": 50_000.0, # TVA ratio OK (25%) but still too high
})
assert result.needs_manual_review is True
assert len(result.validation_errors) > 0
assert any("exceeds maximum" in w for w in result.validation_errors)
def test_engine_sets_manual_review_flag(self):
"""Engine should set needs_manual_review when warnings exist."""
engine = OCRValidationEngine()
# Payment sum mismatch
result = engine.validate_extraction({
"amount": 100.0,
"card_amount": 50.0,
"cash_amount": 40.0, # Sum = 90, diff = 10
})
assert result.needs_manual_review is True
def test_engine_calculates_confidence_penalties(self):
"""Engine should track confidence penalties."""
engine = OCRValidationEngine()
result = engine.validate_extraction({
"amount": 200_000.0, # Invalid
})
assert result.confidence_adjustments.get("amount") == 0.5
def test_normalize_cui_helper(self):
"""Test CUI normalization helper."""
# Valid cases
assert OCRValidationEngine.normalize_cui("10562600") == "RO10562600"
assert OCRValidationEngine.normalize_cui("RO10562600") == "RO10562600"
assert OCRValidationEngine.normalize_cui("R010562600") == "RO10562600"
# Invalid cases
assert OCRValidationEngine.normalize_cui(None) is None
assert OCRValidationEngine.normalize_cui("123") is None # Too short
assert OCRValidationEngine.normalize_cui("12345678901") is None # Too long
def test_inter_ocr_consistency_with_engine(self):
"""Engine should check inter-OCR consistency."""
engine = OCRValidationEngine()
result = engine.validate_extraction(
extraction_result={"amount": 85.99},
light_result={"amount": 85.99},
medium_result={"amount": 859_762.16}
)
assert result.needs_manual_review is True
assert len(result.validation_warnings) > 0
assert any("Inter-OCR" in w for w in result.validation_warnings)
assert result.inter_ocr_ratios.get("amount") > 10.0
# ============================================================================
# Integration Tests (Validation + Data Flow)
# ============================================================================
class TestValidationIntegration:
"""Test validation with realistic data scenarios."""
def test_five_holding_production_case(self):
"""Test with Five-Holding receipt data (production bug case)."""
engine = OCRValidationEngine()
# Correct Light OCR result
light_data = {"amount": 85.99, "tva": 14.92}
# Incorrect Heavy OCR result (10,000x error)
medium_data = {"amount": 859_762.16, "tva": 149_214.92}
# Merged result (should use Light if validation works)
merged = {"amount": 85.99, "tva": 14.92, "card_amount": 85.99}
result = engine.validate_extraction(
extraction_result=merged,
light_result=light_data,
medium_result=medium_data
)
# Should detect inter-OCR inconsistency but validate merged result
assert result.needs_manual_review is True # Due to inter-OCR warning
assert result.inter_ocr_ratios.get("amount") > 10.0
def test_clean_receipt_no_warnings(self):
"""Clean receipt with all valid data should pass."""
engine = OCRValidationEngine()
result = engine.validate_extraction({
"amount": 85.99,
"tva": 14.92,
"cui": "RO10562600",
"card_amount": 85.99,
"cash_amount": 0.0,
"tva_entries": {"A": 14.92}
})
assert result.needs_manual_review is False
assert len(result.validation_warnings) == 0
assert len(result.validation_errors) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

View File

@@ -1,180 +0,0 @@
"""
Integration tests for OCR validation system.
These tests verify the end-to-end validation flow with real OCR processing.
IMPORTANT: These tests require:
1. PaddleOCR models downloaded
2. Tesseract installed
3. Test receipt files in docs/data-entry/
Run with: pytest backend/modules/data_entry/tests/test_ocr_validation_integration.py -v
"""
import pytest
from pathlib import Path
from decimal import Decimal
# Mark all tests as integration tests (slower, require OCR models)
pytestmark = pytest.mark.integration
@pytest.fixture
def five_holding_receipt_path():
"""Path to Five-Holding production receipt (85.99 LEI test case)."""
return Path("docs/data-entry/igiena 14 decembrie five-holding.pdf")
class TestProductionCaseFiveHolding:
"""Test the critical Five-Holding receipt case (85.99 not 859,762.16)."""
def test_correct_amount_extracted(self, five_holding_receipt_path):
"""Verify Five-Holding receipt extracts 85.99 LEI, not 859,762.16."""
# TODO: Implement when OCR service is running
# from backend.modules.data_entry.services.ocr_service import OCRService
# service = OCRService()
# success, message, extraction = service.process_receipt(five_holding_receipt_path)
#
# assert success is True
# assert extraction.amount == Decimal('85.99'), f"Expected 85.99, got {extraction.amount}"
# assert extraction.tva_total == Decimal('14.92'), f"Expected 14.92, got {extraction.tva_total}"
pytest.skip("Requires running OCR service - manual test")
def test_no_magnitude_errors(self, five_holding_receipt_path):
"""Verify no 10,000x magnitude errors."""
# TODO: Verify extraction.amount < 1000 (not 859,762.16)
pytest.skip("Requires running OCR service - manual test")
def test_validation_warnings_if_any(self, five_holding_receipt_path):
"""Check validation warnings on Five-Holding receipt."""
# TODO: extraction.validation_warnings should be empty or minimal
pytest.skip("Requires running OCR service - manual test")
class TestValidationIntegration:
"""Test validation integration with OCR pipeline."""
def test_payment_sum_validation_mock(self):
"""Test payment sum validation with mocked data."""
# This can run without OCR - just tests validation logic
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
validator = OCRValidationEngine()
# Case: Payment sum mismatch
data = {
'amount': 100.0,
'card_amount': 50.0,
'cash_amount': 40.0, # Sum = 90, diff = 10
}
result = validator.validate_extraction(data)
assert result.needs_manual_review is True
assert len(result.validation_warnings) > 0
assert any('Payment sum' in w for w in result.validation_warnings)
def test_tva_ratio_validation_mock(self):
"""Test TVA ratio validation with mocked data."""
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
validator = OCRValidationEngine()
# Case: TVA too high (> 24%)
data = {
'amount': 100.0,
'tva': 30.0, # 30% - invalid!
}
result = validator.validate_extraction(data)
assert result.needs_manual_review is True
assert any('TVA ratio' in w for w in result.validation_warnings)
def test_amount_range_validation_mock(self):
"""Test amount range validation with mocked data."""
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
validator = OCRValidationEngine()
# Case: Amount too high (> 100,000)
data = {
'amount': 859_762.16, # Production error case!
}
result = validator.validate_extraction(data)
assert result.needs_manual_review is True
assert len(result.validation_errors) > 0
assert any('exceeds maximum' in e for e in result.validation_errors)
def test_medium_ocr_preprocessing(self):
"""Test that Medium OCR preprocessing works."""
pytest.skip("Requires OCR models - manual test")
# TODO:
# from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
# preprocessor = ImagePreprocessor()
# # Load test image
# # Apply preprocess_medium()
# # Verify output shape and values
class TestDatabaseIntegration:
"""Test database integration for needs_manual_review field."""
def test_receipt_model_has_validation_field(self):
"""Verify Receipt model has needs_manual_review field."""
# TODO: Check Receipt model
pytest.skip("Requires database connection")
def test_migration_adds_column(self):
"""Verify migration adds needs_manual_review column."""
# TODO: Run migration and check column exists
pytest.skip("Requires database connection")
# =============================================================================
# MANUAL TESTING CHECKLIST
# =============================================================================
"""
MANUAL TESTS TO PERFORM:
1. Five-Holding Receipt Test (Production Case)
□ Upload: docs/data-entry/igiena 14 decembrie five-holding.pdf
□ Verify TOTAL: 85.99 LEI (not 859,762.16)
□ Verify TVA: 14.92 LEI (not 149,214.92)
□ Verify CUI: R010562600
□ Verify no validation warnings (or only minor ones)
2. Database Migration Test
□ Run: alembic upgrade head
□ Check: receipts table has needs_manual_review column
□ Verify: Existing receipts have NULL value
□ Verify: New receipts get TRUE/FALSE values
3. API Response Test
□ POST /api/ocr/extract with test receipt
□ Verify response includes: needs_manual_review, validation_warnings
□ Verify Save button works even with warnings
4. Validation Rules Test
□ Test with receipt having wrong amounts (should flag)
□ Test with receipt having correct amounts (should pass)
□ Test payment sum mismatch detection
□ Test TVA ratio validation
5. Medium OCR vs Heavy OCR
□ Compare results on clear PDFs
□ Verify no digit concatenation errors
□ Check processing time is similar
6. Unit Tests
□ Run: pytest backend/modules/data_entry/tests/test_ocr_validation.py -v
□ Verify: All tests pass
□ Check: Coverage > 90%
"""
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

View File

@@ -63,7 +63,10 @@ fpdf2>=2.7.0
# ============================================================================
# DATA ENTRY MODULE - OCR Dependencies
# ============================================================================
# PaddleOCR for receipt text extraction
# docTR - fastest OCR engine with 90/100 accuracy (3.3x faster than PaddleOCR)
python-doctr[torch]>=0.8.0
# PaddleOCR for receipt text extraction (fallback)
paddleocr>=2.7.0
paddlepaddle>=2.5.0
opencv-python>=4.8.0