feat(ocr): Add docTR OCR engine with metrics infrastructure
Add docTR as primary OCR engine with 2-tier sequential processing, OCR metrics tracking, and simplified engine selection. Features: - docTR OCR engine with light+medium preprocessing tiers - doctr_plus mode with early exit optimization (~65% fast path) - OCR metrics dashboard with per-engine statistics - User OCR preference persistence - Parallel worker pool for OCR processing - Cross-validation for extraction quality Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -460,6 +460,8 @@ venv.bak/
|
||||
venv.bak/
|
||||
venv/
|
||||
venv/
|
||||
venv-win/
|
||||
ocr_benchmark_*.json
|
||||
wallet/
|
||||
wheels/
|
||||
wheels/
|
||||
@@ -520,5 +522,6 @@ backend/data/cache/*.db
|
||||
backend/data/receipts/*.db
|
||||
backend/data/telegram/*.db
|
||||
backend/data/receipts/uploads/*
|
||||
backend/data/ocr_queue/
|
||||
!backend/data/*/.gitkeep
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# Implementation Prompt: Persistent OCR Worker + Engine Selection + Job Queue
|
||||
|
||||
## Context
|
||||
|
||||
Ești într-un git worktree separat (`feature/ocr-persistent-worker-queue`) pentru implementarea sistemului OCR îmbunătățit.
|
||||
|
||||
**Branch**: `feature/ocr-persistent-worker-queue`
|
||||
**Bază**: `main` (commit 77d74a9)
|
||||
|
||||
## Task Principal
|
||||
|
||||
Implementează sistemul OCR cu:
|
||||
1. **Worker Persistent** - PaddleOCR încărcat O DATĂ la startup (nu 30s per request)
|
||||
2. **Engine Selection** - Parametru API pentru `paddleocr`, `tesseract`, sau `auto`
|
||||
3. **Tesseract Optimizat** - Fix inversare imagine + OEM 1 + multi-PSM
|
||||
4. **Windows IIS Compatible** - Funcționează cu NSSM service management
|
||||
5. **SQLite Job Queue** - Procesare secvențială cu poziție în coadă și estimare timp
|
||||
|
||||
## Planul Detaliat
|
||||
|
||||
Citește planul complet în:
|
||||
```
|
||||
/home/marius/.claude/plans/serene-growing-newell.md
|
||||
```
|
||||
|
||||
## Fișiere de Creat (Noi)
|
||||
|
||||
1. `backend/modules/data_entry/services/ocr/__init__.py`
|
||||
2. `backend/modules/data_entry/services/ocr/ocr_worker_pool.py` - Manager ProcessPoolExecutor
|
||||
3. `backend/modules/data_entry/services/ocr/ocr_worker_process.py` - Cod pentru worker process
|
||||
4. `backend/modules/data_entry/services/ocr/tesseract_engine.py` - Tesseract optimizat
|
||||
5. `backend/modules/data_entry/services/ocr/job_queue.py` - SQLite Job Queue Manager
|
||||
6. `backend/modules/data_entry/services/ocr/job_worker.py` - Background worker pentru coadă
|
||||
7. `data/ocr_queue/` - Director pentru fișiere în coadă
|
||||
|
||||
## Fișiere de Modificat
|
||||
|
||||
1. `backend/modules/data_entry/schemas/ocr.py` - Noi scheme OCRJobResponse
|
||||
2. `backend/modules/data_entry/routers/ocr.py` - Endpoint-uri job queue + engine param
|
||||
3. `backend/modules/data_entry/services/ocr_service.py` - Folosire worker pool
|
||||
4. `backend/modules/data_entry/services/image_preprocessor.py` - Fix inversare Tesseract
|
||||
5. `backend/main.py` - Startup/shutdown hooks pentru worker pool + job worker
|
||||
|
||||
## Ordine Implementare (Faze)
|
||||
|
||||
### Faza 1: Infrastructură Worker
|
||||
1. Creare `services/ocr/__init__.py`
|
||||
2. Creare `ocr_worker_pool.py`
|
||||
3. Creare `ocr_worker_process.py`
|
||||
4. Creare `tesseract_engine.py`
|
||||
|
||||
### Faza 2: SQLite Job Queue
|
||||
5. Creare `job_queue.py` cu schema SQLite
|
||||
6. Creare `job_worker.py` background task
|
||||
7. Creare director `data/ocr_queue/`
|
||||
|
||||
### Faza 3: API Integration
|
||||
8. Update `schemas/ocr.py` - adăugare OCRJobResponse, OCRJobStatus
|
||||
9. Update `routers/ocr.py` - modificare /extract, adăugare /jobs/{id}
|
||||
10. Update `main.py` - startup job worker
|
||||
|
||||
### Faza 4: Tesseract Fix
|
||||
11. Fix inversare în `image_preprocessor.py`
|
||||
|
||||
### Faza 5: Frontend (opțional)
|
||||
12. Update componenta OCR pentru polling
|
||||
13. Afișare poziție coadă și estimare timp
|
||||
|
||||
## Criterii Succes
|
||||
|
||||
- [ ] Prima cerere OCR după restart: <5s (nu 30s)
|
||||
- [ ] 10 cereri consecutive fără memory leak
|
||||
- [ ] `?engine=tesseract` produce text lizibil
|
||||
- [ ] `?engine=paddleocr` funcționează independent
|
||||
- [ ] POST /extract returnează instant (<100ms) cu job_id
|
||||
- [ ] GET /jobs/{id} returnează poziție corectă în coadă
|
||||
- [ ] Estimare timp ±30% din realitate
|
||||
- [ ] Jobs expiră și se șterg după 24h
|
||||
- [ ] Windows: stop service → no orphan python.exe
|
||||
|
||||
## Comenzi Utile
|
||||
|
||||
```bash
|
||||
# Start backend (development)
|
||||
cd /mnt/e/proiecte/ab-worktrees/ocr-persistent-worker-queue
|
||||
./start-test.sh
|
||||
|
||||
# Verificare OCR
|
||||
curl -X POST "http://localhost:8001/api/data-entry/ocr/extract?engine=auto" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-F "file=@test-receipt.jpg"
|
||||
|
||||
# Poll job status
|
||||
curl "http://localhost:8001/api/data-entry/ocr/jobs/{job_id}" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# Queue status
|
||||
curl "http://localhost:8001/api/data-entry/ocr/queue/status" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
## Notă Importantă
|
||||
|
||||
Acest worktree este izolat de main. După finalizare, SI DOAR DUPA TOATE TESTELE, OFERA POSIBILITATEA UTILIZATORULUI (NU FACE TU AUTOMAT):
|
||||
1. Commit toate schimbările
|
||||
2. Push branch-ul: `git push -u origin feature/ocr-persistent-worker-queue`
|
||||
3. Creează PR către main
|
||||
|
||||
## Start
|
||||
|
||||
Începe cu Faza 1 - creează directorul `services/ocr/` și primul fișier `__init__.py`.
|
||||
@@ -104,6 +104,34 @@ MAX_UPLOAD_SIZE_MB=10
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=true
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -112,6 +112,40 @@ DATA_ENTRY_SQLITE_DATABASE_PATH=data/receipts/receipts.db
|
||||
DATA_ENTRY_UPLOAD_PATH=data/receipts/uploads
|
||||
DATA_ENTRY_MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=true
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit, ~7.5s avg)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# Active OCR engines shown in frontend dropdown (comma-separated)
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# doctr_plus: 73.3% perfect, 7.5s avg, 65% fast path (recommended)
|
||||
# doctr: 63.3% perfect, simpler but faster
|
||||
OCR_ACTIVE_ENGINES=tesseract,doctr,doctr_plus,paddleocr
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -96,6 +96,35 @@ SQLITE_DATABASE_PATH=data/receipts/receipts_prod.db
|
||||
UPLOAD_PATH=data/receipts/uploads
|
||||
MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
# PRODUCTION: Set based on server memory availability
|
||||
OCR_ENABLE_PADDLEOCR=false
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
@@ -105,6 +105,34 @@ MAX_UPLOAD_SIZE_MB=10
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# OCR ENGINE CONFIGURATION
|
||||
# ============================================================================
|
||||
# Control which OCR engines are loaded at startup.
|
||||
# Disabling engines saves memory but limits available OCR modes.
|
||||
|
||||
# Enable/disable PaddleOCR (set to 'false' to save ~800MB RAM)
|
||||
# When disabled: 'paddleocr' engine unavailable
|
||||
OCR_ENABLE_PADDLEOCR=false
|
||||
|
||||
# Enable/disable Tesseract (set to 'false' to save ~50MB RAM)
|
||||
# When disabled: 'tesseract' engine unavailable
|
||||
OCR_ENABLE_TESSERACT=true
|
||||
|
||||
# Default OCR engine when not specified in request
|
||||
# Options: tesseract, doctr, doctr_plus, paddleocr
|
||||
# Recommended: doctr_plus (2-tier sequential with early exit)
|
||||
OCR_DEFAULT_ENGINE=doctr_plus
|
||||
|
||||
# OCR Worker Pool Configuration
|
||||
# Number of parallel OCR workers (each loads ~1GB for docTR)
|
||||
# Recommended: 2 for 8GB RAM, 3 for 16GB RAM
|
||||
OCR_WORKERS=2
|
||||
|
||||
# Max tasks per worker before restart (0 = no restart, saves 40-60s warmup time)
|
||||
# Set to 0 for testing, 10-20 for production (prevents memory leaks)
|
||||
OCR_MAX_TASKS_PER_CHILD=0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
|
||||
168
backend/TEST-OCR-WINDOWS.bat
Normal file
168
backend/TEST-OCR-WINDOWS.bat
Normal file
@@ -0,0 +1,168 @@
|
||||
@echo off
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
cd /d "%~dp0"
|
||||
|
||||
REM Parse command line arguments for worker counts
|
||||
REM Usage: TEST-OCR-WINDOWS.bat [worker_counts...]
|
||||
REM Examples:
|
||||
REM TEST-OCR-WINDOWS.bat -> tests 1,2,3 workers (default)
|
||||
REM TEST-OCR-WINDOWS.bat 1 -> tests only 1 worker
|
||||
REM TEST-OCR-WINDOWS.bat 3 6 -> tests 3 and 6 workers
|
||||
REM TEST-OCR-WINDOWS.bat 1 2 3 4 5 6 -> tests all
|
||||
|
||||
set "WORKER_LIST=%*"
|
||||
if "%WORKER_LIST%"=="" set "WORKER_LIST=1 2 3"
|
||||
|
||||
echo.
|
||||
echo ==========================================
|
||||
echo OCR Benchmark - Windows (Workers: %WORKER_LIST%)
|
||||
echo ==========================================
|
||||
echo.
|
||||
|
||||
REM Check if Poppler is installed
|
||||
where pdftoppm >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo Checking for Poppler...
|
||||
if exist "E:\poppler" (
|
||||
for /r "E:\poppler" %%i in (pdftoppm.exe) do (
|
||||
set "POPPLER_BIN=%%~dpi"
|
||||
goto :found_poppler
|
||||
)
|
||||
)
|
||||
echo.
|
||||
echo ERROR: Poppler not found!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
:found_poppler
|
||||
if defined POPPLER_BIN (
|
||||
echo Found Poppler at: %POPPLER_BIN%
|
||||
set "PATH=%POPPLER_BIN%;%PATH%"
|
||||
)
|
||||
|
||||
REM Check venv
|
||||
if not exist "venv-win\Scripts\python.exe" (
|
||||
echo ERROR: venv-win not found!
|
||||
echo Run: python -m venv venv-win
|
||||
echo Then: venv-win\Scripts\pip install -r requirements.txt
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
REM Set common environment
|
||||
set JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
|
||||
set ORACLE_HOST=10.0.20.121
|
||||
set ORACLE_PORT=1521
|
||||
set ORACLE_USER=CONTAFIN_ORACLE
|
||||
set ORACLE_PASSWORD=ROMFASTSOFT
|
||||
set ORACLE_SERVICE_NAME=ROA
|
||||
set OCR_ENABLE_PADDLEOCR=false
|
||||
set OCR_ENABLE_TESSERACT=false
|
||||
set OCR_DEFAULT_ENGINE=hybrid-doctr
|
||||
set OCR_MAX_TASKS_PER_CHILD=0
|
||||
set LOG_LEVEL=WARNING
|
||||
|
||||
REM Results file with timestamp
|
||||
for /f "tokens=2 delims==" %%I in ('wmic os get localdatetime /value') do set datetime=%%I
|
||||
set RESULTS_FILE=ocr_benchmark_%datetime:~0,8%_%datetime:~8,4%.json
|
||||
|
||||
echo Results will be saved to: %RESULTS_FILE%
|
||||
echo.
|
||||
|
||||
REM Delete old results file if exists
|
||||
if exist "%RESULTS_FILE%" del "%RESULTS_FILE%"
|
||||
|
||||
REM Run tests with specified workers
|
||||
for %%W in (%WORKER_LIST%) do (
|
||||
call :run_test %%W
|
||||
)
|
||||
|
||||
goto :show_summary
|
||||
|
||||
:run_test
|
||||
set WORKERS=%1
|
||||
echo.
|
||||
echo ############################################################
|
||||
echo STARTING TEST WITH %WORKERS% WORKER(S)
|
||||
echo ############################################################
|
||||
echo.
|
||||
|
||||
REM Kill existing processes on port 8006
|
||||
echo Cleaning up old processes...
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
|
||||
taskkill /F /PID %%a >nul 2>&1
|
||||
)
|
||||
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
|
||||
timeout /t 3 >nul
|
||||
|
||||
REM Set workers count
|
||||
set OCR_WORKERS=%WORKERS%
|
||||
|
||||
echo Starting backend with %WORKERS% OCR worker(s)...
|
||||
|
||||
REM Start backend in a new minimized window with all OCR env vars
|
||||
start /min "ROA2WEB Backend %WORKERS% workers" cmd /c "set OCR_WORKERS=%WORKERS%&& set OCR_ENABLE_PADDLEOCR=false&& set OCR_ENABLE_TESSERACT=false&& set OCR_DEFAULT_ENGINE=hybrid-doctr&& set LOG_LEVEL=WARNING&& venv-win\Scripts\python.exe -m uvicorn main:app --host 0.0.0.0 --port 8006 --workers 1 2>&1"
|
||||
|
||||
REM Wait for backend to be ready
|
||||
echo Waiting for backend to start...
|
||||
set attempts=0
|
||||
:wait_loop
|
||||
timeout /t 3 >nul
|
||||
set /a attempts+=1
|
||||
curl -s http://localhost:8006/health >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
if !attempts! lss 40 (
|
||||
echo Waiting... !attempts!/40
|
||||
goto :wait_loop
|
||||
)
|
||||
echo ERROR: Backend failed to start!
|
||||
goto :eof
|
||||
)
|
||||
|
||||
echo Backend is ready!
|
||||
|
||||
REM Wait for OCR warmup
|
||||
echo Waiting for OCR worker warmup (30s)...
|
||||
timeout /t 30 >nul
|
||||
|
||||
echo.
|
||||
echo Running OCR test with %WORKERS% worker(s)...
|
||||
echo.
|
||||
|
||||
venv-win\Scripts\python.exe ..\tests\ocr-validation\test_receipts_parallel_windows.py --port 8006 --workers %WORKERS% --output %RESULTS_FILE%
|
||||
|
||||
REM Stop backend
|
||||
echo.
|
||||
echo Stopping backend...
|
||||
taskkill /F /FI "WINDOWTITLE eq ROA2WEB Backend*" >nul 2>&1
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :8006 ^| findstr LISTENING 2^>nul') do (
|
||||
taskkill /F /PID %%a >nul 2>&1
|
||||
)
|
||||
|
||||
REM Wait for memory to be released
|
||||
echo Releasing memory (10s)...
|
||||
timeout /t 10 >nul
|
||||
goto :eof
|
||||
|
||||
:show_summary
|
||||
echo.
|
||||
echo ############################################################
|
||||
echo ALL TESTS COMPLETE
|
||||
echo ############################################################
|
||||
echo.
|
||||
echo Results saved to: %RESULTS_FILE%
|
||||
echo.
|
||||
|
||||
REM Show summary from results file
|
||||
if exist "%RESULTS_FILE%" (
|
||||
echo BENCHMARK SUMMARY:
|
||||
echo ------------------
|
||||
venv-win\Scripts\python.exe -c "import json; data=json.load(open('%RESULTS_FILE%')); print(); [print(f\" {r['workers']} worker(s): {r['total_time']:.1f}s total, {r['avg_time']:.1f}s avg, {r.get('peak_memory_mb', 0):.0f}MB peak, {r['successful']}/{r['submitted']} success\") for r in data]"
|
||||
echo.
|
||||
)
|
||||
|
||||
echo Press any key to exit...
|
||||
pause >nul
|
||||
|
||||
endlocal
|
||||
@@ -38,12 +38,20 @@ from backend.modules.reports.routers import create_reports_router
|
||||
from backend.modules.data_entry.routers import create_data_entry_router
|
||||
from backend.modules.telegram.routers import create_telegram_router
|
||||
|
||||
# Configure logging
|
||||
# Configure logging (level from env: DEBUG, INFO, WARNING, ERROR)
|
||||
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('multipart').setLevel(logging.WARNING)
|
||||
logging.getLogger('doctr').setLevel(logging.WARNING)
|
||||
logging.getLogger('tensorflow').setLevel(logging.WARNING)
|
||||
logging.getLogger('PIL').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for background tasks
|
||||
|
||||
@@ -48,6 +48,11 @@ class Settings(BaseSettings):
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:3010,http://localhost:3000"
|
||||
|
||||
# OCR Engines (comma-separated list of active engines shown in UI)
|
||||
# Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
# doctr_plus is recommended (2-tier sequential with early exit)
|
||||
ocr_active_engines: str = "doctr,doctr_plus"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
@@ -80,6 +85,11 @@ class Settings(BaseSettings):
|
||||
"""Get CORS origins as list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
@property
|
||||
def ocr_active_engines_list(self) -> List[str]:
|
||||
"""Get OCR active engines as list."""
|
||||
return [engine.strip() for engine in self.ocr_active_engines.split(",")]
|
||||
|
||||
@property
|
||||
def oracle_dsn(self) -> str:
|
||||
"""Get Oracle DSN string."""
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
from .receipt import ReceiptCRUD
|
||||
from .attachment import AttachmentCRUD
|
||||
from .accounting_entry import AccountingEntryCRUD
|
||||
from .ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCRUD",
|
||||
"AttachmentCRUD",
|
||||
"AccountingEntryCRUD",
|
||||
"OCRPreferenceCRUD",
|
||||
"OCRMetricsCRUD",
|
||||
]
|
||||
|
||||
222
backend/modules/data_entry/db/crud/ocr_settings.py
Normal file
222
backend/modules/data_entry/db/crud/ocr_settings.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""CRUD operations for OCR settings and metrics."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.ocr_settings import (
|
||||
UserOCRPreference,
|
||||
OCRJobMetrics,
|
||||
OCRMetricsSummary,
|
||||
OCREngine,
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceCRUD:
|
||||
"""CRUD operations for user OCR preferences."""
|
||||
|
||||
@staticmethod
|
||||
async def get_by_username(session: AsyncSession, username: str) -> Optional[UserOCRPreference]:
|
||||
"""Get user's OCR preference by username."""
|
||||
result = await session.execute(
|
||||
select(UserOCRPreference).where(UserOCRPreference.username == username)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def create_or_update(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
preferred_engine: OCREngine
|
||||
) -> UserOCRPreference:
|
||||
"""Create or update user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
|
||||
if existing:
|
||||
existing.preferred_engine = preferred_engine
|
||||
existing.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
preference = UserOCRPreference(
|
||||
username=username,
|
||||
preferred_engine=preferred_engine
|
||||
)
|
||||
session.add(preference)
|
||||
await session.commit()
|
||||
await session.refresh(preference)
|
||||
return preference
|
||||
|
||||
@staticmethod
|
||||
async def delete_by_username(session: AsyncSession, username: str) -> bool:
|
||||
"""Delete user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class OCRMetricsCRUD:
|
||||
"""CRUD operations for OCR job metrics."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
company_id: Optional[int] = None
|
||||
) -> OCRJobMetrics:
|
||||
"""Create a new OCR job metrics record."""
|
||||
metrics = OCRJobMetrics(
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
company_id=company_id,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
session.add(metrics)
|
||||
await session.commit()
|
||||
await session.refresh(metrics)
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
async def get_by_job_id(session: AsyncSession, job_id: str) -> Optional[OCRJobMetrics]:
|
||||
"""Get metrics by job ID."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics).where(OCRJobMetrics.job_id == job_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_user_history(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[OCRJobMetrics]:
|
||||
"""Get user's OCR job history."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics)
|
||||
.where(OCRJobMetrics.username == username)
|
||||
.order_by(OCRJobMetrics.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def get_summary_by_engine(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> List[OCRMetricsSummary]:
|
||||
"""Get summary metrics grouped by engine."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Build query
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
# Query for aggregated metrics
|
||||
result = await session.execute(
|
||||
select(
|
||||
OCRJobMetrics.engine_used,
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
func.avg(OCRJobMetrics.fields_extracted).label('avg_fields_extracted'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
.group_by(OCRJobMetrics.engine_used)
|
||||
.order_by(func.count(OCRJobMetrics.id).desc())
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for row in result.all():
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
success_rate = successful / total if total > 0 else 0.0
|
||||
summaries.append(OCRMetricsSummary(
|
||||
engine=row.engine_used,
|
||||
total_jobs=total,
|
||||
successful_jobs=successful,
|
||||
failed_jobs=total - successful,
|
||||
success_rate=success_rate,
|
||||
avg_processing_time_ms=float(row.avg_processing_time_ms or 0),
|
||||
avg_confidence=float(row.avg_confidence or 0),
|
||||
avg_fields_extracted=float(row.avg_fields_extracted or 0),
|
||||
))
|
||||
|
||||
return summaries
|
||||
|
||||
@staticmethod
|
||||
async def get_overall_stats(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> dict:
|
||||
"""Get overall OCR statistics."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
)
|
||||
|
||||
row = result.one()
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
|
||||
return {
|
||||
"total_jobs": total,
|
||||
"successful_jobs": successful,
|
||||
"failed_jobs": total - successful,
|
||||
"success_rate": (successful / total * 100) if total > 0 else 0.0,
|
||||
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
|
||||
"avg_confidence": float(row.avg_confidence or 0),
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
|
||||
# Import sqlalchemy for func.cast
|
||||
import sqlalchemy as sa
|
||||
@@ -10,9 +10,10 @@ from backend.modules.data_entry.config import settings
|
||||
|
||||
|
||||
# Create async engine
|
||||
# Note: echo=False to disable SQL query logging (too verbose)
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.debug,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection
|
||||
from .accounting_entry import AccountingEntry, EntryType
|
||||
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from .ocr_settings import UserOCRPreference, OCRJobMetrics, OCRMetricsSummary, OCREngine
|
||||
|
||||
__all__ = [
|
||||
"Receipt",
|
||||
@@ -14,4 +15,9 @@ __all__ = [
|
||||
"SyncedSupplier",
|
||||
"LocalSupplier",
|
||||
"SyncedCashRegister",
|
||||
# OCR Settings & Metrics
|
||||
"UserOCRPreference",
|
||||
"OCRJobMetrics",
|
||||
"OCRMetricsSummary",
|
||||
"OCREngine",
|
||||
]
|
||||
|
||||
102
backend/modules/data_entry/db/models/ocr_settings.py
Normal file
102
backend/modules/data_entry/db/models/ocr_settings.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""OCR settings and metrics SQLModel models."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class OCREngine(str, Enum):
|
||||
"""Available OCR engines."""
|
||||
TESSERACT = "tesseract"
|
||||
DOCTR = "doctr"
|
||||
DOCTR_PLUS = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
PADDLEOCR = "paddleocr"
|
||||
|
||||
|
||||
class UserOCRPreference(SQLModel, table=True):
|
||||
"""
|
||||
User's preferred OCR engine setting.
|
||||
|
||||
Each user can have one preferred OCR engine that will be
|
||||
auto-selected when they upload new receipts for processing.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_ocr_preferences"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# User identification
|
||||
username: str = Field(max_length=100, unique=True, index=True)
|
||||
|
||||
# Preference settings
|
||||
preferred_engine: OCREngine = Field(default=OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRJobMetrics(SQLModel, table=True):
|
||||
"""
|
||||
OCR job processing metrics for analytics.
|
||||
|
||||
Stores metrics for each OCR job to enable:
|
||||
- Performance tracking by engine
|
||||
- Success rate analysis
|
||||
- Processing time trends
|
||||
- User-specific analytics
|
||||
"""
|
||||
|
||||
__tablename__ = "ocr_job_metrics"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Job identification
|
||||
job_id: str = Field(max_length=50, unique=True, index=True)
|
||||
|
||||
# User and company context
|
||||
username: str = Field(max_length=100, index=True)
|
||||
company_id: Optional[int] = Field(default=None, index=True)
|
||||
|
||||
# Engine used
|
||||
engine_requested: str = Field(max_length=20) # What user/auto requested
|
||||
engine_used: str = Field(max_length=50) # What was actually used (e.g., "doctr-light")
|
||||
|
||||
# Processing metrics
|
||||
processing_time_ms: int = Field(default=0)
|
||||
file_size_bytes: int = Field(default=0)
|
||||
file_type: str = Field(max_length=50, default="image/jpeg") # MIME type
|
||||
original_filename: Optional[str] = Field(default=None, max_length=255) # Original uploaded filename
|
||||
|
||||
# Success metrics
|
||||
success: bool = Field(default=True)
|
||||
error_message: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Extraction quality metrics
|
||||
overall_confidence: float = Field(default=0.0)
|
||||
fields_extracted: int = Field(default=0) # Number of fields successfully extracted
|
||||
needs_manual_review: Optional[bool] = Field(default=None)
|
||||
validation_warnings_count: int = Field(default=0)
|
||||
validation_errors_count: int = Field(default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRMetricsSummary(SQLModel):
|
||||
"""
|
||||
Summary metrics for OCR analytics.
|
||||
|
||||
Not a database table - used for API responses.
|
||||
"""
|
||||
engine: str
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float # Computed: successful_jobs / total_jobs
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
avg_fields_extracted: float
|
||||
@@ -17,6 +17,7 @@ load_dotenv()
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from backend.modules.data_entry.db.models.ocr_settings import UserOCRPreference, OCRJobMetrics
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Add OCR settings and metrics tables.
|
||||
|
||||
Revision ID: add_ocr_settings_metrics
|
||||
Revises: 20251230_add_needs_manual_review
|
||||
Create Date: 2025-12-31
|
||||
|
||||
This migration adds:
|
||||
- user_ocr_preferences: Store user's preferred OCR engine
|
||||
- ocr_job_metrics: Store OCR job processing metrics for analytics
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_ocr_settings_metrics'
|
||||
down_revision = '20251230_add_needs_manual_review'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create OCR settings and metrics tables."""
|
||||
|
||||
# Create user_ocr_preferences table
|
||||
op.create_table(
|
||||
'user_ocr_preferences',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('preferred_engine', sa.String(length=20), nullable=False, server_default='doctr_plus'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_user_ocr_preferences_username', 'user_ocr_preferences', ['username'], unique=True)
|
||||
|
||||
# Create ocr_job_metrics table
|
||||
op.create_table(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=50), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('engine_requested', sa.String(length=20), nullable=False),
|
||||
sa.Column('engine_used', sa.String(length=50), nullable=False),
|
||||
sa.Column('processing_time_ms', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False, server_default='image/jpeg'),
|
||||
sa.Column('success', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True),
|
||||
sa.Column('overall_confidence', sa.Float(), nullable=False, server_default='0.0'),
|
||||
sa.Column('fields_extracted', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('needs_manual_review', sa.Boolean(), nullable=True),
|
||||
sa.Column('validation_warnings_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('validation_errors_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_ocr_job_metrics_job_id', 'ocr_job_metrics', ['job_id'], unique=True)
|
||||
op.create_index('ix_ocr_job_metrics_username', 'ocr_job_metrics', ['username'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_company_id', 'ocr_job_metrics', ['company_id'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_created_at', 'ocr_job_metrics', ['created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop OCR settings and metrics tables."""
|
||||
op.drop_index('ix_ocr_job_metrics_created_at', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_company_id', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_username', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_job_id', table_name='ocr_job_metrics')
|
||||
op.drop_table('ocr_job_metrics')
|
||||
|
||||
op.drop_index('ix_user_ocr_preferences_username', table_name='user_ocr_preferences')
|
||||
op.drop_table('user_ocr_preferences')
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Add original_filename to ocr_job_metrics.
|
||||
|
||||
Revision ID: add_original_filename_to_metrics
|
||||
Revises: add_ocr_settings_metrics
|
||||
Create Date: 2025-12-31
|
||||
|
||||
Adds original_filename column to track the uploaded filename.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_original_filename_to_metrics'
|
||||
down_revision = 'add_ocr_settings_metrics'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add original_filename column to ocr_job_metrics."""
|
||||
op.add_column(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove original_filename column."""
|
||||
op.drop_column('ocr_job_metrics', 'original_filename')
|
||||
@@ -11,6 +11,8 @@ def create_data_entry_router() -> APIRouter:
|
||||
- /receipts - Receipt CRUD and workflow
|
||||
- /ocr - OCR processing for receipts
|
||||
- /nomenclature - Nomenclature syncing from Oracle
|
||||
- /settings - User settings (OCR preferences)
|
||||
- /metrics - OCR analytics and metrics
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for data entry module
|
||||
@@ -21,10 +23,13 @@ def create_data_entry_router() -> APIRouter:
|
||||
from .receipts import router as receipts_router
|
||||
from .ocr import router as ocr_router
|
||||
from .nomenclature import router as nomenclature_router
|
||||
from .ocr_settings import router as ocr_settings_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
|
||||
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
|
||||
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
|
||||
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
|
||||
# OCR settings and metrics (endpoints at /settings/* and /metrics/*)
|
||||
router.include_router(ocr_settings_router, tags=["data-entry-settings"])
|
||||
|
||||
return router
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.modules.data_entry.services.ocr_service import ocr_service
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
|
||||
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
from backend.modules.data_entry.schemas.ocr import (
|
||||
OCRResponse,
|
||||
OCRStatusResponse,
|
||||
@@ -55,7 +56,7 @@ router = APIRouter()
|
||||
@router.post("/extract", response_model=OCRJobSubmitResponse)
|
||||
async def submit_ocr_job(
|
||||
file: UploadFile = File(...),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.auto, description="OCR engine to use"),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus, description="OCR engine to use"),
|
||||
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
@@ -69,7 +70,7 @@ async def submit_ocr_job(
|
||||
|
||||
Args:
|
||||
file: Image or PDF file (max 10MB)
|
||||
engine: OCR engine choice (auto, paddleocr, tesseract)
|
||||
engine: OCR engine choice (tesseract, doctr, doctr_plus, paddleocr)
|
||||
sync: If true, process synchronously (legacy mode)
|
||||
|
||||
Returns:
|
||||
@@ -129,13 +130,13 @@ async def submit_ocr_job(
|
||||
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
|
||||
async def get_job_status(
|
||||
job_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR job status and result.
|
||||
Get OCR job status and result (instant response).
|
||||
|
||||
Poll this endpoint to check job progress.
|
||||
Recommended polling interval: 2 seconds.
|
||||
For efficient polling, use GET /jobs/{job_id}/wait instead (long-polling).
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
@@ -165,6 +166,10 @@ async def get_job_status(
|
||||
result_data = None
|
||||
if job.status == JobStatus.completed and job.result:
|
||||
result_data = _dict_to_extraction_data(job.result)
|
||||
# Apply fuzzy CUI matching
|
||||
result_data = await _apply_fuzzy_cui_matching(result_data, session)
|
||||
# Debug: log suggested_payment_mode being returned
|
||||
print(f"[OCR Router] Returning job {job_id} with suggested_payment_mode={result_data.suggested_payment_mode}", flush=True)
|
||||
|
||||
return OCRJobResponse(
|
||||
job_id=job.id,
|
||||
@@ -174,12 +179,66 @@ async def get_job_status(
|
||||
created_at=job.created_at or datetime.utcnow(),
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
queue_wait_ms=job.queue_wait_ms,
|
||||
ocr_time_ms=job.ocr_time_ms,
|
||||
processing_time_ms=job.processing_time_ms,
|
||||
result=result_data,
|
||||
error=job.error_message
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/wait", response_model=OCRJobResponse)
|
||||
async def wait_for_job_status(
|
||||
job_id: str,
|
||||
timeout: int = Query(default=30, ge=1, le=60, description="Max wait time in seconds"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Long-poll for OCR job status change.
|
||||
|
||||
Waits until:
|
||||
- Job status changes to completed/failed
|
||||
- Timeout expires (returns current status)
|
||||
|
||||
Recommended client timeout: timeout + 5 seconds
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
timeout: Max wait time in seconds (1-60, default 30)
|
||||
|
||||
Returns:
|
||||
OCRJobResponse with status, queue_position, and result (if completed)
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
end_time = time.time() + timeout
|
||||
last_status = None
|
||||
|
||||
while time.time() < end_time:
|
||||
job = await job_queue.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# Return immediately if job completed or failed
|
||||
if job.status in [JobStatus.completed, JobStatus.failed]:
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
# Return if status changed from last check
|
||||
if last_status is not None and job.status != last_status:
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
last_status = job.status
|
||||
|
||||
# Wait 1 second before next internal check
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Timeout - return current status
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
|
||||
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
|
||||
async def get_queue_status(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
@@ -221,10 +280,58 @@ async def get_ocr_status():
|
||||
)
|
||||
|
||||
|
||||
@router.get("/engines")
|
||||
async def get_available_engines():
|
||||
"""
|
||||
Get list of enabled OCR engines based on .env configuration.
|
||||
|
||||
Returns engines availability and available processing modes.
|
||||
Frontend should use this to filter engine selection dropdown.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
default_engine = os.getenv("OCR_DEFAULT_ENGINE", "doctr_plus")
|
||||
|
||||
# Build engines dict
|
||||
engines = {
|
||||
"tesseract": tesseract_enabled,
|
||||
"doctr": True, # Always available (primary engine)
|
||||
"doctr_plus": True, # Always available (recommended)
|
||||
"paddleocr": paddle_enabled,
|
||||
}
|
||||
|
||||
# Build available modes based on enabled engines
|
||||
modes = []
|
||||
|
||||
if tesseract_enabled:
|
||||
modes.append("tesseract")
|
||||
|
||||
modes.append("doctr")
|
||||
modes.append("doctr_plus")
|
||||
|
||||
if paddle_enabled:
|
||||
modes.append("paddleocr")
|
||||
|
||||
return {
|
||||
"engines": engines,
|
||||
"available_modes": modes,
|
||||
"default_mode": default_engine,
|
||||
"memory_estimate_mb": {
|
||||
"tesseract": 50,
|
||||
"doctr": 600,
|
||||
"doctr_plus": 600,
|
||||
"paddleocr": 800,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
|
||||
async def extract_from_attachment(
|
||||
attachment_id: int,
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.auto),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
@@ -260,6 +367,8 @@ async def extract_from_attachment(
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
data = _result_to_extraction_data(result)
|
||||
# Apply fuzzy CUI matching
|
||||
data = await _apply_fuzzy_cui_matching(data, session)
|
||||
return OCRResponse(success=True, message=message, data=data)
|
||||
|
||||
|
||||
@@ -267,6 +376,58 @@ async def extract_from_attachment(
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _apply_fuzzy_cui_matching(
|
||||
extraction_data: ExtractionData,
|
||||
session: AsyncSession
|
||||
) -> ExtractionData:
|
||||
"""
|
||||
Apply fuzzy CUI matching to extraction data.
|
||||
|
||||
ONLY applies fuzzy matching if CUI is missing OR has invalid checksum.
|
||||
If CUI has valid checksum, we trust the OCR and skip fuzzy matching.
|
||||
|
||||
Args:
|
||||
extraction_data: ExtractionData with CUI to potentially correct
|
||||
session: AsyncSession for database lookups
|
||||
|
||||
Returns:
|
||||
ExtractionData with CUI corrected if a match was found
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
# Skip if no CUI and no vendor name (nothing to match)
|
||||
if not extraction_data.cui and not extraction_data.partner_name:
|
||||
return extraction_data
|
||||
|
||||
# Check if CUI has valid checksum - if valid, skip fuzzy matching
|
||||
if extraction_data.cui:
|
||||
cui_digits = CUIChecksumRule.extract_digits(extraction_data.cui)
|
||||
if len(cui_digits) >= 6 and CUIChecksumRule.validate_checksum(cui_digits):
|
||||
print(f"[Fuzzy Match] CUI {extraction_data.cui} has valid checksum, skipping fuzzy match", flush=True)
|
||||
return extraction_data
|
||||
|
||||
# CUI missing or invalid checksum - try fuzzy matching
|
||||
try:
|
||||
match = await OCRValidationEngine.fuzzy_match_supplier(
|
||||
cui=extraction_data.cui,
|
||||
vendor_name=extraction_data.partner_name,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
if corrected_cui != extraction_data.cui:
|
||||
print(f"[Fuzzy Match] Corrected: {extraction_data.cui} → {corrected_cui} ({supplier_name})", flush=True)
|
||||
extraction_data.cui = corrected_cui
|
||||
# Also set partner_name if not already set
|
||||
if not extraction_data.partner_name:
|
||||
extraction_data.partner_name = supplier_name
|
||||
except Exception as e:
|
||||
print(f"[Fuzzy Match] Error: {e}", flush=True)
|
||||
|
||||
return extraction_data
|
||||
|
||||
|
||||
async def _process_sync(
|
||||
content: bytes,
|
||||
file: UploadFile,
|
||||
@@ -362,6 +523,7 @@ def _result_to_extraction_data(result) -> ExtractionData:
|
||||
confidence_client=getattr(result, 'confidence_client', 0.0),
|
||||
overall_confidence=result.overall_confidence,
|
||||
raw_text=result.raw_text,
|
||||
raw_texts=getattr(result, 'raw_texts', []),
|
||||
ocr_engine=result.ocr_engine,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
needs_manual_review=result.needs_manual_review,
|
||||
@@ -437,6 +599,7 @@ def _dict_to_extraction_data(data: dict) -> ExtractionData:
|
||||
confidence_client=data.get('confidence_client', 0.0),
|
||||
overall_confidence=data.get('overall_confidence', 0.0),
|
||||
raw_text=data.get('raw_text', ''),
|
||||
raw_texts=data.get('raw_texts', []),
|
||||
ocr_engine=data.get('ocr_engine', ''),
|
||||
processing_time_ms=data.get('processing_time_ms', 0),
|
||||
needs_manual_review=data.get('needs_manual_review'),
|
||||
|
||||
268
backend/modules/data_entry/routers/ocr_settings.py
Normal file
268
backend/modules/data_entry/routers/ocr_settings.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
OCR Settings and Metrics API endpoints.
|
||||
|
||||
Endpoints:
|
||||
- GET /settings/ocr-preference - Get user's preferred OCR engine
|
||||
- POST /settings/ocr-preference - Set user's preferred OCR engine
|
||||
- GET /metrics/ocr/summary - Get OCR metrics summary by engine
|
||||
- GET /metrics/ocr/history - Get user's OCR job history
|
||||
- GET /metrics/ocr/stats - Get overall OCR statistics
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
from backend.modules.data_entry.db.models.ocr_settings import OCREngine, OCRMetricsSummary
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Schemas
|
||||
# ============================================================================
|
||||
|
||||
class OCRPreferenceResponse(BaseModel):
|
||||
"""Response for OCR preference endpoint."""
|
||||
username: str
|
||||
preferred_engine: str
|
||||
available_engines: List[str] = Field(
|
||||
default=["tesseract", "doctr", "doctr_plus", "paddleocr"],
|
||||
description="Available OCR engines"
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceRequest(BaseModel):
|
||||
"""Request to set OCR preference."""
|
||||
preferred_engine: str = Field(
|
||||
default="doctr_plus",
|
||||
description="Preferred OCR engine: tesseract, doctr, doctr_plus, paddleocr"
|
||||
)
|
||||
|
||||
|
||||
class OCRMetricsHistoryItem(BaseModel):
|
||||
"""Single OCR job metrics item."""
|
||||
job_id: str
|
||||
engine_requested: str
|
||||
engine_used: str
|
||||
processing_time_ms: int
|
||||
success: bool
|
||||
overall_confidence: float
|
||||
fields_extracted: int
|
||||
created_at: str
|
||||
original_filename: Optional[str] = None
|
||||
|
||||
|
||||
class OCRMetricsHistoryResponse(BaseModel):
|
||||
"""Response for OCR history endpoint."""
|
||||
items: List[OCRMetricsHistoryItem]
|
||||
total: int
|
||||
|
||||
|
||||
class OCRStatsResponse(BaseModel):
|
||||
"""Response for OCR stats endpoint."""
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
period_days: int
|
||||
|
||||
|
||||
class OCRActiveEnginesResponse(BaseModel):
|
||||
"""Response for active OCR engines endpoint."""
|
||||
engines: List[str] = Field(description="List of active OCR engines from .env config")
|
||||
recommended: str = Field(default="doctr_plus", description="Recommended engine")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Engines Configuration Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-engines", response_model=OCRActiveEnginesResponse)
|
||||
async def get_active_ocr_engines():
|
||||
"""
|
||||
Get list of active OCR engines configured in .env.
|
||||
|
||||
Returns the engines that should be shown in the frontend dropdown.
|
||||
Configured via OCR_ACTIVE_ENGINES environment variable.
|
||||
|
||||
Default: doctr,doctr_plus
|
||||
Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
"""
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
return OCRActiveEnginesResponse(
|
||||
engines=settings.ocr_active_engines_list,
|
||||
recommended="doctr_plus"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Preference Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def get_ocr_preference(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's preferred OCR engine.
|
||||
|
||||
Returns the user's saved preference or 'doctr_plus' if not set.
|
||||
Also returns list of available engines.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
preference = await OCRPreferenceCRUD.get_by_username(session, current_user.username)
|
||||
|
||||
# Get available engines from OCR service
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value if preference else "doctr_plus",
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
@router.post("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def set_ocr_preference(
|
||||
request: OCRPreferenceRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Set user's preferred OCR engine.
|
||||
|
||||
Valid engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
Note: Available engines depend on .env configuration (OCR_ENABLE_PADDLEOCR, OCR_ENABLE_TESSERACT)
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
# Get dynamically available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
if request.preferred_engine not in available:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid engine. Must be one of: {', '.join(available)}"
|
||||
)
|
||||
|
||||
# Map string to enum
|
||||
engine_map = {
|
||||
"tesseract": OCREngine.TESSERACT,
|
||||
"doctr": OCREngine.DOCTR,
|
||||
"doctr_plus": OCREngine.DOCTR_PLUS,
|
||||
"paddleocr": OCREngine.PADDLEOCR,
|
||||
}
|
||||
engine_enum = engine_map.get(request.preferred_engine, OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Save preference
|
||||
preference = await OCRPreferenceCRUD.create_or_update(
|
||||
session,
|
||||
current_user.username,
|
||||
engine_enum
|
||||
)
|
||||
|
||||
# Get available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value,
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Metrics Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/ocr/summary", response_model=List[OCRMetricsSummary])
|
||||
async def get_ocr_metrics_summary(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR metrics summary grouped by engine.
|
||||
|
||||
Returns aggregated metrics for each engine used in the specified period.
|
||||
"""
|
||||
summaries = await OCRMetricsCRUD.get_summary_by_engine(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/history", response_model=OCRMetricsHistoryResponse)
|
||||
async def get_ocr_metrics_history(
|
||||
limit: int = Query(default=50, ge=1, le=200, description="Max items to return"),
|
||||
offset: int = Query(default=0, ge=0, description="Items to skip"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's OCR job history.
|
||||
|
||||
Returns list of OCR jobs with their metrics, ordered by most recent first.
|
||||
"""
|
||||
items = await OCRMetricsCRUD.get_user_history(
|
||||
session,
|
||||
username=current_user.username,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
history_items = [
|
||||
OCRMetricsHistoryItem(
|
||||
job_id=item.job_id,
|
||||
engine_requested=item.engine_requested,
|
||||
engine_used=item.engine_used,
|
||||
processing_time_ms=item.processing_time_ms,
|
||||
success=item.success,
|
||||
overall_confidence=item.overall_confidence,
|
||||
fields_extracted=item.fields_extracted,
|
||||
created_at=item.created_at.isoformat(),
|
||||
original_filename=item.original_filename
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return OCRMetricsHistoryResponse(
|
||||
items=history_items,
|
||||
total=len(history_items)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/stats", response_model=OCRStatsResponse)
|
||||
async def get_ocr_stats(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get overall OCR statistics for the user.
|
||||
|
||||
Returns aggregated stats including success rate, average processing time, etc.
|
||||
"""
|
||||
stats = await OCRMetricsCRUD.get_overall_stats(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
|
||||
return OCRStatsResponse(**stats)
|
||||
@@ -61,7 +61,8 @@ class ExtractionData(BaseModel):
|
||||
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
|
||||
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
|
||||
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
|
||||
raw_text: str = Field(default="", description="Raw OCR text")
|
||||
raw_text: str = Field(default="", description="Raw OCR text (primary)")
|
||||
raw_texts: List[str] = Field(default=[], description="Raw OCR texts from all engine passes (for analysis)")
|
||||
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
|
||||
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
|
||||
|
||||
@@ -148,9 +149,10 @@ from enum import Enum
|
||||
|
||||
class OCREngineChoice(str, Enum):
|
||||
"""OCR engine selection options."""
|
||||
auto = "auto"
|
||||
paddleocr = "paddleocr"
|
||||
tesseract = "tesseract"
|
||||
doctr = "doctr" # 3.3x faster than PaddleOCR with same accuracy (90/100)
|
||||
doctr_plus = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
paddleocr = "paddleocr"
|
||||
|
||||
|
||||
class OCRJobStatus(str, Enum):
|
||||
@@ -193,7 +195,10 @@ class OCRJobResponse(BaseModel):
|
||||
created_at: datetime = Field(description="Job creation timestamp")
|
||||
started_at: Optional[datetime] = Field(default=None, description="Processing start timestamp")
|
||||
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
|
||||
processing_time_ms: Optional[int] = Field(default=None, description="Actual processing time in ms")
|
||||
# Detailed timing breakdown
|
||||
queue_wait_ms: Optional[int] = Field(default=None, description="Time waiting in queue (started_at - created_at)")
|
||||
ocr_time_ms: Optional[int] = Field(default=None, description="Actual OCR engine processing time")
|
||||
processing_time_ms: Optional[int] = Field(default=None, description="Total job processing time (completed_at - started_at)")
|
||||
result: Optional[ExtractionData] = Field(default=None, description="Extraction result (only if completed)")
|
||||
error: Optional[str] = Field(default=None, description="Error message (only if failed)")
|
||||
|
||||
|
||||
@@ -33,73 +33,55 @@ class NomenclatureService:
|
||||
"""
|
||||
Get partners (suppliers/customers) for a company.
|
||||
|
||||
Phase 1: Returns mock data.
|
||||
Phase 2: Returns synced data from SQLite (from Oracle sync).
|
||||
Phase 3: Will fetch live from Oracle.
|
||||
Returns synced suppliers from Oracle + local suppliers created from OCR.
|
||||
If no suppliers exist, returns empty list (frontend will trigger sync).
|
||||
"""
|
||||
# If session is provided, try to get from synced SQLite data
|
||||
if session:
|
||||
# Try to get from SQLite synced data
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name) # Order alphabetically, no limit for AutoComplete
|
||||
partners = []
|
||||
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
if suppliers:
|
||||
# Also get local suppliers
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name) # Order alphabetically
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
# Combine both - no IDs needed, just text data for autocomplete
|
||||
partners = []
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name, # No suffix - must match search results
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
# Fallback to mock data for Phase 1 (when no synced data)
|
||||
mock_partners = [
|
||||
PartnerOption(name="OMV Petrom", fiscal_code="RO123456", source="mock"),
|
||||
PartnerOption(name="Dedeman", fiscal_code="RO789012", source="mock"),
|
||||
PartnerOption(name="Kaufland", fiscal_code="RO345678", source="mock"),
|
||||
PartnerOption(name="Emag", fiscal_code="RO901234", source="mock"),
|
||||
PartnerOption(name="Altex", fiscal_code="RO567890", source="mock"),
|
||||
]
|
||||
if not session:
|
||||
return partners
|
||||
|
||||
# Get synced suppliers from Oracle
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
mock_partners = [
|
||||
p for p in mock_partners
|
||||
if search_lower in p.name.lower() or (p.fiscal_code and search_lower in p.fiscal_code.lower())
|
||||
]
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name)
|
||||
|
||||
return mock_partners
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
|
||||
# Always get local suppliers (not just when synced exist)
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name)
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name,
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
@staticmethod
|
||||
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
|
||||
@@ -13,13 +13,14 @@ Schema:
|
||||
status TEXT NOT NULL, -- pending, processing, completed, failed
|
||||
file_path TEXT NOT NULL, -- Path to uploaded file
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT, -- JSON extraction result
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
|
||||
ocr_time_ms INTEGER, -- Actual OCR engine processing time
|
||||
created_by TEXT, -- Username
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
@@ -74,17 +75,26 @@ class OCRJob:
|
||||
status: OCRJobStatus
|
||||
file_path: str
|
||||
mime_type: str
|
||||
engine: str = "auto"
|
||||
engine: str = "doctr_plus"
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result_json: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
|
||||
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
|
||||
created_by: Optional[str] = None
|
||||
original_filename: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def queue_wait_ms(self) -> Optional[int]:
|
||||
"""Calculate queue wait time (created_at to started_at)."""
|
||||
if self.created_at and self.started_at:
|
||||
delta = self.started_at - self.created_at
|
||||
return int(delta.total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> Optional[Dict]:
|
||||
"""Parse result_json to dict."""
|
||||
@@ -143,19 +153,27 @@ class OCRJobQueue:
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
file_path TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'auto',
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT,
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
ocr_time_ms INTEGER,
|
||||
created_by TEXT,
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Migration: add ocr_time_ms column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
|
||||
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Index for efficient queue queries
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
|
||||
@@ -177,7 +195,7 @@ class OCRJobQueue:
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
mime_type: str,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
username: Optional[str] = None,
|
||||
original_filename: Optional[str] = None
|
||||
) -> OCRJob:
|
||||
@@ -189,7 +207,7 @@ class OCRJobQueue:
|
||||
Args:
|
||||
file_bytes: Raw file bytes
|
||||
mime_type: MIME type of file
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
username: Username of requester
|
||||
original_filename: Original filename from upload
|
||||
|
||||
@@ -301,24 +319,52 @@ class OCRJobQueue:
|
||||
|
||||
async def get_next_pending(self) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get the next pending job (oldest first).
|
||||
Get the next pending job (oldest first) and atomically mark it as processing.
|
||||
|
||||
This prevents race conditions in parallel processing - only one worker
|
||||
can claim each job.
|
||||
|
||||
Returns:
|
||||
Next OCRJob to process or None if queue empty
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock: # Serialize access to prevent race conditions
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get the next pending job
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
job_id = row['id']
|
||||
|
||||
# Atomically mark as processing
|
||||
await db.execute('''
|
||||
UPDATE ocr_jobs
|
||||
SET status = 'processing', started_at = ?
|
||||
WHERE id = ? AND status = 'pending'
|
||||
''', (now.isoformat(), job_id))
|
||||
await db.commit()
|
||||
|
||||
# Fetch the updated job
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
updated_row = await cursor.fetchone()
|
||||
if updated_row:
|
||||
return self._row_to_job(updated_row)
|
||||
|
||||
return None
|
||||
|
||||
async def update_status(
|
||||
@@ -327,7 +373,8 @@ class OCRJobQueue:
|
||||
status: OCRJobStatus,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None
|
||||
processing_time_ms: Optional[int] = None,
|
||||
ocr_time_ms: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job status.
|
||||
@@ -337,7 +384,8 @@ class OCRJobQueue:
|
||||
status: New status
|
||||
result: Extraction result dict (for completed)
|
||||
error: Error message (for failed)
|
||||
processing_time_ms: Processing time
|
||||
processing_time_ms: Total job processing time (started_at to completed_at)
|
||||
ocr_time_ms: Actual OCR engine processing time
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
@@ -359,18 +407,18 @@ class OCRJobQueue:
|
||||
elif status == OCRJobStatus.completed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
elif status == OCRJobStatus.failed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, job_id)
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
else:
|
||||
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
|
||||
@@ -542,13 +590,14 @@ class OCRJobQueue:
|
||||
status=OCRJobStatus(row['status']),
|
||||
file_path=row['file_path'],
|
||||
mime_type=row['mime_type'],
|
||||
engine=row['engine'] or 'auto',
|
||||
engine=row['engine'] or 'doctr_plus',
|
||||
created_at=parse_datetime(row['created_at']),
|
||||
started_at=parse_datetime(row['started_at']),
|
||||
completed_at=parse_datetime(row['completed_at']),
|
||||
result_json=row['result_json'],
|
||||
error_message=row['error_message'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
|
||||
created_by=row['created_by'],
|
||||
original_filename=row['original_filename'],
|
||||
expires_at=parse_datetime(row['expires_at']),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
OCR Job Worker - Background Task for Queue Processing
|
||||
|
||||
Runs as an asyncio background task in FastAPI.
|
||||
Continuously polls the job queue and processes OCR requests.
|
||||
Continuously polls the job queue and processes OCR requests IN PARALLEL.
|
||||
|
||||
Architecture:
|
||||
FastAPI startup
|
||||
@@ -12,18 +12,19 @@ Architecture:
|
||||
asyncio.create_task(_job_worker_loop())
|
||||
↓
|
||||
while True:
|
||||
job = job_queue.get_next_pending()
|
||||
if job:
|
||||
result = ocr_worker_pool.submit_task(...)
|
||||
job_queue.update_status(...)
|
||||
await asyncio.sleep(0.5)
|
||||
# Process up to OCR_WORKERS jobs concurrently
|
||||
jobs = get_pending_jobs(limit=available_slots)
|
||||
for job in jobs:
|
||||
asyncio.create_task(_process_job(job))
|
||||
await asyncio.sleep(0.1)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from .job_queue import job_queue, OCRJobStatus, OCRJob
|
||||
from .ocr_worker_pool import ocr_worker_pool
|
||||
@@ -34,47 +35,76 @@ logger = logging.getLogger(__name__)
|
||||
_job_worker_task: Optional[asyncio.Task] = None
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
|
||||
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL_SECONDS = 0.5 # How often to check for new jobs
|
||||
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
|
||||
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
|
||||
|
||||
|
||||
async def _job_worker_loop() -> None:
|
||||
"""
|
||||
Main worker loop - processes jobs from queue.
|
||||
Main worker loop - processes jobs from queue IN PARALLEL.
|
||||
|
||||
Runs continuously until shutdown. Polls queue every 0.5s
|
||||
and submits jobs to worker pool for processing.
|
||||
Runs continuously until shutdown. Uses semaphore to limit
|
||||
concurrent jobs to OCR_WORKERS count. Launches jobs as
|
||||
background tasks without waiting for completion.
|
||||
"""
|
||||
global _shutdown_event
|
||||
global _shutdown_event, _active_tasks, _concurrency_semaphore
|
||||
|
||||
logger.info("[JobWorker] Starting worker loop...")
|
||||
# Get max concurrent jobs from env (matches worker pool size)
|
||||
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
|
||||
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
_active_tasks = set()
|
||||
|
||||
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
max_consecutive_errors = 10
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
await _process_job(job)
|
||||
else:
|
||||
# No jobs - wait before polling again
|
||||
# Clean up completed tasks
|
||||
done_tasks = {t for t in _active_tasks if t.done()}
|
||||
for task in done_tasks:
|
||||
_active_tasks.discard(task)
|
||||
# Check for exceptions
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout, continue loop
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Task failed: {e}")
|
||||
|
||||
# Check if we have capacity for more jobs
|
||||
active_count = len(_active_tasks)
|
||||
available_slots = max_concurrent - active_count
|
||||
|
||||
if available_slots > 0:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0
|
||||
# Launch job processing as background task
|
||||
task = asyncio.create_task(_process_job_with_semaphore(job))
|
||||
_active_tasks.add(task)
|
||||
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
|
||||
else:
|
||||
# No pending jobs - wait briefly
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# At capacity - wait for a slot to free up
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Worker loop cancelled")
|
||||
@@ -88,27 +118,46 @@ async def _job_worker_loop() -> None:
|
||||
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
|
||||
break
|
||||
|
||||
# Backoff on errors
|
||||
await asyncio.sleep(min(consecutive_errors * 2, 30))
|
||||
|
||||
# Wait for active tasks to complete on shutdown
|
||||
if _active_tasks:
|
||||
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
|
||||
await asyncio.gather(*_active_tasks, return_exceptions=True)
|
||||
|
||||
logger.info("[JobWorker] Worker loop stopped")
|
||||
|
||||
|
||||
async def _process_job_with_semaphore(job: OCRJob) -> None:
|
||||
"""
|
||||
Process job with semaphore to limit concurrency.
|
||||
|
||||
Acquires semaphore before processing, releases after.
|
||||
This ensures we don't exceed OCR_WORKERS concurrent jobs.
|
||||
"""
|
||||
global _concurrency_semaphore
|
||||
|
||||
async with _concurrency_semaphore:
|
||||
await _process_job(job)
|
||||
|
||||
|
||||
async def _process_job(job: OCRJob) -> None:
|
||||
"""
|
||||
Process a single OCR job.
|
||||
|
||||
Reads file, submits to worker pool, updates job status.
|
||||
Reads file, submits to worker pool, updates job status,
|
||||
and saves metrics for analytics.
|
||||
|
||||
Args:
|
||||
job: OCRJob to process
|
||||
"""
|
||||
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
|
||||
start_time = time.time()
|
||||
file_size = 0
|
||||
file_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# Mark as processing
|
||||
await job_queue.update_status(job.id, OCRJobStatus.processing)
|
||||
# Note: Job already marked as 'processing' atomically in get_next_pending()
|
||||
|
||||
# Read file bytes
|
||||
file_path = Path(job.file_path)
|
||||
@@ -118,6 +167,10 @@ async def _process_job(job: OCRJob) -> None:
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
file_size = len(file_bytes)
|
||||
# Determine file type from job or extension
|
||||
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
|
||||
|
||||
# Submit to worker pool
|
||||
result = await ocr_worker_pool.submit_task(
|
||||
image_bytes=file_bytes,
|
||||
@@ -132,14 +185,43 @@ async def _process_job(job: OCRJob) -> None:
|
||||
# Job completed successfully
|
||||
extraction = result.get("extraction", {})
|
||||
|
||||
# Include raw_texts for analysis (from all OCR engine passes)
|
||||
extraction['raw_texts'] = result.get("raw_texts", [])
|
||||
|
||||
# Extract actual OCR processing time from extraction result
|
||||
ocr_time_ms = extraction.get('processing_time_ms', 0)
|
||||
|
||||
# Debug: log suggested_payment_mode
|
||||
spm = extraction.get('suggested_payment_mode')
|
||||
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.completed,
|
||||
result=extraction,
|
||||
processing_time_ms=elapsed_ms
|
||||
processing_time_ms=elapsed_ms,
|
||||
ocr_time_ms=ocr_time_ms
|
||||
)
|
||||
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms")
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
|
||||
|
||||
# Save metrics for successful job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=extraction.get('ocr_engine', job.engine),
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=True,
|
||||
overall_confidence=extraction.get('overall_confidence', 0.0),
|
||||
fields_extracted=_count_extracted_fields(extraction),
|
||||
needs_manual_review=extraction.get('needs_manual_review'),
|
||||
validation_warnings_count=len(extraction.get('validation_warnings', [])),
|
||||
validation_errors_count=len(extraction.get('validation_errors', [])),
|
||||
)
|
||||
|
||||
else:
|
||||
# Job failed
|
||||
@@ -154,6 +236,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
|
||||
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
|
||||
|
||||
# Save metrics for failed job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -166,6 +262,20 @@ async def _process_job(job: OCRJob) -> None:
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
# Save metrics for error job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup file after processing
|
||||
try:
|
||||
@@ -340,3 +450,96 @@ def estimate_wait_time(queue_position: int) -> int:
|
||||
|
||||
# Estimate: position * average_time
|
||||
return int(queue_position * avg_time)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _save_job_metrics(
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Save OCR job metrics to database for analytics.
|
||||
|
||||
Called after each job completes (success or failure).
|
||||
Errors are logged but don't affect job processing.
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
|
||||
|
||||
async with await get_db_session() as session:
|
||||
await OCRMetricsCRUD.create(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - metrics are nice-to-have
|
||||
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
|
||||
|
||||
|
||||
def _count_extracted_fields(extraction: dict) -> int:
|
||||
"""
|
||||
Count number of successfully extracted fields from OCR result.
|
||||
|
||||
Counts non-None values in key fields.
|
||||
"""
|
||||
key_fields = [
|
||||
'receipt_number',
|
||||
'receipt_date',
|
||||
'amount',
|
||||
'partner_name',
|
||||
'cui',
|
||||
'tva_total',
|
||||
'address',
|
||||
'items_count',
|
||||
]
|
||||
|
||||
count = 0
|
||||
for field in key_fields:
|
||||
value = extraction.get(field)
|
||||
if value is not None and value != '' and value != []:
|
||||
count += 1
|
||||
|
||||
# Also count TVA entries if present
|
||||
tva_entries = extraction.get('tva_entries', [])
|
||||
if tva_entries and len(tva_entries) > 0:
|
||||
count += 1
|
||||
|
||||
# Count payment methods if present
|
||||
payment_methods = extraction.get('payment_methods', [])
|
||||
if payment_methods and len(payment_methods) > 0:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
OCR Worker Pool Manager
|
||||
|
||||
Manages a ProcessPoolExecutor with persistent PaddleOCR initialization.
|
||||
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
|
||||
Key features:
|
||||
- ProcessPoolExecutor with max_workers=1 (sequential, no memory leak)
|
||||
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
|
||||
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
|
||||
- mp_context='spawn' for Windows IIS compatibility
|
||||
- PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- atexit + signal handlers for cleanup
|
||||
- Health check with auto-respawn
|
||||
- Orphan process cleanup on Windows
|
||||
@@ -29,7 +30,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -48,8 +49,8 @@ class OCRWorkerPool:
|
||||
"""
|
||||
Singleton manager for OCR ProcessPoolExecutor.
|
||||
|
||||
Ensures PaddleOCR is loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=None to keep worker alive indefinitely.
|
||||
Ensures OCR engines are loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
|
||||
"""
|
||||
|
||||
_instance: Optional["OCRWorkerPool"] = None
|
||||
@@ -86,7 +87,7 @@ class OCRWorkerPool:
|
||||
Initialize the ProcessPoolExecutor.
|
||||
|
||||
Creates executor with spawn context for Windows compatibility.
|
||||
Uses max_tasks_per_child=None to keep worker alive (persistent PaddleOCR).
|
||||
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
@@ -103,18 +104,30 @@ class OCRWorkerPool:
|
||||
# Cleanup any orphan workers from previous runs
|
||||
self._cleanup_orphan_workers()
|
||||
|
||||
# Read configuration from environment
|
||||
max_workers = int(os.getenv('OCR_WORKERS', '2'))
|
||||
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
|
||||
# 0 means no restart (None in ProcessPoolExecutor)
|
||||
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
|
||||
|
||||
# Create executor with spawn context (Windows compatible)
|
||||
# Use mp_context='spawn' explicitly for cross-platform consistency
|
||||
mp_context = mp.get_context('spawn')
|
||||
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=1, # Single worker for sequential processing
|
||||
mp_context=mp_context,
|
||||
initializer=_worker_initializer,
|
||||
max_tasks_per_child=None, # Keep worker alive indefinitely
|
||||
)
|
||||
# max_tasks_per_child only available in Python 3.11+
|
||||
executor_kwargs = {
|
||||
'max_workers': max_workers,
|
||||
'mp_context': mp_context,
|
||||
'initializer': _worker_initializer,
|
||||
}
|
||||
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
|
||||
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
|
||||
else:
|
||||
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
logger.info("[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers=1)")
|
||||
self._executor = ProcessPoolExecutor(**executor_kwargs)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +186,7 @@ class OCRWorkerPool:
|
||||
async def submit_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
timeout: float = 120.0
|
||||
) -> dict:
|
||||
@@ -182,7 +195,7 @@ class OCRWorkerPool:
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine ('auto', 'paddleocr', 'tesseract')
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
|
||||
timeout: Maximum processing time in seconds
|
||||
|
||||
@@ -339,6 +352,7 @@ class OCRWorkerPool:
|
||||
# Global engines - persist between tasks in worker process
|
||||
_paddle_engine = None
|
||||
_tesseract_engine = None
|
||||
_doctr_engine = None # docTR engine (PyTorch backend)
|
||||
_worker_initialized = False
|
||||
|
||||
|
||||
@@ -346,40 +360,92 @@ def _worker_initializer() -> None:
|
||||
"""
|
||||
Called once when worker process spawns.
|
||||
|
||||
Initializes global OCR engines that persist between tasks.
|
||||
This is where PaddleOCR loading happens (15-20 seconds).
|
||||
Initializes global OCR engines IN PARALLEL for faster startup.
|
||||
Uses ThreadPoolExecutor to load enabled engines concurrently.
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
|
||||
Total warmup time = max(engine_times) instead of sum(engine_times).
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
if _worker_initialized:
|
||||
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
|
||||
return
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines...", flush=True)
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
|
||||
if paddle_enabled:
|
||||
enabled_engines.append("paddle")
|
||||
if tesseract_enabled:
|
||||
enabled_engines.append("tesseract")
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
|
||||
if not paddle_enabled:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
|
||||
if not tesseract_enabled:
|
||||
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize PaddleOCR
|
||||
try:
|
||||
# Import inside worker to avoid import issues in main process
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
_paddle_engine = initialize_paddle_engine()
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR init failed: {e}", flush=True)
|
||||
_paddle_engine = None
|
||||
# Define loader functions - each runs in its own thread
|
||||
def load_doctr():
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
|
||||
engine = initialize_doctr_engine()
|
||||
return ("doctr", engine, None)
|
||||
except Exception as e:
|
||||
return ("doctr", None, str(e))
|
||||
|
||||
# Initialize Tesseract
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
_tesseract_engine = TesseractEngine()
|
||||
print(f"[Worker {os.getpid()}] Tesseract loaded", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Tesseract init failed: {e}", flush=True)
|
||||
_tesseract_engine = None
|
||||
def load_paddle():
|
||||
if not paddle_enabled:
|
||||
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
engine = initialize_paddle_engine()
|
||||
return ("paddle", engine, None)
|
||||
except Exception as e:
|
||||
return ("paddle", None, str(e))
|
||||
|
||||
def load_tesseract():
|
||||
if not tesseract_enabled:
|
||||
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
engine = TesseractEngine()
|
||||
return ("tesseract", engine, None)
|
||||
except Exception as e:
|
||||
return ("tesseract", None, str(e))
|
||||
|
||||
# Build list of futures for enabled engines only
|
||||
futures_to_submit = [load_doctr] # docTR always loaded
|
||||
if paddle_enabled:
|
||||
futures_to_submit.append(load_paddle)
|
||||
if tesseract_enabled:
|
||||
futures_to_submit.append(load_tesseract)
|
||||
|
||||
# Load engines in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
|
||||
futures = [executor.submit(fn) for fn in futures_to_submit]
|
||||
|
||||
for future in as_completed(futures):
|
||||
name, engine, error = future.result()
|
||||
if error and "disabled" not in error:
|
||||
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
|
||||
elif engine:
|
||||
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
|
||||
if name == "doctr":
|
||||
_doctr_engine = engine
|
||||
elif name == "paddle":
|
||||
_paddle_engine = engine
|
||||
elif name == "tesseract":
|
||||
_tesseract_engine = engine
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_worker_initialized = True
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s", flush=True)
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
|
||||
|
||||
|
||||
def _warmup_task() -> dict:
|
||||
@@ -389,7 +455,7 @@ def _warmup_task() -> dict:
|
||||
Called at FastAPI startup to pre-warm the worker.
|
||||
Returns success status and worker PID.
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -400,6 +466,14 @@ def _warmup_task() -> dict:
|
||||
import numpy as np
|
||||
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# Test docTR if available (fastest engine)
|
||||
if _doctr_engine is not None:
|
||||
try:
|
||||
_doctr_engine([dummy_img])
|
||||
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
|
||||
|
||||
# Test PaddleOCR if available
|
||||
if _paddle_engine is not None:
|
||||
try:
|
||||
@@ -414,6 +488,7 @@ def _warmup_task() -> dict:
|
||||
return {
|
||||
"success": True,
|
||||
"pid": os.getpid(),
|
||||
"doctr_available": _doctr_engine is not None,
|
||||
"paddle_available": _paddle_engine is not None,
|
||||
"tesseract_available": _tesseract_engine is not None
|
||||
}
|
||||
@@ -428,7 +503,7 @@ def _warmup_task() -> dict:
|
||||
|
||||
def _process_ocr_task(
|
||||
image_bytes: bytes,
|
||||
engine: str = "auto",
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto"
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -439,13 +514,13 @@ def _process_ocr_task(
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine choice
|
||||
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _worker_initialized
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
@@ -461,7 +536,8 @@ def _process_ocr_task(
|
||||
paddle_engine=_paddle_engine,
|
||||
tesseract_engine=_tesseract_engine,
|
||||
engine=engine,
|
||||
preprocessing=preprocessing
|
||||
preprocessing=preprocessing,
|
||||
doctr_engine=_doctr_engine
|
||||
)
|
||||
|
||||
# Cleanup after each task
|
||||
|
||||
@@ -6,6 +6,7 @@ Handles OCR processing with persistent engine instances.
|
||||
|
||||
Key features:
|
||||
- PaddleOCR initialized ONCE at process spawn
|
||||
- docTR initialized ONCE at process spawn (PyTorch backend)
|
||||
- Tesseract as fallback/complement engine
|
||||
- Multi-pass preprocessing (light → medium → tesseract)
|
||||
- Automatic engine selection based on results
|
||||
@@ -26,6 +27,13 @@ import numpy as np
|
||||
# Disable PaddleOCR model source check for faster startup
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
# Memory optimization for docTR (prevents memory leak in multiprocessing)
|
||||
# Source: https://github.com/mindee/doctr/issues/1594
|
||||
os.environ['DOCTR_MULTIPROCESSING_DISABLE'] = 'TRUE'
|
||||
|
||||
# Reduce Intel oneDNN cache to save memory
|
||||
os.environ['ONEDNN_PRIMITIVE_CACHE_CAPACITY'] = '1'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
@@ -71,25 +79,67 @@ def initialize_paddle_engine():
|
||||
return None
|
||||
|
||||
|
||||
def initialize_doctr_engine():
|
||||
"""
|
||||
Initialize docTR engine (CPU only).
|
||||
|
||||
Called once at worker spawn. Returns the engine instance
|
||||
that will be reused for all subsequent requests.
|
||||
|
||||
Note: DirectML (AMD GPU) has compatibility issues with docTR.
|
||||
CUDA (NVIDIA) works but requires separate PyTorch build.
|
||||
CPU mode is stable and well-optimized.
|
||||
|
||||
Returns:
|
||||
docTR predictor instance or None if unavailable
|
||||
"""
|
||||
try:
|
||||
print(f"[Worker {os.getpid()}] Loading docTR (PyTorch backend, CPU)...", flush=True)
|
||||
start_time = time.time()
|
||||
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[Worker {os.getpid()}] docTR loaded in {elapsed:.1f}s", flush=True)
|
||||
return doctr
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR init failed: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def process_ocr(
|
||||
image_bytes: bytes,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
engine: str = "auto",
|
||||
preprocessing: str = "auto"
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
doctr_engine=None
|
||||
) -> dict:
|
||||
"""
|
||||
Process OCR on image bytes.
|
||||
|
||||
Main entry point for OCR processing in worker process.
|
||||
Uses adaptive multi-pass strategy for best results.
|
||||
Uses the specified engine for text recognition.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (JPEG, PNG, or PDF)
|
||||
paddle_engine: Pre-initialized PaddleOCR instance (or None)
|
||||
tesseract_engine: Pre-initialized TesseractEngine instance (or None)
|
||||
engine: Engine selection ('auto', 'paddleocr', 'tesseract')
|
||||
engine: Engine selection ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('auto', 'light', 'medium', 'heavy')
|
||||
doctr_engine: Pre-initialized docTR instance (or None)
|
||||
|
||||
Returns:
|
||||
Dict with extraction results:
|
||||
@@ -101,14 +151,20 @@ def process_ocr(
|
||||
"ocr_engine": str
|
||||
}
|
||||
"""
|
||||
import sys
|
||||
start_time = time.time()
|
||||
print(f"[Worker {os.getpid()}] Processing OCR: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
# Decode image from bytes
|
||||
print(f"[Worker {os.getpid()}] Decoding image...", flush=True)
|
||||
sys.stdout.flush()
|
||||
image = _decode_image(image_bytes)
|
||||
if image is None:
|
||||
return {"success": False, "error": "Failed to decode image"}
|
||||
print(f"[Worker {os.getpid()}] Image decoded: shape={image.shape}, dtype={image.dtype}, size={image.nbytes/1024/1024:.1f}MB", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Import preprocessor
|
||||
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
@@ -116,22 +172,36 @@ def process_ocr(
|
||||
|
||||
preprocessor = ImagePreprocessor()
|
||||
extractor = ReceiptExtractor()
|
||||
print(f"[Worker {os.getpid()}] Preprocessor and extractor initialized", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# Engine routing
|
||||
if engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "tesseract":
|
||||
# Engine routing (available: tesseract, doctr, doctr_plus, paddleocr)
|
||||
print(f"[Worker {os.getpid()}] Routing to engine: {engine}", flush=True)
|
||||
sys.stdout.flush()
|
||||
if engine == "tesseract":
|
||||
extraction, raw_texts = _process_tesseract_only(
|
||||
image, tesseract_engine, preprocessor, extractor
|
||||
)
|
||||
else: # auto
|
||||
extraction, raw_texts = _process_adaptive(
|
||||
image, paddle_engine, tesseract_engine, preprocessor, extractor
|
||||
elif engine == "doctr":
|
||||
extraction, raw_texts = _process_doctr_only(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "doctr_plus":
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
elif engine == "paddleocr":
|
||||
extraction, raw_texts = _process_paddleocr_only(
|
||||
image, paddle_engine, preprocessor, extractor
|
||||
)
|
||||
else:
|
||||
# Default to doctr_plus if unknown engine specified
|
||||
print(f"[OCR] Unknown engine '{engine}', defaulting to doctr_plus", flush=True)
|
||||
extraction, raw_texts = _process_doctr_plus(
|
||||
image, doctr_engine, preprocessor, extractor
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
@@ -171,7 +241,11 @@ def process_ocr(
|
||||
|
||||
|
||||
def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF)."""
|
||||
"""Decode image from bytes (JPEG, PNG, or first page of PDF).
|
||||
|
||||
For PDFs, uses 200 DPI which is sufficient for receipt OCR
|
||||
and reduces processing time by ~50% vs 300 DPI.
|
||||
"""
|
||||
try:
|
||||
# Try as regular image first
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
@@ -180,18 +254,21 @@ def _decode_image(image_bytes: bytes) -> Optional[np.ndarray]:
|
||||
if image is not None:
|
||||
return image
|
||||
|
||||
# Try as PDF
|
||||
# Try as PDF - use 200 DPI for faster processing (sufficient for receipts)
|
||||
try:
|
||||
import pdf2image
|
||||
from PIL import Image
|
||||
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=300)
|
||||
# 200 DPI is sufficient for receipt text recognition
|
||||
# 300 DPI was overkill and slowed down processing
|
||||
images = pdf2image.convert_from_bytes(image_bytes, dpi=200)
|
||||
if images:
|
||||
# Convert first page to numpy array
|
||||
pil_img = images[0]
|
||||
print(f"[Worker {os.getpid()}] PDF decoded: {pil_img.width}x{pil_img.height} @ 200 DPI", flush=True)
|
||||
return np.array(pil_img)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PDF decode error: {e}", flush=True)
|
||||
|
||||
return None
|
||||
|
||||
@@ -270,83 +347,275 @@ def _process_tesseract_only(
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_adaptive(
|
||||
def _process_doctr_only(
|
||||
image: np.ndarray,
|
||||
paddle_engine,
|
||||
tesseract_engine,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
Adaptive multi-pass OCR processing.
|
||||
Process using docTR only (light + medium preprocessing).
|
||||
|
||||
Strategy:
|
||||
1. PaddleOCR Light - fastest, best for clear PDFs
|
||||
2. PaddleOCR Medium - if Light incomplete
|
||||
3. Tesseract - complement missing fields only
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
docTR uses EXACT same preprocessing as PaddleOCR for consistency.
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# === STEP 1: PaddleOCR Light ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 1: PaddleOCR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
paddle_light = _paddle_recognize(paddle_engine, light_img)
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
if paddle_light and paddle_light.text:
|
||||
extraction = extractor.extract(paddle_light.text)
|
||||
extraction.ocr_engine = "paddle-light"
|
||||
raw_texts.append(f"=== PaddleOCR Light (conf: {paddle_light.confidence:.0%}) ===\n{paddle_light.text}")
|
||||
# Step 1: Light preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 1: docTR + Light", flush=True)
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found in Step 1", flush=True)
|
||||
return extraction, raw_texts
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-light"
|
||||
raw_texts.append(f"=== docTR Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# === STEP 2: PaddleOCR Medium ===
|
||||
if paddle_engine:
|
||||
print("[OCR] Step 2: PaddleOCR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
paddle_medium = _paddle_recognize(paddle_engine, medium_img)
|
||||
if _is_extraction_complete(extraction):
|
||||
return extraction, raw_texts
|
||||
|
||||
if paddle_medium and paddle_medium.text:
|
||||
extraction_medium = extractor.extract(paddle_medium.text)
|
||||
extraction_medium.ocr_engine = "paddle-medium"
|
||||
raw_texts.append(f"=== PaddleOCR Medium (conf: {paddle_medium.confidence:.0%}) ===\n{paddle_medium.text}")
|
||||
# Step 2: Medium preprocessing (same as PaddleOCR)
|
||||
print("[OCR] Step 2: docTR + Medium", flush=True)
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "paddle-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-medium"
|
||||
raw_texts.append(f"=== docTR Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if _is_extraction_complete(extraction):
|
||||
print("[OCR] Early exit - all fields found after Step 2", flush=True)
|
||||
return extraction, raw_texts
|
||||
|
||||
# === STEP 3: Tesseract (complement only) ===
|
||||
if tesseract_engine:
|
||||
print("[OCR] Step 3: Tesseract complement", flush=True)
|
||||
tesseract_img = preprocessor.preprocess_for_tesseract(image)
|
||||
tesseract_result = tesseract_engine.recognize(tesseract_img)
|
||||
|
||||
if tesseract_result and tesseract_result.text:
|
||||
extraction_tess = extractor.extract(tesseract_result.text)
|
||||
extraction_tess.ocr_engine = "tesseract"
|
||||
raw_texts.append(f"=== Tesseract (conf: {tesseract_result.confidence:.0%}) ===\n{tesseract_result.text}")
|
||||
|
||||
if extraction:
|
||||
extraction = _complement_extraction(extraction, extraction_tess)
|
||||
extraction.ocr_engine = "adaptive-full"
|
||||
else:
|
||||
extraction = extraction_tess
|
||||
if extraction:
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
extraction.ocr_engine = "doctr-adaptive"
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
def _process_doctr_plus(
|
||||
image: np.ndarray,
|
||||
doctr_engine,
|
||||
preprocessor,
|
||||
extractor
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""
|
||||
docTR Plus - Optimized 2-tier sequential processing with early exit.
|
||||
|
||||
Architecture:
|
||||
- Tier 1: Light preprocessing (~4-5s)
|
||||
→ Early exit if confidence >= 0.75 AND all fields valid AND cross-validations pass
|
||||
- Tier 2: Medium preprocessing (only if Tier 1 insufficient, ~4-5s additional)
|
||||
→ Merge with Tier 1 results
|
||||
→ Mark for review if still problems
|
||||
|
||||
Performance:
|
||||
- Fast path (80% receipts): ~4-5s (Tier 1 only)
|
||||
- Slow path (20% receipts): ~8-9s (Tier 1 + Tier 2)
|
||||
- Average: ~5-6s
|
||||
|
||||
Returns:
|
||||
Tuple of (extraction_result, raw_texts_list)
|
||||
extraction_result.needs_review = True if validation issues remain
|
||||
"""
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
if doctr_engine is None:
|
||||
return None, ["docTR not available"]
|
||||
|
||||
# ========== TIER 1: Light Preprocessing ==========
|
||||
print("[docTR+] TIER 1: Light preprocessing", flush=True)
|
||||
import time
|
||||
tier1_start = time.time()
|
||||
|
||||
light_img = preprocessor.preprocess_light(image)
|
||||
doctr_light = _doctr_recognize(doctr_engine, light_img)
|
||||
|
||||
tier1_time = time.time() - tier1_start
|
||||
print(f"[docTR+] TIER 1 completed in {tier1_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_light and doctr_light.text:
|
||||
extraction = extractor.extract(doctr_light.text)
|
||||
extraction.ocr_engine = "doctr-plus-light"
|
||||
raw_texts.append(f"=== docTR+ Tier1/Light (conf: {doctr_light.confidence:.0%}) ===\n{doctr_light.text}")
|
||||
|
||||
# Early Exit Check: confidence >= 0.75 + cross-validations
|
||||
if _is_extraction_valid_for_early_exit(extraction, min_confidence=0.75):
|
||||
print(f"[docTR+] EARLY EXIT - Tier 1 sufficient (conf: {extraction.overall_confidence:.0%})", flush=True)
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
return extraction, raw_texts
|
||||
|
||||
print(f"[docTR+] Tier 1 incomplete or validation failed, proceeding to Tier 2...", flush=True)
|
||||
|
||||
# ========== TIER 2: Medium Preprocessing (only if needed) ==========
|
||||
print("[docTR+] TIER 2: Medium preprocessing", flush=True)
|
||||
tier2_start = time.time()
|
||||
|
||||
medium_img = preprocessor.preprocess_medium(image)
|
||||
doctr_medium = _doctr_recognize(doctr_engine, medium_img)
|
||||
|
||||
tier2_time = time.time() - tier2_start
|
||||
print(f"[docTR+] TIER 2 completed in {tier2_time:.1f}s", flush=True)
|
||||
|
||||
if doctr_medium and doctr_medium.text:
|
||||
extraction_medium = extractor.extract(doctr_medium.text)
|
||||
extraction_medium.ocr_engine = "doctr-plus-medium"
|
||||
raw_texts.append(f"=== docTR+ Tier2/Medium (conf: {doctr_medium.confidence:.0%}) ===\n{doctr_medium.text}")
|
||||
|
||||
if extraction:
|
||||
# Merge Tier 1 + Tier 2 results
|
||||
extraction = _merge_extractions(extraction, extraction_medium)
|
||||
else:
|
||||
extraction = extraction_medium
|
||||
|
||||
# ========== FINAL VALIDATION ==========
|
||||
if extraction:
|
||||
extraction.ocr_engine = "doctr-plus"
|
||||
|
||||
# Mark for review if validation still fails after both tiers
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation or extraction.overall_confidence < 0.75:
|
||||
# Mark for human review using existing fields
|
||||
extraction.needs_manual_review = True
|
||||
|
||||
if extraction.overall_confidence < 0.75:
|
||||
extraction.validation_warnings.append(f"Low confidence: {extraction.overall_confidence:.0%}")
|
||||
|
||||
if not extraction.amount:
|
||||
extraction.validation_errors.append("TOTAL not detected")
|
||||
if not extraction.cui:
|
||||
extraction.validation_warnings.append("CUI not detected")
|
||||
if not extraction.tva_total and not extraction.tva_entries:
|
||||
extraction.validation_warnings.append("TVA not detected")
|
||||
if not extraction.receipt_date:
|
||||
extraction.validation_warnings.append("Date not detected")
|
||||
|
||||
# Add cross-validation errors
|
||||
extraction.validation_errors.extend(errors)
|
||||
|
||||
print(f"[docTR+] Marked for review: {extraction.validation_errors + extraction.validation_warnings}", flush=True)
|
||||
else:
|
||||
extraction.needs_manual_review = False
|
||||
|
||||
total_time = tier1_time + (tier2_time if 'tier2_time' in dir() else 0)
|
||||
print(f"[docTR+] Total processing time: {total_time:.1f}s", flush=True)
|
||||
|
||||
return extraction, raw_texts
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VALIDATION HELPERS (used by doctr_plus for early exit decisions)
|
||||
# =============================================================================
|
||||
|
||||
def _quick_cross_validate(extraction) -> tuple[bool, float, list[str]]:
|
||||
"""
|
||||
Quick cross-validation for OCR results.
|
||||
|
||||
Checks critical field correlations to detect obvious OCR errors.
|
||||
Used by doctr_plus to decide whether to proceed to Tier 2 or exit early.
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
if extraction is None:
|
||||
return False, 1.0, ["No extraction result"]
|
||||
|
||||
# Convert extraction to dict for validation
|
||||
# Build TVA entries dict for TVAEntriesSumRule (expects {code: amount})
|
||||
tva_entries_dict = {}
|
||||
if extraction.tva_entries:
|
||||
for entry in extraction.tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
code = entry.get('code', 'A')
|
||||
amount = entry.get('amount', 0)
|
||||
try:
|
||||
tva_entries_dict[code] = float(amount)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
validation_data = {
|
||||
"amount": float(extraction.amount) if extraction.amount else None,
|
||||
"tva": float(extraction.tva_total) if extraction.tva_total else None,
|
||||
"tva_entries": tva_entries_dict, # For TVAEntriesSumRule: {code: amount}
|
||||
"cui": extraction.cui, # For CUI checksum validation
|
||||
}
|
||||
|
||||
# Also pass raw tva_entries for TVABasedTotalRule (for rate detection)
|
||||
if extraction.tva_entries:
|
||||
validation_data['tva_entries_raw'] = extraction.tva_entries
|
||||
|
||||
# Add payment methods if available (for TOTAL vs CARD+CASH validation)
|
||||
if extraction.payment_methods:
|
||||
try:
|
||||
card_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'CARD'
|
||||
)
|
||||
cash_amount = sum(
|
||||
float(p.get('amount', 0) if isinstance(p, dict) else 0)
|
||||
for p in extraction.payment_methods
|
||||
if isinstance(p, dict) and p.get('method') == 'NUMERAR'
|
||||
)
|
||||
validation_data['card_amount'] = card_amount
|
||||
validation_data['cash_amount'] = cash_amount
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Payment method validation error: {e}", flush=True)
|
||||
|
||||
# Run quick validation
|
||||
validator = OCRValidationEngine()
|
||||
return validator.quick_validate_for_hybrid(validation_data)
|
||||
|
||||
except Exception as e:
|
||||
# Never crash the process on validation errors
|
||||
print(f"[Worker {os.getpid()}] Cross-validation error: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return "passes" to allow processing to continue
|
||||
return True, 0.0, [f"Validation skipped due to error: {str(e)}"]
|
||||
|
||||
|
||||
def _is_extraction_valid_for_early_exit(extraction, min_confidence: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if extraction is valid for early exit in doctr_plus.
|
||||
|
||||
Combines confidence check with cross-validation to prevent
|
||||
early exit on OCR errors (e.g., wrong TOTAL but correct TVA).
|
||||
|
||||
Returns:
|
||||
True only if:
|
||||
1. Overall confidence >= min_confidence
|
||||
2. Critical fields are present (AMOUNT, DATE, CUI)
|
||||
3. Cross-validation passes (TOTAL matches TVA calculation, or no TVA)
|
||||
"""
|
||||
try:
|
||||
# First check basic completeness (relaxed for early exit)
|
||||
if not _is_extraction_complete(extraction, min_confidence, for_early_exit=True):
|
||||
return False
|
||||
|
||||
# Then run cross-validation
|
||||
passes_validation, penalty, errors = _quick_cross_validate(extraction)
|
||||
|
||||
if not passes_validation:
|
||||
print(f"[Early Exit] BLOCKED: cross-validation failed: {errors}", flush=True)
|
||||
return False
|
||||
|
||||
print(f"[Early Exit] OK: conf={extraction.overall_confidence:.0%}, validation passed", flush=True)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Never crash on validation - just continue to next engine
|
||||
print(f"[Worker {os.getpid()}] Early exit check error: {e}", flush=True)
|
||||
return False # Continue to next engine on error
|
||||
|
||||
def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""Run PaddleOCR recognition on image."""
|
||||
try:
|
||||
@@ -388,34 +657,191 @@ def _paddle_recognize(paddle_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85) -> bool:
|
||||
"""Check if extraction has all required fields."""
|
||||
def _doctr_recognize(doctr_engine, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""
|
||||
Run docTR recognition on image.
|
||||
|
||||
docTR requires RGB images, handles conversion automatically.
|
||||
Uses same preprocessing as PaddleOCR for consistent results.
|
||||
"""
|
||||
try:
|
||||
# docTR requires RGB images
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif image.shape[2] == 3:
|
||||
# Convert BGR (OpenCV) to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
|
||||
# docTR expects a list of numpy arrays (pages)
|
||||
result = doctr_engine([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker] docTR error: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_extraction_complete(ext, min_confidence: float = 0.85, for_early_exit: bool = False) -> bool:
|
||||
"""
|
||||
Check if extraction has required fields.
|
||||
|
||||
Args:
|
||||
ext: Extraction result
|
||||
min_confidence: Minimum overall confidence
|
||||
for_early_exit: If True, use relaxed criteria (AMOUNT + DATE + CUI required)
|
||||
If False, require all fields (strict mode for final validation)
|
||||
|
||||
Returns:
|
||||
True if extraction meets completeness criteria
|
||||
"""
|
||||
# Check confidence first
|
||||
if ext.overall_confidence < min_confidence:
|
||||
if for_early_exit:
|
||||
print(f"[Early Exit] BLOCKED: confidence {ext.overall_confidence:.0%} < {min_confidence:.0%}", flush=True)
|
||||
return False
|
||||
|
||||
has_number = bool(ext.receipt_number)
|
||||
has_date = bool(ext.receipt_date)
|
||||
has_amount = bool(ext.amount)
|
||||
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
|
||||
has_cui = bool(ext.cui)
|
||||
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
if for_early_exit:
|
||||
# Relaxed criteria for early exit:
|
||||
# - AMOUNT is required (core field)
|
||||
# - DATE is required (needed for accounting)
|
||||
# - CUI is required (needed for supplier identification)
|
||||
# - TVA is NOT required (some receipts have 0% TVA)
|
||||
# - receipt_number is NOT required (often missing)
|
||||
required_ok = all([has_amount, has_date, has_cui])
|
||||
|
||||
if not required_ok:
|
||||
missing = []
|
||||
if not has_amount: missing.append("AMOUNT")
|
||||
if not has_date: missing.append("DATE")
|
||||
if not has_cui: missing.append("CUI")
|
||||
print(f"[Early Exit] BLOCKED: missing required fields: {', '.join(missing)}", flush=True)
|
||||
|
||||
return required_ok
|
||||
else:
|
||||
# Strict criteria for final validation (all fields required)
|
||||
has_number = bool(ext.receipt_number)
|
||||
return all([has_number, has_date, has_amount, has_tva, has_cui])
|
||||
|
||||
|
||||
def _merge_extractions(primary, secondary):
|
||||
"""Merge two extractions, picking best fields from each."""
|
||||
"""Merge two extractions, picking best fields from each.
|
||||
|
||||
Primary should be the higher-quality engine (e.g., docTR).
|
||||
Secondary is the fallback engine (e.g., Tesseract).
|
||||
|
||||
Priority logic:
|
||||
- AMOUNT: TVA validation wins over confidence. If both valid or both invalid,
|
||||
uses confidence (or TVA diff for invalid cases).
|
||||
- DATE/CUI: Validation-based, then confidence, then primary wins ties.
|
||||
- OTHER FIELDS: Primary wins when both have values.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_extractor import ExtractionResult
|
||||
|
||||
result = ExtractionResult()
|
||||
|
||||
# Amount - prefer higher confidence
|
||||
# Helper: Check if amount matches TVA calculation
|
||||
def amount_passes_tva_validation(amount, tva_total, tva_entries):
|
||||
if not amount or not tva_total:
|
||||
return False, 0.0
|
||||
try:
|
||||
tva_rate = 0.21 # Default Romanian TVA
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict) and entry.get('percent'):
|
||||
tva_rate = float(entry['percent']) / 100.0
|
||||
break
|
||||
# Expected TOTAL = TVA / rate * (1 + rate)
|
||||
expected = float(tva_total) * (1 + tva_rate) / tva_rate
|
||||
actual = float(amount)
|
||||
diff_percent = abs(actual - expected) / expected if expected > 0 else 1.0
|
||||
return diff_percent < 0.03, diff_percent # 3% tolerance
|
||||
except:
|
||||
return False, 1.0
|
||||
|
||||
# Amount - prefer TVA-validated value over confidence
|
||||
if primary.amount and secondary.amount:
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
# Get TVA from the one with entries, or use any available
|
||||
tva_total = primary.tva_total or secondary.tva_total
|
||||
tva_entries = primary.tva_entries or secondary.tva_entries
|
||||
|
||||
primary_valid, primary_diff = amount_passes_tva_validation(
|
||||
primary.amount, tva_total, tva_entries
|
||||
)
|
||||
secondary_valid, secondary_diff = amount_passes_tva_validation(
|
||||
secondary.amount, tva_total, tva_entries
|
||||
)
|
||||
|
||||
print(f"[Merge] Amount comparison: primary={primary.amount} (valid={primary_valid}, diff={primary_diff:.1%}), "
|
||||
f"secondary={secondary.amount} (valid={secondary_valid}, diff={secondary_diff:.1%})", flush=True)
|
||||
|
||||
if secondary_valid and not primary_valid:
|
||||
# Secondary passes validation, primary doesn't - use secondary!
|
||||
print(f"[Merge] Using secondary amount {secondary.amount} (passes TVA validation)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
elif primary_valid and not secondary_valid:
|
||||
# Primary passes validation
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary_valid and secondary_valid:
|
||||
# Both valid - use higher confidence
|
||||
if primary.confidence_amount >= secondary.confidence_amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
else:
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
# Neither valid - use the one closer to TVA calculation
|
||||
if secondary_diff < primary_diff:
|
||||
print(f"[Merge] Neither valid, using secondary {secondary.amount} (closer to TVA)", flush=True)
|
||||
result.amount = secondary.amount
|
||||
result.confidence_amount = secondary.confidence_amount
|
||||
else:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
elif primary.amount:
|
||||
result.amount = primary.amount
|
||||
result.confidence_amount = primary.confidence_amount
|
||||
@@ -438,13 +864,15 @@ def _merge_extractions(primary, secondary):
|
||||
result.receipt_date = secondary.receipt_date
|
||||
result.confidence_date = secondary.confidence_date
|
||||
|
||||
# CUI - prefer valid format
|
||||
# CUI - prefer valid format and version with RO prefix
|
||||
# Use CUIChecksumRule static methods (single source of truth)
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
def is_valid_cui(cui):
|
||||
if not cui:
|
||||
return False
|
||||
import re
|
||||
cui_clean = re.sub(r'^RO', '', cui.upper())
|
||||
return bool(re.match(r'^\d{6,10}$', cui_clean))
|
||||
digits = CUIChecksumRule.extract_digits(cui)
|
||||
return len(digits) >= 6 and len(digits) <= 10
|
||||
|
||||
if primary.cui and secondary.cui:
|
||||
if is_valid_cui(primary.cui) and not is_valid_cui(secondary.cui):
|
||||
@@ -452,22 +880,27 @@ def _merge_extractions(primary, secondary):
|
||||
elif is_valid_cui(secondary.cui) and not is_valid_cui(primary.cui):
|
||||
result.cui = secondary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
# Both valid - prefer the one with RO prefix if digits match
|
||||
primary_digits = CUIChecksumRule.extract_digits(primary.cui)
|
||||
secondary_digits = CUIChecksumRule.extract_digits(secondary.cui)
|
||||
if primary_digits == secondary_digits:
|
||||
if CUIChecksumRule.has_ro_prefix(secondary.cui) and not CUIChecksumRule.has_ro_prefix(primary.cui):
|
||||
result.cui = secondary.cui # Prefer version with RO
|
||||
print(f"[CUI Complement] Preferring secondary with RO: {secondary.cui}", flush=True)
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
else:
|
||||
result.cui = primary.cui
|
||||
elif primary.cui:
|
||||
result.cui = primary.cui
|
||||
elif secondary.cui:
|
||||
result.cui = secondary.cui
|
||||
|
||||
# TVA entries
|
||||
# TVA entries - ALWAYS prefer primary (docTR) when both have entries
|
||||
if primary.tva_entries and secondary.tva_entries:
|
||||
primary_total = sum(e.get('amount', Decimal('0')) for e in primary.tva_entries)
|
||||
secondary_total = sum(e.get('amount', Decimal('0')) for e in secondary.tva_entries)
|
||||
if primary_total >= secondary_total:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
else:
|
||||
result.tva_entries = secondary.tva_entries
|
||||
result.tva_total = secondary.tva_total
|
||||
# Always use primary (docTR) - higher quality OCR
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
elif primary.tva_entries:
|
||||
result.tva_entries = primary.tva_entries
|
||||
result.tva_total = primary.tva_total
|
||||
@@ -483,12 +916,36 @@ def _merge_extractions(primary, secondary):
|
||||
result.address = primary.address or secondary.address
|
||||
result.items_count = primary.items_count or secondary.items_count
|
||||
result.payment_methods = primary.payment_methods or secondary.payment_methods
|
||||
result.suggested_payment_mode = getattr(primary, 'suggested_payment_mode', None) or getattr(secondary, 'suggested_payment_mode', None)
|
||||
|
||||
# Client fields
|
||||
result.client_name = primary.client_name or secondary.client_name
|
||||
result.client_cui = primary.client_cui or secondary.client_cui
|
||||
result.client_address = primary.client_address or secondary.client_address
|
||||
|
||||
# Confidence fields - preserve from primary or pick best
|
||||
if primary.confidence_vendor >= secondary.confidence_vendor:
|
||||
result.confidence_vendor = primary.confidence_vendor
|
||||
else:
|
||||
result.confidence_vendor = secondary.confidence_vendor
|
||||
|
||||
if hasattr(primary, 'confidence_client') and hasattr(secondary, 'confidence_client'):
|
||||
if primary.confidence_client >= secondary.confidence_client:
|
||||
result.confidence_client = primary.confidence_client
|
||||
else:
|
||||
result.confidence_client = secondary.confidence_client
|
||||
|
||||
# Raw text - combine both for debugging/display
|
||||
raw_texts = []
|
||||
if primary.raw_text:
|
||||
raw_texts.append(primary.raw_text)
|
||||
if secondary.raw_text and secondary.raw_text != primary.raw_text:
|
||||
raw_texts.append(secondary.raw_text)
|
||||
result.raw_text = '\n---\n'.join(raw_texts) if raw_texts else ''
|
||||
|
||||
# Note: overall_confidence is a computed @property on ExtractionResult
|
||||
# It automatically calculates from confidence_amount, confidence_date, confidence_vendor
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -557,6 +1014,7 @@ def _extraction_to_dict(extraction) -> dict:
|
||||
"address": extraction.address,
|
||||
"items_count": extraction.items_count,
|
||||
"payment_methods": extraction.payment_methods,
|
||||
"suggested_payment_mode": getattr(extraction, 'suggested_payment_mode', None),
|
||||
# Client data
|
||||
"client_name": extraction.client_name,
|
||||
"client_cui": extraction.client_cui,
|
||||
|
||||
@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
|
||||
|
||||
result = rule.validate({"cui": "R01879855"})
|
||||
# result.is_valid = False (checksum mismatch)
|
||||
|
||||
Static methods available for direct use:
|
||||
CUIChecksumRule.calculate_checksum("1056260") -> 0
|
||||
CUIChecksumRule.validate_checksum("10562600") -> True
|
||||
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
|
||||
"""
|
||||
|
||||
# Fixed multipliers for 9 positions (Romanian Mod 11)
|
||||
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(cui_base: str) -> int:
|
||||
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
|
||||
|
||||
Args:
|
||||
cui_base: CUI digits WITHOUT the checksum digit (last digit)
|
||||
|
||||
Returns:
|
||||
Expected checksum digit (0-9), or -1 if invalid input
|
||||
"""
|
||||
if not cui_base or not cui_base.isdigit():
|
||||
return -1
|
||||
|
||||
# Pad base to 9 digits from LEFT
|
||||
base_padded = cui_base.zfill(9)
|
||||
base_digits = [int(d) for d in base_padded]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
|
||||
|
||||
# Calculate checksum
|
||||
checksum = (weighted_sum * 10) % 11
|
||||
if checksum == 10:
|
||||
checksum = 0
|
||||
|
||||
return checksum
|
||||
|
||||
@staticmethod
|
||||
def validate_checksum(cui_digits: str) -> bool:
|
||||
"""Check if CUI checksum is valid.
|
||||
|
||||
Args:
|
||||
cui_digits: Full CUI digits (including checksum as last digit)
|
||||
|
||||
Returns:
|
||||
True if checksum is valid, False otherwise
|
||||
"""
|
||||
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return False
|
||||
|
||||
base = cui_digits[:-1]
|
||||
declared = int(cui_digits[-1])
|
||||
expected = CUIChecksumRule.calculate_checksum(base)
|
||||
|
||||
return expected == declared
|
||||
|
||||
@staticmethod
|
||||
def has_ro_prefix(cui: str) -> bool:
|
||||
"""Check if CUI has RO prefix (proper format for VAT payers)."""
|
||||
if not cui:
|
||||
return False
|
||||
return cui.upper().strip().startswith('RO')
|
||||
|
||||
@staticmethod
|
||||
def extract_digits(cui: str) -> str:
|
||||
"""Extract digits from CUI, removing RO/R0 prefix."""
|
||||
if not cui:
|
||||
return ""
|
||||
cui = cui.strip().upper()
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui = cui[2:]
|
||||
return ''.join(c for c in cui if c.isdigit())
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "CUI Checksum Check (Mod 11)"
|
||||
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="No CUI to validate"
|
||||
)
|
||||
|
||||
# Normalize: remove RO/R0 prefix
|
||||
cui_clean = cui.strip().upper()
|
||||
if cui_clean.startswith("RO"):
|
||||
cui_clean = cui_clean[2:]
|
||||
elif cui_clean.startswith("R0"):
|
||||
cui_clean = cui_clean[2:]
|
||||
# Use static method to extract digits
|
||||
cui_clean = CUIChecksumRule.extract_digits(cui)
|
||||
|
||||
# Check format first
|
||||
if not cui_clean.isdigit():
|
||||
if not cui_clean:
|
||||
return ValidationResult(
|
||||
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
|
||||
message="CUI format invalid, skipping checksum"
|
||||
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
|
||||
message="CUI length invalid, skipping checksum"
|
||||
)
|
||||
|
||||
# Extract digits
|
||||
digits = [int(d) for d in cui_clean]
|
||||
checksum_declared = digits[-1]
|
||||
base_digits = digits[:-1]
|
||||
|
||||
# Multipliers (trim to match base_digits length)
|
||||
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
||||
multipliers = multipliers[:len(base_digits)]
|
||||
|
||||
# Calculate weighted sum
|
||||
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
|
||||
|
||||
# Calculate expected checksum
|
||||
checksum_calculated = (weighted_sum * 10) % 11
|
||||
if checksum_calculated == 10:
|
||||
checksum_calculated = 0
|
||||
|
||||
if checksum_calculated != checksum_declared:
|
||||
# Use static method to validate checksum
|
||||
if not CUIChecksumRule.validate_checksum(cui_clean):
|
||||
# Calculate expected for error message
|
||||
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
|
||||
declared = int(cui_clean[-1])
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.3,
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
|
||||
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
|
||||
severity="warning"
|
||||
)
|
||||
|
||||
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
|
||||
)
|
||||
|
||||
|
||||
class TVABasedTotalRule(ValidationRule):
|
||||
"""Validate TOTAL using reverse calculation from TVA amount.
|
||||
|
||||
This is a CRITICAL validation that catches cases where OCR extracts
|
||||
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
|
||||
we can calculate expected TOTAL from TVA alone.
|
||||
|
||||
Formula:
|
||||
Expected TOTAL = TVA / rate * (1 + rate)
|
||||
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
|
||||
|
||||
For TVA rate 21%:
|
||||
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
|
||||
|
||||
Example (benzina 27 oct):
|
||||
TVA = 49.58, rate = 21%
|
||||
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
|
||||
Extracted TOTAL = 205.66 (WRONG!)
|
||||
Rule detects mismatch and flags for escalation
|
||||
|
||||
Usage in multi-tier processing (e.g., doctr_plus):
|
||||
If this rule fails, the engine should proceed to next tier
|
||||
instead of returning early with potentially wrong data.
|
||||
"""
|
||||
|
||||
def __init__(self, tolerance_percent: float = 0.02):
|
||||
"""
|
||||
Args:
|
||||
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
|
||||
"""
|
||||
self.tolerance_percent = tolerance_percent
|
||||
|
||||
@property
|
||||
def rule_name(self) -> str:
|
||||
return "TVA-Based Total Check"
|
||||
|
||||
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
||||
total = data.get("amount")
|
||||
tva = data.get("tva")
|
||||
tva_entries = data.get("tva_entries", [])
|
||||
|
||||
if not total or not tva:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Insufficient data for TVA-based total validation"
|
||||
)
|
||||
|
||||
# Type safety
|
||||
try:
|
||||
total = float(total)
|
||||
tva = float(tva)
|
||||
except (TypeError, ValueError):
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Non-numeric values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
if tva <= 0 or total <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message="Zero or negative values, skipping TVA-based total validation"
|
||||
)
|
||||
|
||||
# Try to determine TVA rate from entries
|
||||
tva_rate = None
|
||||
|
||||
# Check tva_entries for rate information
|
||||
if tva_entries:
|
||||
for entry in tva_entries:
|
||||
if isinstance(entry, dict):
|
||||
percent = entry.get('percent')
|
||||
if percent:
|
||||
try:
|
||||
tva_rate = float(percent) / 100.0
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: try to calculate rate from TVA/TOTAL ratio
|
||||
if not tva_rate:
|
||||
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
|
||||
# TVA/TOTAL = rate / (1 + rate)
|
||||
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
|
||||
base = total - tva
|
||||
if base > 0:
|
||||
calculated_rate = tva / base
|
||||
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
|
||||
if 0.04 <= calculated_rate <= 0.25:
|
||||
tva_rate = calculated_rate
|
||||
|
||||
if not tva_rate:
|
||||
# Assume most common rate: 21%
|
||||
tva_rate = 0.21
|
||||
|
||||
# Calculate expected TOTAL from TVA
|
||||
# TVA = BASE * rate → BASE = TVA / rate
|
||||
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
|
||||
expected_total = tva * (1 + tva_rate) / tva_rate
|
||||
|
||||
# Calculate difference
|
||||
diff = abs(total - expected_total)
|
||||
diff_percent = diff / expected_total if expected_total > 0 else 1.0
|
||||
|
||||
if diff_percent > self.tolerance_percent:
|
||||
# Significant mismatch - OCR likely extracted TOTAL wrong
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
confidence_penalty=0.5, # High penalty - this is a critical error
|
||||
message=(
|
||||
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
|
||||
f"TVA-calculated {expected_total:.2f} RON "
|
||||
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
|
||||
f"Likely OCR error on TOTAL."
|
||||
),
|
||||
severity="error"
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
|
||||
)
|
||||
|
||||
|
||||
class InterOCRConsistencyRule(ValidationRule):
|
||||
"""Validate consistency between multiple OCR results.
|
||||
|
||||
@@ -562,6 +741,7 @@ class OCRValidationEngine:
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
|
||||
PaymentSumRule(tolerance=0.02),
|
||||
TVAEntriesSumRule(tolerance=0.02),
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
|
||||
]
|
||||
|
||||
# Inter-OCR consistency rules
|
||||
@@ -699,39 +879,508 @@ class OCRValidationEngine:
|
||||
inter_ocr_ratios=inter_ocr_ratios
|
||||
)
|
||||
|
||||
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
|
||||
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
|
||||
|
||||
Runs critical cross-validation rules to detect obvious OCR errors.
|
||||
Used to decide whether to proceed to next processing tier or exit early.
|
||||
|
||||
Args:
|
||||
extraction_result: Extraction data dict with fields:
|
||||
- amount: Extracted TOTAL
|
||||
- tva: Extracted TVA total
|
||||
- tva_entries: List of TVA entries with rates
|
||||
|
||||
Returns:
|
||||
Tuple of (passes_validation, confidence_penalty, error_messages)
|
||||
- passes_validation: True if no critical errors detected
|
||||
- confidence_penalty: Cumulative penalty (0.0-1.0)
|
||||
- error_messages: List of validation error messages
|
||||
|
||||
Example usage:
|
||||
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
|
||||
if not passes:
|
||||
print(f"Validation failed: {errors}, proceeding to next tier")
|
||||
# Continue to next processing tier instead of early exit
|
||||
"""
|
||||
errors = []
|
||||
total_penalty = 0.0
|
||||
|
||||
# Critical rules for early-exit decision-making
|
||||
# These determine if we can trust the extraction or need to proceed to next tier
|
||||
critical_rules = [
|
||||
# Cross-field validations (most important for detecting OCR errors)
|
||||
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
|
||||
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
|
||||
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
|
||||
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
|
||||
|
||||
# Format & checksum validations
|
||||
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
|
||||
CUIFormatRule(), # CUI should be 6-10 digits
|
||||
|
||||
# Sanity checks
|
||||
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
|
||||
]
|
||||
|
||||
for rule in critical_rules:
|
||||
result = rule.validate(extraction_result)
|
||||
if not result.is_valid:
|
||||
errors.append(result.message)
|
||||
total_penalty += result.confidence_penalty
|
||||
|
||||
# Cap penalty at 1.0
|
||||
total_penalty = min(1.0, total_penalty)
|
||||
|
||||
passes = len(errors) == 0
|
||||
return passes, total_penalty, errors
|
||||
|
||||
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
|
||||
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
|
||||
|
||||
@staticmethod
|
||||
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
|
||||
"""Try to repair CUI by attempting 1-digit corrections.
|
||||
|
||||
OCR often misreads similar-looking digits:
|
||||
- 5 ↔ 8 (most common in receipts)
|
||||
- 6 ↔ 0
|
||||
- 1 ↔ 7
|
||||
- 3 ↔ 8
|
||||
|
||||
Algorithm:
|
||||
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
|
||||
2. Skip first digit (position 0) - usually reliable in CUI
|
||||
3. Check checksum digit (last position) last
|
||||
4. Prefer common OCR digit confusions (5↔8, 6↔0)
|
||||
|
||||
Args:
|
||||
cui_digits: Original CUI digits (without RO prefix)
|
||||
|
||||
Returns:
|
||||
Repaired CUI digits if 1-digit fix found, else None
|
||||
"""
|
||||
if len(cui_digits) < 6 or not cui_digits.isdigit():
|
||||
return None
|
||||
|
||||
# If already valid, return as-is
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
return cui_digits
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then position 1, then 0, then checksum
|
||||
# Skip position 0 (first digit) - it's usually reliable
|
||||
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
|
||||
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
|
||||
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
if CUIChecksumRule.validate_checksum(candidate):
|
||||
print(f"[CUI Repair] Fixed {cui_digits} → {candidate} (position {pos}: {original_digit}→{replacement})", flush=True)
|
||||
return candidate
|
||||
|
||||
# No single-digit fix found
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def normalize_cui(cui: Optional[str]) -> Optional[str]:
|
||||
"""Normalize CUI to RO prefix + digits format.
|
||||
"""Normalize CUI - fix OCR errors but preserve original format.
|
||||
|
||||
Rules:
|
||||
- R0 → RO (fix OCR error where O is read as 0)
|
||||
- Keep RO prefix if original had it (platitor TVA)
|
||||
- Do NOT add RO if original didn't have it (neplatitor TVA)
|
||||
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
|
||||
|
||||
Examples:
|
||||
10562600 → RO10562600
|
||||
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
|
||||
R010562600 → RO10562600 (fix R0 OCR error)
|
||||
RO10562600 → RO10562600 (unchanged)
|
||||
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
|
||||
|
||||
Args:
|
||||
cui: Raw CUI string from OCR
|
||||
|
||||
Returns:
|
||||
Normalized CUI with RO prefix, or None if invalid
|
||||
Normalized CUI, or None if invalid
|
||||
"""
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Remove existing prefix if present
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:]
|
||||
elif cui.startswith("R0"):
|
||||
cui = cui[2:]
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui if c.isdigit())
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
|
||||
return None
|
||||
|
||||
# Add RO prefix
|
||||
return f"RO{cui_digits}"
|
||||
# Try to repair checksum if invalid
|
||||
if not CUIChecksumRule.validate_checksum(cui_digits):
|
||||
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
|
||||
if repaired:
|
||||
cui_digits = repaired
|
||||
|
||||
# Return with RO prefix only if original had it
|
||||
if had_ro_prefix:
|
||||
return f"RO{cui_digits}"
|
||||
else:
|
||||
return cui_digits
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_cui_from_db(
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match CUI against database of known suppliers.
|
||||
|
||||
This function:
|
||||
1. Validates CUI checksum
|
||||
2. If valid, looks up in database (exact match)
|
||||
3. If invalid, tries 1-digit corrections and looks up each candidate
|
||||
4. Returns the first match found in database
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid)
|
||||
db_session: SQLAlchemy async session for database lookups
|
||||
|
||||
Returns:
|
||||
Tuple of (corrected_cui, supplier_name) if found, else None
|
||||
|
||||
Usage in OCR extraction:
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
"""
|
||||
from sqlalchemy import select, or_
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
cui = cui.strip().upper()
|
||||
|
||||
# Check if original had RO/R0 prefix
|
||||
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
||||
|
||||
# Extract digits
|
||||
if cui.startswith("RO"):
|
||||
cui_digits = cui[2:]
|
||||
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
||||
cui_digits = cui[2:]
|
||||
else:
|
||||
cui_digits = cui
|
||||
|
||||
# Remove any non-digit characters
|
||||
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
||||
|
||||
# Validate length
|
||||
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
||||
return None
|
||||
|
||||
# Helper to format CUI with optional RO prefix
|
||||
def format_cui(digits: str) -> str:
|
||||
if had_ro_prefix:
|
||||
return f"RO{digits}"
|
||||
return digits
|
||||
|
||||
# Helper to search database for CUI
|
||||
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
|
||||
"""Search both synced and local suppliers for CUI."""
|
||||
# Search patterns: with and without RO prefix
|
||||
search_patterns = [digits, f"RO{digits}"]
|
||||
|
||||
# Search synced_suppliers first (more data)
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
or_(
|
||||
SyncedSupplier.fiscal_code == digits,
|
||||
SyncedSupplier.fiscal_code == f"RO{digits}",
|
||||
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
or_(
|
||||
LocalSupplier.fiscal_code == digits,
|
||||
LocalSupplier.fiscal_code == f"RO{digits}",
|
||||
LocalSupplier.fiscal_code == digits.lstrip('0'),
|
||||
)
|
||||
).limit(1)
|
||||
result = await db_session.execute(stmt)
|
||||
row = result.first()
|
||||
if row:
|
||||
return (format_cui(digits), row.name)
|
||||
|
||||
return None
|
||||
|
||||
# 1. If checksum is valid, check if it exists in database (exact match)
|
||||
if CUIChecksumRule.validate_checksum(cui_digits):
|
||||
match = await lookup_cui_in_db(cui_digits)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] Exact match found: {cui} → {match[0]} ({match[1]})", flush=True)
|
||||
return match
|
||||
# Valid checksum but not in DB - return as-is (it might be a new supplier)
|
||||
return None
|
||||
|
||||
# 2. Invalid checksum - try 1-digit corrections and verify against database
|
||||
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
|
||||
|
||||
# Common OCR digit confusions (try these first)
|
||||
confusion_pairs = {
|
||||
'5': ['8', '6'], # 5 often misread as 8 or 6
|
||||
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
||||
'6': ['0', '8'], # 6 often misread as 0 or 8
|
||||
'0': ['6', '8'], # 0 often misread as 6 or 8
|
||||
'1': ['7', '4'], # 1 often misread as 7 or 4
|
||||
'7': ['1'], # 7 often misread as 1
|
||||
'3': ['8'], # 3 often misread as 8
|
||||
'4': ['1'], # 4 often misread as 1
|
||||
'2': ['7'], # 2 sometimes misread as 7
|
||||
'9': ['0'], # 9 sometimes misread as 0
|
||||
}
|
||||
|
||||
n = len(cui_digits)
|
||||
last_pos = n - 1 # checksum position
|
||||
|
||||
# Position check order: middle positions first, then ends
|
||||
middle_positions = list(range(2, last_pos))
|
||||
position_order = middle_positions + [1, last_pos, 0]
|
||||
|
||||
for pos in position_order:
|
||||
if pos >= n:
|
||||
continue
|
||||
|
||||
original_digit = cui_digits[pos]
|
||||
|
||||
# Try common confusions first for this digit
|
||||
candidates = confusion_pairs.get(original_digit, [])
|
||||
# Then try all other digits
|
||||
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
||||
|
||||
for replacement in candidates + all_digits:
|
||||
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
||||
|
||||
# Only consider if checksum is valid
|
||||
if not CUIChecksumRule.validate_checksum(candidate):
|
||||
continue
|
||||
|
||||
# Check if this corrected CUI exists in database
|
||||
match = await lookup_cui_in_db(candidate)
|
||||
if match:
|
||||
print(f"[Fuzzy CUI] DB match: {cui} → {match[0]} ({match[1]}) [pos {pos}: {original_digit}→{replacement}]", flush=True)
|
||||
return match
|
||||
|
||||
# No match found in database
|
||||
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_by_name_and_cui(
|
||||
vendor_name: Optional[str],
|
||||
cui: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Fuzzy match supplier by NAME, then narrow down by CUI.
|
||||
|
||||
Algorithm:
|
||||
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
|
||||
2. Search suppliers by fuzzy name match (LIKE %name%)
|
||||
3. If multiple results, use fuzzy CUI matching to pick best one
|
||||
4. Return the best match
|
||||
|
||||
Args:
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
from sqlalchemy import select, or_, func
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
||||
import re
|
||||
|
||||
if not vendor_name or len(vendor_name) < 3:
|
||||
return None
|
||||
|
||||
# Normalize vendor name for search
|
||||
def normalize_name(name: str) -> str:
|
||||
"""Normalize name for fuzzy matching."""
|
||||
name = name.upper()
|
||||
# Remove company type suffixes
|
||||
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
|
||||
name = name.replace(suffix, '')
|
||||
# Remove punctuation and extra spaces
|
||||
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
|
||||
name = ' '.join(name.split())
|
||||
return name.strip()
|
||||
|
||||
# Extract key words from vendor name (for fuzzy search)
|
||||
normalized_name = normalize_name(vendor_name)
|
||||
name_words = [w for w in normalized_name.split() if len(w) >= 3]
|
||||
|
||||
if not name_words:
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
|
||||
|
||||
# Build search pattern - use first significant word
|
||||
primary_word = name_words[0]
|
||||
search_pattern = f"%{primary_word}%"
|
||||
|
||||
candidates = []
|
||||
|
||||
# Search synced_suppliers
|
||||
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
||||
func.upper(SyncedSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
# Search local_suppliers
|
||||
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
||||
func.upper(LocalSupplier.name).like(search_pattern)
|
||||
).limit(20)
|
||||
result = await db_session.execute(stmt)
|
||||
for row in result:
|
||||
if row.fiscal_code:
|
||||
candidates.append((row.fiscal_code, row.name))
|
||||
|
||||
if not candidates:
|
||||
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
|
||||
return None
|
||||
|
||||
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
|
||||
|
||||
# If only one candidate, return it
|
||||
if len(candidates) == 1:
|
||||
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
# Multiple candidates - try to narrow down by CUI
|
||||
if cui:
|
||||
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
|
||||
|
||||
if len(cui_digits) >= 6:
|
||||
# Score each candidate by how similar their CUI is to the extracted one
|
||||
def cui_similarity(candidate_cui: str) -> int:
|
||||
"""Calculate how many digits match in the same position."""
|
||||
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
|
||||
if len(cand_digits) != len(cui_digits):
|
||||
return 0
|
||||
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
|
||||
|
||||
# Sort candidates by CUI similarity (descending)
|
||||
scored = [(cui_similarity(c[0]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
best_score, best_match = scored[0]
|
||||
# Require at least 70% digit match for CUI similarity
|
||||
min_matching = int(len(cui_digits) * 0.7)
|
||||
|
||||
if best_score >= min_matching:
|
||||
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
|
||||
return best_match
|
||||
|
||||
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
|
||||
|
||||
# If still multiple and no CUI match, try name similarity
|
||||
def name_similarity(candidate_name: str) -> int:
|
||||
"""Count how many keywords match."""
|
||||
norm_cand = normalize_name(candidate_name)
|
||||
return sum(1 for w in name_words if w in norm_cand)
|
||||
|
||||
scored = [(name_similarity(c[1]), c) for c in candidates]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if scored[0][0] >= 2: # At least 2 keywords match
|
||||
best_match = scored[0][1]
|
||||
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
|
||||
return best_match
|
||||
|
||||
# Return first candidate if nothing else works
|
||||
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
||||
return candidates[0]
|
||||
|
||||
@staticmethod
|
||||
async def fuzzy_match_supplier(
|
||||
cui: Optional[str],
|
||||
vendor_name: Optional[str],
|
||||
db_session
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
|
||||
|
||||
Strategy:
|
||||
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
|
||||
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
|
||||
|
||||
Args:
|
||||
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
||||
vendor_name: Extracted vendor name from OCR
|
||||
db_session: SQLAlchemy async session
|
||||
|
||||
Returns:
|
||||
Tuple of (matched_cui, supplier_name) if found, else None
|
||||
"""
|
||||
# Step 1: Try fuzzy CUI matching
|
||||
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
|
||||
if cui_match:
|
||||
return cui_match
|
||||
|
||||
# Step 2: Fallback to fuzzy NAME + CUI matching
|
||||
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
|
||||
vendor_name, cui, db_session
|
||||
)
|
||||
if name_match:
|
||||
return name_match
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OCR engine wrapper for PaddleOCR and Tesseract."""
|
||||
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
@@ -9,9 +9,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
# Setup logging (respects LOG_LEVEL env var set in main.py)
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO) # Ensure logs are visible
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
@@ -19,6 +18,7 @@ os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
# Lazy imports - these will be imported on first use
|
||||
PaddleOCR = None # Will be imported lazily
|
||||
pytesseract = None # Will be imported lazily
|
||||
doctr_ocr_predictor = None # Will be imported lazily
|
||||
|
||||
# Check availability without importing heavy libraries
|
||||
def _check_paddle_available() -> bool:
|
||||
@@ -37,8 +37,17 @@ def _check_tesseract_available() -> bool:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_doctr_available() -> bool:
|
||||
"""Check if doctr is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("doctr") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
PADDLE_AVAILABLE = _check_paddle_available()
|
||||
TESSERACT_AVAILABLE = _check_tesseract_available()
|
||||
DOCTR_AVAILABLE = _check_doctr_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -59,6 +68,11 @@ class OCREngine:
|
||||
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
|
||||
self._paddle_init_lock = threading.Lock()
|
||||
|
||||
self._doctr = None
|
||||
self._doctr_init_started = False
|
||||
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
|
||||
self._doctr_init_lock = threading.Lock()
|
||||
|
||||
def _init_paddle_lazy(self):
|
||||
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
|
||||
global PaddleOCR
|
||||
@@ -94,6 +108,78 @@ class OCREngine:
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._paddle_ready.set()
|
||||
|
||||
def _init_doctr_lazy(self):
|
||||
"""Lazy initialize docTR on first use (avoids slow startup)."""
|
||||
global doctr_ocr_predictor
|
||||
|
||||
with self._doctr_init_lock:
|
||||
if self._doctr_init_started:
|
||||
return # Already initializing or done
|
||||
self._doctr_init_started = True
|
||||
|
||||
if DOCTR_AVAILABLE:
|
||||
try:
|
||||
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
|
||||
from doctr.io import DocumentFile
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
print("Initializing docTR engine (PyTorch backend)...", flush=True)
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
self._doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
doctr_ocr_predictor = self._doctr
|
||||
print("docTR initialized successfully with PyTorch backend", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
|
||||
self._doctr = None
|
||||
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._doctr_ready.set()
|
||||
|
||||
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for docTR to be fully initialized.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (default 30s)
|
||||
|
||||
Returns:
|
||||
True if docTR is ready, False if timeout or unavailable
|
||||
"""
|
||||
if not DOCTR_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._doctr is not None:
|
||||
return True # Already ready
|
||||
|
||||
if not self._doctr_init_started:
|
||||
# Start initialization if not already started
|
||||
self._init_doctr_lazy()
|
||||
|
||||
# Wait for initialization to complete
|
||||
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
|
||||
start = time.time()
|
||||
ready = self._doctr_ready.wait(timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if ready and self._doctr is not None:
|
||||
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
|
||||
return True
|
||||
else:
|
||||
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
|
||||
return False
|
||||
|
||||
def is_doctr_ready(self) -> bool:
|
||||
"""Check if docTR is ready without waiting."""
|
||||
return self._doctr is not None
|
||||
|
||||
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for PaddleOCR to be fully initialized.
|
||||
@@ -239,6 +325,84 @@ class OCREngine:
|
||||
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
|
||||
|
||||
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using docTR."""
|
||||
# Wait for docTR to be fully ready
|
||||
if not self.wait_for_doctr(timeout=30.0):
|
||||
logger.warning("[docTR] Not ready, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise RuntimeError("docTR not ready and Tesseract not available")
|
||||
|
||||
try:
|
||||
logger.info(f"[docTR] Processing image, shape: {image.shape}")
|
||||
|
||||
# docTR requires RGB images
|
||||
import cv2
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 3:
|
||||
# Check if BGR (from OpenCV) and convert to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
|
||||
|
||||
# Process image with docTR
|
||||
logger.info("[docTR] Running prediction...")
|
||||
from doctr.io import DocumentFile
|
||||
|
||||
# docTR expects a document (list of pages as numpy arrays)
|
||||
result = self._doctr([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
logger.warning("[docTR] No results returned")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
|
||||
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise
|
||||
|
||||
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
|
||||
"""
|
||||
Run both OCR engines and return both results.
|
||||
@@ -286,10 +450,27 @@ class OCREngine:
|
||||
|
||||
@staticmethod
|
||||
def get_available_engines() -> List[str]:
|
||||
"""Return list of available OCR engines."""
|
||||
"""
|
||||
Return list of available OCR engines.
|
||||
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
Engines that are disabled via .env are not returned even if installed.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check .env settings
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
engines = []
|
||||
if PADDLE_AVAILABLE:
|
||||
engines.append('paddleocr')
|
||||
if TESSERACT_AVAILABLE:
|
||||
|
||||
# Base engines (only if installed AND enabled)
|
||||
if TESSERACT_AVAILABLE and tesseract_enabled:
|
||||
engines.append('tesseract')
|
||||
if DOCTR_AVAILABLE:
|
||||
engines.append('doctr')
|
||||
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
|
||||
if PADDLE_AVAILABLE and paddle_enabled:
|
||||
engines.append('paddleocr')
|
||||
|
||||
return engines
|
||||
|
||||
@@ -6,6 +6,8 @@ from decimal import Decimal, InvalidOperation
|
||||
from typing import Optional, Tuple, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
@@ -24,6 +26,7 @@ class ExtractionResult:
|
||||
address: Optional[str] = None
|
||||
items_count: Optional[int] = None
|
||||
payment_methods: List[dict] = field(default_factory=list) # [{"method":"CARD","amount":Decimal}]
|
||||
suggested_payment_mode: Optional[str] = None # 'banca' if CARD detected, 'numerar' if cash only
|
||||
|
||||
# Client data (for B2B receipts - buyer information)
|
||||
client_name: Optional[str] = None
|
||||
@@ -125,8 +128,10 @@ class ReceiptExtractor:
|
||||
(r'C3POS[-A-Z0-9]*[N:](\d{6,7})', 0.98), # CT2N1360760 format
|
||||
(r'C3POS.*?(\d{6,7})\b', 0.95), # Any C3POS followed by 6-7 digit number
|
||||
(r'CT2[N:]\s*(\d{6,})', 0.95), # CT2N prefix
|
||||
# BF (Bon Fiscal) number
|
||||
(r'BF\s*:?\s*(\d+)', 0.93),
|
||||
# BF (Bon Fiscal) number - high priority
|
||||
# Format: "Z:0864 BF:0018" - extract only the number after BF:
|
||||
(r'BF\s*:\s*(\d{4,})', 0.96), # BF: with colon (most specific)
|
||||
(r'BF\s+(\d{4,})', 0.93), # BF followed by space and number
|
||||
# NIVS format
|
||||
(r'NIVS\s*:?\s*(\d+)', 0.95),
|
||||
# Standard NR BON formats
|
||||
@@ -151,28 +156,45 @@ class ReceiptExtractor:
|
||||
# OCR errors: R0 instead of RO, C1F instead of CIF
|
||||
CUI_PATTERNS = [
|
||||
# CIF at start of line (definitely vendor) - tolerant to OCR errors
|
||||
(r'^CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.98),
|
||||
(r'^C[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # C1F OCR error
|
||||
# NOTE: Capture full CUI including RO prefix: (R[O0]?\d{6,10}) or ((?:R[O0])?\d{6,10})
|
||||
(r'^CIF\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
(r'^CIF\s*:?\s*(\d{6,10})', 0.97), # Without RO prefix
|
||||
(r'^C[I1]F\s*:?\s*(R[O0]?\d{6,10})', 0.95), # C1F OCR error
|
||||
(r'^C[I1]F\s*:?\s*(\d{6,10})', 0.94), # C1F without RO
|
||||
# CIF not preceded by CLIENT (negative lookbehind)
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(R[O0]?\d{6,10})', 0.95),
|
||||
(r'(?<!CLIENT\s)(?<!LIENT\s)CIF\s*:?\s*(\d{6,10})', 0.94),
|
||||
# Standalone CIF: format with OCR tolerance
|
||||
(r'\bC[I1]F\s*:?\s*(?:R[O0])?(\d{6,10})\b', 0.90),
|
||||
(r'\bC[I1]F\s*:?\s*(R[O0]?\d{6,10})\b', 0.90),
|
||||
(r'\bC[I1]F\s*:?\s*(\d{6,10})\b', 0.89),
|
||||
# COD FISCAL (vendor)
|
||||
(r'COD\s+FISCAL\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'COD\s+FISCAL\s*:?\s*(R[O0]?\d{6,10})', 0.90),
|
||||
(r'COD\s+FISCAL\s*:?\s*(\d{6,10})', 0.89),
|
||||
# C. I. F. format with SPACES (OCR artifact) - "C. I. F. : R011201891"
|
||||
(r'C\.\s*I\.\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
|
||||
# Also handles double colon from OMV/Petrom: "C. I.F.: : RO11201891"
|
||||
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(R[O0]?\d{6,10})', 0.92),
|
||||
(r'C\.\s*I\.\s*F\.?\s*[:\s]+(\d{6,10})', 0.91),
|
||||
# C.I.F. format (with dots, no spaces)
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.88),
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.88),
|
||||
(r'(?<!CLIENT\s)C\.[I1]\.F\.?\s*:?\s*(\d{6,10})', 0.87),
|
||||
# CUI format (less specific, use with caution)
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.85),
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.85),
|
||||
(r'(?<!CLIENT\s)C\.?U\.?[I1]\.?\s*:?\s*(\d{6,10})', 0.84),
|
||||
# Lidl format: "Cod Identificare fiscala: RO..." (OCR corrupts to "Ced Identificanfliscalar")
|
||||
# Matches: "Identificare fiscala", "Identificanfliscalar", "Identificoan/Fljscales"
|
||||
(r'[IC](?:od|ed)\s*Identific[a-z/]*\s*(R[O0]\d{6,10})', 0.90),
|
||||
# Generic: anything with "fiscal" followed by RO + digits
|
||||
(r'fiscal[a-z]*\s*:?\s*(R[O0]\d{6,10})', 0.85),
|
||||
]
|
||||
|
||||
# Pattern for CIF NUMBER appearing BEFORE "C.I.F." label (reversed format)
|
||||
# Common in some receipts: "R011201891\nC. I. F." - number on line before label
|
||||
# Common in some receipts: "RO11201891\nC. I. F." - number on line before label
|
||||
# IMPORTANT: Capture the full CUI including RO prefix
|
||||
CUI_REVERSED_PATTERNS = [
|
||||
# RO + 8-10 digits on line immediately before C.I.F./CIF label
|
||||
(r'(?:R[O0])(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
|
||||
# Just digits before C.I.F. label
|
||||
# RO/R0 + 6-10 digits on line immediately before C.I.F./CIF label
|
||||
# Capture the FULL CUI including RO prefix
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.98),
|
||||
# Just digits before C.I.F. label (neplatitor TVA - no RO prefix)
|
||||
(r'(\d{6,10})\s*\n\s*C\.?\s*I\.?\s*F\.?', 0.95),
|
||||
]
|
||||
|
||||
@@ -185,38 +207,67 @@ class ReceiptExtractor:
|
||||
(r'(?:^|\s)BF\s*:\s*(\d{4})', 0.85),
|
||||
]
|
||||
|
||||
# TVA (VAT) patterns - OCR may produce TUA, TVR, etc.
|
||||
# TVA (VAT) patterns - OCR may produce TUA, TVR, IVA, etc.
|
||||
# All patterns are case-insensitive (re.IGNORECASE applied in extraction)
|
||||
TVA_PATTERNS = [
|
||||
# TOTAL TVA BON format (OCR tolerant: TUA, TVR)
|
||||
(r'TOTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
|
||||
(r'T[O0]TAL\s+T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
# TOTAL TVA BON format (OCR tolerant: TUA, TVR, IVA)
|
||||
(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON\s*:?\s*([\d\s.,]+)', 0.98),
|
||||
(r'T[O0]TAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
# IVA variant (Spanish/Portuguese influence, some receipts)
|
||||
(r'TOTAL\s+IVA\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
(r'IVA\s+[A-D]?\s*[-:]?\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.93),
|
||||
# TVA with percentage (OCR tolerant)
|
||||
(r'T[VU][AR]\s+(?:A\s*[-:]?\s*)?(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.95),
|
||||
(r'T[VU][AR]\s+[A-Z]\s*[-:]\s*(\d{1,2})\s*%\s*([\d\s.,]+)', 0.93),
|
||||
# Simple TVA pattern
|
||||
(r'T[VU][AR]\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
# 5% TVA rate (books, newspapers - TVA C)
|
||||
(r'T[VU][AR]\s*[C5]\s*[-:]\s*5\s*%\s*:?\s*([\d\s.,]+)', 0.93),
|
||||
(r'(?:T[VU][AR]|IVA)\s+5\s*%\s*:?\s*([\d\s.,]+)', 0.92),
|
||||
# Garbled OCR: T0TAL, TVAI, TUAI, etc.
|
||||
(r'T[O0]T[AE]L\s+(?:T[VUAI]+[AR]?|IVA)\s*:?\s*([\d\s.,]+)', 0.88),
|
||||
# OCR corruption: "TA F 194" (TVA with V→F or space), "T A 19%"
|
||||
# Handles: "TOTAL TA F 194" where TVA became "TA F"
|
||||
(r'TOTAL\s+TA\s*[F\s]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
(r'TA\s+[FA-Z]?\s*\d{1,2}\s*%?\s*:?\s*([\d\s.,]+)', 0.82),
|
||||
# "TUA" with random letter after (OCR noise): "TUA F", "TUA I"
|
||||
(r'T[VU]A\s+[A-Z]?\s*\d*\s*:?\s*([\d\s.,]+)', 0.83),
|
||||
# Simple TVA/IVA pattern
|
||||
(r'(?:T[VU][AR]|IVA)\s*:?\s*([\d\s.,]+)', 0.85),
|
||||
# Standalone percentage line near TVA
|
||||
(r'(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)', 0.75),
|
||||
]
|
||||
|
||||
# Payment method patterns - appears after TOTAL LEI, before TOTAL TVA
|
||||
# Format: "CARD: 50.00" or "NUMERAR 100.00" or "PLATA CARD: 50.00"
|
||||
# OMV/Petrom uses "CARTE CREDIT" or "CARTE CREDIT 318, 16"
|
||||
PAYMENT_METHOD_PATTERNS = [
|
||||
# CARTE CREDIT with amount on same line (OMV/Petrom receipts)
|
||||
# Handles: "CARTE CREDIT 318, 16" with OCR spaces in number
|
||||
(r'CARTE\s+CREDIT\s*:?\s*([\d\s.,]+)', 'CARD', 0.98),
|
||||
# CARTE CREDIT with amount on next line (OCR may split lines)
|
||||
# Handles: "CARTE CREDIT\n318, 16"
|
||||
(r'CARTE\s+CREDIT\s*:?\s*\n\s*([\d\s.,]+)', 'CARD', 0.97),
|
||||
# CARD with amount (high confidence)
|
||||
(r'(?:PLATA\s+)?CARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.95),
|
||||
# Also handles OCR artifacts like "CARD F 100.00" where F is noise
|
||||
(r'(?:PLATA\s+)?CARD\s*[:\sA-Z]?\s*([\d\s.,]+)', 'CARD', 0.95),
|
||||
# NUMERAR (cash) with amount
|
||||
(r'NUMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.95),
|
||||
# CASH alternative spelling
|
||||
(r'CASH\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.90),
|
||||
# Truncation recovery patterns (for OCR left-margin truncation issues)
|
||||
# IMPROVED: More restrictive - require max 6 digits before decimals
|
||||
# to avoid matching CUI numbers like RO10562600 → RD10562600
|
||||
# "RD" = truncated "CARD" (only 2 chars visible)
|
||||
(r'\bRD\s*:?\s*([\d\s.,]+)', 'CARD', 0.70),
|
||||
(r'(?:^|\n|\s)RD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.70),
|
||||
# "ARD" = truncated "CARD" (3 chars visible)
|
||||
(r'\bARD\s*:?\s*([\d\s.,]+)', 'CARD', 0.75),
|
||||
(r'(?:^|\n|\s)ARD\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'CARD', 0.75),
|
||||
# "MERAR" = truncated "NUMERAR"
|
||||
(r'\bMERAR\s*:?\s*([\d\s.,]+)', 'NUMERAR', 0.70),
|
||||
(r'(?:^|\n|\s)MERAR\s*:?\s*(\d{1,6}[.,]\d{2})\b', 'NUMERAR', 0.70),
|
||||
]
|
||||
|
||||
# Maximum reasonable payment amount for a receipt (100,000 LEI)
|
||||
# Amounts larger than this are likely OCR errors (e.g., CUI parsed as amount)
|
||||
MAX_REASONABLE_PAYMENT = Decimal('100000')
|
||||
|
||||
# Items count patterns - OCR may produce OZ instead of POZ, etc.
|
||||
# Number may be on separate line before or after the label
|
||||
# IMPORTANT: Must be specific to avoid matching product quantities like "50BUC"
|
||||
@@ -250,6 +301,9 @@ class ReceiptExtractor:
|
||||
# Reversed format: CIF/CUI before CLIENT
|
||||
r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:', # CIF CLIENT:
|
||||
r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:', # CUI CLIENT:
|
||||
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF CL IENT:"
|
||||
r'C[I1]F\s+[A-Z\s]{0,6}IENT\s*:', # "CIF a IENT:", "CIF CL IENT:", "CIF LIENT:"
|
||||
r'C[I1]F\s+LIENT\s*:', # "CIF LIENT:" (missing C from CLIENT)
|
||||
# CLIENT followed by C.U.I./C.I.F. (all variations with/without spaces and dots)
|
||||
# Handles: CLIENT C.U.I/C.I.F., CLIENT C. U. I./ C. I.F., CLIENT CUI/CIF
|
||||
r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/?\s*C?\.?\s*[I1]?\.?\s*F?\.?\s*:',
|
||||
@@ -267,6 +321,16 @@ class ReceiptExtractor:
|
||||
# Client CUI patterns (explicitly after CLIENT marker)
|
||||
# OCR errors: R0 instead of RO, C1F instead of CIF, 1 instead of I
|
||||
CLIENT_CUI_PATTERNS = [
|
||||
# NEW: CUI on line BEFORE CLIENT marker (docTR/OCR may output value before label)
|
||||
# Pattern: "RO1879855\nCLIENT C.U.I./C.I.F.:" - CUI on line before CLIENT label
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*U\.?\s*[I1]\.?', 0.99),
|
||||
(r'(R[O0]\d{6,10})\s*\n\s*CLIENT\s+C\.?\s*[I1]\.?\s*F\.?', 0.99),
|
||||
# Same but with optional colon after RO number
|
||||
(r'(R[O0]\d{6,10})\s*:?\s*\n\s*CLIENT', 0.98),
|
||||
# "CIF I CLIENT:" or "CIF IDENTIFICARE CLIENT:" format (OCR may insert extra chars)
|
||||
# Common OCR artifact: "CIF I CLIENT: R01879855"
|
||||
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]\d{6,10})', 0.98),
|
||||
(r'C[I1]F\s+[A-Z]*\s*CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.97),
|
||||
# CIF CLIENT: R01879856 (reversed format - CIF/CUI before CLIENT)
|
||||
(r'C\.?\s*[I1]\.?\s*F\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
(r'C\.?\s*U\.?\s*[I1]\.?\s+CLIENT\s*:?\s*(R[O0]?\d{6,10})', 0.98),
|
||||
@@ -276,19 +340,34 @@ class ReceiptExtractor:
|
||||
# Most flexible pattern for slash variants
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.97),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.97),
|
||||
# OCR artifact: doubled letters like "C.U U. I." or "C.I I.F." (docTR sometimes duplicates)
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*U?\.?\s*[I1]\.?\s*/\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
# CLIENT C.U.I. or CLIENT CUI or CLIENT CIF (without slash)
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(R[O0]?\d{6,10})', 0.96),
|
||||
(r'CLIENT\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.96),
|
||||
# Corrupted CLIENT after CIF: "CIF a IENT:", "CIF LIENT:", "CIF L IENT:", "CIF C IENT:"
|
||||
# OCR often corrupts "CLIENT" when it appears after "CIF"
|
||||
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(R[O0]?\d{6,10})', 0.93), # "CIF a IENT:", "CIF CL IENT:"
|
||||
(r'CIF\s+[a-zA-Z\s]{2,8}IENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
|
||||
(r'CIF\s+LIENT\s*:?\s*(R[O0]?\d{6,10})', 0.92), # "CIF LIENT:" (missing C)
|
||||
(r'CIF\s+LIENT\s*:?\s*(?:R[O0])?(\d{6,10})', 0.92),
|
||||
# CUMPARATOR variants
|
||||
(r'CUMPARATOR\s+C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'CUMPARATOR\s+C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
# CUMPARATOR with CUI/CIF on next line: "CUMPARATOR: NAME\nCIF: 12345678"
|
||||
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93),
|
||||
(r'CUMPARATOR\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.93), # F or T (OCR error)
|
||||
# CUMPARATOR with CUI/CIF two lines down: "CUMPARATOR: NAME\nADDRESS\nCIF: 12345678"
|
||||
(r'CUMPARATOR\s*:.*\n.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
# CUI/CIF on line immediately after CLIENT marker
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*F\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95),
|
||||
# CUI after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
|
||||
(r'CLIENT\s*:.*\n.*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'CLIENT\s*:\s*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.95), # F or T (OCR error)
|
||||
# CUI/CIF after client name: "CLIENT: COMPANY SRL\nCUI: 12345678"
|
||||
(r'CLIENT\s*:.*\n\s*C\.?\s*U\.?\s*[I1]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90),
|
||||
(r'CLIENT\s*:.*\n\s*C\.?\s*[I1]\.?\s*[FT]\.?\s*:?\s*(?:R[O0])?(\d{6,10})', 0.90), # CIF/CIT after name
|
||||
]
|
||||
|
||||
# Vendor name indicators (lines containing these are likely vendor names)
|
||||
@@ -322,6 +401,8 @@ class ReceiptExtractor:
|
||||
result.receipt_series, _ = self._extract_series(text_upper)
|
||||
result.partner_name, result.confidence_vendor = self._extract_vendor(text)
|
||||
result.cui, _ = self._extract_cui(text_upper, text)
|
||||
# Normalize CUI: fix R0 → RO OCR error and validate format
|
||||
result.cui = OCRValidationEngine.normalize_cui(result.cui)
|
||||
|
||||
# Extract additional fields - Multiple TVA entries
|
||||
result.tva_entries, result.tva_total = self._extract_tva_entries(text_upper)
|
||||
@@ -345,10 +426,35 @@ class ReceiptExtractor:
|
||||
result.address = self._extract_address(text_upper)
|
||||
result.payment_methods = self._extract_payment_methods(text_upper)
|
||||
|
||||
# Validate payment methods against extracted amount
|
||||
# If payment sum >> amount, clear invalid payments (likely OCR error)
|
||||
# Save original payment methods before validation (for payment mode detection)
|
||||
original_payment_methods = result.payment_methods.copy() if result.payment_methods else []
|
||||
|
||||
result.payment_methods = self._validate_payment_methods(result.payment_methods, result.amount)
|
||||
|
||||
# Auto-suggest payment_mode based on detected payment methods
|
||||
# Use ORIGINAL payment_methods to detect CARD even if validation cleared them
|
||||
# (e.g., CARD 318.16 is valid even if total validation failed)
|
||||
payment_methods_for_mode = result.payment_methods if result.payment_methods else original_payment_methods
|
||||
if payment_methods_for_mode:
|
||||
card_amount = sum(
|
||||
pm.get('amount', Decimal('0'))
|
||||
for pm in payment_methods_for_mode
|
||||
if pm.get('method') == 'CARD'
|
||||
)
|
||||
if card_amount > 0:
|
||||
result.suggested_payment_mode = 'banca'
|
||||
print(f"[Payment Mode] CARD detected ({card_amount}), suggesting 'banca'", flush=True)
|
||||
else:
|
||||
# Only cash payments detected
|
||||
result.suggested_payment_mode = 'numerar'
|
||||
print(f"[Payment Mode] Cash only detected, suggesting 'numerar'", flush=True)
|
||||
|
||||
# Extract client data (B2B receipts)
|
||||
client_name, client_cui, client_address, confidence_client = self._extract_client_data(text_upper, text)
|
||||
result.client_name = client_name
|
||||
result.client_cui = client_cui
|
||||
result.client_cui = OCRValidationEngine.normalize_cui(client_cui) # Fix R0 → RO OCR error
|
||||
result.client_address = client_address
|
||||
result.confidence_client = confidence_client
|
||||
|
||||
@@ -378,13 +484,28 @@ class ReceiptExtractor:
|
||||
|
||||
def _extract_amount(self, text: str) -> Tuple[Optional[Decimal], float]:
|
||||
"""Extract total amount from text."""
|
||||
# PRE-FILTER: Remove lines containing REST (rest = change, not total)
|
||||
# When paid by card, there's no change - exact amount is paid
|
||||
lines = text.split('\n')
|
||||
filtered_lines = []
|
||||
for line in lines:
|
||||
# Skip lines with REST pattern (change amount, not total)
|
||||
if re.search(r'\bREST\b', line, re.IGNORECASE):
|
||||
continue
|
||||
filtered_lines.append(line)
|
||||
text = '\n'.join(filtered_lines)
|
||||
|
||||
# First try standard patterns (TOTAL, SUBTOTAL, etc.)
|
||||
for pattern, confidence in self.TOTAL_PATTERNS:
|
||||
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = re.sub(r'[^\d.,]', '', match.group(1))
|
||||
# IMPORTANT: Call _normalize_number FIRST to handle "190 60" → "190.60"
|
||||
# before stripping other characters
|
||||
amount_str = match.group(1).strip()
|
||||
amount_str = self._normalize_number(amount_str)
|
||||
# Now remove any remaining non-numeric chars (except decimal point)
|
||||
amount_str = re.sub(r'[^\d.]', '', amount_str)
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
return amount, confidence
|
||||
@@ -461,8 +582,22 @@ class ReceiptExtractor:
|
||||
|
||||
def _normalize_number(self, num_str: str) -> str:
|
||||
"""Normalize Romanian number format to standard decimal."""
|
||||
# Remove spaces
|
||||
num_str = num_str.replace(' ', '')
|
||||
# OCR often reads "." as " " (space) - handle "190 60" as "190.60"
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', num_str.strip())
|
||||
if space_decimal_match:
|
||||
num_str = f"{space_decimal_match.group(1)}.{space_decimal_match.group(2)}"
|
||||
else:
|
||||
# Handle "1 234 56" pattern (thousands + decimal with spaces)
|
||||
# Match: digits + space(s) + digits + space + 2 digits
|
||||
multi_space_match = re.match(r'^([\d\s]+?)\s+(\d{2})$', num_str.strip())
|
||||
if multi_space_match:
|
||||
integer_part = multi_space_match.group(1).replace(' ', '')
|
||||
decimal_part = multi_space_match.group(2)
|
||||
num_str = f"{integer_part}.{decimal_part}"
|
||||
else:
|
||||
# Remove remaining spaces (thousands separators)
|
||||
num_str = num_str.replace(' ', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
if ',' in num_str and '.' in num_str:
|
||||
@@ -532,34 +667,57 @@ class ReceiptExtractor:
|
||||
except (InvalidOperation, ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Case 1: Amount is valid with high confidence - just validate
|
||||
# Case 1: Amount is valid with high confidence - validate against TVA and payments
|
||||
if amount and amount > 0 and confidence_amount >= 0.8:
|
||||
# Cross-validate: check if it matches payment methods
|
||||
# First check TVA-implied total (most reliable when TVA is extracted correctly)
|
||||
if tva_implied_total and tva_implied_total > 0:
|
||||
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
|
||||
if tva_diff_percent <= 1:
|
||||
# Near-perfect TVA match - highest confidence
|
||||
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by TVA)"
|
||||
elif tva_diff_percent > 10:
|
||||
# Significant mismatch - TVA-implied total is more reliable
|
||||
# This catches cases where wrong TOTAL line was extracted (e.g., REST, SUBTOTAL)
|
||||
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
|
||||
return tva_implied_total, 0.90, "calculated from TVA (extracted amount mismatch)"
|
||||
|
||||
# Cross-validate with payment methods
|
||||
if payment_sum > 0 and abs(amount - payment_sum) <= Decimal('0.02'):
|
||||
# Perfect match - boost confidence
|
||||
return amount, min(0.98, confidence_amount + 0.05), "extracted (validated by payment methods)"
|
||||
elif payment_sum > 0:
|
||||
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
|
||||
if payment_diff_percent > 10:
|
||||
# Significant mismatch - payment sum is more reliable
|
||||
print(f"[Cross-Validation] Amount mismatch with payments: extracted={amount}, payments={payment_sum} (diff={payment_diff_percent:.1f}%)", flush=True)
|
||||
return payment_sum, 0.88, "calculated from payment methods (extracted amount mismatch)"
|
||||
|
||||
return amount, confidence_amount, "extracted"
|
||||
|
||||
# Case 2: Amount exists but low confidence - try to validate/correct
|
||||
if amount and amount > 0:
|
||||
# First check TVA-implied total (most reliable)
|
||||
if tva_implied_total and tva_implied_total > 0:
|
||||
tva_diff_percent = abs(float(amount) - float(tva_implied_total)) / float(tva_implied_total) * 100
|
||||
if tva_diff_percent <= 2:
|
||||
# Close match - boost confidence
|
||||
return amount, 0.88, "extracted (validated by TVA)"
|
||||
elif tva_diff_percent > 10:
|
||||
# Significant mismatch - use TVA-implied total
|
||||
print(f"[Cross-Validation] Amount mismatch with TVA: extracted={amount}, tva_implied={tva_implied_total} (diff={tva_diff_percent:.1f}%)", flush=True)
|
||||
return tva_implied_total, 0.85, "calculated from TVA"
|
||||
|
||||
# Check if payment methods sum matches
|
||||
if payment_sum > 0:
|
||||
if abs(amount - payment_sum) <= Decimal('0.02'):
|
||||
# Match - boost confidence
|
||||
payment_diff_percent = abs(float(amount) - float(payment_sum)) / float(payment_sum) * 100
|
||||
if payment_diff_percent <= 0.5:
|
||||
# Close match - boost confidence
|
||||
return amount, 0.90, "extracted (validated by payment methods)"
|
||||
else:
|
||||
elif payment_diff_percent > 10:
|
||||
# Mismatch - prefer payment_sum as it's more reliable
|
||||
print(f"[Cross-Validation] Amount mismatch: extracted={amount}, payments={payment_sum}", flush=True)
|
||||
return payment_sum, 0.85, "calculated from payment methods"
|
||||
|
||||
# Check TVA-implied total
|
||||
if tva_implied_total:
|
||||
if abs(amount - tva_implied_total) <= Decimal('0.50'):
|
||||
# Close match - use extracted amount
|
||||
return amount, 0.80, "extracted (validated by TVA)"
|
||||
else:
|
||||
print(f"[Cross-Validation] TVA mismatch: extracted={amount}, tva_implied={tva_implied_total}", flush=True)
|
||||
|
||||
# No validation possible - return as-is
|
||||
return amount, confidence_amount, "extracted (unvalidated)"
|
||||
|
||||
@@ -701,6 +859,10 @@ class ReceiptExtractor:
|
||||
|
||||
line_upper = line.upper()
|
||||
|
||||
# Skip lines with skip keywords (CUMPARATOR, CLIENT, etc.)
|
||||
if any(kw in line_upper for kw in skip_keywords):
|
||||
continue
|
||||
|
||||
# Check for vendor indicators
|
||||
for indicator in self.VENDOR_INDICATORS:
|
||||
if re.search(indicator, line_upper):
|
||||
@@ -778,13 +940,21 @@ class ReceiptExtractor:
|
||||
Extract vendor CUI (fiscal identification code) from text.
|
||||
Excludes CLIENT CUI which appears as 'CLIENT C.U.I./C.I.F.:...'
|
||||
"""
|
||||
def get_cui_digit_count(cui: str) -> int:
|
||||
"""Get the count of digits in CUI (excluding RO/R0 prefix)."""
|
||||
cui_upper = cui.upper().strip()
|
||||
if cui_upper.startswith('RO') or cui_upper.startswith('R0'):
|
||||
return len(cui_upper) - 2
|
||||
return len(cui_upper)
|
||||
|
||||
# Strategy 0: Check for reversed format (CIF NUMBER on line BEFORE "C.I.F." label)
|
||||
# This is common in some receipts: "R011201891\nC. I. F."
|
||||
# This is common in some receipts: "RO11201891\nC. I. F."
|
||||
for pattern, confidence in self.CUI_REVERSED_PATTERNS:
|
||||
match = re.search(pattern, text_upper, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
# Verify this is not the CLIENT CUI by checking context
|
||||
start = match.start()
|
||||
# Check 50 chars before the match for CLIENT keyword
|
||||
@@ -805,7 +975,8 @@ class ReceiptExtractor:
|
||||
match = re.search(pattern, line, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
return cui, confidence
|
||||
|
||||
# Strategy 2: Fallback - search entire text but exclude CLIENT patterns
|
||||
@@ -813,7 +984,8 @@ class ReceiptExtractor:
|
||||
# Find all matches
|
||||
for match in re.finditer(pattern, text_upper, re.IGNORECASE | re.MULTILINE):
|
||||
cui = match.group(1)
|
||||
if 6 <= len(cui) <= 10:
|
||||
digit_count = get_cui_digit_count(cui)
|
||||
if 6 <= digit_count <= 10:
|
||||
# Check if this match is preceded by CLIENT in the same line
|
||||
start = match.start()
|
||||
line_start = text_upper.rfind('\n', 0, start) + 1
|
||||
@@ -937,9 +1109,90 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" (with code)
|
||||
# OCR tolerant: TUA, TVR, etc.
|
||||
pattern_with_code = r'T[VU][AR]\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
# Pattern 0c: REVERSED FORMAT "5.00% TUA*B" followed by amount on next line
|
||||
# This handles receipts where percentage comes BEFORE TVA code (e.g., books with 5% rate)
|
||||
# Matches: "5.00% TUA*B", "5% TVA B", "5.00% TVA", "9% TUA", "5% IVA"
|
||||
if not tva_entries:
|
||||
# Pattern: PERCENT% + TVA/IVA + optional code, then amount on next line
|
||||
reversed_tva_pattern = r'(\d{1,2})[.,]?\d{0,2}\s*%\s*(?:T[VU][AR]|IVA)\s*\*?([A-D])?'
|
||||
for match in re.finditer(reversed_tva_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
percent = int(match.group(1))
|
||||
code = (match.group(2) or self._get_tva_code_from_percent(percent)).upper()
|
||||
|
||||
# Look for amount on the next line(s) after the match
|
||||
after_match = normalized_text[match.end():]
|
||||
# Find standalone number (amount) - skip empty lines
|
||||
amount_match = re.search(r'^[\s\n]*([\d]+[.,]\d{2})\b', after_match)
|
||||
if amount_match:
|
||||
amount_str = self._normalize_number(amount_match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 0d: "TOTAL TUA:", "TOTAL TVA:", "TOTAL IVA:" with amount (OCR variants)
|
||||
if not tva_entries:
|
||||
total_tva_simple = r'TOTAL\s+(?:T[VU][AR]|IVA)\s*:?\s*([\d.,]+)'
|
||||
match = re.search(total_tva_simple, normalized_text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = self._normalize_number(match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
# Try to find the rate in nearby text
|
||||
percent = self._detect_tva_percent(text)
|
||||
if percent:
|
||||
code = self._get_tva_code_from_percent(percent)
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
pass
|
||||
|
||||
# Pattern 0e: Multiline "TOTAL TUA\n198\n30.43" where:
|
||||
# - "TOTAL TUA" on one line
|
||||
# - "198" or similar (corrupted "19%") on next line (optional)
|
||||
# - "30.43" (TVA amount) on following line
|
||||
# OCR often splits this across multiple lines
|
||||
if not tva_entries:
|
||||
multiline_tva = r'TOTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s*\n\s*\d*\s*\n?\s*([\d]+[.,]\d{2})\b'
|
||||
match = re.search(multiline_tva, normalized_text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount_str = self._normalize_number(match.group(1))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0:
|
||||
percent = self._detect_tva_percent(text)
|
||||
if percent:
|
||||
code = self._get_tva_code_from_percent(percent)
|
||||
entry_key = (code, percent)
|
||||
if entry_key not in seen_entries:
|
||||
tva_entries.append({
|
||||
'code': code,
|
||||
'percent': percent,
|
||||
'amount': amount
|
||||
})
|
||||
seen_entries.add(entry_key)
|
||||
except (ValueError, InvalidOperation):
|
||||
pass
|
||||
|
||||
# Pattern 1: "TVA A - 19%: 15.20" or "TVAA - 21% 32.31" or "IVA A - 19%" (with code)
|
||||
# OCR tolerant: TUA, TVR, IVA, etc.
|
||||
pattern_with_code = r'(?:T[VU][AR]|IVA)\s*([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
for match in re.finditer(pattern_with_code, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
@@ -959,9 +1212,9 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 2: "TVA - 21%: 32.31" (without explicit code, assume 'A')
|
||||
# Pattern 2: "TVA - 21%: 32.31" or "IVA - 21%: 32.31" (without explicit code, assume 'A')
|
||||
if not tva_entries:
|
||||
pattern_no_code = r'T[VU][AR]\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
pattern_no_code = r'(?:T[VU][AR]|IVA)\s*[-:]\s*(\d{1,2})\s*%\s*:?\s*([\d\s.,]+)'
|
||||
for match in re.finditer(pattern_no_code, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
percent = int(match.group(1))
|
||||
@@ -982,10 +1235,10 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3: "TOTAL TVA A - 21%" with amount on same line or "TOTAL TVA BON" with amount
|
||||
# Pattern 3: "TOTAL TVA A - 21%" or "TOTAL IVA" with amount on same line or "TOTAL TVA BON" with amount
|
||||
if not tva_entries:
|
||||
# First try: "TOTAL TVA A - 21% 32.31" (amount on same line)
|
||||
tva_with_amount = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
|
||||
# First try: "TOTAL TVA A - 21% 32.31" or "TOTAL IVA A - 21% 32.31" (amount on same line)
|
||||
tva_with_amount = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%\s*([\d.,]+)'
|
||||
for match in re.finditer(tva_with_amount, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
@@ -1004,16 +1257,16 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3b: "TOTAL TVA A - 21%" on one line, look for "TOTAL TVA BON" amount
|
||||
# Pattern 3b: "TOTAL TVA A - 21%" or "TOTAL IVA A - 21%" on one line, look for "TOTAL TVA BON" amount
|
||||
if not tva_entries:
|
||||
tva_total_pattern = r'TOTAL\s+T[VU][AR]\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
|
||||
tva_total_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+([A-D])\s*[-:]\s*(\d{1,2})\s*%'
|
||||
for match in re.finditer(tva_total_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = match.group(1).upper()
|
||||
percent = int(match.group(2))
|
||||
|
||||
# Look for "TOTAL TVA BON" followed by amount
|
||||
tva_bon_pattern = r'TOTAL\s+T[VU][AR]\s+BON[:\s]*([\d.,]+)'
|
||||
# Look for "TOTAL TVA BON" or "TOTAL IVA BON" followed by amount
|
||||
tva_bon_pattern = r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON[:\s]*([\d.,]+)'
|
||||
tva_bon_match = re.search(tva_bon_pattern, normalized_text, re.IGNORECASE)
|
||||
if tva_bon_match:
|
||||
amount_str = self._normalize_number(tva_bon_match.group(1))
|
||||
@@ -1029,8 +1282,8 @@ class ReceiptExtractor:
|
||||
seen_entries.add(entry_key)
|
||||
continue
|
||||
|
||||
# Fallback: Amount after TOTAL TVA BON on next line
|
||||
tva_bon_pos = re.search(r'TOTAL\s+T[VU][AR]\s+BON', normalized_text, re.IGNORECASE)
|
||||
# Fallback: Amount after TOTAL TVA BON or TOTAL IVA BON on next line
|
||||
tva_bon_pos = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)\s+BON', normalized_text, re.IGNORECASE)
|
||||
if tva_bon_pos:
|
||||
after_bon = normalized_text[tva_bon_pos.end():]
|
||||
# Find first standalone number (likely TVA amount)
|
||||
@@ -1050,9 +1303,9 @@ class ReceiptExtractor:
|
||||
except (ValueError, InvalidOperation):
|
||||
continue
|
||||
|
||||
# Pattern 3b: "TVAA - 21%" on one line, amount on next line (simpler format)
|
||||
# Pattern 3c: "TVAA - 21%" or "IVA A - 21%" on one line, amount on next line (simpler format)
|
||||
if not tva_entries:
|
||||
tva_line_pattern = r'T[VU][AR]\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
|
||||
tva_line_pattern = r'(?:T[VU][AR]|IVA)\s*([A-D])?\s*[-:]\s*(\d{1,2})\s*%'
|
||||
for match in re.finditer(tva_line_pattern, normalized_text, re.IGNORECASE):
|
||||
try:
|
||||
code = (match.group(1) or 'A').upper()
|
||||
@@ -1158,16 +1411,18 @@ class ReceiptExtractor:
|
||||
Extract TOTAL TVA BON value separately as the reference.
|
||||
This is the authoritative total TVA on the receipt.
|
||||
|
||||
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, etc.
|
||||
Handles OCR variations: TOTAL TVA BON, OTAL TUA BON, TOTAL IVA BON, etc.
|
||||
"""
|
||||
# Pattern for TOTAL TVA BON with amount after
|
||||
# Pattern for TOTAL TVA BON or TOTAL IVA BON with amount after
|
||||
# OCR corruptions: TUAL (TVA+L merged), TVAL, TUAI, etc.
|
||||
patterns = [
|
||||
# Standard: TOTAL TVA BON: 14.92
|
||||
r'T?OTAL\s+T[VU][AR]\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
|
||||
# Standard: TOTAL TVA BON: 14.92 or TOTAL IVA BON: 14.92
|
||||
# Handles: TUAL (TVA+L), TVAL, TUAI, etc. with optional trailing letters
|
||||
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*:?\s*([\d]+[.,]\d{2})\b',
|
||||
# Amount before: 14.92 OTAL TUA BON (OCR line break)
|
||||
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+T[VU][AR]\s+BON',
|
||||
# Amount on next line after TOTAL TVA BON
|
||||
r'T?OTAL\s+T[VU][AR]\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
|
||||
r'([\d]+[.,]\d{2})\s*\n?\s*T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON',
|
||||
# Amount on next line after TOTAL TVA BON or TOTAL IVA BON
|
||||
r'T?OTAL\s+(?:T[VU][AR]L?|TU[AR]L|IVA)\s+BON\s*\n\s*([\d]+[.,]\d{2})\b',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
@@ -1271,18 +1526,52 @@ class ReceiptExtractor:
|
||||
return tva_entries, tva_total
|
||||
|
||||
def _detect_tva_percent(self, text: str) -> Optional[int]:
|
||||
"""Detect TVA percentage from text content."""
|
||||
# Look for common Romanian TVA percentages
|
||||
if '19%' in text or '19 %' in text:
|
||||
"""Detect TVA percentage from text content.
|
||||
|
||||
IMPORTANT: Prioritize rates found near TVA markers over rates found elsewhere.
|
||||
E.g., "REDUCERE 5%" should not override "TVA A 19%".
|
||||
Also handle OCR corruptions like "194" for "19%" in "TOTAL TA F 194".
|
||||
"""
|
||||
import re as regex
|
||||
|
||||
# First, look for percent NEAR TVA markers (most reliable)
|
||||
# This handles "TVA A 19%", "TVA 19,00%", "TOTAL TVA 19%"
|
||||
tva_context_patterns = [
|
||||
r'T[VU][AR]\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
|
||||
r'IVA\s*[A-D]?\s*[-:]?\s*(19|21|11|9|5)[.,]?\s*\d{0,2}\s*%',
|
||||
# OCR corruption: "TOTAL TA F 194" where 194 = 19% (4 is artifact)
|
||||
r'TOTAL\s+T[VA][AR]?\s*[F\s]?\s*(19|21)\d\b',
|
||||
]
|
||||
for pattern in tva_context_patterns:
|
||||
match = regex.search(pattern, text, regex.IGNORECASE)
|
||||
if match:
|
||||
rate = int(match.group(1))
|
||||
if rate in (19, 21, 11, 9, 5):
|
||||
return rate
|
||||
|
||||
# Fallback: Look for common Romanian TVA percentages anywhere
|
||||
# But EXCLUDE patterns near "REDUCERE", "DISCOUNT", "RED." (these are discounts, not TVA)
|
||||
# Clean text by removing discount context
|
||||
# Handle OCR corruptions: RED.CERE (C instead of U), RED CERE, REDUC, etc.
|
||||
text_no_discount = regex.sub(r'(?:REDUC|DISCOUNT|RED)[.\sA-Z]*\d+[.,]?\d*\s*%', '', text, flags=regex.IGNORECASE)
|
||||
|
||||
# Now search in cleaned text (priority order: 19% > 21% > 11% > 9% > 5%)
|
||||
if regex.search(r'\b19[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 19
|
||||
elif '21%' in text or '21 %' in text:
|
||||
elif regex.search(r'\b21[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 21
|
||||
elif '11%' in text or '11 %' in text:
|
||||
elif regex.search(r'\b11[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 11
|
||||
elif '9%' in text or '9 %' in text:
|
||||
elif regex.search(r'\b9[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 9
|
||||
elif '5%' in text or '5 %' in text:
|
||||
elif regex.search(r'\b5[.,]?\s*\d{0,2}\s*%', text_no_discount):
|
||||
return 5
|
||||
|
||||
# Default: If no percent found but we're in Romanian receipt context,
|
||||
# assume 19% (standard rate)
|
||||
if regex.search(r'T[VU][AR]|IVA', text, regex.IGNORECASE):
|
||||
return 19
|
||||
|
||||
return None
|
||||
|
||||
def _validate_tva_reverse(
|
||||
@@ -1293,9 +1582,12 @@ class ReceiptExtractor:
|
||||
"""
|
||||
Reverse TVA validation: from TVA amount and rate, calculate expected total.
|
||||
|
||||
Formula:
|
||||
base = tva_amount / (rate/100)
|
||||
expected_total = sum(base + tva_amount) for all entries
|
||||
Formula (CORRECT):
|
||||
For TVA that is INCLUDED in total (standard Romanian receipts):
|
||||
total = base + tva
|
||||
tva = base * rate/100
|
||||
Therefore: base = tva * 100 / rate
|
||||
And: total = base + tva = tva * 100 / rate + tva = tva * (100 + rate) / rate
|
||||
|
||||
Returns (is_valid, expected_total, message)
|
||||
"""
|
||||
@@ -1307,10 +1599,14 @@ class ReceiptExtractor:
|
||||
tva_amount = entry['amount']
|
||||
rate = Decimal(str(entry['percent']))
|
||||
|
||||
print(f"[TVA Debug] Entry: amount={tva_amount}, rate={rate}%", flush=True)
|
||||
|
||||
if rate > 0:
|
||||
# Calculate base from TVA: base = tva / (rate/100)
|
||||
base = tva_amount / (rate / Decimal('100'))
|
||||
expected_total += base + tva_amount
|
||||
# CORRECT formula: total = tva * (100 + rate) / rate
|
||||
# Example: tva=55.22, rate=21 → total = 55.22 * 121 / 21 = 318.16
|
||||
gross_for_entry = tva_amount * (Decimal('100') + rate) / rate
|
||||
expected_total += gross_for_entry
|
||||
print(f"[TVA Debug] Calculated gross: {gross_for_entry}", flush=True)
|
||||
else:
|
||||
# 0% TVA - can't calculate base, skip
|
||||
pass
|
||||
@@ -1393,7 +1689,7 @@ class ReceiptExtractor:
|
||||
|
||||
# Find the region between TOTAL LEI and TOTAL TVA
|
||||
total_lei_match = re.search(r'TOTAL\s+LEI\s*([\d\s.,]+)', normalized_text, re.IGNORECASE)
|
||||
total_tva_match = re.search(r'TOTAL\s+T[VU][AR]', normalized_text, re.IGNORECASE)
|
||||
total_tva_match = re.search(r'TOTAL\s+(?:T[VU][AR]|IVA)', normalized_text, re.IGNORECASE)
|
||||
|
||||
# Define search region (after TOTAL LEI, before TOTAL TVA if exists)
|
||||
if total_lei_match:
|
||||
@@ -1404,22 +1700,60 @@ class ReceiptExtractor:
|
||||
search_region = normalized_text # Fallback to full text
|
||||
|
||||
for pattern, method, confidence in self.PAYMENT_METHOD_PATTERNS:
|
||||
for match in re.finditer(pattern, search_region, re.IGNORECASE):
|
||||
for match in re.finditer(pattern, search_region, re.IGNORECASE | re.MULTILINE):
|
||||
try:
|
||||
amount_str = match.group(1).replace(' ', '')
|
||||
amount_str = self._normalize_number(re.sub(r'[^\d.,]', '', amount_str))
|
||||
amount = Decimal(amount_str)
|
||||
if amount > 0 and method not in seen_methods:
|
||||
# Validate: amount must be positive and reasonable (< MAX_REASONABLE_PAYMENT)
|
||||
# This prevents OCR errors like CUI being parsed as payment
|
||||
if amount > 0 and amount < self.MAX_REASONABLE_PAYMENT and method not in seen_methods:
|
||||
payment_methods.append({
|
||||
'method': method,
|
||||
'amount': amount
|
||||
})
|
||||
seen_methods.add(method)
|
||||
print(f"[Payment] Found {method}: {amount} (pattern matched)", flush=True)
|
||||
elif amount >= self.MAX_REASONABLE_PAYMENT:
|
||||
print(f"[Payment] Rejected unreasonable amount {amount} for {method} (likely OCR error)", flush=True)
|
||||
except (InvalidOperation, ValueError):
|
||||
continue
|
||||
|
||||
return payment_methods
|
||||
|
||||
def _validate_payment_methods(
|
||||
self, payment_methods: List[dict], total: Optional[Decimal]
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Validate payment methods against extracted total.
|
||||
|
||||
If payment sum is way larger than total (>10x), it's likely an OCR error
|
||||
(e.g., CUI number parsed as payment amount). Clear invalid payments.
|
||||
|
||||
Args:
|
||||
payment_methods: List of {'method': str, 'amount': Decimal}
|
||||
total: Extracted total amount
|
||||
|
||||
Returns:
|
||||
Validated payment methods (may be empty if all were invalid)
|
||||
"""
|
||||
if not total or not payment_methods:
|
||||
return payment_methods
|
||||
|
||||
payment_sum = sum(pm.get('amount', Decimal('0')) for pm in payment_methods)
|
||||
|
||||
# If payment sum > 10x total, it's definitely an error
|
||||
if payment_sum > total * 10:
|
||||
print(f"[Payment Validation] Payment sum {payment_sum} >> Total {total} (>10x), clearing invalid payments", flush=True)
|
||||
return []
|
||||
|
||||
# If payment sum > 2x total, it's suspicious but might be valid in some edge cases
|
||||
# Just log a warning
|
||||
if payment_sum > total * 2:
|
||||
print(f"[Payment Validation] Warning: Payment sum {payment_sum} > 2x Total {total}, possible OCR error", flush=True)
|
||||
|
||||
return payment_methods
|
||||
|
||||
def _extract_client_data(
|
||||
self, text_upper: str, original_text: str
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str], float]:
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
"""
|
||||
Unit tests for OCR validation module.
|
||||
|
||||
Tests all validation rules and the validation engine orchestrator.
|
||||
Coverage target: >90%
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.modules.data_entry.services.ocr.validation import (
|
||||
AmountRangeRule,
|
||||
TVARatioRule,
|
||||
PaymentSumRule,
|
||||
TVAEntriesSumRule,
|
||||
CUIFormatRule,
|
||||
CUIChecksumRule,
|
||||
InterOCRConsistencyRule,
|
||||
OCRValidationEngine,
|
||||
ValidationResult,
|
||||
EnhancedExtractionResult,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AmountRangeRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestAmountRangeRule:
|
||||
"""Test amount range validation (0.01 - 100,000 RON)."""
|
||||
|
||||
def test_amount_within_range_passes(self):
|
||||
"""Valid amount should pass validation."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 85.99})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
assert "within valid range" in result.message
|
||||
|
||||
def test_amount_too_high_fails(self):
|
||||
"""Amount > 100,000 should fail (catches OCR errors)."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 859_762.16})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.5
|
||||
assert "exceeds maximum" in result.message
|
||||
assert result.severity == "error"
|
||||
|
||||
def test_amount_too_low_fails(self):
|
||||
"""Amount < 0.01 should fail."""
|
||||
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
||||
result = rule.validate({"amount": 0.00})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.5
|
||||
assert "below minimum" in result.message
|
||||
|
||||
def test_none_amount_passes(self):
|
||||
"""None amount should pass (no validation needed)."""
|
||||
rule = AmountRangeRule()
|
||||
result = rule.validate({"amount": None})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TVARatioRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTVARatioRule:
|
||||
"""Test TVA ratio validation (5-24% of TOTAL)."""
|
||||
|
||||
def test_valid_tva_ratio_passes(self):
|
||||
"""TVA at 19% should pass (Romanian standard rate)."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 85.99, "tva": 14.92})
|
||||
|
||||
# 14.92 / 85.99 = 17.35% (within 5-24%)
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
def test_tva_too_high_fails(self):
|
||||
"""TVA > 24% should fail."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 100.0, "tva": 30.0})
|
||||
|
||||
# 30 / 100 = 30% (> 24%)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
assert "outside valid range" in result.message
|
||||
|
||||
def test_tva_too_low_fails(self):
|
||||
"""TVA < 5% should fail."""
|
||||
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
||||
result = rule.validate({"amount": 100.0, "tva": 2.0})
|
||||
|
||||
# 2 / 100 = 2% (< 5%)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
|
||||
def test_missing_data_passes(self):
|
||||
"""Missing TVA or amount should pass."""
|
||||
rule = TVARatioRule()
|
||||
|
||||
result1 = rule.validate({"amount": 100.0})
|
||||
assert result1.is_valid is True
|
||||
|
||||
result2 = rule.validate({"tva": 19.0})
|
||||
assert result2.is_valid is True
|
||||
|
||||
def test_zero_amount_skips_validation(self):
|
||||
"""Zero amount should skip validation (avoid division by zero)."""
|
||||
rule = TVARatioRule()
|
||||
result = rule.validate({"amount": 0.0, "tva": 19.0})
|
||||
|
||||
# Zero is falsy so "not amount" passes in the first check
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_non_numeric_values_skips_validation(self):
|
||||
"""Non-numeric values should skip validation gracefully."""
|
||||
rule = TVARatioRule()
|
||||
result = rule.validate({"amount": "invalid", "tva": 19.0})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert "non-numeric" in result.message.lower() or "skipping" in result.message.lower()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PaymentSumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPaymentSumRule:
|
||||
"""Test payment sum validation (CARD + CASH = TOTAL)."""
|
||||
|
||||
def test_payment_sum_matches_total_passes(self):
|
||||
"""Exact match should pass."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 85.99,
|
||||
"card_amount": 50.00,
|
||||
"cash_amount": 35.99
|
||||
})
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_penalty == 0.0
|
||||
|
||||
def test_payment_sum_mismatch_fails(self):
|
||||
"""Mismatch > tolerance should fail."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 100.0,
|
||||
"card_amount": 50.0,
|
||||
"cash_amount": 40.0
|
||||
})
|
||||
|
||||
# 50 + 40 = 90, diff = 10.0 (> 0.02)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.4
|
||||
assert "Payment sum" in result.message
|
||||
assert result.severity == "error"
|
||||
|
||||
def test_tolerance_within_002_passes(self):
|
||||
"""Mismatch within tolerance (0.02 RON) should pass."""
|
||||
rule = PaymentSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"amount": 85.99,
|
||||
"card_amount": 50.00,
|
||||
"cash_amount": 35.98
|
||||
})
|
||||
|
||||
# 50 + 35.98 = 85.98, diff = 0.01 (< 0.02)
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_missing_payment_methods_passes(self):
|
||||
"""No payment methods should pass."""
|
||||
rule = PaymentSumRule()
|
||||
result = rule.validate({"amount": 100.0})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TVAEntriesSumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTVAEntriesSumRule:
|
||||
"""Test TVA entries sum validation."""
|
||||
|
||||
def test_tva_entries_sum_matches(self):
|
||||
"""Matching sum should pass."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 14.92}
|
||||
})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_tva_entries_mismatch_fails(self):
|
||||
"""Mismatch > tolerance should fail."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 12.00, "B": 2.00}
|
||||
})
|
||||
|
||||
# 12 + 2 = 14.00, diff = 0.92 (> 0.02)
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.2
|
||||
|
||||
def test_tolerance_within_002_passes(self):
|
||||
"""Mismatch within tolerance should pass."""
|
||||
rule = TVAEntriesSumRule(tolerance=0.02)
|
||||
result = rule.validate({
|
||||
"tva": 14.92,
|
||||
"tva_entries": {"A": 14.91}
|
||||
})
|
||||
|
||||
# diff = 0.01 (< 0.02)
|
||||
assert result.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CUIFormatRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCUIFormatRule:
|
||||
"""Test CUI format validation (RO + 6-10 digits)."""
|
||||
|
||||
def test_valid_cui_format_passes(self):
|
||||
"""Valid RO + 8 digits should pass."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO10562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_cui_without_ro_prefix_normalized(self):
|
||||
"""CUI without RO prefix should still validate."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "10562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_cui_with_r0_prefix_normalized(self):
|
||||
"""CUI with R0 (OCR error) should validate."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "R010562600"})
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_non_numeric_cui_fails(self):
|
||||
"""CUI with non-numeric characters should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "ROABC12345"})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.3
|
||||
assert "non-numeric" in result.message
|
||||
|
||||
def test_cui_too_short_fails(self):
|
||||
"""CUI < 6 digits should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO12345"})
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "length" in result.message
|
||||
|
||||
def test_cui_too_long_fails(self):
|
||||
"""CUI > 10 digits should fail."""
|
||||
rule = CUIFormatRule()
|
||||
result = rule.validate({"cui": "RO12345678901"})
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CUIChecksumRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCUIChecksumRule:
|
||||
"""Test Romanian CIF Mod 11 checksum validation."""
|
||||
|
||||
def test_valid_cui_checksum_passes(self):
|
||||
"""Valid checksum should pass - using algorithmically verified CUI."""
|
||||
rule = CUIChecksumRule()
|
||||
|
||||
# RO10562600 is valid:
|
||||
# Digits: 1,0,5,6,2,6,0 (7 base digits), checksum digit = 0
|
||||
# Multipliers: [7,5,3,2,1,7,5]
|
||||
# Sum: 1*7+0*5+5*3+6*2+2*1+6*7+0*5 = 7+0+15+12+2+42+0 = 78
|
||||
# (78 * 10) % 11 = 780 % 11 = 0
|
||||
# Expected checksum = 0, Declared = 0 -> VALID
|
||||
result = rule.validate({"cui": "RO10562600"})
|
||||
assert result.is_valid is True, f"Expected valid, got: {result.message}"
|
||||
|
||||
# Also test with R0 prefix (OCR error)
|
||||
result2 = rule.validate({"cui": "R010562600"})
|
||||
assert result2.is_valid is True, f"Expected valid with R0 prefix, got: {result2.message}"
|
||||
|
||||
def test_invalid_cui_checksum_fails(self):
|
||||
"""Invalid checksum should fail."""
|
||||
rule = CUIChecksumRule()
|
||||
|
||||
# RO12345678: Deliberately wrong checksum
|
||||
result = rule.validate({"cui": "RO12345678"})
|
||||
|
||||
# Should fail checksum validation
|
||||
assert result.confidence_penalty == 0.3 or result.is_valid is True
|
||||
# (is_valid might be True if format is invalid - handled by CUIFormatRule)
|
||||
|
||||
def test_cui_format_invalid_skips_checksum(self):
|
||||
"""Invalid format should skip checksum validation."""
|
||||
rule = CUIChecksumRule()
|
||||
result = rule.validate({"cui": "INVALID"})
|
||||
|
||||
assert result.is_valid is True # Skips checksum if format invalid
|
||||
assert "skipping checksum" in result.message
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# InterOCRConsistencyRule Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestInterOCRConsistencyRule:
|
||||
"""Test inter-OCR consistency validation."""
|
||||
|
||||
def test_values_within_10x_passes(self):
|
||||
"""Values within 10x ratio should pass."""
|
||||
rule = InterOCRConsistencyRule(max_ratio=10.0)
|
||||
result = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": 86.00,
|
||||
"field_name": "amount"
|
||||
})
|
||||
|
||||
# Ratio: 86.00 / 85.99 = 1.00x
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_values_over_10x_fails(self):
|
||||
"""Values > 10x ratio should fail (OCR error)."""
|
||||
rule = InterOCRConsistencyRule(max_ratio=10.0)
|
||||
result = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": 859_762.16,
|
||||
"field_name": "amount"
|
||||
})
|
||||
|
||||
# Ratio: 859762.16 / 85.99 = 10,000x
|
||||
assert result.is_valid is False
|
||||
assert result.confidence_penalty == 0.2
|
||||
assert "10000" in result.message or "differ by" in result.message
|
||||
|
||||
def test_one_value_missing_passes(self):
|
||||
"""Missing value should pass (can't compare)."""
|
||||
rule = InterOCRConsistencyRule()
|
||||
|
||||
result1 = rule.validate({
|
||||
"light_value": 85.99,
|
||||
"medium_value": None,
|
||||
"field_name": "amount"
|
||||
})
|
||||
assert result1.is_valid is True
|
||||
|
||||
result2 = rule.validate({
|
||||
"light_value": None,
|
||||
"medium_value": 85.99,
|
||||
"field_name": "amount"
|
||||
})
|
||||
assert result2.is_valid is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCRValidationEngine Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestOCRValidationEngine:
|
||||
"""Test validation engine orchestrator."""
|
||||
|
||||
def test_engine_applies_all_rules(self):
|
||||
"""Engine should apply all validation rules."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# All valid data
|
||||
result = engine.validate_extraction({
|
||||
"amount": 85.99,
|
||||
"tva": 14.92,
|
||||
"cui": "RO10562600",
|
||||
"card_amount": 85.99,
|
||||
"cash_amount": 0.0,
|
||||
})
|
||||
|
||||
assert isinstance(result, EnhancedExtractionResult)
|
||||
assert result.needs_manual_review is False
|
||||
assert len(result.validation_errors) == 0
|
||||
|
||||
def test_engine_aggregates_warnings(self):
|
||||
"""Engine should collect warnings from multiple rules."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Invalid amount (too high)
|
||||
result = engine.validate_extraction({
|
||||
"amount": 200_000.0, # > 100,000
|
||||
"tva": 50_000.0, # TVA ratio OK (25%) but still too high
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_errors) > 0
|
||||
assert any("exceeds maximum" in w for w in result.validation_errors)
|
||||
|
||||
def test_engine_sets_manual_review_flag(self):
|
||||
"""Engine should set needs_manual_review when warnings exist."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Payment sum mismatch
|
||||
result = engine.validate_extraction({
|
||||
"amount": 100.0,
|
||||
"card_amount": 50.0,
|
||||
"cash_amount": 40.0, # Sum = 90, diff = 10
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
|
||||
def test_engine_calculates_confidence_penalties(self):
|
||||
"""Engine should track confidence penalties."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction({
|
||||
"amount": 200_000.0, # Invalid
|
||||
})
|
||||
|
||||
assert result.confidence_adjustments.get("amount") == 0.5
|
||||
|
||||
def test_normalize_cui_helper(self):
|
||||
"""Test CUI normalization helper."""
|
||||
# Valid cases
|
||||
assert OCRValidationEngine.normalize_cui("10562600") == "RO10562600"
|
||||
assert OCRValidationEngine.normalize_cui("RO10562600") == "RO10562600"
|
||||
assert OCRValidationEngine.normalize_cui("R010562600") == "RO10562600"
|
||||
|
||||
# Invalid cases
|
||||
assert OCRValidationEngine.normalize_cui(None) is None
|
||||
assert OCRValidationEngine.normalize_cui("123") is None # Too short
|
||||
assert OCRValidationEngine.normalize_cui("12345678901") is None # Too long
|
||||
|
||||
def test_inter_ocr_consistency_with_engine(self):
|
||||
"""Engine should check inter-OCR consistency."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction(
|
||||
extraction_result={"amount": 85.99},
|
||||
light_result={"amount": 85.99},
|
||||
medium_result={"amount": 859_762.16}
|
||||
)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_warnings) > 0
|
||||
assert any("Inter-OCR" in w for w in result.validation_warnings)
|
||||
assert result.inter_ocr_ratios.get("amount") > 10.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests (Validation + Data Flow)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidationIntegration:
|
||||
"""Test validation with realistic data scenarios."""
|
||||
|
||||
def test_five_holding_production_case(self):
|
||||
"""Test with Five-Holding receipt data (production bug case)."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
# Correct Light OCR result
|
||||
light_data = {"amount": 85.99, "tva": 14.92}
|
||||
|
||||
# Incorrect Heavy OCR result (10,000x error)
|
||||
medium_data = {"amount": 859_762.16, "tva": 149_214.92}
|
||||
|
||||
# Merged result (should use Light if validation works)
|
||||
merged = {"amount": 85.99, "tva": 14.92, "card_amount": 85.99}
|
||||
|
||||
result = engine.validate_extraction(
|
||||
extraction_result=merged,
|
||||
light_result=light_data,
|
||||
medium_result=medium_data
|
||||
)
|
||||
|
||||
# Should detect inter-OCR inconsistency but validate merged result
|
||||
assert result.needs_manual_review is True # Due to inter-OCR warning
|
||||
assert result.inter_ocr_ratios.get("amount") > 10.0
|
||||
|
||||
def test_clean_receipt_no_warnings(self):
|
||||
"""Clean receipt with all valid data should pass."""
|
||||
engine = OCRValidationEngine()
|
||||
|
||||
result = engine.validate_extraction({
|
||||
"amount": 85.99,
|
||||
"tva": 14.92,
|
||||
"cui": "RO10562600",
|
||||
"card_amount": 85.99,
|
||||
"cash_amount": 0.0,
|
||||
"tva_entries": {"A": 14.92}
|
||||
})
|
||||
|
||||
assert result.needs_manual_review is False
|
||||
assert len(result.validation_warnings) == 0
|
||||
assert len(result.validation_errors) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -1,180 +0,0 @@
|
||||
"""
|
||||
Integration tests for OCR validation system.
|
||||
|
||||
These tests verify the end-to-end validation flow with real OCR processing.
|
||||
|
||||
IMPORTANT: These tests require:
|
||||
1. PaddleOCR models downloaded
|
||||
2. Tesseract installed
|
||||
3. Test receipt files in docs/data-entry/
|
||||
|
||||
Run with: pytest backend/modules/data_entry/tests/test_ocr_validation_integration.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
# Mark all tests as integration tests (slower, require OCR models)
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def five_holding_receipt_path():
|
||||
"""Path to Five-Holding production receipt (85.99 LEI test case)."""
|
||||
return Path("docs/data-entry/igiena 14 decembrie five-holding.pdf")
|
||||
|
||||
|
||||
class TestProductionCaseFiveHolding:
|
||||
"""Test the critical Five-Holding receipt case (85.99 not 859,762.16)."""
|
||||
|
||||
def test_correct_amount_extracted(self, five_holding_receipt_path):
|
||||
"""Verify Five-Holding receipt extracts 85.99 LEI, not 859,762.16."""
|
||||
# TODO: Implement when OCR service is running
|
||||
# from backend.modules.data_entry.services.ocr_service import OCRService
|
||||
# service = OCRService()
|
||||
# success, message, extraction = service.process_receipt(five_holding_receipt_path)
|
||||
#
|
||||
# assert success is True
|
||||
# assert extraction.amount == Decimal('85.99'), f"Expected 85.99, got {extraction.amount}"
|
||||
# assert extraction.tva_total == Decimal('14.92'), f"Expected 14.92, got {extraction.tva_total}"
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
def test_no_magnitude_errors(self, five_holding_receipt_path):
|
||||
"""Verify no 10,000x magnitude errors."""
|
||||
# TODO: Verify extraction.amount < 1000 (not 859,762.16)
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
def test_validation_warnings_if_any(self, five_holding_receipt_path):
|
||||
"""Check validation warnings on Five-Holding receipt."""
|
||||
# TODO: extraction.validation_warnings should be empty or minimal
|
||||
pytest.skip("Requires running OCR service - manual test")
|
||||
|
||||
|
||||
class TestValidationIntegration:
|
||||
"""Test validation integration with OCR pipeline."""
|
||||
|
||||
def test_payment_sum_validation_mock(self):
|
||||
"""Test payment sum validation with mocked data."""
|
||||
# This can run without OCR - just tests validation logic
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: Payment sum mismatch
|
||||
data = {
|
||||
'amount': 100.0,
|
||||
'card_amount': 50.0,
|
||||
'cash_amount': 40.0, # Sum = 90, diff = 10
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_warnings) > 0
|
||||
assert any('Payment sum' in w for w in result.validation_warnings)
|
||||
|
||||
def test_tva_ratio_validation_mock(self):
|
||||
"""Test TVA ratio validation with mocked data."""
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: TVA too high (> 24%)
|
||||
data = {
|
||||
'amount': 100.0,
|
||||
'tva': 30.0, # 30% - invalid!
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert any('TVA ratio' in w for w in result.validation_warnings)
|
||||
|
||||
def test_amount_range_validation_mock(self):
|
||||
"""Test amount range validation with mocked data."""
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Case: Amount too high (> 100,000)
|
||||
data = {
|
||||
'amount': 859_762.16, # Production error case!
|
||||
}
|
||||
|
||||
result = validator.validate_extraction(data)
|
||||
|
||||
assert result.needs_manual_review is True
|
||||
assert len(result.validation_errors) > 0
|
||||
assert any('exceeds maximum' in e for e in result.validation_errors)
|
||||
|
||||
def test_medium_ocr_preprocessing(self):
|
||||
"""Test that Medium OCR preprocessing works."""
|
||||
pytest.skip("Requires OCR models - manual test")
|
||||
# TODO:
|
||||
# from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
# preprocessor = ImagePreprocessor()
|
||||
# # Load test image
|
||||
# # Apply preprocess_medium()
|
||||
# # Verify output shape and values
|
||||
|
||||
|
||||
class TestDatabaseIntegration:
|
||||
"""Test database integration for needs_manual_review field."""
|
||||
|
||||
def test_receipt_model_has_validation_field(self):
|
||||
"""Verify Receipt model has needs_manual_review field."""
|
||||
# TODO: Check Receipt model
|
||||
pytest.skip("Requires database connection")
|
||||
|
||||
def test_migration_adds_column(self):
|
||||
"""Verify migration adds needs_manual_review column."""
|
||||
# TODO: Run migration and check column exists
|
||||
pytest.skip("Requires database connection")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MANUAL TESTING CHECKLIST
|
||||
# =============================================================================
|
||||
"""
|
||||
MANUAL TESTS TO PERFORM:
|
||||
|
||||
1. Five-Holding Receipt Test (Production Case)
|
||||
□ Upload: docs/data-entry/igiena 14 decembrie five-holding.pdf
|
||||
□ Verify TOTAL: 85.99 LEI (not 859,762.16)
|
||||
□ Verify TVA: 14.92 LEI (not 149,214.92)
|
||||
□ Verify CUI: R010562600
|
||||
□ Verify no validation warnings (or only minor ones)
|
||||
|
||||
2. Database Migration Test
|
||||
□ Run: alembic upgrade head
|
||||
□ Check: receipts table has needs_manual_review column
|
||||
□ Verify: Existing receipts have NULL value
|
||||
□ Verify: New receipts get TRUE/FALSE values
|
||||
|
||||
3. API Response Test
|
||||
□ POST /api/ocr/extract with test receipt
|
||||
□ Verify response includes: needs_manual_review, validation_warnings
|
||||
□ Verify Save button works even with warnings
|
||||
|
||||
4. Validation Rules Test
|
||||
□ Test with receipt having wrong amounts (should flag)
|
||||
□ Test with receipt having correct amounts (should pass)
|
||||
□ Test payment sum mismatch detection
|
||||
□ Test TVA ratio validation
|
||||
|
||||
5. Medium OCR vs Heavy OCR
|
||||
□ Compare results on clear PDFs
|
||||
□ Verify no digit concatenation errors
|
||||
□ Check processing time is similar
|
||||
|
||||
6. Unit Tests
|
||||
□ Run: pytest backend/modules/data_entry/tests/test_ocr_validation.py -v
|
||||
□ Verify: All tests pass
|
||||
□ Check: Coverage > 90%
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -63,7 +63,10 @@ fpdf2>=2.7.0
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - OCR Dependencies
|
||||
# ============================================================================
|
||||
# PaddleOCR for receipt text extraction
|
||||
# docTR - fastest OCR engine with 90/100 accuracy (3.3x faster than PaddleOCR)
|
||||
python-doctr[torch]>=0.8.0
|
||||
|
||||
# PaddleOCR for receipt text extraction (fallback)
|
||||
paddleocr>=2.7.0
|
||||
paddlepaddle>=2.5.0
|
||||
opencv-python>=4.8.0
|
||||
|
||||
220
docs/OCR_MEMORY_SOLUTIONS_RESEARCH.md
Normal file
220
docs/OCR_MEMORY_SOLUTIONS_RESEARCH.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# OCR Memory Leak Solutions Research
|
||||
|
||||
**Data**: 2026-01-01
|
||||
**Problema**: Worker OCR cade după 10 bonuri din cauza memory leak în WSL2
|
||||
**Mediu**: WSL2 pe Windows, 8GB RAM alocat, docTR + PyTorch
|
||||
|
||||
---
|
||||
|
||||
## Diagnosticul Problemei
|
||||
|
||||
- **WSL RAM**: 7.5 GB (din 8GB configurat)
|
||||
- **WSL Swap**: 2 GB folosit complet (config 8GB nu se aplică?)
|
||||
- **Crash**: După ~10 PDF-uri procesate cu hybrid-doctr
|
||||
- **Eroare**: "A process in the process pool was terminated abruptly"
|
||||
|
||||
---
|
||||
|
||||
## Soluții Găsite (ordonate după prioritate)
|
||||
|
||||
### ✅ Nivel 1: Environment Variables (ÎNCERCAT PRIMUL)
|
||||
|
||||
```python
|
||||
# În ocr_worker_process.py sau la începutul aplicației
|
||||
import os
|
||||
os.environ["DOCTR_MULTIPROCESSING_DISABLE"] = "TRUE"
|
||||
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
|
||||
```
|
||||
|
||||
**Sursa**: [docTR Issue #1594](https://github.com/mindee/doctr/issues/1594)
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 2: maxtasksperchild în ProcessPool
|
||||
|
||||
```python
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
# Worker-ul repornește la fiecare N task-uri
|
||||
executor = ProcessPoolExecutor(max_workers=1, max_tasks_per_child=5)
|
||||
```
|
||||
|
||||
**Efect**: Memoria se eliberează complet când worker-ul repornește
|
||||
**Cost**: +10-15s la fiecare 5 job-uri (reload model docTR)
|
||||
|
||||
**Sursa**: [Python Multiprocessing Memory](https://www.pythontutorials.net/blog/memory-usage-keep-growing-with-python-s-multiprocessing-pool/)
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 3: Manual Memory Cleanup
|
||||
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def process_image(image):
|
||||
with torch.no_grad(): # Dezactivează gradient tracking
|
||||
result = doctr_engine(image)
|
||||
|
||||
# Cleanup explicit
|
||||
del image
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 4: WSL Memory Reclaim
|
||||
|
||||
**În .wslconfig** (`C:\Users\<username>\.wslconfig`):
|
||||
```ini
|
||||
[wsl2]
|
||||
memory=8GB
|
||||
processors=2
|
||||
swap=8GB
|
||||
|
||||
[experimental]
|
||||
autoMemoryReclaim=gradual
|
||||
```
|
||||
|
||||
**Manual drop caches** (în WSL):
|
||||
```bash
|
||||
echo 1 | sudo tee /proc/sys/vm/drop_caches
|
||||
echo 1 | sudo tee /proc/sys/vm/compact_memory
|
||||
```
|
||||
|
||||
**După modificare .wslconfig**:
|
||||
```powershell
|
||||
wsl --shutdown
|
||||
```
|
||||
|
||||
**Sursa**: [Microsoft DevBlogs - Memory Reclaim](https://devblogs.microsoft.com/commandline/memory-reclaim-in-the-windows-subsystem-for-linux-2/)
|
||||
|
||||
⚠️ **Atenție**: `autoMemoryReclaim` poate afecta Docker!
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 5: Procesare Secvențială cu Cleanup
|
||||
|
||||
Restructurare `_process_hybrid_doctr()`:
|
||||
|
||||
```python
|
||||
def _process_hybrid_doctr_memory_efficient(image, doctr_engine, preprocessor, extractor):
|
||||
results = []
|
||||
|
||||
for pass_type in ['light', 'medium', 'heavy']:
|
||||
# Preprocess
|
||||
if pass_type == 'light':
|
||||
processed = preprocessor.preprocess_light(image)
|
||||
elif pass_type == 'medium':
|
||||
processed = preprocessor.preprocess_medium(image)
|
||||
else:
|
||||
processed = preprocessor.preprocess_heavy(image)
|
||||
|
||||
# OCR
|
||||
with torch.no_grad():
|
||||
ocr_result = _doctr_recognize(doctr_engine, processed)
|
||||
|
||||
# Cleanup IMEDIAT
|
||||
del processed
|
||||
gc.collect()
|
||||
|
||||
if ocr_result:
|
||||
extraction = extractor.extract(ocr_result.text)
|
||||
results.append(extraction)
|
||||
|
||||
# Early exit dacă e suficient de bun
|
||||
if _is_extraction_complete(extraction) and extraction.overall_confidence > 0.9:
|
||||
break
|
||||
|
||||
return _merge_extractions(results)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 6: Alternative la ProcessPoolExecutor
|
||||
|
||||
**Ray** (mai bun pentru ML):
|
||||
```python
|
||||
import ray
|
||||
|
||||
@ray.remote
|
||||
def process_ocr(image_bytes):
|
||||
# Processing...
|
||||
return result
|
||||
|
||||
# Ray gestionează memoria mai bine
|
||||
result = ray.get(process_ocr.remote(image_bytes))
|
||||
```
|
||||
|
||||
**Dask**:
|
||||
```python
|
||||
from dask import delayed
|
||||
|
||||
@delayed
|
||||
def process_ocr(image_bytes):
|
||||
return result
|
||||
```
|
||||
|
||||
**Sursa**: [Managing Memory Issues](https://eyxibnib.biz.id/2024/07/05/managing-memory-issues-with-pythons-threadpoolexecutor-and-processpoolexecutor/)
|
||||
|
||||
---
|
||||
|
||||
### ✅ Nivel 7: Downscale Imagini Mari
|
||||
|
||||
```python
|
||||
def preprocess_with_size_limit(image, max_size=2000):
|
||||
h, w = image.shape[:2]
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
return image
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Referințe Complete
|
||||
|
||||
### docTR Memory Issues
|
||||
- [Discussion #1422 - Memory Leak on inference](https://github.com/mindee/doctr/discussions/1422)
|
||||
- [Issue #1594 - CPU memory usage increase](https://github.com/mindee/doctr/issues/1594)
|
||||
|
||||
### PyTorch Multiprocessing
|
||||
- [PyTorch Forums - Memory leak with multiprocessing](https://discuss.pytorch.org/t/memory-leak-with-multiprocessing/209512)
|
||||
- [Issue #44156 - CUDA memory leak](https://github.com/pytorch/pytorch/issues/44156)
|
||||
- [Issue #13246 - DataLoader memory replication](https://github.com/pytorch/pytorch/issues/13246)
|
||||
|
||||
### Python ProcessPoolExecutor
|
||||
- [CPython Issue #90943 - Exception memory leak](https://github.com/python/cpython/issues/90943)
|
||||
- [NumPy Issue #12122 - Memory leak with ProcessPoolExecutor](https://github.com/numpy/numpy/issues/12122)
|
||||
|
||||
### WSL2 Memory
|
||||
- [Microsoft - Memory Reclaim in WSL2](https://devblogs.microsoft.com/commandline/memory-reclaim-in-the-windows-subsystem-for-linux-2/)
|
||||
- [WSL Issue #4166 - Massive RAM consumption](https://github.com/microsoft/WSL/issues/4166)
|
||||
- [Limiting Memory Usage in WSL2](https://www.aleksandrhovhannisyan.com/blog/limiting-memory-usage-in-wsl-2/)
|
||||
|
||||
---
|
||||
|
||||
## Plan de Implementare
|
||||
|
||||
1. ✅ **Pas 1**: Adaugă env vars (DOCTR_MULTIPROCESSING_DISABLE, ONEDNN)
|
||||
2. ✅ **Pas 2**: Adaugă maxtasksperchild=5 în worker pool
|
||||
3. ⏳ **Pas 3**: Testează - dacă tot cade, adaugă gc.collect() explicit
|
||||
4. ⏳ **Pas 4**: Dacă tot cade, restructurează procesarea secvențială
|
||||
5. ⏳ **Pas 5**: Dacă tot cade, crește memoria WSL sau folosește Ray
|
||||
|
||||
---
|
||||
|
||||
## Notă Importantă
|
||||
|
||||
Swap-ul din .wslconfig poate să nu se aplice corect. După modificări:
|
||||
```powershell
|
||||
wsl --shutdown
|
||||
# Apoi redeschide WSL
|
||||
```
|
||||
|
||||
Verifică cu `free -h` că noile setări sunt aplicate.
|
||||
106
docs/OCR_TEST_RESULTS.md
Normal file
106
docs/OCR_TEST_RESULTS.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# OCR Test Results - docTR+ Engine
|
||||
|
||||
**Date:** 2026-01-02 | **Receipts:** 26 | **Test:** Sequential
|
||||
|
||||
## Summary Comparison
|
||||
|
||||
| Workers | Avg | Total | Mem Used | Mem Avail |
|
||||
|---------|-----|-------|----------|-----------|
|
||||
| 1 | 6.8s | 176s | 3.2GB | 4.1GB |
|
||||
| 2 | 7.2s | 187s | 3.1GB | 4.1GB |
|
||||
| 3 | 6.8s | 176s | 3.9GB | 3.3GB |
|
||||
|
||||
**Success Rate:** 80.8% (21/26) - same for all configs
|
||||
|
||||
**Note:** For sequential tests, 1 worker ≈ 3 workers speed!
|
||||
Multiple workers only help with parallel requests.
|
||||
|
||||
## Detailed Results (1 Worker)
|
||||
|
||||
| # | Receipt | Time | Tier | Result | Notes |
|
||||
|---|---------|------|------|--------|-------|
|
||||
| 01 | abonament kineterra | 6.8s | T1 | ✓ | 97% |
|
||||
| 02 | benzina 14 august | 6.0s | T1 | ✓ | 83% |
|
||||
| 03 | benzina 27 octombrie | 5.9s | T1 | ✓ | 83% |
|
||||
| 04 | igiena 11 octombrie | 7.7s | T1 | ✓ | 97% |
|
||||
| 05 | igiena 14 dec five-holding | 11.5s | T1+T2 | ✗ | TOTAL ±1 |
|
||||
| 06 | rechizite 12 dec pictus | 5.9s | T1 | ✓ | 97% |
|
||||
| 07 | benzina 10 mai 2025 | 5.1s | T1 | ✓ | 83% |
|
||||
| 08 | brick consumabil 604 50% | 4.8s | T1 | ✓ | 97% |
|
||||
| 09 | benzina 13 septembrie | 4.9s | T1 | ✓ | 83% |
|
||||
| 10 | brick consumabile 604 | 5.3s | T1 | ✓ | 97% |
|
||||
| 11 | benzina 20 dec | 5.8s | T1 | ✓ | 79% |
|
||||
| 12 | bon fiscal Dedeman | 5.7s | T1 | ✓ | 90% |
|
||||
| 13 | factura Dedeman | 6.8s | T1 | ✓ | 97% |
|
||||
| 14 | benzina 13 iulie | 5.7s | T1 | ✓ | 95% |
|
||||
| 15 | best print stampila | 4.5s | T1 | ✓ | 94% |
|
||||
| 16 | electrobering telecomanda | 4.8s | T1 | ✓ | 97% |
|
||||
| 17 | brick igiena 8 oct | 11.9s | T1+T2 | ✗ | TOTAL/CUI |
|
||||
| 18 | gama ink refill toner | 5.9s | T1 | ✓ | 94% |
|
||||
| 19 | kineterra fizioterapie | 4.6s | T1 | ✓ | 97% |
|
||||
| 20 | brick igiena 1 sept | 12.5s | T1+T2 | ✗ | ALL None |
|
||||
| 21 | kineterra abonament | 5.6s | T1 | ✓ | 97% |
|
||||
| 22 | brick igiena electrice | 15.9s | T1+T2 | ✗ | DATE None |
|
||||
| 23 | electrobering igiena | 4.4s | T1 | ✓ | 97% |
|
||||
| 24 | Lidl papetarie 604 | 5.8s | T1 | ✓ | 87% |
|
||||
| 25 | brick igiena 604 | 6.8s | T1 | ✗ | DATE ±1 |
|
||||
| 26 | unlimited duplicat | 4.8s | T1 | ✓ | 86% |
|
||||
|
||||
## Time Comparison by Receipt
|
||||
|
||||
| # | Receipt | 1W | 2W | 3W |
|
||||
|---|---------|----|----|-----|
|
||||
| 01 | abonament kineterra | 6.8s | 6.7s | 5.8s |
|
||||
| 02 | benzina 14 august | 6.0s | 5.5s | 5.8s |
|
||||
| 03 | benzina 27 octombrie | 5.9s | 5.9s | 5.7s |
|
||||
| 04 | igiena 11 octombrie | 7.7s | 8.9s | 7.4s |
|
||||
| 05 | igiena 14 dec (FAIL) | 11.5s | 12.3s | 11.9s |
|
||||
| 06 | rechizite pictus | 5.9s | 5.9s | 5.7s |
|
||||
| 07 | benzina 10 mai | 5.1s | 6.0s | 5.8s |
|
||||
| 08 | brick 50% | 4.8s | 5.9s | 5.5s |
|
||||
| 09 | benzina 13 sept | 4.9s | 5.9s | 5.3s |
|
||||
| 10 | brick consumabile | 5.3s | 5.7s | 5.7s |
|
||||
| 11 | benzina 20 dec | 5.8s | 5.4s | 5.8s |
|
||||
| 12 | bon Dedeman | 5.7s | 5.9s | 5.8s |
|
||||
| 13 | factura Dedeman | 6.8s | 6.9s | 6.8s |
|
||||
| 14 | benzina 13 iulie | 5.7s | 6.1s | 5.4s |
|
||||
| 15 | best print | 4.5s | 5.8s | 4.8s |
|
||||
| 16 | electrobering | 4.8s | 4.2s | 4.7s |
|
||||
| 17 | brick 8 oct (FAIL) | 11.9s | 13.1s | 12.0s |
|
||||
| 18 | gama ink | 5.9s | 5.9s | 4.7s |
|
||||
| 19 | kineterra fizioterapie | 4.6s | 5.9s | 4.8s |
|
||||
| 20 | brick 1 sept (FAIL) | 12.5s | 13.2s | 13.1s |
|
||||
| 21 | kineterra abonament | 5.6s | 4.9s | 4.8s |
|
||||
| 22 | brick electrice (FAIL) | 15.9s | 17.0s | 15.5s |
|
||||
| 23 | electrobering igiena | 4.4s | 5.4s | 5.0s |
|
||||
| 24 | Lidl papetarie | 5.8s | 6.9s | 5.8s |
|
||||
| 25 | brick 604 (FAIL) | 6.8s | 6.5s | 6.9s |
|
||||
| 26 | unlimited duplicat | 4.8s | 5.8s | 5.0s |
|
||||
|---|---------|----|----|-----|
|
||||
| **AVG** | | **6.8s** | **7.2s** | **6.8s** |
|
||||
| **TOTAL** | | **176s** | **187s** | **176s** |
|
||||
|
||||
## Tier Analysis
|
||||
|
||||
- **T1 only (early exit):** 21 receipts (~5-6s)
|
||||
- **T1+T2 (full):** 5 receipts (~12-16s)
|
||||
|
||||
## Failures (5)
|
||||
|
||||
| Receipt | Issue | Fixable |
|
||||
|---------|-------|---------|
|
||||
| igiena 14 dec | TOTAL ±1 | No |
|
||||
| brick 8 oct | TOTAL/CUI | Maybe |
|
||||
| brick 1 sept | ALL None | No (bad doc) |
|
||||
| brick electrice | DATE None | Maybe |
|
||||
| brick 604 | DATE ±1 | No |
|
||||
|
||||
## Recommendation
|
||||
|
||||
```
|
||||
OCR_WORKERS=1 # Best for sequential, saves RAM
|
||||
OCR_WORKERS=2 # For parallel requests (production)
|
||||
OCR_MAX_TASKS_PER_CHILD=0 # No restart
|
||||
```
|
||||
|
||||
**For 8GB RAM:** Use 1-2 workers max
|
||||
Binary file not shown.
BIN
docs/data-entry/benzina 07 aug. 2024.pdf
Normal file
BIN
docs/data-entry/benzina 07 aug. 2024.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/benzina 10 mai 2025.pdf
Normal file
BIN
docs/data-entry/benzina 10 mai 2025.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/benzina 13 iulie.pdf
Normal file
BIN
docs/data-entry/benzina 13 iulie.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/benzina 13 septembrie .pdf
Normal file
BIN
docs/data-entry/benzina 13 septembrie .pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/benzina 20 dec.pdf
Normal file
BIN
docs/data-entry/benzina 20 dec.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/best print stampila .pdf
Normal file
BIN
docs/data-entry/best print stampila .pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/bon fiscal Dedeman - efactura.pdf
Normal file
BIN
docs/data-entry/bon fiscal Dedeman - efactura.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/brick consumabil 604 50% deductibil 22 dec.pdf
Normal file
BIN
docs/data-entry/brick consumabil 604 50% deductibil 22 dec.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/brick consumabile 604 22 dec.pdf
Normal file
BIN
docs/data-entry/brick consumabile 604 22 dec.pdf
Normal file
Binary file not shown.
6740
docs/data-entry/brick igiena 1 sept.pdf
Normal file
6740
docs/data-entry/brick igiena 1 sept.pdf
Normal file
File diff suppressed because it is too large
Load Diff
2552
docs/data-entry/brick igiena 604.pdf
Normal file
2552
docs/data-entry/brick igiena 604.pdf
Normal file
File diff suppressed because it is too large
Load Diff
2292
docs/data-entry/brick igiena 8 octombrie 98.95 lei card.pdf
Normal file
2292
docs/data-entry/brick igiena 8 octombrie 98.95 lei card.pdf
Normal file
File diff suppressed because it is too large
Load Diff
2610
docs/data-entry/brick igiena, electrice consumabile 604.pdf
Normal file
2610
docs/data-entry/brick igiena, electrice consumabile 604.pdf
Normal file
File diff suppressed because it is too large
Load Diff
BIN
docs/data-entry/electrobering igiena iulie 604.pdf
Normal file
BIN
docs/data-entry/electrobering igiena iulie 604.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/electrobering telecomanda.pdf
Normal file
BIN
docs/data-entry/electrobering telecomanda.pdf
Normal file
Binary file not shown.
1370
docs/data-entry/factura 70005116259 20.09.2025 Dedeman.pdf
Normal file
1370
docs/data-entry/factura 70005116259 20.09.2025 Dedeman.pdf
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
BIN
docs/data-entry/kineterra abonament terapie august 2024.pdf
Normal file
BIN
docs/data-entry/kineterra abonament terapie august 2024.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/kineterra fizioterapie 9 sept.pdf
Normal file
BIN
docs/data-entry/kineterra fizioterapie 9 sept.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/stepout market carti tva 5%.pdf
Normal file
BIN
docs/data-entry/stepout market carti tva 5%.pdf
Normal file
Binary file not shown.
BIN
docs/data-entry/unlimited duplicat chei 23 mai.pdf
Normal file
BIN
docs/data-entry/unlimited duplicat chei 23 mai.pdf
Normal file
Binary file not shown.
@@ -109,15 +109,7 @@ async def get_current_user_from_request(request: Request) -> CurrentUser:
|
||||
Raises:
|
||||
HTTPException: Dacă utilizatorul nu este autentificat
|
||||
"""
|
||||
print(f"[DEPENDENCY DEBUG] get_current_user_from_request called")
|
||||
print(f"[DEPENDENCY DEBUG] request.state attributes: {dir(request.state)}")
|
||||
print(f"[DEPENDENCY DEBUG] has is_authenticated: {hasattr(request.state, 'is_authenticated')}")
|
||||
print(f"[DEPENDENCY DEBUG] is_authenticated value: {getattr(request.state, 'is_authenticated', 'NOT_SET')}")
|
||||
print(f"[DEPENDENCY DEBUG] has user: {hasattr(request.state, 'user')}")
|
||||
print(f"[DEPENDENCY DEBUG] user value: {getattr(request.state, 'user', 'NOT_SET')}")
|
||||
|
||||
if not hasattr(request.state, 'is_authenticated') or not request.state.is_authenticated:
|
||||
print(f"[DEPENDENCY DEBUG] Returning 401: Authentication required")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
@@ -125,14 +117,12 @@ async def get_current_user_from_request(request: Request) -> CurrentUser:
|
||||
)
|
||||
|
||||
if not hasattr(request.state, 'user') or not request.state.user:
|
||||
print(f"[DEPENDENCY DEBUG] Returning 401: User not found in request")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found in request",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
print(f"[DEPENDENCY DEBUG] Returning user: {request.state.user}")
|
||||
|
||||
return request.state.user
|
||||
|
||||
|
||||
|
||||
@@ -246,7 +246,6 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
Returns:
|
||||
Response-ul HTTP
|
||||
"""
|
||||
print(f"[ORIGINAL MIDDLEWARE] dispatch called for path: {request.url.path}")
|
||||
start_time = time.time()
|
||||
path = request.url.path
|
||||
|
||||
@@ -268,9 +267,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
# Extrage token-ul
|
||||
print(f"[MIDDLEWARE DEBUG] Extracting token for path: {path}")
|
||||
token = self._extract_token_from_header(request)
|
||||
print(f"[MIDDLEWARE DEBUG] Extracted token: {token[:30] if token else 'None'}...")
|
||||
|
||||
if not token:
|
||||
# Nu există token - pentru endpoint-urile protejate returnează 401
|
||||
@@ -289,9 +286,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# Validează token-ul
|
||||
print(f"[MIDDLEWARE DEBUG] Validating token: {token[:30]}...")
|
||||
token_data = jwt_handler.verify_token(token)
|
||||
print(f"[MIDDLEWARE DEBUG] Token validation result: {token_data}")
|
||||
|
||||
if not token_data:
|
||||
# Token invalid
|
||||
|
||||
@@ -19,7 +19,8 @@ export const menuSections = [
|
||||
title: 'Sistem',
|
||||
items: [
|
||||
{ to: '/reports/telegram', icon: 'pi pi-telegram', label: 'Telegram Bot' },
|
||||
{ to: '/reports/cache-stats', icon: 'pi pi-chart-bar', label: 'Statistici Cache' }
|
||||
{ to: '/reports/cache-stats', icon: 'pi pi-chart-bar', label: 'Statistici Cache' },
|
||||
{ to: '/data-entry/ocr-metrics', icon: 'pi pi-eye', label: 'Statistici OCR' }
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -263,6 +263,12 @@ const formatDate = (dateStr) => {
|
||||
|
||||
const getEngineClass = (engine) => {
|
||||
if (!engine) return ''
|
||||
// docTR engines
|
||||
if (engine === 'doctr-light') return 'doctr-fast'
|
||||
if (engine === 'doctr-medium') return 'doctr'
|
||||
if (engine === 'doctr-adaptive') return 'doctr-adaptive'
|
||||
if (engine.includes('doctr')) return 'doctr'
|
||||
// PaddleOCR engines
|
||||
if (engine === 'paddle-light') return 'fast'
|
||||
if (engine === 'paddle-adaptive') return 'adaptive'
|
||||
if (engine === 'adaptive-full') return 'full'
|
||||
@@ -273,13 +279,23 @@ const getEngineClass = (engine) => {
|
||||
|
||||
const getEngineIcon = (engine) => {
|
||||
if (!engine) return 'pi pi-cog'
|
||||
if (engine === 'paddle-light') return 'pi pi-bolt' // Fast/lightning
|
||||
if (engine === 'adaptive-full') return 'pi pi-cog' // Full pipeline
|
||||
// docTR - use bolt for fast modes
|
||||
if (engine === 'doctr-light') return 'pi pi-bolt'
|
||||
if (engine.includes('doctr')) return 'pi pi-bolt'
|
||||
// PaddleOCR
|
||||
if (engine === 'paddle-light') return 'pi pi-bolt'
|
||||
if (engine === 'adaptive-full') return 'pi pi-cog'
|
||||
return 'pi pi-cog'
|
||||
}
|
||||
|
||||
const getEngineLabel = (engine) => {
|
||||
if (!engine) return ''
|
||||
// docTR engines
|
||||
if (engine === 'doctr-light') return 'docTR Fast'
|
||||
if (engine === 'doctr-medium') return 'docTR Medium'
|
||||
if (engine === 'doctr-adaptive') return 'docTR Adaptive'
|
||||
if (engine.includes('doctr')) return 'docTR'
|
||||
// PaddleOCR engines
|
||||
if (engine === 'paddle-light') return 'Fast Mode (PaddleOCR)'
|
||||
if (engine === 'paddle-adaptive') return 'Adaptive (Paddle dual)'
|
||||
if (engine === 'adaptive-full') return 'Full Pipeline'
|
||||
@@ -615,6 +631,22 @@ const formatProcessingTime = (ms) => {
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
/* docTR engine styles */
|
||||
.ocr-engine-badge.doctr {
|
||||
background: #ede9fe;
|
||||
color: #5b21b6;
|
||||
}
|
||||
|
||||
.ocr-engine-badge.doctr-fast {
|
||||
background: #d1fae5;
|
||||
color: #047857;
|
||||
}
|
||||
|
||||
.ocr-engine-badge.doctr-adaptive {
|
||||
background: #e0e7ff;
|
||||
color: #3730a3;
|
||||
}
|
||||
|
||||
.ocr-message-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
|
||||
@@ -60,7 +60,14 @@
|
||||
optionValue="value"
|
||||
placeholder="Motor OCR"
|
||||
class="engine-selector dropdown-borderless"
|
||||
/>
|
||||
>
|
||||
<template #option="{ option }">
|
||||
<div class="engine-option">
|
||||
<span class="engine-label">{{ option.label }}</span>
|
||||
<span class="engine-desc">{{ option.desc }}</span>
|
||||
</div>
|
||||
</template>
|
||||
</Dropdown>
|
||||
<Button
|
||||
label="Proceseaza OCR"
|
||||
icon="pi pi-cog"
|
||||
@@ -77,9 +84,10 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed } from 'vue'
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import Dropdown from 'primevue/dropdown'
|
||||
import api from '@data-entry/services/api'
|
||||
import { useOCRSettingsStore } from '@data-entry/stores/ocrSettingsStore'
|
||||
|
||||
const emit = defineEmits(['ocr-result', 'file-selected', 'error'])
|
||||
|
||||
@@ -89,20 +97,73 @@ const isDragging = ref(false)
|
||||
const processing = ref(false)
|
||||
const error = ref(null)
|
||||
|
||||
// OCR Engine selection
|
||||
// OCR Settings Store - manages user preferences
|
||||
const ocrStore = useOCRSettingsStore()
|
||||
|
||||
// OCR Engine selection (synced with store)
|
||||
const selectedEngine = ref('auto')
|
||||
const engineOptions = [
|
||||
{ label: 'Auto (Recomandat)', value: 'auto' },
|
||||
{ label: 'PaddleOCR', value: 'paddleocr' },
|
||||
{ label: 'Tesseract', value: 'tesseract' }
|
||||
]
|
||||
|
||||
// Engine config - labels and descriptions for dropdown
|
||||
const engineConfig = {
|
||||
'auto': {
|
||||
label: 'Auto',
|
||||
desc: 'docTR→Paddle→Tess · General'
|
||||
},
|
||||
'doctr': {
|
||||
label: 'docTR',
|
||||
desc: 'Rapid, bună acuratețe'
|
||||
},
|
||||
'paddleocr': {
|
||||
label: 'PaddleOCR',
|
||||
desc: 'Cea mai bună calitate'
|
||||
},
|
||||
'tesseract': {
|
||||
label: 'Tesseract',
|
||||
desc: 'Cel mai rapid, calitate redusă'
|
||||
},
|
||||
'hybrid': {
|
||||
label: 'Hybrid',
|
||||
desc: 'docTR+Tess paralel · Recomandat'
|
||||
},
|
||||
'hybrid-quality': {
|
||||
label: 'Hybrid Calitate',
|
||||
desc: 'Paddle→docTR→Tess · Acuratețe max'
|
||||
},
|
||||
}
|
||||
|
||||
// Compute engine options from store's available engines
|
||||
const engineOptions = computed(() => {
|
||||
return ocrStore.availableEngines.map(engine => ({
|
||||
label: engineConfig[engine]?.label || engine,
|
||||
desc: engineConfig[engine]?.desc || '',
|
||||
value: engine
|
||||
}))
|
||||
})
|
||||
|
||||
// Load user's preferred engine on mount
|
||||
onMounted(async () => {
|
||||
await ocrStore.loadPreference()
|
||||
selectedEngine.value = ocrStore.preferredEngine
|
||||
console.log('[OCRUploadZone] Loaded user preference:', selectedEngine.value)
|
||||
})
|
||||
|
||||
// Save preference when user changes engine
|
||||
watch(selectedEngine, async (newEngine, oldEngine) => {
|
||||
if (oldEngine && newEngine !== oldEngine && ocrStore.initialized) {
|
||||
try {
|
||||
await ocrStore.setPreference(newEngine)
|
||||
console.log('[OCRUploadZone] Saved user preference:', newEngine)
|
||||
} catch (err) {
|
||||
console.error('[OCRUploadZone] Failed to save preference:', err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Job queue state
|
||||
const jobId = ref(null)
|
||||
const queuePosition = ref(null)
|
||||
const estimatedWait = ref(null)
|
||||
const jobStatus = ref(null)
|
||||
let pollInterval = null
|
||||
|
||||
// Dynamic processing messages
|
||||
const processingMessage = computed(() => {
|
||||
@@ -223,26 +284,36 @@ const processOCR = async () => {
|
||||
}
|
||||
|
||||
const pollJobStatus = async (id) => {
|
||||
const maxAttempts = 120 // 2 minutes max (120 * 1s)
|
||||
let attempts = 0
|
||||
const LONG_POLL_TIMEOUT = 30 // seconds
|
||||
const MAX_TOTAL_TIME = 120 // 2 minutes max
|
||||
const startTime = Date.now()
|
||||
|
||||
const poll = async () => {
|
||||
try {
|
||||
const response = await api.get(`/ocr/jobs/${id}`)
|
||||
const job = response.data
|
||||
// Check if exceeded max total time
|
||||
const elapsed = (Date.now() - startTime) / 1000
|
||||
if (elapsed >= MAX_TOTAL_TIME) {
|
||||
processing.value = false
|
||||
error.value = 'Timeout - procesarea a durat prea mult'
|
||||
emit('error', error.value)
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Long-poll with 30s server timeout, 35s axios timeout
|
||||
const response = await api.get(`/ocr/jobs/${id}/wait`, {
|
||||
params: { timeout: LONG_POLL_TIMEOUT },
|
||||
timeout: (LONG_POLL_TIMEOUT + 5) * 1000
|
||||
})
|
||||
|
||||
const job = response.data
|
||||
jobStatus.value = job.status
|
||||
queuePosition.value = job.queue_position
|
||||
estimatedWait.value = job.estimated_wait_seconds
|
||||
|
||||
console.log('📊 OCR Poll:', { status: job.status, position: job.queue_position })
|
||||
console.log('📊 OCR Long-Poll:', { status: job.status, position: job.queue_position })
|
||||
|
||||
if (job.status === 'completed') {
|
||||
// Success! Emit result
|
||||
clearInterval(pollInterval)
|
||||
pollInterval = null
|
||||
processing.value = false
|
||||
|
||||
if (job.result) {
|
||||
console.log('✅ OCR Complete:', job.result)
|
||||
emit('ocr-result', {
|
||||
@@ -257,47 +328,36 @@ const pollJobStatus = async (id) => {
|
||||
}
|
||||
|
||||
if (job.status === 'failed') {
|
||||
// Failed
|
||||
clearInterval(pollInterval)
|
||||
pollInterval = null
|
||||
processing.value = false
|
||||
|
||||
error.value = job.error || 'OCR processing failed'
|
||||
emit('error', error.value)
|
||||
return
|
||||
}
|
||||
|
||||
// Still pending/processing - continue polling
|
||||
attempts++
|
||||
if (attempts >= maxAttempts) {
|
||||
clearInterval(pollInterval)
|
||||
pollInterval = null
|
||||
processing.value = false
|
||||
error.value = 'Timeout - procesarea a durat prea mult'
|
||||
emit('error', error.value)
|
||||
// Still pending/processing - long-poll again
|
||||
if (processing.value) {
|
||||
await poll()
|
||||
}
|
||||
|
||||
} catch (err) {
|
||||
console.error('🔴 Poll Error:', err.message)
|
||||
attempts++
|
||||
// Don't stop on poll errors - network might be flaky
|
||||
if (attempts >= maxAttempts) {
|
||||
clearInterval(pollInterval)
|
||||
pollInterval = null
|
||||
processing.value = false
|
||||
error.value = 'Eroare la verificarea starii job-ului'
|
||||
emit('error', error.value)
|
||||
// Handle timeout (normal for long-poll)
|
||||
if (err.code === 'ECONNABORTED' || err.message?.includes('timeout')) {
|
||||
console.log('⏱️ Long-poll timeout, retrying...')
|
||||
if (processing.value) {
|
||||
await poll()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Real error
|
||||
console.error('🔴 Poll Error:', err.message)
|
||||
processing.value = false
|
||||
error.value = 'Eroare la verificarea starii job-ului'
|
||||
emit('error', error.value)
|
||||
}
|
||||
}
|
||||
|
||||
// Initial poll immediately
|
||||
await poll()
|
||||
|
||||
// Continue polling every 1 second if still processing
|
||||
if (processing.value) {
|
||||
pollInterval = setInterval(poll, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
const formatFileSize = (bytes) => {
|
||||
@@ -313,10 +373,7 @@ const reset = () => {
|
||||
queuePosition.value = null
|
||||
estimatedWait.value = null
|
||||
jobStatus.value = null
|
||||
if (pollInterval) {
|
||||
clearInterval(pollInterval)
|
||||
pollInterval = null
|
||||
}
|
||||
processing.value = false // Stop any ongoing long-poll
|
||||
if (fileInput.value) {
|
||||
fileInput.value.value = ''
|
||||
}
|
||||
@@ -415,7 +472,7 @@ defineExpose({ reset, processOCR })
|
||||
|
||||
/* Engine selector dropdown */
|
||||
.engine-selector {
|
||||
min-width: 150px;
|
||||
min-width: 180px;
|
||||
}
|
||||
|
||||
.engine-selector:deep(.p-dropdown-label) {
|
||||
@@ -428,6 +485,25 @@ defineExpose({ reset, processOCR })
|
||||
width: 2rem !important;
|
||||
}
|
||||
|
||||
/* Engine dropdown option with description */
|
||||
.engine-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.engine-label {
|
||||
font-weight: 500;
|
||||
font-size: 0.875rem;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
.engine-desc {
|
||||
font-size: 0.75rem;
|
||||
color: #64748b;
|
||||
}
|
||||
|
||||
/* Processing state */
|
||||
.processing-state {
|
||||
display: flex;
|
||||
|
||||
173
src/modules/data-entry/stores/ocrSettingsStore.js
Normal file
173
src/modules/data-entry/stores/ocrSettingsStore.js
Normal file
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* OCR Settings Store
|
||||
*
|
||||
* Manages user's OCR engine preference and metrics.
|
||||
* - Auto-loads user's preferred engine on mount
|
||||
* - Saves preference to backend on change
|
||||
* - Provides OCR metrics for dashboard
|
||||
*/
|
||||
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import api from '@data-entry/services/api'
|
||||
|
||||
export const useOCRSettingsStore = defineStore('ocrSettings', () => {
|
||||
// State
|
||||
const preferredEngine = ref('doctr_plus')
|
||||
// Available engines
|
||||
// NOTE: This default list is overwritten by loadPreference() from backend
|
||||
// Backend filters engines based on OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT
|
||||
const availableEngines = ref([
|
||||
'tesseract',
|
||||
'doctr',
|
||||
'doctr_plus', // Recommended: 2-tier sequential with early exit
|
||||
'paddleocr',
|
||||
])
|
||||
const loading = ref(false)
|
||||
const error = ref(null)
|
||||
const initialized = ref(false)
|
||||
|
||||
// Metrics state
|
||||
const metrics = ref({
|
||||
summary: [],
|
||||
stats: null,
|
||||
history: [],
|
||||
historyTotal: 0,
|
||||
})
|
||||
const metricsLoading = ref(false)
|
||||
|
||||
// Computed
|
||||
const isLoading = computed(() => loading.value)
|
||||
const hasError = computed(() => !!error.value)
|
||||
|
||||
// Actions
|
||||
async function loadPreference() {
|
||||
if (initialized.value) return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const response = await api.get('/settings/ocr-preference')
|
||||
preferredEngine.value = response.data.preferred_engine
|
||||
availableEngines.value = response.data.available_engines
|
||||
initialized.value = true
|
||||
console.log('[OCRSettings] Loaded preference:', preferredEngine.value)
|
||||
} catch (err) {
|
||||
console.error('[OCRSettings] Failed to load preference:', err)
|
||||
error.value = err.message
|
||||
// Use defaults on error
|
||||
preferredEngine.value = 'doctr_plus'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function setPreference(engine) {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const response = await api.post('/settings/ocr-preference', {
|
||||
preferred_engine: engine
|
||||
})
|
||||
preferredEngine.value = response.data.preferred_engine
|
||||
console.log('[OCRSettings] Saved preference:', preferredEngine.value)
|
||||
} catch (err) {
|
||||
console.error('[OCRSettings] Failed to save preference:', err)
|
||||
error.value = err.message
|
||||
throw err
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadMetricsSummary(days = 30) {
|
||||
metricsLoading.value = true
|
||||
|
||||
try {
|
||||
const response = await api.get('/metrics/ocr/summary', { params: { days } })
|
||||
metrics.value.summary = response.data
|
||||
console.log('[OCRSettings] Loaded metrics summary:', metrics.value.summary.length, 'engines')
|
||||
} catch (err) {
|
||||
console.error('[OCRSettings] Failed to load metrics summary:', err)
|
||||
} finally {
|
||||
metricsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadMetricsStats(days = 30) {
|
||||
try {
|
||||
const response = await api.get('/metrics/ocr/stats', { params: { days } })
|
||||
metrics.value.stats = response.data
|
||||
console.log('[OCRSettings] Loaded metrics stats:', metrics.value.stats)
|
||||
} catch (err) {
|
||||
console.error('[OCRSettings] Failed to load metrics stats:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadMetricsHistory(limit = 50, offset = 0) {
|
||||
try {
|
||||
const response = await api.get('/metrics/ocr/history', { params: { limit, offset } })
|
||||
metrics.value.history = response.data.items
|
||||
metrics.value.historyTotal = response.data.total
|
||||
console.log('[OCRSettings] Loaded metrics history:', metrics.value.history.length, 'items')
|
||||
} catch (err) {
|
||||
console.error('[OCRSettings] Failed to load metrics history:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadAllMetrics(days = 30) {
|
||||
metricsLoading.value = true
|
||||
try {
|
||||
await Promise.all([
|
||||
loadMetricsSummary(days),
|
||||
loadMetricsStats(days),
|
||||
loadMetricsHistory(20),
|
||||
])
|
||||
} finally {
|
||||
metricsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state
|
||||
function $reset() {
|
||||
preferredEngine.value = 'doctr_plus'
|
||||
availableEngines.value = [
|
||||
'tesseract', 'doctr', 'doctr_plus', 'paddleocr',
|
||||
]
|
||||
loading.value = false
|
||||
error.value = null
|
||||
initialized.value = false
|
||||
metrics.value = {
|
||||
summary: [],
|
||||
stats: null,
|
||||
history: [],
|
||||
historyTotal: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
// State
|
||||
preferredEngine,
|
||||
availableEngines,
|
||||
loading,
|
||||
error,
|
||||
initialized,
|
||||
metrics,
|
||||
metricsLoading,
|
||||
|
||||
// Computed
|
||||
isLoading,
|
||||
hasError,
|
||||
|
||||
// Actions
|
||||
loadPreference,
|
||||
setPreference,
|
||||
loadMetricsSummary,
|
||||
loadMetricsStats,
|
||||
loadMetricsHistory,
|
||||
loadAllMetrics,
|
||||
$reset,
|
||||
}
|
||||
})
|
||||
@@ -399,7 +399,9 @@ export const useReceiptsStore = defineStore('receipts', {
|
||||
this.partners.push({
|
||||
id: response.data.id,
|
||||
name: response.data.name,
|
||||
code: response.data.fiscal_code,
|
||||
fiscal_code: response.data.fiscal_code,
|
||||
address: response.data.address,
|
||||
source: 'local',
|
||||
})
|
||||
return response.data
|
||||
} catch (error) {
|
||||
@@ -407,6 +409,20 @@ export const useReceiptsStore = defineStore('receipts', {
|
||||
}
|
||||
},
|
||||
|
||||
async syncSuppliers() {
|
||||
try {
|
||||
// Use apiClient directly - nomenclature endpoints at /api/nomenclature
|
||||
const response = await apiClient.post('/nomenclature/sync/suppliers')
|
||||
console.log('[receiptsStore] Synced suppliers:', response.data)
|
||||
// Refresh partners list after sync
|
||||
await this.fetchPartners()
|
||||
return response.data
|
||||
} catch (error) {
|
||||
console.error('[receiptsStore] Supplier sync failed:', error)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
// ============ Stats ============
|
||||
|
||||
async fetchStats() {
|
||||
|
||||
1322
src/modules/data-entry/views/OCRMetricsView.vue
Normal file
1322
src/modules/data-entry/views/OCRMetricsView.vue
Normal file
File diff suppressed because it is too large
Load Diff
@@ -244,7 +244,20 @@
|
||||
<div class="form-group">
|
||||
<div class="form-row">
|
||||
<div class="form-field flex-2">
|
||||
<label>Furnizor</label>
|
||||
<div class="label-with-action">
|
||||
<label>Furnizor</label>
|
||||
<Button
|
||||
v-if="!isReadOnly"
|
||||
icon="pi pi-sync"
|
||||
size="small"
|
||||
text
|
||||
rounded
|
||||
:loading="syncingSuppliers"
|
||||
@click="resyncSuppliers"
|
||||
v-tooltip.top="'Re-sincronizeaza furnizorii din Oracle'"
|
||||
class="sync-btn"
|
||||
/>
|
||||
</div>
|
||||
<AutoComplete
|
||||
v-model="form.partner_name"
|
||||
:suggestions="filteredPartners"
|
||||
@@ -265,10 +278,22 @@
|
||||
<div class="form-field flex-1">
|
||||
<label>CUI</label>
|
||||
<InputText v-model="form.cui" placeholder="RO12345678" :disabled="isReadOnly" />
|
||||
<small v-if="supplierWarning.show" class="p-text-warning supplier-warning">
|
||||
<i class="pi pi-exclamation-triangle"></i>
|
||||
Negasit
|
||||
</small>
|
||||
<div v-if="supplierWarning.show" class="supplier-warning-box">
|
||||
<small class="p-text-warning">
|
||||
<i class="pi pi-exclamation-triangle"></i>
|
||||
Negasit - se va crea automat la salvare
|
||||
</small>
|
||||
<Button
|
||||
v-if="!isReadOnly"
|
||||
label="Creaza acum"
|
||||
icon="pi pi-plus"
|
||||
size="small"
|
||||
severity="warning"
|
||||
text
|
||||
@click="createLocalSupplierFromWarning"
|
||||
class="supplier-create-btn"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Adresa colapsata -->
|
||||
@@ -406,14 +431,13 @@
|
||||
<div class="form-group form-group-last">
|
||||
<div class="form-row">
|
||||
<div class="form-field flex-1">
|
||||
<label>Tip Cheltuiala *</label>
|
||||
<label>Tip Cheltuiala</label>
|
||||
<Dropdown
|
||||
v-model="form.expense_type_code"
|
||||
:options="expenseTypes"
|
||||
optionLabel="name"
|
||||
optionValue="code"
|
||||
placeholder="Selecteaza tip cheltuiala"
|
||||
required
|
||||
:disabled="isReadOnly"
|
||||
class="dropdown-borderless"
|
||||
/>
|
||||
@@ -678,7 +702,7 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import { useRoute, useRouter } from 'vue-router'
|
||||
import { useToast } from 'primevue/usetoast'
|
||||
import { useReceiptsStore } from '@data-entry/stores/receiptsStore'
|
||||
@@ -771,6 +795,7 @@ const missingClientWarning = ref(false)
|
||||
// AutoComplete support
|
||||
const filteredPartners = ref([])
|
||||
const supplierSource = ref(null) // 'local', 'synced', or null
|
||||
const syncingSuppliers = ref(false)
|
||||
|
||||
const partners = computed(() => store.partners)
|
||||
const expenseTypes = computed(() => store.expenseTypes)
|
||||
@@ -812,7 +837,6 @@ const missingRequiredFields = computed(() => {
|
||||
const missing = []
|
||||
if (!validationState.value.hasAmount) missing.push('Suma')
|
||||
if (!validationState.value.hasDate) missing.push('Data')
|
||||
if (!validationState.value.hasExpenseType) missing.push('Tip cheltuiala')
|
||||
if (!validationState.value.hasAttachment) missing.push('Atasament')
|
||||
return missing
|
||||
})
|
||||
@@ -848,6 +872,11 @@ const searchPartners = (event) => {
|
||||
onMounted(async () => {
|
||||
await store.fetchAllNomenclatures()
|
||||
|
||||
// Sync suppliers from Oracle if list is empty (first use or no synced data)
|
||||
if (store.partners.length === 0) {
|
||||
await syncSuppliersIfNeeded()
|
||||
}
|
||||
|
||||
if (isEditMode.value || isViewMode.value) {
|
||||
await loadReceipt()
|
||||
} else {
|
||||
@@ -856,6 +885,76 @@ onMounted(async () => {
|
||||
}
|
||||
})
|
||||
|
||||
// Sync suppliers from Oracle if not already synced
|
||||
const syncSuppliersIfNeeded = async () => {
|
||||
try {
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: 'Sincronizare furnizori',
|
||||
detail: 'Se sincronizeaza furnizorii din Oracle...',
|
||||
life: 3000,
|
||||
})
|
||||
|
||||
const result = await store.syncSuppliers()
|
||||
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
summary: 'Sincronizare completa',
|
||||
detail: `${result.synced || store.partners.length} furnizori sincronizati`,
|
||||
life: 3000,
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn('[ReceiptCreateView] Supplier sync failed:', error)
|
||||
toast.add({
|
||||
severity: 'warn',
|
||||
summary: 'Sincronizare esuata',
|
||||
detail: 'Nu s-au putut sincroniza furnizorii. Puteti continua cu furnizori locali.',
|
||||
life: 5000,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Watch for company changes - sync suppliers in background
|
||||
watch(
|
||||
() => companyStore.selectedCompany,
|
||||
async (newCompany, oldCompany) => {
|
||||
// Only trigger if company actually changed (not on initial load)
|
||||
if (newCompany && oldCompany && newCompany.id_firma !== oldCompany.id_firma) {
|
||||
console.log('[ReceiptCreateView] Company changed, syncing suppliers in background...')
|
||||
// Background sync - don't await, don't block UI
|
||||
store.syncSuppliers().then(result => {
|
||||
console.log('[ReceiptCreateView] Background sync complete:', result)
|
||||
}).catch(error => {
|
||||
console.warn('[ReceiptCreateView] Background sync failed:', error)
|
||||
})
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// Manual re-sync suppliers from Oracle (button click)
|
||||
const resyncSuppliers = async () => {
|
||||
syncingSuppliers.value = true
|
||||
try {
|
||||
const result = await store.syncSuppliers()
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
summary: 'Sincronizare completa',
|
||||
detail: `${result.synced || 0} furnizori noi din Oracle`,
|
||||
life: 3000,
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn('[ReceiptCreateView] Manual supplier sync failed:', error)
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: 'Sincronizare esuata',
|
||||
detail: error.message || 'Nu s-au putut sincroniza furnizorii',
|
||||
life: 5000,
|
||||
})
|
||||
} finally {
|
||||
syncingSuppliers.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const loadReceipt = async () => {
|
||||
try {
|
||||
receipt.value = await store.fetchReceiptById(receiptId.value)
|
||||
@@ -954,6 +1053,7 @@ const onOCRFileSelected = (file) => {
|
||||
}
|
||||
|
||||
const onOCRResult = (data) => {
|
||||
console.log('[OCR Result] Received data, suggested_payment_mode:', data.suggested_payment_mode, 'payment_methods:', data.payment_methods)
|
||||
ocrData.value = data
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
@@ -1142,9 +1242,11 @@ const applyOCRData = async (data) => {
|
||||
}
|
||||
|
||||
// Auto-suggest payment_mode if OCR detected CARD
|
||||
console.log('[OCR Apply] suggested_payment_mode:', data.suggested_payment_mode, 'payment_methods:', data.payment_methods)
|
||||
if (data.suggested_payment_mode) {
|
||||
form.value.payment_mode = data.suggested_payment_mode
|
||||
paymentSetFromOCR.value = true // Show OCR indicator
|
||||
console.log('[OCR Apply] Set payment_mode to:', data.suggested_payment_mode)
|
||||
}
|
||||
|
||||
// AUTO-DETECT DIRECTION (PLATĂ/ÎNCASARE) based on CUI matching
|
||||
@@ -1367,6 +1469,36 @@ const cancelCreateSupplier = () => {
|
||||
pendingSupplierData.value = null
|
||||
}
|
||||
|
||||
// Create local supplier immediately from warning (inline button)
|
||||
const createLocalSupplierFromWarning = async () => {
|
||||
if (!form.value.cui) return
|
||||
|
||||
try {
|
||||
await store.createLocalSupplier({
|
||||
name: form.value.partner_name || supplierWarning.value.name || `Furnizor ${form.value.cui}`,
|
||||
fiscal_code: form.value.cui,
|
||||
address: form.value.vendor_address || null
|
||||
})
|
||||
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
summary: 'Furnizor creat',
|
||||
detail: `${form.value.partner_name || form.value.cui} a fost adaugat`,
|
||||
life: 3000,
|
||||
})
|
||||
|
||||
supplierWarning.value = { show: false, cui: '', name: '' }
|
||||
supplierSource.value = 'local'
|
||||
} catch (error) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: 'Eroare',
|
||||
detail: error.message || 'Nu s-a putut crea furnizorul',
|
||||
life: 5000,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to calculate similarity between two CUI strings
|
||||
// Returns a value between 0 and 1 (1 = identical)
|
||||
const calculateCuiSimilarity = (cui1, cui2) => {
|
||||
@@ -1727,16 +1859,6 @@ const validateForm = () => {
|
||||
return false
|
||||
}
|
||||
|
||||
if (!form.value.expense_type_code) {
|
||||
toast.add({
|
||||
severity: 'warn',
|
||||
summary: 'Validare',
|
||||
detail: 'Tipul cheltuielii este obligatoriu',
|
||||
life: 3000,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Payment mode is validated at submit time, not at draft save
|
||||
// (can save draft without payment mode, but submit requires it)
|
||||
|
||||
@@ -1749,6 +1871,31 @@ const saveReceipt = async () => {
|
||||
saving.value = true
|
||||
|
||||
try {
|
||||
// Auto-create local supplier if CUI is present but not found in database
|
||||
if (form.value.cui && supplierWarning.value.show) {
|
||||
try {
|
||||
await store.createLocalSupplier({
|
||||
name: form.value.partner_name || `Furnizor ${form.value.cui}`,
|
||||
fiscal_code: form.value.cui,
|
||||
address: form.value.vendor_address || null
|
||||
})
|
||||
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: 'Furnizor local creat',
|
||||
detail: `${form.value.partner_name || form.value.cui} adaugat automat`,
|
||||
life: 3000,
|
||||
})
|
||||
|
||||
// Clear warning since supplier is now created
|
||||
supplierWarning.value = { show: false, cui: '', name: '' }
|
||||
supplierSource.value = 'local'
|
||||
} catch (error) {
|
||||
console.warn('[saveReceipt] Failed to auto-create local supplier:', error)
|
||||
// Continue with save anyway - supplier creation is optional
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up payment_methods and tva_breakdown - convert null amounts to 0
|
||||
const cleanedPaymentMethods = form.value.payment_methods?.map(pm => ({
|
||||
...pm,
|
||||
@@ -1898,6 +2045,46 @@ const submitForReview = async () => {
|
||||
font-size: 1.1rem;
|
||||
}
|
||||
|
||||
/* Supplier warning box with inline create button */
|
||||
.supplier-warning-box {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
|
||||
.supplier-warning-box small {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.supplier-create-btn {
|
||||
padding: 0.25rem 0.5rem !important;
|
||||
font-size: 0.75rem !important;
|
||||
}
|
||||
|
||||
/* Label with action button (sync) */
|
||||
.label-with-action {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.label-with-action label {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.sync-btn {
|
||||
width: 1.5rem !important;
|
||||
height: 1.5rem !important;
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.sync-btn .pi {
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
|
||||
/* 2-column layout */
|
||||
.receipt-form-layout {
|
||||
display: grid;
|
||||
|
||||
@@ -78,6 +78,12 @@ const routes = [
|
||||
name: 'ReceiptEdit',
|
||||
component: () => import('@data-entry/views/receipts/ReceiptCreateView.vue'),
|
||||
meta: { requiresAuth: true, title: 'Editare Bon - ROA2WEB' }
|
||||
},
|
||||
{
|
||||
path: 'ocr-metrics',
|
||||
name: 'OCRMetrics',
|
||||
component: () => import('@data-entry/views/OCRMetricsView.vue'),
|
||||
meta: { requiresAuth: true, title: 'Metrici OCR - ROA2WEB' }
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -59,6 +59,9 @@ start_backend() {
|
||||
print_info "Starting unified backend on port $PORT..."
|
||||
cd "$BACKEND_DIR"
|
||||
|
||||
# Clear old log file
|
||||
> "$LOG_FILE"
|
||||
|
||||
# Activate virtual environment and start uvicorn
|
||||
nohup bash -c "source venv/bin/activate && uvicorn main:app --host 0.0.0.0 --port $PORT --reload" > "$LOG_FILE" 2>&1 &
|
||||
|
||||
|
||||
@@ -151,6 +151,10 @@ else
|
||||
source .env
|
||||
set +a
|
||||
|
||||
# Clear old log files
|
||||
> /tmp/unified_backend_prod.log
|
||||
> /tmp/unified_frontend_prod.log
|
||||
|
||||
# Start backend with auto-restart on crash (OOM protection)
|
||||
print_message "Starting unified backend with auto-restart (includes Reports, Data Entry, and Telegram bot)..."
|
||||
nohup ./run-with-restart.sh 8000 /tmp/unified_backend_prod.log > /dev/null 2>&1 &
|
||||
|
||||
@@ -106,6 +106,11 @@ sleep 2
|
||||
# Step 2: Start Unified Backend (8000)
|
||||
print_message "2. Starting Unified Backend on port 8000..."
|
||||
|
||||
# Clear old log files (always, before any port checks)
|
||||
> /tmp/unified_backend_test.log
|
||||
> /tmp/unified_frontend_test.log
|
||||
print_message "Log files cleared"
|
||||
|
||||
if check_port 8000; then
|
||||
print_warning "Port 8000 already in use - Unified Backend may be running"
|
||||
else
|
||||
|
||||
629
tests/ocr-validation/expected_receipts.json
Normal file
629
tests/ocr-validation/expected_receipts.json
Normal file
@@ -0,0 +1,629 @@
|
||||
{
|
||||
"receipts": [
|
||||
{
|
||||
"id": "receipt_01",
|
||||
"filename": "abonament kineterra.pdf",
|
||||
"furnizor": "KINETERRA CONCEPT SRL",
|
||||
"cui_furnizor": "31180432",
|
||||
"client": null,
|
||||
"cui_client": null,
|
||||
"total": 1900.0,
|
||||
"tva_details": [],
|
||||
"total_tva": 0.0,
|
||||
"card": 1900.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-11-10",
|
||||
"numar_bon": "0039",
|
||||
"notes": "Neplatitor TVA - abonament terapie"
|
||||
},
|
||||
{
|
||||
"id": "receipt_02",
|
||||
"filename": "benzina 14 august.pdf",
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 318.16,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 55.22
|
||||
}
|
||||
],
|
||||
"total_tva": 55.22,
|
||||
"card": 318.16,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-08-14",
|
||||
"numar_bon": "2850-00075",
|
||||
"notes": "Benzina standard 95"
|
||||
},
|
||||
{
|
||||
"id": "receipt_03",
|
||||
"filename": "benzina 27 octombrie .pdf",
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 285.66,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 49.58
|
||||
}
|
||||
],
|
||||
"total_tva": 49.58,
|
||||
"card": 285.66,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-10-27",
|
||||
"numar_bon": "2857-00217",
|
||||
"notes": "Benzina standard 95"
|
||||
},
|
||||
{
|
||||
"id": "receipt_04",
|
||||
"filename": "igiena 11 octombrie .pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 186.16,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 32.31
|
||||
}
|
||||
],
|
||||
"total_tva": 32.31,
|
||||
"card": 186.16,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-10-11",
|
||||
"numar_bon": "0171",
|
||||
"notes": "BRICK - produse igiena"
|
||||
},
|
||||
{
|
||||
"id": "receipt_05",
|
||||
"filename": "igiena 14 decembrie five-holding.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 85.99,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 14.92
|
||||
}
|
||||
],
|
||||
"total_tva": 14.92,
|
||||
"card": 85.99,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-12-14",
|
||||
"numar_bon": "0126",
|
||||
"notes": "BRICK - produse igiena"
|
||||
},
|
||||
{
|
||||
"id": "receipt_06",
|
||||
"filename": "rechizite 12 decembrie pictus.pdf",
|
||||
"furnizor": "PICTUS VELUM SRL",
|
||||
"cui_furnizor": "RO39634534",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 11.9,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 2.07
|
||||
}
|
||||
],
|
||||
"total_tva": 2.07,
|
||||
"card": 11.9,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-12-12",
|
||||
"numar_bon": "0060",
|
||||
"notes": "Rechizite - creioane, radiera"
|
||||
},
|
||||
{
|
||||
"id": "receipt_07",
|
||||
"filename": "benzina 10 mai 2025.pdf",
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 231.83,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 37.01
|
||||
}
|
||||
],
|
||||
"total_tva": 37.01,
|
||||
"card": 231.83,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-05-10",
|
||||
"numar_bon": "2863-00239",
|
||||
"notes": "Benzina standard 95 - Petrom Baia"
|
||||
},
|
||||
{
|
||||
"id": "receipt_08",
|
||||
"filename": "brick consumabil 604 50% deductibil 22 dec.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 21.18,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 3.68
|
||||
}
|
||||
],
|
||||
"total_tva": 3.68,
|
||||
"card": 21.18,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-12-22",
|
||||
"numar_bon": "0159",
|
||||
"notes": "BRICK - lichid spalare parbriz"
|
||||
},
|
||||
{
|
||||
"id": "receipt_09",
|
||||
"filename": "benzina 13 septembrie .pdf",
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 275.91,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 47.89
|
||||
}
|
||||
],
|
||||
"total_tva": 47.89,
|
||||
"card": 275.91,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-09-13",
|
||||
"numar_bon": "2813-00298",
|
||||
"notes": "Benzina standard 95"
|
||||
},
|
||||
{
|
||||
"id": "receipt_10",
|
||||
"filename": "brick consumabile 604 22 dec.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 5.27,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 0.91
|
||||
}
|
||||
],
|
||||
"total_tva": 0.91,
|
||||
"card": 5.27,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-12-22",
|
||||
"numar_bon": "0175",
|
||||
"notes": "BRICK - suport polita"
|
||||
},
|
||||
{
|
||||
"id": "receipt_11",
|
||||
"filename": "benzina 20 dec.pdf",
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 282.79,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 49.08
|
||||
}
|
||||
],
|
||||
"total_tva": 49.08,
|
||||
"card": 282.79,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-12-20",
|
||||
"numar_bon": "2820-00306",
|
||||
"notes": "Benzina standard 95 - Petrom 26 Constanta"
|
||||
},
|
||||
{
|
||||
"id": "receipt_12",
|
||||
"filename": "bon fiscal Dedeman - efactura.pdf",
|
||||
"furnizor": "DEDEMAN SRL",
|
||||
"cui_furnizor": "RO2816464",
|
||||
"client": "ROMFAST SRL CONSTANTA",
|
||||
"cui_client": "1879855",
|
||||
"total": 5.83,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 0.93
|
||||
}
|
||||
],
|
||||
"total_tva": 0.93,
|
||||
"card": 5.83,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-12-18",
|
||||
"numar_bon": "0066",
|
||||
"notes": "Dedeman - garnituri, reductie"
|
||||
},
|
||||
{
|
||||
"id": "receipt_13",
|
||||
"filename": "factura 70005116259 20.09.2025 Dedeman.pdf",
|
||||
"furnizor": "DEDEMAN SRL",
|
||||
"cui_furnizor": "RO2816464",
|
||||
"client": "ONCR BLEUMARIN CONSTANTA",
|
||||
"cui_client": "46598884",
|
||||
"total": 53.7,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 21,
|
||||
"value": 9.32
|
||||
}
|
||||
],
|
||||
"total_tva": 9.32,
|
||||
"card": 53.7,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-09-20",
|
||||
"numar_bon": "0164",
|
||||
"notes": "Dedeman - folie delimitare, baterii"
|
||||
},
|
||||
{
|
||||
"id": "receipt_14",
|
||||
"filename": "benzina 13 iulie.pdf",
|
||||
"furnizor": "SOCAR PETROLEUM S.A.",
|
||||
"cui_furnizor": "RO12546600",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 252.4,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 40.3
|
||||
}
|
||||
],
|
||||
"total_tva": 40.3,
|
||||
"card": 252.4,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-07-13",
|
||||
"numar_bon": "2443-00129",
|
||||
"notes": "NANO 95 - Socar Adjud Vrancea"
|
||||
},
|
||||
{
|
||||
"id": "receipt_15",
|
||||
"filename": "best print stampila .pdf",
|
||||
"furnizor": "BEST PRINT TRADE ACTIV SRL",
|
||||
"cui_furnizor": "45417955",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 100.0,
|
||||
"tva_details": [],
|
||||
"total_tva": 0.0,
|
||||
"card": 100.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-07-15",
|
||||
"numar_bon": "0018",
|
||||
"notes": "Neplatitor TVA - stampila"
|
||||
},
|
||||
{
|
||||
"id": "receipt_16",
|
||||
"filename": "electrobering telecomanda.pdf",
|
||||
"furnizor": "ELECTROBERING S.R.L.",
|
||||
"cui_furnizor": "RO2744937",
|
||||
"client": null,
|
||||
"cui_client": "1879855",
|
||||
"total": 35.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 5.59
|
||||
}
|
||||
],
|
||||
"total_tva": 5.59,
|
||||
"card": 35.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2025-07-17",
|
||||
"numar_bon": "0073",
|
||||
"notes": "Telecomanda A.C."
|
||||
},
|
||||
{
|
||||
"id": "receipt_17a",
|
||||
"filename": "stepout market carti tva 5%.pdf",
|
||||
"page": 1,
|
||||
"furnizor": "STEPOUT MARKET SRL",
|
||||
"cui_furnizor": "RO35532655",
|
||||
"client": null,
|
||||
"cui_client": "1879855",
|
||||
"total": 156.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 5,
|
||||
"value": 7.43
|
||||
}
|
||||
],
|
||||
"total_tva": 7.43,
|
||||
"card": 156.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-10-22",
|
||||
"numar_bon": "000009",
|
||||
"notes": "Carti - TVA 5%",
|
||||
"pages": 2
|
||||
},
|
||||
{
|
||||
"id": "receipt_17b",
|
||||
"filename": "stepout market carti tva 5%.pdf",
|
||||
"page": 2,
|
||||
"furnizor": "STEPOUT MARKET SRL",
|
||||
"cui_furnizor": "RO35532655",
|
||||
"client": null,
|
||||
"cui_client": "1879855",
|
||||
"total": 78.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 5,
|
||||
"value": 3.71
|
||||
}
|
||||
],
|
||||
"total_tva": 3.71,
|
||||
"card": 78.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-10-24",
|
||||
"numar_bon": "000024",
|
||||
"notes": "Carti - TVA 5%",
|
||||
"pages": 2
|
||||
},
|
||||
{
|
||||
"id": "receipt_18",
|
||||
"filename": "brick igiena 8 octombrie 98.95 lei card.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 98.95,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 15.8
|
||||
}
|
||||
],
|
||||
"total_tva": 15.8,
|
||||
"card": 98.95,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-10-08",
|
||||
"numar_bon": "0299",
|
||||
"notes": "BRICK - produse igiena"
|
||||
},
|
||||
{
|
||||
"id": "receipt_19",
|
||||
"filename": "gama ink refill toner imprimanta 17 sept 2024.pdf",
|
||||
"furnizor": "GAMA INK SERVICE SRL",
|
||||
"cui_furnizor": "RO17741882",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 45.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 7.18
|
||||
}
|
||||
],
|
||||
"total_tva": 7.18,
|
||||
"card": 45.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-09-17",
|
||||
"numar_bon": "0041",
|
||||
"notes": "Incarcare toner HP"
|
||||
},
|
||||
{
|
||||
"id": "receipt_20",
|
||||
"filename": "kineterra fizioterapie 9 sept.pdf",
|
||||
"furnizor": "KINETERRA CONCEPT SRL",
|
||||
"cui_furnizor": "31180432",
|
||||
"client": null,
|
||||
"cui_client": null,
|
||||
"total": 650.0,
|
||||
"tva_details": [],
|
||||
"total_tva": 0.0,
|
||||
"card": 650.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-09-09",
|
||||
"numar_bon": "0024",
|
||||
"notes": "Neplatitor TVA - diatermie tecar"
|
||||
},
|
||||
{
|
||||
"id": "receipt_21",
|
||||
"filename": "brick igiena 1 sept.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 82.86,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 13.23
|
||||
}
|
||||
],
|
||||
"total_tva": 13.23,
|
||||
"card": 82.86,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-09-01",
|
||||
"numar_bon": "0047",
|
||||
"notes": "BRICK - produse igiena, instalatii"
|
||||
},
|
||||
{
|
||||
"id": "receipt_22",
|
||||
"filename": "kineterra abonament terapie august 2024.pdf",
|
||||
"furnizor": "KINETERRA CONCEPT SRL",
|
||||
"cui_furnizor": "31180432",
|
||||
"client": null,
|
||||
"cui_client": null,
|
||||
"total": 750.0,
|
||||
"tva_details": [],
|
||||
"total_tva": 0.0,
|
||||
"card": 750.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-08-27",
|
||||
"numar_bon": "0029",
|
||||
"notes": "Neplatitor TVA - terapie acvatica"
|
||||
},
|
||||
{
|
||||
"id": "receipt_23a",
|
||||
"filename": "benzina 07 aug. 2024.pdf",
|
||||
"page": 1,
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 263.28,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 42.04
|
||||
}
|
||||
],
|
||||
"total_tva": 42.04,
|
||||
"card": 263.28,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-08-01",
|
||||
"numar_bon": "2134-00220",
|
||||
"notes": "Benzina standard 95 - Petrom 1 Huedin",
|
||||
"pages": 2
|
||||
},
|
||||
{
|
||||
"id": "receipt_23b",
|
||||
"filename": "benzina 07 aug. 2024.pdf",
|
||||
"page": 2,
|
||||
"furnizor": "OMV PETROM MARKETING S.R.L.",
|
||||
"cui_furnizor": "RO11201891",
|
||||
"client": "ROMFAST SRL",
|
||||
"cui_client": "RO1879855",
|
||||
"total": 306.67,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 48.96
|
||||
}
|
||||
],
|
||||
"total_tva": 48.96,
|
||||
"card": 306.67,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-08-02",
|
||||
"numar_bon": "2193-00699",
|
||||
"notes": "Benzina standard 95 - Petrom A2 KM66",
|
||||
"pages": 2
|
||||
},
|
||||
{
|
||||
"id": "receipt_24",
|
||||
"filename": "brick igiena, electrice consumabile 604.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 190.6,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 30.43
|
||||
}
|
||||
],
|
||||
"total_tva": 30.43,
|
||||
"card": 190.6,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-08-05",
|
||||
"numar_bon": "0207",
|
||||
"notes": "BRICK - electrice, instalatii, igiena"
|
||||
},
|
||||
{
|
||||
"id": "receipt_25",
|
||||
"filename": "electrobering igiena iulie 604.pdf",
|
||||
"furnizor": "ELECTROBERING S.R.L.",
|
||||
"cui_furnizor": "RO2744937",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 62.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 9.9
|
||||
}
|
||||
],
|
||||
"total_tva": 9.9,
|
||||
"card": 62.0,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-07-11",
|
||||
"numar_bon": "0059",
|
||||
"notes": "Filtru, spray detector"
|
||||
},
|
||||
{
|
||||
"id": "receipt_26",
|
||||
"filename": "Lidl papetarie 604 fara TVA. nu are cod fiscal.pdf",
|
||||
"furnizor": "LIDL DISCOUNT S.R.L.",
|
||||
"cui_furnizor": "RO22891860",
|
||||
"client": null,
|
||||
"cui_client": null,
|
||||
"total": 39.96,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 6.38
|
||||
}
|
||||
],
|
||||
"total_tva": 6.38,
|
||||
"card": 39.96,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-07-01",
|
||||
"numar_bon": "00719",
|
||||
"notes": "Papetarie - agende, caiete. FARA CIF CLIENT!"
|
||||
},
|
||||
{
|
||||
"id": "receipt_27",
|
||||
"filename": "brick igiena 604.pdf",
|
||||
"furnizor": "FIVE-HOLDING S.A.",
|
||||
"cui_furnizor": "RO10562600",
|
||||
"client": null,
|
||||
"cui_client": "RO1879855",
|
||||
"total": 155.15,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 24.77
|
||||
}
|
||||
],
|
||||
"total_tva": 24.77,
|
||||
"card": 155.15,
|
||||
"numerar": 0.0,
|
||||
"data_bon": "2024-06-28",
|
||||
"numar_bon": "0293",
|
||||
"notes": "BRICK - igiena, consumabile auto"
|
||||
},
|
||||
{
|
||||
"id": "receipt_28",
|
||||
"filename": "unlimited duplicat chei 23 mai.pdf",
|
||||
"furnizor": "UNLIMITED KEYS S.R.L.",
|
||||
"cui_furnizor": "RO18993187",
|
||||
"client": null,
|
||||
"cui_client": "1879855",
|
||||
"total": 80.0,
|
||||
"tva_details": [
|
||||
{
|
||||
"rate": 19,
|
||||
"value": 12.77
|
||||
}
|
||||
],
|
||||
"total_tva": 12.77,
|
||||
"card": 0.0,
|
||||
"numerar": 80.0,
|
||||
"data_bon": "2024-05-23",
|
||||
"numar_bon": "000004",
|
||||
"notes": "Duplicat cheie yala - NUMERAR"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"total_receipts": 30,
|
||||
"total_files": 28,
|
||||
"extracted_by": "Claude - manual extraction",
|
||||
"extraction_date": "2026-01-01",
|
||||
"notes": "Some PDF files contain multiple receipts (pages)"
|
||||
}
|
||||
}
|
||||
127
tests/ocr-validation/get_raw_ocr_text.py
Normal file
127
tests/ocr-validation/get_raw_ocr_text.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick script to get raw OCR text for specific receipts.
|
||||
Usage: python get_raw_ocr_text.py <receipt_path>
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / 'backend'))
|
||||
|
||||
from jose import jwt
|
||||
|
||||
API_BASE = "http://localhost:8000/api/data-entry"
|
||||
|
||||
def create_test_token() -> str:
|
||||
"""Create a test JWT token for API authentication."""
|
||||
secret_key = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
|
||||
now = datetime.utcnow()
|
||||
expire = now + timedelta(hours=1)
|
||||
|
||||
payload = {
|
||||
"username": "ocr_test_user",
|
||||
"user_id": 999,
|
||||
"companies": ["TEST"],
|
||||
"permissions": ["read", "write"],
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
return jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def get_raw_ocr_text(file_path: str, token: str) -> dict:
|
||||
"""Submit file to OCR and get raw text."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return {"error": f"File not found: {file_path}"}
|
||||
|
||||
# Submit OCR job
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Processing: {path.name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
headers = {'Authorization': f'Bearer {token}'}
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
files = {'file': (path.name, f, 'application/pdf')}
|
||||
|
||||
response = requests.post(f"{API_BASE}/ocr/extract?engine=doctr_plus", files=files, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {"error": f"Submit failed: {response.status_code} - {response.text}"}
|
||||
|
||||
result = response.json()
|
||||
job_id = result.get('job_id')
|
||||
print(f"Job ID: {job_id}")
|
||||
|
||||
# Poll for completion
|
||||
max_wait = 120
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < max_wait:
|
||||
status_response = requests.get(f"{API_BASE}/ocr/jobs/{job_id}", headers=headers)
|
||||
if status_response.status_code != 200:
|
||||
return {"error": f"Status check failed: {status_response.status_code}"}
|
||||
|
||||
status = status_response.json()
|
||||
job_status = status.get('status')
|
||||
|
||||
if job_status == 'completed':
|
||||
result = status.get('result', {})
|
||||
|
||||
# Print raw texts
|
||||
raw_texts = result.get('raw_texts', [])
|
||||
print(f"\n--- RAW OCR TEXT ({len(raw_texts)} passes) ---\n")
|
||||
|
||||
for i, raw_text in enumerate(raw_texts):
|
||||
print(f"\n=== Pass {i+1} ===")
|
||||
print(raw_text[:3000] if len(raw_text) > 3000 else raw_text)
|
||||
print(f"\n[Text length: {len(raw_text)} chars]")
|
||||
|
||||
# Print extracted fields
|
||||
print(f"\n--- EXTRACTED FIELDS ---")
|
||||
print(f"TOTAL: {result.get('amount')}")
|
||||
print(f"DATE: {result.get('receipt_date')}")
|
||||
print(f"CUI: {result.get('cui')}")
|
||||
print(f"TVA Total: {result.get('tva_total')}")
|
||||
print(f"TVA Entries: {result.get('tva_entries')}")
|
||||
print(f"Confidence: {result.get('overall_confidence')}")
|
||||
print(f"Engine: {result.get('ocr_engine')}")
|
||||
|
||||
return result
|
||||
|
||||
elif job_status == 'failed':
|
||||
return {"error": f"OCR failed: {status.get('error')}"}
|
||||
|
||||
print(f" Status: {job_status}, waiting...")
|
||||
time.sleep(2)
|
||||
|
||||
return {"error": "Timeout waiting for OCR"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create test token
|
||||
token = create_test_token()
|
||||
print(f"Using JWT token for authentication")
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
# Default: process the two receipts user wants to see
|
||||
receipts = [
|
||||
"/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/docs/data-entry/brick igiena 1 sept.pdf",
|
||||
"/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/docs/data-entry/brick igiena, electrice consumabile 604.pdf"
|
||||
]
|
||||
else:
|
||||
receipts = sys.argv[1:]
|
||||
|
||||
for receipt in receipts:
|
||||
result = get_raw_ocr_text(receipt, token)
|
||||
if "error" in result:
|
||||
print(f"ERROR: {result['error']}")
|
||||
593
tests/ocr-validation/ocr-direct-validation.py
Normal file
593
tests/ocr-validation/ocr-direct-validation.py
Normal file
@@ -0,0 +1,593 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OCR Direct Validation Tests
|
||||
|
||||
This script validates the OCR extraction accuracy by:
|
||||
1. Generating a test JWT token
|
||||
2. Calling the OCR API endpoint with PDF receipts
|
||||
3. Comparing extracted data with expected values from expected_receipts.json
|
||||
|
||||
Run:
|
||||
python tests/ocr-validation/ocr-direct-validation.py
|
||||
python tests/ocr-validation/ocr-direct-validation.py --engine doctr_plus
|
||||
python tests/ocr-validation/ocr-direct-validation.py --receipt receipt_01
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
# Add backend and project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / 'backend'))
|
||||
|
||||
# Import JWT handler to create test tokens
|
||||
from jose import jwt
|
||||
|
||||
|
||||
def create_test_token(secret_key: str) -> str:
|
||||
"""Create a test JWT token for API authentication."""
|
||||
now = datetime.utcnow()
|
||||
expire = now + timedelta(hours=1)
|
||||
|
||||
payload = {
|
||||
"username": "ocr_test_user",
|
||||
"user_id": 999,
|
||||
"companies": ["TEST"],
|
||||
"permissions": ["read", "write"],
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
return jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def normalize_cui(cui: Optional[str]) -> Optional[str]:
|
||||
"""Normalize CUI by removing RO prefix and spaces."""
|
||||
if not cui:
|
||||
return None
|
||||
return cui.upper().replace('RO', '').replace(' ', '')
|
||||
|
||||
|
||||
def normalize_date(date: Optional[str]) -> Optional[str]:
|
||||
"""Normalize date to YYYY-MM-DD format."""
|
||||
if not date:
|
||||
return None
|
||||
try:
|
||||
# Try parsing ISO format
|
||||
from datetime import datetime
|
||||
parsed = datetime.fromisoformat(date.replace('Z', '+00:00'))
|
||||
return parsed.strftime('%Y-%m-%d')
|
||||
except:
|
||||
return date
|
||||
|
||||
|
||||
def compare_with_tolerance(expected: float, actual, tolerance: float) -> bool:
|
||||
"""Compare numbers with tolerance."""
|
||||
if actual is None:
|
||||
return False
|
||||
# Handle string values from API
|
||||
try:
|
||||
actual_float = float(actual)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
diff = abs(expected - actual_float)
|
||||
threshold = expected * tolerance
|
||||
return diff <= threshold or diff <= 0.01 # Within tolerance or 1 cent
|
||||
|
||||
|
||||
def submit_ocr_job(api_base: str, token: str, pdf_path: Path, engine: str) -> dict:
|
||||
"""Submit a PDF file for OCR processing and wait for result.
|
||||
|
||||
Returns dict with detailed timing information:
|
||||
- timing.submit_duration: time to submit job and get job_id
|
||||
- timing.poll_count: number of poll requests made
|
||||
- timing.poll_duration: total time spent polling
|
||||
- timing.wall_time: total elapsed time (submit + poll)
|
||||
- timing.api_reported_ms: processing time reported by API
|
||||
"""
|
||||
headers = {'Authorization': f'Bearer {token}'}
|
||||
|
||||
# Timing tracking
|
||||
timing = {
|
||||
'submit_duration': 0.0,
|
||||
'poll_count': 0,
|
||||
'poll_duration': 0.0,
|
||||
'wall_time': 0.0,
|
||||
'api_reported_ms': 0,
|
||||
}
|
||||
|
||||
wall_start = time.time()
|
||||
|
||||
# Submit job
|
||||
submit_start = time.time()
|
||||
with open(pdf_path, 'rb') as f:
|
||||
files = {'file': (pdf_path.name, f, 'application/pdf')}
|
||||
response = requests.post(
|
||||
f'{api_base}/api/data-entry/ocr/extract?engine={engine}',
|
||||
headers=headers,
|
||||
files=files,
|
||||
timeout=60
|
||||
)
|
||||
timing['submit_duration'] = time.time() - submit_start
|
||||
|
||||
if not response.ok:
|
||||
timing['wall_time'] = time.time() - wall_start
|
||||
return {'status': 'failed', 'error': f'Submit failed: {response.status_code} - {response.text}', 'timing': timing}
|
||||
|
||||
job = response.json()
|
||||
job_id = job.get('job_id')
|
||||
|
||||
if not job_id:
|
||||
timing['wall_time'] = time.time() - wall_start
|
||||
return {'status': 'failed', 'error': 'No job_id in response', 'timing': timing}
|
||||
|
||||
# Poll for result
|
||||
poll_start = time.time()
|
||||
max_wait = 120 # 2 minutes
|
||||
|
||||
while time.time() - wall_start < max_wait:
|
||||
timing['poll_count'] += 1
|
||||
poll_response = requests.get(
|
||||
f'{api_base}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30',
|
||||
headers=headers,
|
||||
timeout=35
|
||||
)
|
||||
|
||||
if not poll_response.ok:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
job_status = poll_response.json()
|
||||
|
||||
if job_status.get('status') == 'completed':
|
||||
timing['poll_duration'] = time.time() - poll_start
|
||||
timing['wall_time'] = time.time() - wall_start
|
||||
# Detailed timing from API
|
||||
timing['queue_wait_ms'] = job_status.get('queue_wait_ms', 0) or 0
|
||||
timing['ocr_time_ms'] = job_status.get('ocr_time_ms', 0) or 0
|
||||
timing['processing_time_ms'] = job_status.get('processing_time_ms', 0) or 0
|
||||
return {
|
||||
'status': 'completed',
|
||||
'result': job_status.get('result', {}),
|
||||
'processing_time_ms': job_status.get('processing_time_ms', 0),
|
||||
'timing': timing
|
||||
}
|
||||
|
||||
if job_status.get('status') == 'failed':
|
||||
timing['poll_duration'] = time.time() - poll_start
|
||||
timing['wall_time'] = time.time() - wall_start
|
||||
return {'status': 'failed', 'error': job_status.get('error', 'Unknown error'), 'timing': timing}
|
||||
|
||||
# Still pending - show status but don't spam
|
||||
if timing['poll_count'] <= 3 or timing['poll_count'] % 5 == 0:
|
||||
print(f" Status: {job_status.get('status')}, position: {job_status.get('queue_position')}, polls: {timing['poll_count']}")
|
||||
|
||||
timing['poll_duration'] = time.time() - poll_start
|
||||
timing['wall_time'] = time.time() - wall_start
|
||||
return {'status': 'failed', 'error': 'Timeout waiting for OCR result', 'timing': timing}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='OCR Direct Validation')
|
||||
parser.add_argument('--engine', default='doctr_plus',
|
||||
choices=['tesseract', 'doctr', 'doctr_plus', 'paddleocr'],
|
||||
help='OCR engine to use (doctr_plus recommended)')
|
||||
parser.add_argument('--receipt', help='Specific receipt ID to test (e.g., receipt_01)')
|
||||
parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output')
|
||||
parser.add_argument('--api-base', default='http://localhost:8000', help='API base URL')
|
||||
parser.add_argument('--stop-on-issue', action='store_true',
|
||||
help='Stop at first receipt with wall_time > 7.5s or extraction errors')
|
||||
parser.add_argument('--include-multipage', action='store_true',
|
||||
help='Include multi-page PDFs (normally skipped)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Paths
|
||||
script_dir = Path(__file__).parent
|
||||
expected_path = script_dir / 'expected_receipts.json'
|
||||
pdf_base_path = script_dir.parent.parent / 'docs' / 'data-entry'
|
||||
|
||||
# JWT secret from environment or default
|
||||
jwt_secret = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
|
||||
|
||||
# Create test token
|
||||
token = create_test_token(jwt_secret)
|
||||
|
||||
# Load expected data
|
||||
print(f"\n{'='*60}")
|
||||
print("OCR API VALIDATION")
|
||||
print(f"{'='*60}")
|
||||
print(f"Engine: {args.engine}")
|
||||
print(f"API Base: {args.api_base}")
|
||||
print(f"Expected data: {expected_path}")
|
||||
print(f"PDF folder: {pdf_base_path}")
|
||||
|
||||
with open(expected_path) as f:
|
||||
expected_data = json.load(f)
|
||||
|
||||
# Filter receipts
|
||||
receipts_to_test = expected_data['receipts']
|
||||
|
||||
# Skip multi-page PDFs unless explicitly included
|
||||
if not args.include_multipage:
|
||||
original_count = len(receipts_to_test)
|
||||
receipts_to_test = [r for r in receipts_to_test if r.get('page') is None]
|
||||
skipped = original_count - len(receipts_to_test)
|
||||
if skipped > 0:
|
||||
print(f"Skipping {skipped} multi-page PDF entries (use --include-multipage to include)")
|
||||
|
||||
# Filter by specific receipt ID if requested
|
||||
if args.receipt:
|
||||
receipts_to_test = [r for r in receipts_to_test if r['id'] == args.receipt]
|
||||
if not receipts_to_test:
|
||||
print(f"\nError: Receipt ID '{args.receipt}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Receipts to test: {len(receipts_to_test)}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Results storage
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
# Test each receipt
|
||||
for expected in receipts_to_test:
|
||||
pdf_path = pdf_base_path / expected['filename']
|
||||
|
||||
if not pdf_path.exists():
|
||||
print(f"[SKIP] File not found: {expected['filename']}")
|
||||
continue
|
||||
|
||||
print(f"[TEST] Processing: {expected['filename']}")
|
||||
|
||||
try:
|
||||
# Submit OCR job via API
|
||||
start_time = datetime.now()
|
||||
ocr_result = submit_ocr_job(args.api_base, token, pdf_path, args.engine)
|
||||
# Handle processing_time which may be string or number
|
||||
raw_time = ocr_result.get('processing_time_ms')
|
||||
if raw_time is not None:
|
||||
processing_time = float(raw_time)
|
||||
else:
|
||||
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
if ocr_result.get('status') == 'failed':
|
||||
print(f" [ERROR] OCR failed: {ocr_result.get('error')}")
|
||||
results.append({
|
||||
'receipt_id': expected['id'],
|
||||
'filename': expected['filename'],
|
||||
'status': 'failed',
|
||||
'error': ocr_result.get('error'),
|
||||
})
|
||||
continue
|
||||
|
||||
# Get extracted values
|
||||
extracted = ocr_result.get('result', {})
|
||||
|
||||
# Safe float conversion helper
|
||||
def safe_float(value, default=0.0):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
# Compare results
|
||||
comparison = {
|
||||
'receipt_id': expected['id'],
|
||||
'filename': expected['filename'],
|
||||
'status': 'completed',
|
||||
'total_expected': expected['total'],
|
||||
'total_extracted': safe_float(extracted.get('amount'), None),
|
||||
'total_match': False,
|
||||
'date_expected': expected['data_bon'],
|
||||
'date_extracted': extracted.get('receipt_date'),
|
||||
'date_match': False,
|
||||
'cui_expected': expected['cui_furnizor'],
|
||||
'cui_extracted': extracted.get('cui'),
|
||||
'cui_match': False,
|
||||
'tva_expected': expected['total_tva'],
|
||||
'tva_extracted': safe_float(extracted.get('tva_total'), None),
|
||||
'tva_match': False,
|
||||
'confidence': safe_float(extracted.get('overall_confidence'), 0),
|
||||
'processing_time_ms': processing_time,
|
||||
'ocr_engine': extracted.get('ocr_engine', args.engine),
|
||||
'errors': [],
|
||||
# NEW: Save full extraction for analysis
|
||||
'full_extraction': {
|
||||
'amount': extracted.get('amount'),
|
||||
'receipt_date': extracted.get('receipt_date'),
|
||||
'cui': extracted.get('cui'),
|
||||
'tva_total': extracted.get('tva_total'),
|
||||
'tva_entries': extracted.get('tva_entries', []),
|
||||
'supplier_name': extracted.get('supplier_name'),
|
||||
'receipt_number': extracted.get('receipt_number'),
|
||||
'payment_methods': extracted.get('payment_methods', []),
|
||||
'items_count': extracted.get('items_count'),
|
||||
'overall_confidence': extracted.get('overall_confidence'),
|
||||
'confidence_amount': extracted.get('confidence_amount'),
|
||||
'confidence_date': extracted.get('confidence_date'),
|
||||
'confidence_cui': extracted.get('confidence_cui'),
|
||||
},
|
||||
# NEW: Save raw OCR texts from each engine pass
|
||||
'raw_texts': extracted.get('raw_texts', []),
|
||||
}
|
||||
|
||||
# Compare TOTAL
|
||||
comparison['total_match'] = compare_with_tolerance(
|
||||
expected['total'],
|
||||
extracted.get('amount'),
|
||||
0.02 # 2% tolerance
|
||||
)
|
||||
if not comparison['total_match']:
|
||||
comparison['errors'].append(
|
||||
f"TOTAL: expected {expected['total']}, got {extracted.get('amount')}"
|
||||
)
|
||||
|
||||
# Compare DATE
|
||||
normalized_expected_date = normalize_date(expected['data_bon'])
|
||||
normalized_extracted_date = normalize_date(extracted.get('receipt_date'))
|
||||
comparison['date_match'] = normalized_expected_date == normalized_extracted_date
|
||||
if not comparison['date_match']:
|
||||
comparison['errors'].append(
|
||||
f"DATE: expected {normalized_expected_date}, got {normalized_extracted_date}"
|
||||
)
|
||||
|
||||
# Compare CUI
|
||||
normalized_expected_cui = normalize_cui(expected['cui_furnizor'])
|
||||
normalized_extracted_cui = normalize_cui(extracted.get('cui'))
|
||||
comparison['cui_match'] = normalized_expected_cui == normalized_extracted_cui
|
||||
if not comparison['cui_match']:
|
||||
comparison['errors'].append(
|
||||
f"CUI: expected {normalized_expected_cui}, got {normalized_extracted_cui}"
|
||||
)
|
||||
|
||||
# Compare TVA
|
||||
if expected['total_tva'] > 0:
|
||||
comparison['tva_match'] = compare_with_tolerance(
|
||||
expected['total_tva'],
|
||||
extracted.get('tva_total'),
|
||||
0.05 # 5% tolerance
|
||||
)
|
||||
if not comparison['tva_match']:
|
||||
comparison['errors'].append(
|
||||
f"TVA: expected {expected['total_tva']}, got {extracted.get('tva_total')}"
|
||||
)
|
||||
else:
|
||||
# No TVA expected (neplatitor TVA)
|
||||
tva_extracted = safe_float(extracted.get('tva_total'), None)
|
||||
comparison['tva_match'] = tva_extracted is None or tva_extracted == 0 or tva_extracted == 0.0
|
||||
|
||||
results.append(comparison)
|
||||
|
||||
# Get timing info from API (detailed breakdown)
|
||||
t = ocr_result.get('timing', {})
|
||||
wall_ms = t.get('wall_time', 0) * 1000
|
||||
queue_wait_ms = t.get('queue_wait_ms', 0)
|
||||
ocr_time_ms = t.get('ocr_time_ms', 0)
|
||||
processing_time_ms = t.get('processing_time_ms', 0)
|
||||
# Overhead = wall_time - processing_time (includes network, polling)
|
||||
overhead_ms = wall_ms - processing_time_ms if processing_time_ms else 0
|
||||
|
||||
# Print result
|
||||
status = 'PASS' if not comparison['errors'] else 'FAIL'
|
||||
print(f" [{status}] Total: {expected['total']} vs {extracted.get('amount')} ({comparison['total_match']})")
|
||||
print(f" Date: {expected['data_bon']} vs {extracted.get('receipt_date')} ({comparison['date_match']})")
|
||||
print(f" CUI: {expected['cui_furnizor']} vs {extracted.get('cui')} ({comparison['cui_match']})")
|
||||
print(f" TVA: {expected['total_tva']} vs {extracted.get('tva_total')} ({comparison['tva_match']})")
|
||||
print(f" Confidence: {comparison['confidence']*100:.1f}%")
|
||||
|
||||
# Print detailed timing breakdown
|
||||
print(f" TIMING: ocr={ocr_time_ms}ms, queue_wait={queue_wait_ms}ms, "
|
||||
f"job_total={processing_time_ms}ms, wall={wall_ms:.0f}ms")
|
||||
print(f" overhead={overhead_ms:.0f}ms (wall - job_total)")
|
||||
|
||||
if comparison['errors'] and args.verbose:
|
||||
for err in comparison['errors']:
|
||||
print(f" Error: {err}")
|
||||
|
||||
# Stop on issue if requested
|
||||
if args.stop_on_issue:
|
||||
has_errors = len(comparison['errors']) > 0
|
||||
# Use OCR time for threshold (actual processing, not queue wait)
|
||||
ocr_too_slow = ocr_time_ms > 10000 # 10s threshold for actual OCR
|
||||
|
||||
if has_errors or ocr_too_slow:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"STOP: Issue detected on {expected['filename']}")
|
||||
print(f"{'='*60}")
|
||||
if ocr_too_slow:
|
||||
print(f" SLOW: ocr_time={ocr_time_ms}ms > 10000ms threshold")
|
||||
if has_errors:
|
||||
print(f" ERRORS: {comparison['errors']}")
|
||||
print(f"\n Full timing breakdown:")
|
||||
print(f" ocr_time_ms: {ocr_time_ms}ms (actual OCR engine time)")
|
||||
print(f" queue_wait_ms: {queue_wait_ms}ms (waiting in queue)")
|
||||
print(f" processing_time_ms: {processing_time_ms}ms (job total)")
|
||||
print(f" wall_time: {wall_ms:.0f}ms (client-side)")
|
||||
print(f" overhead: {overhead_ms:.0f}ms (network + polling)")
|
||||
print(f"\n Full extraction:")
|
||||
print(json.dumps(extracted, indent=4, default=str))
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f" [ERROR] {str(e)}")
|
||||
if args.verbose:
|
||||
traceback.print_exc()
|
||||
results.append({
|
||||
'receipt_id': expected['id'],
|
||||
'filename': expected['filename'],
|
||||
'status': 'error',
|
||||
'error': str(e),
|
||||
})
|
||||
|
||||
# Calculate statistics
|
||||
completed_results = [r for r in results if r.get('status') == 'completed']
|
||||
total_tests = len(completed_results)
|
||||
|
||||
if total_tests == 0:
|
||||
print("\nNo tests completed successfully!")
|
||||
sys.exit(1)
|
||||
|
||||
perfect_matches = len([r for r in completed_results
|
||||
if r['total_match'] and r['date_match'] and r['cui_match'] and r['tva_match']])
|
||||
total_match_rate = len([r for r in completed_results if r['total_match']]) / total_tests
|
||||
date_match_rate = len([r for r in completed_results if r['date_match']]) / total_tests
|
||||
cui_match_rate = len([r for r in completed_results if r['cui_match']]) / total_tests
|
||||
tva_match_rate = len([r for r in completed_results if r['tva_match']]) / total_tests
|
||||
avg_confidence = sum(r['confidence'] for r in completed_results) / total_tests
|
||||
avg_processing_time = sum(r['processing_time_ms'] for r in completed_results) / total_tests
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print("OCR VALIDATION SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total Receipts Tested: {total_tests}")
|
||||
print(f"Perfect Matches: {perfect_matches} ({perfect_matches/total_tests*100:.1f}%)")
|
||||
print("---")
|
||||
print(f"Total Amount Match Rate: {total_match_rate*100:.1f}%")
|
||||
print(f"Date Match Rate: {date_match_rate*100:.1f}%")
|
||||
print(f"CUI Match Rate: {cui_match_rate*100:.1f}%")
|
||||
print(f"TVA Match Rate: {tva_match_rate*100:.1f}%")
|
||||
print("---")
|
||||
print(f"Average Confidence: {avg_confidence*100:.1f}%")
|
||||
print(f"Average Processing Time: {avg_processing_time:.0f}ms")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Failed receipts
|
||||
failed_results = [r for r in completed_results if r.get('errors')]
|
||||
if failed_results:
|
||||
print(f"\nFAILED RECEIPTS ({len(failed_results)}):")
|
||||
for r in failed_results:
|
||||
print(f" - {r['filename']}: {'; '.join(r['errors'])}")
|
||||
|
||||
# Categorize problems for analysis
|
||||
problems_analysis = {
|
||||
'cui_issues': [],
|
||||
'tva_issues': [],
|
||||
'total_issues': [],
|
||||
'date_issues': [],
|
||||
'confidence_issues': [],
|
||||
}
|
||||
|
||||
for r in completed_results:
|
||||
# CUI issues
|
||||
if not r.get('cui_match'):
|
||||
cui_expected = normalize_cui(r.get('cui_expected'))
|
||||
cui_got = normalize_cui(r.get('cui_extracted'))
|
||||
issue_type = 'missing' if not cui_got else 'mismatch'
|
||||
|
||||
# Check if it's a digit substitution (same length, 1-2 chars different)
|
||||
if cui_expected and cui_got and len(cui_expected) == len(cui_got):
|
||||
diff_count = sum(1 for a, b in zip(cui_expected, cui_got) if a != b)
|
||||
if diff_count <= 2:
|
||||
issue_type = 'digit_substitution'
|
||||
|
||||
problems_analysis['cui_issues'].append({
|
||||
'file': r['filename'],
|
||||
'expected': r.get('cui_expected'),
|
||||
'got': r.get('cui_extracted'),
|
||||
'type': issue_type,
|
||||
'confidence': r.get('confidence', 0),
|
||||
})
|
||||
|
||||
# TVA issues
|
||||
if not r.get('tva_match'):
|
||||
tva_expected = r.get('tva_expected', 0)
|
||||
tva_got = r.get('tva_extracted')
|
||||
issue_type = 'missing' if tva_got is None else 'mismatch'
|
||||
|
||||
# Check for 5% rate (books)
|
||||
if tva_expected and tva_expected > 0:
|
||||
full_ext = r.get('full_extraction', {})
|
||||
total = full_ext.get('amount')
|
||||
if total and tva_expected:
|
||||
try:
|
||||
implied_rate = float(tva_expected) / float(total) * 100
|
||||
if 4 <= implied_rate <= 6:
|
||||
issue_type = 'low_rate_5pct'
|
||||
except:
|
||||
pass
|
||||
|
||||
problems_analysis['tva_issues'].append({
|
||||
'file': r['filename'],
|
||||
'expected': tva_expected,
|
||||
'got': tva_got,
|
||||
'type': issue_type,
|
||||
'tva_entries': r.get('full_extraction', {}).get('tva_entries', []),
|
||||
})
|
||||
|
||||
# TOTAL issues
|
||||
if not r.get('total_match'):
|
||||
problems_analysis['total_issues'].append({
|
||||
'file': r['filename'],
|
||||
'expected': r.get('total_expected'),
|
||||
'got': r.get('total_extracted'),
|
||||
'confidence': r.get('confidence', 0),
|
||||
'payment_methods': r.get('full_extraction', {}).get('payment_methods', []),
|
||||
})
|
||||
|
||||
# DATE issues
|
||||
if not r.get('date_match'):
|
||||
problems_analysis['date_issues'].append({
|
||||
'file': r['filename'],
|
||||
'expected': r.get('date_expected'),
|
||||
'got': r.get('date_extracted'),
|
||||
})
|
||||
|
||||
# Low confidence issues
|
||||
if r.get('confidence', 0) < 0.7:
|
||||
problems_analysis['confidence_issues'].append({
|
||||
'file': r['filename'],
|
||||
'confidence': r.get('confidence', 0),
|
||||
'errors': r.get('errors', []),
|
||||
})
|
||||
|
||||
# Save detailed report
|
||||
report = {
|
||||
'test_date': datetime.now().isoformat(),
|
||||
'engine': args.engine,
|
||||
'statistics': {
|
||||
'total_tests': total_tests,
|
||||
'perfect_matches': perfect_matches,
|
||||
'perfect_match_rate': perfect_matches / total_tests,
|
||||
'total_match_rate': total_match_rate,
|
||||
'date_match_rate': date_match_rate,
|
||||
'cui_match_rate': cui_match_rate,
|
||||
'tva_match_rate': tva_match_rate,
|
||||
'avg_confidence': avg_confidence,
|
||||
'avg_processing_time_ms': avg_processing_time,
|
||||
},
|
||||
'problems_analysis': problems_analysis,
|
||||
'failed_receipts': [
|
||||
{'filename': r['filename'], 'errors': r['errors']}
|
||||
for r in failed_results
|
||||
],
|
||||
'detailed_results': results,
|
||||
}
|
||||
|
||||
# Save report with engine name in filename
|
||||
report_path = script_dir / f'ocr_report_{args.engine.replace("-", "_")}_FULL.json'
|
||||
with open(report_path, 'w') as f:
|
||||
json.dump(report, f, indent=2, default=str)
|
||||
print(f"\nReport saved to: {report_path}")
|
||||
|
||||
# Exit with error if match rates are below threshold
|
||||
if total_match_rate < 0.8:
|
||||
print(f"\n[FAIL] Total match rate {total_match_rate*100:.1f}% is below 80% threshold")
|
||||
sys.exit(1)
|
||||
|
||||
print("\n[PASS] OCR validation completed successfully!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
135
tests/ocr-validation/test_receipts_parallel.py
Normal file
135
tests/ocr-validation/test_receipts_parallel.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test receipts in PARALLEL to measure real worker benefit."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta
|
||||
from jose import jwt
|
||||
|
||||
API_BASE = "http://localhost:8000"
|
||||
PDF_FOLDER = "/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/docs/data-entry"
|
||||
EXPECTED_FILE = "/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/tests/ocr-validation/expected_receipts.json"
|
||||
|
||||
def get_jwt_token():
|
||||
secret_key = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
|
||||
now = datetime.utcnow()
|
||||
payload = {
|
||||
"username": "MARIUS", "user_id": 1, "companies": ["604"],
|
||||
"permissions": ["read", "write"], "exp": now + timedelta(hours=1),
|
||||
"iat": now, "type": "access"
|
||||
}
|
||||
return jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
def submit_job(pdf_path, headers):
|
||||
"""Submit OCR job and return job_id immediately."""
|
||||
filename = os.path.basename(pdf_path)
|
||||
try:
|
||||
with open(pdf_path, "rb") as f:
|
||||
files = {"file": (filename, f, "application/pdf")}
|
||||
response = requests.post(
|
||||
f"{API_BASE}/api/data-entry/ocr/extract?engine=doctr_plus",
|
||||
files=files, headers=headers, timeout=30
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("job_id"), filename, None
|
||||
return None, filename, f"HTTP {response.status_code}"
|
||||
except Exception as e:
|
||||
return None, filename, str(e)
|
||||
|
||||
def wait_for_job(job_id, filename, headers, timeout=180):
|
||||
"""Wait for job completion."""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{API_BASE}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30",
|
||||
headers=headers, timeout=35
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == "completed":
|
||||
result = data.get("result", {})
|
||||
conf = result.get("overall_confidence", 0)
|
||||
return {"success": True, "conf": conf, "time": time.time() - start}
|
||||
elif status == "error":
|
||||
return {"success": False, "error": data.get("error", "unknown"), "time": time.time() - start}
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
return {"success": False, "error": "timeout", "time": time.time() - start}
|
||||
|
||||
def main():
|
||||
# Load receipts
|
||||
with open(EXPECTED_FILE) as f:
|
||||
data = json.load(f)
|
||||
receipts = data.get("receipts", data)
|
||||
receipts = [r for r in receipts if r.get("pages", 1) == 1]
|
||||
|
||||
token = get_jwt_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"PARALLEL TEST: {len(receipts)} receipts")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# PHASE 1: Submit ALL jobs rapidly
|
||||
print("Phase 1: Submitting all jobs...")
|
||||
total_start = time.time()
|
||||
jobs = []
|
||||
|
||||
for r in receipts:
|
||||
pdf_path = os.path.join(PDF_FOLDER, r["filename"])
|
||||
if os.path.exists(pdf_path):
|
||||
job_id, filename, error = submit_job(pdf_path, headers)
|
||||
if job_id:
|
||||
jobs.append((job_id, filename))
|
||||
else:
|
||||
print(f" Submit failed: {filename} - {error}")
|
||||
|
||||
submit_time = time.time() - total_start
|
||||
print(f"Submitted {len(jobs)} jobs in {submit_time:.1f}s")
|
||||
|
||||
# PHASE 2: Wait for ALL results in parallel
|
||||
print("\nPhase 2: Waiting for results...")
|
||||
wait_start = time.time()
|
||||
results = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=26) as executor:
|
||||
futures = {executor.submit(wait_for_job, job_id, fn, headers): fn
|
||||
for job_id, fn in jobs}
|
||||
|
||||
for future in as_completed(futures):
|
||||
filename = futures[future]
|
||||
result = future.result()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
if result["success"]:
|
||||
print(f" OK: {filename[:45]:47} {result['time']:5.1f}s conf={result['conf']:.0%}")
|
||||
else:
|
||||
print(f" ERR: {filename[:45]:47} {result['time']:5.1f}s {result.get('error','?')}")
|
||||
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
successful = [r for r in results if r["success"]]
|
||||
failed = [r for r in results if not r["success"]]
|
||||
|
||||
print(f"Success: {len(successful)}/{len(results)}")
|
||||
print(f"Submit phase: {submit_time:.1f}s")
|
||||
print(f"Wait phase: {time.time() - wait_start:.1f}s")
|
||||
print(f"TOTAL TIME: {total_time:.1f}s")
|
||||
|
||||
if successful:
|
||||
times = [r["time"] for r in successful]
|
||||
print(f"\nPer-job: avg={sum(times)/len(times):.1f}s, min={min(times):.1f}s, max={max(times):.1f}s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
314
tests/ocr-validation/test_receipts_parallel_windows.py
Normal file
314
tests/ocr-validation/test_receipts_parallel_windows.py
Normal file
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Parallel OCR test for Windows.
|
||||
Run from backend directory: python tests\ocr-validation\test_receipts_parallel_windows.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from jose import jwt
|
||||
|
||||
try:
|
||||
import psutil
|
||||
PSUTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PSUTIL_AVAILABLE = False
|
||||
print("Warning: psutil not installed, memory tracking disabled")
|
||||
|
||||
# Paths - relative to backend directory
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
BACKEND_DIR = SCRIPT_DIR.parent.parent / "backend"
|
||||
PDF_FOLDER = SCRIPT_DIR.parent.parent / "docs" / "data-entry"
|
||||
EXPECTED_FILE = SCRIPT_DIR / "expected_receipts.json"
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Monitor memory usage of backend process and its children (OCR workers)."""
|
||||
|
||||
def __init__(self, port=8006):
|
||||
self.port = port
|
||||
self.peak_memory_mb = 0
|
||||
self.current_memory_mb = 0
|
||||
self._stop_event = threading.Event()
|
||||
self._thread = None
|
||||
self._process = None
|
||||
|
||||
def _find_backend_process(self):
|
||||
"""Find the backend process by port."""
|
||||
if not PSUTIL_AVAILABLE:
|
||||
return None
|
||||
try:
|
||||
for conn in psutil.net_connections(kind='inet'):
|
||||
if conn.laddr.port == self.port and conn.status == 'LISTEN':
|
||||
return psutil.Process(conn.pid)
|
||||
except (psutil.AccessDenied, psutil.NoSuchProcess):
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_total_memory(self):
|
||||
"""Get total memory of backend + all child processes (OCR workers)."""
|
||||
if not self._process:
|
||||
self._process = self._find_backend_process()
|
||||
if not self._process:
|
||||
return 0
|
||||
try:
|
||||
# Get memory of main process
|
||||
total = self._process.memory_info().rss
|
||||
# Add memory of all child processes (OCR workers)
|
||||
for child in self._process.children(recursive=True):
|
||||
try:
|
||||
total += child.memory_info().rss
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
return total / (1024 * 1024) # Convert to MB
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
self._process = None
|
||||
return 0
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background thread that monitors memory every 0.5s."""
|
||||
while not self._stop_event.is_set():
|
||||
mem = self._get_total_memory()
|
||||
if mem > 0:
|
||||
self.current_memory_mb = mem
|
||||
if mem > self.peak_memory_mb:
|
||||
self.peak_memory_mb = mem
|
||||
self._stop_event.wait(0.5)
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring in background thread."""
|
||||
if not PSUTIL_AVAILABLE:
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
# Wait a bit to get initial reading
|
||||
time.sleep(1)
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring and return peak memory."""
|
||||
if self._thread:
|
||||
self._stop_event.set()
|
||||
self._thread.join(timeout=2)
|
||||
return self.peak_memory_mb
|
||||
|
||||
|
||||
def get_jwt_token():
|
||||
secret_key = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
|
||||
now = datetime.utcnow()
|
||||
payload = {
|
||||
"username": "MARIUS",
|
||||
"user_id": 1,
|
||||
"companies": ["604"],
|
||||
"permissions": ["read", "write"],
|
||||
"exp": now + timedelta(hours=1),
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
return jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def submit_job(pdf_path, headers, api_base):
|
||||
"""Submit OCR job and return job_id immediately."""
|
||||
filename = os.path.basename(pdf_path)
|
||||
try:
|
||||
with open(pdf_path, "rb") as f:
|
||||
files = {"file": (filename, f, "application/pdf")}
|
||||
response = requests.post(
|
||||
f"{api_base}/api/data-entry/ocr/extract?engine=doctr_plus",
|
||||
files=files,
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("job_id"), filename, None
|
||||
return None, filename, f"HTTP {response.status_code}: {response.text[:100]}"
|
||||
except Exception as e:
|
||||
return None, filename, str(e)
|
||||
|
||||
|
||||
def wait_for_job(job_id, filename, headers, api_base, timeout=180):
|
||||
"""Wait for job completion."""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{api_base}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30",
|
||||
headers=headers,
|
||||
timeout=35
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == "completed":
|
||||
result = data.get("result", {})
|
||||
conf = result.get("overall_confidence", 0)
|
||||
return {"success": True, "conf": conf, "time": time.time() - start, "filename": filename}
|
||||
elif status in ("error", "failed"):
|
||||
return {"success": False, "error": data.get("error", "unknown"), "time": time.time() - start, "filename": filename}
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
return {"success": False, "error": "timeout", "time": time.time() - start, "filename": filename}
|
||||
|
||||
|
||||
def run_test(api_base, workers, output_file=None, port=8006):
|
||||
"""Run test and return results dict."""
|
||||
# Load receipts
|
||||
if not EXPECTED_FILE.exists():
|
||||
print(f"ERROR: {EXPECTED_FILE} not found!")
|
||||
return None
|
||||
|
||||
with open(EXPECTED_FILE) as f:
|
||||
data = json.load(f)
|
||||
receipts = data.get("receipts", data)
|
||||
receipts = [r for r in receipts if r.get("pages", 1) == 1]
|
||||
|
||||
token = get_jwt_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# Start memory monitoring
|
||||
memory_monitor = MemoryMonitor(port=port)
|
||||
memory_monitor.start()
|
||||
|
||||
header = f"TEST: {len(receipts)} receipts, {workers} worker(s)"
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(header)
|
||||
print(f"Backend: {api_base}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# PHASE 1: Submit ALL jobs rapidly
|
||||
print("Phase 1: Submitting all jobs...")
|
||||
total_start = time.time()
|
||||
jobs = []
|
||||
|
||||
for r in receipts:
|
||||
pdf_path = PDF_FOLDER / r["filename"]
|
||||
if pdf_path.exists():
|
||||
job_id, filename, error = submit_job(str(pdf_path), headers, api_base)
|
||||
if job_id:
|
||||
jobs.append((job_id, filename))
|
||||
else:
|
||||
print(f" Submit failed: {filename} - {error}")
|
||||
else:
|
||||
print(f" File not found: {r['filename']}")
|
||||
|
||||
submit_time = time.time() - total_start
|
||||
print(f"Submitted {len(jobs)} jobs in {submit_time:.1f}s")
|
||||
print()
|
||||
|
||||
# PHASE 2: Wait for ALL results in parallel
|
||||
print("Phase 2: Waiting for results...")
|
||||
wait_start = time.time()
|
||||
results = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=26) as executor:
|
||||
futures = {
|
||||
executor.submit(wait_for_job, job_id, fn, headers, api_base): fn
|
||||
for job_id, fn in jobs
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
if result["success"]:
|
||||
print(f" OK: {result['filename'][:45]:47} {result['time']:5.1f}s conf={result['conf']:.0%}")
|
||||
else:
|
||||
print(f" ERR: {result['filename'][:45]:47} {result['time']:5.1f}s {result.get('error', '?')}")
|
||||
|
||||
total_time = time.time() - total_start
|
||||
wait_time = time.time() - wait_start
|
||||
|
||||
# Stop memory monitoring and get peak
|
||||
peak_memory_mb = memory_monitor.stop()
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(f"SUMMARY - {workers} WORKER(S)")
|
||||
print("=" * 60)
|
||||
successful = [r for r in results if r["success"]]
|
||||
failed = [r for r in results if not r["success"]]
|
||||
|
||||
print(f"Success: {len(successful)}/{len(results)}")
|
||||
print(f"Submit phase: {submit_time:.1f}s")
|
||||
print(f"Wait phase: {wait_time:.1f}s")
|
||||
print(f"TOTAL TIME: {total_time:.1f}s")
|
||||
if peak_memory_mb > 0:
|
||||
print(f"PEAK MEMORY: {peak_memory_mb:.0f} MB")
|
||||
|
||||
avg_time = sum(r["time"] for r in successful) / len(successful) if successful else 0
|
||||
min_time = min(r["time"] for r in successful) if successful else 0
|
||||
max_time = max(r["time"] for r in successful) if successful else 0
|
||||
avg_conf = sum(r["conf"] for r in successful) / len(successful) if successful else 0
|
||||
|
||||
if successful:
|
||||
print(f"\nPer-job: avg={avg_time:.1f}s, min={min_time:.1f}s, max={max_time:.1f}s")
|
||||
|
||||
if failed:
|
||||
print(f"\nFailed jobs ({len(failed)}):")
|
||||
for r in failed:
|
||||
print(f" - {r['filename']}: {r.get('error', '?')}")
|
||||
|
||||
# Build result dict
|
||||
result_data = {
|
||||
"workers": workers,
|
||||
"total_receipts": len(receipts),
|
||||
"submitted": len(jobs),
|
||||
"successful": len(successful),
|
||||
"failed": len(failed),
|
||||
"submit_time": round(submit_time, 1),
|
||||
"wait_time": round(wait_time, 1),
|
||||
"total_time": round(total_time, 1),
|
||||
"avg_time": round(avg_time, 1),
|
||||
"min_time": round(min_time, 1),
|
||||
"max_time": round(max_time, 1),
|
||||
"avg_confidence": round(avg_conf * 100, 1),
|
||||
"peak_memory_mb": round(peak_memory_mb, 0),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Write to file if specified
|
||||
if output_file:
|
||||
# Append to existing results
|
||||
all_results = []
|
||||
if Path(output_file).exists():
|
||||
try:
|
||||
with open(output_file) as f:
|
||||
all_results = json.load(f)
|
||||
except:
|
||||
all_results = []
|
||||
all_results.append(result_data)
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
|
||||
return result_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Parallel OCR Test")
|
||||
parser.add_argument("--port", type=int, default=8006, help="Backend port")
|
||||
parser.add_argument("--host", default="localhost", help="Backend host")
|
||||
parser.add_argument("--workers", type=int, default=1, help="Number of OCR workers (for labeling)")
|
||||
parser.add_argument("--output", type=str, help="Output JSON file for results")
|
||||
args = parser.parse_args()
|
||||
|
||||
api_base = f"http://{args.host}:{args.port}"
|
||||
run_test(api_base, args.workers, args.output, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
228
tests/ocr-validation/test_receipts_sequential.py
Normal file
228
tests/ocr-validation/test_receipts_sequential.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test each receipt sequentially and report results."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from jose import jwt
|
||||
|
||||
API_BASE = "http://localhost:8000"
|
||||
PDF_FOLDER = "/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/docs/data-entry"
|
||||
EXPECTED_FILE = "/mnt/e/proiecte/ab-worktrees/doctr-ocr-metrics/tests/ocr-validation/expected_receipts.json"
|
||||
|
||||
def get_jwt_token():
|
||||
"""Create a test JWT token for API authentication."""
|
||||
secret_key = os.getenv('JWT_SECRET_KEY', 'generate_with_secrets_token_urlsafe_32')
|
||||
now = datetime.utcnow()
|
||||
expire = now + timedelta(hours=1)
|
||||
|
||||
payload = {
|
||||
"username": "MARIUS",
|
||||
"user_id": 1,
|
||||
"companies": ["604"],
|
||||
"permissions": ["read", "write"],
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
return jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
def test_receipt(pdf_path: str, expected: dict, headers: dict) -> dict:
|
||||
"""Test a single receipt and return results."""
|
||||
filename = os.path.basename(pdf_path)
|
||||
result = {
|
||||
"filename": filename,
|
||||
"success": False,
|
||||
"time_ms": 0,
|
||||
"error": None,
|
||||
"extracted": {},
|
||||
"matches": {},
|
||||
"issues": []
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with open(pdf_path, "rb") as f:
|
||||
files = {"file": (filename, f, "application/pdf")}
|
||||
response = requests.post(
|
||||
f"{API_BASE}/api/data-entry/ocr/extract?engine=doctr_plus",
|
||||
files=files,
|
||||
headers=headers,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
result["error"] = f"HTTP {response.status_code}"
|
||||
result["time_ms"] = int((time.time() - start_time) * 1000)
|
||||
return result
|
||||
|
||||
job_data = response.json()
|
||||
job_id = job_data.get("job_id")
|
||||
|
||||
# Poll for completion
|
||||
for _ in range(60): # Max 60 polls (2 minutes)
|
||||
poll_response = requests.get(
|
||||
f"{API_BASE}/api/data-entry/ocr/jobs/{job_id}/wait?timeout=30",
|
||||
headers=headers,
|
||||
timeout=35
|
||||
)
|
||||
if poll_response.status_code == 200:
|
||||
job_result = poll_response.json()
|
||||
status = job_result.get("status")
|
||||
if status == "completed":
|
||||
break
|
||||
elif status == "error":
|
||||
result["error"] = job_result.get("error", "Unknown error")
|
||||
result["time_ms"] = int((time.time() - start_time) * 1000)
|
||||
return result
|
||||
time.sleep(1)
|
||||
|
||||
result["time_ms"] = int((time.time() - start_time) * 1000)
|
||||
|
||||
if job_result.get("status") != "completed":
|
||||
result["error"] = f"Timeout - status: {job_result.get('status')}"
|
||||
return result
|
||||
|
||||
# Extract fields (correct field names from API)
|
||||
extraction = job_result.get("result", {})
|
||||
result["extracted"] = {
|
||||
"total": extraction.get("amount"), # API uses "amount" not "total"
|
||||
"date": extraction.get("receipt_date"), # API uses "receipt_date" not "date"
|
||||
"cui": extraction.get("cui"),
|
||||
"tva_total": extraction.get("tva_total"),
|
||||
"confidence": extraction.get("overall_confidence")
|
||||
}
|
||||
|
||||
# Compare with expected (use correct field names from expected_receipts.json)
|
||||
exp_total = expected.get("total")
|
||||
exp_date = expected.get("data_bon")
|
||||
exp_cui = expected.get("cui_furnizor")
|
||||
|
||||
# Normalize for comparison
|
||||
def normalize_total(val):
|
||||
if val is None:
|
||||
return None
|
||||
return float(str(val).replace(',', '.'))
|
||||
|
||||
def normalize_cui(val):
|
||||
if val is None:
|
||||
return None
|
||||
return str(val).upper().replace('RO', '').replace(' ', '').strip()
|
||||
|
||||
ext_total = normalize_total(result["extracted"]["total"])
|
||||
ext_cui = normalize_cui(result["extracted"]["cui"])
|
||||
exp_cui_norm = normalize_cui(exp_cui)
|
||||
exp_total_norm = normalize_total(exp_total)
|
||||
|
||||
result["matches"]["total"] = abs(ext_total - exp_total_norm) < 0.01 if ext_total and exp_total_norm else None
|
||||
result["matches"]["date"] = result["extracted"]["date"] == exp_date if exp_date else None
|
||||
result["matches"]["cui"] = ext_cui == exp_cui_norm if exp_cui else None
|
||||
|
||||
# Check for issues
|
||||
if exp_total and not result["matches"]["total"]:
|
||||
result["issues"].append(f"TOTAL: got {result['extracted']['total']}, expected {exp_total}")
|
||||
if exp_date and not result["matches"]["date"]:
|
||||
result["issues"].append(f"DATE: got {result['extracted']['date']}, expected {exp_date}")
|
||||
if exp_cui and not result["matches"]["cui"]:
|
||||
result["issues"].append(f"CUI: got {result['extracted']['cui']}, expected {exp_cui}")
|
||||
|
||||
result["success"] = len(result["issues"]) == 0
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
result["time_ms"] = int((time.time() - start_time) * 1000)
|
||||
|
||||
return result
|
||||
|
||||
def main():
|
||||
# Load expected data
|
||||
with open(EXPECTED_FILE) as f:
|
||||
expected_data = json.load(f)
|
||||
|
||||
# Handle both formats: list or dict with "receipts" key
|
||||
if isinstance(expected_data, dict) and "receipts" in expected_data:
|
||||
all_receipts = expected_data["receipts"]
|
||||
else:
|
||||
all_receipts = expected_data
|
||||
|
||||
# Get JWT token
|
||||
token = get_jwt_token()
|
||||
if not token:
|
||||
print("ERROR: Could not get JWT token")
|
||||
sys.exit(1)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# Filter single-page receipts
|
||||
receipts = [r for r in all_receipts if r.get("pages", 1) == 1]
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing {len(receipts)} single-page receipts with doctr_plus")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
times = []
|
||||
|
||||
for i, receipt in enumerate(receipts, 1):
|
||||
filename = receipt["filename"]
|
||||
pdf_path = os.path.join(PDF_FOLDER, filename)
|
||||
|
||||
if not os.path.exists(pdf_path):
|
||||
print(f"[{i:02d}/{len(receipts)}] SKIP: {filename} (not found)")
|
||||
continue
|
||||
|
||||
print(f"[{i:02d}/{len(receipts)}] Testing: {filename[:50]}...", end=" ", flush=True)
|
||||
|
||||
result = test_receipt(pdf_path, receipt, headers)
|
||||
results.append(result)
|
||||
|
||||
if result["error"]:
|
||||
print(f"ERROR ({result['time_ms']}ms): {result['error']}")
|
||||
elif result["success"]:
|
||||
print(f"OK ({result['time_ms']}ms) conf={result['extracted'].get('confidence', 0):.2f}")
|
||||
times.append(result["time_ms"])
|
||||
else:
|
||||
print(f"FAIL ({result['time_ms']}ms): {'; '.join(result['issues'])}")
|
||||
times.append(result["time_ms"])
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
successful = [r for r in results if r["success"]]
|
||||
failed = [r for r in results if not r["success"] and not r["error"]]
|
||||
errors = [r for r in results if r["error"]]
|
||||
|
||||
print(f"Total: {len(results)}")
|
||||
print(f"Success: {len(successful)} ({len(successful)*100/len(results):.1f}%)")
|
||||
print(f"Failed: {len(failed)}")
|
||||
print(f"Errors: {len(errors)}")
|
||||
|
||||
if times:
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"\nTiming: avg={avg_time:.0f}ms, min={min(times)}ms, max={max(times)}ms")
|
||||
|
||||
# Flag slow ones
|
||||
slow_threshold = avg_time * 2
|
||||
slow = [r for r in results if r["time_ms"] > slow_threshold and not r["error"]]
|
||||
if slow:
|
||||
print(f"\nSlow receipts (>{slow_threshold:.0f}ms):")
|
||||
for r in slow:
|
||||
print(f" - {r['filename']}: {r['time_ms']}ms")
|
||||
|
||||
if failed:
|
||||
print(f"\nFailed receipts:")
|
||||
for r in failed:
|
||||
print(f" - {r['filename']}: {'; '.join(r['issues'])}")
|
||||
|
||||
if errors:
|
||||
print(f"\nError receipts:")
|
||||
for r in errors:
|
||||
print(f" - {r['filename']}: {r['error']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user