feat(ocr): Add docTR OCR engine with metrics infrastructure
Add docTR as primary OCR engine with 2-tier sequential processing, OCR metrics tracking, and simplified engine selection. Features: - docTR OCR engine with light+medium preprocessing tiers - doctr_plus mode with early exit optimization (~65% fast path) - OCR metrics dashboard with per-engine statistics - User OCR preference persistence - Parallel worker pool for OCR processing - Cross-validation for extraction quality Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -104,6 +104,34 @@ MAX_UPLOAD_SIZE_MB=10
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=true
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -112,6 +112,40 @@ DATA_ENTRY_SQLITE_DATABASE_PATH=data/receipts/receipts.db
|
||||
DATA_ENTRY_UPLOAD_PATH=data/receipts/uploads
|
||||
DATA_ENTRY_MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=true
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit, ~7.5s avg)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# Active OCR engines shown in frontend dropdown (comma-separated)
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# doctr_plus: 73.3% perfect, 7.5s avg, 65% fast path (recommended)
|
||||
# doctr: 63.3% perfect, simpler but faster
|
||||
OCR_ACTIVE_ENGINES=tesseract,doctr,doctr_plus,paddleocr
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -96,6 +96,35 @@ SQLITE_DATABASE_PATH=data/receipts/receipts_prod.db
|
||||
UPLOAD_PATH=data/receipts/uploads
|
||||
MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
# PRODUCTION: Set based on server memory availability
|
||||
OCR_ENABLE_PADDLEOCR=false
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -105,6 +105,34 @@ MAX_UPLOAD_SIZE_MB=10
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=false
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
168
backend/TEST-OCR-WINDOWS.bat
Normal file
168
backend/TEST-OCR-WINDOWS.bat
Normal file
@@ -0,0 +1,168 @@
|
||||
@echo off
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
cd /d "%~dp0"
|
||||
|
||||
REM Parse command line arguments for worker counts
|
||||
REM Usage: TEST-OCR-WINDOWS.bat [worker_counts...]
|
||||
REM Examples:
|
||||
REM TEST-OCR-WINDOWS.bat -> tests 1,2,3 workers (default)
|
||||
REM TEST-OCR-WINDOWS.bat 1 -> tests only 1 worker
|
||||
REM TEST-OCR-WINDOWS.bat 3 6 -> tests 3 and 6 workers
|
||||
REM TEST-OCR-WINDOWS.bat 1 2 3 4 5 6 -> tests all
|
||||
|
||||
set "WORKER_LIST=%*"
|
||||
if "%WORKER_LIST%"=="" set "WORKER_LIST=1 2 3"
|
||||
|
||||
echo.
|
||||
echo ==========================================
|
||||
echo OCR Benchmark - Windows (Workers: %WORKER_LIST%)
|
||||
echo ==========================================
|
||||
echo.
|
||||
|
||||
REM Check if Poppler is installed
|
||||
where pdftoppm >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo Checking for Poppler...
|
||||
if exist "E:\poppler" (
|
||||
for /r "E:\poppler" %%i in (pdftoppm.exe) do (
|
||||
set "POPPLER_BIN=%%~dpi"
|
||||
goto :found_poppler
|
||||
)
|
||||
)
|
||||
echo.
|
||||
echo ERROR: Poppler not found!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
:found_poppler
|
||||
if defined POPPLER_BIN (
|
||||
echo Found Poppler at: %POPPLER_BIN%
|
||||
set "PATH=%POPPLER_BIN%;%PATH%"
|
||||
)
|
||||
|
||||
REM Check venv
|
||||
if not exist "venv-win\Scripts\python.exe" (
|
||||
echo ERROR: venv-win not found!
|
||||
echo Run: python -m venv venv-win
|
||||
echo Then: venv-win\Scripts\pip install -r requirements.txt
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
REM Set common environment
|
||||
set JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
|
||||
set ORACLE_HOST=10.0.20.121
|
||||
set ORACLE_PORT=1521
|
||||
set ORACLE_USER=CONTAFIN_ORACLE
|
||||
set ORACLE_PASSWORD=ROMFASTSOFT
|
||||
set ORACLE_SERVICE_NAME=ROA
|
||||
set OCR_ENABLE_PADDLEOCR=false
|
||||
set OCR_ENABLE_TESSERACT=false
|
||||
set OCR_DEFAULT_ENGINE=hybrid-doctr
|
||||
set OCR_MAX_TASKS_PER_CHILD=0
|
||||
set LOG_LEVEL=WARNING
|
||||
|
||||
REM Results file with timestamp
|
||||
for /f "tokens=2 delims==" %%I in ('wmic os get localdatetime /value') do set datetime=%%I
|
||||
set RESULTS_FILE=ocr_benchmark_%datetime:~0,8%_%datetime:~8,4%.json
|
||||
|
||||
echo Results will be saved to: %RESULTS_FILE%
|
||||
echo.
|
||||
|
||||
REM Delete old results file if exists
|
||||
if exist "%RESULTS_FILE%" del "%RESULTS_FILE%"
|
||||
|
||||
REM Run tests with specified workers
|
||||
for %%W in (%WORKER_LIST%) do (
|
||||
call :run_test %%W
|
||||
)
|
||||
|
||||
goto :show_summary
|
||||
|
||||
:run_test
|
||||
set WORKERS=%1
|
||||
echo.
|
||||
echo ############################################################
|
||||
echo STARTING TEST WITH %WORKERS% WORKER(S)
|
||||
echo ############################################################
|
||||
echo.
|
||||
|
||||
REM Kill existing processes on port 8006
|
||||
echo Cleaning up old processes...
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
|
||||
taskkill /F /PID %%a >nul 2>&1
|
||||
)
|
||||
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
|
||||
timeout /t 3 >nul
|
||||
|
||||
REM Set workers count
|
||||
set OCR_WORKERS=%WORKERS%
|
||||
|
||||
echo Starting backend with %WORKERS% OCR worker(s)...
|
||||
|
||||
REM Start backend in a new minimized window with all OCR env vars
|
||||
start /min "ROA2WEB Backend %WORKERS% workers" cmd /c "set OCR_WORKERS=%WORKERS%&& set OCR_ENABLE_PADDLEOCR=false&& set OCR_ENABLE_TESSERACT=false&& set OCR_DEFAULT_ENGINE=hybrid-doctr&& set LOG_LEVEL=WARNING&& venv-win\Scripts\python.exe -m uvicorn main:app --host 0.0.0.0 --port 8006 --workers 1 2>&1"
|
||||
|
||||
REM Wait for backend to be ready
|
||||
echo Waiting for backend to start...
|
||||
set attempts=0
|
||||
:wait_loop
|
||||
timeout /t 3 >nul
|
||||
set /a attempts+=1
|
||||
curl -s http://localhost:8006/health >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
if !attempts! lss 40 (
|
||||
echo Waiting... !attempts!/40
|
||||
goto :wait_loop
|
||||
)
|
||||
echo ERROR: Backend failed to start!
|
||||
goto :eof
|
||||
)
|
||||
|
||||
echo Backend is ready!
|
||||
|
||||
REM Wait for OCR warmup
|
||||
echo Waiting for OCR worker warmup (30s)...
|
||||
timeout /t 30 >nul
|
||||
|
||||
echo.
|
||||
echo Running OCR test with %WORKERS% worker(s)...
|
||||
echo.
|
||||
|
||||
venv-win\Scripts\python.exe ..\tests\ocr-validation\test_receipts_parallel_windows.py --port 8006 --workers %WORKERS% --output %RESULTS_FILE%
|
||||
|
||||
REM Stop backend
|
||||
echo.
|
||||
echo Stopping backend...
|
||||
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
|
||||
taskkill /F /PID %%a >nul 2>&1
|
||||
)
|
||||
|
||||
REM Wait for memory to be released
|
||||
echo Releasing memory (10s)...
|
||||
timeout /t 10 >nul
|
||||
goto :eof
|
||||
|
||||
:show_summary
|
||||
echo.
|
||||
echo ############################################################
|
||||
echo ALL TESTS COMPLETE
|
||||
echo ############################################################
|
||||
echo.
|
||||
echo Results saved to: %RESULTS_FILE%
|
||||
echo.
|
||||
|
||||
REM Show summary from results file
|
||||
if exist "%RESULTS_FILE%" (
|
||||
echo BENCHMARK SUMMARY:
|
||||
echo ------------------
|
||||
venv-win\Scripts\python.exe -c "import json; data=json.load(open('%RESULTS_FILE%')); print(); [print(f\" {r['workers']} worker(s): {r['total_time']:.1f}s total, {r['avg_time']:.1f}s avg, {r.get('peak_memory_mb', 0):.0f}MB peak, {r['successful']}/{r['submitted']} success\") for r in data]"
|
||||
echo.
|
||||
)
|
||||
|
||||
echo Press any key to exit...
|
||||
pause >nul
|
||||
|
||||
endlocal
|
||||
@@ -38,12 +38,20 @@ from backend.modules.reports.routers import create_reports_router
|
||||
from backend.modules.data_entry.routers import create_data_entry_router
|
||||
from backend.modules.telegram.routers import create_telegram_router
|
||||
|
||||
# Configure logging
|
||||
# Configure logging (level from env: DEBUG, INFO, WARNING, ERROR)
|
||||
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('multipart').setLevel(logging.WARNING)
|
||||
logging.getLogger('doctr').setLevel(logging.WARNING)
|
||||
logging.getLogger('tensorflow').setLevel(logging.WARNING)
|
||||
logging.getLogger('PIL').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for background tasks
|
||||
|
||||
@@ -48,6 +48,11 @@ class Settings(BaseSettings):
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:3010,http://localhost:3000"
|
||||
|
||||
# OCR Engines (comma-separated list of active engines shown in UI)
|
||||
# Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
# doctr_plus is recommended (2-tier sequential with early exit)
|
||||
ocr_active_engines: str = "doctr,doctr_plus"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
@@ -80,6 +85,11 @@ class Settings(BaseSettings):
|
||||
"""Get CORS origins as list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
@property
|
||||
def ocr_active_engines_list(self) -> List[str]:
|
||||
"""Get OCR active engines as list."""
|
||||
return [engine.strip() for engine in self.ocr_active_engines.split(",")]
|
||||
|
||||
@property
|
||||
def oracle_dsn(self) -> str:
|
||||
"""Get Oracle DSN string."""
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
from .receipt import ReceiptCRUD
|
||||
from .attachment import AttachmentCRUD
|
||||
from .accounting_entry import AccountingEntryCRUD
|
||||
from .ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCRUD",
|
||||
"AttachmentCRUD",
|
||||
"AccountingEntryCRUD",
|
||||
"OCRPreferenceCRUD",
|
||||
"OCRMetricsCRUD",
|
||||
]
|
||||
|
||||
222
backend/modules/data_entry/db/crud/ocr_settings.py
Normal file
222
backend/modules/data_entry/db/crud/ocr_settings.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""CRUD operations for OCR settings and metrics."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.ocr_settings import (
|
||||
UserOCRPreference,
|
||||
OCRJobMetrics,
|
||||
OCRMetricsSummary,
|
||||
OCREngine,
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceCRUD:
|
||||
"""CRUD operations for user OCR preferences."""
|
||||
|
||||
@staticmethod
|
||||
async def get_by_username(session: AsyncSession, username: str) -> Optional[UserOCRPreference]:
|
||||
"""Get user's OCR preference by username."""
|
||||
result = await session.execute(
|
||||
select(UserOCRPreference).where(UserOCRPreference.username == username)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def create_or_update(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
preferred_engine: OCREngine
|
||||
) -> UserOCRPreference:
|
||||
"""Create or update user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
|
||||
if existing:
|
||||
existing.preferred_engine = preferred_engine
|
||||
existing.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
preference = UserOCRPreference(
|
||||
username=username,
|
||||
preferred_engine=preferred_engine
|
||||
)
|
||||
session.add(preference)
|
||||
await session.commit()
|
||||
await session.refresh(preference)
|
||||
return preference
|
||||
|
||||
@staticmethod
|
||||
async def delete_by_username(session: AsyncSession, username: str) -> bool:
|
||||
"""Delete user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class OCRMetricsCRUD:
|
||||
"""CRUD operations for OCR job metrics."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
company_id: Optional[int] = None
|
||||
) -> OCRJobMetrics:
|
||||
"""Create a new OCR job metrics record."""
|
||||
metrics = OCRJobMetrics(
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
company_id=company_id,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
session.add(metrics)
|
||||
await session.commit()
|
||||
await session.refresh(metrics)
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
async def get_by_job_id(session: AsyncSession, job_id: str) -> Optional[OCRJobMetrics]:
|
||||
"""Get metrics by job ID."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics).where(OCRJobMetrics.job_id == job_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_user_history(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[OCRJobMetrics]:
|
||||
"""Get user's OCR job history."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics)
|
||||
.where(OCRJobMetrics.username == username)
|
||||
.order_by(OCRJobMetrics.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def get_summary_by_engine(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> List[OCRMetricsSummary]:
|
||||
"""Get summary metrics grouped by engine."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Build query
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
# Query for aggregated metrics
|
||||
result = await session.execute(
|
||||
select(
|
||||
OCRJobMetrics.engine_used,
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
func.avg(OCRJobMetrics.fields_extracted).label('avg_fields_extracted'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
.group_by(OCRJobMetrics.engine_used)
|
||||
.order_by(func.count(OCRJobMetrics.id).desc())
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for row in result.all():
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
success_rate = successful / total if total > 0 else 0.0
|
||||
summaries.append(OCRMetricsSummary(
|
||||
engine=row.engine_used,
|
||||
total_jobs=total,
|
||||
successful_jobs=successful,
|
||||
failed_jobs=total - successful,
|
||||
success_rate=success_rate,
|
||||
avg_processing_time_ms=float(row.avg_processing_time_ms or 0),
|
||||
avg_confidence=float(row.avg_confidence or 0),
|
||||
avg_fields_extracted=float(row.avg_fields_extracted or 0),
|
||||
))
|
||||
|
||||
return summaries
|
||||
|
||||
@staticmethod
|
||||
async def get_overall_stats(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> dict:
|
||||
"""Get overall OCR statistics."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
)
|
||||
|
||||
row = result.one()
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
|
||||
return {
|
||||
"total_jobs": total,
|
||||
"successful_jobs": successful,
|
||||
"failed_jobs": total - successful,
|
||||
"success_rate": (successful / total * 100) if total > 0 else 0.0,
|
||||
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
|
||||
"avg_confidence": float(row.avg_confidence or 0),
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
|
||||
# Import sqlalchemy for func.cast
|
||||
import sqlalchemy as sa
|
||||
@@ -10,9 +10,10 @@ from backend.modules.data_entry.config import settings
|
||||
|
||||
|
||||
# Create async engine
|
||||
# Note: echo=False to disable SQL query logging (too verbose)
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.debug,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection
|
||||
from .accounting_entry import AccountingEntry, EntryType
|
||||
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from .ocr_settings import UserOCRPreference, OCRJobMetrics, OCRMetricsSummary, OCREngine
|
||||
|
||||
__all__ = [
|
||||
"Receipt",
|
||||
@@ -14,4 +15,9 @@ __all__ = [
|
||||
"SyncedSupplier",
|
||||
"LocalSupplier",
|
||||
"SyncedCashRegister",
|
||||
# OCR Settings & Metrics
|
||||
"UserOCRPreference",
|
||||
"OCRJobMetrics",
|
||||
"OCRMetricsSummary",
|
||||
"OCREngine",
|
||||
]
|
||||
|
||||
102
backend/modules/data_entry/db/models/ocr_settings.py
Normal file
102
backend/modules/data_entry/db/models/ocr_settings.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""OCR settings and metrics SQLModel models."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class OCREngine(str, Enum):
|
||||
"""Available OCR engines."""
|
||||
TESSERACT = "tesseract"
|
||||
DOCTR = "doctr"
|
||||
DOCTR_PLUS = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
PADDLEOCR = "paddleocr"
|
||||
|
||||
|
||||
class UserOCRPreference(SQLModel, table=True):
|
||||
"""
|
||||
User's preferred OCR engine setting.
|
||||
|
||||
Each user can have one preferred OCR engine that will be
|
||||
auto-selected when they upload new receipts for processing.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_ocr_preferences"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# User identification
|
||||
username: str = Field(max_length=100, unique=True, index=True)
|
||||
|
||||
# Preference settings
|
||||
preferred_engine: OCREngine = Field(default=OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRJobMetrics(SQLModel, table=True):
|
||||
"""
|
||||
OCR job processing metrics for analytics.
|
||||
|
||||
Stores metrics for each OCR job to enable:
|
||||
- Performance tracking by engine
|
||||
- Success rate analysis
|
||||
- Processing time trends
|
||||
- User-specific analytics
|
||||
"""
|
||||
|
||||
__tablename__ = "ocr_job_metrics"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Job identification
|
||||
job_id: str = Field(max_length=50, unique=True, index=True)
|
||||
|
||||
# User and company context
|
||||
username: str = Field(max_length=100, index=True)
|
||||
company_id: Optional[int] = Field(default=None, index=True)
|
||||
|
||||
# Engine used
|
||||
engine_requested: str = Field(max_length=20) # What user/auto requested
|
||||
engine_used: str = Field(max_length=50) # What was actually used (e.g., "doctr-light")
|
||||
|
||||
# Processing metrics
|
||||
processing_time_ms: int = Field(default=0)
|
||||
file_size_bytes: int = Field(default=0)
|
||||
file_type: str = Field(max_length=50, default="image/jpeg") # MIME type
|
||||
original_filename: Optional[str] = Field(default=None, max_length=255) # Original uploaded filename
|
||||
|
||||
# Success metrics
|
||||
success: bool = Field(default=True)
|
||||
error_message: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Extraction quality metrics
|
||||
overall_confidence: float = Field(default=0.0)
|
||||
fields_extracted: int = Field(default=0) # Number of fields successfully extracted
|
||||
needs_manual_review: Optional[bool] = Field(default=None)
|
||||
validation_warnings_count: int = Field(default=0)
|
||||
validation_errors_count: int = Field(default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRMetricsSummary(SQLModel):
|
||||
"""
|
||||
Summary metrics for OCR analytics.
|
||||
|
||||
Not a database table - used for API responses.
|
||||
"""
|
||||
engine: str
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float # Computed: successful_jobs / total_jobs
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
avg_fields_extracted: float
|
||||
@@ -17,6 +17,7 @@ load_dotenv()
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from backend.modules.data_entry.db.models.ocr_settings import UserOCRPreference, OCRJobMetrics
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Add OCR settings and metrics tables.
|
||||
|
||||
Revision ID: add_ocr_settings_metrics
|
||||
Revises: 20251230_add_needs_manual_review
|
||||
Create Date: 2025-12-31
|
||||
|
||||
This migration adds:
|
||||
- user_ocr_preferences: Store user's preferred OCR engine
|
||||
- ocr_job_metrics: Store OCR job processing metrics for analytics
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_ocr_settings_metrics'
|
||||
down_revision = '20251230_add_needs_manual_review'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create OCR settings and metrics tables."""
|
||||
|
||||
# Create user_ocr_preferences table
|
||||
op.create_table(
|
||||
'user_ocr_preferences',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('preferred_engine', sa.String(length=20), nullable=False, server_default='doctr_plus'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_user_ocr_preferences_username', 'user_ocr_preferences', ['username'], unique=True)
|
||||
|
||||
# Create ocr_job_metrics table
|
||||
op.create_table(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=50), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('engine_requested', sa.String(length=20), nullable=False),
|
||||
sa.Column('engine_used', sa.String(length=50), nullable=False),
|
||||
sa.Column('processing_time_ms', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False, server_default='image/jpeg'),
|
||||
sa.Column('success', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True),
|
||||
sa.Column('overall_confidence', sa.Float(), nullable=False, server_default='0.0'),
|
||||
sa.Column('fields_extracted', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('needs_manual_review', sa.Boolean(), nullable=True),
|
||||
sa.Column('validation_warnings_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('validation_errors_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_ocr_job_metrics_job_id', 'ocr_job_metrics', ['job_id'], unique=True)
|
||||
op.create_index('ix_ocr_job_metrics_username', 'ocr_job_metrics', ['username'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_company_id', 'ocr_job_metrics', ['company_id'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_created_at', 'ocr_job_metrics', ['created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop OCR settings and metrics tables."""
|
||||
op.drop_index('ix_ocr_job_metrics_created_at', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_company_id', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_username', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_job_id', table_name='ocr_job_metrics')
|
||||
op.drop_table('ocr_job_metrics')
|
||||
|
||||
op.drop_index('ix_user_ocr_preferences_username', table_name='user_ocr_preferences')
|
||||
op.drop_table('user_ocr_preferences')
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Add original_filename to ocr_job_metrics.
|
||||
|
||||
Revision ID: add_original_filename_to_metrics
|
||||
Revises: add_ocr_settings_metrics
|
||||
Create Date: 2025-12-31
|
||||
|
||||
Adds original_filename column to track the uploaded filename.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_original_filename_to_metrics'
|
||||
down_revision = 'add_ocr_settings_metrics'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add original_filename column to ocr_job_metrics."""
|
||||
op.add_column(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove original_filename column."""
|
||||
op.drop_column('ocr_job_metrics', 'original_filename')
|
||||
@@ -11,6 +11,8 @@ def create_data_entry_router() -> APIRouter:
|
||||
- /receipts - Receipt CRUD and workflow
|
||||
- /ocr - OCR processing for receipts
|
||||
- /nomenclature - Nomenclature syncing from Oracle
|
||||
- /settings - User settings (OCR preferences)
|
||||
- /metrics - OCR analytics and metrics
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for data entry module
|
||||
@@ -21,10 +23,13 @@ def create_data_entry_router() -> APIRouter:
|
||||
from .receipts import router as receipts_router
|
||||
from .ocr import router as ocr_router
|
||||
from .nomenclature import router as nomenclature_router
|
||||
from .ocr_settings import router as ocr_settings_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
|
||||
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
|
||||
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
|
||||
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
|
||||
# OCR settings and metrics (endpoints at /settings/* and /metrics/*)
|
||||
router.include_router(ocr_settings_router, tags=["data-entry-settings"])
|
||||
|
||||
return router
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.modules.data_entry.services.ocr_service import ocr_service
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
|
||||
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
from backend.modules.data_entry.schemas.ocr import (
|
||||
OCRResponse,
|
||||
OCRStatusResponse,
|
||||
@@ -55,7 +56,7 @@ router = APIRouter()
|
||||
@router.post("/extract", response_model=OCRJobSubmitResponse)
|
||||
async def submit_ocr_job(
|
||||
file: UploadFile = File(...),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.auto, description="OCR engine to use"),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus, description="OCR engine to use"),
|
||||
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
@@ -69,7 +70,7 @@ async def submit_ocr_job(
|
||||
|
||||
Args:
|
||||
file: Image or PDF file (max 10MB)
|
||||
engine: OCR engine choice (auto, paddleocr, tesseract)
|
||||
engine: OCR engine choice (tesseract, doctr, doctr_plus, paddleocr)
|
||||
sync: If true, process synchronously (legacy mode)
|
||||
|
||||
Returns:
|
||||
@@ -129,13 +130,13 @@ async def submit_ocr_job(
|
||||
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
|
||||
async def get_job_status(
|
||||
job_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR job status and result.
|
||||
Get OCR job status and result (instant response).
|
||||
|
||||
Poll this endpoint to check job progress.
|
||||
Recommended polling interval: 2 seconds.
|
||||
For efficient polling, use GET /jobs/{job_id}/wait instead (long-polling).
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
@@ -165,6 +166,10 @@ async def get_job_status(
|
||||
result_data = None
|
||||
if job.status == JobStatus.completed and job.result:
|
||||
result_data = _dict_to_extraction_data(job.result)
|
||||
# Apply fuzzy CUI matching
|
||||
result_data = await _apply_fuzzy_cui_matching(result_data, session)
|
||||
# Debug: log suggested_payment_mode being returned
|
||||
print(f"[OCR Router] Returning job {job_id} with suggested_payment_mode={result_data.suggested_payment_mode}", flush=True)
|
||||
|
||||
return OCRJobResponse(
|
||||
job_id=job.id,
|
||||
@@ -174,12 +179,66 @@ async def get_job_status(
|
||||
created_at=job.created_at or datetime.utcnow(),
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
queue_wait_ms=job.queue_wait_ms,
|
||||
ocr_time_ms=job.ocr_time_ms,
|
||||
processing_time_ms=job.processing_time_ms,
|
||||
result=result_data,
|
||||
error=job.error_message
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/wait", response_model=OCRJobResponse)
|
||||
async def wait_for_job_status(
|
||||
job_id: str,
|
||||
timeout: int = Query(default=30, ge=1, le=60, description="Max wait time in seconds"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Long-poll for OCR job status change.
|
||||
|
||||
Waits until:
|
||||
- Job status changes to completed/failed
|
||||
- Timeout expires (returns current status)
|
||||
|
||||
Recommended client timeout: timeout + 5 seconds
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
timeout: Max wait time in seconds (1-60, default 30)
|
||||
|
||||
Returns:
|
||||
OCRJobResponse with status, queue_position, and result (if completed)
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
end_time = time.time() + timeout
|
||||
last_status = None
|
||||
|
||||
while time.time() < end_time:
|
||||
job = await job_queue.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# Return immediately if job completed or failed
|
||||
if job.status in [JobStatus.completed, JobStatus.failed]:
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
# Return if status changed from last check
|
||||
if last_status is not None and job.status != last_status:
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
last_status = job.status
|
||||
|
||||
# Wait 1 second before next internal check
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Timeout - return current status
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
|
||||
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
|
||||
async def get_queue_status(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
@@ -221,10 +280,58 @@ async def get_ocr_status():
|
||||
)
|
||||
|
||||
|
||||
@router.get("/engines")
|
||||
async def get_available_engines():
|
||||
"""
|
||||
Get list of enabled OCR engines based on .env configuration.
|
||||
|
||||
Returns engines availability and available processing modes.
|
||||
Frontend should use this to filter engine selection dropdown.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
default_engine = os.getenv("OCR_DEFAULT_ENGINE", "doctr_plus")
|
||||
|
||||
# Build engines dict
|
||||
engines = {
|
||||
"tesseract": tesseract_enabled,
|
||||
"doctr": True, # Always available (primary engine)
|
||||
"doctr_plus": True, # Always available (recommended)
|
||||
"paddleocr": paddle_enabled,
|
||||
}
|
||||
|
||||
# Build available modes based on enabled engines
|
||||
modes = []
|
||||
|
||||
if tesseract_enabled:
|
||||
modes.append("tesseract")
|
||||
|
||||
modes.append("doctr")
|
||||
modes.append("doctr_plus")
|
||||
|
||||
if paddle_enabled:
|
||||
modes.append("paddleocr")
|
||||
|
||||
return {
|
||||
"engines": engines,
|
||||
"available_modes": modes,
|
||||
"default_mode": default_engine,
|
||||
"memory_estimate_mb": {
|
||||
"tesseract": 50,
|
||||
"doctr": 600,
|
||||
"doctr_plus": 600,
|
||||
"paddleocr": 800,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
|
||||
async def extract_from_attachment(
|
||||
attachment_id: int,
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.auto),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
@@ -260,6 +367,8 @@ async def extract_from_attachment(
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
data = _result_to_extraction_data(result)
|
||||
# Apply fuzzy CUI matching
|
||||
data = await _apply_fuzzy_cui_matching(data, session)
|
||||
return OCRResponse(success=True, message=message, data=data)
|
||||
|
||||
|
||||
@@ -267,6 +376,58 @@ async def extract_from_attachment(
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _apply_fuzzy_cui_matching(
|
||||
extraction_data: ExtractionData,
|
||||
session: AsyncSession
|
||||
) -> ExtractionData:
|
||||
"""
|
||||
Apply fuzzy CUI matching to extraction data.
|
||||
|
||||
ONLY applies fuzzy matching if CUI is missing OR has invalid checksum.
|
||||
If CUI has valid checksum, we trust the OCR and skip fuzzy matching.
|
||||
|
||||
Args:
|
||||
extraction_data: ExtractionData with CUI to potentially correct
|
||||
session: AsyncSession for database lookups
|
||||
|
||||
Returns:
|
||||
ExtractionData with CUI corrected if a match was found
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
# Skip if no CUI and no vendor name (nothing to match)
|
||||
if not extraction_data.cui and not extraction_data.partner_name:
|
||||
return extraction_data
|
||||
|
||||
# Check if CUI has valid checksum - if valid, skip fuzzy matching
|
||||
if extraction_data.cui:
|
||||
cui_digits = CUIChecksumRule.extract_digits(extraction_data.cui)
|
||||
if len(cui_digits) >= 6 and CUIChecksumRule.validate_checksum(cui_digits):
|
||||
print(f"[Fuzzy Match] CUI {extraction_data.cui} has valid checksum, skipping fuzzy match", flush=True)
|
||||
return extraction_data
|
||||
|
||||
# CUI missing or invalid checksum - try fuzzy matching
|
||||
try:
|
||||
match = await OCRValidationEngine.fuzzy_match_supplier(
|
||||
cui=extraction_data.cui,
|
||||
vendor_name=extraction_data.partner_name,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
if corrected_cui != extraction_data.cui:
|
||||
print(f"[Fuzzy Match] Corrected: {extraction_data.cui} → {corrected_cui} ({supplier_name})", flush=True)
|
||||
extraction_data.cui = corrected_cui
|
||||
# Also set partner_name if not already set
|
||||
if not extraction_data.partner_name:
|
||||
extraction_data.partner_name = supplier_name
|
||||
except Exception as e:
|
||||
print(f"[Fuzzy Match] Error: {e}", flush=True)
|
||||
|
||||
return extraction_data
|
||||
|
||||
|
||||
async def _process_sync(
|
||||
content: bytes,
|
||||
file: UploadFile,
|
||||
@@ -362,6 +523,7 @@ def _result_to_extraction_data(result) -> ExtractionData:
|
||||
confidence_client=getattr(result, 'confidence_client', 0.0),
|
||||
overall_confidence=result.overall_confidence,
|
||||
raw_text=result.raw_text,
|
||||
raw_texts=getattr(result, 'raw_texts', []),
|
||||
ocr_engine=result.ocr_engine,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
needs_manual_review=result.needs_manual_review,
|
||||
@@ -437,6 +599,7 @@ def _dict_to_extraction_data(data: dict) -> ExtractionData:
|
||||
confidence_client=data.get('confidence_client', 0.0),
|
||||
overall_confidence=data.get('overall_confidence', 0.0),
|
||||
raw_text=data.get('raw_text', ''),
|
||||
raw_texts=data.get('raw_texts', []),
|
||||
ocr_engine=data.get('ocr_engine', ''),
|
||||
processing_time_ms=data.get('processing_time_ms', 0),
|
||||
needs_manual_review=data.get('needs_manual_review'),
|
||||
|
||||
268
backend/modules/data_entry/routers/ocr_settings.py
Normal file
268
backend/modules/data_entry/routers/ocr_settings.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
OCR Settings and Metrics API endpoints.
|
||||
|
||||
Endpoints:
|
||||
- GET /settings/ocr-preference - Get user's preferred OCR engine
|
||||
- POST /settings/ocr-preference - Set user's preferred OCR engine
|
||||
- GET /metrics/ocr/summary - Get OCR metrics summary by engine
|
||||
- GET /metrics/ocr/history - Get user's OCR job history
|
||||
- GET /metrics/ocr/stats - Get overall OCR statistics
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
from backend.modules.data_entry.db.models.ocr_settings import OCREngine, OCRMetricsSummary
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Schemas
|
||||
# ============================================================================
|
||||
|
||||
class OCRPreferenceResponse(BaseModel):
|
||||
"""Response for OCR preference endpoint."""
|
||||
username: str
|
||||
preferred_engine: str
|
||||
available_engines: List[str] = Field(
|
||||
default=["tesseract", "doctr", "doctr_plus", "paddleocr"],
|
||||
description="Available OCR engines"
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceRequest(BaseModel):
|
||||
"""Request to set OCR preference."""
|
||||
preferred_engine: str = Field(
|
||||
default="doctr_plus",
|
||||
description="Preferred OCR engine: tesseract, doctr, doctr_plus, paddleocr"
|
||||
)
|
||||
|
||||
|
||||
class OCRMetricsHistoryItem(BaseModel):
|
||||
"""Single OCR job metrics item."""
|
||||
job_id: str
|
||||
engine_requested: str
|
||||
engine_used: str
|
||||
processing_time_ms: int
|
||||
success: bool
|
||||
overall_confidence: float
|
||||
fields_extracted: int
|
||||
created_at: str
|
||||
original_filename: Optional[str] = None
|
||||
|
||||
|
||||
class OCRMetricsHistoryResponse(BaseModel):
|
||||
"""Response for OCR history endpoint."""
|
||||
items: List[OCRMetricsHistoryItem]
|
||||
total: int
|
||||
|
||||
|
||||
class OCRStatsResponse(BaseModel):
|
||||
"""Response for OCR stats endpoint."""
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
period_days: int
|
||||
|
||||
|
||||
class OCRActiveEnginesResponse(BaseModel):
|
||||
"""Response for active OCR engines endpoint."""
|
||||
engines: List[str] = Field(description="List of active OCR engines from .env config")
|
||||
recommended: str = Field(default="doctr_plus", description="Recommended engine")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Engines Configuration Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-engines", response_model=OCRActiveEnginesResponse)
|
||||
async def get_active_ocr_engines():
|
||||
"""
|
||||
Get list of active OCR engines configured in .env.
|
||||
|
||||
Returns the engines that should be shown in the frontend dropdown.
|
||||
Configured via OCR_ACTIVE_ENGINES environment variable.
|
||||
|
||||
Default: doctr,doctr_plus
|
||||
Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
"""
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
return OCRActiveEnginesResponse(
|
||||
engines=settings.ocr_active_engines_list,
|
||||
recommended="doctr_plus"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Preference Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def get_ocr_preference(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's preferred OCR engine.
|
||||
|
||||
Returns the user's saved preference or 'doctr_plus' if not set.
|
||||
Also returns list of available engines.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
preference = await OCRPreferenceCRUD.get_by_username(session, current_user.username)
|
||||
|
||||
# Get available engines from OCR service
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value if preference else "doctr_plus",
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
@router.post("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def set_ocr_preference(
|
||||
request: OCRPreferenceRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Set user's preferred OCR engine.
|
||||
|
||||
Valid engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
Note: Available engines depend on .env configuration (OCR_ENABLE_PADDLEOCR, OCR_ENABLE_TESSERACT)
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
# Get dynamically available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
if request.preferred_engine not in available:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid engine. Must be one of: {', '.join(available)}"
|
||||
)
|
||||
|
||||
# Map string to enum
|
||||
engine_map = {
|
||||
"tesseract": OCREngine.TESSERACT,
|
||||
"doctr": OCREngine.DOCTR,
|
||||
"doctr_plus": OCREngine.DOCTR_PLUS,
|
||||
"paddleocr": OCREngine.PADDLEOCR,
|
||||
}
|
||||
engine_enum = engine_map.get(request.preferred_engine, OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Save preference
|
||||
preference = await OCRPreferenceCRUD.create_or_update(
|
||||
session,
|
||||
current_user.username,
|
||||
engine_enum
|
||||
)
|
||||
|
||||
# Get available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value,
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Metrics Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/ocr/summary", response_model=List[OCRMetricsSummary])
|
||||
async def get_ocr_metrics_summary(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR metrics summary grouped by engine.
|
||||
|
||||
Returns aggregated metrics for each engine used in the specified period.
|
||||
"""
|
||||
summaries = await OCRMetricsCRUD.get_summary_by_engine(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/history", response_model=OCRMetricsHistoryResponse)
|
||||
async def get_ocr_metrics_history(
|
||||
limit: int = Query(default=50, ge=1, le=200, description="Max items to return"),
|
||||
offset: int = Query(default=0, ge=0, description="Items to skip"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's OCR job history.
|
||||
|
||||
Returns list of OCR jobs with their metrics, ordered by most recent first.
|
||||
"""
|
||||
items = await OCRMetricsCRUD.get_user_history(
|
||||
session,
|
||||
username=current_user.username,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
history_items = [
|
||||
OCRMetricsHistoryItem(
|
||||
job_id=item.job_id,
|
||||
engine_requested=item.engine_requested,
|
||||
engine_used=item.engine_used,
|
||||
processing_time_ms=item.processing_time_ms,
|
||||
success=item.success,
|
||||
overall_confidence=item.overall_confidence,
|
||||
fields_extracted=item.fields_extracted,
|
||||
created_at=item.created_at.isoformat(),
|
||||
original_filename=item.original_filename
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return OCRMetricsHistoryResponse(
|
||||
items=history_items,
|
||||
total=len(history_items)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/stats", response_model=OCRStatsResponse)
|
||||
async def get_ocr_stats(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get overall OCR statistics for the user.
|
||||
|
||||
Returns aggregated stats including success rate, average processing time, etc.
|
||||
"""
|
||||
stats = await OCRMetricsCRUD.get_overall_stats(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
|
||||
return OCRStatsResponse(**stats)
|
||||
@@ -61,7 +61,8 @@ class ExtractionData(BaseModel):
|
||||
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
|
||||
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
|
||||
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
|
||||
raw_text: str = Field(default="", description="Raw OCR text")
|
||||
raw_text: str = Field(default="", description="Raw OCR text (primary)")
|
||||
raw_texts: List[str] = Field(default=[], description="Raw OCR texts from all engine passes (for analysis)")
|
||||
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
|
||||
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
|
||||
|
||||
@@ -148,9 +149,10 @@ from enum import Enum
|
||||
|
||||
class OCREngineChoice(str, Enum):
|
||||
"""OCR engine selection options."""
|
||||
auto = "auto"
|
||||
paddleocr = "paddleocr"
|
||||
tesseract = "tesseract"
|
||||
doctr = "doctr" # 3.3x faster than PaddleOCR with same accuracy (90/100)
|
||||
doctr_plus = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
paddleocr = "paddleocr"
|
||||
|
||||
|
||||
class OCRJobStatus(str, Enum):
|
||||
@@ -193,7 +195,10 @@ class OCRJobResponse(BaseModel):
|
||||
created_at: datetime = Field(description="Job creation timestamp")
|
||||
started_at: Optional[datetime] = Field(default=None, description="Processing start timestamp")
|
||||
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
|
||||
processing_time_ms: Optional[int] = Field(default=None, description="Actual processing time in ms")
|
||||
# Detailed timing breakdown
|
||||
queue_wait_ms: Optional[int] = Field(default=None, description="Time waiting in queue (started_at - created_at)")
|
||||
ocr_time_ms: Optional[int] = Field(default=None, description="Actual OCR engine processing time")
|
||||
processing_time_ms: Optional[int] = Field(default=None, description="Total job processing time (completed_at - started_at)")
|
||||
result: Optional[ExtractionData] = Field(default=None, description="Extraction result (only if completed)")
|
||||
error: Optional[str] = Field(default=None, description="Error message (only if failed)")
|
||||
|
||||
|
||||
@@ -33,73 +33,55 @@ class NomenclatureService:
|
||||
"""
|
||||
Get partners (suppliers/customers) for a company.
|
||||
|
||||
Phase 1: Returns mock data.
|
||||
Phase 2: Returns synced data from SQLite (from Oracle sync).
|
||||
Phase 3: Will fetch live from Oracle.
|
||||
Returns synced suppliers from Oracle + local suppliers created from OCR.
|
||||
If no suppliers exist, returns empty list (frontend will trigger sync).
|
||||
"""
|
||||
# If session is provided, try to get from synced SQLite data
|
||||
if session:
|
||||
# Try to get from SQLite synced data
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name) # Order alphabetically, no limit for AutoComplete
|
||||
partners = []
|
||||
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
if suppliers:
|
||||
# Also get local suppliers
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name) # Order alphabetically
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
# Combine both - no IDs needed, just text data for autocomplete
|
||||
partners = []
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name, # No suffix - must match search results
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
# Fallback to mock data for Phase 1 (when no synced data)
|
||||
mock_partners = [
|
||||
PartnerOption(name="OMV Petrom", fiscal_code="RO123456", source="mock"),
|
||||
PartnerOption(name="Dedeman", fiscal_code="RO789012", source="mock"),
|
||||
PartnerOption(name="Kaufland", fiscal_code="RO345678", source="mock"),
|
||||
PartnerOption(name="Emag", fiscal_code="RO901234", source="mock"),
|
||||
PartnerOption(name="Altex", fiscal_code="RO567890", source="mock"),
|
||||
]
|
||||
if not session:
|
||||
return partners
|
||||
|
||||
# Get synced suppliers from Oracle
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
mock_partners = [
|
||||
p for p in mock_partners
|
||||
if search_lower in p.name.lower() or (p.fiscal_code and search_lower in p.fiscal_code.lower())
|
||||
]
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name)
|
||||
|
||||
return mock_partners
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
|
||||
# Always get local suppliers (not just when synced exist)
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name)
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name,
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
@staticmethod
|
||||
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
|
||||
@@ -13,13 +13,14 @@ Schema:
|
||||
status TEXT NOT NULL, -- pending, processing, completed, failed
|
||||
file_path TEXT NOT NULL, -- Path to uploaded file
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT, -- JSON extraction result
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
|
||||
ocr_time_ms INTEGER, -- Actual OCR engine processing time
|
||||
created_by TEXT, -- Username
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
@@ -74,17 +75,26 @@ class OCRJob:
|
||||
status: OCRJobStatus
|
||||
file_path: str
|
||||
mime_type: str
|
||||
engine: str = "auto"
|
||||
engine: str = "doctr_plus"
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result_json: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
|
||||
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
|
||||
created_by: Optional[str] = None
|
||||
original_filename: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def queue_wait_ms(self) -> Optional[int]:
|
||||
"""Calculate queue wait time (created_at to started_at)."""
|
||||
if self.created_at and self.started_at:
|
||||
delta = self.started_at - self.created_at
|
||||
return int(delta.total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> Optional[Dict]:
|
||||
"""Parse result_json to dict."""
|
||||
@@ -143,19 +153,27 @@ class OCRJobQueue:
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
file_path TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT,
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
ocr_time_ms INTEGER,
|
||||
created_by TEXT,
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Migration: add ocr_time_ms column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
|
||||
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Index for efficient queue queries
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
|
||||
@@ -177,7 +195,7 @@ class OCRJobQueue:
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
mime_type: str,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
username: Optional[str] = None,
|
||||
original_filename: Optional[str] = None
|
||||
) -> OCRJob:
|
||||
@@ -189,7 +207,7 @@ class OCRJobQueue:
|
||||
Args:
|
||||
file_bytes: Raw file bytes
|
||||
mime_type: MIME type of file
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
username: Username of requester
|
||||
original_filename: Original filename from upload
|
||||
|
||||
@@ -301,24 +319,52 @@ class OCRJobQueue:
|
||||
|
||||
async def get_next_pending(self) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get the next pending job (oldest first).
|
||||
Get the next pending job (oldest first) and atomically mark it as processing.
|
||||
|
||||
This prevents race conditions in parallel processing - only one worker
|
||||
can claim each job.
|
||||
|
||||
Returns:
|
||||
Next OCRJob to process or None if queue empty
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock: # Serialize access to prevent race conditions
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get the next pending job
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
job_id = row['id']
|
||||
|
||||
# Atomically mark as processing
|
||||
await db.execute('''
|
||||
UPDATE ocr_jobs
|
||||
SET status = 'processing', started_at = ?
|
||||
WHERE id = ? AND status = 'pending'
|
||||
''', (now.isoformat(), job_id))
|
||||
await db.commit()
|
||||
|
||||
# Fetch the updated job
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
updated_row = await cursor.fetchone()
|
||||
if updated_row:
|
||||
return self._row_to_job(updated_row)
|
||||
|
||||
return None
|
||||
|
||||
async def update_status(
|
||||
@@ -327,7 +373,8 @@ class OCRJobQueue:
|
||||
status: OCRJobStatus,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None,
|
||||
ocr_time_ms: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job status.
|
||||
@@ -337,7 +384,8 @@ class OCRJobQueue:
|
||||
status: New status
|
||||
result: Extraction result dict (for completed)
|
||||
error: Error message (for failed)
|
||||
processing_time_ms: Processing time
|
||||
processing_time_ms: Total job processing time (started_at to completed_at)
|
||||
ocr_time_ms: Actual OCR engine processing time
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
@@ -359,18 +407,18 @@ class OCRJobQueue:
|
||||
elif status == OCRJobStatus.completed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
elif status == OCRJobStatus.failed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
else:
|
||||
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
|
||||
@@ -542,13 +590,14 @@ class OCRJobQueue:
|
||||
status=OCRJobStatus(row['status']),
|
||||
file_path=row['file_path'],
|
||||
mime_type=row['mime_type'],
|
||||
engine=row['engine'] or 'auto',
|
||||
engine=row['engine'] or 'doctr_plus',
|
||||
created_at=parse_datetime(row['created_at']),
|
||||
started_at=parse_datetime(row['started_at']),
|
||||
completed_at=parse_datetime(row['completed_at']),
|
||||
result_json=row['result_json'],
|
||||
error_message=row['error_message'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
|
||||
created_by=row['created_by'],
|
||||
original_filename=row['original_filename'],
|
||||
expires_at=parse_datetime(row['expires_at']),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
OCR Job Worker - Background Task for Queue Processing
|
||||
|
||||
Runs as an asyncio background task in FastAPI.
|
||||
Continuously polls the job queue and processes OCR requests.
|
||||
Continuously polls the job queue and processes OCR requests IN PARALLEL.
|
||||
|
||||
Architecture:
|
||||
FastAPI startup
|
||||
@@ -12,18 +12,19 @@ Architecture:
|
||||
asyncio.create_task(_job_worker_loop())
|
||||
↓
|
||||
while True:
|
||||
job = job_queue.get_next_pending()
|
||||
if job:
|
||||
result = ocr_worker_pool.submit_task(...)
|
||||
job_queue.update_status(...)
|
||||
await asyncio.sleep(0.5)
|
||||
# Process up to OCR_WORKERS jobs concurrently
|
||||
jobs = get_pending_jobs(limit=available_slots)
|
||||
for job in jobs:
|
||||
asyncio.create_task(_process_job(job))
|
||||
await asyncio.sleep(0.1)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from .job_queue import job_queue, OCRJobStatus, OCRJob
|
||||
from .ocr_worker_pool import ocr_worker_pool
|
||||
@@ -34,47 +35,76 @@ logger = logging.getLogger(__name__)
|
||||
_job_worker_task: Optional[asyncio.Task] = None
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
|
||||
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL_SECONDS = 0.5 # How often to check for new jobs
|
||||
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
|
||||
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
|
||||
|
||||
|
||||
async def _job_worker_loop() -> None:
|
||||
"""
|
||||
Main worker loop - processes jobs from queue.
|
||||
Main worker loop - processes jobs from queue IN PARALLEL.
|
||||
|
||||
Runs continuously until shutdown. Polls queue every 0.5s
|
||||
and submits jobs to worker pool for processing.
|
||||
Runs continuously until shutdown. Uses semaphore to limit
|
||||
concurrent jobs to OCR_WORKERS count. Launches jobs as
|
||||
background tasks without waiting for completion.
|
||||
"""
|
||||
global _shutdown_event
|
||||
global _shutdown_event, _active_tasks, _concurrency_semaphore
|
||||
|
||||
logger.info("[JobWorker] Starting worker loop...")
|
||||
# Get max concurrent jobs from env (matches worker pool size)
|
||||
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
|
||||
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
_active_tasks = set()
|
||||
|
||||
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
max_consecutive_errors = 10
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
await _process_job(job)
|
||||
else:
|
||||
# No jobs - wait before polling again
|
||||
# Clean up completed tasks
|
||||
done_tasks = {t for t in _active_tasks if t.done()}
|
||||
for task in done_tasks:
|
||||
_active_tasks.discard(task)
|
||||
# Check for exceptions
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout, continue loop
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Task failed: {e}")
|
||||
|
||||
# Check if we have capacity for more jobs
|
||||
active_count = len(_active_tasks)
|
||||
available_slots = max_concurrent - active_count
|
||||
|
||||
if available_slots > 0:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0
|
||||
# Launch job processing as background task
|
||||
task = asyncio.create_task(_process_job_with_semaphore(job))
|
||||
_active_tasks.add(task)
|
||||
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
|
||||
else:
|
||||
# No pending jobs - wait briefly
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# At capacity - wait for a slot to free up
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Worker loop cancelled")
|
||||
@@ -88,27 +118,46 @@ async def _job_worker_loop() -> None:
|
||||
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
|
||||
break
|
||||
|
||||
# Backoff on errors
|
||||
await asyncio.sleep(min(consecutive_errors * 2, 30))
|
||||
|
||||
# Wait for active tasks to complete on shutdown
|
||||
if _active_tasks:
|
||||
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
|
||||
await asyncio.gather(*_active_tasks, return_exceptions=True)
|
||||
|
||||
logger.info("[JobWorker] Worker loop stopped")
|
||||
|
||||
|
||||
async def _process_job_with_semaphore(job: OCRJob) -> None:
|
||||
"""
|
||||
Process job with semaphore to limit concurrency.
|
||||
|
||||
Acquires semaphore before processing, releases after.
|
||||
This ensures we don't exceed OCR_WORKERS concurrent jobs.
|
||||
"""
|
||||
global _concurrency_semaphore
|
||||
|
||||
async with _concurrency_semaphore:
|
||||
await _process_job(job)
|
||||
|
||||
|
||||
async def _process_job(job: OCRJob) -> None:
|
||||
"""
|
||||
Process a single OCR job.
|
||||
|
||||
Reads file, submits to worker pool, updates job status.
|
||||
Reads file, submits to worker pool, updates job status,
|
||||
and saves metrics for analytics.
|
||||
|
||||
Args:
|
||||
job: OCRJob to process
|
||||
"""
|
||||
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
|
||||
start_time = time.time()
|
||||
file_size = 0
|
||||
file_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# Mark as processing
|
||||
await job_queue.update_status(job.id, OCRJobStatus.processing)
|
||||
# Note: Job already marked as 'processing' atomically in get_next_pending()
|
||||
|
||||
# Read file bytes
|
||||
file_path = Path(job.file_path)
|
||||
@@ -118,6 +167,10 @@ async def _process_job(job: OCRJob) -> None:
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
file_size = len(file_bytes)
|
||||
# Determine file type from job or extension
|
||||
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
|
||||
|
||||
# Submit to worker pool
|
||||
result = await ocr_worker_pool.submit_task(
|
||||
image_bytes=file_bytes,
|
||||
@@ -132,14 +185,43 @@ async def _process_job(job: OCRJob) -> None:
|
||||
# Job completed successfully
|
||||
extraction = result.get("extraction", {})
|
||||
|
||||
# Include raw_texts for analysis (from all OCR engine passes)
|
||||
extraction['raw_texts'] = result.get("raw_texts", [])
|
||||
|
||||
# Extract actual OCR processing time from extraction result
|
||||
ocr_time_ms = extraction.get('processing_time_ms', 0)
|
||||
|
||||
# Debug: log suggested_payment_mode
|
||||
spm = extraction.get('suggested_payment_mode')
|
||||
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.completed,
|
||||
result=extraction,
|
||||
processing_time_ms=elapsed_ms
|
||||
processing_time_ms=elapsed_ms,
|
||||
ocr_time_ms=ocr_time_ms
|
||||
)
|
||||
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms")
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
|
||||
|
||||
# Save metrics for successful job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=extraction.get('ocr_engine', job.engine),
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=True,
|
||||
overall_confidence=extraction.get('overall_confidence', 0.0),
|
||||
fields_extracted=_count_extracted_fields(extraction),
|
||||
needs_manual_review=extraction.get('needs_manual_review'),
|
||||
validation_warnings_count=len(extraction.get('validation_warnings', [])),
|
||||
validation_errors_count=len(extraction.get('validation_errors', [])),
|
||||
)
|
||||
|
||||
else:
|
||||
# Job failed
|
||||
@@ -154,6 +236,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
|
||||
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
|
||||
|
||||
# Save metrics for failed job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -166,6 +262,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
# Save metrics for error job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup file after processing
|
||||
try:
|
||||
@@ -340,3 +450,96 @@ def estimate_wait_time(queue_position: int) -> int:
|
||||
|
||||
# Estimate: position * average_time
|
||||
return int(queue_position * avg_time)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _save_job_metrics(
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Save OCR job metrics to database for analytics.
|
||||
|
||||
Called after each job completes (success or failure).
|
||||
Errors are logged but don't affect job processing.
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
|
||||
|
||||
async with await get_db_session() as session:
|
||||
await OCRMetricsCRUD.create(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - metrics are nice-to-have
|
||||
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
|
||||
|
||||
|
||||
def _count_extracted_fields(extraction: dict) -> int:
|
||||
"""
|
||||
Count number of successfully extracted fields from OCR result.
|
||||
|
||||
Counts non-None values in key fields.
|
||||
"""
|
||||
key_fields = [
|
||||
'receipt_number',
|
||||
'receipt_date',
|
||||
'amount',
|
||||
'partner_name',
|
||||
'cui',
|
||||
'tva_total',
|
||||
'address',
|
||||
'items_count',
|
||||
]
|
||||
|
||||
count = 0
|
||||
for field in key_fields:
|
||||
value = extraction.get(field)
|
||||
if value is not None and value != '' and value != []:
|
||||
count += 1
|
||||
|
||||
# Also count TVA entries if present
|
||||
tva_entries = extraction.get('tva_entries', [])
|
||||
if tva_entries and len(tva_entries) > 0:
|
||||
count += 1
|
||||
|
||||
# Count payment methods if present
|
||||
payment_methods = extraction.get('payment_methods', [])
|
||||
if payment_methods and len(payment_methods) > 0:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
OCR Worker Pool Manager
|
||||
|
||||
Manages a ProcessPoolExecutor with persistent PaddleOCR initialization.
|
||||
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
|
||||
Key features:
|
||||
- ProcessPoolExecutor with max_workers=1 (sequential, no memory leak)
|
||||
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
|
||||
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
|
||||
- mp_context='spawn' for Windows IIS compatibility
|
||||
- PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- atexit + signal handlers for cleanup
|
||||
- Health check with auto-respawn
|
||||
- Orphan process cleanup on Windows
|
||||
@@ -29,7 +30,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -48,8 +49,8 @@ class OCRWorkerPool:
|
||||
"""
|
||||
Singleton manager for OCR ProcessPoolExecutor.
|
||||
|
||||
Ensures PaddleOCR is loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=None to keep worker alive indefinitely.
|
||||
Ensures OCR engines are loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
|
||||
"""
|
||||
|
||||
_instance: Optional["OCRWorkerPool"] = None
|
||||
@@ -86,7 +87,7 @@ class OCRWorkerPool:
|
||||
Initialize the ProcessPoolExecutor.
|
||||
|
||||
Creates executor with spawn context for Windows compatibility.
|
||||
Uses max_tasks_per_child=None to keep worker alive (persistent PaddleOCR).
|
||||
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
@@ -103,18 +104,30 @@ class OCRWorkerPool:
|
||||
# Cleanup any orphan workers from previous runs
|
||||
self._cleanup_orphan_workers()
|
||||
|
||||
# Read configuration from environment
|
||||
max_workers = int(os.getenv('OCR_WORKERS', '2'))
|
||||
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
|
||||
# 0 means no restart (None in ProcessPoolExecutor)
|
||||
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
|
||||
|
||||
# Create executor with spawn context (Windows compatible)
|
||||
# Use mp_context='spawn' explicitly for cross-platform consistency
|
||||
mp_context = mp.get_context('spawn')
|
||||
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=1, # Single worker for sequential processing
|
||||
mp_context=mp_context,
|
||||
initializer=_worker_initializer,
|
||||
max_tasks_per_child=None, # Keep worker alive indefinitely
|
||||
)
|
||||
# max_tasks_per_child only available in Python 3.11+
|
||||
executor_kwargs = {
|
||||
'max_workers': max_workers,
|
||||
'mp_context': mp_context,
|
||||
'initializer': _worker_initializer,
|
||||
}
|
||||
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
|
||||
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
|
||||
else:
|
||||
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
logger.info("[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers=1)")
|
||||
self._executor = ProcessPoolExecutor(**executor_kwargs)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +186,7 @@ class OCRWorkerPool:
|
||||
async def submit_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
timeout: float = 120.0
|
||||
) -> dict:
|
||||
@@ -182,7 +195,7 @@ class OCRWorkerPool:
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
|
||||
timeout: Maximum processing time in seconds
|
||||
|
||||
@@ -339,6 +352,7 @@ class OCRWorkerPool:
|
||||
# Global engines - persist between tasks in worker process
|
||||
_paddle_engine = None
|
||||
_tesseract_engine = None
|
||||
_doctr_engine = None # docTR engine (PyTorch backend)
|
||||
_worker_initialized = False
|
||||
|
||||
|
||||
@@ -346,40 +360,92 @@ def _worker_initializer() -> None:
|
||||
"""
|
||||
Called once when worker process spawns.
|
||||
|
||||
Initializes global OCR engines that persist between tasks.
|
||||
This is where PaddleOCR loading happens (15-20 seconds).
|
||||
Initializes global OCR engines IN PARALLEL for faster startup.
|
||||
Uses ThreadPoolExecutor to load enabled engines concurrently.
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
|
||||
Total warmup time = max(engine_times) instead of sum(engine_times).
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
if _worker_initialized:
|
||||
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
|
||||
return
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines...", flush=True)
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
|
||||
if paddle_enabled:
|
||||
enabled_engines.append("paddle")
|
||||
if tesseract_enabled:
|
||||
enabled_engines.append("tesseract")
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
|
||||
if not paddle_enabled:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
|
||||
if not tesseract_enabled:
|
||||
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize PaddleOCR
|
||||
try:
|
||||
# Import inside worker to avoid import issues in main process
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
_paddle_engine = initialize_paddle_engine()
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR init failed: {e}", flush=True)
|
||||
_paddle_engine = None
|
||||
# Define loader functions - each runs in its own thread
|
||||
def load_doctr():
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
|
||||
engine = initialize_doctr_engine()
|
||||
return ("doctr", engine, None)
|
||||
except Exception as e:
|
||||
return ("doctr", None, str(e))
|
||||
|
||||
# Initialize Tesseract
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
_tesseract_engine = TesseractEngine()
|
||||
print(f"[Worker {os.getpid()}] Tesseract loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Tesseract init failed: {e}", flush=True)
|
||||
_tesseract_engine = None
|
||||
def load_paddle():
|
||||
if not paddle_enabled:
|
||||
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
engine = initialize_paddle_engine()
|
||||
return ("paddle", engine, None)
|
||||
except Exception as e:
|
||||
return ("paddle", None, str(e))
|
||||
|
||||
def load_tesseract():
|
||||
if not tesseract_enabled:
|
||||
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
engine = TesseractEngine()
|
||||
return ("tesseract", engine, None)
|
||||
except Exception as e:
|
||||
return ("tesseract", None, str(e))
|
||||
|
||||
# Build list of futures for enabled engines only
|
||||
futures_to_submit = [load_doctr] # docTR always loaded
|
||||
if paddle_enabled:
|
||||
futures_to_submit.append(load_paddle)
|
||||
if tesseract_enabled:
|
||||
futures_to_submit.append(load_tesseract)
|
||||
|
||||
# Load engines in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
|
||||
futures = [executor.submit(fn) for fn in futures_to_submit]
|
||||
|
||||
for future in as_completed(futures):
|
||||
name, engine, error = future.result()
|
||||
if error and "disabled" not in error:
|
||||
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
|
||||
elif engine:
|
||||
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
|
||||
if name == "doctr":
|
||||
_doctr_engine = engine
|
||||
elif name == "paddle":
|
||||
_paddle_engine = engine
|
||||
elif name == "tesseract":
|
||||
_tesseract_engine = engine
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_worker_initialized = True
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s", flush=True)
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
|
||||
|
||||
|
||||
def _warmup_task() -> dict:
|
||||
@@ -389,7 +455,7 @@ def _warmup_task() -> dict:
|
||||
Called at FastAPI startup to pre-warm the worker.
|
||||
Returns success status and worker PID.
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -400,6 +466,14 @@ def _warmup_task() -> dict:
|
||||
import numpy as np
|
||||
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# Test docTR if available (fastest engine)
|
||||
if _doctr_engine is not None:
|
||||
try:
|
||||
_doctr_engine([dummy_img])
|
||||
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
|
||||
|
||||
# Test PaddleOCR if available
|
||||
if _paddle_engine is not None:
|
||||
try:
|
||||
@@ -414,6 +488,7 @@ def _warmup_task() -> dict:
|
||||
return {
|
||||
"success": True,
|
||||
"pid": os.getpid(),
|
||||
"doctr_available": _doctr_engine is not None,
|
||||
"paddle_available": _paddle_engine is not None,
|
||||
"tesseract_available": _tesseract_engine is not None
|
||||
}
|
||||
@@ -428,7 +503,7 @@ def _warmup_task() -> dict:
|
||||
|
||||
def _process_ocr_task(
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto"
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -439,13 +514,13 @@ def _process_ocr_task(
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine choice
|
||||
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -461,7 +536,8 @@ def _process_ocr_task(
|
||||
paddle_engine=_paddle_engine,
|
||||
tesseract_engine=_tesseract_engine,
|
||||
engine=engine,
|
||||
preprocessing=preprocessing
|
||||
preprocessing=preprocessing,
|
||||
doctr_engine=_doctr_engine
|
||||
)
|
||||
|
||||
# Cleanup after each task
|
||||
|
||||
@@ -6,6 +6,7 @@ Handles OCR processing with persistent engine instances.
|
||||
|
||||
Key features:
|
||||
- PaddleOCR initialized ONCE at process spawn
|
||||
- docTR initialized ONCE at process spawn (PyTorch backend)
|
||||
- Tesseract as fallback/complement engine
|
||||
- Multi-pass preprocessing (light → medium → tesseract)
|
||||
- Automatic engine selection based on results
|
||||
@@ -26,6 +27,13 @@ import numpy as np
|
||||
# Disable PaddleOCR model source check for faster startup
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
# Memory optimization for docTR (prevents memory leak in multiprocessing)
|
||||
# Source: https://github.com/mindee/doctr/issues/1594
|
||||
os.environ['DOCTR_MULTIPROCESSING_DISABLE'] = 'TRUE'
|
||||
|
||||
# Reduce Intel oneDNN cache to save memory
|
||||
os.environ['ONEDNN_PRIMITIVE_CACHE_CAPACITY'] = '1'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
@@ -71,25 +79,67 @@ def initialize_paddle_engine():
|
||||
return None
|
||||
|
||||
|
||||
def initialize_doctr_engine():
|
||||
"""
|
||||
Initialize docTR engine (CPU only).
|
||||
|
||||
Called once at worker spawn. Returns the engine instance
|
||||
that will be reused for all subsequent requests.
|
||||
|
||||
Note: DirectML (AMD GPU) has compatibility issues with docTR.
|
||||
CUDA (NVIDIA) works but requires separate PyTorch build.
|
||||
CPU mode is stable and well-optimized.
|
||||
|
||||
Returns:
|
||||
docTR predictor instance or None if unavailable
|
||||
"""
|
||||
try:
|
||||
print(f"[Worker {os.getpid()}] Loading docTR (PyTorch backend, CPU)...", flush=True)
|
||||
start_time = time.time()
|
||||
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[Worker {os.getpid()}] docTR loaded in {elapsed:.1f}s", flush=True)
|
||||
return doctr
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR init failed: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def process_ocr(
|
||||
image_bytes: bytes,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
engine: str = "auto",
|
||||
preprocessing: str = "auto"
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
doctr_engine=None
|
||||
) -> dict:
|
||||
"""
|
||||
Process OCR on image bytes.
|
||||
|
||||
Main entry point for OCR processing in worker process.
|
||||
Uses adaptive multi-pass strategy for best results.
|
||||
Uses the specified engine for text recognition.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (JPEG, PNG, or PDF)
|
||||
paddle_engine: Pre-initialized PaddleOCR instance (or None)
|
||||
tesseract_engine: Pre-initialized TesseractEngine instance (or None)
|
||||
engine: Engine selection ('auto', 'paddleocr', 'tesseract')
|
||||
engine: Engine selection ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('auto', 'light', 'medium', 'heavy')
|
||||
doctr_engine: Pre-initialized docTR instance (or None)
|
||||
|
||||
Returns:
|
||||
Dict with extraction results:
|
||||
@@ -101,14 +151,20 @@ def process_ocr(
|
||||
"ocr_engine": str
|
||||
}
|
||||
"""
|
||||
import sys
|
||||
start_time = time.time()
|
||||
print(f"[Worker {os.getpid()}] Processing OCR: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
# Decode image from bytes
|
||||
print(f"[Worker {os.getpid()}] Decoding image...", flush=True)
|
||||
sys.stdout.flush()
|
||||
image = _decode_image(image_bytes)
|
||||
if image is None:
|
||||
return {"success": False, "error": "Failed to decode image"}
|
||||
print(f"[Worker {os.getpid()}] Image decoded: shape={image.shape}, dtype={image.dtype}, size={image.nbytes/1024/1024:.1f}MB", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Import preprocessor
|
||||
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
@@ -116,22 +172,36 @@ def process_ocr(
|
||||
|
||||
preprocessor = ImagePreprocessor()
|
||||
extractor = ReceiptExtractor()
|
||||
print(f"[Worker {os.getpid()}] Preprocessor and extractor initialized", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# Engine routing
|
||||
if engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "tesseract":
|
||||
# Engine routing (available: tesseract, doctr, doctr_plus, paddleocr)
|
||||
print(f"[Worker {os.getpid()}] Routing to engine: {engine}", flush=True)
|
||||
sys.stdout.flush()
|
||||
if engine == "tesseract":
|
||||
extraction, raw_texts = _process_tesseract_only(
|
||||
image, tesseract_engine, preprocessor, extractor
|
||||
)
|
||||
else: # auto
|
||||
extraction, raw_texts = _process_adaptive(
|
||||
image, paddle_engine, tesseract_engine, preprocessor, extractor
|
||||
elif engine == "doctr":
|
||||
extraction, raw_texts = _process_doctr_only(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "doctr_plus":
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
else:
|
||||
# Default to doctr_plus if unknown engine specified
|
||||
print(f"[OCR] Unknown engine '{engine}', defaulting to doctr_plus", flush=True)
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
@@ -171,7 +241,11 @@ def process_ocr(
|
||||
|
||||
|
||||
def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF)."""
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF).
|
||||
|
||||
For PDFs, uses 200 DPI which is sufficient for receipt OCR
|
||||
and reduces processing time by ~50% vs 300 DPI.
|
||||
"""
|
||||
try:
|
||||
# Try as regular image first
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
@@ -180,18 +254,21 @@ def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
if image is not None:
|
||||
return image
|
||||
|
||||
# Try as PDF
|
||||
# Try as PDF - use 200 DPI for faster processing (sufficient for receipts)
|
||||
try:
|
||||
import pdf2image
|
||||
from PIL import Image
|
||||
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=300)
|
||||
# 200 DPI is sufficient for receipt text recognition
|
||||
# 300 DPI was overkill and slowed down processing
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=200)
|
||||
if images:
|
||||
# Convert first page to numpy array
|
||||
pil_img = images[0]
|
||||
print(f"[Worker {os.getpid()}] PDF decoded: {pil_img.width}x{pil_img.height} @ 200 DPI", flush=True)
|
||||
return np.array(pil_img)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PDF decode error: {e}", flush=True)
|
||||
|
||||
return None
|
||||
|
||||
@@ -270,83 +347,275 @@ def _process_tesseract_only(
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_adaptive(
|
||||
def _process_doctr_only(
|
||||
image: np.ndarray,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
Adaptive multi-pass OCR processing.
|
||||
Process using docTR only (light + medium preprocessing).
|
||||
|
||||
Strategy:
|
||||
1. PaddleOCR Light - fastest, best for clear PDFs
|
||||
2. PaddleOCR Medium - if Light incomplete
|
||||
3. Tesseract - complement missing fields only
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
docTR uses EXACT same preprocessing as PaddleOCR for consistency.
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# === STEP 1: PaddleOCR Light ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 1: PaddleOCR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
paddle_light = _paddle_recognize(paddle_engine, light_img)
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
if paddle_light and paddle_light.text:
|
||||
extraction = extractor.extract(paddle_light.text)
|
||||
extraction.ocr_engine = "paddle-light"
|
||||
raw_texts.append(f"=== PaddleOCR Light (conf: {paddle_light.confidence:.0%}) ===\n{paddle_light.text}")
|
||||
# Step 1: Light preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 1: docTR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found in Step 1", flush=True)
|
||||
return extraction, raw_texts
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-light"
|
||||
raw_texts.append(f"=== docTR Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# === STEP 2: PaddleOCR Medium ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 2: PaddleOCR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
paddle_medium = _paddle_recognize(paddle_engine, medium_img)
|
||||
if _is_extraction_complete(extraction):
|
||||
return extraction, raw_texts
|
||||
|
||||
if paddle_medium and paddle_medium.text:
|
||||
extraction_medium = extractor.extract(paddle_medium.text)
|
||||
extraction_medium.ocr_engine = "paddle-medium"
|
||||
raw_texts.append(f"=== PaddleOCR Medium (conf: {paddle_medium.confidence:.0%}) ===\n{paddle_medium.text}")
|
||||
# Step 2: Medium preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 2: docTR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "paddle-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-medium"
|
||||
raw_texts.append(f"=== docTR Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found after Step 2", flush=True)
|
||||
return extraction, raw_texts
|
||||
|
||||
# === STEP 3: Tesseract (complement only) ===
|
||||
if tesseract_engine:
|
||||
print("[OCR] Step 3: Tesseract complement", flush=True)
|
||||
tesseract_img = preprocessor.preprocess_for_tesseract(image)
|
||||
tesseract_result = tesseract_engine.recognize(tesseract_img)
|
||||
|
||||
if tesseract_result and tesseract_result.text:
|
||||
extraction_tess = extractor.extract(tesseract_result.text)
|
||||
extraction_tess.ocr_engine = "tesseract"
|
||||
raw_texts.append(f"=== Tesseract (conf: {tesseract_result.confidence:.0%}) ===\n{tesseract_result.text}")
|
||||
|
||||
if extraction:
|
||||
extraction = _complement_extraction(extraction, extraction_tess)
|
||||
extraction.ocr_engine = "adaptive-full"
|
||||
else:
|
||||
extraction = extraction_tess
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "doctr-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_doctr_plus(
|
||||
image: np.ndarray,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
docTR Plus - Optimized 2-tier sequential processing with early exit.
|
||||
|
||||
Architecture:
|
||||
- Tier 1: Light preprocessing (~4-5s)
|
||||
→ Early exit if confidence >= 0.75 AND all fields valid AND cross-validations pass
|
||||
- Tier 2: Medium preprocessing (only if Tier 1 insufficient, ~4-5s additional)
|
||||
→ Merge with Tier 1 results
|
||||
→ Mark for review if still problems
|
||||
|
||||
Performance:
|
||||
- Fast path (80% receipts): ~4-5s (Tier 1 only)
|
||||
- Slow path (20% receipts): ~8-9s (Tier 1 + Tier 2)
|
||||
- Average: ~5-6s
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
extraction_result.needs_review = True if validation issues remain
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
# ========== TIER 1: Light Preprocessing ==========
|
||||
print("[docTR+] TIER 1: Light preprocessing", flush=True)
|
||||
import time
|
||||
tier1_start = time.time()
|
||||
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
tier1_time = time.time() - tier1_start
|
||||
print(f"[docTR+] TIER 1 completed in {tier1_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-plus-light"
|
||||
raw_texts.append(f"=== docTR+ Tier1/Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# Early Exit Check: confidence >= 0.75 + cross-validations
|
||||
if _is_extraction_valid_for_early_exit(extraction, min_confidence=0.75):
|
||||
print(f"[docTR+] EARLY EXIT - Tier 1 sufficient (conf: {extraction.overall_confidence:.0%})", flush=True)
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
return extraction, raw_texts
|
||||
|
||||
print(f"[docTR+] Tier 1 incomplete or validation failed, proceeding to Tier 2...", flush=True)
|
||||
|
||||
# ========== TIER 2: Medium Preprocessing (only if needed) ==========
|
||||
print("[docTR+] TIER 2: Medium preprocessing", flush=True)
|
||||
tier2_start = time.time()
|
||||
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
tier2_time = time.time() - tier2_start
|
||||
print(f"[docTR+] TIER 2 completed in {tier2_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-plus-medium"
|
||||
raw_texts.append(f"=== docTR+ Tier2/Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if extraction:
|
||||
# Merge Tier 1 + Tier 2 results
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
# ========== FINAL VALIDATION ==========
|
||||
if extraction:
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
|
||||
# Mark for review if validation still fails after both tiers
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation or extraction.overall_confidence < 0.75:
|
||||
# Mark for human review using existing fields
|
||||
extraction.needs_manual_review = True
|
||||
|
||||
if extraction.overall_confidence < 0.75:
|
||||
extraction.validation_warnings.append(f"Low confidence: {extraction.overall_confidence:.0%}")
|
||||
|
||||
if not extraction.amount:
|
||||
extraction.validation_errors.append("TOTAL not detected")
|
||||
if not extraction.cui:
|
||||
extraction.validation_warnings.append("CUI not detected")
|
||||
if not extraction.tva_total and not extraction.tva_entries:
|
||||
extraction.validation_warnings.append("TVA not detected")
|
||||
if not extraction.receipt_date:
|
||||
extraction.validation_warnings.append("Date not detected")
|
||||
|
||||
# Add cross-validation errors
|
||||
extraction.validation_errors.extend(errors)
|
||||
|
||||
print(f"[docTR+] Marked for review: {extraction.validation_errors + extraction.validation_warnings}", flush=True)
|
||||
else:
|
||||
extraction.needs_manual_review = False
|
||||
|
||||
total_time = tier1_time + (tier2_time if 'tier2_time' in dir() else 0)
|
||||
print(f"[docTR+] Total processing time: {total_time:.1f}s", flush=True)
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VALIDATION HELPERS (used by doctr_plus for early exit decisions)
|
||||
# =============================================================================
|
||||
|
||||
def _quick_cross_validate(extraction) -> tuple[bool, float, list[str]]:
|
||||
"""
|
||||
Quick cross-validation for OCR results.
|
||||
|
||||
Checks critical field correlations to detect obvious OCR errors.
|
||||
Used by doctr_plus to decide whether to proceed to Tier 2 or exit early.
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
if extraction is None:
|
||||
return False, 1.0, ["No extraction result"]
|
||||
|
||||
# Convert extraction to dict for validation
|
||||
# Build TVA entries dict for TVAEntriesSumRule (expects {code: amount})
|
||||
tva_entries_dict = {}
|
||||
if extraction.tva_entries:
|
||||
for entry in extraction.tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
code = entry.get('code', 'A')
|
||||
amount = entry.get('amount', 0)
|
||||
try:
|
||||
tva_entries_dict[code] = float(amount)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
validation_data = {
|
||||
"amount": float(extraction.amount) if extraction.amount else None,
|
||||
"tva": float(extraction.tva_total) if extraction.tva_total else None,
|
||||
"tva_entries": tva_entries_dict, # For TVAEntriesSumRule: {code: amount}
|
||||
"cui": extraction.cui, # For CUI checksum validation
|
||||
}
|
||||
|
||||
# Also pass raw tva_entries for TVABasedTotalRule (for rate detection)
|
||||
if extraction.tva_entries:
|
||||
validation_data['tva_entries_raw'] = extraction.tva_entries
|
||||
|
||||
# Add payment methods if available (for TOTAL vs CARD+CASH validation)
|
||||
if extraction.payment_methods:
|
||||
try:
|
||||
card_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'CARD'
|
||||
)
|
||||
cash_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'NUMERAR'
|
||||
)
|
||||
validation_data['card_amount'] = card_amount
|
||||
validation_data['cash_amount'] = cash_amount
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Payment method validation error: {e}", flush=True)
|
||||
|
||||
# Run quick validation
|
||||
validator = OCRValidationEngine()
|
||||
return validator.quick_validate_for_hybrid(validation_data)
|
||||
|
||||
except Exception as e:
|
||||
# Never crash the process on validation errors
|
||||
print(f"[Worker {os.getpid()}] Cross-validation error: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return "passes" to allow processing to continue
|
||||
return True, 0.0, [f"Validation skipped due to error: {str(e)}"]
|
||||
|
||||
|
||||
def _is_extraction_valid_for_early_exit(extraction, min_confidence: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if extraction is valid for early exit in doctr_plus.
|
||||
|
||||
Combines confidence check with cross-validation to prevent
|
||||
early exit on OCR errors (e.g., wrong TOTAL but correct TVA).
|
||||
|
||||
Returns:
|
||||
True only if:
|
||||
1. Overall confidence >= min_confidence
|
||||
2. Critical fields are present (AMOUNT, DATE, CUI)
|
||||
3. Cross-validation passes (TOTAL matches TVA calculation, or no TVA)
|
||||
"""
|
||||
try:
|
||||
# First check basic completeness (relaxed for early exit)
|
||||
if not _is_extraction_complete(extraction, min_confidence, for_early_exit=True):
|
||||
return False
|
||||
|
||||
# Then run cross-validation
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation:
|
||||
print(f"[Early Exit] BLOCKED: cross-validation failed: {errors}", flush=True)
|
||||
return False
|
||||
|
||||
print(f"[Early Exit] OK: conf={extraction.overall_confidence:.0%}, validation passed", flush=True)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Never crash on validation - just continue to next engine
|
||||
print(f"[Worker {os.getpid()}] Early exit check error: {e}", flush=True)
|
||||
return False # Continue to next engine on error
|
||||
|
||||
def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""Run PaddleOCR recognition on image."""
|
||||
try:
|
||||
@@ -388,34 +657,191 @@ def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85) -> bool:
|
||||
"""Check if extraction has all required fields."""
|
||||
def _doctr_recognize(doctr_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""
|
||||
Run docTR recognition on image.
|
||||
|
||||
docTR requires RGB images, handles conversion automatically.
|
||||
Uses same preprocessing as PaddleOCR for consistent results.
|
||||
"""
|
||||
try:
|
||||
# docTR requires RGB images
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif image.shape[2] == 3:
|
||||
# Convert BGR (OpenCV) to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
|
||||
# docTR expects a list of numpy arrays (pages)
|
||||
result = doctr_engine([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker] docTR error: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85, for_early_exit: bool = False) -> bool:
|
||||
"""
|
||||
Check if extraction has required fields.
|
||||
|
||||
Args:
|
||||
ext: Extraction result
|
||||
min_confidence: Minimum overall confidence
|
||||
for_early_exit: If True, use relaxed criteria (AMOUNT + DATE + CUI required)
|
||||
If False, require all fields (strict mode for final validation)
|
||||
|
||||
Returns:
|
||||
True if extraction meets completeness criteria
|
||||
"""
|
||||
# Check confidence first
|
||||
if ext.overall_confidence < min_confidence:
|
||||
if for_early_exit:
|
||||
print(f"[Early Exit] BLOCKED: confidence {ext.overall_confidence:.0%} < {min_confidence:.0%}", flush=True)
|
||||
return False
|
||||
|
||||
has_number = bool(ext.receipt_number)
|
||||
has_date = bool(ext.receipt_date)
|
||||
has_amount = bool(ext.amount)
|
||||
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
|
||||
has_cui = bool(ext.cui)
|
||||
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
if for_early_exit:
|
||||
# Relaxed criteria for early exit:
|
||||
# - AMOUNT is required (core field)
|
||||
# - DATE is required (needed for accounting)
|
||||
# - CUI is required (needed for supplier identification)
|
||||
# - TVA is NOT required (some receipts have 0% TVA)
|
||||
# - receipt_number is NOT required (often missing)
|
||||
required_ok = all([has_amount, has_date, has_cui])
|
||||
|
||||
if not required_ok:
|
||||
missing = []
|
||||
if not has_amount: missing.append("AMOUNT")
|
||||
if not has_date: missing.append("DATE")
|
||||
if not has_cui: missing.append("CUI")
|
||||
print(f"[Early Exit] BLOCKED: missing required fields: {', '.join(missing)}", flush=True)
|
||||
|
||||
return required_ok
|
||||
else:
|
||||
# Strict criteria for final validation (all fields required)
|
||||
has_number = bool(ext.receipt_number)
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
|
||||
|
||||
def _merge_extractions(primary, secondary):
|
||||
"""Merge two extractions, picking best fields from each."""
|
||||
"""Merge two extractions, picking best fields from each.
|
||||
|
||||
Primary should be the higher-quality engine (e.g., docTR).
|
||||
Secondary is the fallback engine (e.g., Tesseract).
|
||||
|
||||
Priority logic:
|
||||
- AMOUNT: TVA validation wins over confidence. If both valid or both invalid,
|
||||
uses confidence (or TVA diff for invalid cases).
|
||||
- DATE/CUI: Validation-based, then confidence, then primary wins ties.
|
||||
- OTHER FIELDS: Primary wins when both have values.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_extractor import ExtractionResult
|
||||
|
||||
result = ExtractionResult()
|
||||
|
||||
# Amount - prefer higher confidence
|
||||
# Helper: Check if amount matches TVA calculation
|
||||
def amount_passes_tva_validation(amount, tva_total, tva_entries):
|
||||
if not amount or not tva_total:
|
||||
return False, 0.0
|
||||
try:
|
||||
tva_rate = 0.21 # Default Romanian TVA
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict) and entry.get('percent'):
|
||||
tva_rate = float(entry['percent']) / 100.0
|
||||
break
|
||||
# Expected TOTAL = TVA / rate * (1 + rate)
|
||||
expected = float(tva_total) * (1 + tva_rate) / tva_rate
|
||||
actual = float(amount)
|
||||
diff_percent = abs(actual - expected) / expected if expected > 0 else 1.0
|
||||
return diff_percent < 0.03, diff_percent # 3% tolerance
|
||||
except:
|
||||
return False, 1.0
|
||||
|
||||
# Amount - prefer TVA-validated value over confidence
|
||||
if primary.amount and secondary.amount:
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
# Get TVA from the one with entries, or use any available
|
||||
tva_total = primary.tva_total or secondary.tva_total
|
||||
tva_entries = primary.tva_entries or secondary.tva_entries
|
||||
|
||||
primary_valid, primary_diff = amount_passes_tva_validation(
|
||||
primary.amount, tva_total, tva_entries
|
||||
)
|
||||
secondary_valid, secondary_diff = amount_passes_tva_validation(
|
||||
secondary.amount, tva_total, tva_entries
|
||||
)
|
||||
|
||||
print(f"[Merge] Amount comparison: primary={primary.amount} (valid={primary_valid}, diff={primary_diff:.1%}), "
|
||||
f"secondary={secondary.amount} (valid={secondary_valid}, diff={secondary_diff:.1%})", flush=True)
|
||||
|
||||
if secondary_valid and not primary_valid:
|
||||
# Secondary passes validation, primary doesn't - use secondary!
|
||||
print(f"[Merge] Using secondary amount {secondary.amount} (passes TVA validation)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
elif primary_valid and not secondary_valid:
|
||||
# Primary passes validation
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary_valid and secondary_valid:
|
||||
# Both valid - use higher confidence
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
# Neither valid - use the one closer to TVA calculation
|
||||
if secondary_diff < primary_diff:
|
||||
print(f"[Merge] Neither valid, using secondary {secondary.amount} (closer to TVA)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary.amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
@@ -438,13 +864,15 @@ def _merge_extractions(primary, secondary):
|
||||
result.receipt_date = secondary.receipt_date
|
||||
result.confidence_date = secondary.confidence_date
|
||||
|
||||
# CUI - prefer valid format
|
||||
# CUI - prefer valid format and version with RO prefix
|
||||
# Use CUIChecksumRule static methods (single source of truth)
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
def is_valid_cui(cui):
|
||||
if not cui:
|
||||
return False
|
||||
import re
|
||||
cui_clean = re.sub(r'^RO', '', cui.upper())
|
||||
return bool(re.match(r'^\d{6,10}$', cui_clean))
|
||||
digits = CUIChecksumRule.extract_digits(cui)
|
||||
return len(digits) >= 6 and len(digits) <= 10
|
||||
|
||||
if primary.cui and secondary.cui:
|
||||
if is_valid_cui(primary.cui) and not is_valid_cui(secondary.cui):
|
||||
@@ -452,22 +880,27 @@ def _merge_extractions(primary, secondary):
|
||||
elif is_valid_cui(secondary.cui) and not is_valid_cui(primary.cui):
|
||||
result.cui = secondary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
# Both valid - prefer the one with RO prefix if digits match
|
||||
primary_digits = CUIChecksumRule.extract_digits(primary.cui)
|
||||
secondary_digits = CUIChecksumRule.extract_digits(secondary.cui)
|
||||
if primary_digits == secondary_digits:
|
||||
if CUIChecksumRule.has_ro_prefix(secondary.cui) and not CUIChecksumRule.has_ro_prefix(primary.cui):
|
||||
result.cui = secondary.cui # Prefer version with RO
|
||||
print(f"[CUI Complement] Preferring secondary with RO: {secondary.cui}", flush=True)
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
elif primary.cui:
|
||||
result.cui = primary.cui
|
||||
elif secondary.cui:
|
||||
result.cui = secondary.cui
|
||||
|
||||
# TVA entries
|
||||
# TVA entries - ALWAYS prefer primary (docTR) when both have entries
|
||||
if primary.tva_entries and secondary.tva_entries:
|
||||
primary_total = sum(e.get('amount', Decimal('0')) for e in primary.tva_entries)
|
||||
secondary_total = sum(e.get('amount', Decimal('0')) for e in secondary.tva_entries)
|
||||
if primary_total >= secondary_total:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
else:
|
||||
result.tva_entries = secondary.tva_entries
|
||||
result.tva_total = secondary.tva_total
|
||||
# Always use primary (docTR) - higher quality OCR
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
elif primary.tva_entries:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
@@ -483,12 +916,36 @@ def _merge_extractions(primary, secondary):
|
||||
result.address = primary.address or secondary.address
|
||||
result.items_count = primary.items_count or secondary.items_count
|
||||
result.payment_methods = primary.payment_methods or secondary.payment_methods
|
||||
result.suggested_payment_mode = getattr(primary, 'suggested_payment_mode', None) or getattr(secondary, 'suggested_payment_mode', None)
|
||||
|
||||
# Client fields
|
||||
result.client_name = primary.client_name or secondary.client_name
|
||||
result.client_cui = primary.client_cui or secondary.client_cui
|
||||
result.client_address = primary.client_address or secondary.client_address
|
||||
|
||||
# Confidence fields - preserve from primary or pick best
|
||||
if primary.confidence_vendor >= secondary.confidence_vendor:
|
||||
result.confidence_vendor = primary.confidence_vendor
|
||||
else:
|
||||
result.confidence_vendor = secondary.confidence_vendor
|
||||
|
||||
if hasattr(primary, 'confidence_client') and hasattr(secondary, 'confidence_client'):
|
||||
if primary.confidence_client >= secondary.confidence_client:
|
||||
result.confidence_client = primary.confidence_client
|
||||
else:
|
||||
result.confidence_client = secondary.confidence_client
|
||||
|
||||
# Raw text - combine both for debugging/display
|
||||
raw_texts = []
|
||||
if primary.raw_text:
|
||||
raw_texts.append(primary.raw_text)
|
||||
if secondary.raw_text and secondary.raw_text != primary.raw_text:
|
||||
raw_texts.append(secondary.raw_text)
|
||||
result.raw_text = '\n---\n'.join(raw_texts) if raw_texts else ''
|
||||
|
||||
# Note: overall_confidence is a computed @property on ExtractionResult
|
||||
# It automatically calculates from confidence_amount, confidence_date, confidence_vendor
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -557,6 +1014,7 @@ def _extraction_to_dict(extraction) -> dict:
|
||||
"address": extraction.address,
|
||||
"items_count": extraction.items_count,
|
||||
"payment_methods": extraction.payment_methods,
|
||||
"suggested_payment_mode": getattr(extraction, 'suggested_payment_mode', None),
|
||||
# Client data
|
||||
"client_name": extraction.client_name,
|
||||
"client_cui": extraction.client_cui,
|
||||
|
||||
@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
|
||||
|
||||
result = rule.validate({"cui": "R01879855"})
|
||||
# result.is_valid = False (checksum mismatch)
|
||||
|
||||
Static methods available for direct use:
|
||||
CUIChecksumRule.calculate_checksum("1056260") -> 0
|
||||
CUIChecksumRule.validate_checksum("10562600") -> True
|
||||
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
|
||||
"""
|
||||
|
||||
# Fixed multipliers for 9 positions (Romanian Mod 11)
|
||||
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(cui_base: str) -> int:
|
||||
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
|
||||
|
||||
Args:
|
||||
cui_base: CUI digits WITHOUT the checksum digit (last digit)
|
||||
|
||||
Returns:
|
||||
Expected checksum digit (0-9), or -1 if invalid input
|
||||
"""
|
||||
if not cui_base or not cui_base.isdigit():
|
||||
return -1
|
||||
|
||||
# Pad base to 9 digits from LEFT
|
||||
base_padded = cui_base.zfill(9)
|
||||
base_digits = [int(d) for d in base_padded]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
|
||||
|
||||
# Calculate checksum
|
||||
checksum = (weighted_sum * 10) % 11
|
||||
if checksum == 10:
|
||||
checksum = 0
|
||||
|
||||
return checksum
|
||||
|
||||
@staticmethod
|
||||
def validate_checksum(cui_digits: str) -> bool:
|
||||
"""Check if CUI checksum is valid.
|
||||
|
||||
Args:
|
||||
cui_digits: Full CUI digits (including checksum as last digit)
|
||||
|
||||
Returns:
|
||||
True if checksum is valid, False otherwise
|
||||
"""
|
||||
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return False
|
||||
|
||||
base = cui_digits[:-1]
|
||||
declared = int(cui_digits[-1])
|
||||
expected = CUIChecksumRule.calculate_checksum(base)
|
||||
|
||||
return expected == declared
|
||||
|
||||
@staticmethod
|
||||
def has_ro_prefix(cui: str) -> bool:
|
||||
"""Check if CUI has RO prefix (proper format for VAT payers)."""
|
||||
if not cui:
|
||||
return False
|
||||
return cui.upper().strip().startswith('RO')
|
||||
|
||||
@staticmethod
|
||||
def extract_digits(cui: str) -> str:
|
||||
"""Extract digits from CUI, removing RO/R0 prefix."""
|
||||
if not cui:
|
||||
return ""
|
||||
cui = cui.strip().upper()
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui = cui[2:]
|
||||
return ''.join(c for c in cui if c.isdigit())
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "CUI Checksum Check (Mod 11)"
|
||||
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="No CUI to validate"
|
||||
)
|
||||
|
||||
# Normalize: remove RO/R0 prefix
|
||||
cui_clean = cui.strip().upper()
|
||||
if cui_clean.startswith("RO"):
|
||||
cui_clean = cui_clean[2:]
|
||||
elif cui_clean.startswith("R0"):
|
||||
cui_clean = cui_clean[2:]
|
||||
# Use static method to extract digits
|
||||
cui_clean = CUIChecksumRule.extract_digits(cui)
|
||||
|
||||
# Check format first
|
||||
if not cui_clean.isdigit():
|
||||
if not cui_clean:
|
||||
return ValidationResult(
|
||||
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
|
||||
message="CUI format invalid, skipping checksum"
|
||||
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="CUI length invalid, skipping checksum"
|
||||
)
|
||||
|
||||
# Extract digits
|
||||
digits = [int(d) for d in cui_clean]
|
||||
checksum_declared = digits[-1]
|
||||
base_digits = digits[:-1]
|
||||
|
||||
# Multipliers (trim to match base_digits length)
|
||||
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
multipliers = multipliers[:len(base_digits)]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
|
||||
|
||||
# Calculate expected checksum
|
||||
checksum_calculated = (weighted_sum * 10) % 11
|
||||
if checksum_calculated == 10:
|
||||
checksum_calculated = 0
|
||||
|
||||
if checksum_calculated != checksum_declared:
|
||||
# Use static method to validate checksum
|
||||
if not CUIChecksumRule.validate_checksum(cui_clean):
|
||||
# Calculate expected for error message
|
||||
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
|
||||
declared = int(cui_clean[-1])
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.3,
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
|
||||
severity="warning"
|
||||
)
|
||||
|
||||
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
|
||||
)
|
||||
|
||||
|
||||
class TVABasedTotalRule(ValidationRule):
|
||||
"""Validate TOTAL using reverse calculation from TVA amount.
|
||||
|
||||
This is a CRITICAL validation that catches cases where OCR extracts
|
||||
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
|
||||
we can calculate expected TOTAL from TVA alone.
|
||||
|
||||
Formula:
|
||||
Expected TOTAL = TVA / rate * (1 + rate)
|
||||
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
|
||||
|
||||
For TVA rate 21%:
|
||||
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
|
||||
|
||||
Example (benzina 27 oct):
|
||||
TVA = 49.58, rate = 21%
|
||||
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
|
||||
Extracted TOTAL = 205.66 (WRONG!)
|
||||
Rule detects mismatch and flags for escalation
|
||||
|
||||
Usage in multi-tier processing (e.g., doctr_plus):
|
||||
If this rule fails, the engine should proceed to next tier
|
||||
instead of returning early with potentially wrong data.
|
||||
"""
|
||||
|
||||
def __init__(self, tolerance_percent: float = 0.02):
|
||||
"""
|
||||
Args:
|
||||
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
|
||||
"""
|
||||
self.tolerance_percent = tolerance_percent
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "TVA-Based Total Check"
|
||||
|
||||
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
||||
total = data.get("amount")
|
||||
tva = data.get("tva")
|
||||
tva_entries = data.get("tva_entries", [])
|
||||
|
||||
if not total or not tva:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Insufficient data for TVA-based total validation"
|
||||
)
|
||||
|
||||
# Type safety
|
||||
try:
|
||||
total = float(total)
|
||||
tva = float(tva)
|
||||
except (TypeError, ValueError):
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Non-numeric values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
if tva <= 0 or total <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Zero or negative values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
# Try to determine TVA rate from entries
|
||||
tva_rate = None
|
||||
|
||||
# Check tva_entries for rate information
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
percent = entry.get('percent')
|
||||
if percent:
|
||||
try:
|
||||
tva_rate = float(percent) / 100.0
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: try to calculate rate from TVA/TOTAL ratio
|
||||
if not tva_rate:
|
||||
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
|
||||
# TVA/TOTAL = rate / (1 + rate)
|
||||
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
|
||||
base = total - tva
|
||||
if base > 0:
|
||||
calculated_rate = tva / base
|
||||
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
|
||||
if 0.04 <= calculated_rate <= 0.25:
|
||||
tva_rate = calculated_rate
|
||||
|
||||
if not tva_rate:
|
||||
# Assume most common rate: 21%
|
||||
tva_rate = 0.21
|
||||
|
||||
# Calculate expected TOTAL from TVA
|
||||
# TVA = BASE * rate → BASE = TVA / rate
|
||||
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
|
||||
expected_total = tva * (1 + tva_rate) / tva_rate
|
||||
|
||||
# Calculate difference
|
||||
diff = abs(total - expected_total)
|
||||
diff_percent = diff / expected_total if expected_total > 0 else 1.0
|
||||
|
||||
if diff_percent > self.tolerance_percent:
|
||||
# Significant mismatch - OCR likely extracted TOTAL wrong
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.5, # High penalty - this is a critical error
|
||||
message=(
|
||||
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
|
||||
f"TVA-calculated {expected_total:.2f} RON "
|
||||
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
|
||||
f"Likely OCR error on TOTAL."
|
||||
),
|
||||
severity="error"
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
|
||||
)
|
||||
|
||||
|
||||
class InterOCRConsistencyRule(ValidationRule):
|
||||
"""Validate consistency between multiple OCR results.
|
||||
|
||||
@@ -562,6 +741,7 @@ class OCRValidationEngine:
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
|
||||
PaymentSumRule(tolerance=0.02),
|
||||
TVAEntriesSumRule(tolerance=0.02),
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
|
||||
]
|
||||
|
||||
# Inter-OCR consistency rules
|
||||
@@ -699,39 +879,508 @@ class OCRValidationEngine:
|
||||
inter_ocr_ratios=inter_ocr_ratios
|
||||
)
|
||||
|
||||
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
|
||||
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
|
||||
|
||||
Runs critical cross-validation rules to detect obvious OCR errors.
|
||||
Used to decide whether to proceed to next processing tier or exit early.
|
||||
|
||||
Args:
|
||||
extraction_result: Extraction data dict with fields:
|
||||
- amount: Extracted TOTAL
|
||||
- tva: Extracted TVA total
|
||||
- tva_entries: List of TVA entries with rates
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
- passes_validation: True if no critical errors detected
|
||||
- confidence_penalty: Cumulative penalty (0.0-1.0)
|
||||
- error_messages: List of validation error messages
|
||||
|
||||
Example usage:
|
||||
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
|
||||
if not passes:
|
||||
print(f"Validation failed: {errors}, proceeding to next tier")
|
||||
# Continue to next processing tier instead of early exit
|
||||
"""
|
||||
errors = []
|
||||
total_penalty = 0.0
|
||||
|
||||
# Critical rules for early-exit decision-making
|
||||
# These determine if we can trust the extraction or need to proceed to next tier
|
||||
critical_rules = [
|
||||
# Cross-field validations (most important for detecting OCR errors)
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
|
||||
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
|
||||
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
|
||||
|
||||
# Format & checksum validations
|
||||
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
|
||||
CUIFormatRule(), # CUI should be 6-10 digits
|
||||
|
||||
# Sanity checks
|
||||
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
|
||||
]
|
||||
|
||||
for rule in critical_rules:
|
||||
result = rule.validate(extraction_result)
|
||||
if not result.is_valid:
|
||||
errors.append(result.message)
|
||||
total_penalty += result.confidence_penalty
|
||||
|
||||
# Cap penalty at 1.0
|
||||
total_penalty = min(1.0, total_penalty)
|
||||
|
||||
passes = len(errors) == 0
|
||||
return passes, total_penalty, errors
|
||||
|
||||
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
|
||||
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
|
||||
|
||||
@staticmethod
|
||||
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
|
||||
"""Try to repair CUI by attempting 1-digit corrections.
|
||||
|
||||
OCR often misreads similar-looking digits:
|
||||
- 5 ↔ 8 (most common in receipts)
|
||||
- 6 ↔ 0
|
||||
- 1 ↔ 7
|
||||
- 3 ↔ 8
|
||||
|
||||
Algorithm:
|
||||
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
|
||||
2. Skip first digit (position 0) - usually reliable in CUI
|
||||
3. Check checksum digit (last position) last
|
||||
4. Prefer common OCR digit confusions (5↔8, 6↔0)
|
||||
|
||||
Args:
|
||||
cui_digits: Original CUI digits (without RO prefix)
|
||||
|
||||
Returns:
|
||||
Repaired CUI digits if 1-digit fix found, else None
|
||||
"""
|
||||
if len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return None
|
||||
|
||||
# If already valid, return as-is
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
return cui_digits
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then position 1, then 0, then checksum
|
||||
# Skip position 0 (first digit) - it's usually reliable
|
||||
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
|
||||
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
|
||||
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
if CUIChecksumRule.validate_checksum(candidate):
|
||||
print(f"[CUI Repair] Fixed {cui_digits} → {candidate} (position {pos}: {original_digit}→{replacement})", flush=True)
|
||||
return candidate
|
||||
|
||||
# No single-digit fix found
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def normalize_cui(cui: Optional[str]) -> Optional[str]:
|
||||
"""Normalize CUI to RO prefix + digits format.
|
||||
"""Normalize CUI - fix OCR errors but preserve original format.
|
||||
|
||||
Rules:
|
||||
- R0 → RO (fix OCR error where O is read as 0)
|
||||
- Keep RO prefix if original had it (platitor TVA)
|
||||
- Do NOT add RO if original didn't have it (neplatitor TVA)
|
||||
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
|
||||
|
||||
Examples:
|
||||
10562600 → RO10562600
|
||||
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
|
||||
R010562600 → RO10562600 (fix R0 OCR error)
|
||||
RO10562600 → RO10562600 (unchanged)
|
||||
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
|
||||
|
||||
Args:
|
||||
cui: Raw CUI string from OCR
|
||||
|
||||
Returns:
|
||||
Normalized CUI with RO prefix, or None if invalid
|
||||
Normalized CUI, or None if invalid
|
||||
"""
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Remove existing prefix if present
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"):
|
||||
cui = cui[2:]
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui if c.isdigit())
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
|
||||
return None
|
||||
|
||||
# Add RO prefix
|
||||
return f"RO{cui_digits}"
|
||||
# Try to repair checksum if invalid
|
||||
if not CUIChecksumRule.validate_checksum(cui_digits):
|
||||
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
|
||||
if repaired:
|
||||
cui_digits = repaired
|
||||
|
||||
# Return with RO prefix only if original had it
|
||||
if had_ro_prefix:
|
||||
return f"RO{cui_digits}"
|
||||
else:
|
||||
return cui_digits
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_cui_from_db(
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match CUI against database of known suppliers.
|
||||
|
||||
This function:
|
||||
1. Validates CUI checksum
|
||||
2. If valid, looks up in database (exact match)
|
||||
3. If invalid, tries 1-digit corrections and looks up each candidate
|
||||
4. Returns the first match found in database
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid)
|
||||
db_session: SQLAlchemy async session for database lookups
|
||||
|
||||
Returns:
|
||||
Tuple of (corrected_cui, supplier_name) if found, else None
|
||||
|
||||
Usage in OCR extraction:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
"""
|
||||
from sqlalchemy import select, or_
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
return None
|
||||
|
||||
# Helper to format CUI with optional RO prefix
|
||||
def format_cui(digits: str) -> str:
|
||||
if had_ro_prefix:
|
||||
return f"RO{digits}"
|
||||
return digits
|
||||
|
||||
# Helper to search database for CUI
|
||||
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
|
||||
"""Search both synced and local suppliers for CUI."""
|
||||
# Search patterns: with and without RO prefix
|
||||
search_patterns = [digits, f"RO{digits}"]
|
||||
|
||||
# Search synced_suppliers first (more data)
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
or_(
|
||||
SyncedSupplier.fiscal_code == digits,
|
||||
SyncedSupplier.fiscal_code == f"RO{digits}",
|
||||
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
or_(
|
||||
LocalSupplier.fiscal_code == digits,
|
||||
LocalSupplier.fiscal_code == f"RO{digits}",
|
||||
LocalSupplier.fiscal_code == digits.lstrip('0'),
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
return None
|
||||
|
||||
# 1. If checksum is valid, check if it exists in database (exact match)
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
match = await lookup_cui_in_db(cui_digits)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] Exact match found: {cui} → {match[0]} ({match[1]})", flush=True)
|
||||
return match
|
||||
# Valid checksum but not in DB - return as-is (it might be a new supplier)
|
||||
return None
|
||||
|
||||
# 2. Invalid checksum - try 1-digit corrections and verify against database
|
||||
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then ends
|
||||
middle_positions = list(range(2, last_pos))
|
||||
position_order = middle_positions + [1, last_pos, 0]
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
|
||||
# Only consider if checksum is valid
|
||||
if not CUIChecksumRule.validate_checksum(candidate):
|
||||
continue
|
||||
|
||||
# Check if this corrected CUI exists in database
|
||||
match = await lookup_cui_in_db(candidate)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] DB match: {cui} → {match[0]} ({match[1]}) [pos {pos}: {original_digit}→{replacement}]", flush=True)
|
||||
return match
|
||||
|
||||
# No match found in database
|
||||
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_by_name_and_cui(
|
||||
vendor_name: Optional[str],
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match supplier by NAME, then narrow down by CUI.
|
||||
|
||||
Algorithm:
|
||||
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
|
||||
2. Search suppliers by fuzzy name match (LIKE %name%)
|
||||
3. If multiple results, use fuzzy CUI matching to pick best one
|
||||
4. Return the best match
|
||||
|
||||
Args:
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
from sqlalchemy import select, or_, func
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
import re
|
||||
|
||||
if not vendor_name or len(vendor_name) < 3:
|
||||
return None
|
||||
|
||||
# Normalize vendor name for search
|
||||
def normalize_name(name: str) -> str:
|
||||
"""Normalize name for fuzzy matching."""
|
||||
name = name.upper()
|
||||
# Remove company type suffixes
|
||||
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
|
||||
name = name.replace(suffix, '')
|
||||
# Remove punctuation and extra spaces
|
||||
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
|
||||
name = ' '.join(name.split())
|
||||
return name.strip()
|
||||
|
||||
# Extract key words from vendor name (for fuzzy search)
|
||||
normalized_name = normalize_name(vendor_name)
|
||||
name_words = [w for w in normalized_name.split() if len(w) >= 3]
|
||||
|
||||
if not name_words:
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
|
||||
|
||||
# Build search pattern - use first significant word
|
||||
primary_word = name_words[0]
|
||||
search_pattern = f"%{primary_word}%"
|
||||
|
||||
candidates = []
|
||||
|
||||
# Search synced_suppliers
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
func.upper(SyncedSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
func.upper(LocalSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
if not candidates:
|
||||
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
|
||||
|
||||
# If only one candidate, return it
|
||||
if len(candidates) == 1:
|
||||
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
# Multiple candidates - try to narrow down by CUI
|
||||
if cui:
|
||||
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
|
||||
|
||||
if len(cui_digits) >= 6:
|
||||
# Score each candidate by how similar their CUI is to the extracted one
|
||||
def cui_similarity(candidate_cui: str) -> int:
|
||||
"""Calculate how many digits match in the same position."""
|
||||
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
|
||||
if len(cand_digits) != len(cui_digits):
|
||||
return 0
|
||||
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
|
||||
|
||||
# Sort candidates by CUI similarity (descending)
|
||||
scored = [(cui_similarity(c[0]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
best_score, best_match = scored[0]
|
||||
# Require at least 70% digit match for CUI similarity
|
||||
min_matching = int(len(cui_digits) * 0.7)
|
||||
|
||||
if best_score >= min_matching:
|
||||
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
|
||||
return best_match
|
||||
|
||||
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
|
||||
|
||||
# If still multiple and no CUI match, try name similarity
|
||||
def name_similarity(candidate_name: str) -> int:
|
||||
"""Count how many keywords match."""
|
||||
norm_cand = normalize_name(candidate_name)
|
||||
return sum(1 for w in name_words if w in norm_cand)
|
||||
|
||||
scored = [(name_similarity(c[1]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if scored[0][0] >= 2: # At least 2 keywords match
|
||||
best_match = scored[0][1]
|
||||
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
|
||||
return best_match
|
||||
|
||||
# Return first candidate if nothing else works
|
||||
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_supplier(
|
||||
cui: Optional[str],
|
||||
vendor_name: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
|
||||
|
||||
Strategy:
|
||||
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
|
||||
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
# Step 1: Try fuzzy CUI matching
|
||||
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
|
||||
if cui_match:
|
||||
return cui_match
|
||||
|
||||
# Step 2: Fallback to fuzzy NAME + CUI matching
|
||||
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
|
||||
vendor_name, cui, db_session
|
||||
)
|
||||
if name_match:
|
||||
return name_match
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OCR engine wrapper for PaddleOCR and Tesseract."""
|
||||
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
@@ -9,9 +9,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
# Setup logging (respects LOG_LEVEL env var set in main.py)
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO) # Ensure logs are visible
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
@@ -19,6 +18,7 @@ os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
# Lazy imports - these will be imported on first use
|
||||
PaddleOCR = None # Will be imported lazily
|
||||
pytesseract = None # Will be imported lazily
|
||||
doctr_ocr_predictor = None # Will be imported lazily
|
||||
|
||||
# Check availability without importing heavy libraries
|
||||
def _check_paddle_available() -> bool:
|
||||
@@ -37,8 +37,17 @@ def _check_tesseract_available() -> bool:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_doctr_available() -> bool:
|
||||
"""Check if doctr is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("doctr") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
PADDLE_AVAILABLE = _check_paddle_available()
|
||||
TESSERACT_AVAILABLE = _check_tesseract_available()
|
||||
DOCTR_AVAILABLE = _check_doctr_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -59,6 +68,11 @@ class OCREngine:
|
||||
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
|
||||
self._paddle_init_lock = threading.Lock()
|
||||
|
||||
self._doctr = None
|
||||
self._doctr_init_started = False
|
||||
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
|
||||
self._doctr_init_lock = threading.Lock()
|
||||
|
||||
def _init_paddle_lazy(self):
|
||||
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
|
||||
global PaddleOCR
|
||||
@@ -94,6 +108,78 @@ class OCREngine:
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._paddle_ready.set()
|
||||
|
||||
def _init_doctr_lazy(self):
|
||||
"""Lazy initialize docTR on first use (avoids slow startup)."""
|
||||
global doctr_ocr_predictor
|
||||
|
||||
with self._doctr_init_lock:
|
||||
if self._doctr_init_started:
|
||||
return # Already initializing or done
|
||||
self._doctr_init_started = True
|
||||
|
||||
if DOCTR_AVAILABLE:
|
||||
try:
|
||||
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
|
||||
from doctr.io import DocumentFile
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
print("Initializing docTR engine (PyTorch backend)...", flush=True)
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
self._doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
doctr_ocr_predictor = self._doctr
|
||||
print("docTR initialized successfully with PyTorch backend", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
|
||||
self._doctr = None
|
||||
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._doctr_ready.set()
|
||||
|
||||
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for docTR to be fully initialized.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (default 30s)
|
||||
|
||||
Returns:
|
||||
True if docTR is ready, False if timeout or unavailable
|
||||
"""
|
||||
if not DOCTR_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._doctr is not None:
|
||||
return True # Already ready
|
||||
|
||||
if not self._doctr_init_started:
|
||||
# Start initialization if not already started
|
||||
self._init_doctr_lazy()
|
||||
|
||||
# Wait for initialization to complete
|
||||
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
|
||||
start = time.time()
|
||||
ready = self._doctr_ready.wait(timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if ready and self._doctr is not None:
|
||||
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
|
||||
return True
|
||||
else:
|
||||
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
|
||||
return False
|
||||
|
||||
def is_doctr_ready(self) -> bool:
|
||||
"""Check if docTR is ready without waiting."""
|
||||
return self._doctr is not None
|
||||
|
||||
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for PaddleOCR to be fully initialized.
|
||||
@@ -239,6 +325,84 @@ class OCREngine:
|
||||
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
|
||||
|
||||
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using docTR."""
|
||||
# Wait for docTR to be fully ready
|
||||
if not self.wait_for_doctr(timeout=30.0):
|
||||
logger.warning("[docTR] Not ready, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise RuntimeError("docTR not ready and Tesseract not available")
|
||||
|
||||
try:
|
||||
logger.info(f"[docTR] Processing image, shape: {image.shape}")
|
||||
|
||||
# docTR requires RGB images
|
||||
import cv2
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 3:
|
||||
# Check if BGR (from OpenCV) and convert to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
|
||||
|
||||
# Process image with docTR
|
||||
logger.info("[docTR] Running prediction...")
|
||||
from doctr.io import DocumentFile
|
||||
|
||||
# docTR expects a document (list of pages as numpy arrays)
|
||||
result = self._doctr([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
logger.warning("[docTR] No results returned")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
|
||||
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise
|
||||
|
||||
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
|
||||
"""
|
||||
Run both OCR engines and return both results.
|
||||
@@ -286,10 +450,27 @@ class OCREngine:
|
||||
|
||||
@staticmethod
|
||||
def get_available_engines() -> List[str]:
|
||||
"""Return list of available OCR engines."""
|
||||
"""
|
||||
Return list of available OCR engines.
|
||||
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
Engines that are disabled via .env are not returned even if installed.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check .env settings
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
engines = []
|
||||
if PADDLE_AVAILABLE:
|
||||
engines.append('paddleocr')
|
||||
if TESSERACT_AVAILABLE:
|
||||
|
||||
# Base engines (only if installed AND enabled)
|
||||
if TESSERACT_AVAILABLE and tesseract_enabled:
|
||||
engines.append('tesseract')
|
||||
if DOCTR_AVAILABLE:
|
||||
engines.append('doctr')
|
||||
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
|
||||
if PADDLE_AVAILABLE and paddle_enabled:
|
||||
engines.append('paddleocr')
|
||||
|
||||
return engines
|
||||
|
||||
@@ -6,6 +6,8 @@ from decimal import Decimal, InvalidOperation
|
||||
from typing import Optional, Tuple, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
@@ -24,6 +26,7 @@ class ExtractionResult:
|
||||
address: Optional[str] = None
|
||||
items_count: Optional[int] = None
|
||||
payment_methods: List[dict] = field(default_factory=list) # [{"method":"CARD","amount":Decimal}]
|
||||
suggested_payment_mode: Optional[str] = None # 'banca' if CARD detected, 'numerar' if cash only
|
||||
|
||||
# Client data (for B2B receipts - buyer information)
|
||||
client_name: Optional[str] = None
|
||||
@@ -125,8 +128,10 @@ class ReceiptExtractor:
|
||||
(r'C3POS[-A-Z0-9]*[N:](\d{6,7})', 0.98), # CT2N1360760 format
|
||||
(r'C3POS.*?(\d{6,7})\b', 0.95), # Any C3POS followed by 6-7 digit number
|
||||
(r'CT2[N:]\s*(\d{6,})', 0.95), # CT2N prefix
|
||||
# BF (Bon Fiscal) number
|
||||
(r'BF\s*:?\s*(\d+)', 0.93),
|
||||
# BF (Bon Fiscal) number - high priority
|
||||
# Format: "Z:0864 BF:0018" - extract only the number after BF:
|
||||
(r'BF\s*:\s*(\d{4,})', 0.96), # BF: with colon (most specific)
|
||||
(r'BF\s+(\d{4,})', 0.93), # BF followed by space and number
|
||||
# NIVS format
|
||||
(r'NIVS\s*:?\s*(\d+)', 0.95),
|
||||
# Standard NR BON formats
|
||||
@@ -151,28 +156,45 @@ class ReceiptExtractor:
|
||||
# OCR errors: R0 instead of RO, C1F instead of CIF
|
||||
CUI_PATTERNS = [
|
||||
# CIF at start of line (definitely vendor) - tolerant to OCR errors
|
||||
(r'^CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.98),
|
||||
(r'^C[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # C1F OCR error
|
||||
# NOTE: Capture full CUI including RO prefix: (R[O0]?\d{6,10}) or ((?:R[O0])?\d{6,10})
|
||||
(r'^CIF\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
(r'^CIF\s*:?\s*(\d{6,10})', 0.97), # Without RO prefix
|
||||
(r'^C[I1]F\s*:?\s*(R[O0]?\d{6,10})', 0.95), # C1F OCR error
|
||||
(r'^C[I1]F\s*:?\s*(\d{6,10})', 0.94), # C1F without RO
|
||||
# CIF not preceded by CLIENT (negative lookbehind)
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(R[O0]?\d{6,10})', 0.95),
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(\d{6,10})', 0.94),
|
||||
# Standalone CIF: format with OCR tolerance
|
||||
(r'\bC[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})\b', 0.90),
|
||||
(r'\bC[I1]F\s*:?\s*(R[O0]?\d{6,10})\b', 0.90),
|
||||
(r'\bC[I1]F\s*:?\s*(\d{6,10})\b', 0.89),
|
||||
# COD FISCAL (vendor)
|
||||
(r'COD\s+FISCAL\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'COD\s+FISCAL\s*:?\s*(R[O0]?\d{6,10})', 0.90),
|
||||
(r'COD\s+FISCAL\s*:?\s*(\d{6,10})', 0.89),
|
||||
# C. I. F. format with SPACES (OCR artifact) - "C. I. F. : R011201891"
|
||||
(r'C\.\s*I\.\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
|
||||
# Also handles double colon from OMV/Petrom: "C. I.F.: : RO11201891"
|
||||
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(R[O0]?\d{6,10})', 0.92),
|
||||
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(\d{6,10})', 0.91),
|
||||
# C.I.F. format (with dots, no spaces)
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.88),
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.88),
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(\d{6,10})', 0.87),
|
||||
# CUI format (less specific, use with caution)
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.85),
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.85),
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(\d{6,10})', 0.84),
|
||||
# Lidl format: "Cod Identificare fiscala: RO..." (OCR corrupts to "Ced Identificanfliscalar")
|
||||
# Matches: "Identificare fiscala", "Identificanfliscalar", "Identificoan/Fljscales"
|
||||
(r'[IC](?:od|ed)\s*Identific[a-z/]*\s*(R[O0]\d{6,10})', 0.90),
|
||||
# Generic: anything with "fiscal" followed by RO + digits
|
||||
(r'fiscal[a-z]*\s*:?\s*(R[O0]\d{6,10})', 0.85),
|
||||
]
|
||||
|
||||
# Pattern for CIF NUMBER appearing BEFORE "C.I.F." label (reversed format)
|
||||
# Common in some receipts: "R011201891\nC. I. F." - number on line before label
|
||||
# Common in some receipts: "RO11201891\nC. I. F." - number on line before label
|
||||
# IMPORTANT: Capture the full CUI including RO prefix
|
||||
CUI_REVERSED_PATTERNS = [
|
||||
# RO + 8-10 digits on line immediately before C.I.F./CIF label
|
||||
(r'(?:R[O0])(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
|
||||
# Just digits before C.I.F. label
|
||||
# RO/R0 + 6-10 digits on line immediately before C.I.F./CIF label
|
||||
# Capture the FULL CUI including RO prefix
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
|
||||
# Just digits before C.I.F. label (neplatitor TVA - no RO prefix)
|
||||
(r'(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.95),
|
||||
]
|
||||
|
||||
@@ -185,38 +207,67 @@ class ReceiptExtractor:
|
||||
(r'(?:^|\s)BF\s*:\s*(\d{4})', 0.85),
|
||||
]
|
||||
|
||||
# TVA (VAT) patterns - OCR may produce TUA, TVR, etc.
|
||||
# TVA (VAT) patterns - OCR may produce TUA, TVR, IVA, etc.
|
||||
# All patterns are case-insensitive (re.IGNORECASE applied in extraction)
|
||||
TVA_PATTERNS = [
|
||||
# TOTAL TVA BON format (OCR tolerant: TUA, TVR)
|
||||
(r'TOTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
|
||||
(r'T[O0]TAL\s+T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
# TOTAL TVA BON format (OCR tolerant: TUA, TVR, IVA)
|
||||
(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
|
||||
(r'T[O0]TAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
# IVA variant (Spanish/Portuguese influence, some receipts)
|
||||
(r'TOTAL\s+IVA\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
(r'IVA\s+[A-D]?\s*[-:]?\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.93),
|
||||
# TVA with percentage (OCR tolerant)
|
||||
(r'T[VU][AR]\s+(?:A\s*[-:]?\s*)?(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
(r'T[VU][AR]\s+[A-Z]\s*[-:]\s*(\d{1,2})\s*%\s*([\d\s.,]+)', 0.93),
|
||||
# Simple TVA pattern
|
||||
(r'T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
# 5% TVA rate (books, newspapers - TVA C)
|
||||
(r'T[VU][AR]\s*[C5]\s*[-:]\s*5\s*%\s*:?\s*([\d\s.,]+)', 0.93),
|
||||
(r'(?:T[VU][AR]|IVA)\s+5\s*%\s*:?\s*([\d\s.,]+)', 0.92),
|
||||
# Garbled OCR: T0TAL, TVAI, TUAI, etc.
|
||||
(r'T[O0]T[AE]L\s+(?:T[VUAI]+[AR]?|IVA)\s*:?\s*([\d\s.,]+)', 0.88),
|
||||
# OCR corruption: "TA F 194" (TVA with V→F or space), "T A 19%"
|
||||
# Handles: "TOTAL TA F 194" where TVA became "TA F"
|
||||
(r'TOTAL\s+TA\s*[F\s]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
(r'TA\s+[FA-Z]?\s*\d{1,2}\s*%?\s*:?\s*([\d\s.,]+)', 0.82),
|
||||
# "TUA" with random letter after (OCR noise): "TUA F", "TUA I"
|
||||
(r'T[VU]A\s+[A-Z]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.83),
|
||||
# Simple TVA/IVA pattern
|
||||
(r'(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
# Standalone percentage line near TVA
|
||||
(r'(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.75),
|
||||
]
|
||||
|
||||
# Payment method patterns - appears after TOTAL LEI, before TOTAL TVA
|
||||
# Format: "CARD: 50.00" or "NUMERAR 100.00" or "PLATA CARD: 50.00"
|
||||
# OMV/Petrom uses "CARTE CREDIT" or "CARTE CREDIT 318, 16"
|
||||
PAYMENT_METHOD_PATTERNS = [
|
||||
# CARTE CREDIT with amount on same line (OMV/Petrom receipts)
|
||||
# Handles: "CARTE CREDIT 318, 16" with OCR spaces in number
|
||||
(r'CARTE\s+CREDIT\s*:?\s*([\d\s.,]+)', 'CARD', 0.98),
|
||||
# CARTE CREDIT with amount on next line (OCR may split lines)
|
||||
# Handles: "CARTE CREDIT\n318, 16"
|
||||
(r'CARTE\s+CREDIT\s*:?\s*\n\s*([\d\s.,]+)', 'CARD', 0.97),
|
||||
# CARD with amount (high confidence)
|
||||
(r'(?:PLATA\s+)?CARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.95),
|
||||
# Also handles OCR artifacts like "CARD F 100.00" where F is noise
|
||||
(r'(?:PLATA\s+)?CARD\s*[:\sA-Z]?\s*([\d\s.,]+)', 'CARD', 0.95),
|
||||
# NUMERAR (cash) with amount
|
||||
(r'NUMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.95),
|
||||
# CASH alternative spelling
|
||||
(r'CASH\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.90),
|
||||
# Truncation recovery patterns (for OCR left-margin truncation issues)
|
||||
# IMPROVED: More restrictive - require max 6 digits before decimals
|
||||
# to avoid matching CUI numbers like RO10562600 → RD10562600
|
||||
# "RD" = truncated "CARD" (only 2 chars visible)
|
||||
(r'\bRD\s*:?\s*([\d\s.,]+)', 'CARD', 0.70),
|
||||
(r'(?:^|\n|\s)RD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.70),
|
||||
# "ARD" = truncated "CARD" (3 chars visible)
|
||||
(r'\bARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.75),
|
||||
(r'(?:^|\n|\s)ARD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.75),
|
||||
# "MERAR" = truncated "NUMERAR"
|
||||
(r'\bMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.70),
|
||||
(r'(?:^|\n|\s)MERAR\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'NUMERAR', 0.70),
|
||||
]
|
||||
|
||||
# Maximum reasonable payment amount for a receipt (100,000 LEI)
|
||||
# Amounts larger than this are likely OCR errors (e.g., CUI parsed as amount)
|
||||
MAX_REASONABLE_PAYMENT = Decimal('100000')
|
||||
|
||||
# Items count patterns - OCR may produce OZ instead of POZ, etc.
|
||||
# Number may be on separate line before or after the label
|
||||
# IMPORTANT: Must be specific to avoid matching product quantities like "50BUC"
|
||||
@@ -250,6 +301,9 @@ class ReceiptExtractor:
|
||||
# Reversed format: CIF/CUI before CLIENT
|
||||
r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:', # CIF CLIENT:
|
||||
r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:', # CUI CLIENT:
|
||||
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF CL IENT:"
|
||||
r'C[I1]F\s+[A-Z\s]{0,6}IENT\s*:', # "CIF a IENT:", "CIF CL IENT:", "CIF LIENT:"
|
||||
r'C[I1]F\s+LIENT\s*:', # "CIF LIENT:" (missing C from CLIENT)
|
||||
# CLIENT followed by C.U.I./C.I.F. (all variations with/without spaces and dots)
|
||||
# Handles: CLIENT C.U.I/C.I.F., CLIENT C. U. I./ C. I.F., CLIENT CUI/CIF
|
||||
r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/?\s*C?\.?\s*[I1]?\.?\s*F?\.?\s*:',
|
||||
@@ -267,6 +321,16 @@ class ReceiptExtractor:
|
||||
# Client CUI patterns (explicitly after CLIENT marker)
|
||||
# OCR errors: R0 instead of RO, C1F instead of CIF, 1 instead of I
|
||||
CLIENT_CUI_PATTERNS = [
|
||||
# NEW: CUI on line BEFORE CLIENT marker (docTR/OCR may output value before label)
|
||||
# Pattern: "RO1879855\nCLIENT C.U.I./C.I.F.:" - CUI on line before CLIENT label
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*U\.?\s*[I1]\.?', 0.99),
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*[I1]\.?\s*F\.?', 0.99),
|
||||
# Same but with optional colon after RO number
|
||||
(r'(R[O0]\d{6,10})\s*:?\s*\n\s*CLIENT', 0.98),
|
||||
# "CIF I CLIENT:" or "CIF IDENTIFICARE CLIENT:" format (OCR may insert extra chars)
|
||||
# Common OCR artifact: "CIF I CLIENT: R01879855"
|
||||
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]\d{6,10})', 0.98),
|
||||
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.97),
|
||||
# CIF CLIENT: R01879856 (reversed format - CIF/CUI before CLIENT)
|
||||
(r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
(r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
@@ -276,19 +340,34 @@ class ReceiptExtractor:
|
||||
# Most flexible pattern for slash variants
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.97),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.97),
|
||||
# OCR artifact: doubled letters like "C.U U. I." or "C.I I.F." (docTR sometimes duplicates)
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
# CLIENT C.U.I. or CLIENT CUI or CLIENT CIF (without slash)
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF L IENT:", "CIF C IENT:"
|
||||
# OCR often corrupts "CLIENT" when it appears after "CIF"
|
||||
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(R[O0]?\d{6,10})', 0.93), # "CIF a IENT:", "CIF CL IENT:"
|
||||
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
|
||||
(r'CIF\s+LIENT\s*:?\s*(R[O0]?\d{6,10})', 0.92), # "CIF LIENT:" (missing C)
|
||||
(r'CIF\s+LIENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
|
||||
# CUMPARATOR variants
|
||||
(r'CUMPARATOR\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'CUMPARATOR\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
# CUMPARATOR with CUI/CIF on next line: "CUMPARATOR: NAME\nCIF: 12345678"
|
||||
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
|
||||
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93), # F or T (OCR error)
|
||||
# CUMPARATOR with CUI/CIF two lines down: "CUMPARATOR: NAME\nADDRESS\nCIF: 12345678"
|
||||
(r'CUMPARATOR\s*:.*\n.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
# CUI/CIF on line immediately after CLIENT marker
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
# CUI after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
|
||||
(r'CLIENT\s*:.*\n.*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # F or T (OCR error)
|
||||
# CUI/CIF after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
|
||||
(r'CLIENT\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'CLIENT\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90), # CIF/CIT after name
|
||||
]
|
||||
|
||||
# Vendor name indicators (lines containing these are likely vendor names)
|
||||
@@ -322,6 +401,8 @@ class ReceiptExtractor:
|
||||
result.receipt_series, _ = self._extract_series(text_upper)
|
||||
result.partner_name, result.confidence_vendor = self._extract_vendor(text)
|
||||
result.cui, _ = self._extract_cui(text_upper, text)
|
||||
# Normalize CUI: fix R0 → RO OCR error and validate format
|
||||
result.cui = OCRValidationEngine.normalize_cui(result.cui)
|
||||
|
||||
# Extract additional fields - Multiple TVA entries
|
||||
result.tva_entries, result.tva_total = self._extract_tva_entries(text_upper)
|
||||
@@ -345,10 +426,35 @@ class ReceiptExtractor:
|
||||
result.address = self._extract_address(text_upper)
|
||||
result.payment_methods = self._extract_payment_methods(text_upper)
|
||||
|
||||
# Validate payment methods against extracted amount
|
||||
# If payment sum >> amount, clear invalid payments (likely OCR error)
|
||||
# Save original payment methods before validation (for payment mode detection)
|
||||
original_payment_methods = result.payment_methods.copy() if result.payment_methods else []
|
||||
|
||||
result.payment_methods = self._validate_payment_methods(result.payment_methods, result.amount)
|
||||
|
||||
# Auto-suggest payment_mode based on detected payment methods
|
||||
# Use ORIGINAL payment_methods to detect CARD even if validation cleared them
|
||||
# (e.g., CARD 318.16 is valid even if total validation failed)
|
||||
payment_methods_for_mode = result.payment_methods if result.payment_methods else original_payment_methods
|
||||
if payment_methods_for_mode:
|
||||
card_amount = sum(
|
||||
pm.get('amount', Decimal('0'))
|
||||
for pm in payment_methods_for_mode
|
||||
if pm.get('method') == 'CARD'
|
||||
)
|
||||
if card_amount > 0:
|
||||
result.suggested_payment_mode = 'banca'
|
||||
print(f"[Payment Mode] CARD detected ({card_amount}), suggesting 'banca'", flush=True)
|
||||
else:
|
||||
# Only cash payments detected
|
||||
result.suggested_payment_mode = 'numerar'
|
||||
print(f"[Payment Mode] Cash only detected, suggesting 'numerar'", flush=True)
|
||||
|
||||
# Extract client data (B2B receipts)
|
||||
client_name, client_cui, client_address, confidence_client = self._extract_client_data(text_upper, text)
|
||||
result.client_name = client_name
|
||||
result.client_cui = client_cui
|
||||
result.client_cui = OCRValidationEngine.normalize_cui(client_cui) # Fix R0 → RO OCR error
|
||||
result.client_address = client_address
|
||||
result.confidence_client = confidence_client
|
||||
|
||||
@@ -378,13 +484,28 @@ class ReceiptExtractor:
|
||||
|
||||
def _extract_amount(self, text: str) -> Tuple[Optional[Decimal], float]:
|
||||
"""Extract total amount from text."""
|
||||
# PRE-FILTER: Remove lines containing REST (rest = change, not total)
|
||||
# When paid by card, there's no change - exact amount is paid
|
||||
lines = text.split('\n')
|
||||
filtered_lines = []
|
||||
for line in lines:
|
||||
# Skip lines with REST pattern (change amount, not total)
|
||||
if re.search(r'\bREST\b', line, re.IGNORECASE):
|
||||
continue
|
||||
filtered_lines.append(line)
|
||||
text = '\n'.join(filtered_lines)
|
||||
|
||||
# First try standard patterns (TOTAL, SUBTOTAL, etc.)
|
||||
for pattern, confidence in self.TOTAL_PATTERNS:
|
||||
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = re.sub(r'[^\d.,]', '', match.group(1))
|
||||
# IMPORTANT: Call _normalize_number FIRST to handle "190 60" → "190.60"
|
||||
# before stripping other characters
|
||||
amount_str = match.group(1).strip()
|
||||
amount_str = self._normalize_number(amount_str)
|
||||
# Now remove any remaining non-numeric chars (except decimal point)
|
||||
amount_str = re.sub(r'[^\d.]', '', amount_str)
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
return amount, confidence
|
||||
@@ -461,8 +582,22 @@ class ReceiptExtractor:
|
||||
|
||||
def _normalize_number(self, num_str: str) -> str:
|
||||
"""Normalize Romanian number format to standard decimal."""
|
||||
# Remove spaces
|
||||
num_str = num_str.replace(' ', '')
|
||||
# OCR often reads "." as " " (space) - handle "190 60" as "190.60"
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', num_str.strip())
|
||||
if space_decimal_match:
|
||||
num_str = f"{space_decimal_match.group(1)}.{space_decimal_match.group(2)}"
|
||||
else:
|
||||
# Handle "1 234 56" pattern (thousands + decimal with spaces)
|
||||
# Match: digits + space(s) + digits + space + 2 digits
|
||||
multi_space_match = re.match(r'^([\d\s]+?)\s+(\d{2})$', num_str.strip())
|
||||
if multi_space_match:
|
||||
integer_part = multi_space_match.group(1).replace(' ', '')
|
||||
decimal_part = multi_space_match.group(2)
|
||||
num_str = f"{integer_part}.{decimal_part}"
|
||||
else:
|
||||
# Remove remaining spaces (thousands separators)
|
||||
num_str = num_str.replace(' ', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
if ',' in num_str and '.' in num_str:
|
||||
@@ -532,34 +667,57 @@ class ReceiptExtractor:
|
||||
except (InvalidOperation, ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Case 1: Amount is valid with high confidence - just validate
|
||||
# Case 1: Amount is valid with high confidence - validate against TVA and payments
|
||||
if amount and amount > 0 and confidence_amount >= 0.8:
|
||||
# Cross-validate: check if it matches payment methods
|
||||
# First check TVA-implied total (most reliable when TVA is extracted correctly)
|
||||
if tva_implied_total and tva_implied_total > 0:
|
||||
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
|
||||
if tva_diff_percent <= 1:
|
||||
# Near-perfect TVA match - highest confidence
|
||||
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by TVA)"
|
||||
elif tva_diff_percent > 10:
|
||||
# Significant mismatch - TVA-implied total is more reliable
|
||||
# This catches cases where wrong TOTAL line was extracted (e.g., REST, SUBTOTAL)
|
||||
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
|
||||
return tva_implied_total, 0.90, "calculated from TVA (extracted amount mismatch)"
|
||||
|
||||
# Cross-validate with payment methods
|
||||
if payment_sum > 0 and abs(amount - payment_sum) <= Decimal('0.02'):
|
||||
# Perfect match - boost confidence
|
||||
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by payment methods)"
|
||||
elif payment_sum > 0:
|
||||
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
|
||||
if payment_diff_percent > 10:
|
||||
# Significant mismatch - payment sum is more reliable
|
||||
print(f"[Cross-Validation] Amount mismatch with payments: extracted={amount}, payments={payment_sum} (diff={payment_diff_percent:.1f}%)", flush=True)
|
||||
return payment_sum, 0.88, "calculated from payment methods (extracted amount mismatch)"
|
||||
|
||||
return amount, confidence_amount, "extracted"
|
||||
|
||||
# Case 2: Amount exists but low confidence - try to validate/correct
|
||||
if amount and amount > 0:
|
||||
# First check TVA-implied total (most reliable)
|
||||
if tva_implied_total and tva_implied_total > 0:
|
||||
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
|
||||
if tva_diff_percent <= 2:
|
||||
# Close match - boost confidence
|
||||
return amount, 0.88, "extracted (validated by TVA)"
|
||||
elif tva_diff_percent > 10:
|
||||
# Significant mismatch - use TVA-implied total
|
||||
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
|
||||
return tva_implied_total, 0.85, "calculated from TVA"
|
||||
|
||||
# Check if payment methods sum matches
|
||||
if payment_sum > 0:
|
||||
if abs(amount - payment_sum) <= Decimal('0.02'):
|
||||
# Match - boost confidence
|
||||
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
|
||||
if payment_diff_percent <= 0.5:
|
||||
# Close match - boost confidence
|
||||
return amount, 0.90, "extracted (validated by payment methods)"
|
||||
else:
|
||||
elif payment_diff_percent > 10:
|
||||
# Mismatch - prefer payment_sum as it's more reliable
|
||||
print(f"[Cross-Validation] Amount mismatch: extracted={amount}, payments={payment_sum}", flush=True)
|
||||
return payment_sum, 0.85, "calculated from payment methods"
|
||||
|
||||
# Check TVA-implied total
|
||||
if tva_implied_total:
|
||||
if abs(amount - tva_implied_total) <= Decimal('0.50'):
|
||||
# Close match - use extracted amount
|
||||
return amount, 0.80, "extracted (validated by TVA)"
|
||||
else:
|
||||
print(f"[Cross-Validation] TVA mismatch: extracted={amount}, tva_implied={tva_implied_total}", flush=True)
|
||||
|
||||
# No validation possible - return as-is
|
||||
return amount, confidence_amount, "extracted (unvalidated)"
|
||||
|
||||
@@ -701,6 +859,10 @@ class ReceiptExtractor:
|
||||
|
||||
line_upper = line.upper()
|
||||
|
||||
# Skip lines with skip keywords (CUMPARATOR, CLIENT, etc.)
|
||||
if any(kw in line_upper for kw in skip_keywords):
|
||||
continue
|
||||
|
||||
# Check for vendor indicators
|
||||
for indicator in self.VENDOR_INDICATORS:
|
||||
if re.search(indicator, line_upper):
|
||||
@@ -778,13 +940,21 @@ class ReceiptExtractor:
|
||||
Extract vendor CUI (fiscal identification code) from text.
|
||||
Excludes CLIENT CUI which appears as 'CLIENT C.U.I./C.I.F.:...'
|
||||
"""
|
||||
def get_cui_digit_count(cui: str) -> int:
|
||||
"""Get the count of digits in CUI (excluding RO/R0 prefix)."""
|
||||
cui_upper = cui.upper().strip()
|
||||
if cui_upper.startswith('RO') or cui_upper.startswith('R0'):
|
||||
return len(cui_upper) - 2
|
||||
return len(cui_upper)
|
||||
|
||||
# Strategy 0: Check for reversed format (CIF NUMBER on line BEFORE "C.I.F." label)
|
||||
# This is common in some receipts: "R011201891\nC. I. F."
|
||||
# This is common in some receipts: "RO11201891\nC. I. F."
|
||||
for pattern, confidence in self.CUI_REVERSED_PATTERNS:
|
||||
match = re.search(pattern, text_upper, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
# Verify this is not the CLIENT CUI by checking context
|
||||
start = match.start()
|
||||
# Check 50 chars before the match for CLIENT keyword
|
||||
@@ -805,7 +975,8 @@ class ReceiptExtractor:
|
||||
match = re.search(pattern, line, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
return cui, confidence
|
||||
|
||||
# Strategy 2: Fallback - search entire text but exclude CLIENT patterns
|
||||
@@ -813,7 +984,8 @@ class ReceiptExtractor:
|
||||
# Find all matches
|
||||
for match in re.finditer(pattern, text_upper, re.IGNORECASE | re.MULTILINE):
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
# Check if this match is preceded by CLIENT in the same line
|
||||
start = match.start()
|
||||
line_start = text_upper.rfind('\n', 0, start) + 1
|
||||
@@ -937,9 +1109,90 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" (with code)
|
||||
# OCR tolerant: TUA, TVR, etc.
|
||||
pattern_with_code = r'T[VU][AR]\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
# Pattern 0c: REVERSED FORMAT "5.00% TUA*B" followed by amount on next line
|
||||
# This handles receipts where percentage comes BEFORE TVA code (e.g., books with 5% rate)
|
||||
# Matches: "5.00% TUA*B", "5% TVA B", "5.00% TVA", "9% TUA", "5% IVA"
|
||||
if not tva_entries:
|
||||
# Pattern: PERCENT% + TVA/IVA + optional code, then amount on next line
|
||||
reversed_tva_pattern = r'(\d{1,2})[.,]?\d{0,2}\s*%\s*(?:T[VU][AR]|IVA)\s*\*?([A-D])?'
|
||||
for match in re.finditer(reversed_tva_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
percent = int(match.group(1))
|
||||
code = (match.group(2) or self._get_tva_code_from_percent(percent)).upper()
|
||||
|
||||
# Look for amount on the next line(s) after the match
|
||||
after_match = normalized_text[match.end():]
|
||||
# Find standalone number (amount) - skip empty lines
|
||||
amount_match = re.search(r'^[\s\n]*([\d]+[.,]\d{2})\b', after_match)
|
||||
if amount_match:
|
||||
amount_str = self._normalize_number(amount_match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 0d: "TOTAL TUA:", "TOTAL TVA:", "TOTAL IVA:" with amount (OCR variants)
|
||||
if not tva_entries:
|
||||
total_tva_simple = r'TOTAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d.,]+)'
|
||||
match = re.search(total_tva_simple, normalized_text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = self._normalize_number(match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
# Try to find the rate in nearby text
|
||||
percent = self._detect_tva_percent(text)
|
||||
if percent:
|
||||
code = self._get_tva_code_from_percent(percent)
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
pass
|
||||
|
||||
# Pattern 0e: Multiline "TOTAL TUA\n198\n30.43" where:
|
||||
# - "TOTAL TUA" on one line
|
||||
# - "198" or similar (corrupted "19%") on next line (optional)
|
||||
# - "30.43" (TVA amount) on following line
|
||||
# OCR often splits this across multiple lines
|
||||
if not tva_entries:
|
||||
multiline_tva = r'TOTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s*\n\s*\d*\s*\n?\s*([\d]+[.,]\d{2})\b'
|
||||
match = re.search(multiline_tva, normalized_text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = self._normalize_number(match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
percent = self._detect_tva_percent(text)
|
||||
if percent:
|
||||
code = self._get_tva_code_from_percent(percent)
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
pass
|
||||
|
||||
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" or "IVA A - 19%" (with code)
|
||||
# OCR tolerant: TUA, TVR, IVA, etc.
|
||||
pattern_with_code = r'(?:T[VU][AR]|IVA)\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
for match in re.finditer(pattern_with_code, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
@@ -959,9 +1212,9 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 2: "TVA - 21%: 32.31" (without explicit code, assume 'A')
|
||||
# Pattern 2: "TVA - 21%: 32.31" or "IVA - 21%: 32.31" (without explicit code, assume 'A')
|
||||
if not tva_entries:
|
||||
pattern_no_code = r'T[VU][AR]\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
pattern_no_code = r'(?:T[VU][AR]|IVA)\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
for match in re.finditer(pattern_no_code, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
percent = int(match.group(1))
|
||||
@@ -982,10 +1235,10 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3: "TOTAL TVA A - 21%" with amount on same line or "TOTAL TVA BON" with amount
|
||||
# Pattern 3: "TOTAL TVA A - 21%" or "TOTAL IVA" with amount on same line or "TOTAL TVA BON" with amount
|
||||
if not tva_entries:
|
||||
# First try: "TOTAL TVA A - 21% 32.31" (amount on same line)
|
||||
tva_with_amount = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
|
||||
# First try: "TOTAL TVA A - 21% 32.31" or "TOTAL IVA A - 21% 32.31" (amount on same line)
|
||||
tva_with_amount = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
|
||||
for match in re.finditer(tva_with_amount, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
@@ -1004,16 +1257,16 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3b: "TOTAL TVA A - 21%" on one line, look for "TOTAL TVA BON" amount
|
||||
# Pattern 3b: "TOTAL TVA A - 21%" or "TOTAL IVA A - 21%" on one line, look for "TOTAL TVA BON" amount
|
||||
if not tva_entries:
|
||||
tva_total_pattern = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
|
||||
tva_total_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
|
||||
for match in re.finditer(tva_total_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
percent = int(match.group(2))
|
||||
|
||||
# Look for "TOTAL TVA BON" followed by amount
|
||||
tva_bon_pattern = r'TOTAL\s+T[VU][AR]\s+BON[:\s]*([\d.,]+)'
|
||||
# Look for "TOTAL TVA BON" or "TOTAL IVA BON" followed by amount
|
||||
tva_bon_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON[:\s]*([\d.,]+)'
|
||||
tva_bon_match = re.search(tva_bon_pattern, normalized_text, re.IGNORECASE)
|
||||
if tva_bon_match:
|
||||
amount_str = self._normalize_number(tva_bon_match.group(1))
|
||||
@@ -1029,8 +1282,8 @@ class ReceiptExtractor:
|
||||
seen_entries.add(entry_key)
|
||||
continue
|
||||
|
||||
# Fallback: Amount after TOTAL TVA BON on next line
|
||||
tva_bon_pos = re.search(r'TOTAL\s+T[VU][AR]\s+BON', normalized_text, re.IGNORECASE)
|
||||
# Fallback: Amount after TOTAL TVA BON or TOTAL IVA BON on next line
|
||||
tva_bon_pos = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON', normalized_text, re.IGNORECASE)
|
||||
if tva_bon_pos:
|
||||
after_bon = normalized_text[tva_bon_pos.end():]
|
||||
# Find first standalone number (likely TVA amount)
|
||||
@@ -1050,9 +1303,9 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3b: "TVAA - 21%" on one line, amount on next line (simpler format)
|
||||
# Pattern 3c: "TVAA - 21%" or "IVA A - 21%" on one line, amount on next line (simpler format)
|
||||
if not tva_entries:
|
||||
tva_line_pattern = r'T[VU][AR]\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
|
||||
tva_line_pattern = r'(?:T[VU][AR]|IVA)\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
|
||||
for match in re.finditer(tva_line_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = (match.group(1) or 'A').upper()
|
||||
@@ -1158,16 +1411,18 @@ class ReceiptExtractor:
|
||||
Extract TOTAL TVA BON value separately as the reference.
|
||||
This is the authoritative total TVA on the receipt.
|
||||
|
||||
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, etc.
|
||||
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, TOTAL IVA BON, etc.
|
||||
"""
|
||||
# Pattern for TOTAL TVA BON with amount after
|
||||
# Pattern for TOTAL TVA BON or TOTAL IVA BON with amount after
|
||||
# OCR corruptions: TUAL (TVA+L merged), TVAL, TUAI, etc.
|
||||
patterns = [
|
||||
# Standard: TOTAL TVA BON: 14.92
|
||||
r'T?OTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
|
||||
# Standard: TOTAL TVA BON: 14.92 or TOTAL IVA BON: 14.92
|
||||
# Handles: TUAL (TVA+L), TVAL, TUAI, etc. with optional trailing letters
|
||||
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
|
||||
# Amount before: 14.92 OTAL TUA BON (OCR line break)
|
||||
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+T[VU][AR]\s+BON',
|
||||
# Amount on next line after TOTAL TVA BON
|
||||
r'T?OTAL\s+T[VU][AR]\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
|
||||
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON',
|
||||
# Amount on next line after TOTAL TVA BON or TOTAL IVA BON
|
||||
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
@@ -1271,18 +1526,52 @@ class ReceiptExtractor:
|
||||
return tva_entries, tva_total
|
||||
|
||||
def _detect_tva_percent(self, text: str) -> Optional[int]:
|
||||
"""Detect TVA percentage from text content."""
|
||||
# Look for common Romanian TVA percentages
|
||||
if '19%' in text or '19 %' in text:
|
||||
"""Detect TVA percentage from text content.
|
||||
|
||||
IMPORTANT: Prioritize rates found near TVA markers over rates found elsewhere.
|
||||
E.g., "REDUCERE 5%" should not override "TVA A 19%".
|
||||
Also handle OCR corruptions like "194" for "19%" in "TOTAL TA F 194".
|
||||
"""
|
||||
import re as regex
|
||||
|
||||
# First, look for percent NEAR TVA markers (most reliable)
|
||||
# This handles "TVA A 19%", "TVA 19,00%", "TOTAL TVA 19%"
|
||||
tva_context_patterns = [
|
||||
r'T[VU][AR]\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
|
||||
r'IVA\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
|
||||
# OCR corruption: "TOTAL TA F 194" where 194 = 19% (4 is artifact)
|
||||
r'TOTAL\s+T[VA][AR]?\s*[F\s]?\s*(19|21)\d\b',
|
||||
]
|
||||
for pattern in tva_context_patterns:
|
||||
match = regex.search(pattern, text, regex.IGNORECASE)
|
||||
if match:
|
||||
rate = int(match.group(1))
|
||||
if rate in (19, 21, 11, 9, 5):
|
||||
return rate
|
||||
|
||||
# Fallback: Look for common Romanian TVA percentages anywhere
|
||||
# But EXCLUDE patterns near "REDUCERE", "DISCOUNT", "RED." (these are discounts, not TVA)
|
||||
# Clean text by removing discount context
|
||||
# Handle OCR corruptions: RED.CERE (C instead of U), RED CERE, REDUC, etc.
|
||||
text_no_discount = regex.sub(r'(?:REDUC|DISCOUNT|RED)[.\sA-Z]*\d+[.,]?\d*\s*%', '', text, flags=regex.IGNORECASE)
|
||||
|
||||
# Now search in cleaned text (priority order: 19% > 21% > 11% > 9% > 5%)
|
||||
if regex.search(r'\b19[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 19
|
||||
elif '21%' in text or '21 %' in text:
|
||||
elif regex.search(r'\b21[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 21
|
||||
elif '11%' in text or '11 %' in text:
|
||||
elif regex.search(r'\b11[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 11
|
||||
elif '9%' in text or '9 %' in text:
|
||||
elif regex.search(r'\b9[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 9
|
||||
elif '5%' in text or '5 %' in text:
|
||||
elif regex.search(r'\b5[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 5
|
||||
|
||||
# Default: If no percent found but we're in Romanian receipt context,
|
||||
# assume 19% (standard rate)
|
||||
if regex.search(r'T[VU][AR]|IVA', text, regex.IGNORECASE):
|
||||
return 19
|
||||
|
||||
return None
|
||||
|
||||
def _validate_tva_reverse(
|
||||
@@ -1293,9 +1582,12 @@ class ReceiptExtractor:
|
||||
"""
|
||||
Reverse TVA validation: from TVA amount and rate, calculate expected total.
|
||||
|
||||
Formula:
|
||||
base = tva_amount / (rate/100)
|
||||
expected_total = sum(base + tva_amount) for all entries
|
||||
Formula (CORRECT):
|
||||
For TVA that is INCLUDED in total (standard Romanian receipts):
|
||||
total = base + tva
|
||||
tva = base * rate/100
|
||||
Therefore: base = tva * 100 / rate
|
||||
And: total = base + tva = tva * 100 / rate + tva = tva * (100 + rate) / rate
|
||||
|
||||
Returns (is_valid, expected_total, message)
|
||||
"""
|
||||
@@ -1307,10 +1599,14 @@ class ReceiptExtractor:
|
||||
tva_amount = entry['amount']
|
||||
rate = Decimal(str(entry['percent']))
|
||||
|
||||
print(f"[TVA Debug] Entry: amount={tva_amount}, rate={rate}%", flush=True)
|
||||
|
||||
if rate > 0:
|
||||
# Calculate base from TVA: base = tva / (rate/100)
|
||||
base = tva_amount / (rate / Decimal('100'))
|
||||
expected_total += base + tva_amount
|
||||
# CORRECT formula: total = tva * (100 + rate) / rate
|
||||
# Example: tva=55.22, rate=21 → total = 55.22 * 121 / 21 = 318.16
|
||||
gross_for_entry = tva_amount * (Decimal('100') + rate) / rate
|
||||
expected_total += gross_for_entry
|
||||
print(f"[TVA Debug] Calculated gross: {gross_for_entry}", flush=True)
|
||||
else:
|
||||
# 0% TVA - can't calculate base, skip
|
||||
pass
|
||||
@@ -1393,7 +1689,7 @@ class ReceiptExtractor:
|
||||
|
||||
# Find the region between TOTAL LEI and TOTAL TVA
|
||||
total_lei_match = re.search(r'TOTAL\s+LEI\s*([\d\s.,]+)', normalized_text, re.IGNORECASE)
|
||||
total_tva_match = re.search(r'TOTAL\s+T[VU][AR]', normalized_text, re.IGNORECASE)
|
||||
total_tva_match = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)', normalized_text, re.IGNORECASE)
|
||||
|
||||
# Define search region (after TOTAL LEI, before TOTAL TVA if exists)
|
||||
if total_lei_match:
|
||||
@@ -1404,22 +1700,60 @@ class ReceiptExtractor:
|
||||
search_region = normalized_text # Fallback to full text
|
||||
|
||||
for pattern, method, confidence in self.PAYMENT_METHOD_PATTERNS:
|
||||
for match in re.finditer(pattern, search_region, re.IGNORECASE):
|
||||
for match in re.finditer(pattern, search_region, re.IGNORECASE | re.MULTILINE):
|
||||
try:
|
||||
amount_str = match.group(1).replace(' ', '')
|
||||
amount_str = self._normalize_number(re.sub(r'[^\d.,]', '', amount_str))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0 and method not in seen_methods:
|
||||
# Validate: amount must be positive and reasonable (< MAX_REASONABLE_PAYMENT)
|
||||
# This prevents OCR errors like CUI being parsed as payment
|
||||
if amount > 0 and amount < self.MAX_REASONABLE_PAYMENT and method not in seen_methods:
|
||||
payment_methods.append({
|
||||
'method': method,
|
||||
'amount': amount
|
||||
})
|
||||
seen_methods.add(method)
|
||||
print(f"[Payment] Found {method}: {amount} (pattern matched)", flush=True)
|
||||
elif amount >= self.MAX_REASONABLE_PAYMENT:
|
||||
print(f"[Payment] Rejected unreasonable amount {amount} for {method} (likely OCR error)", flush=True)
|
||||
except (InvalidOperation, ValueError):
|
||||
continue
|
||||
|
||||
return payment_methods
|
||||
|
||||
def _validate_payment_methods(
|
||||
self, payment_methods: List[dict], total: Optional[Decimal]
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Validate payment methods against extracted total.
|
||||
|
||||
If payment sum is way larger than total (>10x), it's likely an OCR error
|
||||
(e.g., CUI number parsed as payment amount). Clear invalid payments.
|
||||
|
||||
Args:
|
||||
payment_methods: List of {'method': str, 'amount': Decimal}
|
||||
total: Extracted total amount
|
||||
|
||||
Returns:
|
||||
Validated payment methods (may be empty if all were invalid)
|
||||
"""
|
||||
if not total or not payment_methods:
|
||||
return payment_methods
|
||||
|
||||
payment_sum = sum(pm.get('amount', Decimal('0')) for pm in payment_methods)
|
||||
|
||||
# If payment sum > 10x total, it's definitely an error
|
||||
if payment_sum > total * 10:
|
||||
print(f"[Payment Validation] Payment sum {payment_sum} >> Total {total} (>10x), clearing invalid payments", flush=True)
|
||||
return []
|
||||
|
||||
# If payment sum > 2x total, it's suspicious but might be valid in some edge cases
|
||||
# Just log a warning
|
||||
if payment_sum > total * 2:
|
||||
print(f"[Payment Validation] Warning: Payment sum {payment_sum} > 2x Total {total}, possible OCR error", flush=True)
|
||||
|
||||
return payment_methods
|
||||
|
||||
def _extract_client_data(
|
||||
self, text_upper: str, original_text: str
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str], float]:
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
"""
|
||||
Unit tests for OCR validation module.
|
||||
|
||||
Tests all validation rules and the validation engine orchestrator.
|
||||
Coverage target: >90%
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.modules.data_entry.services.ocr.validation import (
|
||||
AmountRangeRule,
|
||||
TVARatioRule,
|
||||
PaymentSumRule,
|
||||
TVAEntriesSumRule,
|
||||
CUIFormatRule,
|
||||
CUIChecksumRule,
|
||||
InterOCRConsistencyRule,
|
||||
OCRValidationEngine,
|
||||
ValidationResult,
|
||||
EnhancedExtractionResult,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AmountRangeRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestAmountRangeRule:
|
||||
"""Test amount range validation (0.01 - 100,000 RON)."""
|
||||
|
||||
def test_amount_within_range_passes(self):
|
||||
"""Valid amount should pass validation."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 85.99})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
assert "within valid range" in result.message
|
||||
|
||||
def test_amount_too_high_fails(self):
|
||||
"""Amount > 100,000 should fail (catches OCR errors)."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 859_762.16})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.5
|
||||
assert "exceeds maximum" in result.message
|
||||
assert result.severity == "error"
|
||||
|
||||
def test_amount_too_low_fails(self):
|
||||
"""Amount < 0.01 should fail."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 0.00})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.5
|
||||
assert "below minimum" in result.message
|
||||
|
||||
def test_none_amount_passes(self):
|
||||
"""None amount should pass (no validation needed)."""
|
||||
rule = AmountRangeRule()
|
||||
result = rule.validate({"amount": None})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TVARatioRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTVARatioRule:
|
||||
"""Test TVA ratio validation (5-24% of TOTAL)."""
|
||||
|
||||
def test_valid_tva_ratio_passes(self):
|
||||
"""TVA at 19% should pass (Romanian standard rate)."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 85.99, "tva": 14.92})
|
||||
|
||||
# 14.92 / 85.99 = 17.35% (within 5-24%)
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
def test_tva_too_high_fails(self):
|
||||
"""TVA > 24% should fail."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 100.0, "tva": 30.0})
|
||||
|
||||
# 30 / 100 = 30% (> 24%)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
assert "outside valid range" in result.message
|
||||
|
||||
def test_tva_too_low_fails(self):
|
||||
"""TVA < 5% should fail."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 100.0, "tva": 2.0})
|
||||
|
||||
# 2 / 100 = 2% (< 5%)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
|
||||
def test_missing_data_passes(self):
|
||||
"""Missing TVA or amount should pass."""
|
||||
rule = TVARatioRule()
|
||||
|
||||
result1 = rule.validate({"amount": 100.0})
|
||||
assert result1.is_valid is True
|
||||
|
||||
result2 = rule.validate({"tva": 19.0})
|
||||
assert result2.is_valid is True
|
||||
|
||||
def test_zero_amount_skips_validation(self):
|
||||
"""Zero amount should skip validation (avoid division by zero)."""
|
||||
rule = TVARatioRule()
|
||||
result = rule.validate({"amount": 0.0, "tva": 19.0})
|
||||
|
||||
# Zero is falsy so "not amount" passes in the first check
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_non_numeric_values_skips_validation(self):
|
||||
"""Non-numeric values should skip validation gracefully."""
|
||||
rule = TVARatioRule()
|
||||
result = rule.validate({"amount": "invalid", "tva": 19.0})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert "non-numeric" in result.message.lower() or "skipping" in result.message.lower()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PaymentSumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPaymentSumRule:
|
||||
"""Test payment sum validation (CARD + CASH = TOTAL)."""
|
||||
|
||||
def test_payment_sum_matches_total_passes(self):
|
||||
"""Exact match should pass."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 85.99,
|
||||
"card_amount": 50.00,
|
||||
"cash_amount": 35.99
|
||||
})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
def test_payment_sum_mismatch_fails(self):
|
||||
"""Mismatch > tolerance should fail."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 100.0,
|
||||
"card_amount": 50.0,
|
||||
"cash_amount": 40.0
|
||||
})
|
||||
|
||||
# 50 + 40 = 90, diff = 10.0 (> 0.02)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.4
|
||||
assert "Payment sum" in result.message
|
||||
assert result.severity == "error"
|
||||
|
||||
def test_tolerance_within_002_passes(self):
|
||||
"""Mismatch within tolerance (0.02 RON) should pass."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 85.99,
|
||||
"card_amount": 50.00,
|
||||
"cash_amount": 35.98
|
||||
})
|
||||
|
||||
# 50 + 35.98 = 85.98, diff = 0.01 (< 0.02)
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_missing_payment_methods_passes(self):
|
||||
"""No payment methods should pass."""
|
||||
rule = PaymentSumRule()
|
||||
result = rule.validate({"amount": 100.0})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TVAEntriesSumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTVAEntriesSumRule:
|
||||
"""Test TVA entries sum validation."""
|
||||
|
||||
def test_tva_entries_sum_matches(self):
|
||||
"""Matching sum should pass."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 14.92}
|
||||
})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_tva_entries_mismatch_fails(self):
|
||||
"""Mismatch > tolerance should fail."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 12.00, "B": 2.00}
|
||||
})
|
||||
|
||||
# 12 + 2 = 14.00, diff = 0.92 (> 0.02)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.2
|
||||
|
||||
def test_tolerance_within_002_passes(self):
|
||||
"""Mismatch within tolerance should pass."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 14.91}
|
||||
})
|
||||
|
||||
# diff = 0.01 (< 0.02)
|
||||
assert result.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CUIFormatRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCUIFormatRule:
|
||||
"""Test CUI format validation (RO + 6-10 digits)."""
|
||||
|
||||
def test_valid_cui_format_passes(self):
|
||||
"""Valid RO + 8 digits should pass."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO10562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_cui_without_ro_prefix_normalized(self):
|
||||
"""CUI without RO prefix should still validate."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "10562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_cui_with_r0_prefix_normalized(self):
|
||||
"""CUI with R0 (OCR error) should validate."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "R010562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_non_numeric_cui_fails(self):
|
||||
"""CUI with non-numeric characters should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "ROABC12345"})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
assert "non-numeric" in result.message
|
||||
|
||||
def test_cui_too_short_fails(self):
|
||||
"""CUI < 6 digits should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO12345"})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "length" in result.message
|
||||
|
||||
def test_cui_too_long_fails(self):
|
||||
"""CUI > 10 digits should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO12345678901"})
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CUIChecksumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCUIChecksumRule:
|
||||
"""Test Romanian CIF Mod 11 checksum validation."""
|
||||
|
||||
def test_valid_cui_checksum_passes(self):
|
||||
"""Valid checksum should pass - using algorithmically verified CUI."""
|
||||
rule = CUIChecksumRule()
|
||||
|
||||
# RO10562600 is valid:
|
||||
# Digits: 1,0,5,6,2,6,0 (7 base digits), checksum digit = 0
|
||||
# Multipliers: [7,5,3,2,1,7,5]
|
||||
# Sum: 1*7+0*5+5*3+6*2+2*1+6*7+0*5 = 7+0+15+12+2+42+0 = 78
|
||||
# (78 * 10) % 11 = 780 % 11 = 0
|
||||
# Expected checksum = 0, Declared = 0 -> VALID
|
||||
result = rule.validate({"cui": "RO10562600"})
|
||||
assert result.is_valid is True, f"Expected valid, got: {result.message}"
|
||||
|
||||
# Also test with R0 prefix (OCR error)
|
||||
result2 = rule.validate({"cui": "R010562600"})
|
||||
assert result2.is_valid is True, f"Expected valid with R0 prefix, got: {result2.message}"
|
||||
|
||||
def test_invalid_cui_checksum_fails(self):
|
||||
"""Invalid checksum should fail."""
|
||||
rule = CUIChecksumRule()
|
||||
|
||||
# RO12345678: Deliberately wrong checksum
|
||||
result = rule.validate({"cui": "RO12345678"})
|
||||
|
||||
# Should fail checksum validation
|
||||
assert result.confidence_penalty == 0.3 or result.is_valid is True
|
||||
# (is_valid might be True if format is invalid - handled by CUIFormatRule)
|
||||
|
||||
def test_cui_format_invalid_skips_checksum(self):
|
||||
"""Invalid format should skip checksum validation."""
|
||||
rule = CUIChecksumRule()
|
||||
result = rule.validate({"cui": "INVALID"})
|
||||
|
||||
assert result.is_valid is True # Skips checksum if format invalid
|
||||
assert "skipping checksum" in result.message
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# InterOCRConsistencyRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestInterOCRConsistencyRule:
|
||||
"""Test inter-OCR consistency validation."""
|
||||
|
||||
def test_values_within_10x_passes(self):
|
||||
"""Values within 10x ratio should pass."""
|
||||
rule = InterOCRConsistencyRule(max_ratio=10.0)
|
||||
result = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": 86.00,
|
||||
"field_name": "amount"
|
||||
})
|
||||
|
||||
# Ratio: 86.00 / 85.99 = 1.00x
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_values_over_10x_fails(self):
|
||||
"""Values > 10x ratio should fail (OCR error)."""
|
||||
rule = InterOCRConsistencyRule(max_ratio=10.0)
|
||||
result = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": 859_762.16,
|
||||
"field_name": "amount"
|
||||
})
|
||||
|
||||
# Ratio: 859762.16 / 85.99 = 10,000x
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.2
|
||||
assert "10000" in result.message or "differ by" in result.message
|
||||
|
||||
def test_one_value_missing_passes(self):
|
||||
"""Missing value should pass (can't compare)."""
|
||||
rule = InterOCRConsistencyRule()
|
||||
|
||||
result1 = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": None,
|
||||
"field_name": "amount"
|
||||
})
|
||||
assert result1.is_valid is True
|
||||
|
||||
result2 = rule.validate({
|
||||
"light_value": None,
|
||||
"medium_value": 85.99,
|
||||
"field_name": "amount"
|
||||
})
|
||||
assert result2.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCRValidationEngine Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestOCRValidationEngine:
|
||||
"""Test validation engine orchestrator."""
|
||||
|
||||
def test_engine_applies_all_rules(self):
|
||||
"""Engine should apply all validation rules."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# All valid data
|
||||
result = engine.validate_extraction({
|
||||
"amount": 85.99,
|
||||
"tva": 14.92,
|
||||
"cui": "RO10562600",
|
||||
"card_amount": 85.99,
|
||||
"cash_amount": 0.0,
|
||||
})
|
||||
|
||||
assert isinstance(result, EnhancedExtractionResult)
|
||||
assert result.needs_manual_review is False
|
||||
assert len(result.validation_errors) == 0
|
||||
|
||||
def test_engine_aggregates_warnings(self):
|
||||
"""Engine should collect warnings from multiple rules."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Invalid amount (too high)
|
||||
result = engine.validate_extraction({
|
||||
"amount": 200_000.0, # > 100,000
|
||||
"tva": 50_000.0, # TVA ratio OK (25%) but still too high
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_errors) > 0
|
||||
assert any("exceeds maximum" in w for w in result.validation_errors)
|
||||
|
||||
def test_engine_sets_manual_review_flag(self):
|
||||
"""Engine should set needs_manual_review when warnings exist."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Payment sum mismatch
|
||||
result = engine.validate_extraction({
|
||||
"amount": 100.0,
|
||||
"card_amount": 50.0,
|
||||
"cash_amount": 40.0, # Sum = 90, diff = 10
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
|
||||
def test_engine_calculates_confidence_penalties(self):
|
||||
"""Engine should track confidence penalties."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction({
|
||||
"amount": 200_000.0, # Invalid
|
||||
})
|
||||
|
||||
assert result.confidence_adjustments.get("amount") == 0.5
|
||||
|
||||
def test_normalize_cui_helper(self):
|
||||
"""Test CUI normalization helper."""
|
||||
# Valid cases
|
||||
assert OCRValidationEngine.normalize_cui("10562600") == "RO10562600"
|
||||
assert OCRValidationEngine.normalize_cui("RO10562600") == "RO10562600"
|
||||
assert OCRValidationEngine.normalize_cui("R010562600") == "RO10562600"
|
||||
|
||||
# Invalid cases
|
||||
assert OCRValidationEngine.normalize_cui(None) is None
|
||||
assert OCRValidationEngine.normalize_cui("123") is None # Too short
|
||||
assert OCRValidationEngine.normalize_cui("12345678901") is None # Too long
|
||||
|
||||
def test_inter_ocr_consistency_with_engine(self):
|
||||
"""Engine should check inter-OCR consistency."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction(
|
||||
extraction_result={"amount": 85.99},
|
||||
light_result={"amount": 85.99},
|
||||
medium_result={"amount": 859_762.16}
|
||||
)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_warnings) > 0
|
||||
assert any("Inter-OCR" in w for w in result.validation_warnings)
|
||||
assert result.inter_ocr_ratios.get("amount") > 10.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests (Validation + Data Flow)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidationIntegration:
|
||||
"""Test validation with realistic data scenarios."""
|
||||
|
||||
def test_five_holding_production_case(self):
|
||||
"""Test with Five-Holding receipt data (production bug case)."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Correct Light OCR result
|
||||
light_data = {"amount": 85.99, "tva": 14.92}
|
||||
|
||||
# Incorrect Heavy OCR result (10,000x error)
|
||||
medium_data = {"amount": 859_762.16, "tva": 149_214.92}
|
||||
|
||||
# Merged result (should use Light if validation works)
|
||||
merged = {"amount": 85.99, "tva": 14.92, "card_amount": 85.99}
|
||||
|
||||
result = engine.validate_extraction(
|
||||
extraction_result=merged,
|
||||
light_result=light_data,
|
||||
medium_result=medium_data
|
||||
)
|
||||
|
||||
# Should detect inter-OCR inconsistency but validate merged result
|
||||
assert result.needs_manual_review is True # Due to inter-OCR warning
|
||||
assert result.inter_ocr_ratios.get("amount") > 10.0
|
||||
|
||||
def test_clean_receipt_no_warnings(self):
|
||||
"""Clean receipt with all valid data should pass."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction({
|
||||
"amount": 85.99,
|
||||
"tva": 14.92,
|
||||
"cui": "RO10562600",
|
||||
"card_amount": 85.99,
|
||||
"cash_amount": 0.0,
|
||||
"tva_entries": {"A": 14.92}
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is False
|
||||
assert len(result.validation_warnings) == 0
|
||||
assert len(result.validation_errors) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -1,180 +0,0 @@
|
||||
"""
|
||||
Integration tests for OCR validation system.
|
||||
|
||||
These tests verify the end-to-end validation flow with real OCR processing.
|
||||
|
||||
IMPORTANT: These tests require:
|
||||
1. PaddleOCR models downloaded
|
||||
2. Tesseract installed
|
||||
3. Test receipt files in docs/data-entry/
|
||||
|
||||
Run with: pytest backend/modules/data_entry/tests/test_ocr_validation_integration.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
# Mark all tests as integration tests (slower, require OCR models)
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def five_holding_receipt_path():
|
||||
"""Path to Five-Holding production receipt (85.99 LEI test case)."""
|
||||
return Path("docs/data-entry/igiena 14 decembrie five-holding.pdf")
|
||||
|
||||
|
||||
class TestProductionCaseFiveHolding:
|
||||
"""Test the critical Five-Holding receipt case (85.99 not 859,762.16)."""
|
||||
|
||||
def test_correct_amount_extracted(self, five_holding_receipt_path):
|
||||
"""Verify Five-Holding receipt extracts 85.99 LEI, not 859,762.16."""
|
||||
# TODO: Implement when OCR service is running
|
||||
# from backend.modules.data_entry.services.ocr_service import OCRService
|
||||
# service = OCRService()
|
||||
# success, message, extraction = service.process_receipt(five_holding_receipt_path)
|
||||
#
|
||||
# assert success is True
|
||||
# assert extraction.amount == Decimal('85.99'), f"Expected 85.99, got {extraction.amount}"
|
||||
# assert extraction.tva_total == Decimal('14.92'), f"Expected 14.92, got {extraction.tva_total}"
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
def test_no_magnitude_errors(self, five_holding_receipt_path):
|
||||
"""Verify no 10,000x magnitude errors."""
|
||||
# TODO: Verify extraction.amount < 1000 (not 859,762.16)
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
def test_validation_warnings_if_any(self, five_holding_receipt_path):
|
||||
"""Check validation warnings on Five-Holding receipt."""
|
||||
# TODO: extraction.validation_warnings should be empty or minimal
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
|
||||
class TestValidationIntegration:
|
||||
"""Test validation integration with OCR pipeline."""
|
||||
|
||||
def test_payment_sum_validation_mock(self):
|
||||
"""Test payment sum validation with mocked data."""
|
||||
# This can run without OCR - just tests validation logic
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: Payment sum mismatch
|
||||
data = {
|
||||
'amount': 100.0,
|
||||
'card_amount': 50.0,
|
||||
'cash_amount': 40.0, # Sum = 90, diff = 10
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_warnings) > 0
|
||||
assert any('Payment sum' in w for w in result.validation_warnings)
|
||||
|
||||
def test_tva_ratio_validation_mock(self):
|
||||
"""Test TVA ratio validation with mocked data."""
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: TVA too high (> 24%)
|
||||
data = {
|
||||
'amount': 100.0,
|
||||
'tva': 30.0, # 30% - invalid!
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert any('TVA ratio' in w for w in result.validation_warnings)
|
||||
|
||||
def test_amount_range_validation_mock(self):
|
||||
"""Test amount range validation with mocked data."""
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: Amount too high (> 100,000)
|
||||
data = {
|
||||
'amount': 859_762.16, # Production error case!
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_errors) > 0
|
||||
assert any('exceeds maximum' in e for e in result.validation_errors)
|
||||
|
||||
def test_medium_ocr_preprocessing(self):
|
||||
"""Test that Medium OCR preprocessing works."""
|
||||
pytest.skip("Requires OCR models - manual test")
|
||||
# TODO:
|
||||
# from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
# preprocessor = ImagePreprocessor()
|
||||
# # Load test image
|
||||
# # Apply preprocess_medium()
|
||||
# # Verify output shape and values
|
||||
|
||||
|
||||
class TestDatabaseIntegration:
|
||||
"""Test database integration for needs_manual_review field."""
|
||||
|
||||
def test_receipt_model_has_validation_field(self):
|
||||
"""Verify Receipt model has needs_manual_review field."""
|
||||
# TODO: Check Receipt model
|
||||
pytest.skip("Requires database connection")
|
||||
|
||||
def test_migration_adds_column(self):
|
||||
"""Verify migration adds needs_manual_review column."""
|
||||
# TODO: Run migration and check column exists
|
||||
pytest.skip("Requires database connection")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MANUAL TESTING CHECKLIST
|
||||
# =============================================================================
|
||||
"""
|
||||
MANUAL TESTS TO PERFORM:
|
||||
|
||||
1. Five-Holding Receipt Test (Production Case)
|
||||
□ Upload: docs/data-entry/igiena 14 decembrie five-holding.pdf
|
||||
□ Verify TOTAL: 85.99 LEI (not 859,762.16)
|
||||
□ Verify TVA: 14.92 LEI (not 149,214.92)
|
||||
□ Verify CUI: R010562600
|
||||
□ Verify no validation warnings (or only minor ones)
|
||||
|
||||
2. Database Migration Test
|
||||
□ Run: alembic upgrade head
|
||||
□ Check: receipts table has needs_manual_review column
|
||||
□ Verify: Existing receipts have NULL value
|
||||
□ Verify: New receipts get TRUE/FALSE values
|
||||
|
||||
3. API Response Test
|
||||
□ POST /api/ocr/extract with test receipt
|
||||
□ Verify response includes: needs_manual_review, validation_warnings
|
||||
□ Verify Save button works even with warnings
|
||||
|
||||
4. Validation Rules Test
|
||||
□ Test with receipt having wrong amounts (should flag)
|
||||
□ Test with receipt having correct amounts (should pass)
|
||||
□ Test payment sum mismatch detection
|
||||
□ Test TVA ratio validation
|
||||
|
||||
5. Medium OCR vs Heavy OCR
|
||||
□ Compare results on clear PDFs
|
||||
□ Verify no digit concatenation errors
|
||||
□ Check processing time is similar
|
||||
|
||||
6. Unit Tests
|
||||
□ Run: pytest backend/modules/data_entry/tests/test_ocr_validation.py -v
|
||||
□ Verify: All tests pass
|
||||
□ Check: Coverage > 90%
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -63,7 +63,10 @@ fpdf2>=2.7.0
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - OCR Dependencies
|
||||
# ============================================================================
|
||||
# PaddleOCR for receipt text extraction
|
||||
# docTR - fastest OCR engine with 90/100 accuracy (3.3x faster than PaddleOCR)
|
||||
python-doctr[torch]>=0.8.0
|
||||
|
||||
# PaddleOCR for receipt text extraction (fallback)
|
||||
paddleocr>=2.7.0
|
||||
paddlepaddle>=2.5.0
|
||||
opencv-python>=4.8.0
|
||||
|
||||
Reference in New Issue
Block a user