fix telegram
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
# Alembic configuration for Data Entry module
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration files
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||
|
||||
# version path separator
|
||||
# version_path_separator = :
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# SQLite database URL - will be overridden by env.py using SQLITE_DATABASE_PATH env var
|
||||
sqlalchemy.url = sqlite:///data/receipts/receipts.db
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - disabled
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -q
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Application configuration using pydantic-settings."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# App info
|
||||
app_name: str = "Data Entry API"
|
||||
app_version: str = "1.0.0"
|
||||
debug: bool = False
|
||||
|
||||
# API
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8003
|
||||
|
||||
# SQLite Database
|
||||
sqlite_database_path: str = "data/receipts/receipts.db"
|
||||
|
||||
# File uploads
|
||||
upload_path: str = "data/uploads"
|
||||
max_upload_size_mb: int = 10
|
||||
allowed_mime_types: List[str] = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
]
|
||||
|
||||
# Oracle Database (for nomenclatures)
|
||||
oracle_user: str = ""
|
||||
oracle_password: str = ""
|
||||
oracle_host: str = "localhost"
|
||||
oracle_port: int = 1526
|
||||
oracle_sid: str = "ROA"
|
||||
|
||||
# JWT Authentication
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expire_minutes: int = 480
|
||||
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:3010,http://localhost:3000"
|
||||
|
||||
# OCR Engines (comma-separated list of active engines shown in UI)
|
||||
# Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
# doctr_plus is recommended (2-tier sequential with early exit)
|
||||
ocr_active_engines: str = "doctr,doctr_plus"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""Get SQLite database URL for async."""
|
||||
# Resolve to absolute path for Windows/IIS compatibility
|
||||
abs_path = Path(self.sqlite_database_path).resolve()
|
||||
return f"sqlite+aiosqlite:///{abs_path}"
|
||||
|
||||
@property
|
||||
def sync_database_url(self) -> str:
|
||||
"""Get SQLite database URL for sync operations (Alembic)."""
|
||||
# Resolve to absolute path for Windows/IIS compatibility
|
||||
abs_path = Path(self.sqlite_database_path).resolve()
|
||||
return f"sqlite:///{abs_path}"
|
||||
|
||||
@property
|
||||
def upload_path_resolved(self) -> Path:
|
||||
"""Get resolved upload path."""
|
||||
path = Path(self.upload_path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def max_upload_size_bytes(self) -> int:
|
||||
"""Get max upload size in bytes."""
|
||||
return self.max_upload_size_mb * 1024 * 1024
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
"""Get CORS origins as list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
@property
|
||||
def ocr_active_engines_list(self) -> List[str]:
|
||||
"""Get OCR active engines as list."""
|
||||
return [engine.strip() for engine in self.ocr_active_engines.split(",")]
|
||||
|
||||
@property
|
||||
def oracle_dsn(self) -> str:
|
||||
"""Get Oracle DSN string."""
|
||||
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Convenience instance
|
||||
settings = get_settings()
|
||||
@@ -0,0 +1,4 @@
|
||||
# Database module
|
||||
from .database import get_session, init_db, engine
|
||||
|
||||
__all__ = ["get_session", "init_db", "engine"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# CRUD operations
|
||||
from .receipt import ReceiptCRUD
|
||||
from .attachment import AttachmentCRUD
|
||||
from .accounting_entry import AccountingEntryCRUD
|
||||
from .ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCRUD",
|
||||
"AttachmentCRUD",
|
||||
"AccountingEntryCRUD",
|
||||
"OCRPreferenceCRUD",
|
||||
"OCRMetricsCRUD",
|
||||
]
|
||||
@@ -0,0 +1,197 @@
|
||||
"""CRUD operations for accounting entries."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry, EntryType
|
||||
from backend.modules.data_entry.schemas.receipt import AccountingEntryCreate, AccountingEntryUpdate
|
||||
|
||||
|
||||
class AccountingEntryCRUD:
|
||||
"""CRUD operations for AccountingEntry model."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
data: AccountingEntryCreate,
|
||||
sort_order: int = 0,
|
||||
is_auto_generated: bool = True,
|
||||
) -> AccountingEntry:
|
||||
"""Create a new accounting entry."""
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=data.entry_type,
|
||||
account_code=data.account_code,
|
||||
account_name=data.account_name,
|
||||
amount=data.amount,
|
||||
partner_id=data.partner_id,
|
||||
cost_center_id=data.cost_center_id,
|
||||
is_auto_generated=is_auto_generated,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
async def create_bulk(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
is_auto_generated: bool = True,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Create multiple accounting entries at once."""
|
||||
created_entries = []
|
||||
|
||||
for idx, entry_data in enumerate(entries):
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=entry_data.entry_type,
|
||||
account_code=entry_data.account_code,
|
||||
account_name=entry_data.account_name,
|
||||
amount=entry_data.amount,
|
||||
partner_id=entry_data.partner_id,
|
||||
cost_center_id=entry_data.cost_center_id,
|
||||
is_auto_generated=is_auto_generated,
|
||||
sort_order=idx,
|
||||
)
|
||||
session.add(entry)
|
||||
created_entries.append(entry)
|
||||
|
||||
await session.commit()
|
||||
|
||||
for entry in created_entries:
|
||||
await session.refresh(entry)
|
||||
|
||||
return created_entries
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
entry_id: int,
|
||||
) -> Optional[AccountingEntry]:
|
||||
"""Get accounting entry by ID."""
|
||||
query = select(AccountingEntry).where(AccountingEntry.id == entry_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_by_receipt_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Get all accounting entries for a receipt."""
|
||||
query = select(AccountingEntry).where(
|
||||
AccountingEntry.receipt_id == receipt_id
|
||||
).order_by(AccountingEntry.sort_order.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
session: AsyncSession,
|
||||
entry: AccountingEntry,
|
||||
data: AccountingEntryUpdate,
|
||||
modified_by: str,
|
||||
) -> AccountingEntry:
|
||||
"""Update an accounting entry."""
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(entry, field, value)
|
||||
|
||||
entry.is_auto_generated = False
|
||||
entry.modified_by = modified_by
|
||||
entry.modified_at = datetime.utcnow()
|
||||
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, entry: AccountingEntry) -> bool:
|
||||
"""Delete an accounting entry."""
|
||||
await session.delete(entry)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
|
||||
"""Delete all accounting entries for a receipt."""
|
||||
query = delete(AccountingEntry).where(AccountingEntry.receipt_id == receipt_id)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
@staticmethod
|
||||
async def replace_all_for_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
modified_by: str,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Replace all entries for a receipt with new ones."""
|
||||
# Delete existing entries
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
|
||||
# Create new entries (marked as manually modified)
|
||||
created_entries = []
|
||||
|
||||
for idx, entry_data in enumerate(entries):
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=entry_data.entry_type,
|
||||
account_code=entry_data.account_code,
|
||||
account_name=entry_data.account_name,
|
||||
amount=entry_data.amount,
|
||||
partner_id=entry_data.partner_id,
|
||||
cost_center_id=entry_data.cost_center_id,
|
||||
is_auto_generated=False,
|
||||
modified_by=modified_by,
|
||||
modified_at=datetime.utcnow(),
|
||||
sort_order=idx,
|
||||
)
|
||||
session.add(entry)
|
||||
created_entries.append(entry)
|
||||
|
||||
await session.commit()
|
||||
|
||||
for entry in created_entries:
|
||||
await session.refresh(entry)
|
||||
|
||||
return created_entries
|
||||
|
||||
@staticmethod
|
||||
async def validate_entries(entries: List[AccountingEntryCreate]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate accounting entries.
|
||||
Returns (is_valid, error_message).
|
||||
"""
|
||||
if not entries:
|
||||
return False, "At least one entry is required"
|
||||
|
||||
total_debit = sum(
|
||||
e.amount for e in entries if e.entry_type == EntryType.DEBIT
|
||||
)
|
||||
total_credit = sum(
|
||||
e.amount for e in entries if e.entry_type == EntryType.CREDIT
|
||||
)
|
||||
|
||||
# Check balance (debit should equal credit)
|
||||
if abs(total_debit - total_credit) > 0.01:
|
||||
return False, f"Entries not balanced: Debit={total_debit}, Credit={total_credit}"
|
||||
|
||||
# Check for valid account codes
|
||||
for entry in entries:
|
||||
if not entry.account_code or len(entry.account_code) < 3:
|
||||
return False, f"Invalid account code: {entry.account_code}"
|
||||
|
||||
return True, ""
|
||||
@@ -0,0 +1,140 @@
|
||||
"""CRUD operations for receipt attachments."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import aiofiles
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import UploadFile
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptAttachment
|
||||
from backend.config import settings
|
||||
|
||||
|
||||
class AttachmentCRUD:
|
||||
"""CRUD operations for ReceiptAttachment model."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_stored_filename(original_filename: str) -> str:
|
||||
"""Generate unique filename for storage."""
|
||||
ext = Path(original_filename).suffix.lower()
|
||||
return f"{uuid.uuid4()}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def _get_upload_path(stored_filename: str) -> Path:
|
||||
"""Get full path for storing file, organized by year/month."""
|
||||
now = datetime.utcnow()
|
||||
relative_path = Path(str(now.year)) / f"{now.month:02d}"
|
||||
full_path = settings.data_entry_upload_path_resolved / relative_path
|
||||
|
||||
# Ensure directory exists
|
||||
full_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return relative_path / stored_filename
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
file: UploadFile,
|
||||
) -> ReceiptAttachment:
|
||||
"""Create attachment by saving file and creating DB record."""
|
||||
# Generate stored filename
|
||||
stored_filename = AttachmentCRUD._generate_stored_filename(file.filename or "upload")
|
||||
|
||||
# Get relative path
|
||||
relative_path = AttachmentCRUD._get_upload_path(stored_filename)
|
||||
|
||||
# Full path for saving
|
||||
full_path = settings.data_entry_upload_path_resolved / relative_path
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Validate file size
|
||||
if file_size > settings.data_entry_max_upload_size_bytes:
|
||||
raise ValueError(f"File too large. Maximum size is {settings.data_entry_max_upload_size_mb}MB")
|
||||
|
||||
# Validate MIME type
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
if mime_type not in settings.data_entry_allowed_mime_types:
|
||||
raise ValueError(f"File type not allowed: {mime_type}")
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# Create DB record
|
||||
attachment = ReceiptAttachment(
|
||||
receipt_id=receipt_id,
|
||||
filename=file.filename or "upload",
|
||||
stored_filename=stored_filename,
|
||||
file_path=str(relative_path),
|
||||
file_size=file_size,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
session.add(attachment)
|
||||
await session.commit()
|
||||
await session.refresh(attachment)
|
||||
|
||||
return attachment
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
attachment_id: int,
|
||||
) -> Optional[ReceiptAttachment]:
|
||||
"""Get attachment by ID."""
|
||||
query = select(ReceiptAttachment).where(ReceiptAttachment.id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_by_receipt_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> List[ReceiptAttachment]:
|
||||
"""Get all attachments for a receipt."""
|
||||
query = select(ReceiptAttachment).where(
|
||||
ReceiptAttachment.receipt_id == receipt_id
|
||||
).order_by(ReceiptAttachment.uploaded_at.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
def get_file_path(attachment: ReceiptAttachment) -> Path:
|
||||
"""Get full file path for an attachment."""
|
||||
return settings.data_entry_upload_path_resolved / attachment.file_path
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, attachment: ReceiptAttachment) -> bool:
|
||||
"""Delete attachment (file and DB record)."""
|
||||
# Delete file
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
# Delete DB record
|
||||
await session.delete(attachment)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
|
||||
"""Delete all attachments for a receipt."""
|
||||
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
|
||||
count = 0
|
||||
|
||||
for attachment in attachments:
|
||||
await AttachmentCRUD.delete(session, attachment)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
@@ -0,0 +1,222 @@
|
||||
"""CRUD operations for OCR settings and metrics."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.ocr_settings import (
|
||||
UserOCRPreference,
|
||||
OCRJobMetrics,
|
||||
OCRMetricsSummary,
|
||||
OCREngine,
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceCRUD:
|
||||
"""CRUD operations for user OCR preferences."""
|
||||
|
||||
@staticmethod
|
||||
async def get_by_username(session: AsyncSession, username: str) -> Optional[UserOCRPreference]:
|
||||
"""Get user's OCR preference by username."""
|
||||
result = await session.execute(
|
||||
select(UserOCRPreference).where(UserOCRPreference.username == username)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def create_or_update(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
preferred_engine: OCREngine
|
||||
) -> UserOCRPreference:
|
||||
"""Create or update user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
|
||||
if existing:
|
||||
existing.preferred_engine = preferred_engine
|
||||
existing.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
preference = UserOCRPreference(
|
||||
username=username,
|
||||
preferred_engine=preferred_engine
|
||||
)
|
||||
session.add(preference)
|
||||
await session.commit()
|
||||
await session.refresh(preference)
|
||||
return preference
|
||||
|
||||
@staticmethod
|
||||
async def delete_by_username(session: AsyncSession, username: str) -> bool:
|
||||
"""Delete user's OCR preference."""
|
||||
existing = await OCRPreferenceCRUD.get_by_username(session, username)
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class OCRMetricsCRUD:
|
||||
"""CRUD operations for OCR job metrics."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
company_id: Optional[int] = None
|
||||
) -> OCRJobMetrics:
|
||||
"""Create a new OCR job metrics record."""
|
||||
metrics = OCRJobMetrics(
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
company_id=company_id,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
session.add(metrics)
|
||||
await session.commit()
|
||||
await session.refresh(metrics)
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
async def get_by_job_id(session: AsyncSession, job_id: str) -> Optional[OCRJobMetrics]:
|
||||
"""Get metrics by job ID."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics).where(OCRJobMetrics.job_id == job_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_user_history(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[OCRJobMetrics]:
|
||||
"""Get user's OCR job history."""
|
||||
result = await session.execute(
|
||||
select(OCRJobMetrics)
|
||||
.where(OCRJobMetrics.username == username)
|
||||
.order_by(OCRJobMetrics.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def get_summary_by_engine(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> List[OCRMetricsSummary]:
|
||||
"""Get summary metrics grouped by engine."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Build query
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
# Query for aggregated metrics
|
||||
result = await session.execute(
|
||||
select(
|
||||
OCRJobMetrics.engine_used,
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
func.avg(OCRJobMetrics.fields_extracted).label('avg_fields_extracted'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
.group_by(OCRJobMetrics.engine_used)
|
||||
.order_by(func.count(OCRJobMetrics.id).desc())
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for row in result.all():
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
success_rate = successful / total if total > 0 else 0.0
|
||||
summaries.append(OCRMetricsSummary(
|
||||
engine=row.engine_used,
|
||||
total_jobs=total,
|
||||
successful_jobs=successful,
|
||||
failed_jobs=total - successful,
|
||||
success_rate=success_rate,
|
||||
avg_processing_time_ms=float(row.avg_processing_time_ms or 0),
|
||||
avg_confidence=float(row.avg_confidence or 0),
|
||||
avg_fields_extracted=float(row.avg_fields_extracted or 0),
|
||||
))
|
||||
|
||||
return summaries
|
||||
|
||||
@staticmethod
|
||||
async def get_overall_stats(
|
||||
session: AsyncSession,
|
||||
days: int = 30,
|
||||
username: Optional[str] = None
|
||||
) -> dict:
|
||||
"""Get overall OCR statistics."""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
conditions = [OCRJobMetrics.created_at >= cutoff_date]
|
||||
if username:
|
||||
conditions.append(OCRJobMetrics.username == username)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
func.count(OCRJobMetrics.id).label('total_jobs'),
|
||||
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
|
||||
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
|
||||
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
)
|
||||
|
||||
row = result.one()
|
||||
total = row.total_jobs or 0
|
||||
successful = row.successful_jobs or 0
|
||||
|
||||
return {
|
||||
"total_jobs": total,
|
||||
"successful_jobs": successful,
|
||||
"failed_jobs": total - successful,
|
||||
"success_rate": (successful / total * 100) if total > 0 else 0.0,
|
||||
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
|
||||
"avg_confidence": float(row.avg_confidence or 0),
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
|
||||
# Import sqlalchemy for func.cast
|
||||
import sqlalchemy as sa
|
||||
@@ -0,0 +1,418 @@
|
||||
"""CRUD operations for receipts."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus
|
||||
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, ReceiptUpdate, ReceiptFilter
|
||||
|
||||
|
||||
def _serialize_tva_breakdown(tva_breakdown: Optional[List[Any]]) -> Optional[str]:
|
||||
"""Serialize TVA breakdown list to JSON string for SQLite storage."""
|
||||
if tva_breakdown is None:
|
||||
return None
|
||||
|
||||
# Convert Decimal to float for JSON serialization
|
||||
serializable = []
|
||||
for entry in tva_breakdown:
|
||||
if hasattr(entry, 'model_dump'):
|
||||
# Pydantic model
|
||||
item = entry.model_dump()
|
||||
elif isinstance(entry, dict):
|
||||
item = entry.copy()
|
||||
else:
|
||||
item = dict(entry)
|
||||
|
||||
# Convert Decimal to float
|
||||
if 'amount' in item and isinstance(item['amount'], Decimal):
|
||||
item['amount'] = float(item['amount'])
|
||||
|
||||
serializable.append(item)
|
||||
|
||||
return json.dumps(serializable)
|
||||
|
||||
|
||||
def _serialize_payment_methods(payment_methods: Optional[List[Any]]) -> Optional[str]:
|
||||
"""Serialize payment methods list to JSON string for SQLite storage."""
|
||||
if payment_methods is None:
|
||||
return None
|
||||
|
||||
serializable = []
|
||||
for pm in payment_methods:
|
||||
if hasattr(pm, 'model_dump'):
|
||||
item = pm.model_dump()
|
||||
elif isinstance(pm, dict):
|
||||
item = pm.copy()
|
||||
else:
|
||||
item = dict(pm)
|
||||
|
||||
# Convert Decimal to float for JSON
|
||||
if 'amount' in item:
|
||||
if hasattr(item['amount'], '__float__'):
|
||||
item['amount'] = float(item['amount'])
|
||||
|
||||
serializable.append(item)
|
||||
|
||||
return json.dumps(serializable)
|
||||
|
||||
|
||||
class ReceiptCRUD:
|
||||
"""CRUD operations for Receipt model."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
data: ReceiptCreate,
|
||||
created_by: str,
|
||||
) -> Receipt:
|
||||
"""Create a new receipt."""
|
||||
# Get data as dict and serialize tva_breakdown and payment_methods to JSON string
|
||||
receipt_data = data.model_dump()
|
||||
receipt_data['tva_breakdown'] = _serialize_tva_breakdown(receipt_data.get('tva_breakdown'))
|
||||
receipt_data['payment_methods'] = _serialize_payment_methods(receipt_data.get('payment_methods'))
|
||||
|
||||
receipt = Receipt(
|
||||
**receipt_data,
|
||||
created_by=created_by,
|
||||
status=ReceiptStatus.DRAFT,
|
||||
)
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
include_relations: bool = True,
|
||||
) -> Optional[Receipt]:
|
||||
"""Get receipt by ID, optionally with relationships."""
|
||||
query = select(Receipt).where(Receipt.id == receipt_id)
|
||||
|
||||
if include_relations:
|
||||
query = query.options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_list(
|
||||
session: AsyncSession,
|
||||
filters: ReceiptFilter,
|
||||
) -> Tuple[List[Receipt], int]:
|
||||
"""Get paginated list of receipts with filters."""
|
||||
# Base query
|
||||
query = select(Receipt).options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if filters.status:
|
||||
query = query.where(Receipt.status == filters.status)
|
||||
|
||||
if filters.direction:
|
||||
query = query.where(Receipt.direction == filters.direction)
|
||||
|
||||
if filters.company_id:
|
||||
query = query.where(Receipt.company_id == filters.company_id)
|
||||
|
||||
if filters.created_by:
|
||||
query = query.where(Receipt.created_by == filters.created_by)
|
||||
|
||||
if filters.date_from:
|
||||
query = query.where(Receipt.receipt_date >= filters.date_from)
|
||||
|
||||
if filters.date_to:
|
||||
query = query.where(Receipt.receipt_date <= filters.date_to)
|
||||
|
||||
if filters.search:
|
||||
search_term = f"%{filters.search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Receipt.description.ilike(search_term),
|
||||
Receipt.partner_name.ilike(search_term),
|
||||
Receipt.receipt_number.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# Bulk upload filters (US-012)
|
||||
# US-005: Support comma-separated values for processing_status filter (e.g., "pending,processing")
|
||||
if filters.processing_status:
|
||||
statuses = [s.strip() for s in filters.processing_status.split(",")]
|
||||
if len(statuses) == 1:
|
||||
query = query.where(Receipt.processing_status == statuses[0])
|
||||
else:
|
||||
query = query.where(Receipt.processing_status.in_(statuses))
|
||||
|
||||
if filters.batch_id:
|
||||
query = query.where(Receipt.batch_id == filters.batch_id)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await session.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply ordering based on sort_by parameter (US-012)
|
||||
if filters.sort_by == "processing_started_at":
|
||||
query = query.order_by(Receipt.processing_started_at.desc())
|
||||
elif filters.sort_by == "processing_started_at_asc":
|
||||
query = query.order_by(Receipt.processing_started_at.asc())
|
||||
else:
|
||||
# Default ordering
|
||||
query = query.order_by(Receipt.created_at.desc())
|
||||
|
||||
# Apply pagination
|
||||
offset = (filters.page - 1) * filters.page_size
|
||||
query = query.offset(offset).limit(filters.page_size)
|
||||
|
||||
# Execute
|
||||
result = await session.execute(query)
|
||||
receipts = result.scalars().all()
|
||||
|
||||
return list(receipts), total
|
||||
|
||||
@staticmethod
|
||||
async def get_processing_stats(
|
||||
session: AsyncSession,
|
||||
company_id: Optional[int] = None,
|
||||
batch_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get processing status counts for bulk uploaded receipts (US-012)."""
|
||||
# Build base query for counting by processing_status
|
||||
base_conditions = []
|
||||
|
||||
if company_id:
|
||||
base_conditions.append(Receipt.company_id == company_id)
|
||||
|
||||
if batch_id:
|
||||
base_conditions.append(Receipt.batch_id == batch_id)
|
||||
|
||||
# Only count receipts that have a processing_status (bulk uploads)
|
||||
base_conditions.append(Receipt.processing_status.isnot(None))
|
||||
|
||||
query = select(
|
||||
Receipt.processing_status,
|
||||
func.count(Receipt.id).label("count")
|
||||
)
|
||||
|
||||
for condition in base_conditions:
|
||||
query = query.where(condition)
|
||||
|
||||
query = query.group_by(Receipt.processing_status)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Initialize stats
|
||||
stats = {
|
||||
"pending_count": 0,
|
||||
"processing_count": 0,
|
||||
"completed_count": 0,
|
||||
"failed_count": 0,
|
||||
}
|
||||
|
||||
# Map results
|
||||
for row in rows:
|
||||
status = row.processing_status
|
||||
count = row.count
|
||||
if status == "pending":
|
||||
stats["pending_count"] = count
|
||||
elif status == "processing":
|
||||
stats["processing_count"] = count
|
||||
elif status == "completed":
|
||||
stats["completed_count"] = count
|
||||
elif status == "failed":
|
||||
stats["failed_count"] = count
|
||||
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_review(
|
||||
session: AsyncSession,
|
||||
company_id: Optional[int] = None,
|
||||
) -> List[Receipt]:
|
||||
"""Get all receipts pending review."""
|
||||
query = select(Receipt).where(
|
||||
Receipt.status == ReceiptStatus.PENDING_REVIEW
|
||||
).options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
if company_id:
|
||||
query = query.where(Receipt.company_id == company_id)
|
||||
|
||||
query = query.order_by(Receipt.submitted_at.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
session: AsyncSession,
|
||||
receipt: Receipt,
|
||||
data: ReceiptUpdate,
|
||||
) -> Receipt:
|
||||
"""Update receipt fields.
|
||||
|
||||
US-407: When a receipt is manually updated, reset processing_status and
|
||||
processing_error to NULL. This allows failed OCR receipts to be corrected
|
||||
manually and then submitted for approval without showing as "error" status.
|
||||
"""
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Recalculate tva_total from tva_breakdown if breakdown is being updated
|
||||
if 'tva_breakdown' in update_data and update_data['tva_breakdown']:
|
||||
tva_total = sum(
|
||||
float(entry.get('amount', 0) if isinstance(entry, dict) else getattr(entry, 'amount', 0))
|
||||
for entry in update_data['tva_breakdown']
|
||||
)
|
||||
update_data['tva_total'] = round(tva_total, 2)
|
||||
|
||||
# Serialize tva_breakdown and payment_methods to JSON string if present
|
||||
if 'tva_breakdown' in update_data:
|
||||
update_data['tva_breakdown'] = _serialize_tva_breakdown(update_data['tva_breakdown'])
|
||||
if 'payment_methods' in update_data:
|
||||
update_data['payment_methods'] = _serialize_payment_methods(update_data['payment_methods'])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(receipt, field, value)
|
||||
|
||||
# US-407: Reset processing status when receipt is manually edited
|
||||
# This clears the "failed" status so edited receipts can be submitted for approval
|
||||
if receipt.processing_status == 'failed':
|
||||
receipt.processing_status = None
|
||||
receipt.processing_error = None
|
||||
|
||||
receipt.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def update_status(
|
||||
session: AsyncSession,
|
||||
receipt: Receipt,
|
||||
new_status: ReceiptStatus,
|
||||
reviewed_by: Optional[str] = None,
|
||||
rejection_reason: Optional[str] = None,
|
||||
) -> Receipt:
|
||||
"""Update receipt workflow status."""
|
||||
receipt.status = new_status
|
||||
receipt.updated_at = datetime.utcnow()
|
||||
|
||||
if new_status == ReceiptStatus.PENDING_REVIEW:
|
||||
receipt.submitted_at = datetime.utcnow()
|
||||
|
||||
if new_status in [ReceiptStatus.APPROVED, ReceiptStatus.REJECTED]:
|
||||
receipt.reviewed_by = reviewed_by
|
||||
receipt.reviewed_at = datetime.utcnow()
|
||||
|
||||
if new_status == ReceiptStatus.REJECTED:
|
||||
receipt.rejection_reason = rejection_reason
|
||||
|
||||
if new_status == ReceiptStatus.DRAFT:
|
||||
# Reset review fields when moving back to draft
|
||||
receipt.rejection_reason = None
|
||||
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, receipt: Receipt) -> bool:
|
||||
"""Delete a receipt (cascade deletes attachments and entries)."""
|
||||
await session.delete(receipt)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def can_edit(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can edit receipt."""
|
||||
# DRAFT and REJECTED receipts can be edited (to fix and resubmit)
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
return False
|
||||
|
||||
# Only creator can edit their own receipts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def can_delete(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can delete receipt."""
|
||||
# Only DRAFT receipts can be deleted
|
||||
if receipt.status != ReceiptStatus.DRAFT:
|
||||
return False
|
||||
|
||||
# Only creator can delete their own drafts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def can_submit(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can submit receipt for review."""
|
||||
# Only DRAFT or REJECTED receipts can be submitted
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
return False
|
||||
|
||||
# Only creator can submit their own receipts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def get_stats(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get receipt statistics."""
|
||||
base_query = select(
|
||||
Receipt.status,
|
||||
func.count(Receipt.id).label("count"),
|
||||
func.sum(Receipt.amount).label("total_amount"),
|
||||
).where(
|
||||
Receipt.company_id == company_id
|
||||
)
|
||||
|
||||
if created_by:
|
||||
base_query = base_query.where(Receipt.created_by == created_by)
|
||||
|
||||
query = base_query.group_by(Receipt.status)
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
stats = {
|
||||
"draft": {"count": 0, "amount": 0},
|
||||
"pending_review": {"count": 0, "amount": 0},
|
||||
"approved": {"count": 0, "amount": 0},
|
||||
"rejected": {"count": 0, "amount": 0},
|
||||
"synced": {"count": 0, "amount": 0},
|
||||
"total": {"count": 0, "amount": 0},
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
status_key = row.status.value
|
||||
stats[status_key] = {
|
||||
"count": row.count,
|
||||
"amount": float(row.total_amount or 0),
|
||||
}
|
||||
stats["total"]["count"] += row.count
|
||||
stats["total"]["amount"] += float(row.total_amount or 0)
|
||||
|
||||
return stats
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Database configuration and session management using SQLModel."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from backend.config import settings
|
||||
|
||||
|
||||
# Create async engine
|
||||
# Note: echo=False to disable SQL query logging (too verbose)
|
||||
engine = create_async_engine(
|
||||
settings.data_entry_database_url,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
async_session_maker = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database - create tables if they don't exist."""
|
||||
# Ensure data directory exists
|
||||
db_path = Path(settings.data_entry_sqlite_database_path)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get async database session for dependency injection."""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
# Convenience function for manual session usage
|
||||
async def get_db_session() -> AsyncSession:
|
||||
"""Get a new database session (manual management)."""
|
||||
return async_session_maker()
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Alembic migrations helper for Data Entry module.
|
||||
|
||||
Provides automatic migration execution at backend startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_migrations() -> bool:
|
||||
"""
|
||||
Run pending Alembic migrations at startup.
|
||||
|
||||
Returns:
|
||||
True if migrations ran successfully (or no pending migrations),
|
||||
False if migrations failed (backend should continue with WARNING).
|
||||
"""
|
||||
try:
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# Get the path to alembic.ini
|
||||
data_entry_module = Path(__file__).parent.parent
|
||||
alembic_ini_path = data_entry_module / "alembic.ini"
|
||||
|
||||
if not alembic_ini_path.exists():
|
||||
logger.warning(f"[MIGRATIONS] alembic.ini not found at {alembic_ini_path}")
|
||||
return False
|
||||
|
||||
# Get database path from environment or default
|
||||
db_path = Path(os.getenv(
|
||||
"SQLITE_DATABASE_PATH",
|
||||
"data/receipts/receipts.db"
|
||||
)).resolve()
|
||||
|
||||
# Ensure database directory exists
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create Alembic config
|
||||
alembic_cfg = Config(str(alembic_ini_path))
|
||||
|
||||
# Override database URL
|
||||
sync_db_url = f"sqlite:///{db_path}"
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", sync_db_url)
|
||||
|
||||
# Set script location relative to alembic.ini
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location",
|
||||
str(data_entry_module / "migrations")
|
||||
)
|
||||
|
||||
# Get current revision before upgrade
|
||||
engine = create_engine(sync_db_url)
|
||||
with engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
engine.dispose()
|
||||
|
||||
logger.info(f"[MIGRATIONS] Current revision: {current_rev or 'None (fresh database)'}")
|
||||
logger.info(f"[MIGRATIONS] Database path: {db_path}")
|
||||
|
||||
# Run upgrade to head
|
||||
logger.info("[MIGRATIONS] Checking for pending migrations...")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# Get new revision after upgrade
|
||||
engine = create_engine(sync_db_url)
|
||||
with engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection)
|
||||
new_rev = context.get_current_revision()
|
||||
engine.dispose()
|
||||
|
||||
if current_rev != new_rev:
|
||||
logger.info(f"[MIGRATIONS] Applied: {current_rev or 'None'} -> {new_rev}")
|
||||
else:
|
||||
logger.info(f"[MIGRATIONS] No pending migrations. Current: {new_rev}")
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"[MIGRATIONS] Alembic not installed: {e}")
|
||||
logger.warning("[MIGRATIONS] Skipping migrations - install alembic to enable")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MIGRATIONS] Migration error: {e}", exc_info=True)
|
||||
logger.warning("[MIGRATIONS] Backend will continue without migrations")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_revision() -> str:
|
||||
"""
|
||||
Get the current Alembic revision.
|
||||
|
||||
Returns:
|
||||
Current revision string, or 'unknown' if cannot be determined.
|
||||
"""
|
||||
try:
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# Get database path from environment or default
|
||||
db_path = Path(os.getenv(
|
||||
"SQLITE_DATABASE_PATH",
|
||||
"data/receipts/receipts.db"
|
||||
)).resolve()
|
||||
|
||||
if not db_path.exists():
|
||||
return "no_database"
|
||||
|
||||
sync_db_url = f"sqlite:///{db_path}"
|
||||
engine = create_engine(sync_db_url)
|
||||
|
||||
with engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection)
|
||||
revision = context.get_current_revision()
|
||||
|
||||
engine.dispose()
|
||||
return revision or "none"
|
||||
|
||||
except ImportError:
|
||||
return "alembic_not_installed"
|
||||
except Exception as e:
|
||||
logger.debug(f"[MIGRATIONS] Could not get revision: {e}")
|
||||
return "unknown"
|
||||
@@ -0,0 +1,29 @@
|
||||
# Database models
|
||||
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection, ProcessingStatus
|
||||
from .accounting_entry import AccountingEntry, EntryType
|
||||
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from .ocr_settings import UserOCRPreference, OCRJobMetrics, OCRMetricsSummary, OCREngine
|
||||
from .batch import BatchUpload, BatchJob, BatchStatus
|
||||
|
||||
__all__ = [
|
||||
"Receipt",
|
||||
"ReceiptAttachment",
|
||||
"ReceiptStatus",
|
||||
"ReceiptType",
|
||||
"ReceiptDirection",
|
||||
"ProcessingStatus",
|
||||
"AccountingEntry",
|
||||
"EntryType",
|
||||
"SyncedSupplier",
|
||||
"LocalSupplier",
|
||||
"SyncedCashRegister",
|
||||
# OCR Settings & Metrics
|
||||
"UserOCRPreference",
|
||||
"OCRJobMetrics",
|
||||
"OCRMetricsSummary",
|
||||
"OCREngine",
|
||||
# Batch Upload
|
||||
"BatchUpload",
|
||||
"BatchJob",
|
||||
"BatchStatus",
|
||||
]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""AccountingEntry SQLModel model for proposed accounting entries."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .receipt import Receipt
|
||||
|
||||
|
||||
class EntryType(str, Enum):
|
||||
"""Type of accounting entry."""
|
||||
DEBIT = "debit"
|
||||
CREDIT = "credit"
|
||||
|
||||
|
||||
class AccountingEntry(SQLModel, table=True):
|
||||
"""Proposed accounting entry for a receipt."""
|
||||
|
||||
__tablename__ = "accounting_entries"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
receipt_id: int = Field(foreign_key="receipts.id", index=True)
|
||||
|
||||
# Account
|
||||
entry_type: EntryType
|
||||
account_code: str = Field(max_length=20) # e.g., 6022, 5311, 4426
|
||||
account_name: Optional[str] = Field(default=None, max_length=200) # Cache: "Cheltuieli combustibil"
|
||||
|
||||
# Amount
|
||||
amount: Decimal = Field(decimal_places=2, max_digits=15)
|
||||
|
||||
# Analytics (optional)
|
||||
partner_id: Optional[int] = Field(default=None)
|
||||
cost_center_id: Optional[int] = Field(default=None)
|
||||
|
||||
# Entry metadata
|
||||
is_auto_generated: bool = Field(default=True) # True if system-generated
|
||||
modified_by: Optional[str] = Field(default=None, max_length=100) # Username if modified
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
# Order for display
|
||||
sort_order: int = Field(default=0)
|
||||
|
||||
# Relationship
|
||||
receipt: Optional["Receipt"] = Relationship(back_populates="entries")
|
||||
@@ -0,0 +1,64 @@
|
||||
"""BatchUpload and BatchJob SQLModel models for bulk receipt processing."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class BatchStatus(str, Enum):
|
||||
"""Status of a batch upload."""
|
||||
PENDING = "pending" # Batch created, jobs queued
|
||||
PROCESSING = "processing" # At least one job is processing
|
||||
COMPLETED = "completed" # All jobs completed (success or failed)
|
||||
FAILED = "failed" # Batch-level failure (e.g., all jobs failed)
|
||||
|
||||
|
||||
class BatchUpload(SQLModel, table=True):
|
||||
"""
|
||||
Batch upload record for grouping multiple OCR jobs.
|
||||
|
||||
Tracks overall progress and status of a bulk upload operation.
|
||||
"""
|
||||
|
||||
__tablename__ = "batch_uploads"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# User info
|
||||
user_id: str = Field(max_length=100, index=True) # Username who created the batch
|
||||
company_id: int = Field(index=True) # Company ID for receipt creation
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Status tracking
|
||||
status: BatchStatus = Field(default=BatchStatus.PENDING)
|
||||
total_files: int = Field(default=0)
|
||||
|
||||
|
||||
class BatchJob(SQLModel, table=True):
|
||||
"""
|
||||
Junction table linking batch_uploads to ocr_jobs.
|
||||
|
||||
Each record represents one file in a batch, linking to its OCR job.
|
||||
Also stores the receipt_id once the job completes and auto-creates a receipt.
|
||||
"""
|
||||
|
||||
__tablename__ = "batch_jobs"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Foreign keys
|
||||
batch_id: int = Field(foreign_key="batch_uploads.id", index=True)
|
||||
job_id: str = Field(max_length=36, index=True) # UUID from ocr_jobs table
|
||||
|
||||
# Original filename for display
|
||||
filename: str = Field(max_length=255)
|
||||
|
||||
# Receipt reference (set after auto-create)
|
||||
receipt_id: Optional[int] = Field(default=None, foreign_key="receipts.id")
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Nomenclature models for synced and local data."""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class SyncedSupplier(SQLModel, table=True):
|
||||
"""Suppliers synced from Oracle NOM_PARTENERI."""
|
||||
__tablename__ = "synced_suppliers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
oracle_id: int = Field(index=True) # Original Oracle ID
|
||||
company_id: int = Field(index=True) # Company this supplier belongs to
|
||||
name: str = Field(max_length=200)
|
||||
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True) # CUI/CIF
|
||||
address: Optional[str] = Field(default=None, max_length=500)
|
||||
synced_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class LocalSupplier(SQLModel, table=True):
|
||||
"""Suppliers created locally from OCR (not in Oracle)."""
|
||||
__tablename__ = "local_suppliers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
company_id: int = Field(index=True)
|
||||
name: str = Field(max_length=200)
|
||||
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True)
|
||||
address: Optional[str] = Field(default=None, max_length=500)
|
||||
created_by: str = Field(max_length=100) # Username who created it
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
# Flag to indicate if it should be synced to Oracle later
|
||||
pending_oracle_sync: bool = Field(default=True)
|
||||
|
||||
|
||||
class SyncedCashRegister(SQLModel, table=True):
|
||||
"""Cash registers and bank accounts synced from Oracle."""
|
||||
__tablename__ = "synced_cash_registers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
oracle_id: int = Field(index=True)
|
||||
company_id: int = Field(index=True)
|
||||
name: str = Field(max_length=100)
|
||||
account_code: str = Field(max_length=20) # 5311, 5121, etc.
|
||||
register_type: str = Field(max_length=10) # 'cash' or 'bank'
|
||||
synced_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
@@ -0,0 +1,102 @@
|
||||
"""OCR settings and metrics SQLModel models."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class OCREngine(str, Enum):
|
||||
"""Available OCR engines."""
|
||||
TESSERACT = "tesseract"
|
||||
DOCTR = "doctr"
|
||||
DOCTR_PLUS = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
PADDLEOCR = "paddleocr"
|
||||
|
||||
|
||||
class UserOCRPreference(SQLModel, table=True):
|
||||
"""
|
||||
User's preferred OCR engine setting.
|
||||
|
||||
Each user can have one preferred OCR engine that will be
|
||||
auto-selected when they upload new receipts for processing.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_ocr_preferences"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# User identification
|
||||
username: str = Field(max_length=100, unique=True, index=True)
|
||||
|
||||
# Preference settings
|
||||
preferred_engine: OCREngine = Field(default=OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRJobMetrics(SQLModel, table=True):
|
||||
"""
|
||||
OCR job processing metrics for analytics.
|
||||
|
||||
Stores metrics for each OCR job to enable:
|
||||
- Performance tracking by engine
|
||||
- Success rate analysis
|
||||
- Processing time trends
|
||||
- User-specific analytics
|
||||
"""
|
||||
|
||||
__tablename__ = "ocr_job_metrics"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Job identification
|
||||
job_id: str = Field(max_length=50, unique=True, index=True)
|
||||
|
||||
# User and company context
|
||||
username: str = Field(max_length=100, index=True)
|
||||
company_id: Optional[int] = Field(default=None, index=True)
|
||||
|
||||
# Engine used
|
||||
engine_requested: str = Field(max_length=20) # What user/auto requested
|
||||
engine_used: str = Field(max_length=50) # What was actually used (e.g., "doctr-light")
|
||||
|
||||
# Processing metrics
|
||||
processing_time_ms: int = Field(default=0)
|
||||
file_size_bytes: int = Field(default=0)
|
||||
file_type: str = Field(max_length=50, default="image/jpeg") # MIME type
|
||||
original_filename: Optional[str] = Field(default=None, max_length=255) # Original uploaded filename
|
||||
|
||||
# Success metrics
|
||||
success: bool = Field(default=True)
|
||||
error_message: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Extraction quality metrics
|
||||
overall_confidence: float = Field(default=0.0)
|
||||
fields_extracted: int = Field(default=0) # Number of fields successfully extracted
|
||||
needs_manual_review: Optional[bool] = Field(default=None)
|
||||
validation_warnings_count: int = Field(default=0)
|
||||
validation_errors_count: int = Field(default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class OCRMetricsSummary(SQLModel):
|
||||
"""
|
||||
Summary metrics for OCR analytics.
|
||||
|
||||
Not a database table - used for API responses.
|
||||
"""
|
||||
engine: str
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float # Computed: successful_jobs / total_jobs
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
avg_fields_extracted: float
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Receipt and ReceiptAttachment SQLModel models."""
|
||||
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
|
||||
|
||||
class ReceiptType(str, Enum):
|
||||
"""Type of receipt document."""
|
||||
BON_FISCAL = "bon_fiscal"
|
||||
CHITANTA = "chitanta"
|
||||
|
||||
|
||||
class ReceiptDirection(str, Enum):
|
||||
"""Direction of receipt - expense or income."""
|
||||
CHELTUIALA = "cheltuiala" # Expense (receipt from supplier)
|
||||
INCASARE = "incasare" # Income (receipt issued to client)
|
||||
|
||||
|
||||
class ReceiptStatus(str, Enum):
|
||||
"""Workflow status of receipt."""
|
||||
DRAFT = "draft" # User is filling in data
|
||||
PENDING_REVIEW = "pending_review" # Awaiting accountant approval
|
||||
APPROVED = "approved" # Approved by accountant
|
||||
REJECTED = "rejected" # Rejected by accountant
|
||||
SYNCED = "synced" # Synced to Oracle (Phase 2)
|
||||
|
||||
|
||||
class PaymentMode(str, Enum):
|
||||
"""Payment mode - how the expense was paid."""
|
||||
CASA = "casa" # Numerar firma (5311)
|
||||
BANCA = "banca" # Virament/POS (5121)
|
||||
AVANS_DECONTARE = "avans_decontare" # Decont angajat (542)
|
||||
|
||||
|
||||
class ProcessingStatus(str, Enum):
|
||||
"""Processing status for bulk uploaded receipts."""
|
||||
PENDING = "pending" # Waiting in queue
|
||||
PROCESSING = "processing" # Currently being processed by OCR
|
||||
COMPLETED = "completed" # Successfully processed
|
||||
FAILED = "failed" # Processing failed with error
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accounting_entry import AccountingEntry
|
||||
|
||||
|
||||
class Receipt(SQLModel, table=True):
|
||||
"""Receipt (Bon Fiscal / Chitanta) with approval workflow."""
|
||||
|
||||
__tablename__ = "receipts"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Document identification
|
||||
receipt_type: ReceiptType = Field(default=ReceiptType.BON_FISCAL)
|
||||
direction: ReceiptDirection = Field(default=ReceiptDirection.CHELTUIALA)
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
|
||||
# Main data
|
||||
receipt_date: date
|
||||
amount: Decimal = Field(decimal_places=2, max_digits=15)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# TVA info (extracted from OCR) - stored as JSON for multiple entries
|
||||
tva_breakdown: Optional[str] = Field(default=None, max_length=1000) # JSON: [{"code":"A","percent":19,"amount":"15.20"}]
|
||||
tva_total: Optional[Decimal] = Field(default=None, decimal_places=2, max_digits=15)
|
||||
items_count: Optional[int] = Field(default=None)
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Expense type (for auto-generating accounting entries)
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
|
||||
# Oracle references (nomenclatures)
|
||||
company_id: int
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200) # Supplier name from OCR/selection
|
||||
cui: Optional[str] = Field(default=None, max_length=20) # Fiscal code from OCR
|
||||
ocr_raw_text: Optional[str] = Field(default=None) # Raw OCR text for debugging
|
||||
payment_methods: Optional[str] = Field(default=None, max_length=500) # JSON: [{"method":"CARD","amount":"50.00"}]
|
||||
cash_register_id: Optional[int] = Field(default=None) # Cash/Bank ID from Oracle
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100) # Cache for display
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20) # Account code (5311, 5121)
|
||||
payment_mode: Optional[str] = Field(default=None, max_length=20) # PaymentMode value: casa/banca/avans_decontare
|
||||
|
||||
# Workflow
|
||||
status: ReceiptStatus = Field(default=ReceiptStatus.DRAFT)
|
||||
created_by: str = Field(max_length=100) # Username of creator
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
submitted_at: Optional[datetime] = Field(default=None) # When submitted for approval
|
||||
|
||||
# Approval
|
||||
reviewed_by: Optional[str] = Field(default=None, max_length=100) # Accountant username
|
||||
reviewed_at: Optional[datetime] = Field(default=None)
|
||||
rejection_reason: Optional[str] = Field(default=None, max_length=500) # Reason for rejection
|
||||
|
||||
# Phase 2 - Oracle sync
|
||||
oracle_synced_at: Optional[datetime] = Field(default=None)
|
||||
oracle_act_id: Optional[int] = Field(default=None)
|
||||
oracle_error: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Bulk upload batch tracking
|
||||
batch_id: Optional[str] = Field(default=None, max_length=50, index=True)
|
||||
processing_status: Optional[str] = Field(default=None, max_length=20, index=True) # ProcessingStatus enum value
|
||||
processing_error: Optional[str] = Field(default=None) # Full error message text
|
||||
file_hash: Optional[str] = Field(default=None, max_length=64, index=True) # SHA-256 hash for duplicate detection
|
||||
processing_started_at: Optional[datetime] = Field(default=None)
|
||||
processing_completed_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
# Relationships
|
||||
attachments: List["ReceiptAttachment"] = Relationship(
|
||||
back_populates="receipt",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
entries: List["AccountingEntry"] = Relationship(
|
||||
back_populates="receipt",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
|
||||
class ReceiptAttachment(SQLModel, table=True):
|
||||
"""Attachment (photo or PDF) for a receipt."""
|
||||
|
||||
__tablename__ = "receipt_attachments"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
receipt_id: int = Field(foreign_key="receipts.id", index=True)
|
||||
|
||||
# File info
|
||||
filename: str = Field(max_length=255) # Original filename
|
||||
stored_filename: str = Field(max_length=255) # Filename on disk (UUID)
|
||||
file_path: str = Field(max_length=500) # Relative path
|
||||
file_size: int # Size in bytes
|
||||
mime_type: str = Field(max_length=100) # MIME type (image/jpeg, application/pdf)
|
||||
uploaded_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Relationship
|
||||
receipt: Optional[Receipt] = Relationship(back_populates="attachments")
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Alembic environment configuration."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from logging.config import fileConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Import all models to ensure they're registered with SQLModel
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
from backend.modules.data_entry.db.models.ocr_settings import UserOCRPreference, OCRJobMetrics
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url from environment variable if set
|
||||
# Resolve to absolute path for Windows/IIS compatibility
|
||||
db_path = Path(os.getenv("SQLITE_DATABASE_PATH", "data/receipts/receipts.db")).resolve()
|
||||
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Initial receipts schema
|
||||
|
||||
Revision ID: 001_initial
|
||||
Revises:
|
||||
Create Date: 2024-12-11
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001_initial'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create receipts table
|
||||
op.create_table(
|
||||
'receipts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_type', sa.Enum('BON_FISCAL', 'CHITANTA', name='receipttype'), nullable=False),
|
||||
sa.Column('direction', sa.Enum('CHELTUIALA', 'INCASARE', name='receiptdirection'), nullable=False),
|
||||
sa.Column('receipt_number', sa.String(length=50), nullable=True),
|
||||
sa.Column('receipt_series', sa.String(length=20), nullable=True),
|
||||
sa.Column('receipt_date', sa.Date(), nullable=False),
|
||||
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
|
||||
sa.Column('description', sa.String(length=500), nullable=True),
|
||||
sa.Column('expense_type_code', sa.String(length=20), nullable=True),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('partner_name', sa.String(length=200), nullable=True),
|
||||
sa.Column('cash_register_id', sa.Integer(), nullable=True),
|
||||
sa.Column('cash_register_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('cash_register_account', sa.String(length=20), nullable=True),
|
||||
sa.Column('status', sa.Enum('DRAFT', 'PENDING_REVIEW', 'APPROVED', 'REJECTED', 'SYNCED', name='receiptstatus'), nullable=False),
|
||||
sa.Column('created_by', sa.String(length=100), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('submitted_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('reviewed_by', sa.String(length=100), nullable=True),
|
||||
sa.Column('reviewed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('rejection_reason', sa.String(length=500), nullable=True),
|
||||
sa.Column('oracle_synced_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('oracle_act_id', sa.Integer(), nullable=True),
|
||||
sa.Column('oracle_error', sa.String(length=500), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_receipts_company_id'), 'receipts', ['company_id'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_status'), 'receipts', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_created_by'), 'receipts', ['created_by'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_receipt_date'), 'receipts', ['receipt_date'], unique=False)
|
||||
|
||||
# Create receipt_attachments table
|
||||
op.create_table(
|
||||
'receipt_attachments',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('stored_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=False),
|
||||
sa.Column('uploaded_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_receipt_attachments_receipt_id'), 'receipt_attachments', ['receipt_id'], unique=False)
|
||||
|
||||
# Create accounting_entries table
|
||||
op.create_table(
|
||||
'accounting_entries',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_id', sa.Integer(), nullable=False),
|
||||
sa.Column('entry_type', sa.Enum('DEBIT', 'CREDIT', name='entrytype'), nullable=False),
|
||||
sa.Column('account_code', sa.String(length=20), nullable=False),
|
||||
sa.Column('account_name', sa.String(length=200), nullable=True),
|
||||
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('cost_center_id', sa.Integer(), nullable=True),
|
||||
sa.Column('is_auto_generated', sa.Boolean(), nullable=False),
|
||||
sa.Column('modified_by', sa.String(length=100), nullable=True),
|
||||
sa.Column('modified_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_accounting_entries_receipt_id'), 'accounting_entries', ['receipt_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_accounting_entries_receipt_id'), table_name='accounting_entries')
|
||||
op.drop_table('accounting_entries')
|
||||
|
||||
op.drop_index(op.f('ix_receipt_attachments_receipt_id'), table_name='receipt_attachments')
|
||||
op.drop_table('receipt_attachments')
|
||||
|
||||
op.drop_index(op.f('ix_receipts_receipt_date'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_created_by'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_status'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_company_id'), table_name='receipts')
|
||||
op.drop_table('receipts')
|
||||
|
||||
# Drop enums (SQLite doesn't actually use these, but for consistency)
|
||||
op.execute("DROP TYPE IF EXISTS receipttype")
|
||||
op.execute("DROP TYPE IF EXISTS receiptdirection")
|
||||
op.execute("DROP TYPE IF EXISTS receiptstatus")
|
||||
op.execute("DROP TYPE IF EXISTS entrytype")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_tva_breakdown_to_receipt
|
||||
|
||||
Revision ID: 1cfb423c6953
|
||||
Revises: 001_initial
|
||||
Create Date: 2025-12-12 14:04:22.464289+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1cfb423c6953'
|
||||
down_revision: Union[str, None] = '001_initial'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add TVA-related columns to receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tva_breakdown', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True))
|
||||
batch_op.add_column(sa.Column('tva_total', sa.Numeric(precision=15, scale=2), nullable=True))
|
||||
batch_op.add_column(sa.Column('items_count', sa.Integer(), nullable=True))
|
||||
batch_op.add_column(sa.Column('vendor_address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove TVA-related columns from receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('vendor_address')
|
||||
batch_op.drop_column('items_count')
|
||||
batch_op.drop_column('tva_total')
|
||||
batch_op.drop_column('tva_breakdown')
|
||||
@@ -0,0 +1,89 @@
|
||||
"""add nomenclature tables
|
||||
|
||||
Revision ID: 3a653da79002
|
||||
Revises: 1cfb423c6953
|
||||
Create Date: 2025-12-13 00:28:05.719430+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3a653da79002'
|
||||
down_revision: Union[str, None] = '1cfb423c6953'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('local_suppliers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
||||
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||
sa.Column('created_by', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('pending_oracle_sync', sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_local_suppliers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_local_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
|
||||
|
||||
op.create_table('synced_cash_registers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('oracle_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
sa.Column('account_code', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False),
|
||||
sa.Column('register_type', sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False),
|
||||
sa.Column('synced_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_synced_cash_registers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_cash_registers_oracle_id'), ['oracle_id'], unique=False)
|
||||
|
||||
op.create_table('synced_suppliers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('oracle_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
||||
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||
sa.Column('synced_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_oracle_id'), ['oracle_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_oracle_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_fiscal_code'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_company_id'))
|
||||
|
||||
op.drop_table('synced_suppliers')
|
||||
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_oracle_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_company_id'))
|
||||
|
||||
op.drop_table('synced_cash_registers')
|
||||
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_local_suppliers_fiscal_code'))
|
||||
batch_op.drop_index(batch_op.f('ix_local_suppliers_company_id'))
|
||||
|
||||
op.drop_table('local_suppliers')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add_ocr_fields_to_receipt
|
||||
|
||||
Revision ID: 4b8e5f2a1d93
|
||||
Revises: 3a653da79002
|
||||
Create Date: 2025-12-15 10:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4b8e5f2a1d93'
|
||||
down_revision: Union[str, None] = '3a653da79002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add OCR-related columns to receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('cui', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
|
||||
batch_op.add_column(sa.Column('ocr_raw_text', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('payment_methods', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove OCR-related columns from receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('payment_methods')
|
||||
batch_op.drop_column('ocr_raw_text')
|
||||
batch_op.drop_column('cui')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Remove partner_id from receipts - supplier data is text-only
|
||||
|
||||
Revision ID: 20251215_remove_partner_id
|
||||
Revises: 20251216_payment_mode
|
||||
Create Date: 2025-12-15
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '20251215_remove_partner_id'
|
||||
down_revision: Union[str, None] = '20251216_payment_mode'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove partner_id column - supplier data is now text-only (partner_name, cui)."""
|
||||
# Drop the partner_id column
|
||||
op.drop_column('receipts', 'partner_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Re-add partner_id column."""
|
||||
op.add_column('receipts', sa.Column('partner_id', sa.Integer(), nullable=True))
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Add payment_mode field to receipts table.
|
||||
|
||||
Revision ID: 20251216_payment_mode
|
||||
Revises: 4b8e5f2a1d93
|
||||
Create Date: 2024-12-16
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '20251216_payment_mode'
|
||||
down_revision = '4b8e5f2a1d93'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add payment_mode column and migrate existing data."""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('payment_mode', sa.String(length=20), nullable=True))
|
||||
|
||||
# Migrate existing data based on cash_register_account
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'casa'
|
||||
WHERE cash_register_account LIKE '531%' AND payment_mode IS NULL
|
||||
""")
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'banca'
|
||||
WHERE cash_register_account LIKE '512%' AND payment_mode IS NULL
|
||||
""")
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'avans_decontare'
|
||||
WHERE cash_register_account LIKE '542%' AND payment_mode IS NULL
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove payment_mode column."""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('payment_mode')
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add needs_manual_review flag to receipts table.
|
||||
|
||||
Revision ID: 20251230_needs_manual_review
|
||||
Revises: 20251216_payment_mode
|
||||
Create Date: 2025-12-30
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '20251230_needs_manual_review'
|
||||
down_revision = '20251216_payment_mode'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add needs_manual_review column for OCR validation tracking.
|
||||
|
||||
This column tracks whether a receipt needs manual supervisor review
|
||||
based on OCR extraction validation warnings:
|
||||
- NULL = not validated yet (old receipts before validation feature)
|
||||
- FALSE = validated, no review needed
|
||||
- TRUE = validated, needs review
|
||||
"""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column('needs_manual_review', sa.Boolean(), nullable=True)
|
||||
)
|
||||
|
||||
# NOTE: We do NOT set a default value for existing rows.
|
||||
# NULL indicates the receipt was created before validation was implemented.
|
||||
# Only new receipts (created after this migration) will have TRUE/FALSE values.
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove needs_manual_review column."""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('needs_manual_review')
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Add OCR settings and metrics tables.
|
||||
|
||||
Revision ID: add_ocr_settings_metrics
|
||||
Revises: 20251230_add_needs_manual_review
|
||||
Create Date: 2025-12-31
|
||||
|
||||
This migration adds:
|
||||
- user_ocr_preferences: Store user's preferred OCR engine
|
||||
- ocr_job_metrics: Store OCR job processing metrics for analytics
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_ocr_settings_metrics'
|
||||
down_revision = '20251230_add_needs_manual_review'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create OCR settings and metrics tables."""
|
||||
|
||||
# Create user_ocr_preferences table
|
||||
op.create_table(
|
||||
'user_ocr_preferences',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('preferred_engine', sa.String(length=20), nullable=False, server_default='doctr_plus'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_user_ocr_preferences_username', 'user_ocr_preferences', ['username'], unique=True)
|
||||
|
||||
# Create ocr_job_metrics table
|
||||
op.create_table(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=50), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('engine_requested', sa.String(length=20), nullable=False),
|
||||
sa.Column('engine_used', sa.String(length=50), nullable=False),
|
||||
sa.Column('processing_time_ms', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False, server_default='image/jpeg'),
|
||||
sa.Column('success', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True),
|
||||
sa.Column('overall_confidence', sa.Float(), nullable=False, server_default='0.0'),
|
||||
sa.Column('fields_extracted', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('needs_manual_review', sa.Boolean(), nullable=True),
|
||||
sa.Column('validation_warnings_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('validation_errors_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_ocr_job_metrics_job_id', 'ocr_job_metrics', ['job_id'], unique=True)
|
||||
op.create_index('ix_ocr_job_metrics_username', 'ocr_job_metrics', ['username'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_company_id', 'ocr_job_metrics', ['company_id'], unique=False)
|
||||
op.create_index('ix_ocr_job_metrics_created_at', 'ocr_job_metrics', ['created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop OCR settings and metrics tables."""
|
||||
op.drop_index('ix_ocr_job_metrics_created_at', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_company_id', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_username', table_name='ocr_job_metrics')
|
||||
op.drop_index('ix_ocr_job_metrics_job_id', table_name='ocr_job_metrics')
|
||||
op.drop_table('ocr_job_metrics')
|
||||
|
||||
op.drop_index('ix_user_ocr_preferences_username', table_name='user_ocr_preferences')
|
||||
op.drop_table('user_ocr_preferences')
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Add original_filename to ocr_job_metrics.
|
||||
|
||||
Revision ID: add_original_filename_to_metrics
|
||||
Revises: add_ocr_settings_metrics
|
||||
Create Date: 2025-12-31
|
||||
|
||||
Adds original_filename column to track the uploaded filename.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_original_filename_to_metrics'
|
||||
down_revision = 'add_ocr_settings_metrics'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add original_filename column to ocr_job_metrics."""
|
||||
op.add_column(
|
||||
'ocr_job_metrics',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove original_filename column."""
|
||||
op.drop_column('ocr_job_metrics', 'original_filename')
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Add company_id to batch_uploads table.
|
||||
|
||||
Revision ID: 20260109_batch_company
|
||||
Revises: 20251231_add_original_filename_to_metrics
|
||||
Create Date: 2026-01-09
|
||||
|
||||
This migration adds the company_id column to batch_uploads to support
|
||||
automatic receipt creation during bulk upload processing.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '20260109_batch_company'
|
||||
down_revision = None # Will be auto-detected
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add company_id column to batch_uploads table."""
|
||||
# Check if column already exists (SQLModel may have created it)
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
# Check if batch_uploads table exists
|
||||
if 'batch_uploads' in inspector.get_table_names():
|
||||
columns = [col['name'] for col in inspector.get_columns('batch_uploads')]
|
||||
if 'company_id' not in columns:
|
||||
op.add_column(
|
||||
'batch_uploads',
|
||||
sa.Column('company_id', sa.Integer(), nullable=True)
|
||||
)
|
||||
# Create index for company_id
|
||||
op.create_index(
|
||||
'ix_batch_uploads_company_id',
|
||||
'batch_uploads',
|
||||
['company_id'],
|
||||
unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove company_id column from batch_uploads table."""
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
if 'batch_uploads' in inspector.get_table_names():
|
||||
columns = [col['name'] for col in inspector.get_columns('batch_uploads')]
|
||||
if 'company_id' in columns:
|
||||
op.drop_index('ix_batch_uploads_company_id', table_name='batch_uploads')
|
||||
op.drop_column('batch_uploads', 'company_id')
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Add batch processing fields to receipts table.
|
||||
|
||||
Revision ID: add_batch_processing_fields
|
||||
Revises: add_original_filename_to_metrics
|
||||
Create Date: 2026-01-11
|
||||
|
||||
Adds fields for bulk upload batch tracking:
|
||||
- batch_id: UUID string for grouping receipts from same upload
|
||||
- processing_status: enum (pending/processing/completed/failed)
|
||||
- processing_error: full error message text
|
||||
- file_hash: SHA-256 hash for duplicate detection
|
||||
- processing_started_at: when OCR processing started
|
||||
- processing_completed_at: when OCR processing completed
|
||||
|
||||
Also creates indexes for efficient querying.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# Revision identifiers
|
||||
revision = 'add_batch_processing_fields'
|
||||
down_revision = 'add_original_filename_to_metrics'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add batch processing columns to receipts table."""
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
# Get existing columns
|
||||
columns = [col['name'] for col in inspector.get_columns('receipts')]
|
||||
|
||||
# Add batch_id column with index
|
||||
if 'batch_id' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('batch_id', sa.String(length=50), nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
'ix_receipts_batch_id',
|
||||
'receipts',
|
||||
['batch_id'],
|
||||
unique=False
|
||||
)
|
||||
|
||||
# Add processing_status column with index
|
||||
if 'processing_status' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('processing_status', sa.String(length=20), nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
'ix_receipts_processing_status',
|
||||
'receipts',
|
||||
['processing_status'],
|
||||
unique=False
|
||||
)
|
||||
|
||||
# Add processing_error column (TEXT for full error messages)
|
||||
if 'processing_error' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('processing_error', sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Add file_hash column with index for duplicate detection
|
||||
if 'file_hash' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('file_hash', sa.String(length=64), nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
'ix_receipts_file_hash',
|
||||
'receipts',
|
||||
['file_hash'],
|
||||
unique=False
|
||||
)
|
||||
|
||||
# Add processing_started_at column
|
||||
if 'processing_started_at' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('processing_started_at', sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
# Add processing_completed_at column
|
||||
if 'processing_completed_at' not in columns:
|
||||
op.add_column(
|
||||
'receipts',
|
||||
sa.Column('processing_completed_at', sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove batch processing columns from receipts table."""
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
columns = [col['name'] for col in inspector.get_columns('receipts')]
|
||||
indexes = [idx['name'] for idx in inspector.get_indexes('receipts')]
|
||||
|
||||
# Remove indexes first (SQLite batch mode)
|
||||
if 'ix_receipts_batch_id' in indexes:
|
||||
op.drop_index('ix_receipts_batch_id', table_name='receipts')
|
||||
if 'ix_receipts_processing_status' in indexes:
|
||||
op.drop_index('ix_receipts_processing_status', table_name='receipts')
|
||||
if 'ix_receipts_file_hash' in indexes:
|
||||
op.drop_index('ix_receipts_file_hash', table_name='receipts')
|
||||
|
||||
# Remove columns (in reverse order of addition)
|
||||
if 'processing_completed_at' in columns:
|
||||
op.drop_column('receipts', 'processing_completed_at')
|
||||
if 'processing_started_at' in columns:
|
||||
op.drop_column('receipts', 'processing_started_at')
|
||||
if 'file_hash' in columns:
|
||||
op.drop_column('receipts', 'file_hash')
|
||||
if 'processing_error' in columns:
|
||||
op.drop_column('receipts', 'processing_error')
|
||||
if 'processing_status' in columns:
|
||||
op.drop_column('receipts', 'processing_status')
|
||||
if 'batch_id' in columns:
|
||||
op.drop_column('receipts', 'batch_id')
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Data Entry module router factory."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
def create_data_entry_router() -> APIRouter:
|
||||
"""
|
||||
Create and configure Data Entry module router.
|
||||
|
||||
Includes all data entry endpoints:
|
||||
- /receipts - Receipt CRUD and workflow
|
||||
- /ocr - OCR processing for receipts
|
||||
- /nomenclature - Nomenclature syncing from Oracle
|
||||
- /settings - User settings (OCR preferences)
|
||||
- /metrics - OCR analytics and metrics
|
||||
- /bulk - Bulk upload for batch processing
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for data entry module
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
# Import routers here to avoid circular imports
|
||||
from .receipts import router as receipts_router
|
||||
from .ocr import router as ocr_router
|
||||
from .nomenclature import router as nomenclature_router
|
||||
from .ocr_settings import router as ocr_settings_router
|
||||
from .bulk import router as bulk_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
|
||||
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
|
||||
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
|
||||
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
|
||||
# OCR settings and metrics (endpoints at /settings/* and /metrics/*)
|
||||
router.include_router(ocr_settings_router, tags=["data-entry-settings"])
|
||||
# Bulk upload for batch processing
|
||||
router.include_router(bulk_router, prefix="/bulk", tags=["data-entry-bulk"])
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,997 @@
|
||||
"""
|
||||
Bulk upload API endpoints for batch receipt processing.
|
||||
|
||||
Endpoints:
|
||||
- POST /upload - Submit multiple files for OCR processing in a single batch
|
||||
- GET /batches/{batch_id}/status - Get batch status with optional long-polling
|
||||
|
||||
Validation:
|
||||
- Max 100 files per batch
|
||||
- Max 10MB per file
|
||||
- Allowed types: PDF, PNG, JPG
|
||||
|
||||
Duplicate Detection (US-007):
|
||||
- SHA-256 hash calculated for each file
|
||||
- Duplicate files (same hash + company_id) are rejected with 409 Conflict info
|
||||
- Duplicates reported in error list, non-duplicates processed normally
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query, Header
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.models import BatchUpload, BatchJob, BatchStatus, Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.schemas.bulk import (
|
||||
BulkUploadResponse,
|
||||
BulkUploadResponseWithDuplicates,
|
||||
BatchStatusResponse,
|
||||
BatchJobInfo,
|
||||
DuplicateFileInfo,
|
||||
RetryResponse,
|
||||
BatchRetryResponse,
|
||||
CancelJobResponse,
|
||||
CancelBatchResponse
|
||||
)
|
||||
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus
|
||||
from backend.config import settings
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Helper for selected company from header ============
|
||||
|
||||
async def get_selected_company(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
x_selected_company: Annotated[Optional[str], Header()] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get selected company from X-Selected-Company header.
|
||||
|
||||
Validates that the user has access to the specified company.
|
||||
Falls back to user's first company if no header is provided.
|
||||
"""
|
||||
if x_selected_company:
|
||||
try:
|
||||
company_id = int(x_selected_company)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid company ID format: {x_selected_company}"
|
||||
)
|
||||
|
||||
if str(company_id) in current_user.companies:
|
||||
return company_id
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Nu aveți acces la firma {company_id}"
|
||||
)
|
||||
|
||||
# No header - use first company from user's list
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Nu aveți nicio firmă asignată"
|
||||
)
|
||||
|
||||
# Validation constants
|
||||
MAX_FILES_PER_BATCH = 100
|
||||
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
ALLOWED_MIME_TYPES = {"image/jpeg", "image/png", "application/pdf"}
|
||||
|
||||
|
||||
def compute_file_hash(content: bytes) -> str:
|
||||
"""
|
||||
Compute SHA-256 hash of file content.
|
||||
|
||||
Used for duplicate detection - same file content = same hash.
|
||||
|
||||
Args:
|
||||
content: Raw file bytes
|
||||
|
||||
Returns:
|
||||
Hexadecimal string of SHA-256 hash (64 characters)
|
||||
"""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
async def check_duplicate_hashes(
|
||||
session: AsyncSession,
|
||||
file_hashes: List[str],
|
||||
company_id: int
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Check which file hashes already exist in the database for this company.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
file_hashes: List of SHA-256 hashes to check
|
||||
company_id: Company ID to scope the duplicate check
|
||||
|
||||
Returns:
|
||||
Dict mapping hash -> existing receipt_id for duplicates found
|
||||
"""
|
||||
if not file_hashes:
|
||||
return {}
|
||||
|
||||
# Query for existing receipts with these hashes for this company
|
||||
result = await session.execute(
|
||||
select(Receipt.file_hash, Receipt.id).where(
|
||||
and_(
|
||||
Receipt.file_hash.in_(file_hashes),
|
||||
Receipt.company_id == company_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Build hash -> receipt_id mapping
|
||||
# Note: result.all() is synchronous in SQLAlchemy async, returns list of tuples
|
||||
duplicates = {}
|
||||
rows = result.all()
|
||||
for row in rows:
|
||||
duplicates[row[0]] = row[1]
|
||||
|
||||
return duplicates
|
||||
|
||||
|
||||
@router.post("/upload", response_model=Union[BulkUploadResponse, BulkUploadResponseWithDuplicates])
|
||||
async def bulk_upload(
|
||||
files: List[UploadFile] = File(..., description="Multiple files to upload"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
selected_company: int = Depends(get_selected_company)
|
||||
):
|
||||
"""
|
||||
Upload multiple files for batch OCR processing.
|
||||
|
||||
Creates a batch record and queues all files as OCR jobs.
|
||||
Invalid files cause entire batch rejection (validation errors).
|
||||
Duplicate files are reported separately and skipped - non-duplicates are processed.
|
||||
|
||||
Duplicate Detection (US-007):
|
||||
- SHA-256 hash calculated for each file before processing
|
||||
- Files with existing hash for same company are rejected with 409 info
|
||||
- Response includes duplicate details with existing_receipt_id
|
||||
|
||||
Args:
|
||||
files: List of image/PDF files (max 100 files, max 10MB each)
|
||||
|
||||
Returns:
|
||||
BulkUploadResponse with batch_id and list of job_ids
|
||||
BulkUploadResponseWithDuplicates if some files were duplicates
|
||||
|
||||
Raises:
|
||||
400: If validation fails (too many files, file too large, invalid type)
|
||||
409: If ALL files are duplicates
|
||||
500: If job creation fails
|
||||
"""
|
||||
# Validate file count
|
||||
if len(files) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No files provided"
|
||||
)
|
||||
|
||||
if len(files) > MAX_FILES_PER_BATCH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Too many files. Maximum {MAX_FILES_PER_BATCH} files per batch."
|
||||
)
|
||||
|
||||
# Pre-validate all files before creating any jobs (atomic check)
|
||||
invalid_files = []
|
||||
file_contents = []
|
||||
|
||||
for file in files:
|
||||
# Check MIME type
|
||||
if file.content_type not in ALLOWED_MIME_TYPES:
|
||||
invalid_files.append(f"{file.filename}: Invalid type ({file.content_type})")
|
||||
continue
|
||||
|
||||
# Read content and check size
|
||||
content = await file.read()
|
||||
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||
invalid_files.append(f"{file.filename}: File too large ({len(content) // (1024*1024)}MB > 10MB)")
|
||||
continue
|
||||
|
||||
# Compute SHA-256 hash for duplicate detection (US-007)
|
||||
file_hash = compute_file_hash(content)
|
||||
|
||||
# Store for later processing
|
||||
file_contents.append({
|
||||
"filename": file.filename,
|
||||
"content": content,
|
||||
"mime_type": file.content_type,
|
||||
"file_hash": file_hash
|
||||
})
|
||||
|
||||
# If any files are invalid, reject the entire batch
|
||||
if invalid_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": f"Validation failed for {len(invalid_files)} file(s)",
|
||||
"invalid_files": invalid_files
|
||||
}
|
||||
)
|
||||
|
||||
# Check for duplicates BEFORE creating batch (US-007)
|
||||
all_hashes = [f["file_hash"] for f in file_contents]
|
||||
existing_duplicates = await check_duplicate_hashes(session, all_hashes, selected_company)
|
||||
|
||||
# Separate duplicate files from processable files
|
||||
duplicate_files: List[DuplicateFileInfo] = []
|
||||
processable_files = []
|
||||
|
||||
for file_data in file_contents:
|
||||
if file_data["file_hash"] in existing_duplicates:
|
||||
existing_receipt_id = existing_duplicates[file_data["file_hash"]]
|
||||
duplicate_files.append(DuplicateFileInfo(
|
||||
filename=file_data["filename"],
|
||||
error="duplicate",
|
||||
existing_receipt_id=existing_receipt_id,
|
||||
message=f"Fișier duplicat - există deja ca bon #{existing_receipt_id}"
|
||||
))
|
||||
logger.info(
|
||||
f"[BulkUpload] Duplicate detected: {file_data['filename']} "
|
||||
f"(hash={file_data['file_hash'][:16]}...) matches receipt #{existing_receipt_id}"
|
||||
)
|
||||
else:
|
||||
processable_files.append(file_data)
|
||||
|
||||
# If ALL files are duplicates, return 409 Conflict
|
||||
if len(duplicate_files) == len(file_contents):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "all_duplicates",
|
||||
"message": f"Toate cele {len(duplicate_files)} fișiere sunt duplicate",
|
||||
"duplicates": [d.model_dump() for d in duplicate_files]
|
||||
}
|
||||
)
|
||||
|
||||
# If no processable files remain after filtering (shouldn't happen but be safe)
|
||||
if not processable_files:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "no_files_to_process",
|
||||
"message": "Nu există fișiere de procesat",
|
||||
"duplicates": [d.model_dump() for d in duplicate_files]
|
||||
}
|
||||
)
|
||||
|
||||
# Create batch record with company_id for auto-save
|
||||
batch = BatchUpload(
|
||||
user_id=current_user.username,
|
||||
company_id=selected_company,
|
||||
status=BatchStatus.PENDING,
|
||||
total_files=len(processable_files) # Only count processable files
|
||||
)
|
||||
session.add(batch)
|
||||
await session.flush() # Get batch.id before creating jobs
|
||||
|
||||
# Create OCR jobs for processable files only
|
||||
job_ids = []
|
||||
batch_jobs = []
|
||||
|
||||
try:
|
||||
for file_data in processable_files:
|
||||
# Create OCR job using existing job_queue
|
||||
# Pass batch_id and file_hash for tracking
|
||||
job = await job_queue.create_job(
|
||||
file_bytes=file_data["content"],
|
||||
mime_type=file_data["mime_type"],
|
||||
engine="doctr_plus", # Default engine for bulk
|
||||
username=current_user.username,
|
||||
original_filename=file_data["filename"],
|
||||
batch_id=batch.id, # Link job to batch for auto-save integration
|
||||
file_hash=file_data["file_hash"] # Pass hash for storage in receipt
|
||||
)
|
||||
|
||||
job_ids.append(job.id)
|
||||
|
||||
# Create batch_job link
|
||||
batch_job = BatchJob(
|
||||
batch_id=batch.id,
|
||||
job_id=job.id,
|
||||
filename=file_data["filename"]
|
||||
)
|
||||
batch_jobs.append(batch_job)
|
||||
|
||||
# Add all batch_job records
|
||||
for bj in batch_jobs:
|
||||
session.add(bj)
|
||||
|
||||
# Commit everything atomically
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"[BulkUpload] Created batch {batch.id} with {len(job_ids)} jobs "
|
||||
f"for user {current_user.username}"
|
||||
f"{f', {len(duplicate_files)} duplicates skipped' if duplicate_files else ''}"
|
||||
)
|
||||
|
||||
# Return response with duplicate info if any duplicates were found
|
||||
if duplicate_files:
|
||||
return BulkUploadResponseWithDuplicates(
|
||||
batch_id=batch.id,
|
||||
job_ids=job_ids,
|
||||
total_files=len(file_contents),
|
||||
processed_files=len(job_ids),
|
||||
duplicate_files=len(duplicate_files),
|
||||
duplicates=duplicate_files,
|
||||
message=f"{len(job_ids)} fișier(e) în procesare, {len(duplicate_files)} duplicate ignorate"
|
||||
)
|
||||
|
||||
return BulkUploadResponse(
|
||||
batch_id=batch.id,
|
||||
job_ids=job_ids,
|
||||
total_files=len(job_ids),
|
||||
message=f"{len(job_ids)} files queued for processing"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
await session.rollback()
|
||||
logger.error(f"[BulkUpload] Failed to create batch: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create batch: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Long-polling constants
|
||||
MAX_WAIT_SECONDS = 30
|
||||
POLL_INTERVAL_SECONDS = 0.5
|
||||
|
||||
|
||||
async def _get_batch_status_snapshot(
|
||||
batch_id: int,
|
||||
session: AsyncSession
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get current batch status snapshot.
|
||||
|
||||
Returns dict with status counts and jobs list, or None if batch not found.
|
||||
"""
|
||||
# Get batch record
|
||||
batch_result = await session.execute(
|
||||
select(BatchUpload).where(BatchUpload.id == batch_id)
|
||||
)
|
||||
batch = batch_result.scalar_one_or_none()
|
||||
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Get all batch_jobs for this batch
|
||||
batch_jobs_result = await session.execute(
|
||||
select(BatchJob).where(BatchJob.batch_id == batch_id)
|
||||
)
|
||||
batch_jobs = batch_jobs_result.scalars().all()
|
||||
|
||||
if not batch_jobs:
|
||||
return {
|
||||
"batch": batch,
|
||||
"pending_count": 0,
|
||||
"processing_count": 0,
|
||||
"completed_count": 0,
|
||||
"failed_count": 0,
|
||||
"jobs": [],
|
||||
"total_amount": None
|
||||
}
|
||||
|
||||
# Get job statuses and error_messages from OCR job queue (SQLite)
|
||||
job_statuses = {}
|
||||
job_errors = {}
|
||||
for bj in batch_jobs:
|
||||
job = await job_queue.get_job(bj.job_id)
|
||||
if job:
|
||||
job_statuses[bj.job_id] = job.status.value
|
||||
job_errors[bj.job_id] = job.error_message
|
||||
else:
|
||||
# Job not found in queue - treat as failed
|
||||
job_statuses[bj.job_id] = "failed"
|
||||
job_errors[bj.job_id] = "Job not found in queue"
|
||||
|
||||
# Count by status
|
||||
pending_count = sum(1 for s in job_statuses.values() if s == "pending")
|
||||
processing_count = sum(1 for s in job_statuses.values() if s == "processing")
|
||||
completed_count = sum(1 for s in job_statuses.values() if s == "completed")
|
||||
failed_count = sum(1 for s in job_statuses.values() if s == "failed")
|
||||
|
||||
# Build jobs list with status info
|
||||
jobs_info = []
|
||||
for bj in batch_jobs:
|
||||
jobs_info.append({
|
||||
"job_id": bj.job_id,
|
||||
"filename": bj.filename,
|
||||
"status": job_statuses.get(bj.job_id, "failed"),
|
||||
"receipt_id": bj.receipt_id,
|
||||
"error_message": job_errors.get(bj.job_id)
|
||||
})
|
||||
|
||||
# Calculate total_amount from completed receipts
|
||||
total_amount = None
|
||||
receipt_ids = [bj.receipt_id for bj in batch_jobs if bj.receipt_id is not None]
|
||||
if receipt_ids:
|
||||
amount_result = await session.execute(
|
||||
select(func.sum(Receipt.amount)).where(Receipt.id.in_(receipt_ids))
|
||||
)
|
||||
total_sum = amount_result.scalar()
|
||||
if total_sum is not None:
|
||||
total_amount = float(total_sum)
|
||||
|
||||
return {
|
||||
"batch": batch,
|
||||
"pending_count": pending_count,
|
||||
"processing_count": processing_count,
|
||||
"completed_count": completed_count,
|
||||
"failed_count": failed_count,
|
||||
"jobs": jobs_info,
|
||||
"total_amount": total_amount
|
||||
}
|
||||
|
||||
|
||||
def _compute_batch_overall_status(pending: int, processing: int, completed: int, failed: int, total: int) -> str:
|
||||
"""Compute overall batch status from job counts."""
|
||||
if pending + processing == 0:
|
||||
# All jobs finished
|
||||
if failed == total:
|
||||
return BatchStatus.FAILED.value
|
||||
return BatchStatus.COMPLETED.value
|
||||
elif processing > 0 or completed > 0 or failed > 0:
|
||||
return BatchStatus.PROCESSING.value
|
||||
else:
|
||||
return BatchStatus.PENDING.value
|
||||
|
||||
|
||||
@router.get("/batches/{batch_id}/status", response_model=BatchStatusResponse)
|
||||
async def get_batch_status(
|
||||
batch_id: int,
|
||||
wait: Optional[int] = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
le=MAX_WAIT_SECONDS,
|
||||
description="Long-polling wait time in seconds (max 30)"
|
||||
),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get batch processing status with optional long-polling.
|
||||
|
||||
Returns aggregated status counts and individual job statuses.
|
||||
When `wait` parameter is provided, the endpoint will poll until:
|
||||
- Status changes from initial snapshot
|
||||
- All jobs complete (pending + processing = 0)
|
||||
- Timeout is reached
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID to query
|
||||
wait: Optional wait time in seconds for long-polling (0-30)
|
||||
|
||||
Returns:
|
||||
BatchStatusResponse with status counts and job details
|
||||
|
||||
Raises:
|
||||
404: If batch not found
|
||||
"""
|
||||
# Get initial snapshot
|
||||
snapshot = await _get_batch_status_snapshot(batch_id, session)
|
||||
|
||||
if snapshot is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Batch {batch_id} not found"
|
||||
)
|
||||
|
||||
# If long-polling requested and jobs still in progress
|
||||
if wait and wait > 0:
|
||||
initial_pending = snapshot["pending_count"]
|
||||
initial_processing = snapshot["processing_count"]
|
||||
initial_completed = snapshot["completed_count"]
|
||||
initial_failed = snapshot["failed_count"]
|
||||
|
||||
# Only wait if there are still jobs in progress
|
||||
if initial_pending + initial_processing > 0:
|
||||
elapsed = 0.0
|
||||
while elapsed < wait:
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
elapsed += POLL_INTERVAL_SECONDS
|
||||
|
||||
# Refresh snapshot
|
||||
snapshot = await _get_batch_status_snapshot(batch_id, session)
|
||||
if snapshot is None:
|
||||
# Batch deleted during polling (edge case)
|
||||
raise HTTPException(status_code=404, detail=f"Batch {batch_id} not found")
|
||||
|
||||
# Check if status changed
|
||||
current_pending = snapshot["pending_count"]
|
||||
current_processing = snapshot["processing_count"]
|
||||
current_completed = snapshot["completed_count"]
|
||||
current_failed = snapshot["failed_count"]
|
||||
|
||||
if (current_pending != initial_pending or
|
||||
current_processing != initial_processing or
|
||||
current_completed != initial_completed or
|
||||
current_failed != initial_failed):
|
||||
# Status changed, return immediately
|
||||
break
|
||||
|
||||
# Check if all jobs finished
|
||||
if current_pending + current_processing == 0:
|
||||
break
|
||||
|
||||
# Build response
|
||||
batch = snapshot["batch"]
|
||||
total_files = batch.total_files
|
||||
|
||||
overall_status = _compute_batch_overall_status(
|
||||
snapshot["pending_count"],
|
||||
snapshot["processing_count"],
|
||||
snapshot["completed_count"],
|
||||
snapshot["failed_count"],
|
||||
total_files
|
||||
)
|
||||
|
||||
jobs = [
|
||||
BatchJobInfo(
|
||||
job_id=j["job_id"],
|
||||
filename=j["filename"],
|
||||
status=j["status"],
|
||||
receipt_id=j["receipt_id"],
|
||||
error_message=j.get("error_message")
|
||||
)
|
||||
for j in snapshot["jobs"]
|
||||
]
|
||||
|
||||
return BatchStatusResponse(
|
||||
batch_id=batch.id,
|
||||
status=overall_status,
|
||||
total_files=total_files,
|
||||
pending_count=snapshot["pending_count"],
|
||||
processing_count=snapshot["processing_count"],
|
||||
completed_count=snapshot["completed_count"],
|
||||
failed_count=snapshot["failed_count"],
|
||||
jobs=jobs,
|
||||
total_amount=snapshot["total_amount"],
|
||||
created_at=batch.created_at
|
||||
)
|
||||
|
||||
|
||||
# ============ Retry Endpoints (US-006) ============
|
||||
|
||||
|
||||
async def _retry_single_receipt(
|
||||
session: AsyncSession,
|
||||
receipt: Receipt,
|
||||
username: str
|
||||
) -> tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Retry processing for a single receipt.
|
||||
|
||||
Finds the original file from attachments, resets processing status,
|
||||
and creates a new OCR job.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
receipt: Receipt to retry
|
||||
username: Username for the new OCR job
|
||||
|
||||
Returns:
|
||||
Tuple of (success, job_id, error_message)
|
||||
"""
|
||||
# Get the first attachment to find the source file
|
||||
attachments_result = await session.execute(
|
||||
select(ReceiptAttachment)
|
||||
.where(ReceiptAttachment.receipt_id == receipt.id)
|
||||
.limit(1)
|
||||
)
|
||||
attachment = attachments_result.scalar_one_or_none()
|
||||
|
||||
if not attachment:
|
||||
return False, None, "Bonul nu are fișier atașat"
|
||||
|
||||
# Construct full path to attachment file
|
||||
file_path = settings.data_entry_upload_path_resolved / attachment.file_path
|
||||
|
||||
if not file_path.exists():
|
||||
return False, None, "Fișierul original nu mai este disponibil"
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"[Retry] Failed to read file {file_path}: {e}")
|
||||
return False, None, f"Eroare la citirea fișierului: {str(e)}"
|
||||
|
||||
# Create new OCR job
|
||||
try:
|
||||
job = await job_queue.create_job(
|
||||
file_bytes=file_bytes,
|
||||
mime_type=attachment.mime_type,
|
||||
engine="doctr_plus",
|
||||
username=username,
|
||||
original_filename=attachment.filename,
|
||||
batch_id=None, # No batch for retry - direct processing
|
||||
file_hash=receipt.file_hash
|
||||
)
|
||||
|
||||
# Reset receipt processing status
|
||||
receipt.processing_status = "pending"
|
||||
receipt.processing_error = None
|
||||
receipt.processing_started_at = datetime.utcnow()
|
||||
receipt.processing_completed_at = None
|
||||
|
||||
await session.flush()
|
||||
|
||||
logger.info(f"[Retry] Receipt {receipt.id} requeued as job {job.id}")
|
||||
return True, job.id, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Retry] Failed to create job for receipt {receipt.id}: {e}")
|
||||
return False, None, f"Eroare la crearea job-ului OCR: {str(e)}"
|
||||
|
||||
|
||||
@router.post("/retry/{receipt_id}", response_model=RetryResponse)
|
||||
async def retry_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
selected_company: int = Depends(get_selected_company)
|
||||
):
|
||||
"""
|
||||
Retry OCR processing for a single failed receipt.
|
||||
|
||||
Resets the receipt's processing_status to 'pending' and creates
|
||||
a new OCR job using the original attachment file.
|
||||
|
||||
Args:
|
||||
receipt_id: ID of the receipt to retry
|
||||
|
||||
Returns:
|
||||
RetryResponse with success status and new job ID
|
||||
|
||||
Raises:
|
||||
404: If receipt not found
|
||||
400: If receipt is not in 'failed' status
|
||||
400: If original file is not available
|
||||
"""
|
||||
# Get the receipt
|
||||
result = await session.execute(
|
||||
select(Receipt).where(
|
||||
and_(
|
||||
Receipt.id == receipt_id,
|
||||
Receipt.company_id == selected_company
|
||||
)
|
||||
)
|
||||
)
|
||||
receipt = result.scalar_one_or_none()
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Bonul #{receipt_id} nu a fost găsit"
|
||||
)
|
||||
|
||||
# Verify receipt is in failed status
|
||||
if receipt.processing_status != "failed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Bonul nu este în stare de eroare (status actual: {receipt.processing_status})"
|
||||
)
|
||||
|
||||
# Attempt retry
|
||||
success, job_id, error = await _retry_single_receipt(
|
||||
session, receipt, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error or "Eroare necunoscută la reîncărcare"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return RetryResponse(
|
||||
success=True,
|
||||
receipt_id=receipt_id,
|
||||
job_id=job_id,
|
||||
message="Bon reîncarcat în procesare"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retry-batch/{batch_id}", response_model=BatchRetryResponse)
|
||||
async def retry_batch_failed(
|
||||
batch_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
selected_company: int = Depends(get_selected_company)
|
||||
):
|
||||
"""
|
||||
Retry all failed receipts in a batch.
|
||||
|
||||
Finds all receipts with batch_id matching and processing_status='failed',
|
||||
then attempts to retry each one.
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID (UUID string from receipt.batch_id)
|
||||
|
||||
Returns:
|
||||
BatchRetryResponse with counts of successful and failed retries
|
||||
|
||||
Raises:
|
||||
404: If no failed receipts found for batch
|
||||
"""
|
||||
# Find all failed receipts in this batch
|
||||
result = await session.execute(
|
||||
select(Receipt).where(
|
||||
and_(
|
||||
Receipt.batch_id == batch_id,
|
||||
Receipt.company_id == selected_company,
|
||||
Receipt.processing_status == "failed"
|
||||
)
|
||||
)
|
||||
)
|
||||
failed_receipts = result.scalars().all()
|
||||
|
||||
if not failed_receipts:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Nu există bonuri cu erori în batch-ul {batch_id}"
|
||||
)
|
||||
|
||||
# Retry each receipt
|
||||
retried_count = 0
|
||||
failed_count = 0
|
||||
errors = []
|
||||
|
||||
for receipt in failed_receipts:
|
||||
success, job_id, error = await _retry_single_receipt(
|
||||
session, receipt, current_user.username
|
||||
)
|
||||
|
||||
if success:
|
||||
retried_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
errors.append(f"Bon #{receipt.id}: {error}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
return BatchRetryResponse(
|
||||
success=retried_count > 0,
|
||||
batch_id=batch_id,
|
||||
retried_count=retried_count,
|
||||
failed_count=failed_count,
|
||||
errors=errors,
|
||||
message=f"{retried_count} bonuri reîncarcate în procesare"
|
||||
+ (f", {failed_count} erori" if failed_count > 0 else "")
|
||||
)
|
||||
|
||||
|
||||
# ============ Cancel Endpoints (US-014) ============
|
||||
|
||||
|
||||
@router.post("/cancel/{job_id}", response_model=CancelJobResponse)
|
||||
async def cancel_job(
|
||||
job_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Cancel a single OCR processing job.
|
||||
|
||||
Only jobs with status 'pending' or 'processing' can be cancelled.
|
||||
Jobs with status 'completed' or 'failed' cannot be cancelled.
|
||||
|
||||
Important: If a receipt has already been created from this job,
|
||||
it will NOT be deleted - receipts are preserved for audit purposes.
|
||||
|
||||
Args:
|
||||
job_id: The UUID of the OCR job to cancel
|
||||
|
||||
Returns:
|
||||
CancelJobResponse with cancellation details
|
||||
|
||||
Raises:
|
||||
404: If job not found in batch_jobs table
|
||||
400: If job has already completed or failed
|
||||
"""
|
||||
# Find the job in batch_jobs table
|
||||
batch_job_result = await session.execute(
|
||||
select(BatchJob).where(BatchJob.job_id == job_id)
|
||||
)
|
||||
batch_job = batch_job_result.scalar_one_or_none()
|
||||
|
||||
if not batch_job:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Job {job_id} nu a fost găsit"
|
||||
)
|
||||
|
||||
# Get the OCR job from job_queue to check current status
|
||||
ocr_job = await job_queue.get_job(job_id)
|
||||
|
||||
if not ocr_job:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Job {job_id} nu există în coada de procesare"
|
||||
)
|
||||
|
||||
# Check if job can be cancelled
|
||||
current_status = ocr_job.status.value
|
||||
|
||||
if current_status == OCRJobStatus.completed.value:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Job-ul a fost deja procesat cu succes. Nu poate fi anulat."
|
||||
)
|
||||
|
||||
if current_status == OCRJobStatus.failed.value:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Job-ul a eșuat deja. Folosiți opțiunea de reîncercare în loc de anulare."
|
||||
)
|
||||
|
||||
if current_status == OCRJobStatus.cancelled.value:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Job-ul a fost deja anulat."
|
||||
)
|
||||
|
||||
# Update job status to cancelled in job_queue (SQLite)
|
||||
cancelled_at = datetime.utcnow()
|
||||
success = await job_queue.update_status(
|
||||
job_id=job_id,
|
||||
status=OCRJobStatus.cancelled,
|
||||
error="Cancelled by user"
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la anularea job-ului"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[CancelJob] Job {job_id} cancelled by {current_user.username} "
|
||||
f"(previous status: {current_status})"
|
||||
)
|
||||
|
||||
return CancelJobResponse(
|
||||
success=True,
|
||||
job_id=job_id,
|
||||
cancelled_at=cancelled_at,
|
||||
message=f"Job anulat cu succes"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cancel-batch/{batch_id}", response_model=CancelBatchResponse)
|
||||
async def cancel_batch(
|
||||
batch_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Cancel all pending/processing jobs in a batch.
|
||||
|
||||
Finds all jobs with status 'pending' or 'processing' in the specified batch
|
||||
and marks them as 'cancelled'. Jobs with status 'completed' or 'failed'
|
||||
are not affected.
|
||||
|
||||
Important: Receipts that have already been created from completed jobs
|
||||
will NOT be deleted - they are preserved for audit purposes.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to cancel
|
||||
|
||||
Returns:
|
||||
CancelBatchResponse with counts of cancelled and skipped jobs
|
||||
|
||||
Raises:
|
||||
404: If batch not found or no jobs exist for batch
|
||||
"""
|
||||
# Verify batch exists
|
||||
batch_result = await session.execute(
|
||||
select(BatchUpload).where(BatchUpload.id == batch_id)
|
||||
)
|
||||
batch = batch_result.scalar_one_or_none()
|
||||
|
||||
if not batch:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Batch {batch_id} nu a fost găsit"
|
||||
)
|
||||
|
||||
# Get all batch_jobs for this batch
|
||||
batch_jobs_result = await session.execute(
|
||||
select(BatchJob).where(BatchJob.batch_id == batch_id)
|
||||
)
|
||||
batch_jobs = batch_jobs_result.scalars().all()
|
||||
|
||||
if not batch_jobs:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Nu există job-uri în batch-ul {batch_id}"
|
||||
)
|
||||
|
||||
# Process each job - cancel pending/processing, skip completed/failed
|
||||
cancelled_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for batch_job in batch_jobs:
|
||||
# Get current job status from OCR job queue
|
||||
ocr_job = await job_queue.get_job(batch_job.job_id)
|
||||
|
||||
if not ocr_job:
|
||||
# Job not found in queue - treat as skipped
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
current_status = ocr_job.status.value
|
||||
|
||||
# Only cancel pending or processing jobs
|
||||
if current_status in (OCRJobStatus.pending.value, OCRJobStatus.processing.value):
|
||||
success = await job_queue.update_status(
|
||||
job_id=batch_job.job_id,
|
||||
status=OCRJobStatus.cancelled,
|
||||
error="Cancelled by user (batch cancel)"
|
||||
)
|
||||
|
||||
if success:
|
||||
cancelled_count += 1
|
||||
logger.debug(f"[CancelBatch] Cancelled job {batch_job.job_id}")
|
||||
else:
|
||||
# Failed to cancel - count as skipped
|
||||
skipped_count += 1
|
||||
logger.warning(
|
||||
f"[CancelBatch] Failed to cancel job {batch_job.job_id}"
|
||||
)
|
||||
else:
|
||||
# Job is completed, failed, or already cancelled - skip it
|
||||
skipped_count += 1
|
||||
|
||||
logger.info(
|
||||
f"[CancelBatch] Batch {batch_id} cancelled by {current_user.username}: "
|
||||
f"{cancelled_count} cancelled, {skipped_count} skipped"
|
||||
)
|
||||
|
||||
# Build message
|
||||
if cancelled_count == 0:
|
||||
message = f"Nu există job-uri de anulat în batch-ul {batch_id}"
|
||||
elif skipped_count == 0:
|
||||
message = f"{cancelled_count} job-uri anulate"
|
||||
else:
|
||||
message = f"{cancelled_count} job-uri anulate, {skipped_count} ignorate (deja procesate)"
|
||||
|
||||
return CancelBatchResponse(
|
||||
success=cancelled_count > 0,
|
||||
batch_id=batch_id,
|
||||
cancelled_count=cancelled_count,
|
||||
skipped_count=skipped_count,
|
||||
message=message
|
||||
)
|
||||
@@ -0,0 +1,260 @@
|
||||
"""Nomenclature API endpoints."""
|
||||
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.services.sync_service import SyncService
|
||||
|
||||
# Import auth dependencies
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# Path setup handled by main.py - this is redundant
|
||||
# project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
# sys.path.insert(0, str(project_root / "shared"))
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Selected Company Dependency ============
|
||||
|
||||
async def get_selected_company(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
x_selected_company: Annotated[Optional[str], Header()] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get selected company from X-Selected-Company header.
|
||||
Validates user access. Falls back to first company if no header.
|
||||
"""
|
||||
if x_selected_company:
|
||||
try:
|
||||
company_id = int(x_selected_company)
|
||||
except ValueError:
|
||||
raise HTTPException(400, f"Invalid company ID: {x_selected_company}")
|
||||
|
||||
if str(company_id) in current_user.companies:
|
||||
return company_id
|
||||
raise HTTPException(403, f"Nu aveți acces la firma {company_id}")
|
||||
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise HTTPException(400, "Nu aveți nicio firmă asignată")
|
||||
|
||||
|
||||
SelectedCompany = Annotated[int, Depends(get_selected_company)]
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class SupplierSearchResult(BaseModel):
|
||||
found: bool
|
||||
supplier: Optional[dict] = None
|
||||
source: str # 'synced', 'local', 'not_found'
|
||||
|
||||
|
||||
class LocalSupplierCreate(BaseModel):
|
||||
name: str
|
||||
fiscal_code: Optional[str] = None
|
||||
address: Optional[str] = None
|
||||
|
||||
|
||||
class LocalSupplierResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
fiscal_code: Optional[str]
|
||||
address: Optional[str]
|
||||
is_local: bool = True
|
||||
|
||||
|
||||
class SyncResult(BaseModel):
|
||||
synced: int
|
||||
errors: int
|
||||
message: str
|
||||
|
||||
|
||||
class SupplierOption(BaseModel):
|
||||
id: int
|
||||
oracle_id: Optional[int] = None
|
||||
name: str
|
||||
fiscal_code: Optional[str]
|
||||
source: str # 'synced' or 'local'
|
||||
|
||||
|
||||
class CashRegisterOption(BaseModel):
|
||||
id: int
|
||||
oracle_id: int
|
||||
name: str
|
||||
account_code: str
|
||||
register_type: str
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/suppliers/search", response_model=SupplierSearchResult)
|
||||
async def search_supplier(
|
||||
fiscal_code: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Search for supplier by fiscal code or name."""
|
||||
if not fiscal_code and not name:
|
||||
raise HTTPException(status_code=400, detail="Provide fiscal_code or name")
|
||||
|
||||
cid = company_id or selected_company
|
||||
|
||||
found, supplier, source = await SyncService.search_supplier(
|
||||
session, cid, fiscal_code, name
|
||||
)
|
||||
|
||||
return SupplierSearchResult(found=found, supplier=supplier, source=source)
|
||||
|
||||
|
||||
@router.get("/suppliers", response_model=List[SupplierOption])
|
||||
async def get_suppliers(
|
||||
search: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all suppliers (synced + local) for dropdown/autocomplete."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
suppliers = await SyncService.get_all_suppliers(session, cid, search)
|
||||
|
||||
return [
|
||||
SupplierOption(
|
||||
id=s["id"],
|
||||
oracle_id=s.get("oracle_id"),
|
||||
name=s["name"],
|
||||
fiscal_code=s.get("fiscal_code"),
|
||||
source=s["source"]
|
||||
)
|
||||
for s in suppliers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/suppliers/local", response_model=LocalSupplierResponse)
|
||||
async def create_local_supplier(
|
||||
data: LocalSupplierCreate,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a local supplier from OCR data."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
supplier = await SyncService.create_local_supplier(
|
||||
session, cid, data.name, data.fiscal_code, data.address, current_user.username
|
||||
)
|
||||
|
||||
return LocalSupplierResponse(
|
||||
id=supplier.id,
|
||||
name=supplier.name,
|
||||
fiscal_code=supplier.fiscal_code,
|
||||
address=supplier.address,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cash-registers", response_model=List[CashRegisterOption])
|
||||
async def get_cash_registers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all cash registers for a company."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
registers = await SyncService.get_all_cash_registers(session, cid)
|
||||
|
||||
return [
|
||||
CashRegisterOption(
|
||||
id=r["id"],
|
||||
oracle_id=r["oracle_id"],
|
||||
name=r["name"],
|
||||
account_code=r["account_code"],
|
||||
register_type=r["register_type"]
|
||||
)
|
||||
for r in registers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/sync/suppliers", response_model=SyncResult)
|
||||
async def sync_suppliers(
|
||||
request: Request,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Manually trigger supplier sync from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
server_id = getattr(request.state, 'server_id', None)
|
||||
|
||||
synced, errors = await SyncService.sync_suppliers(session, cid, server_id=server_id)
|
||||
|
||||
return SyncResult(
|
||||
synced=synced,
|
||||
errors=errors,
|
||||
message=f"Synced {synced} suppliers with {errors} errors"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/cash-registers", response_model=SyncResult)
|
||||
async def sync_cash_registers(
|
||||
request: Request,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Manually trigger cash register sync from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
server_id = getattr(request.state, 'server_id', None)
|
||||
|
||||
synced, errors = await SyncService.sync_cash_registers(session, cid, server_id=server_id)
|
||||
|
||||
return SyncResult(
|
||||
synced=synced,
|
||||
errors=errors,
|
||||
message=f"Synced {synced} cash registers with {errors} errors"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/all", response_model=dict)
|
||||
async def sync_all_nomenclatures(
|
||||
request: Request,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Sync all nomenclatures (suppliers + cash registers) from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
server_id = getattr(request.state, 'server_id', None)
|
||||
|
||||
# Sync suppliers
|
||||
suppliers_synced, suppliers_errors = await SyncService.sync_suppliers(session, cid, server_id=server_id)
|
||||
|
||||
# Sync cash registers
|
||||
registers_synced, registers_errors = await SyncService.sync_cash_registers(session, cid, server_id=server_id)
|
||||
|
||||
return {
|
||||
"suppliers": {
|
||||
"synced": suppliers_synced,
|
||||
"errors": suppliers_errors
|
||||
},
|
||||
"cash_registers": {
|
||||
"synced": registers_synced,
|
||||
"errors": registers_errors
|
||||
},
|
||||
"total_synced": suppliers_synced + registers_synced,
|
||||
"total_errors": suppliers_errors + registers_errors,
|
||||
"message": f"Synced {suppliers_synced} suppliers and {registers_synced} cash registers"
|
||||
}
|
||||
@@ -0,0 +1,715 @@
|
||||
"""
|
||||
OCR API endpoints with async job queue support.
|
||||
|
||||
Endpoints:
|
||||
- POST /extract - Submit OCR job (returns job_id immediately)
|
||||
- GET /jobs/{job_id} - Get job status and result
|
||||
- GET /queue/status - Get queue statistics
|
||||
- GET /status - Check OCR service availability
|
||||
|
||||
For backwards compatibility, we also support sync mode via query param:
|
||||
- POST /extract?sync=true - Process synchronously (blocks until complete)
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
|
||||
from backend.modules.data_entry.services.ocr_service import ocr_service
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
|
||||
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
from backend.modules.data_entry.schemas.ocr import (
|
||||
OCRResponse,
|
||||
OCRStatusResponse,
|
||||
ExtractionData,
|
||||
TvaEntry,
|
||||
PaymentMethod,
|
||||
# New job queue schemas
|
||||
OCREngineChoice,
|
||||
OCRJobStatus,
|
||||
OCRJobSubmitResponse,
|
||||
OCRJobResponse,
|
||||
OCRQueueStatusResponse,
|
||||
)
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Job Queue Endpoints (NEW)
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/extract", response_model=OCRJobSubmitResponse)
|
||||
async def submit_ocr_job(
|
||||
file: UploadFile = File(...),
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus, description="OCR engine to use"),
|
||||
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Submit an OCR job for processing.
|
||||
|
||||
By default, returns immediately with a job_id. Poll GET /jobs/{job_id} for result.
|
||||
|
||||
Use ?sync=true for synchronous processing (blocks until complete).
|
||||
This is for backwards compatibility but not recommended for production.
|
||||
|
||||
Args:
|
||||
file: Image or PDF file (max 10MB)
|
||||
engine: OCR engine choice (tesseract, doctr, doctr_plus, paddleocr)
|
||||
sync: If true, process synchronously (legacy mode)
|
||||
|
||||
Returns:
|
||||
OCRJobSubmitResponse with job_id, queue_position, estimated_wait
|
||||
"""
|
||||
allowed_types = ['image/jpeg', 'image/png', 'application/pdf']
|
||||
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Check file size (10MB limit)
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File too large. Maximum size is 10MB."
|
||||
)
|
||||
|
||||
# Sync mode - use legacy processing (blocks)
|
||||
if sync:
|
||||
return await _process_sync(content, file, engine, current_user)
|
||||
|
||||
# Async mode - create job and return immediately
|
||||
try:
|
||||
job = await job_queue.create_job(
|
||||
file_bytes=content,
|
||||
mime_type=file.content_type,
|
||||
engine=engine.value,
|
||||
username=current_user.username,
|
||||
original_filename=file.filename
|
||||
)
|
||||
|
||||
# Get queue position
|
||||
queue_position = await job_queue.get_queue_position(job.id)
|
||||
estimated_wait = estimate_wait_time(queue_position or 1)
|
||||
|
||||
return OCRJobSubmitResponse(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.pending,
|
||||
queue_position=queue_position or 1,
|
||||
estimated_wait_seconds=estimated_wait,
|
||||
created_at=job.created_at or datetime.utcnow()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create OCR job: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
|
||||
async def get_job_status(
|
||||
job_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR job status and result (instant response).
|
||||
|
||||
For efficient polling, use GET /jobs/{job_id}/wait instead (long-polling).
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
|
||||
Returns:
|
||||
OCRJobResponse with status, queue_position, and result (if completed)
|
||||
"""
|
||||
job = await job_queue.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# Get queue position for pending jobs
|
||||
queue_position = None
|
||||
estimated_wait = None
|
||||
|
||||
if job.status == JobStatus.pending:
|
||||
queue_position = await job_queue.get_queue_position(job_id)
|
||||
estimated_wait = estimate_wait_time(queue_position or 1)
|
||||
elif job.status == JobStatus.processing:
|
||||
queue_position = 0
|
||||
# Estimate remaining time based on average
|
||||
avg_time = await job_queue.get_average_processing_time()
|
||||
estimated_wait = int(avg_time * 0.5) # Rough estimate: half remaining
|
||||
|
||||
# Convert result to ExtractionData if available
|
||||
result_data = None
|
||||
if job.status == JobStatus.completed and job.result:
|
||||
result_data = _dict_to_extraction_data(job.result)
|
||||
# Apply fuzzy CUI matching
|
||||
result_data = await _apply_fuzzy_cui_matching(result_data, session)
|
||||
# Debug: log suggested_payment_mode being returned
|
||||
print(f"[OCR Router] Returning job {job_id} with suggested_payment_mode={result_data.suggested_payment_mode}", flush=True)
|
||||
|
||||
return OCRJobResponse(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus(job.status.value),
|
||||
queue_position=queue_position,
|
||||
estimated_wait_seconds=estimated_wait,
|
||||
created_at=job.created_at or datetime.utcnow(),
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
queue_wait_ms=job.queue_wait_ms,
|
||||
ocr_time_ms=job.ocr_time_ms,
|
||||
processing_time_ms=job.processing_time_ms,
|
||||
result=result_data,
|
||||
error=job.error_message
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/wait", response_model=OCRJobResponse)
|
||||
async def wait_for_job_status(
|
||||
job_id: str,
|
||||
response: Response,
|
||||
timeout: int = Query(default=30, ge=1, le=60, description="Max wait time in seconds"),
|
||||
wait_for_terminal: bool = Query(default=False, description="If true, only return on completed/failed"),
|
||||
_t: int = Query(default=None, description="Cache-busting timestamp (ignored)"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Long-poll for OCR job status change.
|
||||
|
||||
Waits until:
|
||||
- Job status changes (default behavior - returns on any status change)
|
||||
- Job reaches terminal state (if wait_for_terminal=true)
|
||||
- Timeout expires (returns current status)
|
||||
|
||||
Recommended client timeout: timeout + 5 seconds
|
||||
|
||||
Args:
|
||||
job_id: Job UUID from POST /extract response
|
||||
timeout: Max wait time in seconds (1-60, default 30)
|
||||
wait_for_terminal: If true, wait until completed/failed only
|
||||
|
||||
Returns:
|
||||
OCRJobResponse with status, queue_position, and result (if completed)
|
||||
"""
|
||||
# Prevent caching - critical for long-polling
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
end_time = start_time + timeout
|
||||
last_status = None
|
||||
iteration = 0
|
||||
|
||||
print(f"[OCR Wait] Starting long-poll for job {job_id}, timeout={timeout}s, wait_for_terminal={wait_for_terminal}", flush=True)
|
||||
|
||||
while time.time() < end_time:
|
||||
iteration += 1
|
||||
job = await job_queue.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
print(f"[OCR Wait] Job {job_id} not found after {iteration} iterations", flush=True)
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# Return immediately if job completed or failed (terminal states)
|
||||
if job.status in [JobStatus.completed, JobStatus.failed]:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[OCR Wait] Job {job_id} {job.status.value} after {elapsed:.1f}s ({iteration} iterations)", flush=True)
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
# Return on status change (unless wait_for_terminal is set)
|
||||
if not wait_for_terminal and last_status is not None and job.status != last_status:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[OCR Wait] Job {job_id} status changed {last_status.value}->{job.status.value} after {elapsed:.1f}s", flush=True)
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
last_status = job.status
|
||||
|
||||
# Wait 500ms before next internal check (faster polling for better responsiveness)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Timeout - return current status
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[OCR Wait] Job {job_id} timeout after {elapsed:.1f}s ({iteration} iterations), status={last_status.value if last_status else 'unknown'}", flush=True)
|
||||
return await get_job_status(job_id, session, current_user)
|
||||
|
||||
|
||||
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
|
||||
async def get_queue_status(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR queue statistics.
|
||||
|
||||
Returns:
|
||||
Queue status with pending/processing counts and average time
|
||||
"""
|
||||
stats = await job_queue.get_queue_stats()
|
||||
|
||||
return OCRQueueStatusResponse(
|
||||
pending_jobs=stats["pending"],
|
||||
processing_jobs=stats["processing"],
|
||||
average_time_seconds=stats["average_time_seconds"]
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Legacy Endpoints (backwards compatibility)
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/status", response_model=OCRStatusResponse)
|
||||
async def get_ocr_status():
|
||||
"""Check OCR service status and available engines."""
|
||||
engines = OCREngine.get_available_engines()
|
||||
available = len(engines) > 0
|
||||
|
||||
if available:
|
||||
message = f"OCR service ready with engines: {', '.join(engines)}"
|
||||
else:
|
||||
message = "No OCR engines available. Install PaddleOCR or Tesseract."
|
||||
|
||||
return OCRStatusResponse(
|
||||
available=available,
|
||||
engines=engines,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
@router.get("/engines")
|
||||
async def get_available_engines():
|
||||
"""
|
||||
Get list of enabled OCR engines based on .env configuration.
|
||||
|
||||
Returns engines availability and available processing modes.
|
||||
Frontend should use this to filter engine selection dropdown.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
default_engine = os.getenv("OCR_DEFAULT_ENGINE", "doctr_plus")
|
||||
|
||||
# Build engines dict
|
||||
engines = {
|
||||
"tesseract": tesseract_enabled,
|
||||
"doctr": True, # Always available (primary engine)
|
||||
"doctr_plus": True, # Always available (recommended)
|
||||
"paddleocr": paddle_enabled,
|
||||
}
|
||||
|
||||
# Build available modes based on enabled engines
|
||||
modes = []
|
||||
|
||||
if tesseract_enabled:
|
||||
modes.append("tesseract")
|
||||
|
||||
modes.append("doctr")
|
||||
modes.append("doctr_plus")
|
||||
|
||||
if paddle_enabled:
|
||||
modes.append("paddleocr")
|
||||
|
||||
return {
|
||||
"engines": engines,
|
||||
"available_modes": modes,
|
||||
"default_mode": default_engine,
|
||||
"memory_estimate_mb": {
|
||||
"tesseract": 50,
|
||||
"doctr": 600,
|
||||
"doctr_plus": 600,
|
||||
"paddleocr": 800,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
|
||||
async def extract_from_attachment(
|
||||
attachment_id: int,
|
||||
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Extract receipt data from an existing attachment.
|
||||
|
||||
Re-processes an already uploaded file with OCR.
|
||||
This endpoint always processes synchronously.
|
||||
"""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
# Check if file type is supported
|
||||
if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type not supported for OCR: {attachment.mime_type}"
|
||||
)
|
||||
|
||||
# TODO: Could use job queue here too, but keeping sync for now
|
||||
success, message, result = await ocr_service.process_image(
|
||||
file_path, attachment.mime_type
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
data = _result_to_extraction_data(result)
|
||||
# Apply fuzzy CUI matching
|
||||
data = await _apply_fuzzy_cui_matching(data, session)
|
||||
return OCRResponse(success=True, message=message, data=data)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _apply_fuzzy_cui_matching(
|
||||
extraction_data: ExtractionData,
|
||||
session: AsyncSession
|
||||
) -> ExtractionData:
|
||||
"""
|
||||
Apply fuzzy CUI matching to extraction data.
|
||||
|
||||
ONLY applies fuzzy matching if CUI is missing OR has invalid checksum.
|
||||
If CUI has valid checksum, we trust the OCR and skip fuzzy matching.
|
||||
|
||||
Args:
|
||||
extraction_data: ExtractionData with CUI to potentially correct
|
||||
session: AsyncSession for database lookups
|
||||
|
||||
Returns:
|
||||
ExtractionData with CUI corrected if a match was found
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
|
||||
|
||||
# Skip if no CUI and no vendor name (nothing to match)
|
||||
if not extraction_data.cui and not extraction_data.partner_name:
|
||||
return extraction_data
|
||||
|
||||
# Check if CUI has valid checksum - if valid, skip fuzzy matching
|
||||
if extraction_data.cui:
|
||||
cui_digits = CUIChecksumRule.extract_digits(extraction_data.cui)
|
||||
if len(cui_digits) >= 6 and CUIChecksumRule.validate_checksum(cui_digits):
|
||||
print(f"[Fuzzy Match] CUI {extraction_data.cui} has valid checksum, skipping fuzzy match", flush=True)
|
||||
return extraction_data
|
||||
|
||||
# CUI missing or invalid checksum - try fuzzy matching
|
||||
try:
|
||||
match = await OCRValidationEngine.fuzzy_match_supplier(
|
||||
cui=extraction_data.cui,
|
||||
vendor_name=extraction_data.partner_name,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
if match:
|
||||
corrected_cui, supplier_name = match
|
||||
if corrected_cui != extraction_data.cui:
|
||||
print(f"[Fuzzy Match] Corrected: {extraction_data.cui} -> {corrected_cui} ({supplier_name})", flush=True)
|
||||
extraction_data.cui = corrected_cui
|
||||
# Also set partner_name if not already set
|
||||
if not extraction_data.partner_name:
|
||||
extraction_data.partner_name = supplier_name
|
||||
except Exception as e:
|
||||
print(f"[Fuzzy Match] Error: {e}", flush=True)
|
||||
|
||||
return extraction_data
|
||||
|
||||
|
||||
async def _process_sync(
|
||||
content: bytes,
|
||||
file: UploadFile,
|
||||
engine: OCREngineChoice,
|
||||
current_user: CurrentUser
|
||||
) -> OCRJobSubmitResponse:
|
||||
"""
|
||||
Process OCR synchronously (legacy mode).
|
||||
|
||||
Creates a job, processes it immediately, and returns the result
|
||||
wrapped in a JobSubmitResponse for API consistency.
|
||||
"""
|
||||
# Get file extension
|
||||
suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg'
|
||||
if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']:
|
||||
suffix = '.jpg'
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
success, message, result = await ocr_service.process_image(
|
||||
tmp_path, file.content_type
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
# Create a fake job response with the result embedded
|
||||
# This maintains API compatibility
|
||||
now = datetime.utcnow()
|
||||
|
||||
# For sync mode, we return a special response that includes
|
||||
# the result directly. Clients should check if result is present.
|
||||
return OCRJobSubmitResponse(
|
||||
job_id="sync-" + str(hash(content))[:16],
|
||||
status=OCRJobStatus.completed,
|
||||
queue_position=0,
|
||||
estimated_wait_seconds=0,
|
||||
created_at=now
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if tmp_path.exists():
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def _result_to_extraction_data(result) -> ExtractionData:
|
||||
"""Convert ExtractionResult to ExtractionData schema."""
|
||||
# Convert tva_entries from dict to TvaEntry objects
|
||||
tva_entries_schema = [
|
||||
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
|
||||
for e in result.tva_entries
|
||||
] if result.tva_entries else []
|
||||
|
||||
# Convert payment_methods from dict to PaymentMethod objects
|
||||
payment_methods_list = [
|
||||
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
|
||||
for pm in result.payment_methods
|
||||
] if result.payment_methods else []
|
||||
|
||||
# Auto-suggest payment_mode based on detected methods
|
||||
suggested_payment_mode = None
|
||||
if payment_methods_list:
|
||||
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
|
||||
if has_card:
|
||||
suggested_payment_mode = 'banca'
|
||||
|
||||
return ExtractionData(
|
||||
receipt_type=result.receipt_type,
|
||||
receipt_number=result.receipt_number,
|
||||
receipt_series=result.receipt_series,
|
||||
receipt_date=result.receipt_date,
|
||||
amount=result.amount,
|
||||
partner_name=result.partner_name,
|
||||
cui=result.cui,
|
||||
description=result.description,
|
||||
tva_entries=tva_entries_schema,
|
||||
tva_total=result.tva_total,
|
||||
address=result.address,
|
||||
items_count=result.items_count,
|
||||
payment_methods=payment_methods_list,
|
||||
suggested_payment_mode=suggested_payment_mode,
|
||||
client_name=result.client_name,
|
||||
client_cui=result.client_cui,
|
||||
client_address=result.client_address,
|
||||
confidence_amount=result.confidence_amount,
|
||||
confidence_date=result.confidence_date,
|
||||
confidence_vendor=result.confidence_vendor,
|
||||
confidence_client=getattr(result, 'confidence_client', 0.0),
|
||||
overall_confidence=result.overall_confidence,
|
||||
raw_text=result.raw_text,
|
||||
raw_texts=getattr(result, 'raw_texts', []),
|
||||
ocr_engine=result.ocr_engine,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
needs_manual_review=result.needs_manual_review,
|
||||
validation_warnings=result.validation_warnings,
|
||||
validation_errors=result.validation_errors,
|
||||
inter_ocr_ratios=result.inter_ocr_ratios,
|
||||
)
|
||||
|
||||
|
||||
def _dict_to_extraction_data(data: dict) -> ExtractionData:
|
||||
"""Convert result dict (from job queue) to ExtractionData schema."""
|
||||
from datetime import date
|
||||
|
||||
# Parse date if string
|
||||
receipt_date = data.get('receipt_date')
|
||||
if isinstance(receipt_date, str):
|
||||
try:
|
||||
receipt_date = date.fromisoformat(receipt_date)
|
||||
except (ValueError, TypeError):
|
||||
receipt_date = None
|
||||
|
||||
# Convert tva_entries
|
||||
tva_entries = data.get('tva_entries', []) or []
|
||||
tva_entries_schema = []
|
||||
for e in tva_entries:
|
||||
if isinstance(e, dict):
|
||||
tva_entries_schema.append(TvaEntry(
|
||||
code=e.get('code'),
|
||||
percent=e.get('percent', 0),
|
||||
amount=Decimal(str(e.get('amount', 0)))
|
||||
))
|
||||
|
||||
# Convert payment_methods
|
||||
payment_methods = data.get('payment_methods', []) or []
|
||||
payment_methods_list = []
|
||||
for pm in payment_methods:
|
||||
if isinstance(pm, dict):
|
||||
payment_methods_list.append(PaymentMethod(
|
||||
method=pm.get('method', 'NUMERAR'),
|
||||
amount=Decimal(str(pm.get('amount', 0)))
|
||||
))
|
||||
|
||||
# Convert amount and tva_total to Decimal
|
||||
amount = data.get('amount')
|
||||
if amount is not None:
|
||||
amount = Decimal(str(amount))
|
||||
|
||||
tva_total = data.get('tva_total')
|
||||
if tva_total is not None:
|
||||
tva_total = Decimal(str(tva_total))
|
||||
|
||||
return ExtractionData(
|
||||
receipt_type=data.get('receipt_type', 'bon_fiscal'),
|
||||
receipt_number=data.get('receipt_number'),
|
||||
receipt_series=data.get('receipt_series'),
|
||||
receipt_date=receipt_date,
|
||||
amount=amount,
|
||||
partner_name=data.get('partner_name'),
|
||||
cui=data.get('cui'),
|
||||
description=data.get('description'),
|
||||
tva_entries=tva_entries_schema,
|
||||
tva_total=tva_total,
|
||||
address=data.get('address'),
|
||||
items_count=data.get('items_count'),
|
||||
payment_methods=payment_methods_list,
|
||||
suggested_payment_mode=data.get('suggested_payment_mode'),
|
||||
client_name=data.get('client_name'),
|
||||
client_cui=data.get('client_cui'),
|
||||
client_address=data.get('client_address'),
|
||||
confidence_amount=data.get('confidence_amount', 0.0),
|
||||
confidence_date=data.get('confidence_date', 0.0),
|
||||
confidence_vendor=data.get('confidence_vendor', 0.0),
|
||||
confidence_client=data.get('confidence_client', 0.0),
|
||||
confidence_tva=data.get('confidence_tva', 0.0),
|
||||
confidence_payment=data.get('confidence_payment', 0.0),
|
||||
overall_confidence=data.get('overall_confidence', 0.0),
|
||||
raw_text=data.get('raw_text', ''),
|
||||
raw_texts=data.get('raw_texts', []),
|
||||
ocr_engine=data.get('ocr_engine', ''),
|
||||
processing_time_ms=data.get('processing_time_ms', 0),
|
||||
needs_manual_review=data.get('needs_manual_review'),
|
||||
validation_warnings=data.get('validation_warnings', []),
|
||||
validation_errors=data.get('validation_errors', []),
|
||||
inter_ocr_ratios=data.get('inter_ocr_ratios', {}),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Store Profiles Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/profiles/reload")
|
||||
async def reload_store_profiles(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Hot-reload all store profiles.
|
||||
|
||||
Reloads profile Python modules without server restart.
|
||||
Use after adding/modifying profile files.
|
||||
|
||||
Returns:
|
||||
Dict with reloaded count and profile list
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
|
||||
|
||||
count = ProfileRegistry.reload_all()
|
||||
status = ProfileRegistry.get_reload_status()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"reloaded_modules": count,
|
||||
"profiles_count": status["profiles_count"],
|
||||
"registered_cuis": status["registered_cuis"],
|
||||
"last_reload": status["last_reload"],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/profiles")
|
||||
async def list_store_profiles(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
List all registered store profiles.
|
||||
|
||||
Returns:
|
||||
Dict with profiles list and status
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
|
||||
|
||||
profiles = ProfileRegistry.list_profiles()
|
||||
status = ProfileRegistry.get_reload_status()
|
||||
|
||||
return {
|
||||
"profiles": profiles,
|
||||
"count": len(profiles),
|
||||
"last_reload": status["last_reload"],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/profiles/{cui}")
|
||||
async def get_store_profile(
|
||||
cui: str,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Get details for a specific store profile.
|
||||
|
||||
Args:
|
||||
cui: Store CUI (with or without RO prefix)
|
||||
|
||||
Returns:
|
||||
Profile details including validation hints
|
||||
|
||||
Raises:
|
||||
404: If no profile exists for this CUI
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
|
||||
|
||||
info = ProfileRegistry.get_profile_info(cui)
|
||||
|
||||
if not info:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No profile registered for CUI: {cui}"
|
||||
)
|
||||
|
||||
return info
|
||||
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
OCR Settings and Metrics API endpoints.
|
||||
|
||||
Endpoints:
|
||||
- GET /settings/ocr-preference - Get user's preferred OCR engine
|
||||
- POST /settings/ocr-preference - Set user's preferred OCR engine
|
||||
- GET /metrics/ocr/summary - Get OCR metrics summary by engine
|
||||
- GET /metrics/ocr/history - Get user's OCR job history
|
||||
- GET /metrics/ocr/stats - Get overall OCR statistics
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
|
||||
from backend.modules.data_entry.db.models.ocr_settings import OCREngine, OCRMetricsSummary
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Schemas
|
||||
# ============================================================================
|
||||
|
||||
class OCRPreferenceResponse(BaseModel):
|
||||
"""Response for OCR preference endpoint."""
|
||||
username: str
|
||||
preferred_engine: str
|
||||
available_engines: List[str] = Field(
|
||||
default=["tesseract", "doctr", "doctr_plus", "paddleocr"],
|
||||
description="Available OCR engines"
|
||||
)
|
||||
|
||||
|
||||
class OCRPreferenceRequest(BaseModel):
|
||||
"""Request to set OCR preference."""
|
||||
preferred_engine: str = Field(
|
||||
default="doctr_plus",
|
||||
description="Preferred OCR engine: tesseract, doctr, doctr_plus, paddleocr"
|
||||
)
|
||||
|
||||
|
||||
class OCRMetricsHistoryItem(BaseModel):
|
||||
"""Single OCR job metrics item."""
|
||||
job_id: str
|
||||
engine_requested: str
|
||||
engine_used: str
|
||||
processing_time_ms: int
|
||||
success: bool
|
||||
overall_confidence: float
|
||||
fields_extracted: int
|
||||
created_at: str
|
||||
original_filename: Optional[str] = None
|
||||
|
||||
|
||||
class OCRMetricsHistoryResponse(BaseModel):
|
||||
"""Response for OCR history endpoint."""
|
||||
items: List[OCRMetricsHistoryItem]
|
||||
total: int
|
||||
|
||||
|
||||
class OCRStatsResponse(BaseModel):
|
||||
"""Response for OCR stats endpoint."""
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
success_rate: float
|
||||
avg_processing_time_ms: float
|
||||
avg_confidence: float
|
||||
period_days: int
|
||||
|
||||
|
||||
class OCRActiveEnginesResponse(BaseModel):
|
||||
"""Response for active OCR engines endpoint."""
|
||||
engines: List[str] = Field(description="List of active OCR engines from .env config")
|
||||
recommended: str = Field(default="doctr_plus", description="Recommended engine")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Engines Configuration Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-engines", response_model=OCRActiveEnginesResponse)
|
||||
async def get_active_ocr_engines():
|
||||
"""
|
||||
Get list of active OCR engines configured in .env.
|
||||
|
||||
Returns the engines that should be shown in the frontend dropdown.
|
||||
Configured via OCR_ACTIVE_ENGINES environment variable.
|
||||
|
||||
Default: doctr,doctr_plus
|
||||
Available: tesseract, paddleocr, doctr, doctr_plus
|
||||
"""
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
return OCRActiveEnginesResponse(
|
||||
engines=settings.ocr_active_engines_list,
|
||||
recommended="doctr_plus"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Preference Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def get_ocr_preference(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's preferred OCR engine.
|
||||
|
||||
Returns the user's saved preference or 'doctr_plus' if not set.
|
||||
Also returns list of available engines.
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
preference = await OCRPreferenceCRUD.get_by_username(session, current_user.username)
|
||||
|
||||
# Get available engines from OCR service
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value if preference else "doctr_plus",
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
@router.post("/settings/ocr-preference", response_model=OCRPreferenceResponse)
|
||||
async def set_ocr_preference(
|
||||
request: OCRPreferenceRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Set user's preferred OCR engine.
|
||||
|
||||
Valid engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
Note: Available engines depend on .env configuration (OCR_ENABLE_PADDLEOCR, OCR_ENABLE_TESSERACT)
|
||||
"""
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
|
||||
|
||||
# Get dynamically available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
if request.preferred_engine not in available:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid engine. Must be one of: {', '.join(available)}"
|
||||
)
|
||||
|
||||
# Map string to enum
|
||||
engine_map = {
|
||||
"tesseract": OCREngine.TESSERACT,
|
||||
"doctr": OCREngine.DOCTR,
|
||||
"doctr_plus": OCREngine.DOCTR_PLUS,
|
||||
"paddleocr": OCREngine.PADDLEOCR,
|
||||
}
|
||||
engine_enum = engine_map.get(request.preferred_engine, OCREngine.DOCTR_PLUS)
|
||||
|
||||
# Save preference
|
||||
preference = await OCRPreferenceCRUD.create_or_update(
|
||||
session,
|
||||
current_user.username,
|
||||
engine_enum
|
||||
)
|
||||
|
||||
# Get available engines
|
||||
available = OCREngineClass.get_available_engines()
|
||||
|
||||
return OCRPreferenceResponse(
|
||||
username=current_user.username,
|
||||
preferred_engine=preference.preferred_engine.value,
|
||||
available_engines=available
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Metrics Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/ocr/summary", response_model=List[OCRMetricsSummary])
|
||||
async def get_ocr_metrics_summary(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get OCR metrics summary grouped by engine.
|
||||
|
||||
Returns aggregated metrics for each engine used in the specified period.
|
||||
"""
|
||||
summaries = await OCRMetricsCRUD.get_summary_by_engine(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/history", response_model=OCRMetricsHistoryResponse)
|
||||
async def get_ocr_metrics_history(
|
||||
limit: int = Query(default=50, ge=1, le=200, description="Max items to return"),
|
||||
offset: int = Query(default=0, ge=0, description="Items to skip"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get user's OCR job history.
|
||||
|
||||
Returns list of OCR jobs with their metrics, ordered by most recent first.
|
||||
"""
|
||||
items = await OCRMetricsCRUD.get_user_history(
|
||||
session,
|
||||
username=current_user.username,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
history_items = [
|
||||
OCRMetricsHistoryItem(
|
||||
job_id=item.job_id,
|
||||
engine_requested=item.engine_requested,
|
||||
engine_used=item.engine_used,
|
||||
processing_time_ms=item.processing_time_ms,
|
||||
success=item.success,
|
||||
overall_confidence=item.overall_confidence,
|
||||
fields_extracted=item.fields_extracted,
|
||||
created_at=item.created_at.isoformat(),
|
||||
original_filename=item.original_filename
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return OCRMetricsHistoryResponse(
|
||||
items=history_items,
|
||||
total=len(history_items)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/ocr/stats", response_model=OCRStatsResponse)
|
||||
async def get_ocr_stats(
|
||||
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get overall OCR statistics for the user.
|
||||
|
||||
Returns aggregated stats including success rate, average processing time, etc.
|
||||
"""
|
||||
stats = await OCRMetricsCRUD.get_overall_stats(
|
||||
session,
|
||||
days=days,
|
||||
username=current_user.username
|
||||
)
|
||||
|
||||
return OCRStatsResponse(**stats)
|
||||
@@ -0,0 +1,705 @@
|
||||
"""API endpoints for receipts."""
|
||||
|
||||
from typing import List, Optional, Annotated
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Header, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
|
||||
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
|
||||
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
|
||||
from backend.modules.data_entry.services.receipt_service import ReceiptService
|
||||
from backend.modules.data_entry.services.nomenclature_service import NomenclatureService
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
ReceiptFilter,
|
||||
ProcessingStats,
|
||||
AttachmentResponse,
|
||||
AccountingEntryResponse,
|
||||
WorkflowAction,
|
||||
RejectRequest,
|
||||
EntriesUpdateRequest,
|
||||
PartnerOption,
|
||||
AccountOption,
|
||||
CashRegisterOption,
|
||||
ExpenseTypeOption,
|
||||
BulkDeleteRequest,
|
||||
BulkDeleteResponse,
|
||||
BulkDeleteFailure,
|
||||
)
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptStatus, ReceiptDirection
|
||||
from backend.modules.data_entry.services import sse_service
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Helper for selected company from header ============
|
||||
|
||||
async def get_selected_company(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
x_selected_company: Annotated[Optional[str], Header()] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get selected company from X-Selected-Company header.
|
||||
|
||||
Validates that the user has access to the specified company.
|
||||
Falls back to user's first company if no header is provided.
|
||||
|
||||
Raises:
|
||||
HTTPException 403: If user doesn't have access to specified company
|
||||
HTTPException 400: If user has no companies assigned
|
||||
"""
|
||||
if x_selected_company:
|
||||
try:
|
||||
company_id = int(x_selected_company)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid company ID format: {x_selected_company}"
|
||||
)
|
||||
|
||||
# Validate user has access to this company
|
||||
# Auth stores companies as strings
|
||||
if str(company_id) in current_user.companies:
|
||||
return company_id
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Nu aveți acces la firma {company_id}"
|
||||
)
|
||||
|
||||
# No header - use first company from user's list
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Nu aveți nicio firmă asignată"
|
||||
)
|
||||
|
||||
|
||||
# Dependency for injection
|
||||
SelectedCompany = Annotated[int, Depends(get_selected_company)]
|
||||
|
||||
|
||||
# Legacy function for backwards compatibility (deprecated)
|
||||
def get_current_user_company(current_user: CurrentUser) -> int:
|
||||
"""
|
||||
DEPRECATED: Use get_selected_company() dependency instead.
|
||||
This function returns the first company, ignoring X-Selected-Company header.
|
||||
"""
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
return 1
|
||||
return 1
|
||||
|
||||
|
||||
# ============ SSE Endpoint for Real-time Status Updates ============
|
||||
|
||||
@router.get("/sse/status")
|
||||
async def sse_status_stream(
|
||||
batch_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Optional batch_id to filter events for a specific batch"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Server-Sent Events endpoint for real-time receipt status updates.
|
||||
|
||||
This endpoint provides a persistent connection that streams status change
|
||||
events as they occur. Clients receive updates for CRUD operations on receipts
|
||||
without needing to poll.
|
||||
|
||||
Query Parameters:
|
||||
batch_id: Optional filter to only receive events for a specific batch upload.
|
||||
|
||||
Event Format:
|
||||
data: {"receipt_id": 123, "status": "DRAFT", "processing_status": "completed", ...}
|
||||
|
||||
Headers:
|
||||
- Content-Type: text/event-stream
|
||||
- Cache-Control: no-cache
|
||||
- Connection: keep-alive
|
||||
|
||||
Reconnection:
|
||||
The retry: 3000 header hints clients to reconnect after 3 seconds if disconnected.
|
||||
|
||||
Example:
|
||||
curl -N http://localhost:8000/api/data-entry/receipts/sse/status
|
||||
curl -N http://localhost:8000/api/data-entry/receipts/sse/status?batch_id=abc-123
|
||||
"""
|
||||
return StreamingResponse(
|
||||
sse_service.subscribe(batch_id=batch_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============ Receipt CRUD Endpoints ============
|
||||
|
||||
@router.post("/", response_model=ReceiptResponse)
|
||||
async def create_receipt(
|
||||
data: ReceiptCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new receipt in DRAFT status."""
|
||||
receipt = await ReceiptService.create_receipt(session, data, current_user.username)
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.get("/", response_model=ReceiptListResponse)
|
||||
async def list_receipts(
|
||||
response: Response,
|
||||
status: Optional[ReceiptStatus] = None,
|
||||
direction: Optional[ReceiptDirection] = None,
|
||||
company_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
# Bulk upload filters (US-012)
|
||||
processing_status: Optional[str] = Query(default=None, description="Filter by processing status: pending, processing, completed, failed"),
|
||||
batch_id: Optional[str] = Query(default=None, description="Filter by batch_id UUID"),
|
||||
sort_by: Optional[str] = Query(default=None, description="Sort field: processing_started_at, processing_started_at_asc"),
|
||||
# Pagination
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=100),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get paginated list of receipts with filters.
|
||||
|
||||
US-012: Extended with batch_id, processing_status filters and processing_stats.
|
||||
"""
|
||||
# Disable browser caching to always get fresh data
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
from datetime import date as date_type
|
||||
|
||||
filters = ReceiptFilter(
|
||||
status=status,
|
||||
direction=direction,
|
||||
company_id=company_id or selected_company,
|
||||
created_by=created_by,
|
||||
date_from=date_type.fromisoformat(date_from) if date_from else None,
|
||||
date_to=date_type.fromisoformat(date_to) if date_to else None,
|
||||
search=search,
|
||||
processing_status=processing_status,
|
||||
batch_id=batch_id,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return await ReceiptService.get_receipts(session, filters)
|
||||
|
||||
|
||||
@router.get("/pending", response_model=List[ReceiptResponse])
|
||||
async def list_pending_receipts(
|
||||
response: Response,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all receipts pending review (for accountant view)."""
|
||||
# Disable browser caching to always get fresh data
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
receipts = await ReceiptCRUD.get_pending_review(
|
||||
session, company_id or selected_company
|
||||
)
|
||||
return [ReceiptResponse.model_validate(r) for r in receipts]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_receipt_stats(
|
||||
response: Response,
|
||||
company_id: Optional[int] = None,
|
||||
my_receipts: bool = False,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Get receipt statistics."""
|
||||
# Disable browser caching to always get fresh data
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
return await ReceiptCRUD.get_stats(
|
||||
session,
|
||||
company_id or selected_company,
|
||||
created_by=current_user.username if my_receipts else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{receipt_id}", response_model=ReceiptResponse)
|
||||
async def get_receipt(
|
||||
receipt_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get receipt details with attachments and accounting entries."""
|
||||
# Disable browser caching to always get fresh data
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
receipt = await ReceiptService.get_receipt(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.put("/{receipt_id}", response_model=ReceiptResponse)
|
||||
async def update_receipt(
|
||||
receipt_id: int,
|
||||
data: ReceiptUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update receipt (only DRAFT status, only by creator)."""
|
||||
success, message, receipt = await ReceiptService.update_receipt(
|
||||
session, receipt_id, data, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.delete("/bulk", response_model=BulkDeleteResponse)
|
||||
async def bulk_delete_receipts(
|
||||
data: BulkDeleteRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Bulk delete receipts (US-024).
|
||||
|
||||
Deletes multiple receipts in a single request with partial success support.
|
||||
|
||||
Validation rules:
|
||||
- Each receipt must be in DRAFT status
|
||||
- Each receipt must be created by the current user
|
||||
- Receipts with processing_status 'pending' or 'processing' cannot be deleted
|
||||
|
||||
Returns:
|
||||
BulkDeleteResponse with deleted IDs and failed items with error messages
|
||||
"""
|
||||
deleted: List[int] = []
|
||||
failed: List[BulkDeleteFailure] = []
|
||||
|
||||
for receipt_id in data.ids:
|
||||
# Get receipt with relationships for deletion
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
|
||||
|
||||
if not receipt:
|
||||
failed.append(BulkDeleteFailure(id=receipt_id, error="Bonul nu a fost găsit"))
|
||||
continue
|
||||
|
||||
# Check if receipt is being processed (bulk upload in progress)
|
||||
if receipt.processing_status in ["pending", "processing"]:
|
||||
failed.append(BulkDeleteFailure(
|
||||
id=receipt_id,
|
||||
error="Bonul este în curs de procesare și nu poate fi șters"
|
||||
))
|
||||
continue
|
||||
|
||||
# Check status - only DRAFT can be deleted
|
||||
if receipt.status != ReceiptStatus.DRAFT:
|
||||
failed.append(BulkDeleteFailure(
|
||||
id=receipt_id,
|
||||
error=f"Doar bonurile în status DRAFT pot fi șterse (status curent: {receipt.status.value})"
|
||||
))
|
||||
continue
|
||||
|
||||
# Check ownership
|
||||
if receipt.created_by != current_user.username:
|
||||
failed.append(BulkDeleteFailure(
|
||||
id=receipt_id,
|
||||
error="Doar creatorul bonului poate să-l șteargă"
|
||||
))
|
||||
continue
|
||||
|
||||
# All validations passed - delete the receipt
|
||||
# Note: Cascade delete handles attachments and accounting entries
|
||||
await ReceiptCRUD.delete(session, receipt)
|
||||
deleted.append(receipt_id)
|
||||
|
||||
return BulkDeleteResponse(deleted=deleted, failed=failed)
|
||||
|
||||
|
||||
@router.delete("/{receipt_id}")
|
||||
async def delete_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete receipt (only DRAFT status, only by creator)."""
|
||||
success, message = await ReceiptService.delete_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return {"success": True, "message": message}
|
||||
|
||||
|
||||
# ============ Workflow Endpoints ============
|
||||
|
||||
@router.post("/{receipt_id}/submit", response_model=WorkflowAction)
|
||||
async def submit_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Submit receipt for review (DRAFT → PENDING_REVIEW)."""
|
||||
success, message, receipt = await ReceiptService.submit_for_review(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
# Broadcast SSE event on success (US-030)
|
||||
if success and receipt:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/approve", response_model=WorkflowAction)
|
||||
async def approve_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Approve receipt (PENDING_REVIEW → APPROVED). Accountant action."""
|
||||
success, message, receipt = await ReceiptService.approve_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
# Broadcast SSE event on success (US-030)
|
||||
if success and receipt:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/reject", response_model=WorkflowAction)
|
||||
async def reject_receipt(
|
||||
receipt_id: int,
|
||||
data: RejectRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Reject receipt (PENDING_REVIEW → REJECTED). Accountant action."""
|
||||
success, message, receipt = await ReceiptService.reject_receipt(
|
||||
session, receipt_id, current_user.username, data.reason
|
||||
)
|
||||
|
||||
# Broadcast SSE event on success (US-030)
|
||||
if success and receipt:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/resubmit", response_model=WorkflowAction)
|
||||
async def resubmit_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW)."""
|
||||
success, message, receipt = await ReceiptService.resubmit_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
# Broadcast SSE event on success (US-030)
|
||||
if success and receipt:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/unapprove", response_model=WorkflowAction)
|
||||
async def unapprove_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Unapprove receipt (APPROVED → PENDING_REVIEW). Returns to pending for corrections."""
|
||||
success, message, receipt = await ReceiptService.unapprove_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
# Broadcast SSE event on success (US-030)
|
||||
if success and receipt:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
# ============ Accounting Entries Endpoints ============
|
||||
|
||||
@router.get("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
|
||||
async def get_receipt_entries(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get accounting entries for a receipt."""
|
||||
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
@router.put("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
|
||||
async def update_receipt_entries(
|
||||
receipt_id: int,
|
||||
data: EntriesUpdateRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update accounting entries for a receipt (accountant action)."""
|
||||
success, message, entries = await ReceiptService.update_entries(
|
||||
session, receipt_id, data.entries, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/entries/regenerate", response_model=List[AccountingEntryResponse])
|
||||
async def regenerate_entries(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Regenerate accounting entries based on receipt data."""
|
||||
success, message, _ = await ReceiptService.regenerate_entries(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
# ============ Attachment Endpoints ============
|
||||
|
||||
@router.post("/{receipt_id}/attachments", response_model=AttachmentResponse)
|
||||
async def upload_attachment(
|
||||
receipt_id: int,
|
||||
file: UploadFile = File(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Upload attachment for a receipt."""
|
||||
# Check receipt exists and user can modify it
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=False)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
# Only allow uploads for DRAFT and REJECTED receipts
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot upload attachments for this receipt status"
|
||||
)
|
||||
|
||||
# Only creator can upload
|
||||
if receipt.created_by != current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can upload attachments"
|
||||
)
|
||||
|
||||
try:
|
||||
attachment = await AttachmentCRUD.create(session, receipt_id, file)
|
||||
return AttachmentResponse.model_validate(attachment)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{receipt_id}/attachments", response_model=List[AttachmentResponse])
|
||||
async def list_attachments(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get all attachments for a receipt."""
|
||||
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AttachmentResponse.model_validate(a) for a in attachments]
|
||||
|
||||
|
||||
@router.get("/attachments/{attachment_id}/download")
|
||||
async def download_attachment(
|
||||
attachment_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Download an attachment file."""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=attachment.filename,
|
||||
media_type=attachment.mime_type,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/attachments/{attachment_id}")
|
||||
async def delete_attachment(
|
||||
attachment_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete an attachment."""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
# Get receipt to check permissions
|
||||
receipt = await ReceiptCRUD.get_by_id(session, attachment.receipt_id, include_relations=False)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
# Only allow deletion for DRAFT receipts by creator
|
||||
if receipt.status != ReceiptStatus.DRAFT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete attachments for this receipt status"
|
||||
)
|
||||
|
||||
if receipt.created_by != current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can delete attachments"
|
||||
)
|
||||
|
||||
await AttachmentCRUD.delete(session, attachment)
|
||||
return {"success": True, "message": "Attachment deleted"}
|
||||
|
||||
|
||||
# ============ Nomenclature Endpoints ============
|
||||
|
||||
@router.get("/nomenclature/partners", response_model=List[PartnerOption])
|
||||
async def get_partners(
|
||||
search: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get partners (suppliers/customers) for dropdown."""
|
||||
return await NomenclatureService.get_partners(
|
||||
company_id or selected_company, search, session
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nomenclature/accounts", response_model=List[AccountOption])
|
||||
async def get_accounts(
|
||||
prefix: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get chart of accounts for dropdown."""
|
||||
return await NomenclatureService.get_accounts(
|
||||
company_id or selected_company, prefix
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nomenclature/cash-registers", response_model=List[CashRegisterOption])
|
||||
async def get_cash_registers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get cash registers and bank accounts for dropdown."""
|
||||
return await NomenclatureService.get_cash_registers(company_id or selected_company, session)
|
||||
|
||||
|
||||
@router.get("/nomenclature/expense-types", response_model=List[ExpenseTypeOption])
|
||||
async def get_expense_types():
|
||||
"""Get predefined expense types for dropdown."""
|
||||
return await NomenclatureService.get_expense_types()
|
||||
@@ -0,0 +1,39 @@
|
||||
# Pydantic schemas
|
||||
from .receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
ReceiptFilter,
|
||||
AttachmentResponse,
|
||||
AccountingEntryCreate,
|
||||
AccountingEntryUpdate,
|
||||
AccountingEntryResponse,
|
||||
WorkflowAction,
|
||||
RejectRequest,
|
||||
)
|
||||
from .bulk import (
|
||||
BulkUploadResponse,
|
||||
BatchJobInfo,
|
||||
BatchStatusResponse,
|
||||
BulkUploadError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCreate",
|
||||
"ReceiptUpdate",
|
||||
"ReceiptResponse",
|
||||
"ReceiptListResponse",
|
||||
"ReceiptFilter",
|
||||
"AttachmentResponse",
|
||||
"AccountingEntryCreate",
|
||||
"AccountingEntryUpdate",
|
||||
"AccountingEntryResponse",
|
||||
"WorkflowAction",
|
||||
"RejectRequest",
|
||||
# Bulk upload schemas
|
||||
"BulkUploadResponse",
|
||||
"BatchJobInfo",
|
||||
"BatchStatusResponse",
|
||||
"BulkUploadError",
|
||||
]
|
||||
@@ -0,0 +1,212 @@
|
||||
"""Pydantic schemas for bulk upload endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BulkUploadResponse(BaseModel):
|
||||
"""Response schema for bulk upload endpoint."""
|
||||
|
||||
batch_id: int = Field(..., description="Unique batch identifier for tracking")
|
||||
job_ids: List[str] = Field(..., description="List of OCR job UUIDs created")
|
||||
total_files: int = Field(..., description="Number of files in the batch")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"batch_id": 1,
|
||||
"job_ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440001",
|
||||
"550e8400-e29b-41d4-a716-446655440002",
|
||||
],
|
||||
"total_files": 2,
|
||||
"message": "2 files queued for processing"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchJobInfo(BaseModel):
|
||||
"""Information about a single job in a batch."""
|
||||
|
||||
job_id: str = Field(..., description="OCR job UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
status: str = Field(..., description="Job status: pending, processing, completed, failed")
|
||||
receipt_id: Optional[int] = Field(None, description="Created receipt ID (if completed)")
|
||||
error_message: Optional[str] = Field(None, description="Error message (if failed)")
|
||||
|
||||
|
||||
class BatchStatusResponse(BaseModel):
|
||||
"""Response schema for batch status endpoint."""
|
||||
|
||||
batch_id: int = Field(..., description="Batch identifier")
|
||||
status: str = Field(..., description="Overall batch status")
|
||||
total_files: int = Field(..., description="Total number of files in batch")
|
||||
pending_count: int = Field(..., description="Number of pending jobs")
|
||||
processing_count: int = Field(..., description="Number of processing jobs")
|
||||
completed_count: int = Field(..., description="Number of completed jobs")
|
||||
failed_count: int = Field(..., description="Number of failed jobs")
|
||||
jobs: List[BatchJobInfo] = Field(..., description="List of jobs with their status")
|
||||
total_amount: Optional[float] = Field(None, description="Sum of all receipt amounts")
|
||||
created_at: datetime = Field(..., description="Batch creation timestamp")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"batch_id": 1,
|
||||
"status": "processing",
|
||||
"total_files": 5,
|
||||
"pending_count": 2,
|
||||
"processing_count": 1,
|
||||
"completed_count": 2,
|
||||
"failed_count": 0,
|
||||
"jobs": [
|
||||
{"job_id": "abc-123", "filename": "bon1.pdf", "status": "completed", "receipt_id": 15},
|
||||
{"job_id": "def-456", "filename": "bon2.jpg", "status": "processing", "receipt_id": None},
|
||||
],
|
||||
"total_amount": 150.50,
|
||||
"created_at": "2025-01-09T10:30:00"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DuplicateFileInfo(BaseModel):
|
||||
"""Information about a duplicate file detected during upload."""
|
||||
|
||||
filename: str = Field(..., description="Name of the duplicate file")
|
||||
error: str = Field(default="duplicate", description="Error type (always 'duplicate')")
|
||||
existing_receipt_id: int = Field(..., description="ID of the existing receipt with same file hash")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"filename": "bon_lidl.pdf",
|
||||
"error": "duplicate",
|
||||
"existing_receipt_id": 123,
|
||||
"message": "Fișier duplicat - există deja ca bon #123"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkUploadResponseWithDuplicates(BaseModel):
|
||||
"""Response schema for bulk upload with partial success (some duplicates)."""
|
||||
|
||||
batch_id: Optional[int] = Field(None, description="Batch ID (None if all files were duplicates)")
|
||||
job_ids: List[str] = Field(default_factory=list, description="List of OCR job UUIDs created")
|
||||
total_files: int = Field(..., description="Total number of files submitted")
|
||||
processed_files: int = Field(..., description="Number of files successfully queued")
|
||||
duplicate_files: int = Field(..., description="Number of duplicate files rejected")
|
||||
duplicates: List[DuplicateFileInfo] = Field(default_factory=list, description="List of duplicate file details")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"batch_id": 1,
|
||||
"job_ids": ["550e8400-e29b-41d4-a716-446655440001"],
|
||||
"total_files": 3,
|
||||
"processed_files": 1,
|
||||
"duplicate_files": 2,
|
||||
"duplicates": [
|
||||
{
|
||||
"filename": "bon_lidl.pdf",
|
||||
"error": "duplicate",
|
||||
"existing_receipt_id": 123,
|
||||
"message": "Fișier duplicat - există deja ca bon #123"
|
||||
}
|
||||
],
|
||||
"message": "1 fișier în procesare, 2 duplicate ignorate"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkUploadError(BaseModel):
|
||||
"""Error response for bulk upload validation failures."""
|
||||
|
||||
detail: str = Field(..., description="Error message")
|
||||
invalid_files: Optional[List[str]] = Field(None, description="List of invalid filenames")
|
||||
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
"""Response schema for retry endpoints."""
|
||||
|
||||
success: bool = Field(..., description="Whether the retry was successful")
|
||||
receipt_id: int = Field(..., description="Receipt ID that was retried")
|
||||
job_id: Optional[str] = Field(None, description="New OCR job ID created")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"receipt_id": 123,
|
||||
"job_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"message": "Bon reîncarcat în procesare"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchRetryResponse(BaseModel):
|
||||
"""Response schema for batch retry endpoint."""
|
||||
|
||||
success: bool = Field(..., description="Whether any retries were successful")
|
||||
batch_id: str = Field(..., description="Batch ID that was retried")
|
||||
retried_count: int = Field(..., description="Number of receipts successfully retried")
|
||||
failed_count: int = Field(..., description="Number of receipts that couldn't be retried")
|
||||
errors: List[str] = Field(default_factory=list, description="List of error messages")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"batch_id": "abc-123",
|
||||
"retried_count": 3,
|
||||
"failed_count": 0,
|
||||
"errors": [],
|
||||
"message": "3 bonuri reîncarcate în procesare"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CancelJobResponse(BaseModel):
|
||||
"""Response schema for cancel job endpoint."""
|
||||
|
||||
success: bool = Field(..., description="Whether the cancellation was successful")
|
||||
job_id: str = Field(..., description="Job ID that was cancelled")
|
||||
cancelled_at: datetime = Field(..., description="Timestamp when the job was cancelled")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"job_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"cancelled_at": "2025-01-11T15:30:00",
|
||||
"message": "Job anulat cu succes"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CancelBatchResponse(BaseModel):
|
||||
"""Response schema for cancel batch endpoint."""
|
||||
|
||||
success: bool = Field(..., description="Whether any jobs were cancelled")
|
||||
batch_id: int = Field(..., description="Batch ID that was cancelled")
|
||||
cancelled_count: int = Field(..., description="Number of jobs successfully cancelled")
|
||||
skipped_count: int = Field(..., description="Number of jobs skipped (completed/failed)")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"batch_id": 1,
|
||||
"cancelled_count": 3,
|
||||
"skipped_count": 2,
|
||||
"message": "3 job-uri anulate, 2 ignorate (deja procesate)"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
"""Pydantic schemas for OCR API."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TvaEntry(BaseModel):
|
||||
"""Single TVA entry with code, percentage and amount."""
|
||||
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
|
||||
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
|
||||
amount: Decimal = Field(description="TVA amount for this rate")
|
||||
|
||||
|
||||
class PaymentMethod(BaseModel):
|
||||
"""Payment method entry from OCR."""
|
||||
method: str = Field(description="CARD or NUMERAR")
|
||||
amount: Decimal = Field(description="Amount paid")
|
||||
|
||||
|
||||
class ValidationWarning(BaseModel):
|
||||
"""Validation warning from OCR extraction."""
|
||||
field: str = Field(description="Field name (e.g., 'amount', 'tva_total')")
|
||||
rule: str = Field(description="Rule name (e.g., 'amount_range', 'tva_ratio')")
|
||||
message: str = Field(description="Human-readable warning message")
|
||||
severity: str = Field(description="Severity: 'info', 'warning', 'error'")
|
||||
suggested_value: Optional[str] = Field(default=None, description="Suggested corrected value")
|
||||
|
||||
|
||||
class ExtractionData(BaseModel):
|
||||
"""Extracted receipt data from OCR."""
|
||||
|
||||
receipt_type: str = Field(default='bon_fiscal', description="Receipt type: bon_fiscal or chitanta")
|
||||
receipt_number: Optional[str] = Field(default=None, description="Receipt number")
|
||||
receipt_series: Optional[str] = Field(default=None, description="Receipt series")
|
||||
receipt_date: Optional[date] = Field(default=None, description="Receipt date")
|
||||
amount: Optional[Decimal] = Field(default=None, description="Total amount")
|
||||
partner_name: Optional[str] = Field(default=None, description="Vendor/partner name")
|
||||
cui: Optional[str] = Field(default=None, description="CUI (fiscal identification code)")
|
||||
description: Optional[str] = Field(default=None, description="Optional description")
|
||||
|
||||
# Additional extracted fields - Multiple TVA entries support
|
||||
tva_entries: List[TvaEntry] = Field(default=[], description="List of TVA entries by rate (A, B, C, D)")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
address: Optional[str] = Field(default=None, description="Vendor address")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items/articles")
|
||||
|
||||
# Payment methods extracted from receipt
|
||||
payment_methods: List[PaymentMethod] = Field(default=[], description="Payment methods from receipt (CARD, NUMERAR)")
|
||||
suggested_payment_mode: Optional[str] = Field(default=None, description="Auto-suggested payment mode based on OCR (casa/banca)")
|
||||
|
||||
# Client data (for B2B receipts - buyer information)
|
||||
client_name: Optional[str] = Field(default=None, description="Client/customer company name")
|
||||
client_cui: Optional[str] = Field(default=None, description="Client CUI/CIF fiscal code")
|
||||
client_address: Optional[str] = Field(default=None, description="Client address")
|
||||
|
||||
confidence_amount: float = Field(default=0.0, ge=0, le=1, description="Amount extraction confidence")
|
||||
confidence_date: float = Field(default=0.0, ge=0, le=1, description="Date extraction confidence")
|
||||
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
|
||||
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
|
||||
confidence_tva: float = Field(default=0.0, ge=0, le=1, description="TVA extraction confidence")
|
||||
confidence_payment: float = Field(default=0.0, ge=0, le=1, description="Payment extraction confidence")
|
||||
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
|
||||
raw_text: str = Field(default="", description="Raw OCR text (primary)")
|
||||
raw_texts: List[str] = Field(default=[], description="Raw OCR texts from all engine passes (for analysis)")
|
||||
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
|
||||
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
|
||||
|
||||
# Validation results (added by bon-ocr-validation feature)
|
||||
# needs_manual_review: None = not validated yet (old receipts), False = no review needed, True = needs review
|
||||
needs_manual_review: Optional[bool] = Field(default=None, description="Flag for supervisor review (None=not validated, False=ok, True=needs review)")
|
||||
validation_warnings: List[str] = Field(default=[], description="Validation warnings")
|
||||
validation_errors: List[str] = Field(default=[], description="Validation errors")
|
||||
inter_ocr_ratios: dict[str, float] = Field(default={}, description="Inter-OCR consistency ratios")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"receipt_type": "bon_fiscal",
|
||||
"receipt_number": "1360760",
|
||||
"receipt_series": "0146",
|
||||
"receipt_date": "2025-10-11",
|
||||
"amount": 186.16,
|
||||
"partner_name": "FIVE-HOLDING S.A.",
|
||||
"cui": "10562600",
|
||||
"description": None,
|
||||
"tva_entries": [
|
||||
{"code": "A", "percent": 19, "amount": 25.00},
|
||||
{"code": "B", "percent": 9, "amount": 7.31}
|
||||
],
|
||||
"tva_total": 32.31,
|
||||
"address": "JUD. CONSTANTA, MUN. CONSTANTA, STR. ION ROATA NR. 3",
|
||||
"items_count": 17,
|
||||
"confidence_amount": 0.98,
|
||||
"confidence_date": 0.98,
|
||||
"confidence_vendor": 0.95,
|
||||
"overall_confidence": 0.97,
|
||||
"raw_text": "FIVE-HOLDING S.A.\nCIF: RO10562600\n..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""OCR API response."""
|
||||
|
||||
success: bool = Field(description="Whether OCR processing was successful")
|
||||
message: str = Field(description="Status message")
|
||||
data: Optional[ExtractionData] = Field(default=None, description="Extracted data")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "OCR processing successful. Found: amount, date, vendor",
|
||||
"data": {
|
||||
"receipt_type": "bon_fiscal",
|
||||
"receipt_number": "12345",
|
||||
"receipt_date": "2024-01-15",
|
||||
"amount": 125.50,
|
||||
"partner_name": "MEGA IMAGE SRL",
|
||||
"cui": "12345678",
|
||||
"confidence_amount": 0.95,
|
||||
"confidence_date": 0.90,
|
||||
"confidence_vendor": 0.75,
|
||||
"overall_confidence": 0.87,
|
||||
"raw_text": "BON FISCAL\nMEGA IMAGE SRL\n..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRStatusResponse(BaseModel):
|
||||
"""OCR service status response."""
|
||||
|
||||
available: bool = Field(description="Whether OCR service is available")
|
||||
engines: list[str] = Field(description="Available OCR engines")
|
||||
message: str = Field(description="Status message")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Job Queue Schemas (for async OCR processing)
|
||||
# ============================================================================
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OCREngineChoice(str, Enum):
|
||||
"""OCR engine selection options."""
|
||||
tesseract = "tesseract"
|
||||
doctr = "doctr" # 3.3x faster than PaddleOCR with same accuracy (90/100)
|
||||
doctr_plus = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
|
||||
paddleocr = "paddleocr"
|
||||
|
||||
|
||||
class OCRJobStatus(str, Enum):
|
||||
"""OCR job status."""
|
||||
pending = "pending"
|
||||
processing = "processing"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class OCRJobSubmitResponse(BaseModel):
|
||||
"""Response when submitting an OCR job."""
|
||||
|
||||
job_id: str = Field(description="Unique job identifier (UUID)")
|
||||
status: OCRJobStatus = Field(description="Initial job status (pending)")
|
||||
queue_position: int = Field(description="Position in queue (1 = next to process)")
|
||||
estimated_wait_seconds: int = Field(description="Estimated wait time in seconds")
|
||||
created_at: datetime = Field(description="Job creation timestamp")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"job_id": "abc123-def456-ghi789",
|
||||
"status": "pending",
|
||||
"queue_position": 3,
|
||||
"estimated_wait_seconds": 21,
|
||||
"created_at": "2024-01-15T12:00:00"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRJobResponse(BaseModel):
|
||||
"""Full OCR job status response."""
|
||||
|
||||
job_id: str = Field(description="Unique job identifier")
|
||||
status: OCRJobStatus = Field(description="Current job status")
|
||||
queue_position: Optional[int] = Field(default=None, description="Queue position (None if processing/completed)")
|
||||
estimated_wait_seconds: Optional[int] = Field(default=None, description="Estimated wait time")
|
||||
created_at: datetime = Field(description="Job creation timestamp")
|
||||
started_at: Optional[datetime] = Field(default=None, description="Processing start timestamp")
|
||||
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
|
||||
# Detailed timing breakdown
|
||||
queue_wait_ms: Optional[int] = Field(default=None, description="Time waiting in queue (started_at - created_at)")
|
||||
ocr_time_ms: Optional[int] = Field(default=None, description="Actual OCR engine processing time")
|
||||
processing_time_ms: Optional[int] = Field(default=None, description="Total job processing time (completed_at - started_at)")
|
||||
result: Optional[ExtractionData] = Field(default=None, description="Extraction result (only if completed)")
|
||||
error: Optional[str] = Field(default=None, description="Error message (only if failed)")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"job_id": "abc123-def456-ghi789",
|
||||
"status": "completed",
|
||||
"queue_position": None,
|
||||
"estimated_wait_seconds": 0,
|
||||
"created_at": "2024-01-15T12:00:00",
|
||||
"started_at": "2024-01-15T12:00:21",
|
||||
"completed_at": "2024-01-15T12:00:28",
|
||||
"processing_time_ms": 6543,
|
||||
"result": {
|
||||
"receipt_number": "123",
|
||||
"amount": 85.99,
|
||||
"ocr_engine": "paddleocr-light"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRQueueStatusResponse(BaseModel):
|
||||
"""Queue statistics response."""
|
||||
|
||||
pending_jobs: int = Field(description="Number of jobs waiting in queue")
|
||||
processing_jobs: int = Field(description="Number of jobs currently processing")
|
||||
average_time_seconds: float = Field(description="Average processing time in seconds")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"pending_jobs": 5,
|
||||
"processing_jobs": 1,
|
||||
"average_time_seconds": 7.2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
"""Pydantic schemas for receipts API."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Any, Union
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptType, ReceiptDirection, ReceiptStatus, ProcessingStatus
|
||||
from backend.modules.data_entry.db.models.accounting_entry import EntryType
|
||||
|
||||
|
||||
# ============ Accounting Entry Schemas ============
|
||||
|
||||
class AccountingEntryBase(BaseModel):
|
||||
"""Base schema for accounting entry."""
|
||||
entry_type: EntryType
|
||||
account_code: str = Field(max_length=20)
|
||||
account_name: Optional[str] = Field(default=None, max_length=200)
|
||||
amount: Decimal
|
||||
partner_id: Optional[int] = None
|
||||
cost_center_id: Optional[int] = None
|
||||
|
||||
|
||||
class AccountingEntryCreate(AccountingEntryBase):
|
||||
"""Schema for creating an accounting entry."""
|
||||
pass
|
||||
|
||||
|
||||
class AccountingEntryUpdate(BaseModel):
|
||||
"""Schema for updating an accounting entry."""
|
||||
entry_type: Optional[EntryType] = None
|
||||
account_code: Optional[str] = Field(default=None, max_length=20)
|
||||
account_name: Optional[str] = Field(default=None, max_length=200)
|
||||
amount: Optional[Decimal] = None
|
||||
partner_id: Optional[int] = None
|
||||
cost_center_id: Optional[int] = None
|
||||
|
||||
|
||||
class AccountingEntryResponse(AccountingEntryBase):
|
||||
"""Schema for accounting entry response."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
receipt_id: int
|
||||
is_auto_generated: bool
|
||||
modified_by: Optional[str] = None
|
||||
modified_at: Optional[datetime] = None
|
||||
sort_order: int
|
||||
|
||||
|
||||
# ============ Attachment Schemas ============
|
||||
|
||||
class AttachmentResponse(BaseModel):
|
||||
"""Schema for attachment response."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
receipt_id: int
|
||||
filename: str
|
||||
stored_filename: str
|
||||
file_path: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
uploaded_at: datetime
|
||||
|
||||
|
||||
# ============ TVA Schema ============
|
||||
|
||||
class TvaEntrySchema(BaseModel):
|
||||
"""Single TVA entry with code, percentage and amount."""
|
||||
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
|
||||
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
|
||||
amount: Decimal = Field(description="TVA amount for this rate")
|
||||
|
||||
|
||||
class PaymentMethodSchema(BaseModel):
|
||||
"""Payment method entry (CARD/NUMERAR)."""
|
||||
method: str = Field(description="Payment method: CARD or NUMERAR")
|
||||
amount: Decimal = Field(description="Amount paid with this method")
|
||||
|
||||
|
||||
# ============ Receipt Schemas ============
|
||||
|
||||
class ReceiptBase(BaseModel):
|
||||
"""Base schema for receipt."""
|
||||
receipt_type: ReceiptType = ReceiptType.BON_FISCAL
|
||||
direction: ReceiptDirection = ReceiptDirection.CHELTUIALA
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
receipt_date: date
|
||||
amount: Decimal = Field(gt=0)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
# TVA info (multiple entries support)
|
||||
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items")
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
|
||||
# Other fields
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
company_id: int
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200)
|
||||
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
|
||||
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
|
||||
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
|
||||
cash_register_id: Optional[int] = None
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100)
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20)
|
||||
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
|
||||
|
||||
|
||||
class ReceiptCreate(ReceiptBase):
|
||||
"""Schema for creating a receipt."""
|
||||
pass
|
||||
|
||||
|
||||
class ReceiptUpdate(BaseModel):
|
||||
"""Schema for updating a receipt (DRAFT only)."""
|
||||
receipt_type: Optional[ReceiptType] = None
|
||||
direction: Optional[ReceiptDirection] = None
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
receipt_date: Optional[date] = None
|
||||
amount: Optional[Decimal] = Field(default=None, gt=0)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
# TVA info (multiple entries support)
|
||||
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items")
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
|
||||
# Other fields
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200)
|
||||
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
|
||||
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
|
||||
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
|
||||
cash_register_id: Optional[int] = None
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100)
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20)
|
||||
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
|
||||
|
||||
|
||||
class ReceiptResponse(ReceiptBase):
|
||||
"""Schema for receipt response with all fields."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
# Override amount to allow zero values in response (validation is on input, not output)
|
||||
amount: Decimal
|
||||
status: ReceiptStatus
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
submitted_at: Optional[datetime] = None
|
||||
reviewed_by: Optional[str] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
rejection_reason: Optional[str] = None
|
||||
oracle_synced_at: Optional[datetime] = None
|
||||
oracle_act_id: Optional[int] = None
|
||||
oracle_error: Optional[str] = None
|
||||
|
||||
# Bulk upload batch tracking (US-012)
|
||||
batch_id: Optional[str] = None
|
||||
processing_status: Optional[str] = None
|
||||
processing_error: Optional[str] = None
|
||||
file_hash: Optional[str] = None
|
||||
processing_started_at: Optional[datetime] = None
|
||||
processing_completed_at: Optional[datetime] = None
|
||||
|
||||
# Relationships (optional, loaded when needed)
|
||||
attachments: List[AttachmentResponse] = []
|
||||
entries: List[AccountingEntryResponse] = []
|
||||
|
||||
@field_validator('tva_breakdown', mode='before')
|
||||
@classmethod
|
||||
def parse_tva_breakdown(cls, v: Any) -> Optional[List[dict]]:
|
||||
"""Deserialize tva_breakdown from JSON string if needed."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return None
|
||||
|
||||
@field_validator('payment_methods', mode='before')
|
||||
@classmethod
|
||||
def parse_payment_methods(cls, v: Any) -> Optional[List[dict]]:
|
||||
"""Deserialize payment_methods from JSON string if needed."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
class ProcessingStats(BaseModel):
|
||||
"""Statistics for bulk upload processing status (US-012)."""
|
||||
pending_count: int = 0
|
||||
processing_count: int = 0
|
||||
completed_count: int = 0
|
||||
failed_count: int = 0
|
||||
|
||||
|
||||
class ReceiptListResponse(BaseModel):
|
||||
"""Schema for paginated receipt list response."""
|
||||
items: List[ReceiptResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
# Processing stats for bulk upload filtering (US-012)
|
||||
processing_stats: Optional[ProcessingStats] = None
|
||||
|
||||
|
||||
class ReceiptFilter(BaseModel):
|
||||
"""Schema for filtering receipts."""
|
||||
status: Optional[ReceiptStatus] = None
|
||||
direction: Optional[ReceiptDirection] = None
|
||||
company_id: Optional[int] = None
|
||||
created_by: Optional[str] = None
|
||||
date_from: Optional[date] = None
|
||||
date_to: Optional[date] = None
|
||||
search: Optional[str] = None # Search in description, partner_name
|
||||
# Bulk upload filters (US-012)
|
||||
processing_status: Optional[str] = None # ProcessingStatus enum value
|
||||
batch_id: Optional[str] = None # Filter by batch_id
|
||||
sort_by: Optional[str] = None # Sort field (e.g., "processing_started_at")
|
||||
# Pagination
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
# ============ Workflow Schemas ============
|
||||
|
||||
class WorkflowAction(BaseModel):
|
||||
"""Schema for workflow action response."""
|
||||
success: bool
|
||||
message: str
|
||||
receipt: Optional[ReceiptResponse] = None
|
||||
|
||||
|
||||
class RejectRequest(BaseModel):
|
||||
"""Schema for rejection request."""
|
||||
reason: str = Field(min_length=5, max_length=500)
|
||||
|
||||
|
||||
class EntriesUpdateRequest(BaseModel):
|
||||
"""Schema for bulk updating accounting entries."""
|
||||
entries: List[AccountingEntryCreate]
|
||||
|
||||
|
||||
# ============ Nomenclature Schemas ============
|
||||
|
||||
class PartnerOption(BaseModel):
|
||||
"""Schema for partner dropdown option (used for autocomplete assistance)."""
|
||||
name: str
|
||||
fiscal_code: Optional[str] = None
|
||||
address: Optional[str] = None
|
||||
source: str = "oracle" # 'oracle' (synced) or 'local'
|
||||
|
||||
|
||||
class AccountOption(BaseModel):
|
||||
"""Schema for account dropdown option."""
|
||||
code: str
|
||||
name: str
|
||||
|
||||
|
||||
class CashRegisterOption(BaseModel):
|
||||
"""Schema for cash register dropdown option."""
|
||||
id: int
|
||||
name: str
|
||||
account_code: str # 5311, 5121, etc.
|
||||
|
||||
|
||||
class ExpenseTypeOption(BaseModel):
|
||||
"""Schema for expense type dropdown option."""
|
||||
code: str
|
||||
name: str
|
||||
account_code: str
|
||||
has_vat: bool
|
||||
vat_percent: Decimal = Decimal("19")
|
||||
|
||||
|
||||
# ============ Bulk Delete Schemas (US-024) ============
|
||||
|
||||
class BulkDeleteRequest(BaseModel):
|
||||
"""Request schema for bulk delete endpoint."""
|
||||
ids: List[int] = Field(..., min_length=1, description="List of receipt IDs to delete")
|
||||
|
||||
|
||||
class BulkDeleteFailure(BaseModel):
|
||||
"""Schema for a single failed deletion."""
|
||||
id: int
|
||||
error: str
|
||||
|
||||
|
||||
class BulkDeleteResponse(BaseModel):
|
||||
"""Response schema for bulk delete with partial success support."""
|
||||
deleted: List[int] = Field(default_factory=list, description="IDs of successfully deleted receipts")
|
||||
failed: List[BulkDeleteFailure] = Field(default_factory=list, description="IDs that failed with error messages")
|
||||
@@ -0,0 +1,16 @@
|
||||
# Business logic services
|
||||
from .receipt_service import ReceiptService
|
||||
from .nomenclature_service import NomenclatureService
|
||||
from .expense_types import EXPENSE_TYPES, ExpenseType
|
||||
from .receipt_auto_create import ReceiptAutoCreateService, ReceiptCreateResult
|
||||
from . import sse_service
|
||||
|
||||
__all__ = [
|
||||
"ReceiptService",
|
||||
"NomenclatureService",
|
||||
"EXPENSE_TYPES",
|
||||
"ExpenseType",
|
||||
"ReceiptAutoCreateService",
|
||||
"ReceiptCreateResult",
|
||||
"sse_service",
|
||||
]
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Cleanup service for auto-deleting expired failed receipts.
|
||||
|
||||
US-008: Backend - Auto-Cleanup Erori După 7 Zile
|
||||
- Finds receipts with processing_status='failed' and processing_completed_at < now() - 7 days
|
||||
- Deletes the receipts and their attached files from storage
|
||||
- Runs at startup and then daily as a background task
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cleanup configuration
|
||||
CLEANUP_RETENTION_DAYS = 7
|
||||
CLEANUP_INTERVAL_HOURS = 24
|
||||
|
||||
# In-memory storage for last cleanup stats (optional - for login notification)
|
||||
_last_cleanup_stats: dict = {
|
||||
"count": 0,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
|
||||
def get_last_cleanup_stats() -> dict:
|
||||
"""Get stats from the last cleanup run for notification purposes."""
|
||||
return _last_cleanup_stats.copy()
|
||||
|
||||
|
||||
async def cleanup_expired_failed_receipts(session: AsyncSession) -> int:
|
||||
"""
|
||||
Find and delete receipts with processing_status='failed' older than 7 days.
|
||||
|
||||
This function:
|
||||
1. Queries for failed receipts where processing_completed_at < now() - 7 days
|
||||
2. Deletes attachment files from disk
|
||||
3. Deletes the receipt records (cascade deletes attachment records)
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database operations
|
||||
|
||||
Returns:
|
||||
Number of receipts deleted
|
||||
"""
|
||||
global _last_cleanup_stats
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=CLEANUP_RETENTION_DAYS)
|
||||
|
||||
# Find expired failed receipts with their attachments
|
||||
query = select(Receipt).options(
|
||||
selectinload(Receipt.attachments)
|
||||
).where(
|
||||
and_(
|
||||
Receipt.processing_status == "failed",
|
||||
Receipt.processing_completed_at.isnot(None),
|
||||
Receipt.processing_completed_at < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
expired_receipts = result.scalars().all()
|
||||
|
||||
if not expired_receipts:
|
||||
logger.debug("[Cleanup] No expired failed receipts found")
|
||||
return 0
|
||||
|
||||
deleted_count = 0
|
||||
deleted_files = 0
|
||||
|
||||
upload_base_path = settings.upload_path_resolved
|
||||
|
||||
for receipt in expired_receipts:
|
||||
try:
|
||||
# Delete attachment files from disk
|
||||
for attachment in receipt.attachments:
|
||||
file_path = upload_base_path / attachment.file_path
|
||||
if file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files += 1
|
||||
logger.debug(f"[Cleanup] Deleted file: {file_path}")
|
||||
except OSError as e:
|
||||
logger.warning(f"[Cleanup] Failed to delete file {file_path}: {e}")
|
||||
|
||||
# Also try to clean up empty parent directories
|
||||
parent_dir = file_path.parent
|
||||
if parent_dir.exists() and parent_dir != upload_base_path:
|
||||
try:
|
||||
# Only remove if directory is empty
|
||||
if not any(parent_dir.iterdir()):
|
||||
parent_dir.rmdir()
|
||||
logger.debug(f"[Cleanup] Removed empty directory: {parent_dir}")
|
||||
except OSError:
|
||||
pass # Directory not empty or permission issue, skip
|
||||
|
||||
# Delete receipt (cascade deletes attachment records in DB)
|
||||
await session.delete(receipt)
|
||||
deleted_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup] Error deleting receipt {receipt.id}: {e}")
|
||||
continue
|
||||
|
||||
# Commit all deletions
|
||||
if deleted_count > 0:
|
||||
await session.commit()
|
||||
|
||||
# Update stats for notification
|
||||
_last_cleanup_stats = {
|
||||
"count": deleted_count,
|
||||
"files_deleted": deleted_files,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"[Cleanup] Cleaned up {deleted_count} expired failed receipts ({deleted_files} files)")
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
async def run_cleanup_task(get_session_func) -> None:
|
||||
"""
|
||||
Background task that runs cleanup at startup and then every 24 hours.
|
||||
|
||||
Args:
|
||||
get_session_func: Async generator function that yields database sessions
|
||||
"""
|
||||
logger.info("[Cleanup] Starting cleanup background task")
|
||||
|
||||
# Run immediately at startup
|
||||
try:
|
||||
async for session in get_session_func():
|
||||
count = await cleanup_expired_failed_receipts(session)
|
||||
if count > 0:
|
||||
logger.info(f"[Cleanup] Initial cleanup: {count} receipts removed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup] Initial cleanup failed: {e}")
|
||||
|
||||
# Then run every 24 hours
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
|
||||
|
||||
async for session in get_session_func():
|
||||
count = await cleanup_expired_failed_receipts(session)
|
||||
if count > 0:
|
||||
logger.info(f"[Cleanup] Daily cleanup: {count} receipts removed")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[Cleanup] Cleanup task cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup] Daily cleanup failed: {e}")
|
||||
# Continue running even if one cleanup fails
|
||||
|
||||
|
||||
# Global reference to cleanup task for graceful shutdown
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def start_cleanup_task(get_session_func) -> bool:
|
||||
"""
|
||||
Start the cleanup background task.
|
||||
|
||||
Args:
|
||||
get_session_func: Async generator function that yields database sessions
|
||||
|
||||
Returns:
|
||||
True if task started successfully, False otherwise
|
||||
"""
|
||||
global _cleanup_task
|
||||
|
||||
if _cleanup_task is not None and not _cleanup_task.done():
|
||||
logger.warning("[Cleanup] Cleanup task already running")
|
||||
return False
|
||||
|
||||
try:
|
||||
_cleanup_task = asyncio.create_task(run_cleanup_task(get_session_func))
|
||||
logger.info("[Cleanup] ✅ Cleanup background task started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup] Failed to start cleanup task: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_cleanup_task() -> None:
|
||||
"""Stop the cleanup background task gracefully."""
|
||||
global _cleanup_task
|
||||
|
||||
if _cleanup_task is not None and not _cleanup_task.done():
|
||||
_cleanup_task.cancel()
|
||||
try:
|
||||
await _cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("[Cleanup] Cleanup task stopped")
|
||||
|
||||
_cleanup_task = None
|
||||
|
||||
|
||||
def is_cleanup_task_running() -> bool:
|
||||
"""Check if the cleanup task is currently running."""
|
||||
return _cleanup_task is not None and not _cleanup_task.done()
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Predefined expense types for automatic accounting entry generation."""
|
||||
|
||||
from decimal import Decimal
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpenseType:
|
||||
"""Expense type definition with accounting configuration."""
|
||||
code: str
|
||||
name: str
|
||||
account_code: str
|
||||
account_name: str
|
||||
has_vat: bool
|
||||
vat_percent: Decimal = Decimal("19")
|
||||
vat_account: str = "4426"
|
||||
|
||||
|
||||
# Predefined expense types
|
||||
EXPENSE_TYPES: Dict[str, ExpenseType] = {
|
||||
"FUEL": ExpenseType(
|
||||
code="FUEL",
|
||||
name="Combustibil",
|
||||
account_code="6022",
|
||||
account_name="Cheltuieli cu combustibilii",
|
||||
has_vat=True,
|
||||
),
|
||||
"MATERIALS": ExpenseType(
|
||||
code="MATERIALS",
|
||||
name="Materiale consumabile",
|
||||
account_code="6028",
|
||||
account_name="Alte cheltuieli cu materiale consumabile",
|
||||
has_vat=True,
|
||||
),
|
||||
"OFFICE": ExpenseType(
|
||||
code="OFFICE",
|
||||
name="Rechizite birou",
|
||||
account_code="6024",
|
||||
account_name="Cheltuieli privind materialele pentru ambalat",
|
||||
has_vat=True,
|
||||
),
|
||||
"PHONE": ExpenseType(
|
||||
code="PHONE",
|
||||
name="Telefonie / Internet",
|
||||
account_code="626",
|
||||
account_name="Cheltuieli postale si taxe de telecomunicatii",
|
||||
has_vat=True,
|
||||
),
|
||||
"PARKING": ExpenseType(
|
||||
code="PARKING",
|
||||
name="Parcare",
|
||||
account_code="6022",
|
||||
account_name="Cheltuieli cu combustibilii",
|
||||
has_vat=True,
|
||||
),
|
||||
"FOOD": ExpenseType(
|
||||
code="FOOD",
|
||||
name="Alimentatie",
|
||||
account_code="6028",
|
||||
account_name="Alte cheltuieli cu materiale consumabile",
|
||||
has_vat=False, # No deductible VAT for food
|
||||
),
|
||||
"TRANSPORT": ExpenseType(
|
||||
code="TRANSPORT",
|
||||
name="Transport",
|
||||
account_code="624",
|
||||
account_name="Cheltuieli cu transportul de bunuri si personal",
|
||||
has_vat=True,
|
||||
),
|
||||
"OTHER": ExpenseType(
|
||||
code="OTHER",
|
||||
name="Altele",
|
||||
account_code="628",
|
||||
account_name="Alte cheltuieli cu serviciile executate de terti",
|
||||
has_vat=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_expense_type(code: str) -> Optional[ExpenseType]:
|
||||
"""Get expense type by code."""
|
||||
return EXPENSE_TYPES.get(code)
|
||||
|
||||
|
||||
def get_all_expense_types() -> Dict[str, ExpenseType]:
|
||||
"""Get all expense types."""
|
||||
return EXPENSE_TYPES.copy()
|
||||
|
||||
|
||||
# Default cash register accounts
|
||||
CASH_REGISTER_ACCOUNTS = {
|
||||
"CASA": {
|
||||
"code": "5311",
|
||||
"name": "Casa in lei",
|
||||
},
|
||||
"BANCA": {
|
||||
"code": "5121",
|
||||
"name": "Conturi la banci in lei",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
"""Image preprocessing for optimal OCR results."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
try:
|
||||
import pdf2image
|
||||
PDF_AVAILABLE = True
|
||||
except ImportError:
|
||||
PDF_AVAILABLE = False
|
||||
|
||||
|
||||
class ImagePreprocessor:
|
||||
"""Preprocess receipt images for OCR."""
|
||||
|
||||
def _add_safety_padding(self, image: np.ndarray, padding: int = 50) -> np.ndarray:
|
||||
"""Add white padding around image to protect edge content during rotation.
|
||||
|
||||
This prevents left/right margin truncation in OCR by ensuring text near
|
||||
edges isn't lost during deskew rotation.
|
||||
"""
|
||||
if len(image.shape) == 2:
|
||||
# Grayscale
|
||||
return cv2.copyMakeBorder(
|
||||
image, padding, padding, padding, padding,
|
||||
cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
else:
|
||||
# Color (BGR)
|
||||
return cv2.copyMakeBorder(
|
||||
image, padding, padding, padding, padding,
|
||||
cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
||||
)
|
||||
|
||||
def load_image(self, path: Path) -> np.ndarray:
|
||||
"""Load image from file."""
|
||||
image = cv2.imread(str(path))
|
||||
if image is None:
|
||||
raise ValueError(f"Could not load image: {path}")
|
||||
return image
|
||||
|
||||
def pdf_to_images(self, path: Path, dpi: int = 300) -> List[np.ndarray]:
|
||||
"""
|
||||
Convert PDF to images.
|
||||
|
||||
Args:
|
||||
path: Path to PDF file
|
||||
dpi: Resolution (300 = fast & good quality, 400 = better but slower)
|
||||
"""
|
||||
if not PDF_AVAILABLE:
|
||||
raise RuntimeError("pdf2image not available. Install with: pip install pdf2image")
|
||||
images = pdf2image.convert_from_path(str(path), dpi=dpi)
|
||||
return [np.array(img) for img in images]
|
||||
|
||||
def preprocess(self, image: np.ndarray, high_quality: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Apply LIGHT preprocessing - better for clear PDFs.
|
||||
Heavy binarization can destroy text on clear images.
|
||||
"""
|
||||
return self.preprocess_light(image)
|
||||
|
||||
def preprocess_light(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Light preprocessing for CLEAR images (PDFs, good scans).
|
||||
Preserves original quality, only enhances contrast.
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
|
||||
height, width = gray.shape
|
||||
max_side = max(height, width)
|
||||
if max_side > 4000:
|
||||
scale = 4000 / max_side
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
height, width = gray.shape
|
||||
|
||||
# 2b. Scale UP if too small
|
||||
if width < 1500:
|
||||
scale = 1500 / width
|
||||
# Ensure we don't exceed 4000px after upscaling
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
if max(new_width, new_height) > 4000:
|
||||
scale = 4000 / max(new_width, new_height)
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Light contrast enhancement only
|
||||
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# NO binarization, NO morphological ops - preserve original quality
|
||||
return enhanced
|
||||
|
||||
def preprocess_medium(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Medium preprocessing for MIXED-QUALITY images.
|
||||
Balance between Light (too gentle) and Heavy (too aggressive).
|
||||
|
||||
Use cases:
|
||||
- Moderately faded receipts
|
||||
- Photos with uneven lighting
|
||||
- Scans with slight blur
|
||||
|
||||
Preprocessing steps:
|
||||
- Moderate contrast enhancement (CLAHE clipLimit=2.0)
|
||||
- Light denoising (fastNlMeansDenoising h=6)
|
||||
- Gentle sharpening
|
||||
- NO binarization (preserves text boundaries)
|
||||
- NO morphological operations (avoids digit concatenation)
|
||||
|
||||
This method was created to replace preprocess_heavy() which caused
|
||||
digit concatenation errors on high-quality PDFs (85.99 → 859,762.16).
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
|
||||
height, width = gray.shape
|
||||
max_side = max(height, width)
|
||||
if max_side > 4000:
|
||||
scale = 4000 / max_side
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
height, width = gray.shape
|
||||
|
||||
# 2b. Scale UP if too small
|
||||
if width < 1500:
|
||||
scale = 1500 / width
|
||||
# Ensure we don't exceed 4000px after upscaling
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
if max(new_width, new_height) > 4000:
|
||||
scale = 4000 / max(new_width, new_height)
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Moderate contrast enhancement (CLAHE clipLimit=2.0)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# 5. Light denoising (less aggressive than Heavy)
|
||||
denoised = cv2.fastNlMeansDenoising(enhanced, h=6, templateWindowSize=7, searchWindowSize=15)
|
||||
|
||||
# 6. Gentle sharpening
|
||||
gaussian = cv2.GaussianBlur(denoised, (0, 0), 1.0)
|
||||
sharpened = cv2.addWeighted(denoised, 1.3, gaussian, -0.3, 0)
|
||||
|
||||
# NO binarization, NO morphological operations
|
||||
# This preserves text boundaries and avoids digit concatenation
|
||||
return sharpened
|
||||
|
||||
def preprocess_heavy(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Heavy preprocessing for FADED thermal receipts.
|
||||
Aggressive binarization to recover faded text.
|
||||
|
||||
⚠️ DEPRECATED: Use preprocess_medium() instead.
|
||||
Heavy preprocessing causes digit concatenation on clear PDFs
|
||||
(e.g., 85.99 → 859,762.16 due to binarization + morphological operations).
|
||||
Kept for backward compatibility only.
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
|
||||
height, width = gray.shape
|
||||
max_side = max(height, width)
|
||||
if max_side > 4000:
|
||||
scale = 4000 / max_side
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
height, width = gray.shape
|
||||
|
||||
# 2b. Scale UP if too small (larger = better OCR)
|
||||
if width < 1500:
|
||||
scale = 1500 / width
|
||||
# Ensure we don't exceed 4000px after upscaling
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
if max(new_width, new_height) > 4000:
|
||||
scale = 4000 / max(new_width, new_height)
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Contrast enhancement with CLAHE
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# 5. Denoise
|
||||
denoised = cv2.fastNlMeansDenoising(enhanced, h=8, templateWindowSize=7, searchWindowSize=21)
|
||||
|
||||
# 6. Sharpening
|
||||
gaussian = cv2.GaussianBlur(denoised, (0, 0), 2.0)
|
||||
sharpened = cv2.addWeighted(denoised, 1.5, gaussian, -0.5, 0)
|
||||
|
||||
# 7. Adaptive thresholding (binarization)
|
||||
binary = cv2.adaptiveThreshold(
|
||||
sharpened, 255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY,
|
||||
blockSize=11, C=5
|
||||
)
|
||||
|
||||
# 8. Morphological operations
|
||||
kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||
result = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_close)
|
||||
|
||||
return result
|
||||
|
||||
def preprocess_for_tesseract(self, image: np.ndarray, binarize: bool = False,
|
||||
padding: int = 0, clahe_clip: float = 1.5) -> np.ndarray:
|
||||
"""
|
||||
Tesseract-optimized preprocessing (based on comprehensive benchmark).
|
||||
|
||||
BENCHMARK FINDINGS:
|
||||
- DPI 200 is optimal (not 300!)
|
||||
- Padding 40px fixes left margin truncation issues
|
||||
- CLAHE 1.5 for most receipts, 2.0 for difficult ones
|
||||
- NO deskew, NO denoising for clear PDFs
|
||||
|
||||
Recommended usage:
|
||||
- Simple receipts: padding=0, clahe_clip=1.5
|
||||
- Complex receipts: padding=40, clahe_clip=1.5
|
||||
- Difficult/faded: padding=40, clahe_clip=2.0, binarize=True
|
||||
|
||||
Args:
|
||||
image: Input image (RGB from pdf2image or BGR from OpenCV)
|
||||
binarize: Apply Otsu binarization (for faded receipts)
|
||||
padding: White padding in pixels (40px recommended for edge protection)
|
||||
clahe_clip: CLAHE clip limit (1.5 normal, 2.0 for difficult)
|
||||
|
||||
Returns:
|
||||
Preprocessed grayscale image
|
||||
"""
|
||||
# 1. Grayscale (handle both RGB and BGR)
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2. Add padding if specified (protects against left margin truncation)
|
||||
if padding > 0:
|
||||
gray = cv2.copyMakeBorder(
|
||||
gray, padding, padding, padding, padding,
|
||||
cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
|
||||
# 3. CLAHE contrast enhancement
|
||||
clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# NO deskew, NO denoising - these DEGRADE quality on clear PDFs!
|
||||
|
||||
if not binarize:
|
||||
return enhanced
|
||||
|
||||
# Binarization only for faded receipts
|
||||
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# Ensure correct polarity
|
||||
if np.mean(binary) < 127:
|
||||
binary = 255 - binary
|
||||
|
||||
return binary
|
||||
|
||||
def preprocess_for_tesseract_padded(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Tesseract preprocessing with optimal padding (40px).
|
||||
|
||||
Best for complex receipts where left margin gets truncated.
|
||||
"""
|
||||
return self.preprocess_for_tesseract(image, padding=40)
|
||||
|
||||
def preprocess_for_tesseract_faded(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Tesseract preprocessing for FADED thermal receipts.
|
||||
Uses binarization to recover faded text.
|
||||
"""
|
||||
return self.preprocess_for_tesseract(image, binarize=True)
|
||||
|
||||
def get_all_variants(self, image: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate 2 preprocessing variants for OCR (fast mode).
|
||||
Returns: [light_processed, heavy_processed]
|
||||
"""
|
||||
return [
|
||||
self.preprocess_light(image),
|
||||
self.preprocess_heavy(image),
|
||||
]
|
||||
|
||||
def _deskew(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Correct image rotation/skew using Hough lines.
|
||||
|
||||
Uses expanded canvas to preserve all content during rotation,
|
||||
preventing left/right margin truncation.
|
||||
"""
|
||||
edges = cv2.Canny(image, 50, 150, apertureSize=3)
|
||||
lines = cv2.HoughLinesP(
|
||||
edges, 1, np.pi / 180,
|
||||
threshold=100, minLineLength=100, maxLineGap=10
|
||||
)
|
||||
|
||||
if lines is None:
|
||||
return image
|
||||
|
||||
angles = []
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi
|
||||
if abs(angle) < 45:
|
||||
angles.append(angle)
|
||||
|
||||
if not angles:
|
||||
return image
|
||||
|
||||
median_angle = np.median(angles)
|
||||
if abs(median_angle) < 0.5:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
|
||||
|
||||
# Calculate new canvas size to fit entire rotated image (prevents edge truncation)
|
||||
cos_angle = abs(np.cos(np.radians(median_angle)))
|
||||
sin_angle = abs(np.sin(np.radians(median_angle)))
|
||||
new_w = int(h * sin_angle + w * cos_angle)
|
||||
new_h = int(h * cos_angle + w * sin_angle)
|
||||
|
||||
# Adjust rotation matrix for new canvas center
|
||||
M[0, 2] += (new_w - w) / 2
|
||||
M[1, 2] += (new_h - h) / 2
|
||||
|
||||
return cv2.warpAffine(
|
||||
image, M, (new_w, new_h),
|
||||
flags=cv2.INTER_CUBIC,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=255 # White background (grayscale)
|
||||
)
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Service for fetching nomenclatures from Oracle (read-only)."""
|
||||
|
||||
from typing import List, Optional
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
PartnerOption,
|
||||
AccountOption,
|
||||
CashRegisterOption,
|
||||
ExpenseTypeOption,
|
||||
)
|
||||
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
|
||||
class NomenclatureService:
|
||||
"""
|
||||
Service for fetching nomenclatures.
|
||||
|
||||
In Phase 1 (MVP), some nomenclatures are hardcoded.
|
||||
In Phase 2, these will be fetched from Oracle.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_partners(
|
||||
company_id: int,
|
||||
search: Optional[str] = None,
|
||||
session: Optional[AsyncSession] = None
|
||||
) -> List[PartnerOption]:
|
||||
"""
|
||||
Get partners (suppliers/customers) for a company.
|
||||
|
||||
Returns synced suppliers from Oracle + local suppliers created from OCR.
|
||||
If no suppliers exist, returns empty list (frontend will trigger sync).
|
||||
"""
|
||||
partners = []
|
||||
|
||||
if not session:
|
||||
return partners
|
||||
|
||||
# Get synced suppliers from Oracle
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
|
||||
# Always get local suppliers (not just when synced exist)
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name)
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name,
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
@staticmethod
|
||||
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
"""
|
||||
Get chart of accounts for a company.
|
||||
|
||||
Phase 1: Returns common expense/income accounts.
|
||||
Phase 2: Will fetch from Oracle PLAN_CONTURI.
|
||||
"""
|
||||
# Common accounts for expenses and receipts
|
||||
accounts = [
|
||||
# Expense accounts (Class 6)
|
||||
AccountOption(code="6022", name="Cheltuieli cu combustibilii"),
|
||||
AccountOption(code="6024", name="Cheltuieli materiale pentru ambalat"),
|
||||
AccountOption(code="6028", name="Alte cheltuieli cu materiale consumabile"),
|
||||
AccountOption(code="624", name="Cheltuieli cu transportul de bunuri si personal"),
|
||||
AccountOption(code="626", name="Cheltuieli postale si taxe telecomunicatii"),
|
||||
AccountOption(code="628", name="Alte cheltuieli cu serviciile executate de terti"),
|
||||
|
||||
# VAT
|
||||
AccountOption(code="4426", name="TVA deductibila"),
|
||||
AccountOption(code="4427", name="TVA colectata"),
|
||||
|
||||
# Cash and Bank (Class 5)
|
||||
AccountOption(code="5311", name="Casa in lei"),
|
||||
AccountOption(code="5121", name="Conturi la banci in lei"),
|
||||
|
||||
# Income accounts (Class 7)
|
||||
AccountOption(code="7588", name="Alte venituri din exploatare"),
|
||||
]
|
||||
|
||||
if prefix:
|
||||
accounts = [a for a in accounts if a.code.startswith(prefix)]
|
||||
|
||||
return accounts
|
||||
|
||||
@staticmethod
|
||||
async def get_cash_registers(
|
||||
company_id: int,
|
||||
session: Optional[AsyncSession] = None
|
||||
) -> List[CashRegisterOption]:
|
||||
"""
|
||||
Get cash registers and bank accounts for a company.
|
||||
|
||||
Phase 1: Returns default options.
|
||||
Phase 2: Returns synced data from SQLite (from Oracle sync).
|
||||
Phase 3: Will fetch live from Oracle NOM_CASE / NOM_BANCI.
|
||||
"""
|
||||
# If session is provided, try to get from synced SQLite data
|
||||
if session:
|
||||
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
|
||||
result = await session.execute(stmt)
|
||||
registers = result.scalars().all()
|
||||
|
||||
if registers:
|
||||
return [
|
||||
CashRegisterOption(id=r.id, name=r.name, account_code=r.account_code)
|
||||
for r in registers
|
||||
]
|
||||
|
||||
# Fallback to default cash registers for Phase 1
|
||||
return [
|
||||
CashRegisterOption(id=1, name="Casa principala", account_code="5311"),
|
||||
CashRegisterOption(id=2, name="Cont BCR", account_code="5121"),
|
||||
CashRegisterOption(id=3, name="Cont BRD", account_code="5121"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_expense_types() -> List[ExpenseTypeOption]:
|
||||
"""
|
||||
Get predefined expense types with their accounting configuration.
|
||||
"""
|
||||
return [
|
||||
ExpenseTypeOption(
|
||||
code=et.code,
|
||||
name=et.name,
|
||||
account_code=et.account_code,
|
||||
has_vat=et.has_vat,
|
||||
vat_percent=et.vat_percent,
|
||||
)
|
||||
for et in EXPENSE_TYPES.values()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_companies(username: str) -> List[dict]:
|
||||
"""
|
||||
Get companies accessible by user.
|
||||
|
||||
Phase 1: Returns mock data.
|
||||
Phase 2: Will fetch from shared auth based on user permissions.
|
||||
"""
|
||||
# TODO: Integrate with shared auth to get user's companies
|
||||
return [
|
||||
{"id": 1, "name": "SC Test SRL", "cui": "RO12345678"},
|
||||
{"id": 2, "name": "SC Demo SA", "cui": "RO87654321"},
|
||||
]
|
||||
|
||||
# ============ Phase 2 Oracle Integration Methods ============
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_partners_oracle(company_id: int, search: Optional[str] = None) -> List[PartnerOption]:
|
||||
"""
|
||||
Fetch partners from Oracle NOM_PARTENERI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
# Example query:
|
||||
# SELECT ID_PART, DEN_PART, COD_FISCAL
|
||||
# FROM {schema}.NOM_PARTENERI
|
||||
# WHERE DEN_PART LIKE :search
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_accounts_oracle(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
"""
|
||||
Fetch chart of accounts from Oracle PLAN_CONTURI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_cash_registers_oracle(company_id: int) -> List[CashRegisterOption]:
|
||||
"""
|
||||
Fetch cash registers from Oracle NOM_CASE / NOM_BANCI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
OCR Services Module
|
||||
|
||||
Provides persistent OCR worker pool with job queue for efficient processing.
|
||||
|
||||
Components:
|
||||
- ocr_worker_pool: Manages ProcessPoolExecutor with persistent PaddleOCR
|
||||
- job_queue: SQLite-based job queue for async processing
|
||||
- job_worker: Background task that processes queued jobs
|
||||
- tesseract_engine: Optimized Tesseract with multi-PSM and polarity fix
|
||||
|
||||
Architecture:
|
||||
FastAPI → job_queue.create_job() → SQLite
|
||||
↓
|
||||
job_worker loop → ocr_worker_pool.submit_task() → Worker Process
|
||||
↓
|
||||
PaddleOCR/Tesseract
|
||||
"""
|
||||
|
||||
from .ocr_worker_pool import ocr_worker_pool, OCRWorkerPool
|
||||
from .job_queue import job_queue, OCRJobQueue, OCRJob, OCRJobStatus
|
||||
from .job_worker import start_job_worker, stop_job_worker
|
||||
from .tesseract_engine import TesseractEngine
|
||||
from .validation import OCRValidationEngine
|
||||
|
||||
__all__ = [
|
||||
# Worker pool
|
||||
"ocr_worker_pool",
|
||||
"OCRWorkerPool",
|
||||
# Job queue
|
||||
"job_queue",
|
||||
"OCRJobQueue",
|
||||
"OCRJob",
|
||||
"OCRJobStatus",
|
||||
# Job worker
|
||||
"start_job_worker",
|
||||
"stop_job_worker",
|
||||
# Engines
|
||||
"TesseractEngine",
|
||||
# Validation
|
||||
"OCRValidationEngine",
|
||||
]
|
||||
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
SQLite Job Queue Manager for OCR Processing
|
||||
|
||||
Provides async job queue for OCR requests:
|
||||
- Jobs are stored in SQLite for persistence
|
||||
- Queue position and time estimation
|
||||
- Automatic expiration after 24 hours
|
||||
- Statistics for monitoring
|
||||
|
||||
Schema:
|
||||
ocr_jobs (
|
||||
id TEXT PRIMARY KEY, -- UUID
|
||||
status TEXT NOT NULL, -- pending, processing, completed, failed
|
||||
file_path TEXT NOT NULL, -- Path to uploaded file
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT, -- JSON extraction result
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
|
||||
ocr_time_ms INTEGER, -- Actual OCR engine processing time
|
||||
created_by TEXT, -- Username
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP,
|
||||
batch_id INTEGER, -- Foreign key to batch_uploads (for bulk processing)
|
||||
file_hash TEXT -- SHA-256 hash for duplicate detection (US-007)
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class DecimalEncoder(json.JSONEncoder):
|
||||
"""JSON encoder that handles Decimal types."""
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
return super().default(obj)
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
DEFAULT_QUEUE_DIR = Path(__file__).parent.parent.parent.parent.parent / "data" / "ocr_queue"
|
||||
DEFAULT_DB_PATH = DEFAULT_QUEUE_DIR / "ocr_jobs.db"
|
||||
DEFAULT_FILES_DIR = DEFAULT_QUEUE_DIR / "files"
|
||||
|
||||
# Job expiration
|
||||
JOB_EXPIRY_HOURS = 24
|
||||
|
||||
# SQLite busy timeout (milliseconds) - prevents "database is locked" errors
|
||||
SQLITE_BUSY_TIMEOUT_MS = 5000
|
||||
|
||||
|
||||
class OCRJobStatus(str, Enum):
|
||||
"""Job status enum."""
|
||||
pending = "pending"
|
||||
processing = "processing"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
cancelled = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRJob:
|
||||
"""OCR Job data class."""
|
||||
id: str
|
||||
status: OCRJobStatus
|
||||
file_path: str
|
||||
mime_type: str
|
||||
engine: str = "doctr_plus"
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result_json: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
|
||||
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
|
||||
created_by: Optional[str] = None
|
||||
original_filename: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
batch_id: Optional[int] = None # Links to batch_uploads table for bulk processing
|
||||
file_hash: Optional[str] = None # SHA-256 hash for duplicate detection (US-007)
|
||||
|
||||
@property
|
||||
def queue_wait_ms(self) -> Optional[int]:
|
||||
"""Calculate queue wait time (created_at to started_at)."""
|
||||
if self.created_at and self.started_at:
|
||||
delta = self.started_at - self.created_at
|
||||
return int(delta.total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> Optional[Dict]:
|
||||
"""Parse result_json to dict."""
|
||||
if self.result_json:
|
||||
try:
|
||||
return json.loads(self.result_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class OCRJobQueue:
|
||||
"""
|
||||
SQLite-based job queue for OCR processing.
|
||||
|
||||
Provides async methods for job management with position
|
||||
tracking and time estimation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path] = None,
|
||||
files_dir: Optional[Path] = None
|
||||
):
|
||||
"""
|
||||
Initialize job queue.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database (default: data/ocr_queue/ocr_jobs.db)
|
||||
files_dir: Path to files directory (default: data/ocr_queue/files/)
|
||||
"""
|
||||
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
||||
self.files_dir = Path(files_dir) if files_dir else DEFAULT_FILES_DIR
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize database and directories.
|
||||
|
||||
Creates SQLite database and tables if they don't exist.
|
||||
Creates files directory for uploaded files.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Create directories
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create database and tables
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
# Enable WAL mode for better concurrency and set busy timeout
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS ocr_jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
file_path TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL,
|
||||
engine TEXT DEFAULT 'doctr_plus',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
result_json TEXT,
|
||||
error_message TEXT,
|
||||
processing_time_ms INTEGER,
|
||||
ocr_time_ms INTEGER,
|
||||
created_by TEXT,
|
||||
original_filename TEXT,
|
||||
expires_at TIMESTAMP,
|
||||
batch_id INTEGER
|
||||
)
|
||||
''')
|
||||
|
||||
# Migration: add ocr_time_ms column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
|
||||
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Migration: add batch_id column if it doesn't exist
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN batch_id INTEGER')
|
||||
logger.info("[OCRJobQueue] Added batch_id column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Migration: add file_hash column if it doesn't exist (US-007)
|
||||
try:
|
||||
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN file_hash TEXT')
|
||||
logger.info("[OCRJobQueue] Added file_hash column to existing table")
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
|
||||
# Index for efficient queue queries
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
|
||||
ON ocr_jobs(status, created_at)
|
||||
''')
|
||||
|
||||
# Index for expiration cleanup
|
||||
await db.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_expires
|
||||
ON ocr_jobs(expires_at)
|
||||
''')
|
||||
|
||||
await db.commit()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"[OCRJobQueue] Initialized: db={self.db_path}, files={self.files_dir}")
|
||||
|
||||
async def create_job(
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
mime_type: str,
|
||||
engine: str = "doctr_plus",
|
||||
username: Optional[str] = None,
|
||||
original_filename: Optional[str] = None,
|
||||
batch_id: Optional[int] = None,
|
||||
file_hash: Optional[str] = None
|
||||
) -> OCRJob:
|
||||
"""
|
||||
Create a new OCR job.
|
||||
|
||||
Saves file to disk and creates database record.
|
||||
|
||||
Args:
|
||||
file_bytes: Raw file bytes
|
||||
mime_type: MIME type of file
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
username: Username of requester
|
||||
original_filename: Original filename from upload
|
||||
batch_id: Optional batch ID for bulk upload processing
|
||||
file_hash: Optional SHA-256 hash for duplicate detection (US-007)
|
||||
|
||||
Returns:
|
||||
Created OCRJob instance
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
# Generate job ID
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
# Determine file extension
|
||||
ext_map = {
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
'application/pdf': '.pdf',
|
||||
}
|
||||
ext = ext_map.get(mime_type, '.bin')
|
||||
|
||||
# Save file
|
||||
file_path = self.files_dir / f"{job_id}{ext}"
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(file_bytes)
|
||||
|
||||
# Calculate expiration
|
||||
now = datetime.utcnow()
|
||||
expires_at = now + timedelta(hours=JOB_EXPIRY_HOURS)
|
||||
|
||||
# Insert job record
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
await db.execute('''
|
||||
INSERT INTO ocr_jobs (
|
||||
id, status, file_path, mime_type, engine,
|
||||
created_at, created_by, original_filename, expires_at, batch_id, file_hash
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
job_id, OCRJobStatus.pending.value, str(file_path), mime_type, engine,
|
||||
now.isoformat(), username, original_filename, expires_at.isoformat(), batch_id, file_hash
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[OCRJobQueue] Created job {job_id}: engine={engine}, file={file_path.name}, batch_id={batch_id}")
|
||||
|
||||
return OCRJob(
|
||||
id=job_id,
|
||||
status=OCRJobStatus.pending,
|
||||
file_path=str(file_path),
|
||||
mime_type=mime_type,
|
||||
engine=engine,
|
||||
created_at=now,
|
||||
created_by=username,
|
||||
original_filename=original_filename,
|
||||
expires_at=expires_at,
|
||||
batch_id=batch_id,
|
||||
file_hash=file_hash
|
||||
)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get job by ID.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
OCRJob or None if not found
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
return None
|
||||
|
||||
async def get_queue_position(self, job_id: str) -> Optional[int]:
|
||||
"""
|
||||
Get position in queue for a pending job.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
Queue position (1 = next to process) or None if not pending
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
# Check if job is pending
|
||||
async with db.execute(
|
||||
'SELECT status, created_at FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row or row[0] != OCRJobStatus.pending.value:
|
||||
return None
|
||||
job_created_at = row[1]
|
||||
|
||||
# Count jobs ahead in queue (created before this job)
|
||||
async with db.execute('''
|
||||
SELECT COUNT(*) FROM ocr_jobs
|
||||
WHERE status = 'pending' AND created_at < ?
|
||||
''', (job_created_at,)) as cursor:
|
||||
count = await cursor.fetchone()
|
||||
return (count[0] + 1) if count else 1
|
||||
|
||||
async def get_next_pending(self) -> Optional[OCRJob]:
|
||||
"""
|
||||
Get the next pending job (oldest first) and atomically mark it as processing.
|
||||
|
||||
This prevents race conditions in parallel processing - only one worker
|
||||
can claim each job.
|
||||
|
||||
Returns:
|
||||
Next OCRJob to process or None if queue empty
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock: # Serialize access to prevent race conditions
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get the next pending job
|
||||
async with db.execute('''
|
||||
SELECT * FROM ocr_jobs
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
job_id = row['id']
|
||||
|
||||
# Atomically mark as processing
|
||||
await db.execute('''
|
||||
UPDATE ocr_jobs
|
||||
SET status = 'processing', started_at = ?
|
||||
WHERE id = ? AND status = 'pending'
|
||||
''', (now.isoformat(), job_id))
|
||||
await db.commit()
|
||||
|
||||
# Fetch the updated job
|
||||
async with db.execute(
|
||||
'SELECT * FROM ocr_jobs WHERE id = ?',
|
||||
(job_id,)
|
||||
) as cursor:
|
||||
updated_row = await cursor.fetchone()
|
||||
if updated_row:
|
||||
return self._row_to_job(updated_row)
|
||||
|
||||
return None
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: OCRJobStatus,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
processing_time_ms: Optional[int] = None,
|
||||
ocr_time_ms: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job status.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
status: New status
|
||||
result: Extraction result dict (for completed)
|
||||
error: Error message (for failed)
|
||||
processing_time_ms: Total job processing time (started_at to completed_at)
|
||||
ocr_time_ms: Actual OCR engine processing time
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
result_json = json.dumps(result, cls=DecimalEncoder) if result else None
|
||||
|
||||
# Build update query based on status
|
||||
if status == OCRJobStatus.processing:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, started_at = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), job_id)
|
||||
|
||||
elif status == OCRJobStatus.completed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
elif status == OCRJobStatus.failed:
|
||||
query = '''
|
||||
UPDATE ocr_jobs
|
||||
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
|
||||
WHERE id = ?
|
||||
'''
|
||||
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
|
||||
|
||||
else:
|
||||
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
|
||||
params = (status.value, job_id)
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
cursor = await db.execute(query, params)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def get_average_processing_time(self) -> float:
|
||||
"""
|
||||
Calculate average processing time from recent completed jobs.
|
||||
|
||||
Uses last 50 completed jobs for accuracy.
|
||||
|
||||
Returns:
|
||||
Average time in seconds (default 7.0 if no data)
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute('''
|
||||
SELECT AVG(processing_time_ms)
|
||||
FROM (
|
||||
SELECT processing_time_ms FROM ocr_jobs
|
||||
WHERE status = 'completed' AND processing_time_ms IS NOT NULL
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT 50
|
||||
)
|
||||
''') as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0] / 1000.0 # Convert ms to seconds
|
||||
return 7.0 # Default estimate
|
||||
|
||||
async def count_pending(self) -> int:
|
||||
"""Count pending jobs in queue."""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute(
|
||||
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
|
||||
(OCRJobStatus.pending.value,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
async def count_processing(self) -> int:
|
||||
"""Count currently processing jobs."""
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute(
|
||||
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
|
||||
(OCRJobStatus.processing.value,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Delete expired jobs and their files.
|
||||
|
||||
Returns:
|
||||
Number of jobs deleted
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
now = datetime.utcnow()
|
||||
deleted = 0
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get expired jobs
|
||||
async with db.execute('''
|
||||
SELECT id, file_path FROM ocr_jobs
|
||||
WHERE expires_at < ?
|
||||
''', (now.isoformat(),)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
# Delete file
|
||||
file_path = Path(row['file_path'])
|
||||
if file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
|
||||
|
||||
# Delete job record
|
||||
await db.execute('DELETE FROM ocr_jobs WHERE id = ?', (row['id'],))
|
||||
deleted += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"[OCRJobQueue] Cleaned up {deleted} expired job(s)")
|
||||
|
||||
return deleted
|
||||
|
||||
async def cleanup_job_file(self, job_id: str) -> bool:
|
||||
"""
|
||||
Delete the file associated with a job.
|
||||
|
||||
Called after processing to free disk space.
|
||||
|
||||
Args:
|
||||
job_id: Job UUID
|
||||
|
||||
Returns:
|
||||
True if file deleted
|
||||
"""
|
||||
job = await self.get_job(job_id)
|
||||
if job:
|
||||
file_path = Path(job.file_path)
|
||||
if file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
async def get_queue_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get queue statistics.
|
||||
|
||||
Returns:
|
||||
Dict with pending, processing, completed, failed counts
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
stats = {
|
||||
"pending": 0,
|
||||
"processing": 0,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"average_time_seconds": 0.0,
|
||||
}
|
||||
|
||||
async with aiosqlite.connect(str(self.db_path)) as db:
|
||||
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
|
||||
async with db.execute('''
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM ocr_jobs
|
||||
GROUP BY status
|
||||
''') as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
if row[0] in stats:
|
||||
stats[row[0]] = row[1]
|
||||
|
||||
stats["average_time_seconds"] = await self.get_average_processing_time()
|
||||
return stats
|
||||
|
||||
def _row_to_job(self, row: aiosqlite.Row) -> OCRJob:
|
||||
"""Convert database row to OCRJob."""
|
||||
def parse_datetime(val):
|
||||
if val:
|
||||
try:
|
||||
return datetime.fromisoformat(val)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return None
|
||||
|
||||
return OCRJob(
|
||||
id=row['id'],
|
||||
status=OCRJobStatus(row['status']),
|
||||
file_path=row['file_path'],
|
||||
mime_type=row['mime_type'],
|
||||
engine=row['engine'] or 'doctr_plus',
|
||||
created_at=parse_datetime(row['created_at']),
|
||||
started_at=parse_datetime(row['started_at']),
|
||||
completed_at=parse_datetime(row['completed_at']),
|
||||
result_json=row['result_json'],
|
||||
error_message=row['error_message'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
|
||||
created_by=row['created_by'],
|
||||
original_filename=row['original_filename'],
|
||||
expires_at=parse_datetime(row['expires_at']),
|
||||
batch_id=row['batch_id'] if 'batch_id' in row.keys() else None,
|
||||
file_hash=row['file_hash'] if 'file_hash' in row.keys() else None,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
job_queue = OCRJobQueue()
|
||||
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
OCR Job Worker - Background Task for Queue Processing
|
||||
|
||||
Runs as an asyncio background task in FastAPI.
|
||||
Continuously polls the job queue and processes OCR requests IN PARALLEL.
|
||||
|
||||
Architecture:
|
||||
FastAPI startup
|
||||
↓
|
||||
start_job_worker()
|
||||
↓
|
||||
asyncio.create_task(_job_worker_loop())
|
||||
↓
|
||||
while True:
|
||||
# Process up to OCR_WORKERS jobs concurrently
|
||||
jobs = get_pending_jobs(limit=available_slots)
|
||||
for job in jobs:
|
||||
asyncio.create_task(_process_job(job))
|
||||
await asyncio.sleep(0.1)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
from .job_queue import job_queue, OCRJobStatus, OCRJob
|
||||
from .ocr_worker_pool import ocr_worker_pool
|
||||
from backend.modules.data_entry.schemas.ocr import ExtractionData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global task reference
|
||||
_job_worker_task: Optional[asyncio.Task] = None
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
|
||||
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
|
||||
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
|
||||
|
||||
|
||||
async def _job_worker_loop() -> None:
|
||||
"""
|
||||
Main worker loop - processes jobs from queue IN PARALLEL.
|
||||
|
||||
Runs continuously until shutdown. Uses semaphore to limit
|
||||
concurrent jobs to OCR_WORKERS count. Launches jobs as
|
||||
background tasks without waiting for completion.
|
||||
"""
|
||||
global _shutdown_event, _active_tasks, _concurrency_semaphore
|
||||
|
||||
# Get max concurrent jobs from env (matches worker pool size)
|
||||
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
|
||||
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
_active_tasks = set()
|
||||
|
||||
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 10
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Clean up completed tasks
|
||||
done_tasks = {t for t in _active_tasks if t.done()}
|
||||
for task in done_tasks:
|
||||
_active_tasks.discard(task)
|
||||
# Check for exceptions
|
||||
try:
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Task failed: {e}")
|
||||
|
||||
# Check if we have capacity for more jobs
|
||||
active_count = len(_active_tasks)
|
||||
available_slots = max_concurrent - active_count
|
||||
|
||||
if available_slots > 0:
|
||||
# Get next pending job
|
||||
job = await job_queue.get_next_pending()
|
||||
|
||||
if job:
|
||||
consecutive_errors = 0
|
||||
# Launch job processing as background task
|
||||
task = asyncio.create_task(_process_job_with_semaphore(job))
|
||||
_active_tasks.add(task)
|
||||
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
|
||||
else:
|
||||
# No pending jobs - wait briefly
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# At capacity - wait for a slot to free up
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Worker loop cancelled")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}")
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
|
||||
break
|
||||
|
||||
await asyncio.sleep(min(consecutive_errors * 2, 30))
|
||||
|
||||
# Wait for active tasks to complete on shutdown
|
||||
if _active_tasks:
|
||||
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
|
||||
await asyncio.gather(*_active_tasks, return_exceptions=True)
|
||||
|
||||
logger.info("[JobWorker] Worker loop stopped")
|
||||
|
||||
|
||||
async def _process_job_with_semaphore(job: OCRJob) -> None:
|
||||
"""
|
||||
Process job with semaphore to limit concurrency.
|
||||
|
||||
Acquires semaphore before processing, releases after.
|
||||
This ensures we don't exceed OCR_WORKERS concurrent jobs.
|
||||
"""
|
||||
global _concurrency_semaphore
|
||||
|
||||
async with _concurrency_semaphore:
|
||||
await _process_job(job)
|
||||
|
||||
|
||||
async def _process_job(job: OCRJob) -> None:
|
||||
"""
|
||||
Process a single OCR job.
|
||||
|
||||
Reads file, submits to worker pool, updates job status,
|
||||
and saves metrics for analytics.
|
||||
|
||||
Args:
|
||||
job: OCRJob to process
|
||||
"""
|
||||
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
|
||||
start_time = time.time()
|
||||
file_size = 0
|
||||
file_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# Note: Job already marked as 'processing' atomically in get_next_pending()
|
||||
|
||||
# Read file bytes
|
||||
file_path = Path(job.file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
file_size = len(file_bytes)
|
||||
# Determine file type from job or extension
|
||||
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
|
||||
|
||||
# Submit to worker pool
|
||||
result = await ocr_worker_pool.submit_task(
|
||||
image_bytes=file_bytes,
|
||||
engine=job.engine,
|
||||
preprocessing="auto",
|
||||
timeout=OCR_TIMEOUT_SECONDS
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if result.get("success"):
|
||||
# Job completed successfully
|
||||
extraction = result.get("extraction", {})
|
||||
|
||||
# Include raw_texts for analysis (from all OCR engine passes)
|
||||
extraction['raw_texts'] = result.get("raw_texts", [])
|
||||
|
||||
# Extract actual OCR processing time from extraction result
|
||||
ocr_time_ms = extraction.get('processing_time_ms', 0)
|
||||
|
||||
# Debug: log suggested_payment_mode
|
||||
spm = extraction.get('suggested_payment_mode')
|
||||
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.completed,
|
||||
result=extraction,
|
||||
processing_time_ms=elapsed_ms,
|
||||
ocr_time_ms=ocr_time_ms
|
||||
)
|
||||
|
||||
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
|
||||
|
||||
# Save metrics for successful job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=extraction.get('ocr_engine', job.engine),
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=True,
|
||||
overall_confidence=extraction.get('overall_confidence', 0.0),
|
||||
fields_extracted=_count_extracted_fields(extraction),
|
||||
needs_manual_review=extraction.get('needs_manual_review'),
|
||||
validation_warnings_count=len(extraction.get('validation_warnings', [])),
|
||||
validation_errors_count=len(extraction.get('validation_errors', [])),
|
||||
)
|
||||
|
||||
# Auto-save receipt for batch jobs
|
||||
if job.batch_id:
|
||||
auto_save_result = await _auto_save_batch_receipt(
|
||||
job=job,
|
||||
extraction=extraction,
|
||||
file_path=str(file_path)
|
||||
)
|
||||
if not auto_save_result:
|
||||
# Auto-save failed - mark job as failed
|
||||
# Note: job_queue status already updated to 'completed' above
|
||||
# We need to update it back to failed with the auto-save error
|
||||
logger.warning(
|
||||
f"[JobWorker] Job {job.id} OCR succeeded but auto-save failed"
|
||||
)
|
||||
|
||||
else:
|
||||
# Job failed
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=error_msg,
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
|
||||
|
||||
# Save metrics for failed job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}")
|
||||
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=str(e),
|
||||
processing_time_ms=elapsed_ms
|
||||
)
|
||||
|
||||
# Save metrics for error job
|
||||
await _save_job_metrics(
|
||||
job_id=job.id,
|
||||
username=job.created_by or 'unknown',
|
||||
engine_requested=job.engine,
|
||||
engine_used=job.engine,
|
||||
processing_time_ms=elapsed_ms,
|
||||
file_size_bytes=file_size,
|
||||
file_type=file_type,
|
||||
original_filename=job.original_filename,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup file after processing
|
||||
try:
|
||||
await job_queue.cleanup_job_file(job.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}")
|
||||
|
||||
|
||||
async def _cleanup_loop() -> None:
|
||||
"""
|
||||
Periodic cleanup of expired jobs.
|
||||
|
||||
Runs every hour to delete jobs older than 24 hours.
|
||||
"""
|
||||
global _shutdown_event
|
||||
|
||||
logger.info("[JobWorker] Starting cleanup loop...")
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
# Wait for interval or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_shutdown_event.wait(),
|
||||
timeout=CLEANUP_INTERVAL_SECONDS
|
||||
)
|
||||
if _shutdown_event.is_set():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout, do cleanup
|
||||
|
||||
# Run cleanup
|
||||
deleted = await job_queue.cleanup_expired()
|
||||
if deleted > 0:
|
||||
logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[JobWorker] Cleanup loop cancelled")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Cleanup error: {e}")
|
||||
await asyncio.sleep(60) # Retry after 1 minute
|
||||
|
||||
logger.info("[JobWorker] Cleanup loop stopped")
|
||||
|
||||
|
||||
async def start_job_worker() -> bool:
|
||||
"""
|
||||
Start the job worker background task.
|
||||
|
||||
Called at FastAPI startup to begin processing queue.
|
||||
|
||||
Returns:
|
||||
True if started successfully
|
||||
"""
|
||||
global _job_worker_task, _cleanup_task, _shutdown_event
|
||||
|
||||
if _job_worker_task is not None and not _job_worker_task.done():
|
||||
logger.warning("[JobWorker] Already running")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Initialize job queue
|
||||
await job_queue.initialize()
|
||||
|
||||
# Initialize worker pool
|
||||
if not ocr_worker_pool.initialize():
|
||||
logger.error("[JobWorker] Failed to initialize worker pool")
|
||||
return False
|
||||
|
||||
# Pre-warm worker pool in BACKGROUND (don't block startup)
|
||||
# First OCR request may be slower if prewarm isn't done yet
|
||||
async def _background_prewarm():
|
||||
logger.info("[JobWorker] Pre-warming OCR worker pool (background)...")
|
||||
warmup_success = await ocr_worker_pool.prewarm(timeout=90.0)
|
||||
if warmup_success:
|
||||
logger.info("[JobWorker] OCR worker pool pre-warmed successfully")
|
||||
else:
|
||||
logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower")
|
||||
|
||||
asyncio.create_task(_background_prewarm())
|
||||
|
||||
# Start worker loop
|
||||
_shutdown_event = asyncio.Event()
|
||||
_job_worker_task = asyncio.create_task(_job_worker_loop())
|
||||
|
||||
# Start cleanup loop
|
||||
_cleanup_task = asyncio.create_task(_cleanup_loop())
|
||||
|
||||
logger.info("[JobWorker] Started successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[JobWorker] Failed to start: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_job_worker() -> None:
|
||||
"""
|
||||
Stop the job worker background task.
|
||||
|
||||
Called at FastAPI shutdown to gracefully stop processing.
|
||||
"""
|
||||
global _job_worker_task, _cleanup_task, _shutdown_event
|
||||
|
||||
logger.info("[JobWorker] Stopping...")
|
||||
|
||||
# Signal shutdown
|
||||
if _shutdown_event:
|
||||
_shutdown_event.set()
|
||||
|
||||
# Cancel worker task
|
||||
if _job_worker_task and not _job_worker_task.done():
|
||||
_job_worker_task.cancel()
|
||||
try:
|
||||
await _job_worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel cleanup task
|
||||
if _cleanup_task and not _cleanup_task.done():
|
||||
_cleanup_task.cancel()
|
||||
try:
|
||||
await _cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shutdown worker pool
|
||||
ocr_worker_pool.shutdown(wait=True)
|
||||
|
||||
_job_worker_task = None
|
||||
_cleanup_task = None
|
||||
_shutdown_event = None
|
||||
|
||||
logger.info("[JobWorker] Stopped")
|
||||
|
||||
|
||||
def is_running() -> bool:
|
||||
"""Check if job worker is running."""
|
||||
return _job_worker_task is not None and not _job_worker_task.done()
|
||||
|
||||
|
||||
def estimate_wait_time(queue_position: int) -> int:
|
||||
"""
|
||||
Estimate wait time for a job in queue.
|
||||
|
||||
Args:
|
||||
queue_position: Position in queue (1 = next)
|
||||
|
||||
Returns:
|
||||
Estimated wait time in seconds
|
||||
"""
|
||||
if queue_position <= 0:
|
||||
return 0
|
||||
|
||||
# Get average processing time (synchronous fallback)
|
||||
# Default ~7 seconds per job if no data
|
||||
avg_time = 7.0
|
||||
|
||||
try:
|
||||
# Try to get from queue stats
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Can't use sync call in async context, use default
|
||||
pass
|
||||
else:
|
||||
avg_time = loop.run_until_complete(job_queue.get_average_processing_time())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Estimate: position * average_time
|
||||
return int(queue_position * avg_time)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _save_job_metrics(
|
||||
job_id: str,
|
||||
username: str,
|
||||
engine_requested: str,
|
||||
engine_used: str,
|
||||
processing_time_ms: int = 0,
|
||||
file_size_bytes: int = 0,
|
||||
file_type: str = "image/jpeg",
|
||||
original_filename: Optional[str] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
overall_confidence: float = 0.0,
|
||||
fields_extracted: int = 0,
|
||||
needs_manual_review: Optional[bool] = None,
|
||||
validation_warnings_count: int = 0,
|
||||
validation_errors_count: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Save OCR job metrics to database for analytics.
|
||||
|
||||
Called after each job completes (success or failure).
|
||||
Errors are logged but don't affect job processing.
|
||||
"""
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
|
||||
|
||||
async with await get_db_session() as session:
|
||||
await OCRMetricsCRUD.create(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
username=username,
|
||||
engine_requested=engine_requested,
|
||||
engine_used=engine_used,
|
||||
processing_time_ms=processing_time_ms,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_type=file_type,
|
||||
original_filename=original_filename,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
overall_confidence=overall_confidence,
|
||||
fields_extracted=fields_extracted,
|
||||
needs_manual_review=needs_manual_review,
|
||||
validation_warnings_count=validation_warnings_count,
|
||||
validation_errors_count=validation_errors_count,
|
||||
)
|
||||
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - metrics are nice-to-have
|
||||
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
|
||||
|
||||
|
||||
def _count_extracted_fields(extraction: dict) -> int:
|
||||
"""
|
||||
Count number of successfully extracted fields from OCR result.
|
||||
|
||||
Counts non-None values in key fields.
|
||||
"""
|
||||
key_fields = [
|
||||
'receipt_number',
|
||||
'receipt_date',
|
||||
'amount',
|
||||
'partner_name',
|
||||
'cui',
|
||||
'tva_total',
|
||||
'address',
|
||||
'items_count',
|
||||
]
|
||||
|
||||
count = 0
|
||||
for field in key_fields:
|
||||
value = extraction.get(field)
|
||||
if value is not None and value != '' and value != []:
|
||||
count += 1
|
||||
|
||||
# Also count TVA entries if present
|
||||
tva_entries = extraction.get('tva_entries', [])
|
||||
if tva_entries and len(tva_entries) > 0:
|
||||
count += 1
|
||||
|
||||
# Count payment methods if present
|
||||
payment_methods = extraction.get('payment_methods', [])
|
||||
if payment_methods and len(payment_methods) > 0:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Auto-Save Batch Receipt Helper
|
||||
# ============================================================================
|
||||
|
||||
async def _auto_save_batch_receipt(
|
||||
job: OCRJob,
|
||||
extraction: dict,
|
||||
file_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
Automatically create a receipt from OCR result for batch jobs.
|
||||
|
||||
Called when a batch job completes successfully. Creates the receipt,
|
||||
attachment, and accounting entries using ReceiptAutoCreateService.
|
||||
|
||||
Args:
|
||||
job: Completed OCRJob with batch_id set
|
||||
extraction: OCR extraction result dict
|
||||
file_path: Path to the original uploaded file
|
||||
|
||||
Returns:
|
||||
True if receipt created successfully, False otherwise
|
||||
"""
|
||||
if not job.batch_id:
|
||||
return True # Not a batch job, nothing to do
|
||||
|
||||
logger.info(f"[JobWorker] Auto-saving receipt for batch job {job.id} (batch_id={job.batch_id})")
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.modules.data_entry.db.database import get_db_session
|
||||
from backend.modules.data_entry.db.models import BatchUpload
|
||||
from backend.modules.data_entry.services.receipt_auto_create import ReceiptAutoCreateService
|
||||
from sqlalchemy import select
|
||||
|
||||
# Convert extraction dict to ExtractionData schema
|
||||
ocr_result = ExtractionData(**extraction)
|
||||
|
||||
async with await get_db_session() as session:
|
||||
# Get batch info to retrieve company_id and user_id
|
||||
batch_result = await session.execute(
|
||||
select(BatchUpload).where(BatchUpload.id == job.batch_id)
|
||||
)
|
||||
batch = batch_result.scalar_one_or_none()
|
||||
|
||||
if not batch:
|
||||
error_msg = f"Batch {job.batch_id} not found"
|
||||
logger.error(f"[JobWorker] Auto-save failed for job {job.id}: {error_msg}")
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Call ReceiptAutoCreateService
|
||||
result = await ReceiptAutoCreateService.create_from_ocr_result(
|
||||
session=session,
|
||||
job_id=job.id,
|
||||
ocr_result=ocr_result,
|
||||
username=job.created_by or batch.user_id,
|
||||
batch_id=job.batch_id,
|
||||
company_id=batch.company_id,
|
||||
file_path=file_path,
|
||||
original_filename=job.original_filename,
|
||||
file_hash=job.file_hash # Pass file_hash for duplicate detection (US-007)
|
||||
)
|
||||
|
||||
if result.success:
|
||||
logger.info(
|
||||
f"[JobWorker] Auto-save successful for job {job.id}: "
|
||||
f"receipt_id={result.receipt_id}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
error_msg = result.error_message or "Unknown error"
|
||||
logger.warning(
|
||||
f"[JobWorker] Auto-save validation failed for job {job.id}: {error_msg}"
|
||||
)
|
||||
# Update job status to failed with the auto-save error
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[JobWorker] Auto-save exception for job {job.id}: {error_msg}")
|
||||
|
||||
# Update job status to failed
|
||||
try:
|
||||
await job_queue.update_status(
|
||||
job_id=job.id,
|
||||
status=OCRJobStatus.failed,
|
||||
error=f"Auto-save error: {error_msg}"
|
||||
)
|
||||
except Exception as update_err:
|
||||
logger.error(f"[JobWorker] Failed to update job status after auto-save error: {update_err}")
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
OCR Worker Pool Manager
|
||||
|
||||
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
|
||||
Key features:
|
||||
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
|
||||
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
|
||||
- mp_context='spawn' for Windows IIS compatibility
|
||||
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
|
||||
- atexit + signal handlers for cleanup
|
||||
- Health check with auto-respawn
|
||||
- Orphan process cleanup on Windows
|
||||
|
||||
Architecture:
|
||||
Main Process │ Worker Process (PERSISTENT)
|
||||
──────────────────────│──────────────────────────────────
|
||||
OCRWorkerPool │ Worker initialized once
|
||||
↓ │ ↓
|
||||
submit_task() ────────│────→ process_ocr()
|
||||
↓ │ ↓
|
||||
Future.result() ←─────│──── Return result
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import psutil for orphan process cleanup
|
||||
try:
|
||||
import psutil
|
||||
PSUTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PSUTIL_AVAILABLE = False
|
||||
logger.warning("[OCRWorkerPool] psutil not available - orphan cleanup disabled")
|
||||
|
||||
|
||||
class OCRWorkerPool:
|
||||
"""
|
||||
Singleton manager for OCR ProcessPoolExecutor.
|
||||
|
||||
Ensures OCR engines are loaded once and reused for all requests.
|
||||
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
|
||||
"""
|
||||
|
||||
_instance: Optional["OCRWorkerPool"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> "OCRWorkerPool":
|
||||
"""Singleton pattern - only one pool instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize worker pool (runs only once due to singleton)."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
self._worker_pid: Optional[int] = None
|
||||
self._is_warming: bool = False
|
||||
self._is_shutdown: bool = False
|
||||
self._lock = asyncio.Lock() if asyncio.get_event_loop_policy() else None
|
||||
self._sync_lock = mp.Lock()
|
||||
|
||||
# Register cleanup handlers
|
||||
# NOTE: Only use atexit, NOT signal handlers!
|
||||
# Signal handlers interfere with FastAPI's shutdown handling.
|
||||
# FastAPI's shutdown event calls stop_job_worker() which calls shutdown().
|
||||
atexit.register(self._cleanup_on_exit)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("[OCRWorkerPool] Singleton instance created")
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the ProcessPoolExecutor.
|
||||
|
||||
Creates executor with spawn context for Windows compatibility.
|
||||
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if self._executor is not None:
|
||||
logger.warning("[OCRWorkerPool] Already initialized")
|
||||
return True
|
||||
|
||||
if self._is_shutdown:
|
||||
logger.error("[OCRWorkerPool] Cannot initialize - pool is shutdown")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Cleanup any orphan workers from previous runs
|
||||
self._cleanup_orphan_workers()
|
||||
|
||||
# Read configuration from environment
|
||||
max_workers = int(os.getenv('OCR_WORKERS', '2'))
|
||||
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
|
||||
# 0 means no restart (None in ProcessPoolExecutor)
|
||||
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
|
||||
|
||||
# Create executor with spawn context (Windows compatible)
|
||||
# Use mp_context='spawn' explicitly for cross-platform consistency
|
||||
mp_context = mp.get_context('spawn')
|
||||
|
||||
# max_tasks_per_child only available in Python 3.11+
|
||||
executor_kwargs = {
|
||||
'max_workers': max_workers,
|
||||
'mp_context': mp_context,
|
||||
'initializer': _worker_initializer,
|
||||
}
|
||||
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
|
||||
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
|
||||
else:
|
||||
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
self._executor = ProcessPoolExecutor(**executor_kwargs)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
async def prewarm(self, timeout: float = 60.0) -> bool:
|
||||
"""
|
||||
Pre-warm the worker by loading PaddleOCR before first request.
|
||||
|
||||
This is called at FastAPI startup to avoid 30s delay on first request.
|
||||
Submits a dummy task that triggers PaddleOCR initialization.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for warmup (default 60s)
|
||||
|
||||
Returns:
|
||||
True if warmup successful, False if timeout or error
|
||||
"""
|
||||
if self._executor is None:
|
||||
logger.error("[OCRWorkerPool] Cannot prewarm - not initialized")
|
||||
return False
|
||||
|
||||
if self._is_warming:
|
||||
logger.warning("[OCRWorkerPool] Already warming up")
|
||||
return False
|
||||
|
||||
self._is_warming = True
|
||||
logger.info("[OCRWorkerPool] Starting pre-warm (loading PaddleOCR in worker)...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Submit warmup task that initializes PaddleOCR
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self._executor.submit(_warmup_task)
|
||||
|
||||
# Wait with timeout
|
||||
result = await loop.run_in_executor(None, future.result, timeout)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if result.get("success"):
|
||||
logger.info(f"[OCRWorkerPool] Pre-warm complete in {elapsed:.1f}s - PaddleOCR ready")
|
||||
self._worker_pid = result.get("pid")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[OCRWorkerPool] Pre-warm failed: {result.get('error')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"[OCRWorkerPool] Pre-warm failed after {elapsed:.1f}s: {e}")
|
||||
return False
|
||||
finally:
|
||||
self._is_warming = False
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto",
|
||||
timeout: float = 120.0
|
||||
) -> dict:
|
||||
"""
|
||||
Submit OCR task to worker process.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
|
||||
timeout: Maximum processing time in seconds
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
|
||||
Raises:
|
||||
RuntimeError: If pool not initialized or task fails
|
||||
"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError("OCR worker pool not initialized")
|
||||
|
||||
if self._is_shutdown:
|
||||
raise RuntimeError("OCR worker pool is shutdown")
|
||||
|
||||
logger.info(f"[OCRWorkerPool] Submitting task: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes")
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self._executor.submit(
|
||||
_process_ocr_task,
|
||||
image_bytes,
|
||||
engine,
|
||||
preprocessing
|
||||
)
|
||||
|
||||
# Wait for result with timeout
|
||||
result = await loop.run_in_executor(None, future.result, timeout)
|
||||
|
||||
logger.info(f"[OCRWorkerPool] Task complete: success={result.get('success')}")
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"[OCRWorkerPool] Task timed out after {timeout}s")
|
||||
raise RuntimeError(f"OCR processing timed out after {timeout}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Task failed: {e}")
|
||||
raise RuntimeError(f"OCR processing failed: {e}")
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""
|
||||
Check if worker pool is healthy.
|
||||
|
||||
Returns:
|
||||
True if pool is ready to accept tasks
|
||||
"""
|
||||
if self._executor is None:
|
||||
return False
|
||||
if self._is_shutdown:
|
||||
return False
|
||||
|
||||
# Check if worker process is still alive
|
||||
if self._worker_pid and PSUTIL_AVAILABLE:
|
||||
try:
|
||||
proc = psutil.Process(self._worker_pid)
|
||||
if not proc.is_running():
|
||||
logger.warning("[OCRWorkerPool] Worker process died, needs respawn")
|
||||
return False
|
||||
except psutil.NoSuchProcess:
|
||||
logger.warning("[OCRWorkerPool] Worker process not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def shutdown(self, wait: bool = True, timeout: float = 10.0) -> None:
|
||||
"""
|
||||
Shutdown the worker pool gracefully.
|
||||
|
||||
Args:
|
||||
wait: Wait for pending tasks to complete
|
||||
timeout: Maximum wait time in seconds
|
||||
"""
|
||||
if self._executor is None:
|
||||
return
|
||||
|
||||
logger.info("[OCRWorkerPool] Shutting down...")
|
||||
self._is_shutdown = True
|
||||
|
||||
try:
|
||||
self._executor.shutdown(wait=wait, cancel_futures=True)
|
||||
logger.info("[OCRWorkerPool] Executor shutdown complete")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Shutdown error: {e}")
|
||||
|
||||
self._executor = None
|
||||
self._worker_pid = None
|
||||
|
||||
# Final orphan cleanup
|
||||
self._cleanup_orphan_workers()
|
||||
logger.info("[OCRWorkerPool] Shutdown complete")
|
||||
|
||||
def _cleanup_orphan_workers(self) -> int:
|
||||
"""
|
||||
Clean up orphan Python processes from previous runs.
|
||||
|
||||
On Windows with NSSM, orphan processes may remain after service restart.
|
||||
This finds and kills any python.exe processes that were OCR workers.
|
||||
|
||||
Returns:
|
||||
Number of processes killed
|
||||
"""
|
||||
if not PSUTIL_AVAILABLE:
|
||||
return 0
|
||||
|
||||
killed = 0
|
||||
current_pid = os.getpid()
|
||||
|
||||
try:
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
# Skip self
|
||||
if proc.pid == current_pid:
|
||||
continue
|
||||
|
||||
# Look for Python processes with OCR-related cmdline
|
||||
if proc.name().lower() in ('python.exe', 'python3.exe', 'python', 'python3'):
|
||||
cmdline = ' '.join(proc.cmdline() or [])
|
||||
|
||||
# Check if this is an OCR worker process
|
||||
if 'ocr_worker_process' in cmdline.lower() or 'process_ocr_task' in cmdline.lower():
|
||||
logger.warning(f"[OCRWorkerPool] Killing orphan worker: PID={proc.pid}")
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
killed += 1
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OCRWorkerPool] Orphan cleanup error: {e}")
|
||||
|
||||
if killed > 0:
|
||||
logger.info(f"[OCRWorkerPool] Cleaned up {killed} orphan worker(s)")
|
||||
|
||||
return killed
|
||||
|
||||
def _cleanup_on_exit(self) -> None:
|
||||
"""atexit handler for cleanup."""
|
||||
logger.info("[OCRWorkerPool] atexit cleanup triggered")
|
||||
self.shutdown(wait=False)
|
||||
|
||||
def _signal_handler(self, signum: int, frame: Any) -> None:
|
||||
"""Signal handler for SIGTERM/SIGINT."""
|
||||
logger.info(f"[OCRWorkerPool] Received signal {signum}, shutting down...")
|
||||
self.shutdown(wait=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WORKER PROCESS FUNCTIONS
|
||||
# ============================================================================
|
||||
# These functions run in the child process, not the main FastAPI process.
|
||||
|
||||
# Global engines - persist between tasks in worker process
|
||||
_paddle_engine = None
|
||||
_tesseract_engine = None
|
||||
_doctr_engine = None # docTR engine (PyTorch backend)
|
||||
_worker_initialized = False
|
||||
|
||||
|
||||
def _worker_initializer() -> None:
|
||||
"""
|
||||
Called once when worker process spawns.
|
||||
|
||||
Initializes global OCR engines IN PARALLEL for faster startup.
|
||||
Uses ThreadPoolExecutor to load enabled engines concurrently.
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
|
||||
Total warmup time = max(engine_times) instead of sum(engine_times).
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
if _worker_initialized:
|
||||
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
|
||||
return
|
||||
|
||||
# Check which engines are enabled via .env
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
|
||||
if paddle_enabled:
|
||||
enabled_engines.append("paddle")
|
||||
if tesseract_enabled:
|
||||
enabled_engines.append("tesseract")
|
||||
|
||||
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
|
||||
if not paddle_enabled:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
|
||||
if not tesseract_enabled:
|
||||
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Define loader functions - each runs in its own thread
|
||||
def load_doctr():
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
|
||||
engine = initialize_doctr_engine()
|
||||
return ("doctr", engine, None)
|
||||
except Exception as e:
|
||||
return ("doctr", None, str(e))
|
||||
|
||||
def load_paddle():
|
||||
if not paddle_enabled:
|
||||
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
|
||||
engine = initialize_paddle_engine()
|
||||
return ("paddle", engine, None)
|
||||
except Exception as e:
|
||||
return ("paddle", None, str(e))
|
||||
|
||||
def load_tesseract():
|
||||
if not tesseract_enabled:
|
||||
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
|
||||
try:
|
||||
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
|
||||
engine = TesseractEngine()
|
||||
return ("tesseract", engine, None)
|
||||
except Exception as e:
|
||||
return ("tesseract", None, str(e))
|
||||
|
||||
# Build list of futures for enabled engines only
|
||||
futures_to_submit = [load_doctr] # docTR always loaded
|
||||
if paddle_enabled:
|
||||
futures_to_submit.append(load_paddle)
|
||||
if tesseract_enabled:
|
||||
futures_to_submit.append(load_tesseract)
|
||||
|
||||
# Load engines in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
|
||||
futures = [executor.submit(fn) for fn in futures_to_submit]
|
||||
|
||||
for future in as_completed(futures):
|
||||
name, engine, error = future.result()
|
||||
if error and "disabled" not in error:
|
||||
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
|
||||
elif engine:
|
||||
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
|
||||
if name == "doctr":
|
||||
_doctr_engine = engine
|
||||
elif name == "paddle":
|
||||
_paddle_engine = engine
|
||||
elif name == "tesseract":
|
||||
_tesseract_engine = engine
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_worker_initialized = True
|
||||
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
|
||||
|
||||
|
||||
def _warmup_task() -> dict:
|
||||
"""
|
||||
Warmup task that ensures engines are loaded.
|
||||
|
||||
Called at FastAPI startup to pre-warm the worker.
|
||||
Returns success status and worker PID.
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
if not _worker_initialized:
|
||||
_worker_initializer()
|
||||
|
||||
# Quick test - create a small dummy image
|
||||
import numpy as np
|
||||
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
|
||||
# Test docTR if available (fastest engine)
|
||||
if _doctr_engine is not None:
|
||||
try:
|
||||
_doctr_engine([dummy_img])
|
||||
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
|
||||
|
||||
# Test PaddleOCR if available
|
||||
if _paddle_engine is not None:
|
||||
try:
|
||||
_paddle_engine.predict(dummy_img)
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR warmup OK", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] PaddleOCR warmup error: {e}", flush=True)
|
||||
|
||||
# Cleanup
|
||||
gc.collect()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"pid": os.getpid(),
|
||||
"doctr_available": _doctr_engine is not None,
|
||||
"paddle_available": _paddle_engine is not None,
|
||||
"tesseract_available": _tesseract_engine is not None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"pid": os.getpid(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def _process_ocr_task(
|
||||
image_bytes: bytes,
|
||||
engine: str = "doctr_plus",
|
||||
preprocessing: str = "auto"
|
||||
) -> dict:
|
||||
"""
|
||||
Process OCR task in worker process.
|
||||
|
||||
This is the main work function called for each OCR request.
|
||||
Uses persistent global engines loaded at worker init.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
|
||||
preprocessing: Preprocessing mode
|
||||
|
||||
Returns:
|
||||
Dict with extraction results
|
||||
"""
|
||||
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
|
||||
|
||||
try:
|
||||
# Ensure initialization
|
||||
if not _worker_initialized:
|
||||
_worker_initializer()
|
||||
|
||||
# Import processing function
|
||||
from backend.modules.data_entry.services.ocr.ocr_worker_process import process_ocr
|
||||
|
||||
# Run OCR
|
||||
result = process_ocr(
|
||||
image_bytes=image_bytes,
|
||||
paddle_engine=_paddle_engine,
|
||||
tesseract_engine=_tesseract_engine,
|
||||
engine=engine,
|
||||
preprocessing=preprocessing,
|
||||
doctr_engine=_doctr_engine
|
||||
)
|
||||
|
||||
# Cleanup after each task
|
||||
gc.collect()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker {os.getpid()}] Task error: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"pid": os.getpid()
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
ocr_worker_pool = OCRWorkerPool()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,258 @@
|
||||
# Store Profiles - OCR Extraction
|
||||
|
||||
Sistem de profile specifice pentru extracție OCR cu hot-reload.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start: Adaugă un profil nou
|
||||
|
||||
```bash
|
||||
# 1. Generează profil din PDF-uri (dry-run pentru preview)
|
||||
python scripts/generate_store_profile.py \
|
||||
--name "Magazin Nou SRL" \
|
||||
--cui "12345678" \
|
||||
--receipts "docs/data-entry/MagazinNou*.pdf" \
|
||||
--dry-run
|
||||
|
||||
# 2. Generează și salvează
|
||||
python scripts/generate_store_profile.py \
|
||||
--name "Magazin Nou SRL" \
|
||||
--cui "12345678" \
|
||||
--receipts "docs/data-entry/MagazinNou*.pdf" \
|
||||
--output backend/modules/data_entry/services/ocr/profiles/magazin_nou.py
|
||||
|
||||
# 3. Hot-reload (fără restart server)
|
||||
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload
|
||||
|
||||
# 4. Verifică
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Structura directorului
|
||||
|
||||
```
|
||||
profiles/
|
||||
├── __init__.py # ProfileRegistry + hot-reload (~390 linii)
|
||||
├── base.py # BaseStoreProfile + pattern-uri generice (~410 linii)
|
||||
├── lidl.py # Multi-rate TVA (A/B)
|
||||
├── omv.py # B2B, date YYYY.MM.DD
|
||||
├── socar.py # B2B, date YYYY.MM.DD
|
||||
├── brick.py # Standard TVA
|
||||
├── dedeman.py # E-factura support
|
||||
├── kineterra.py # Non-VAT payer
|
||||
├── gama_ink.py # Standard TVA (toner/cartușe)
|
||||
├── electrobering.py # Standard TVA (electronice)
|
||||
├── pictus_velum.py # Standard TVA (rechizite)
|
||||
├── unlimited_keys.py # Standard TVA, NUMERAR payment
|
||||
├── best_print.py # Non-VAT payer (neplătitor TVA)
|
||||
├── stepout_market.py # TVA 5% (cărți/librărie)
|
||||
└── README.md # Acest fișier
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Profile existente (12 profile)
|
||||
|
||||
> **Note**: Pattern-urile TVA sunt **flexibile** și acceptă ORICE cotă (5%, 9%, 11%, 19%, 21%, etc.)
|
||||
> pentru a gestiona atât datele istorice cât și schimbările viitoare ale legislației.
|
||||
|
||||
| Magazin | CUI | Fișier | Caracteristici |
|
||||
|---------|-----|--------|----------------|
|
||||
| LIDL DISCOUNT S.R.L. | 22891860 | `lidl.py` | Multi-rate TVA (coduri A, B, C, D) |
|
||||
| OMV PETROM MARKETING S.R.L. | 11201891 | `omv.py` | B2B (client CUI), date YYYY.MM.DD |
|
||||
| SOCAR PETROLEUM S.A. | 12546600 | `socar.py` | B2B (client CUI), date YYYY.MM.DD |
|
||||
| FIVE-HOLDING S.A. (BRICK) | 10562600 | `brick.py` | Standard TVA |
|
||||
| DEDEMAN SRL | 2816464 | `dedeman.py` | E-factura support |
|
||||
| KINETERRA CONCEPT SRL | 31180432 | `kineterra.py` | Non-VAT payer (returnează `[]`) |
|
||||
| GAMA INK SERVICE SRL | 17741882 | `gama_ink.py` | Standard TVA (toner, cartușe) |
|
||||
| ELECTROBERING S.R.L. | 2744937 | `electrobering.py` | Standard TVA (electronice) |
|
||||
| PICTUS VELUM SRL | 39634534 | `pictus_velum.py` | Standard TVA (rechizite) |
|
||||
| UNLIMITED KEYS S.R.L. | 18993187 | `unlimited_keys.py` | Standard TVA, **NUMERAR** plată |
|
||||
| BEST PRINT TRADE ACTIV SRL | 45417955 | `best_print.py` | **Non-VAT payer** (neplătitor TVA) |
|
||||
| STEPOUT MARKET SRL | 35532655 | `stepout_market.py` | TVA 5% (cărți, librărie) |
|
||||
|
||||
---
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Metodă | Descriere |
|
||||
|----------|--------|-----------|
|
||||
| `/api/data-entry/ocr/profiles` | GET | Lista toate profilele |
|
||||
| `/api/data-entry/ocr/profiles/{cui}` | GET | Detalii profil (acceptă RO prefix) |
|
||||
| `/api/data-entry/ocr/profiles/reload` | POST | Hot-reload toate profilele |
|
||||
|
||||
### Exemple API
|
||||
|
||||
```bash
|
||||
# Lista profile
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles \
|
||||
-H "Authorization: Bearer <token>"
|
||||
|
||||
# Detalii profil (cu sau fără RO prefix)
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles/22891860
|
||||
curl http://localhost:8000/api/data-entry/ocr/profiles/RO22891860
|
||||
|
||||
# Hot-reload după modificări
|
||||
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload \
|
||||
-H "Authorization: Bearer <token>"
|
||||
|
||||
# Response reload:
|
||||
{
|
||||
"success": true,
|
||||
"reloaded_modules": 12,
|
||||
"profiles_count": 12,
|
||||
"registered_cuis": ["22891860", "11201891", "12546600", "10562600", ...],
|
||||
"last_reload": "2026-01-06T22:37:05.000000"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cum funcționează sistemul
|
||||
|
||||
### Flow de extracție
|
||||
|
||||
```
|
||||
ReceiptExtractor.extract()
|
||||
│
|
||||
├─► STEP 1: Extrage vendor + CUI
|
||||
│ └─► _extract_vendor(), _extract_cui()
|
||||
│
|
||||
├─► ProfileRegistry.get_profile(cui)
|
||||
│ └─► Returnează profil specific sau None
|
||||
│
|
||||
├─► STEP 2: Extracție cu profil (dacă există)
|
||||
│ ├─► profile.extract_total()
|
||||
│ ├─► profile.extract_date()
|
||||
│ ├─► profile.extract_receipt_number()
|
||||
│ ├─► profile.extract_tva_entries()
|
||||
│ ├─► profile.extract_payment_methods()
|
||||
│ └─► profile.extract_client_cui()
|
||||
│
|
||||
└─► STEP 3-4: Validare + post-procesare
|
||||
```
|
||||
|
||||
### Fallback
|
||||
|
||||
Dacă nu există profil pentru CUI, se folosește logica generică din `ReceiptExtractor`.
|
||||
|
||||
---
|
||||
|
||||
## Structura unui profil
|
||||
|
||||
```python
|
||||
from .base import BaseStoreProfile
|
||||
from . import ProfileRegistry
|
||||
|
||||
@ProfileRegistry.register
|
||||
class MagazinNouProfile(BaseStoreProfile):
|
||||
"""Docstring cu descriere magazin."""
|
||||
|
||||
CUI_LIST = ["12345678"] # Poate avea mai multe CUI-uri
|
||||
NAME_PATTERNS = ["MAGAZIN", "MAGAZIN NOU", "MAG4ZIN"] # OCR variants
|
||||
STORE_NAME = "Magazin Nou SRL"
|
||||
|
||||
# Override doar ce e diferit de base class
|
||||
def extract_tva_entries(self, text: str) -> List[dict]:
|
||||
# Pattern-uri specifice magazinului
|
||||
...
|
||||
|
||||
def get_validation_hints(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"has_multi_rate_tva": False,
|
||||
"card_equals_total": True,
|
||||
"has_client_cui": False,
|
||||
"has_efactura": False,
|
||||
"is_non_vat_payer": False,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pattern-uri disponibile în base.py
|
||||
|
||||
BaseStoreProfile include pattern-uri generice OCR-tolerant:
|
||||
|
||||
| Pattern | Descriere |
|
||||
|---------|-----------|
|
||||
| `TOTAL_PATTERNS` | 8 variante pentru TOTAL (TOTAL:, TOTAL DE PLATA, etc.) |
|
||||
| `DATE_PATTERNS` | 6 variante (DD.MM.YYYY, YYYY-MM-DD, DD/MM/YYYY) |
|
||||
| `DATE_PATTERNS_OCR_SPACES` | 4 variante cu spații OCR ("2025. 08. 14") |
|
||||
| `NUMBER_PATTERNS` | 11 variante pentru număr bon (NDS, BF, C3POS) |
|
||||
| `PAYMENT_PATTERNS` | 8 variante pentru CARD/NUMERAR |
|
||||
| `CLIENT_MARKERS` | 6 variante pentru secțiune CLIENT |
|
||||
| `CLIENT_CUI_PATTERNS` | 7 variante pentru CUI client |
|
||||
|
||||
### Metode implementate în base class
|
||||
|
||||
- `extract_total(text)` → `Tuple[Decimal, float]`
|
||||
- `extract_date(text)` → `Tuple[date, float]`
|
||||
- `extract_receipt_number(text)` → `Tuple[str, float]`
|
||||
- `extract_payment_methods(text)` → `List[dict]`
|
||||
- `extract_client_cui(text)` → `Tuple[str, float]`
|
||||
- `extract_client_name(text)` → `Tuple[str, float]`
|
||||
|
||||
---
|
||||
|
||||
## Când ai nevoie de profil custom?
|
||||
|
||||
| Situație | Exemplu | Ce trebuie override |
|
||||
|----------|---------|---------------------|
|
||||
| **Multi-rate TVA** | Lidl (TVA A, TVA B) | `extract_tva_entries()` |
|
||||
| **Format dată special** | OMV/Socar (YYYY.MM.DD) | `DATE_PATTERNS_OCR_SPACES` |
|
||||
| **B2B receipts** | Benzinării (au client CUI) | `extract_client_cui()` |
|
||||
| **Non-VAT payer** | Kineterra | `extract_tva_entries()` returnează `[]` |
|
||||
| **E-factura** | Dedeman | `extract_efactura_reference()` |
|
||||
|
||||
---
|
||||
|
||||
## Decizii de design
|
||||
|
||||
1. **Hot-reload manual** - endpoint `/profiles/reload` apelat când se modifică fișiere
|
||||
2. **Persistență în Python** - profile în Git, version controlled
|
||||
3. **Fallback graceful** - dacă nu există profil, folosește logica generică
|
||||
4. **CUI normalization** - gestionează automat prefixul "RO" și whitespace
|
||||
5. **Deduplicare TVA** - folosește `seen = set()` pentru a evita duplicate
|
||||
|
||||
---
|
||||
|
||||
## Comenzi utile
|
||||
|
||||
```bash
|
||||
# Verifică syntax Python pentru toate profilele
|
||||
for f in backend/modules/data_entry/services/ocr/profiles/*.py; do
|
||||
python3 -m py_compile "$f" && echo "✓ $(basename $f)"
|
||||
done
|
||||
|
||||
# Lista profile
|
||||
ls -la backend/modules/data_entry/services/ocr/profiles/
|
||||
|
||||
# Pornește backend pentru testare
|
||||
cd backend && source venv/bin/activate
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
|
||||
|
||||
# Test OCR pe un PDF
|
||||
curl -X POST -F "file=@docs/data-entry/test.pdf" \
|
||||
-H "Authorization: Bearer <token>" \
|
||||
"http://localhost:8000/api/data-entry/ocr/extract?engine=doctr_plus"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Script generare profile
|
||||
|
||||
`scripts/generate_store_profile.py` - generator automat de profile
|
||||
|
||||
```bash
|
||||
# Vezi help
|
||||
python scripts/generate_store_profile.py --help
|
||||
|
||||
# Funcționalități:
|
||||
# - Analizează PDF-uri via OCR API
|
||||
# - Detectează: TVA format, date format, payment patterns, B2B
|
||||
# - Generează cod Python cu OCR error variants
|
||||
# - Suportă glob patterns (*.pdf)
|
||||
# - Verifică sintaxa după generare
|
||||
```
|
||||
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Store Profiles Registry with Hot-Reload Support.
|
||||
|
||||
This module provides a registry for store-specific OCR extraction profiles.
|
||||
Profiles can be reloaded at runtime without restarting the server.
|
||||
|
||||
Usage:
|
||||
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
|
||||
|
||||
# Get profile for a CUI
|
||||
profile = ProfileRegistry.get_profile("22891860")
|
||||
if profile:
|
||||
tva_entries = profile.extract_tva_entries(text)
|
||||
|
||||
# Reload all profiles (after file changes)
|
||||
count = ProfileRegistry.reload_all()
|
||||
|
||||
Architecture:
|
||||
- ProfileRegistry: Singleton registry with class methods
|
||||
- BaseStoreProfile: Abstract base class for profiles
|
||||
- @ProfileRegistry.register: Decorator for profile classes
|
||||
|
||||
Hot-Reload Mechanism:
|
||||
1. Admin calls POST /profiles/reload endpoint
|
||||
2. Registry clears instance cache
|
||||
3. importlib.reload() re-executes each profile module
|
||||
4. @register decorator re-registers classes with new code
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import BaseStoreProfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directory containing profile modules
|
||||
PROFILES_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class ProfileRegistry:
|
||||
"""
|
||||
Registry for store-specific OCR extraction profiles.
|
||||
|
||||
Uses class methods for singleton-like behavior without explicit instantiation.
|
||||
Supports hot-reload via importlib.reload() for runtime updates.
|
||||
|
||||
Attributes:
|
||||
_profiles: Maps CUI -> profile class (not instance)
|
||||
_instances: Maps CUI -> profile instance (lazy, cleared on reload)
|
||||
_last_reload: Timestamp of last reload
|
||||
_loaded: Whether initial load has been performed
|
||||
"""
|
||||
|
||||
# Class-level storage (singleton pattern via class methods)
|
||||
_profiles: Dict[str, Type["BaseStoreProfile"]] = {}
|
||||
_instances: Dict[str, "BaseStoreProfile"] = {}
|
||||
_last_reload: Optional[datetime] = None
|
||||
_loaded: bool = False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Registration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def register(cls, profile_class: Type["BaseStoreProfile"]) -> Type["BaseStoreProfile"]:
|
||||
"""
|
||||
Decorator to register a store profile class.
|
||||
|
||||
Registers the profile for all CUIs in the class's CUI_LIST.
|
||||
Safe for re-registration during hot-reload (overwrites existing).
|
||||
|
||||
Usage:
|
||||
@ProfileRegistry.register
|
||||
class LidlProfile(BaseStoreProfile):
|
||||
CUI_LIST = ["22891860"]
|
||||
...
|
||||
|
||||
Args:
|
||||
profile_class: Profile class to register
|
||||
|
||||
Returns:
|
||||
The same class (allows use as decorator)
|
||||
|
||||
Raises:
|
||||
ValueError: If CUI_LIST is empty
|
||||
"""
|
||||
cui_list = getattr(profile_class, 'CUI_LIST', [])
|
||||
store_name = getattr(profile_class, 'STORE_NAME', profile_class.__name__)
|
||||
|
||||
if not cui_list:
|
||||
logger.warning(f"Profile {profile_class.__name__} has empty CUI_LIST, skipping")
|
||||
return profile_class
|
||||
|
||||
# Register for each CUI
|
||||
for cui in cui_list:
|
||||
# Normalize CUI (remove RO prefix, strip whitespace)
|
||||
normalized_cui = cls._normalize_cui(cui)
|
||||
|
||||
if normalized_cui in cls._profiles:
|
||||
old_class = cls._profiles[normalized_cui]
|
||||
logger.debug(
|
||||
f"Re-registering CUI {normalized_cui}: "
|
||||
f"{old_class.__name__} -> {profile_class.__name__}"
|
||||
)
|
||||
# Clear cached instance for this CUI
|
||||
cls._instances.pop(normalized_cui, None)
|
||||
|
||||
cls._profiles[normalized_cui] = profile_class
|
||||
logger.debug(f"Registered profile {profile_class.__name__} for CUI {normalized_cui}")
|
||||
|
||||
logger.info(f"Registered {store_name} for CUIs: {cui_list}")
|
||||
return profile_class
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lookup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, cui: Optional[str]) -> Optional["BaseStoreProfile"]:
|
||||
"""
|
||||
Get profile instance for a CUI.
|
||||
|
||||
Uses lazy instantiation - creates instance on first access.
|
||||
Returns None if no profile is registered for this CUI.
|
||||
|
||||
Args:
|
||||
cui: CUI to lookup (with or without RO prefix)
|
||||
|
||||
Returns:
|
||||
Profile instance or None
|
||||
"""
|
||||
if not cui:
|
||||
return None
|
||||
|
||||
# Ensure profiles are loaded
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
|
||||
normalized_cui = cls._normalize_cui(cui)
|
||||
|
||||
# Check if profile exists
|
||||
profile_class = cls._profiles.get(normalized_cui)
|
||||
if not profile_class:
|
||||
return None
|
||||
|
||||
# Lazy instantiation
|
||||
if normalized_cui not in cls._instances:
|
||||
try:
|
||||
cls._instances[normalized_cui] = profile_class()
|
||||
logger.debug(f"Instantiated {profile_class.__name__} for CUI {normalized_cui}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate {profile_class.__name__}: {e}")
|
||||
return None
|
||||
|
||||
return cls._instances[normalized_cui]
|
||||
|
||||
@classmethod
|
||||
def has_profile(cls, cui: Optional[str]) -> bool:
|
||||
"""Check if a profile exists for this CUI."""
|
||||
if not cui:
|
||||
return False
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
return cls._normalize_cui(cui) in cls._profiles
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Listing
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def list_profiles(cls) -> List[Dict]:
|
||||
"""
|
||||
List all registered profiles.
|
||||
|
||||
Returns:
|
||||
List of dicts with cui, class_name, store_name, name_patterns
|
||||
"""
|
||||
if not cls._loaded:
|
||||
cls._load_all_profiles()
|
||||
|
||||
result = []
|
||||
seen_classes = set()
|
||||
|
||||
for cui, profile_class in cls._profiles.items():
|
||||
# Avoid duplicates for profiles with multiple CUIs
|
||||
if profile_class.__name__ in seen_classes:
|
||||
continue
|
||||
seen_classes.add(profile_class.__name__)
|
||||
|
||||
result.append({
|
||||
"cuis": list(getattr(profile_class, 'CUI_LIST', [])),
|
||||
"class_name": profile_class.__name__,
|
||||
"store_name": getattr(profile_class, 'STORE_NAME', profile_class.__name__),
|
||||
"name_patterns": list(getattr(profile_class, 'NAME_PATTERNS', [])),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_profile_info(cls, cui: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get detailed info about a profile.
|
||||
|
||||
Args:
|
||||
cui: CUI to lookup
|
||||
|
||||
Returns:
|
||||
Dict with profile details or None
|
||||
"""
|
||||
profile = cls.get_profile(cui)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
return {
|
||||
"cui": cui,
|
||||
"cuis": list(profile.CUI_LIST),
|
||||
"class_name": profile.__class__.__name__,
|
||||
"store_name": profile.STORE_NAME,
|
||||
"name_patterns": list(profile.NAME_PATTERNS),
|
||||
"validation_hints": profile.get_validation_hints(),
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Hot-Reload
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def reload_all(cls) -> int:
|
||||
"""
|
||||
Hot-reload all profile modules.
|
||||
|
||||
Clears instance cache and reloads all .py files in profiles directory.
|
||||
Decorator re-registers classes with updated code.
|
||||
|
||||
Returns:
|
||||
Number of modules reloaded
|
||||
"""
|
||||
logger.info("Starting profile hot-reload...")
|
||||
|
||||
# Clear instance cache (will be recreated on next get_profile)
|
||||
cls._instances.clear()
|
||||
|
||||
# Get list of profile modules (exclude __init__, base)
|
||||
module_names = cls._get_profile_module_names()
|
||||
|
||||
# Determine the module prefix based on how THIS module was imported
|
||||
base_package = cls.__module__
|
||||
|
||||
count = 0
|
||||
for module_name in module_names:
|
||||
full_name = f"{base_package}.{module_name}"
|
||||
|
||||
try:
|
||||
if full_name in sys.modules:
|
||||
# Reload existing module
|
||||
importlib.reload(sys.modules[full_name])
|
||||
logger.debug(f"Reloaded module: {module_name}")
|
||||
else:
|
||||
# Import new module
|
||||
importlib.import_module(full_name)
|
||||
logger.debug(f"Imported new module: {module_name}")
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload {module_name}: {e}")
|
||||
|
||||
cls._last_reload = datetime.utcnow()
|
||||
cls._loaded = True
|
||||
|
||||
logger.info(f"Profile hot-reload complete: {count} modules, {len(cls._profiles)} profiles")
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def get_reload_status(cls) -> Dict:
|
||||
"""Get status of the registry including last reload time."""
|
||||
return {
|
||||
"loaded": cls._loaded,
|
||||
"last_reload": cls._last_reload.isoformat() if cls._last_reload else None,
|
||||
"profiles_count": len(cls._profiles),
|
||||
"instances_count": len(cls._instances),
|
||||
"registered_cuis": list(cls._profiles.keys()),
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def _normalize_cui(cls, cui: str) -> str:
|
||||
"""
|
||||
Normalize CUI for consistent lookup.
|
||||
|
||||
- Removes RO prefix (with or without space)
|
||||
- Strips whitespace
|
||||
- Converts to uppercase
|
||||
|
||||
Args:
|
||||
cui: Raw CUI string
|
||||
|
||||
Returns:
|
||||
Normalized CUI (digits only)
|
||||
"""
|
||||
if not cui:
|
||||
return ""
|
||||
|
||||
cui = str(cui).strip().upper()
|
||||
|
||||
# Remove RO prefix (handles "RO12345" and "RO 12345")
|
||||
if cui.startswith("RO"):
|
||||
cui = cui[2:].lstrip()
|
||||
|
||||
return cui.strip()
|
||||
|
||||
@classmethod
|
||||
def _get_profile_module_names(cls) -> List[str]:
|
||||
"""
|
||||
Get list of profile module names from profiles directory.
|
||||
|
||||
Excludes __init__.py and base.py.
|
||||
|
||||
Returns:
|
||||
List of module names (without .py extension)
|
||||
"""
|
||||
excluded = {"__init__", "base", "__pycache__"}
|
||||
modules = []
|
||||
|
||||
for path in PROFILES_DIR.glob("*.py"):
|
||||
name = path.stem
|
||||
if name not in excluded:
|
||||
modules.append(name)
|
||||
|
||||
return sorted(modules)
|
||||
|
||||
@classmethod
|
||||
def _load_all_profiles(cls) -> None:
|
||||
"""
|
||||
Initial load of all profile modules.
|
||||
|
||||
Called automatically on first get_profile() if not already loaded.
|
||||
"""
|
||||
if cls._loaded:
|
||||
return
|
||||
|
||||
logger.info("Loading store profiles...")
|
||||
|
||||
module_names = cls._get_profile_module_names()
|
||||
|
||||
# Determine the module prefix based on how THIS module was imported
|
||||
# This handles both:
|
||||
# - Running from backend dir: "modules.data_entry.services.ocr.profiles"
|
||||
# - Running from project root: "backend.modules.data_entry.services.ocr.profiles"
|
||||
this_module = cls.__module__ # e.g. "backend.modules..." or "modules..."
|
||||
base_package = this_module # Use the same prefix for child modules
|
||||
|
||||
for module_name in module_names:
|
||||
full_name = f"{base_package}.{module_name}"
|
||||
try:
|
||||
importlib.import_module(full_name)
|
||||
logger.debug(f"Loaded module: {module_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {module_name}: {e}")
|
||||
|
||||
cls._loaded = True
|
||||
cls._last_reload = datetime.utcnow()
|
||||
|
||||
logger.info(f"Loaded {len(cls._profiles)} store profiles")
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""
|
||||
Clear all registered profiles.
|
||||
|
||||
Mainly useful for testing.
|
||||
"""
|
||||
cls._profiles.clear()
|
||||
cls._instances.clear()
|
||||
cls._loaded = False
|
||||
cls._last_reload = None
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Module exports
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
__all__ = [
|
||||
"ProfileRegistry",
|
||||
"BaseStoreProfile",
|
||||
]
|
||||
|
||||
# Re-export BaseStoreProfile for convenience
|
||||
from .base import BaseStoreProfile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,655 @@
|
||||
"""
|
||||
Optimized Tesseract Engine for OCR - SPEED + QUALITY OPTIMIZED
|
||||
|
||||
Performance optimizations (vs previous version):
|
||||
- Single PSM mode (PSM 4) instead of multi-PSM (4 modes × 2 calls = 8x faster)
|
||||
- Single Tesseract call per image (skip image_to_data for speed)
|
||||
- Lighter preprocessing (no over-binarization)
|
||||
- --dpi 300 flag for proper scaling
|
||||
- OEM 3 (default LSTM+Legacy) for balanced speed/accuracy
|
||||
|
||||
Quality optimizations for Romanian receipts:
|
||||
- PSM 4: Single column layout (optimal for receipts)
|
||||
- Polarity correction: ensures black text on white background
|
||||
- Language: Romanian only (-l ron) for faster recognition
|
||||
- Fallback to PSM 6 if PSM 4 produces poor results
|
||||
|
||||
Previous issues fixed:
|
||||
- Was 8x slower than PaddleOCR due to multi-PSM + dual calls
|
||||
- Produced gibberish on clear PDFs due to over-binarization
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Check Tesseract availability
|
||||
try:
|
||||
import pytesseract
|
||||
TESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TESSERACT_AVAILABLE = False
|
||||
pytesseract = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Raw OCR result from Tesseract."""
|
||||
text: str
|
||||
confidence: float
|
||||
boxes: List[dict] = field(default_factory=list)
|
||||
engine: str = "tesseract"
|
||||
|
||||
|
||||
class TesseractEngine:
|
||||
"""
|
||||
Optimized Tesseract engine for receipt OCR.
|
||||
|
||||
TESTED OPTIMAL SETTINGS (from comprehensive benchmark):
|
||||
- DPI 200 for PDF loading (not 300!)
|
||||
- Padding 40px for edge protection
|
||||
- PSM 6 for complex receipts, PSM 4 for simple ones
|
||||
- Multi-pass strategy when quality is critical
|
||||
|
||||
SPEED vs QUALITY tradeoff:
|
||||
- Fast mode (single pass): ~0.9s, ~6-7 keywords
|
||||
- Quality mode (multi-pass): ~1.7s, ~8-9 keywords (+2 more keywords)
|
||||
|
||||
BENCHMARK RESULTS:
|
||||
- padded_psm6_40: Best for complex receipts (igiena, five-holding)
|
||||
- baseline_psm4: Best for simple receipts (rechizite, benzina)
|
||||
- multi-pass: Best overall quality but slower
|
||||
"""
|
||||
|
||||
# PSM modes for receipts
|
||||
PSM_SINGLE_COLUMN = 4 # Best for simple vertical receipts
|
||||
PSM_UNIFORM_BLOCK = 6 # Best for complex layouts
|
||||
PSM_SPARSE_TEXT = 11 # Fallback for difficult receipts
|
||||
|
||||
# Optimal padding (from benchmark)
|
||||
DEFAULT_PADDING = 40
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Tesseract engine."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
raise RuntimeError("pytesseract not available. Install with: pip install pytesseract")
|
||||
|
||||
# Verify Tesseract installation
|
||||
try:
|
||||
self._version = pytesseract.get_tesseract_version()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Tesseract not installed or not in PATH: {e}")
|
||||
|
||||
logger.info(f"[TesseractEngine] Initialized (v{self._version})")
|
||||
|
||||
def recognize(self, image: np.ndarray, fast_mode: bool = True) -> OCRResult:
|
||||
"""
|
||||
Perform OCR recognition on image (OPTIMIZED).
|
||||
|
||||
SPEED: Uses single PSM mode + single Tesseract call.
|
||||
Previously used 4 PSM modes × 2 calls = 8 Tesseract invocations.
|
||||
Now uses 1-2 calls maximum (with fallback).
|
||||
|
||||
Args:
|
||||
image: Preprocessed grayscale image (DO NOT binarize for clear PDFs!)
|
||||
fast_mode: If True, skip confidence calculation for maximum speed
|
||||
|
||||
Returns:
|
||||
OCRResult with text and confidence
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Fix polarity (black text on white background)
|
||||
image = self._ensure_correct_polarity(image)
|
||||
|
||||
# Try PSM 4 first (single column - best for receipts)
|
||||
result = self._recognize_fast(image, self.PSM_SINGLE_COLUMN, fast_mode)
|
||||
|
||||
# If poor result, try PSM 6 as fallback
|
||||
if not result.text.strip() or result.confidence < 0.3:
|
||||
logger.debug(f"[Tesseract] PSM {self.PSM_SINGLE_COLUMN} poor result, trying PSM {self.PSM_UNIFORM_BLOCK}")
|
||||
fallback = self._recognize_fast(image, self.PSM_UNIFORM_BLOCK, fast_mode)
|
||||
if len(fallback.text) > len(result.text):
|
||||
result = fallback
|
||||
|
||||
if result.text.strip():
|
||||
logger.info(f"[TesseractEngine] Result: {len(result.text)} chars, conf={result.confidence:.0%}")
|
||||
|
||||
return result
|
||||
|
||||
def _recognize_fast(self, image: np.ndarray, psm: int, fast_mode: bool = True) -> OCRResult:
|
||||
"""
|
||||
Fast single-call Tesseract recognition.
|
||||
|
||||
Optimizations:
|
||||
- Single call (image_to_string only in fast mode)
|
||||
- OEM 3 (LSTM+Legacy) - faster than OEM 1
|
||||
- --dpi 300 for proper scaling
|
||||
- Romanian only (-l ron)
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
psm: Page segmentation mode
|
||||
fast_mode: Skip confidence calculation for speed
|
||||
|
||||
Returns:
|
||||
OCRResult
|
||||
"""
|
||||
# Build optimized config:
|
||||
# OEM 3 = LSTM + Legacy (faster than pure LSTM)
|
||||
# --dpi 300 = proper scaling hint
|
||||
# -l ron = Romanian only (faster, avoids eng confusion)
|
||||
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
|
||||
|
||||
try:
|
||||
if fast_mode:
|
||||
# Fast path: just get text, estimate confidence
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
# Estimate confidence based on text quality
|
||||
confidence = self._estimate_confidence(text)
|
||||
else:
|
||||
# Accurate path: get text + real confidence
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
data = pytesseract.image_to_data(
|
||||
image, config=config, output_type=pytesseract.Output.DICT
|
||||
)
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
confidence = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
boxes=[],
|
||||
engine="tesseract"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] PSM {psm} error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def _estimate_confidence(self, text: str) -> float:
|
||||
"""
|
||||
Estimate OCR confidence based on text quality.
|
||||
|
||||
Heuristics:
|
||||
- More alphanumeric chars = higher confidence
|
||||
- Less garbage chars = higher confidence
|
||||
- Romanian-specific patterns boost confidence
|
||||
"""
|
||||
if not text.strip():
|
||||
return 0.0
|
||||
|
||||
# Count valid vs garbage chars
|
||||
valid_chars = sum(1 for c in text if c.isalnum() or c in '.,;:-/\n ')
|
||||
total_chars = len(text)
|
||||
|
||||
if total_chars == 0:
|
||||
return 0.0
|
||||
|
||||
# Base confidence from char ratio
|
||||
confidence = valid_chars / total_chars
|
||||
|
||||
# Boost for Romanian receipt patterns
|
||||
text_lower = text.lower()
|
||||
if any(word in text_lower for word in ['total', 'lei', 'ron', 'buc', 'tva', 'cif', 'bon']):
|
||||
confidence = min(confidence + 0.1, 1.0)
|
||||
|
||||
return confidence
|
||||
|
||||
def recognize_multipass(self, image: np.ndarray) -> OCRResult:
|
||||
"""
|
||||
Multi-pass OCR for maximum quality (slower but more accurate).
|
||||
|
||||
Strategy (from benchmark testing):
|
||||
- Pass 1: PSM 4 (single column) - no padding, fast baseline
|
||||
- Pass 2: PSM 6 (uniform block) - with 40px padding, better for complex layouts
|
||||
- Pass 3: PSM 11 (sparse text) - with 40px padding + stronger CLAHE, for difficult receipts
|
||||
|
||||
Merges results: picks the pass with highest keyword count.
|
||||
On average finds +2.1 more keywords than single-pass (~8.7 vs 6.6).
|
||||
|
||||
Time: ~1.7s (vs ~0.9s for single pass)
|
||||
|
||||
Args:
|
||||
image: Input image (RGB or grayscale)
|
||||
|
||||
Returns:
|
||||
OCRResult from the best pass
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# Define passes with different settings
|
||||
passes = [
|
||||
# Pass 1: Fast baseline (no padding) - good for simple receipts
|
||||
{"name": "pass1_psm4", "psm": 4, "padding": 0, "clahe_clip": 1.5},
|
||||
# Pass 2: Padded PSM 6 - good for complex receipts
|
||||
{"name": "pass2_psm6_padded", "psm": 6, "padding": 40, "clahe_clip": 1.5},
|
||||
# Pass 3: Sparse text with stronger enhancement - for difficult cases
|
||||
{"name": "pass3_psm11", "psm": 11, "padding": 40, "clahe_clip": 2.0},
|
||||
]
|
||||
|
||||
best_result = None
|
||||
best_score = -1
|
||||
all_keywords = set()
|
||||
|
||||
for p in passes:
|
||||
# Apply preprocessing for this pass
|
||||
processed = gray.copy()
|
||||
|
||||
# Add padding if specified
|
||||
if p["padding"] > 0:
|
||||
processed = cv2.copyMakeBorder(
|
||||
processed, p["padding"], p["padding"], p["padding"], p["padding"],
|
||||
cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
|
||||
# Apply CLAHE
|
||||
clahe = cv2.createCLAHE(clipLimit=p["clahe_clip"], tileGridSize=(8, 8))
|
||||
processed = clahe.apply(processed)
|
||||
|
||||
# Ensure correct polarity
|
||||
processed = self._ensure_correct_polarity(processed)
|
||||
|
||||
# Run OCR
|
||||
config = f'--psm {p["psm"]} --oem 3 -l ron'
|
||||
try:
|
||||
text = pytesseract.image_to_string(processed, config=config)
|
||||
confidence = self._estimate_confidence(text)
|
||||
|
||||
# Score based on Romanian receipt keywords
|
||||
text_lower = text.lower()
|
||||
keywords = ['cif', 'total', 'tva', 'lei', 'ron', 'buc', 'fiscal', 'bon',
|
||||
'hartie', 'prosop', 'saci', 'creion', 'constanta', 'bucuresti']
|
||||
found_keywords = [kw for kw in keywords if kw in text_lower]
|
||||
all_keywords.update(found_keywords)
|
||||
|
||||
# Score: keywords + CIF bonus + TOTAL bonus
|
||||
score = len(found_keywords) * 10
|
||||
if self._has_cif_pattern(text):
|
||||
score += 15
|
||||
if self._has_total_pattern(text):
|
||||
score += 10
|
||||
|
||||
logger.debug(f"[Tesseract] {p['name']}: {len(found_keywords)} keywords, score={score}")
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_result = OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
boxes=[],
|
||||
engine=f"tesseract-multipass-{p['name']}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] {p['name']} failed: {e}")
|
||||
continue
|
||||
|
||||
if best_result:
|
||||
logger.info(f"[TesseractEngine] Multi-pass best: {best_result.engine}, "
|
||||
f"{len(all_keywords)} total keywords found")
|
||||
return best_result
|
||||
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract-multipass")
|
||||
|
||||
def _has_cif_pattern(self, text: str) -> bool:
|
||||
"""Check if text contains a valid CIF/CUI pattern."""
|
||||
import re
|
||||
text_upper = text.upper()
|
||||
patterns = [
|
||||
r'CIF[:\s]*RO?\d{6,10}',
|
||||
r'CUI[:\s]*RO?\d{6,10}',
|
||||
r'C\.?I\.?F\.?[:\s]*RO?\d{6,10}',
|
||||
]
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text_upper):
|
||||
return True
|
||||
return bool(re.search(r'RO\d{7,10}', text_upper))
|
||||
|
||||
def _has_total_pattern(self, text: str) -> bool:
|
||||
"""Check if TOTAL is properly recognized (not truncated to BTOTAL/OTAL)."""
|
||||
import re
|
||||
text_upper = text.upper()
|
||||
return bool(re.search(r'(^|\s)TOTAL\s', text_upper, re.MULTILINE))
|
||||
|
||||
def recognize_with_boxes(self, image: np.ndarray, psm: int = 4) -> OCRResult:
|
||||
"""
|
||||
Recognition with bounding boxes (slower, for debugging/visualization).
|
||||
|
||||
Use this only when you need box coordinates.
|
||||
For normal OCR, use recognize() which is faster.
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
psm: Page segmentation mode (default: 4 for receipts)
|
||||
|
||||
Returns:
|
||||
OCRResult with text, confidence, and boxes
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
image = self._ensure_correct_polarity(image)
|
||||
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
data = pytesseract.image_to_data(
|
||||
image, config=config, output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
boxes = []
|
||||
for i in range(len(data['text'])):
|
||||
if data['text'][i].strip() and int(data['conf'][i]) > 0:
|
||||
boxes.append({
|
||||
'text': data['text'][i],
|
||||
'confidence': int(data['conf'][i]) / 100,
|
||||
'box': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
|
||||
})
|
||||
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=boxes, engine="tesseract")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tesseract] recognize_with_boxes error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def _ensure_correct_polarity(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Ensure image has black text on white background.
|
||||
|
||||
Receipts should have dark text on light background.
|
||||
If image is inverted (light text on dark), invert it.
|
||||
|
||||
Detection method:
|
||||
- Calculate mean pixel value
|
||||
- If mean < 127, image is mostly dark (inverted)
|
||||
- Invert to correct polarity
|
||||
|
||||
Args:
|
||||
image: Grayscale image
|
||||
|
||||
Returns:
|
||||
Polarity-corrected image
|
||||
"""
|
||||
mean_value = np.mean(image)
|
||||
|
||||
if mean_value < 127:
|
||||
# Image is mostly dark = inverted (white text on black)
|
||||
logger.debug(f"[TesseractEngine] Detected inverted polarity (mean={mean_value:.1f}), correcting...")
|
||||
return 255 - image
|
||||
|
||||
return image
|
||||
|
||||
def recognize_numbers_only(self, image: np.ndarray) -> OCRResult:
|
||||
"""
|
||||
OCR optimized for numeric content (amounts, totals).
|
||||
|
||||
Uses character whitelist to reduce errors on numbers.
|
||||
|
||||
Args:
|
||||
image: Preprocessed grayscale image
|
||||
|
||||
Returns:
|
||||
OCRResult with numeric text
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Fix polarity
|
||||
image = self._ensure_correct_polarity(image)
|
||||
|
||||
# Config for numbers only
|
||||
# Whitelist: digits, comma, period, space, RON, LEI
|
||||
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789.,- '
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
|
||||
data = pytesseract.image_to_data(
|
||||
image,
|
||||
config=config,
|
||||
output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=text.strip(),
|
||||
confidence=avg_conf,
|
||||
boxes=[],
|
||||
engine="tesseract-numeric"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[TesseractEngine] Numeric OCR error: {e}")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
def recognize_cif_optimized(self, image: np.ndarray) -> Optional[str]:
|
||||
"""
|
||||
Optimized CIF extraction using multi-strategy approach.
|
||||
|
||||
BENCHMARK RESULTS (from test_critical_fields.py):
|
||||
- digit_opt_dpi200: 33% accuracy (best)
|
||||
- digit_whitelist: Works well on specific receipts
|
||||
- basic_ron_eng: Good backup
|
||||
|
||||
Strategy:
|
||||
1. Try digit-optimized preprocessing (2x scale + Otsu)
|
||||
2. Try character whitelist (RO + digits only)
|
||||
3. Try standard ron+eng config
|
||||
4. Return best match based on CIF pattern validation
|
||||
|
||||
Args:
|
||||
image: Input image (RGB from pdf2image or BGR from OpenCV)
|
||||
|
||||
Returns:
|
||||
Extracted CIF string (e.g., "RO10562600") or None
|
||||
"""
|
||||
import re
|
||||
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Ensure grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# Extract top 35% of image (where CIF is typically found)
|
||||
height = gray.shape[0]
|
||||
top_region = gray[:int(height * 0.35), :]
|
||||
|
||||
candidates = []
|
||||
|
||||
# Strategy 1: Digit-optimized preprocessing (best performer: 33% accuracy)
|
||||
try:
|
||||
# Scale up 2x + Otsu binarization
|
||||
scaled = cv2.resize(top_region, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(scaled)
|
||||
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
if np.mean(binary) < 127:
|
||||
binary = 255 - binary
|
||||
|
||||
text = pytesseract.image_to_string(binary, config='--psm 6 --oem 3 -l ron')
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('digit_opt', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] digit_opt strategy failed: {e}")
|
||||
|
||||
# Strategy 2: Character whitelist (RO + digits only)
|
||||
try:
|
||||
# Add padding
|
||||
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
|
||||
scaled = cv2.resize(padded, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789ROro'
|
||||
text = pytesseract.image_to_string(scaled, config=config)
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('whitelist', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] whitelist strategy failed: {e}")
|
||||
|
||||
# Strategy 3: Standard ron+eng config (good backup)
|
||||
try:
|
||||
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
|
||||
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(padded)
|
||||
|
||||
text = pytesseract.image_to_string(enhanced, config='--psm 6 --oem 3 -l ron+eng')
|
||||
cif = self._extract_cif_from_text(text)
|
||||
if cif:
|
||||
candidates.append(('ron_eng', cif))
|
||||
except Exception as e:
|
||||
logger.debug(f"[TesseractEngine] ron_eng strategy failed: {e}")
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Log all candidates
|
||||
for strategy, cif in candidates:
|
||||
logger.debug(f"[TesseractEngine] CIF candidate from {strategy}: {cif}")
|
||||
|
||||
# Use majority voting if multiple strategies agree
|
||||
from collections import Counter
|
||||
cif_counts = Counter(cif for _, cif in candidates)
|
||||
most_common_cif, count = cif_counts.most_common(1)[0]
|
||||
|
||||
if count > 1:
|
||||
# Multiple strategies agree
|
||||
logger.info(f"[TesseractEngine] CIF extracted (majority {count} strategies): {most_common_cif}")
|
||||
return most_common_cif
|
||||
|
||||
# No agreement - prefer digit_opt strategy (33% accuracy in benchmarks)
|
||||
for strategy, cif in candidates:
|
||||
if strategy == 'digit_opt':
|
||||
logger.info(f"[TesseractEngine] CIF extracted via digit_opt (preferred): {cif}")
|
||||
return cif
|
||||
|
||||
# Fallback to first candidate
|
||||
strategy, cif = candidates[0]
|
||||
logger.info(f"[TesseractEngine] CIF extracted via {strategy}: {cif}")
|
||||
return cif
|
||||
|
||||
def _extract_cif_from_text(self, text: str) -> Optional[str]:
|
||||
"""Extract CIF/CUI from OCR text."""
|
||||
import re
|
||||
text_upper = text.upper().replace(' ', '')
|
||||
|
||||
patterns = [
|
||||
r'CIF[:\s]*R?O?(\d{6,10})',
|
||||
r'CUI[:\s]*R?O?(\d{6,10})',
|
||||
r'C\.?I\.?F\.?[:\s]*R?O?(\d{6,10})',
|
||||
r'RO(\d{7,10})',
|
||||
r'R\.?O\.?[\s:]*(\d{6,10})',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text_upper)
|
||||
if match:
|
||||
digits = match.group(1).lstrip('0') or '0'
|
||||
return f"RO{digits}"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_romanian_cif(cif: str) -> bool:
|
||||
"""
|
||||
Validate Romanian CIF/CUI using checksum algorithm.
|
||||
|
||||
Romanian CIF format: RO + 2-10 digits
|
||||
The last digit is a control digit calculated using modulo 11.
|
||||
|
||||
Algorithm:
|
||||
1. Multiply each digit by corresponding weight (from right to left: 2,3,4,5,6,7,2,3,4,5)
|
||||
2. Sum all products
|
||||
3. Remainder of sum / 11 is the control digit
|
||||
4. If remainder is 10, control digit is 0
|
||||
|
||||
Args:
|
||||
cif: CIF string (e.g., "RO10562600", "10562600")
|
||||
|
||||
Returns:
|
||||
True if CIF is valid, False otherwise
|
||||
"""
|
||||
# Remove RO prefix and spaces
|
||||
cif = cif.upper().replace(' ', '').replace('RO', '')
|
||||
|
||||
# Must be 2-10 digits
|
||||
if not cif.isdigit() or len(cif) < 2 or len(cif) > 10:
|
||||
return False
|
||||
|
||||
# Weights for checksum calculation (right to left)
|
||||
weights = [2, 3, 4, 5, 6, 7, 2, 3, 4, 5]
|
||||
|
||||
# Pad with zeros on the left to make it 10 digits
|
||||
cif_padded = cif.zfill(10)
|
||||
|
||||
# Calculate checksum (excluding last digit which is control)
|
||||
total = 0
|
||||
for i in range(9):
|
||||
total += int(cif_padded[i]) * weights[i]
|
||||
|
||||
# Control digit
|
||||
control = total % 11
|
||||
if control == 10:
|
||||
control = 0
|
||||
|
||||
# Compare with last digit
|
||||
return int(cif_padded[9]) == control
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if Tesseract is available."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return False
|
||||
|
||||
try:
|
||||
pytesseract.get_tesseract_version()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_version() -> Optional[str]:
|
||||
"""Get Tesseract version string."""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
return str(pytesseract.get_tesseract_version())
|
||||
except Exception:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,476 @@
|
||||
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Setup logging (respects LOG_LEVEL env var set in main.py)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
# Lazy imports - these will be imported on first use
|
||||
PaddleOCR = None # Will be imported lazily
|
||||
pytesseract = None # Will be imported lazily
|
||||
doctr_ocr_predictor = None # Will be imported lazily
|
||||
|
||||
# Check availability without importing heavy libraries
|
||||
def _check_paddle_available() -> bool:
|
||||
"""Check if paddleocr is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("paddleocr") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_tesseract_available() -> bool:
|
||||
"""Check if pytesseract is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("pytesseract") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_doctr_available() -> bool:
|
||||
"""Check if doctr is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("doctr") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
PADDLE_AVAILABLE = _check_paddle_available()
|
||||
TESSERACT_AVAILABLE = _check_tesseract_available()
|
||||
DOCTR_AVAILABLE = _check_doctr_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Raw OCR result."""
|
||||
text: str
|
||||
confidence: float
|
||||
boxes: List[dict]
|
||||
engine: str = "" # OCR engine used: paddleocr or tesseract
|
||||
|
||||
|
||||
class OCREngine:
|
||||
"""Unified OCR engine with fallback support."""
|
||||
|
||||
def __init__(self):
|
||||
self._paddle = None
|
||||
self._paddle_init_started = False
|
||||
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
|
||||
self._paddle_init_lock = threading.Lock()
|
||||
|
||||
self._doctr = None
|
||||
self._doctr_init_started = False
|
||||
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
|
||||
self._doctr_init_lock = threading.Lock()
|
||||
|
||||
def _init_paddle_lazy(self):
|
||||
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
|
||||
global PaddleOCR
|
||||
|
||||
with self._paddle_init_lock:
|
||||
if self._paddle_init_started:
|
||||
return # Already initializing or done
|
||||
self._paddle_init_started = True
|
||||
|
||||
if PADDLE_AVAILABLE:
|
||||
try:
|
||||
print("Importing PaddleOCR (first use, may take ~15-20 seconds)...", flush=True)
|
||||
from paddleocr import PaddleOCR as _PaddleOCR
|
||||
PaddleOCR = _PaddleOCR
|
||||
|
||||
print("Initializing PaddleOCR engine...", flush=True)
|
||||
# PaddleOCR 3.x API - optimized for Romanian receipts
|
||||
# Note: 'latin' not available in PaddleOCR 3.x, 'en' works well for receipts
|
||||
self._paddle = PaddleOCR(
|
||||
lang='en', # 'en' handles Latin alphabet well for receipts
|
||||
# High quality settings for better accuracy
|
||||
det_db_thresh=0.3, # Lower threshold = detect more text (default 0.3)
|
||||
det_db_box_thresh=0.5, # Box confidence threshold (default 0.5)
|
||||
det_db_unclip_ratio=1.8, # Expand detected boxes slightly (default 1.5)
|
||||
rec_batch_num=6, # Batch size for recognition
|
||||
use_angle_cls=True, # Enable text angle classification
|
||||
)
|
||||
print("PaddleOCR initialized successfully with high-quality settings", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize PaddleOCR: {e}", flush=True)
|
||||
self._paddle = None
|
||||
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._paddle_ready.set()
|
||||
|
||||
def _init_doctr_lazy(self):
|
||||
"""Lazy initialize docTR on first use (avoids slow startup)."""
|
||||
global doctr_ocr_predictor
|
||||
|
||||
with self._doctr_init_lock:
|
||||
if self._doctr_init_started:
|
||||
return # Already initializing or done
|
||||
self._doctr_init_started = True
|
||||
|
||||
if DOCTR_AVAILABLE:
|
||||
try:
|
||||
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
|
||||
from doctr.io import DocumentFile
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
print("Initializing docTR engine (PyTorch backend)...", flush=True)
|
||||
# Initialize docTR predictor with pretrained models
|
||||
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
|
||||
self._doctr = ocr_predictor(
|
||||
det_arch='db_resnet50',
|
||||
reco_arch='crnn_vgg16_bn',
|
||||
pretrained=True,
|
||||
assume_straight_pages=True,
|
||||
straighten_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
doctr_ocr_predictor = self._doctr
|
||||
print("docTR initialized successfully with PyTorch backend", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
|
||||
self._doctr = None
|
||||
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._doctr_ready.set()
|
||||
|
||||
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for docTR to be fully initialized.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (default 30s)
|
||||
|
||||
Returns:
|
||||
True if docTR is ready, False if timeout or unavailable
|
||||
"""
|
||||
if not DOCTR_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._doctr is not None:
|
||||
return True # Already ready
|
||||
|
||||
if not self._doctr_init_started:
|
||||
# Start initialization if not already started
|
||||
self._init_doctr_lazy()
|
||||
|
||||
# Wait for initialization to complete
|
||||
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
|
||||
start = time.time()
|
||||
ready = self._doctr_ready.wait(timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if ready and self._doctr is not None:
|
||||
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
|
||||
return True
|
||||
else:
|
||||
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
|
||||
return False
|
||||
|
||||
def is_doctr_ready(self) -> bool:
|
||||
"""Check if docTR is ready without waiting."""
|
||||
return self._doctr is not None
|
||||
|
||||
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for PaddleOCR to be fully initialized.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (default 30s)
|
||||
|
||||
Returns:
|
||||
True if PaddleOCR is ready, False if timeout or unavailable
|
||||
"""
|
||||
if not PADDLE_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._paddle is not None:
|
||||
return True # Already ready
|
||||
|
||||
if not self._paddle_init_started:
|
||||
# Start initialization if not already started
|
||||
self._init_paddle_lazy()
|
||||
|
||||
# Wait for initialization to complete
|
||||
print(f"[OCR] Waiting for PaddleOCR to be ready (max {timeout}s)...", flush=True)
|
||||
start = time.time()
|
||||
ready = self._paddle_ready.wait(timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if ready and self._paddle is not None:
|
||||
print(f"[OCR] PaddleOCR ready after {elapsed:.1f}s", flush=True)
|
||||
return True
|
||||
else:
|
||||
print(f"[OCR] PaddleOCR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
|
||||
return False
|
||||
|
||||
def is_paddle_ready(self) -> bool:
|
||||
"""Check if PaddleOCR is ready without waiting."""
|
||||
return self._paddle is not None
|
||||
|
||||
def recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Perform OCR on preprocessed image."""
|
||||
logger.info(f"[OCR] Starting recognition, image shape: {image.shape}, dtype: {image.dtype}")
|
||||
|
||||
# Lazy init PaddleOCR on first call
|
||||
self._init_paddle_lazy()
|
||||
|
||||
if PADDLE_AVAILABLE and self._paddle:
|
||||
logger.info("[OCR] Using PaddleOCR engine")
|
||||
return self._paddle_recognize(image)
|
||||
elif TESSERACT_AVAILABLE:
|
||||
logger.info("[OCR] Using Tesseract engine (PaddleOCR not available)")
|
||||
return self._tesseract_recognize(image)
|
||||
else:
|
||||
logger.error("[OCR] No OCR engine available!")
|
||||
raise RuntimeError(
|
||||
"No OCR engine available. Install PaddleOCR or Tesseract."
|
||||
)
|
||||
|
||||
def _paddle_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using PaddleOCR 3.x API."""
|
||||
# Wait for PaddleOCR to be fully ready (handles background init)
|
||||
if not self.wait_for_paddle(timeout=30.0):
|
||||
logger.warning("[PaddleOCR] Not ready, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise RuntimeError("PaddleOCR not ready and Tesseract not available")
|
||||
|
||||
try:
|
||||
logger.info(f"[PaddleOCR] Processing image, shape: {image.shape}")
|
||||
|
||||
# PaddleOCR 3.x requires 3-channel images
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to 3-channel BGR
|
||||
import cv2
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
logger.info(f"[PaddleOCR] Converted to BGR, new shape: {image.shape}")
|
||||
|
||||
# PaddleOCR 3.x uses predict() with new parameter names
|
||||
logger.info("[PaddleOCR] Calling predict()...")
|
||||
result = self._paddle.predict(image, use_textline_orientation=True)
|
||||
logger.info(f"[PaddleOCR] predict() returned, result type: {type(result)}")
|
||||
|
||||
if not result or len(result) == 0:
|
||||
logger.warning("[PaddleOCR] No results returned")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
# PaddleOCR 3.x returns OCRResult objects with different structure
|
||||
ocr_result = result[0]
|
||||
|
||||
# Extract texts and scores from the new format
|
||||
rec_texts = ocr_result.get('rec_texts', [])
|
||||
rec_scores = ocr_result.get('rec_scores', [])
|
||||
dt_polys = ocr_result.get('dt_polys', [])
|
||||
|
||||
if not rec_texts:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
boxes = []
|
||||
for i, text in enumerate(rec_texts):
|
||||
conf = rec_scores[i] if i < len(rec_scores) else 0.0
|
||||
box = dt_polys[i].tolist() if i < len(dt_polys) else []
|
||||
boxes.append({
|
||||
'text': text,
|
||||
'confidence': float(conf),
|
||||
'box': box
|
||||
})
|
||||
|
||||
avg_conf = sum(rec_scores) / len(rec_scores) if rec_scores else 0.0
|
||||
text_result = '\n'.join(rec_texts)
|
||||
logger.info(f"[PaddleOCR] SUCCESS - Found {len(rec_texts)} text lines, avg confidence: {avg_conf:.2%}")
|
||||
logger.debug(f"[PaddleOCR] Raw text preview: {text_result[:200]}...")
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="paddleocr"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[PaddleOCR] ERROR: {e}, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise
|
||||
|
||||
def _tesseract_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using Tesseract."""
|
||||
global pytesseract
|
||||
|
||||
logger.info(f"[Tesseract] Processing image, shape: {image.shape}")
|
||||
|
||||
# Lazy import pytesseract
|
||||
if pytesseract is None:
|
||||
logger.info("[Tesseract] Importing pytesseract...")
|
||||
import pytesseract as _pytesseract
|
||||
pytesseract = _pytesseract
|
||||
|
||||
# PSM 4: Single column (best for receipts)
|
||||
config = '--psm 4 -l ron+eng'
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
|
||||
# Quick confidence estimate
|
||||
data = pytesseract.image_to_data(image, config=config, output_type=pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
|
||||
|
||||
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using docTR."""
|
||||
# Wait for docTR to be fully ready
|
||||
if not self.wait_for_doctr(timeout=30.0):
|
||||
logger.warning("[docTR] Not ready, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise RuntimeError("docTR not ready and Tesseract not available")
|
||||
|
||||
try:
|
||||
logger.info(f"[docTR] Processing image, shape: {image.shape}")
|
||||
|
||||
# docTR requires RGB images
|
||||
import cv2
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 4:
|
||||
# Convert RGBA to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
|
||||
elif image.shape[2] == 3:
|
||||
# Check if BGR (from OpenCV) and convert to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
|
||||
|
||||
# Process image with docTR
|
||||
logger.info("[docTR] Running prediction...")
|
||||
from doctr.io import DocumentFile
|
||||
|
||||
# docTR expects a document (list of pages as numpy arrays)
|
||||
result = self._doctr([image])
|
||||
|
||||
if not result or not result.pages:
|
||||
logger.warning("[docTR] No results returned")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
|
||||
|
||||
# Extract text from all pages
|
||||
all_texts = []
|
||||
all_confidences = []
|
||||
boxes = []
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
line_text = ' '.join(word.value for word in line.words)
|
||||
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
|
||||
all_texts.append(line_text)
|
||||
all_confidences.append(line_confidence)
|
||||
|
||||
# Store word-level boxes
|
||||
for word in line.words:
|
||||
boxes.append({
|
||||
'text': word.value,
|
||||
'confidence': float(word.confidence),
|
||||
'box': word.geometry # (xmin, ymin), (xmax, ymax)
|
||||
})
|
||||
|
||||
text_result = '\n'.join(all_texts)
|
||||
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
|
||||
|
||||
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
|
||||
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
|
||||
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="doctr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise
|
||||
|
||||
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
|
||||
"""
|
||||
Run both OCR engines and return both results.
|
||||
|
||||
Returns:
|
||||
Tuple of (paddle_result, tesseract_result)
|
||||
tesseract_result may be None if Tesseract is not available
|
||||
"""
|
||||
logger.info(f"[OCR Dual] Starting dual recognition, image shape: {image.shape}")
|
||||
|
||||
# Lazy init PaddleOCR
|
||||
self._init_paddle_lazy()
|
||||
|
||||
paddle_result = None
|
||||
tesseract_result = None
|
||||
|
||||
# Run PaddleOCR
|
||||
if PADDLE_AVAILABLE and self._paddle:
|
||||
try:
|
||||
logger.info("[OCR Dual] Running PaddleOCR...")
|
||||
paddle_result = self._paddle_recognize(image)
|
||||
logger.info(f"[OCR Dual] PaddleOCR: {len(paddle_result.text)} chars, conf: {paddle_result.confidence:.2%}")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCR Dual] PaddleOCR failed: {e}")
|
||||
paddle_result = OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
# Run Tesseract
|
||||
if TESSERACT_AVAILABLE:
|
||||
try:
|
||||
logger.info("[OCR Dual] Running Tesseract...")
|
||||
tesseract_result = self._tesseract_recognize(image)
|
||||
logger.info(f"[OCR Dual] Tesseract: {len(tesseract_result.text)} chars, conf: {tesseract_result.confidence:.2%}")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCR Dual] Tesseract failed: {e}")
|
||||
tesseract_result = OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Fallback if PaddleOCR not available
|
||||
if paddle_result is None:
|
||||
if tesseract_result:
|
||||
paddle_result = tesseract_result
|
||||
else:
|
||||
raise RuntimeError("No OCR engine available")
|
||||
|
||||
return paddle_result, tesseract_result
|
||||
|
||||
@staticmethod
|
||||
def get_available_engines() -> List[str]:
|
||||
"""
|
||||
Return list of available OCR engines.
|
||||
|
||||
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
|
||||
Engines that are disabled via .env are not returned even if installed.
|
||||
|
||||
Available engines: tesseract, doctr, doctr_plus, paddleocr
|
||||
"""
|
||||
# Check .env settings
|
||||
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
|
||||
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
|
||||
|
||||
engines = []
|
||||
|
||||
# Base engines (only if installed AND enabled)
|
||||
if TESSERACT_AVAILABLE and tesseract_enabled:
|
||||
engines.append('tesseract')
|
||||
if DOCTR_AVAILABLE:
|
||||
engines.append('doctr')
|
||||
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
|
||||
if PADDLE_AVAILABLE and paddle_enabled:
|
||||
engines.append('paddleocr')
|
||||
|
||||
return engines
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,735 @@
|
||||
"""Main OCR service coordinating preprocessing, recognition, and extraction."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x) - must be set before import
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.services.ocr_extractor import ReceiptExtractor, ExtractionResult
|
||||
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_memory_usage_mb() -> float:
|
||||
"""Get current process memory usage in MB."""
|
||||
try:
|
||||
import resource
|
||||
# Get memory in KB, convert to MB
|
||||
rusage = resource.getrusage(resource.RUSAGE_SELF)
|
||||
return rusage.ru_maxrss / 1024 # Linux returns KB
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
class OCRService:
|
||||
"""Service for OCR processing of receipt images."""
|
||||
|
||||
# Single worker to prevent memory accumulation from parallel OCR
|
||||
_executor = ThreadPoolExecutor(max_workers=1)
|
||||
# Semaphore to ensure only one OCR operation at a time (memory protection)
|
||||
_ocr_semaphore = threading.Semaphore(1)
|
||||
# Memory threshold in MB - if exceeded, force GC before processing
|
||||
_memory_threshold_mb = 2500
|
||||
|
||||
def __init__(self):
|
||||
self.preprocessor = ImagePreprocessor()
|
||||
self.ocr_engine = OCREngine()
|
||||
self.extractor = ReceiptExtractor()
|
||||
|
||||
async def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
mime_type: str
|
||||
) -> Tuple[bool, str, Optional[ExtractionResult]]:
|
||||
"""
|
||||
Process receipt image and extract structured data.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
mime_type: MIME type of the file
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message, extraction_result)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._process_sync,
|
||||
image_path,
|
||||
mime_type
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return False, f"OCR processing failed: {str(e)}", None
|
||||
|
||||
def _cleanup_memory(self, *arrays):
|
||||
"""Explicitly delete numpy arrays and force garbage collection."""
|
||||
for arr in arrays:
|
||||
if arr is not None:
|
||||
try:
|
||||
del arr
|
||||
except:
|
||||
pass
|
||||
gc.collect()
|
||||
|
||||
def _process_sync(
|
||||
self,
|
||||
image_path: Path,
|
||||
mime_type: str
|
||||
) -> Tuple[bool, str, Optional[ExtractionResult]]:
|
||||
"""Synchronous processing with ADAPTIVE OCR pipeline."""
|
||||
|
||||
# Acquire semaphore to ensure only one OCR at a time
|
||||
acquired = self._ocr_semaphore.acquire(timeout=120) # 2 min timeout
|
||||
if not acquired:
|
||||
return False, "OCR service busy - please try again", None
|
||||
|
||||
try:
|
||||
return self._process_sync_internal(image_path, mime_type)
|
||||
finally:
|
||||
# Always release semaphore and cleanup
|
||||
self._ocr_semaphore.release()
|
||||
# Force garbage collection after EVERY OCR request
|
||||
gc.collect()
|
||||
mem_after = get_memory_usage_mb()
|
||||
print(f"[OCR Service] Memory after cleanup: {mem_after:.0f}MB", flush=True)
|
||||
|
||||
def _process_sync_internal(
|
||||
self,
|
||||
image_path: Path,
|
||||
mime_type: str
|
||||
) -> Tuple[bool, str, Optional[ExtractionResult]]:
|
||||
"""Internal processing - called with semaphore held."""
|
||||
|
||||
start_time = time.time()
|
||||
mem_before = get_memory_usage_mb()
|
||||
print(f"[OCR Service] Starting processing: {image_path}, mime: {mime_type}", flush=True)
|
||||
print(f"[OCR Service] Memory before: {mem_before:.0f}MB", flush=True)
|
||||
|
||||
# Check if memory is high - force GC before processing
|
||||
if mem_before > self._memory_threshold_mb:
|
||||
print(f"[OCR Service] WARNING: Memory high ({mem_before:.0f}MB > {self._memory_threshold_mb}MB), forcing GC...", flush=True)
|
||||
gc.collect()
|
||||
mem_after_gc = get_memory_usage_mb()
|
||||
print(f"[OCR Service] Memory after pre-GC: {mem_after_gc:.0f}MB", flush=True)
|
||||
|
||||
# Load image
|
||||
images = None # For cleanup
|
||||
image = None
|
||||
if mime_type == 'application/pdf':
|
||||
try:
|
||||
images = self.preprocessor.pdf_to_images(image_path)
|
||||
if not images:
|
||||
return False, "Failed to extract images from PDF", None
|
||||
image = images[0]
|
||||
# Delete other pages immediately to save memory
|
||||
if len(images) > 1:
|
||||
for i in range(1, len(images)):
|
||||
del images[i]
|
||||
images = [image]
|
||||
except RuntimeError as e:
|
||||
return False, str(e), None
|
||||
else:
|
||||
try:
|
||||
image = self.preprocessor.load_image(image_path)
|
||||
except ValueError as e:
|
||||
return False, str(e), None
|
||||
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 1: PaddleOCR + Light (fastest, best for clear PDFs)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 1: PaddleOCR + Light preprocessing", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
light_img = self.preprocessor.preprocess_light(image)
|
||||
|
||||
try:
|
||||
paddle_light = self.ocr_engine._paddle_recognize(light_img)
|
||||
# Cleanup light_img immediately after OCR
|
||||
del light_img
|
||||
light_img = None
|
||||
|
||||
if paddle_light and paddle_light.text:
|
||||
extraction = self.extractor.extract(paddle_light.text)
|
||||
extraction.ocr_engine = "paddle-light"
|
||||
raw_texts.append(f"═══ PaddleOCR (light, conf: {paddle_light.confidence:.0%}) ═══\n{paddle_light.text}")
|
||||
|
||||
# Log extraction results
|
||||
print(f"[OCR] Step 1 Results:", flush=True)
|
||||
print(f" - OCR Confidence: {paddle_light.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total} (entries: {len(extraction.tva_entries) if extraction.tva_entries else 0})", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
|
||||
# Early exit if complete
|
||||
if self._is_extraction_complete(extraction):
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
print(f"[OCR] *** EARLY EXIT at Step 1 - All fields found! ({elapsed_ms}ms) ***", flush=True)
|
||||
# Cleanup before return
|
||||
del image
|
||||
if images:
|
||||
del images
|
||||
return True, "OCR complete (fast mode)", extraction
|
||||
else:
|
||||
print("[OCR] -> Step 1 incomplete, continuing to Step 2...", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[OCR] PaddleOCR light failed: {e}", flush=True)
|
||||
extraction = ExtractionResult()
|
||||
# Cleanup on error
|
||||
if light_img is not None:
|
||||
del light_img
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 2: PaddleOCR + Medium (balanced preprocessing)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 2: PaddleOCR + Medium preprocessing", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
medium_img = self.preprocessor.preprocess_medium(image)
|
||||
|
||||
try:
|
||||
paddle_medium = self.ocr_engine._paddle_recognize(medium_img)
|
||||
# Cleanup medium_img immediately after OCR
|
||||
del medium_img
|
||||
medium_img = None
|
||||
|
||||
if paddle_medium and paddle_medium.text:
|
||||
extraction_medium = self.extractor.extract(paddle_medium.text)
|
||||
extraction_medium.ocr_engine = "paddle-medium"
|
||||
raw_texts.append(f"═══ PaddleOCR (medium, conf: {paddle_medium.confidence:.0%}) ═══\n{paddle_medium.text}")
|
||||
|
||||
print(f"[OCR] Step 2 (Medium) Results:", flush=True)
|
||||
print(f" - OCR Confidence: {paddle_medium.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction_medium.amount}", flush=True)
|
||||
print(f" - Date: {extraction_medium.receipt_date}", flush=True)
|
||||
print(f" - CUI: {extraction_medium.cui}", flush=True)
|
||||
|
||||
# Merge with previous
|
||||
extraction = self._merge_extractions(extraction, extraction_medium)
|
||||
|
||||
print(f"[OCR] After merge:", flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total}", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
|
||||
if self._is_extraction_complete(extraction):
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
extraction.ocr_engine = "paddle-adaptive"
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
print(f"[OCR] *** EARLY EXIT at Step 2 - All fields found after merge! ({elapsed_ms}ms) ***", flush=True)
|
||||
# Cleanup before return
|
||||
del image
|
||||
if images:
|
||||
del images
|
||||
return True, "OCR complete (paddle dual)", extraction
|
||||
else:
|
||||
print("[OCR] -> Step 2 incomplete, continuing to Step 3 (Tesseract)...", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[OCR] PaddleOCR medium failed: {e}", flush=True)
|
||||
# Cleanup on error
|
||||
if medium_img is not None:
|
||||
del medium_img
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 3: Tesseract - ONLY to complete missing fields
|
||||
# Uses Tesseract-optimized preprocessing (binarized, high contrast)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 3: Tesseract (complement only, not override)", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
tesseract_img = None
|
||||
try:
|
||||
# Use Tesseract-specific preprocessing (Otsu binarization)
|
||||
tesseract_img = self.preprocessor.preprocess_for_tesseract(image)
|
||||
tesseract_result = self.ocr_engine._tesseract_recognize(tesseract_img)
|
||||
# Cleanup tesseract_img immediately after OCR
|
||||
del tesseract_img
|
||||
tesseract_img = None
|
||||
|
||||
if tesseract_result and tesseract_result.text:
|
||||
extraction_tess = self.extractor.extract(tesseract_result.text)
|
||||
extraction_tess.ocr_engine = "tesseract"
|
||||
raw_texts.append(f"═══ Tesseract (conf: {tesseract_result.confidence:.0%}) ═══\n{tesseract_result.text}")
|
||||
|
||||
print(f"[OCR] Step 3 (Tesseract) Results:", flush=True)
|
||||
print(f" - OCR Confidence: {tesseract_result.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction_tess.amount}", flush=True)
|
||||
print(f" - Date: {extraction_tess.receipt_date}", flush=True)
|
||||
print(f" - CUI: {extraction_tess.cui}", flush=True)
|
||||
|
||||
# IMPORTANT: Tesseract only COMPLETES missing fields, never overrides!
|
||||
extraction = self._complement_extraction(extraction, extraction_tess)
|
||||
except Exception as e:
|
||||
print(f"[OCR] Tesseract failed: {e}", flush=True)
|
||||
# Cleanup on error
|
||||
if tesseract_img is not None:
|
||||
del tesseract_img
|
||||
|
||||
# Cleanup original image - no longer needed
|
||||
del image
|
||||
if images:
|
||||
del images
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# FINAL VALIDATION: Fix impossible values
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
if extraction:
|
||||
extraction = self._final_validation(extraction)
|
||||
|
||||
# Final result
|
||||
if extraction is None:
|
||||
return False, "No text detected", None
|
||||
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
extraction.ocr_engine = "adaptive-full"
|
||||
|
||||
# Build result message
|
||||
fields_found = []
|
||||
if extraction.amount: fields_found.append("amount")
|
||||
if extraction.receipt_date: fields_found.append("date")
|
||||
if extraction.receipt_number: fields_found.append("number")
|
||||
if extraction.cui: fields_found.append("CUI")
|
||||
if extraction.tva_total or extraction.tva_entries: fields_found.append("TVA")
|
||||
|
||||
message = f"OCR complete (full pipeline). Found: {', '.join(fields_found) or 'no fields'}"
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print(f"[OCR] FINAL RESULT (full pipeline) - {elapsed_ms}ms", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total}", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
print(f" - Processing Time: {elapsed_ms}ms", flush=True)
|
||||
print(f" - Message: {message}", flush=True)
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# VALIDATION: Apply validation rules to final extraction
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("\n" + "=" * 60, flush=True)
|
||||
print("[Validation] Applying validation rules...", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
validator = OCRValidationEngine()
|
||||
|
||||
# Prepare data for validation with safe type conversions
|
||||
def safe_float(value) -> Optional[float]:
|
||||
"""Safely convert Decimal or number to float."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def safe_payment_sum(methods: list, method_type: str) -> Optional[float]:
|
||||
"""Safely sum payment amounts for a given method type."""
|
||||
if not methods:
|
||||
return None
|
||||
try:
|
||||
total = sum(
|
||||
float(pm.get('amount', 0) or 0)
|
||||
for pm in methods
|
||||
if pm.get('method') == method_type
|
||||
)
|
||||
return total if total > 0 else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
validation_data = {
|
||||
'amount': safe_float(extraction.amount),
|
||||
'tva': safe_float(extraction.tva_total),
|
||||
'cui': extraction.cui,
|
||||
'card_amount': safe_payment_sum(extraction.payment_methods, 'CARD'),
|
||||
'cash_amount': safe_payment_sum(extraction.payment_methods, 'NUMERAR'),
|
||||
'tva_entries': {
|
||||
entry.get('code', ''): safe_float(entry.get('amount'))
|
||||
for entry in (extraction.tva_entries or [])
|
||||
if entry.get('code') and safe_float(entry.get('amount')) is not None
|
||||
}
|
||||
}
|
||||
|
||||
# Run validation (no light/medium comparison for final result)
|
||||
validated_result = validator.validate_extraction(validation_data)
|
||||
|
||||
# Apply validation results to extraction
|
||||
extraction.needs_manual_review = validated_result.needs_manual_review
|
||||
extraction.validation_warnings = validated_result.validation_warnings
|
||||
extraction.validation_errors = validated_result.validation_errors
|
||||
extraction.confidence_adjustments = validated_result.confidence_adjustments
|
||||
extraction.inter_ocr_ratios = validated_result.inter_ocr_ratios
|
||||
|
||||
print(f"[Validation] Complete:", flush=True)
|
||||
print(f" - Warnings: {len(extraction.validation_warnings)}", flush=True)
|
||||
print(f" - Errors: {len(extraction.validation_errors)}", flush=True)
|
||||
print(f" - Needs Manual Review: {extraction.needs_manual_review}", flush=True)
|
||||
if extraction.validation_warnings:
|
||||
for warning in extraction.validation_warnings:
|
||||
print(f" [!] {warning}", flush=True)
|
||||
|
||||
return True, message, extraction
|
||||
|
||||
def _merge_extractions(
|
||||
self,
|
||||
paddle: Optional[ExtractionResult],
|
||||
tesseract: Optional[ExtractionResult]
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Merge two extractions, picking best fields from each engine.
|
||||
|
||||
Strategy:
|
||||
- For each field, prefer the one with higher confidence
|
||||
- Use validation rules (CUI format, date validity, company indicators)
|
||||
- Combine TVA entries if different
|
||||
"""
|
||||
result = ExtractionResult()
|
||||
|
||||
# Handle case where one is None
|
||||
if paddle is None and tesseract is None:
|
||||
return result
|
||||
if paddle is None:
|
||||
return tesseract
|
||||
if tesseract is None:
|
||||
return paddle
|
||||
|
||||
print("[Merge] Comparing PaddleOCR vs Tesseract extractions...", flush=True)
|
||||
|
||||
# === AMOUNT ===
|
||||
# Pick higher confidence, both must be positive
|
||||
if paddle.amount and tesseract.amount:
|
||||
if paddle.confidence_amount >= tesseract.confidence_amount:
|
||||
result.amount = paddle.amount
|
||||
result.confidence_amount = paddle.confidence_amount
|
||||
print(f"[Merge] Amount: PaddleOCR {paddle.amount} (conf: {paddle.confidence_amount:.0%})", flush=True)
|
||||
else:
|
||||
result.amount = tesseract.amount
|
||||
result.confidence_amount = tesseract.confidence_amount
|
||||
print(f"[Merge] Amount: Tesseract {tesseract.amount} (conf: {tesseract.confidence_amount:.0%})", flush=True)
|
||||
elif paddle.amount:
|
||||
result.amount = paddle.amount
|
||||
result.confidence_amount = paddle.confidence_amount
|
||||
elif tesseract.amount:
|
||||
result.amount = tesseract.amount
|
||||
result.confidence_amount = tesseract.confidence_amount
|
||||
|
||||
# === DATE ===
|
||||
# Pick higher confidence, validate date reasonableness
|
||||
if paddle.receipt_date and tesseract.receipt_date:
|
||||
if paddle.confidence_date >= tesseract.confidence_date:
|
||||
result.receipt_date = paddle.receipt_date
|
||||
result.confidence_date = paddle.confidence_date
|
||||
print(f"[Merge] Date: PaddleOCR {paddle.receipt_date}", flush=True)
|
||||
else:
|
||||
result.receipt_date = tesseract.receipt_date
|
||||
result.confidence_date = tesseract.confidence_date
|
||||
print(f"[Merge] Date: Tesseract {tesseract.receipt_date}", flush=True)
|
||||
elif paddle.receipt_date:
|
||||
result.receipt_date = paddle.receipt_date
|
||||
result.confidence_date = paddle.confidence_date
|
||||
elif tesseract.receipt_date:
|
||||
result.receipt_date = tesseract.receipt_date
|
||||
result.confidence_date = tesseract.confidence_date
|
||||
|
||||
# === VENDOR NAME ===
|
||||
# Prefer one with company indicators (S.R.L., S.A., etc.)
|
||||
paddle_has_indicator = self._has_company_indicator(paddle.partner_name)
|
||||
tesseract_has_indicator = self._has_company_indicator(tesseract.partner_name)
|
||||
|
||||
if paddle.partner_name and tesseract.partner_name:
|
||||
if paddle_has_indicator and not tesseract_has_indicator:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (has company indicator)", flush=True)
|
||||
elif tesseract_has_indicator and not paddle_has_indicator:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (has company indicator)", flush=True)
|
||||
elif paddle.confidence_vendor >= tesseract.confidence_vendor:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (higher conf)", flush=True)
|
||||
else:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (higher conf)", flush=True)
|
||||
elif paddle.partner_name:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
elif tesseract.partner_name:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
|
||||
# === CUI (Fiscal Code) ===
|
||||
# Validate format: 6-10 digits, prefer valid one
|
||||
paddle_cui_valid = self._is_valid_cui(paddle.cui)
|
||||
tesseract_cui_valid = self._is_valid_cui(tesseract.cui)
|
||||
|
||||
if paddle.cui and tesseract.cui:
|
||||
if paddle_cui_valid and not tesseract_cui_valid:
|
||||
result.cui = paddle.cui
|
||||
print(f"[Merge] CUI: PaddleOCR {paddle.cui} (valid format)", flush=True)
|
||||
elif tesseract_cui_valid and not paddle_cui_valid:
|
||||
result.cui = tesseract.cui
|
||||
print(f"[Merge] CUI: Tesseract {tesseract.cui} (valid format)", flush=True)
|
||||
else:
|
||||
# Both valid or both invalid - prefer PaddleOCR
|
||||
result.cui = paddle.cui
|
||||
print(f"[Merge] CUI: PaddleOCR {paddle.cui}", flush=True)
|
||||
elif paddle.cui and paddle_cui_valid:
|
||||
result.cui = paddle.cui
|
||||
elif tesseract.cui and tesseract_cui_valid:
|
||||
result.cui = tesseract.cui
|
||||
elif paddle.cui:
|
||||
result.cui = paddle.cui
|
||||
elif tesseract.cui:
|
||||
result.cui = tesseract.cui
|
||||
|
||||
# === TVA ENTRIES ===
|
||||
# Prefer non-empty, use the one with more entries or higher amounts
|
||||
if paddle.tva_entries and tesseract.tva_entries:
|
||||
# Compare: prefer the one with actual amounts (not just 0)
|
||||
paddle_total = sum(e.get('amount', Decimal('0')) for e in paddle.tva_entries)
|
||||
tesseract_total = sum(e.get('amount', Decimal('0')) for e in tesseract.tva_entries)
|
||||
|
||||
if paddle_total >= tesseract_total:
|
||||
result.tva_entries = paddle.tva_entries
|
||||
result.tva_total = paddle.tva_total
|
||||
print(f"[Merge] TVA: PaddleOCR (total: {paddle_total})", flush=True)
|
||||
else:
|
||||
result.tva_entries = tesseract.tva_entries
|
||||
result.tva_total = tesseract.tva_total
|
||||
print(f"[Merge] TVA: Tesseract (total: {tesseract_total})", flush=True)
|
||||
elif paddle.tva_entries:
|
||||
result.tva_entries = paddle.tva_entries
|
||||
result.tva_total = paddle.tva_total
|
||||
elif tesseract.tva_entries:
|
||||
result.tva_entries = tesseract.tva_entries
|
||||
result.tva_total = tesseract.tva_total
|
||||
|
||||
# === OTHER FIELDS ===
|
||||
# Simple preference: paddle > tesseract
|
||||
result.receipt_number = paddle.receipt_number or tesseract.receipt_number
|
||||
result.receipt_series = paddle.receipt_series or tesseract.receipt_series
|
||||
result.receipt_type = paddle.receipt_type or tesseract.receipt_type
|
||||
result.items_count = paddle.items_count or tesseract.items_count
|
||||
result.address = paddle.address or tesseract.address
|
||||
result.description = paddle.description or tesseract.description
|
||||
|
||||
return result
|
||||
|
||||
def _has_company_indicator(self, name: Optional[str]) -> bool:
|
||||
"""Check if vendor name has company type indicator (S.R.L., S.A., etc.)"""
|
||||
if not name:
|
||||
return False
|
||||
name_upper = name.upper()
|
||||
indicators = [
|
||||
r'\bS\.?\s*R\.?\s*L\.?\b',
|
||||
r'\bS\.?\s*A\.?\b',
|
||||
r'\bS\.?\s*N\.?\s*C\.?\b',
|
||||
r'\bP\.?\s*F\.?\s*A\.?\b',
|
||||
r'\bI\.?\s*I\.?\b',
|
||||
r'\bHOLDING\b',
|
||||
r'\bGROUP\b',
|
||||
r'\bCOMPANY\b',
|
||||
]
|
||||
for indicator in indicators:
|
||||
if re.search(indicator, name_upper):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_valid_cui(self, cui: Optional[str]) -> bool:
|
||||
"""Validate CUI format: 6-10 digits."""
|
||||
if not cui:
|
||||
return False
|
||||
# Remove any RO prefix
|
||||
cui_clean = re.sub(r'^RO', '', cui.upper())
|
||||
# Must be 6-10 digits
|
||||
return bool(re.match(r'^\d{6,10}$', cui_clean))
|
||||
|
||||
def _is_extraction_complete(self, ext: ExtractionResult, min_confidence: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if extraction has ALL required fields to skip further processing.
|
||||
|
||||
Required for early exit (ALL must be true):
|
||||
- Overall confidence >= 85%
|
||||
- ALL 5 critical fields present: number, date, amount, TVA, CUI
|
||||
"""
|
||||
# Must have high confidence
|
||||
if ext.overall_confidence < min_confidence:
|
||||
print(f"[OCR] Confidence {ext.overall_confidence:.0%} < {min_confidence:.0%} - continuing", flush=True)
|
||||
return False
|
||||
|
||||
# Check all required fields
|
||||
has_number = bool(ext.receipt_number)
|
||||
has_date = bool(ext.receipt_date)
|
||||
has_amount = bool(ext.amount)
|
||||
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
|
||||
has_cui = bool(ext.cui)
|
||||
|
||||
missing = []
|
||||
if not has_number: missing.append("number")
|
||||
if not has_date: missing.append("date")
|
||||
if not has_amount: missing.append("amount")
|
||||
if not has_tva: missing.append("TVA")
|
||||
if not has_cui: missing.append("CUI")
|
||||
|
||||
if missing:
|
||||
print(f"[OCR] Missing: {', '.join(missing)} - continuing", flush=True)
|
||||
return False
|
||||
|
||||
print(f"[OCR] OK: All 5 fields found with {ext.overall_confidence:.0%} confidence", flush=True)
|
||||
return True
|
||||
|
||||
def _complement_extraction(
|
||||
self,
|
||||
primary: Optional[ExtractionResult],
|
||||
secondary: Optional[ExtractionResult]
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Complement primary extraction with missing fields from secondary.
|
||||
NEVER overrides existing values - only fills in gaps.
|
||||
|
||||
This is different from _merge_extractions which can override values.
|
||||
"""
|
||||
if primary is None and secondary is None:
|
||||
return ExtractionResult()
|
||||
if primary is None:
|
||||
return secondary
|
||||
if secondary is None:
|
||||
return primary
|
||||
|
||||
print("[Complement] Adding missing fields from Tesseract...", flush=True)
|
||||
|
||||
# Only fill missing amount
|
||||
if not primary.amount and secondary.amount:
|
||||
primary.amount = secondary.amount
|
||||
primary.confidence_amount = secondary.confidence_amount
|
||||
print(f"[Complement] Added amount: {secondary.amount}", flush=True)
|
||||
|
||||
# Only fill missing date
|
||||
if not primary.receipt_date and secondary.receipt_date:
|
||||
primary.receipt_date = secondary.receipt_date
|
||||
primary.confidence_date = secondary.confidence_date
|
||||
print(f"[Complement] Added date: {secondary.receipt_date}", flush=True)
|
||||
|
||||
# Only fill missing vendor
|
||||
if not primary.partner_name and secondary.partner_name:
|
||||
primary.partner_name = secondary.partner_name
|
||||
primary.confidence_vendor = secondary.confidence_vendor
|
||||
print(f"[Complement] Added vendor: {secondary.partner_name}", flush=True)
|
||||
|
||||
# Only fill missing CUI
|
||||
if not primary.cui and secondary.cui and self._is_valid_cui(secondary.cui):
|
||||
primary.cui = secondary.cui
|
||||
print(f"[Complement] Added CUI: {secondary.cui}", flush=True)
|
||||
|
||||
# Only fill missing TVA
|
||||
if not primary.tva_entries and secondary.tva_entries:
|
||||
primary.tva_entries = secondary.tva_entries
|
||||
primary.tva_total = secondary.tva_total
|
||||
print(f"[Complement] Added TVA: {secondary.tva_total}", flush=True)
|
||||
|
||||
# Only fill missing receipt number
|
||||
if not primary.receipt_number and secondary.receipt_number:
|
||||
primary.receipt_number = secondary.receipt_number
|
||||
print(f"[Complement] Added number: {secondary.receipt_number}", flush=True)
|
||||
|
||||
# Only fill missing address
|
||||
if not primary.address and secondary.address:
|
||||
primary.address = secondary.address
|
||||
print(f"[Complement] Added address: {secondary.address}", flush=True)
|
||||
|
||||
return primary
|
||||
|
||||
def _final_validation(self, extraction: ExtractionResult) -> ExtractionResult:
|
||||
"""
|
||||
Final validation and correction of impossible values.
|
||||
|
||||
Key rules:
|
||||
1. TVA cannot be greater than TOTAL (it's always a fraction)
|
||||
2. If TVA > TOTAL, recalculate TOTAL from TVA using known rates
|
||||
3. Validate TVA entries sum equals TVA total
|
||||
"""
|
||||
print("[Final Validation] Checking extracted values...", flush=True)
|
||||
|
||||
# Rule 1: TVA cannot be greater than TOTAL
|
||||
if extraction.tva_total and extraction.amount:
|
||||
if extraction.tva_total > extraction.amount:
|
||||
print(f"[Final Validation] TVA ({extraction.tva_total}) > TOTAL ({extraction.amount}) - IMPOSSIBLE!", flush=True)
|
||||
|
||||
# Calculate TOTAL from TVA using reverse formula:
|
||||
# total = base + tva = tva * (100/rate + 1) = tva * (100 + rate) / rate
|
||||
# For 9% TVA: total = tva * 109 / 9 = tva * 12.11
|
||||
# For 19% TVA: total = tva * 119 / 19 = tva * 6.26
|
||||
# For 21% TVA: total = tva * 121 / 21 = tva * 5.76
|
||||
|
||||
rate = 19 # Default rate assumption
|
||||
if extraction.tva_entries:
|
||||
# Use the rate from the first entry
|
||||
rate = extraction.tva_entries[0].get('percent', 19)
|
||||
|
||||
if rate > 0:
|
||||
# Formula: total = tva * (100 + rate) / rate
|
||||
calculated_total = extraction.tva_total * (Decimal('100') + Decimal(str(rate))) / Decimal(str(rate))
|
||||
calculated_total = calculated_total.quantize(Decimal('0.01'))
|
||||
|
||||
print(f"[Final Validation] Calculated TOTAL from TVA: {calculated_total} (using {rate}% rate)", flush=True)
|
||||
|
||||
extraction.amount = calculated_total
|
||||
extraction.confidence_amount = 0.70 # Lower confidence for calculated value
|
||||
|
||||
# Rule 2: TVA cannot be more than ~25% of total (max Romanian rate is 21%)
|
||||
if extraction.tva_total and extraction.amount:
|
||||
tva_percent = extraction.tva_total / extraction.amount * Decimal('100')
|
||||
if tva_percent > Decimal('25'):
|
||||
print(f"[Final Validation] Warning: TVA is {tva_percent:.1f}% of total - suspicious", flush=True)
|
||||
|
||||
# Rule 3: Validate TVA entries sum
|
||||
if extraction.tva_entries and extraction.tva_total:
|
||||
entries_sum = sum(e.get('amount', Decimal('0')) for e in extraction.tva_entries)
|
||||
tolerance = Decimal('0.05')
|
||||
if abs(entries_sum - extraction.tva_total) > tolerance:
|
||||
print(f"[Final Validation] TVA entries sum ({entries_sum}) != tva_total ({extraction.tva_total})", flush=True)
|
||||
# Use the sum as it's more reliable
|
||||
extraction.tva_total = entries_sum
|
||||
|
||||
print(f"[Final Validation] Done. Amount={extraction.amount}, TVA={extraction.tva_total}", flush=True)
|
||||
return extraction
|
||||
|
||||
|
||||
# Singleton instance
|
||||
ocr_service = OCRService()
|
||||
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Auto-create Receipt from OCR results for bulk upload flow.
|
||||
|
||||
This service handles automatic creation of Receipt records from OCR extraction
|
||||
results, enabling end-to-end processing without manual UI intervention.
|
||||
|
||||
The service:
|
||||
1. Maps OCR ExtractionData fields to Receipt fields
|
||||
2. Creates attachment from the original uploaded file
|
||||
3. Generates accounting entries
|
||||
4. Links the receipt back to the batch job for tracking
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import (
|
||||
Receipt,
|
||||
ReceiptAttachment,
|
||||
ReceiptStatus,
|
||||
ReceiptType,
|
||||
ReceiptDirection,
|
||||
)
|
||||
from backend.modules.data_entry.db.models.batch import BatchJob
|
||||
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
|
||||
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
|
||||
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, TvaEntrySchema, PaymentMethodSchema
|
||||
from backend.modules.data_entry.schemas.ocr import ExtractionData
|
||||
from backend.modules.data_entry.services.receipt_service import ReceiptService
|
||||
from backend.modules.data_entry.services import sse_service
|
||||
from backend.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReceiptCreateResult:
|
||||
"""Result of auto-create operation."""
|
||||
success: bool
|
||||
receipt_id: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ReceiptAutoCreateService:
|
||||
"""
|
||||
Service for automatically creating receipts from OCR results.
|
||||
|
||||
Used by the bulk upload flow to create receipts without user intervention.
|
||||
Created receipts are in DRAFT status and require review before approval.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_ocr_result(ocr_result: ExtractionData) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform minimal validation on OCR result.
|
||||
|
||||
Validates:
|
||||
- amount > 0 (required for receipt)
|
||||
- date is valid and not in future
|
||||
|
||||
Args:
|
||||
ocr_result: Extracted data from OCR
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Validate amount exists and is positive
|
||||
if ocr_result.amount is None:
|
||||
return False, "Amount not extracted from receipt"
|
||||
|
||||
if ocr_result.amount <= 0:
|
||||
return False, f"Invalid amount: {ocr_result.amount} (must be > 0)"
|
||||
|
||||
# Validate date exists and is not in the future
|
||||
if ocr_result.receipt_date is None:
|
||||
return False, "Receipt date not extracted"
|
||||
|
||||
today = date.today()
|
||||
if ocr_result.receipt_date > today:
|
||||
return False, f"Receipt date {ocr_result.receipt_date} is in the future"
|
||||
|
||||
return True, ""
|
||||
|
||||
@staticmethod
|
||||
def _map_ocr_to_receipt(
|
||||
ocr_result: ExtractionData,
|
||||
company_id: int,
|
||||
) -> ReceiptCreate:
|
||||
"""
|
||||
Map OCR ExtractionData fields to ReceiptCreate schema.
|
||||
|
||||
Args:
|
||||
ocr_result: Extracted data from OCR
|
||||
company_id: Company ID for the receipt
|
||||
|
||||
Returns:
|
||||
ReceiptCreate schema ready for database insertion
|
||||
"""
|
||||
# Map receipt type
|
||||
receipt_type = ReceiptType.BON_FISCAL
|
||||
if ocr_result.receipt_type == "chitanta":
|
||||
receipt_type = ReceiptType.CHITANTA
|
||||
|
||||
# Map TVA breakdown from OCR TvaEntry to schema TvaEntrySchema
|
||||
tva_breakdown: Optional[List[TvaEntrySchema]] = None
|
||||
if ocr_result.tva_entries:
|
||||
tva_breakdown = [
|
||||
TvaEntrySchema(
|
||||
code=entry.code,
|
||||
percent=entry.percent,
|
||||
amount=entry.amount
|
||||
)
|
||||
for entry in ocr_result.tva_entries
|
||||
]
|
||||
|
||||
# Map payment methods
|
||||
payment_methods: Optional[List[PaymentMethodSchema]] = None
|
||||
if ocr_result.payment_methods:
|
||||
payment_methods = [
|
||||
PaymentMethodSchema(
|
||||
method=pm.method,
|
||||
amount=pm.amount
|
||||
)
|
||||
for pm in ocr_result.payment_methods
|
||||
]
|
||||
|
||||
# Create receipt data
|
||||
return ReceiptCreate(
|
||||
receipt_type=receipt_type,
|
||||
direction=ReceiptDirection.CHELTUIALA, # Default to expense
|
||||
receipt_number=ocr_result.receipt_number,
|
||||
receipt_series=ocr_result.receipt_series,
|
||||
receipt_date=ocr_result.receipt_date,
|
||||
amount=ocr_result.amount,
|
||||
description=ocr_result.description,
|
||||
tva_breakdown=tva_breakdown,
|
||||
tva_total=ocr_result.tva_total,
|
||||
items_count=ocr_result.items_count,
|
||||
vendor_address=ocr_result.address,
|
||||
company_id=company_id,
|
||||
partner_name=ocr_result.partner_name,
|
||||
cui=ocr_result.cui,
|
||||
ocr_raw_text=ocr_result.raw_text[:5000] if ocr_result.raw_text else None, # Limit size
|
||||
payment_methods=payment_methods,
|
||||
payment_mode=ocr_result.suggested_payment_mode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_attachment_from_file(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
source_file_path: str,
|
||||
original_filename: Optional[str] = None,
|
||||
) -> Optional[ReceiptAttachment]:
|
||||
"""
|
||||
Create attachment by copying file from OCR job location.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
receipt_id: Receipt ID to attach to
|
||||
source_file_path: Path to the original file from OCR job
|
||||
original_filename: Original filename from upload (optional)
|
||||
|
||||
Returns:
|
||||
Created ReceiptAttachment or None if failed
|
||||
"""
|
||||
source_path = Path(source_file_path)
|
||||
|
||||
if not source_path.exists():
|
||||
logger.warning(f"[ReceiptAutoCreate] Source file not found: {source_path}")
|
||||
return None
|
||||
|
||||
# Generate stored filename
|
||||
ext = source_path.suffix.lower()
|
||||
stored_filename = f"{uuid.uuid4()}{ext}"
|
||||
|
||||
# Determine relative path (organized by year/month)
|
||||
now = datetime.utcnow()
|
||||
relative_path = Path(str(now.year)) / f"{now.month:02d}"
|
||||
|
||||
# Full destination path
|
||||
dest_dir = settings.data_entry_upload_path_resolved / relative_path
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = dest_dir / stored_filename
|
||||
|
||||
# Copy file to attachments directory
|
||||
try:
|
||||
shutil.copy2(source_path, dest_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[ReceiptAutoCreate] Failed to copy file: {e}")
|
||||
return None
|
||||
|
||||
# Get file size
|
||||
file_size = dest_path.stat().st_size
|
||||
|
||||
# Determine MIME type
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.pdf': 'application/pdf',
|
||||
}
|
||||
mime_type = mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
# Use original filename if provided, otherwise use source filename
|
||||
display_filename = original_filename or source_path.name
|
||||
|
||||
# Create attachment record
|
||||
attachment = ReceiptAttachment(
|
||||
receipt_id=receipt_id,
|
||||
filename=display_filename,
|
||||
stored_filename=stored_filename,
|
||||
file_path=str(relative_path / stored_filename),
|
||||
file_size=file_size,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
session.add(attachment)
|
||||
await session.flush()
|
||||
|
||||
return attachment
|
||||
|
||||
@staticmethod
|
||||
async def _update_batch_job_receipt_id(
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
receipt_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update batch_jobs table with the created receipt_id.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
job_id: OCR job UUID
|
||||
receipt_id: Created receipt ID
|
||||
"""
|
||||
await session.execute(
|
||||
update(BatchJob)
|
||||
.where(BatchJob.job_id == job_id)
|
||||
.values(receipt_id=receipt_id)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_from_ocr_result(
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
ocr_result: ExtractionData,
|
||||
username: str,
|
||||
batch_id: int,
|
||||
company_id: int,
|
||||
file_path: Optional[str] = None,
|
||||
original_filename: Optional[str] = None,
|
||||
file_hash: Optional[str] = None,
|
||||
) -> ReceiptCreateResult:
|
||||
"""
|
||||
Create a receipt from OCR extraction result.
|
||||
|
||||
This method:
|
||||
1. Validates the OCR result (amount > 0, date valid)
|
||||
2. Maps OCR fields to Receipt fields
|
||||
3. Creates the Receipt in DRAFT status
|
||||
4. Creates attachment from original file
|
||||
5. Generates accounting entries
|
||||
6. Updates batch_jobs with receipt_id
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
job_id: OCR job UUID for tracking
|
||||
ocr_result: Extracted data from OCR processing
|
||||
username: User who initiated the upload
|
||||
batch_id: Batch ID for grouping
|
||||
company_id: Company ID for the receipt
|
||||
file_path: Path to the original uploaded file
|
||||
original_filename: Original filename from upload
|
||||
file_hash: SHA-256 hash of the file for duplicate detection (US-007)
|
||||
|
||||
Returns:
|
||||
ReceiptCreateResult with success status and receipt_id or error
|
||||
"""
|
||||
try:
|
||||
# Step 1: Validate OCR result
|
||||
is_valid, error_msg = ReceiptAutoCreateService._validate_ocr_result(ocr_result)
|
||||
if not is_valid:
|
||||
logger.warning(f"[ReceiptAutoCreate] Validation failed for job {job_id}: {error_msg}")
|
||||
return ReceiptCreateResult(
|
||||
success=False,
|
||||
error_message=error_msg
|
||||
)
|
||||
|
||||
# Step 2: Map OCR to Receipt schema
|
||||
receipt_data = ReceiptAutoCreateService._map_ocr_to_receipt(
|
||||
ocr_result=ocr_result,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
# Step 3: Create receipt in DRAFT status
|
||||
receipt = await ReceiptCRUD.create(session, receipt_data, created_by=username)
|
||||
|
||||
# Set batch tracking fields (US-007, US-011)
|
||||
receipt.batch_id = str(batch_id)
|
||||
receipt.file_hash = file_hash
|
||||
receipt.processing_status = "completed"
|
||||
session.add(receipt)
|
||||
await session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[ReceiptAutoCreate] Created receipt {receipt.id} for job {job_id}: "
|
||||
f"amount={receipt.amount}, vendor={receipt.partner_name}, file_hash={file_hash[:16] if file_hash else None}..."
|
||||
)
|
||||
|
||||
# Step 4: Create attachment from original file (if path provided)
|
||||
if file_path:
|
||||
attachment = await ReceiptAutoCreateService._create_attachment_from_file(
|
||||
session=session,
|
||||
receipt_id=receipt.id,
|
||||
source_file_path=file_path,
|
||||
original_filename=original_filename,
|
||||
)
|
||||
if attachment:
|
||||
logger.info(f"[ReceiptAutoCreate] Created attachment for receipt {receipt.id}")
|
||||
else:
|
||||
logger.warning(f"[ReceiptAutoCreate] Failed to create attachment for receipt {receipt.id}")
|
||||
|
||||
# Step 5: Generate accounting entries
|
||||
# Note: For DRAFT status, entries are generated but not required for validation
|
||||
try:
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
if entries:
|
||||
await AccountingEntryCRUD.create_bulk(
|
||||
session, receipt.id, entries, is_auto_generated=True
|
||||
)
|
||||
logger.info(
|
||||
f"[ReceiptAutoCreate] Generated {len(entries)} accounting entries "
|
||||
f"for receipt {receipt.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the receipt creation if entry generation fails
|
||||
logger.warning(
|
||||
f"[ReceiptAutoCreate] Failed to generate entries for receipt {receipt.id}: {e}"
|
||||
)
|
||||
|
||||
# Step 6: Update batch_jobs with receipt_id
|
||||
await ReceiptAutoCreateService._update_batch_job_receipt_id(
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
receipt_id=receipt.id,
|
||||
)
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Broadcast SSE event for real-time updates (US-030)
|
||||
try:
|
||||
await sse_service.broadcast_status_change(
|
||||
receipt_id=receipt.id,
|
||||
status=receipt.status.value,
|
||||
processing_status=receipt.processing_status,
|
||||
batch_id=receipt.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the receipt creation if SSE broadcast fails
|
||||
logger.warning(f"[ReceiptAutoCreate] SSE broadcast failed for receipt {receipt.id}: {e}")
|
||||
|
||||
return ReceiptCreateResult(
|
||||
success=True,
|
||||
receipt_id=receipt.id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ReceiptAutoCreate] Failed to create receipt for job {job_id}: {e}")
|
||||
await session.rollback()
|
||||
return ReceiptCreateResult(
|
||||
success=False,
|
||||
error_message=str(e)
|
||||
)
|
||||
@@ -0,0 +1,457 @@
|
||||
"""Business logic service for receipts workflow."""
|
||||
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus, ReceiptDirection
|
||||
from backend.modules.data_entry.db.models.accounting_entry import EntryType
|
||||
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
|
||||
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptFilter,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
ProcessingStats,
|
||||
AccountingEntryCreate,
|
||||
)
|
||||
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES, get_expense_type
|
||||
|
||||
|
||||
# Payment mode to accounting account mapping
|
||||
PAYMENT_MODE_ACCOUNTS = {
|
||||
'casa': ('5311', 'Casa in lei'),
|
||||
'banca': ('5121', 'Conturi la banci in lei'),
|
||||
'avans_decontare': ('542', 'Avansuri de trezorerie'),
|
||||
}
|
||||
|
||||
|
||||
class ReceiptService:
|
||||
"""Service for receipt business logic and workflow."""
|
||||
|
||||
@staticmethod
|
||||
async def create_receipt(
|
||||
session: AsyncSession,
|
||||
data: ReceiptCreate,
|
||||
created_by: str,
|
||||
) -> Receipt:
|
||||
"""Create a new receipt in DRAFT status."""
|
||||
return await ReceiptCRUD.create(session, data, created_by)
|
||||
|
||||
@staticmethod
|
||||
async def get_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> Optional[Receipt]:
|
||||
"""Get receipt by ID with all relationships."""
|
||||
return await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def get_receipts(
|
||||
session: AsyncSession,
|
||||
filters: ReceiptFilter,
|
||||
) -> ReceiptListResponse:
|
||||
"""Get paginated list of receipts with processing_stats (US-012)."""
|
||||
receipts, total = await ReceiptCRUD.get_list(session, filters)
|
||||
|
||||
pages = (total + filters.page_size - 1) // filters.page_size if total > 0 else 1
|
||||
|
||||
# Get processing stats for bulk uploaded receipts (US-012)
|
||||
stats_dict = await ReceiptCRUD.get_processing_stats(
|
||||
session,
|
||||
company_id=filters.company_id,
|
||||
batch_id=filters.batch_id,
|
||||
)
|
||||
processing_stats = ProcessingStats(**stats_dict)
|
||||
|
||||
return ReceiptListResponse(
|
||||
items=[ReceiptResponse.model_validate(r) for r in receipts],
|
||||
total=total,
|
||||
page=filters.page,
|
||||
page_size=filters.page_size,
|
||||
pages=pages,
|
||||
processing_stats=processing_stats,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
data: ReceiptUpdate,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Update receipt (only DRAFT status).
|
||||
Returns (success, message, receipt).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if not await ReceiptCRUD.can_edit(receipt, username):
|
||||
return False, "Cannot edit this receipt", None
|
||||
|
||||
updated = await ReceiptCRUD.update(session, receipt, data)
|
||||
return True, "Receipt updated", updated
|
||||
|
||||
@staticmethod
|
||||
async def delete_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Delete receipt (only DRAFT status).
|
||||
Returns (success, message).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found"
|
||||
|
||||
if not await ReceiptCRUD.can_delete(receipt, username):
|
||||
return False, "Cannot delete this receipt"
|
||||
|
||||
await ReceiptCRUD.delete(session, receipt)
|
||||
return True, "Receipt deleted"
|
||||
|
||||
@staticmethod
|
||||
def generate_accounting_entries(receipt: Receipt) -> List[AccountingEntryCreate]:
|
||||
"""
|
||||
Generate accounting entries based on receipt data and expense type.
|
||||
"""
|
||||
entries: List[AccountingEntryCreate] = []
|
||||
|
||||
# Get expense type configuration
|
||||
expense_type = get_expense_type(receipt.expense_type_code or "OTHER")
|
||||
if not expense_type:
|
||||
expense_type = EXPENSE_TYPES["OTHER"]
|
||||
|
||||
amount = Decimal(str(receipt.amount))
|
||||
|
||||
if receipt.direction == ReceiptDirection.CHELTUIALA:
|
||||
# Expense: Debit expense account, Credit cash/bank
|
||||
if expense_type.has_vat:
|
||||
# Calculate net and VAT
|
||||
vat_rate = expense_type.vat_percent / Decimal("100")
|
||||
net_amount = (amount / (1 + vat_rate)).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
vat_amount = amount - net_amount
|
||||
|
||||
# Debit: Expense account (net)
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.account_code,
|
||||
account_name=expense_type.account_name,
|
||||
amount=net_amount,
|
||||
))
|
||||
|
||||
# Debit: VAT deductible
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.vat_account,
|
||||
account_name="TVA deductibila",
|
||||
amount=vat_amount,
|
||||
))
|
||||
else:
|
||||
# No VAT - full amount to expense
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.account_code,
|
||||
account_name=expense_type.account_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
# Credit entry - based on payment_mode (new) or cash_register (legacy)
|
||||
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
|
||||
credit_account, credit_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
|
||||
elif receipt.cash_register_account:
|
||||
# Backwards compatibility for existing receipts
|
||||
credit_account = receipt.cash_register_account
|
||||
credit_name = receipt.cash_register_name or "Casa/Banca"
|
||||
else:
|
||||
# Default fallback
|
||||
credit_account = "5311"
|
||||
credit_name = "Casa in lei"
|
||||
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.CREDIT,
|
||||
account_code=credit_account,
|
||||
account_name=credit_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
else:
|
||||
# Income: Debit cash/bank, Credit income account
|
||||
# Based on payment_mode (new) or cash_register (legacy)
|
||||
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
|
||||
cash_account, cash_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
|
||||
elif receipt.cash_register_account:
|
||||
cash_account = receipt.cash_register_account
|
||||
cash_name = receipt.cash_register_name or "Casa/Banca"
|
||||
else:
|
||||
cash_account = "5311"
|
||||
cash_name = "Casa in lei"
|
||||
|
||||
# Debit: Cash/Bank
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=cash_account,
|
||||
account_name=cash_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
# Credit: Income account (7xx - to be configured)
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.CREDIT,
|
||||
account_code="7588",
|
||||
account_name="Alte venituri din exploatare",
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
async def submit_for_review(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Submit receipt for review (DRAFT/REJECTED → PENDING_REVIEW).
|
||||
Generates accounting entries automatically.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if not await ReceiptCRUD.can_submit(receipt, username):
|
||||
return False, "Cannot submit this receipt", None
|
||||
|
||||
# Check if receipt has at least one attachment
|
||||
if not receipt.attachments:
|
||||
return False, "Receipt must have at least one attachment", None
|
||||
|
||||
# Check required fields
|
||||
if not receipt.expense_type_code:
|
||||
return False, "Expense type is required", None
|
||||
|
||||
# Validate payment_mode or cash_register (backwards compatibility)
|
||||
if not receipt.payment_mode and not receipt.cash_register_account:
|
||||
return False, "Modul de plata este obligatoriu", None
|
||||
|
||||
# Generate accounting entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
|
||||
# Delete existing entries and create new ones
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
# Refresh receipt to clear stale relationship references after entry deletion
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
# Reload with entries
|
||||
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
return True, "Receipt submitted for review", updated
|
||||
|
||||
@staticmethod
|
||||
async def approve_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Approve receipt (PENDING_REVIEW → APPROVED).
|
||||
Requires valid CUI (fiscal code) for approval.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Receipt is not pending review", None
|
||||
|
||||
# Validate CUI is present (required for Oracle import)
|
||||
if not receipt.cui:
|
||||
return False, "Trebuie completat codul fiscal (CUI) pentru aprobare", None
|
||||
|
||||
# Validate accounting entries
|
||||
if not receipt.entries:
|
||||
return False, "Receipt has no accounting entries", None
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.APPROVED, reviewed_by=username
|
||||
)
|
||||
|
||||
return True, "Receipt approved", updated
|
||||
|
||||
@staticmethod
|
||||
async def unapprove_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Unapprove receipt (APPROVED → PENDING_REVIEW).
|
||||
Returns receipt to pending review for corrections.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.APPROVED:
|
||||
return False, "Receipt is not approved", None
|
||||
|
||||
# Update status back to pending review
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
return True, "Receipt returned to pending review", updated
|
||||
|
||||
@staticmethod
|
||||
async def reject_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
reason: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Reject receipt (PENDING_REVIEW → REJECTED).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Receipt is not pending review", None
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session,
|
||||
receipt,
|
||||
ReceiptStatus.REJECTED,
|
||||
reviewed_by=username,
|
||||
rejection_reason=reason,
|
||||
)
|
||||
|
||||
return True, "Receipt rejected", updated
|
||||
|
||||
@staticmethod
|
||||
async def resubmit_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.REJECTED:
|
||||
return False, "Receipt is not rejected", None
|
||||
|
||||
if receipt.created_by != username:
|
||||
return False, "Only the creator can resubmit", None
|
||||
|
||||
# Re-generate accounting entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
# Refresh receipt to clear stale relationship references after entry deletion
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
# Reload with entries
|
||||
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
return True, "Receipt resubmitted for review", updated
|
||||
|
||||
@staticmethod
|
||||
async def regenerate_entries(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, List[AccountingEntryCreate]]:
|
||||
"""
|
||||
Regenerate accounting entries for a receipt.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", []
|
||||
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.PENDING_REVIEW]:
|
||||
return False, "Cannot regenerate entries for this receipt status", []
|
||||
|
||||
# Generate new entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
|
||||
# Replace existing entries
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
return True, "Entries regenerated", entries
|
||||
|
||||
@staticmethod
|
||||
async def update_entries(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
username: str,
|
||||
) -> Tuple[bool, str, List]:
|
||||
"""
|
||||
Update accounting entries for a receipt (accountant action).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", []
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Can only modify entries for receipts pending review", []
|
||||
|
||||
# Validate entries
|
||||
is_valid, error = await AccountingEntryCRUD.validate_entries(entries)
|
||||
if not is_valid:
|
||||
return False, error, []
|
||||
|
||||
# Replace entries
|
||||
updated_entries = await AccountingEntryCRUD.replace_all_for_receipt(
|
||||
session, receipt_id, entries, username
|
||||
)
|
||||
|
||||
return True, "Entries updated", updated_entries
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_count(
|
||||
session: AsyncSession,
|
||||
company_id: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Get count of receipts pending review."""
|
||||
receipts = await ReceiptCRUD.get_pending_review(session, company_id)
|
||||
return len(receipts)
|
||||
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Server-Sent Events (SSE) service for real-time status updates.
|
||||
|
||||
This module implements an event broadcaster pattern using asyncio.Queue per client.
|
||||
When receipt status changes occur (CRUD operations), events are pushed to all
|
||||
connected clients who are listening for that specific batch or all receipts.
|
||||
|
||||
Usage:
|
||||
# In router endpoint (SSE stream):
|
||||
async for event in sse_service.subscribe(batch_id=None):
|
||||
yield event
|
||||
|
||||
# When status changes (from CRUD operations):
|
||||
await sse_service.broadcast_status_change(receipt_id, status, processing_status, batch_id)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import AsyncGenerator, Optional
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusChangeEvent:
|
||||
"""Event data for receipt status changes."""
|
||||
receipt_id: int
|
||||
status: str
|
||||
processing_status: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
def to_sse_data(self) -> str:
|
||||
"""Format as SSE data line."""
|
||||
data = asdict(self)
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class SSEEventBroadcaster:
|
||||
"""
|
||||
Manages SSE client connections and broadcasts events.
|
||||
|
||||
Each client gets its own asyncio.Queue. When an event occurs,
|
||||
it's pushed to all relevant queues based on batch_id filtering.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Dict of {client_id: (queue, batch_id_filter)}
|
||||
# batch_id_filter is None for clients that want all events
|
||||
self._clients: dict[str, tuple[asyncio.Queue, Optional[str]]] = {}
|
||||
self._client_counter = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _generate_client_id(self) -> str:
|
||||
"""Generate unique client ID."""
|
||||
async with self._lock:
|
||||
self._client_counter += 1
|
||||
return f"client_{self._client_counter}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
batch_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to SSE events.
|
||||
|
||||
Args:
|
||||
batch_id: Optional filter - only receive events for this batch.
|
||||
If None, receives all events.
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings (ready to send to client).
|
||||
"""
|
||||
client_id = await self._generate_client_id()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
# Register client
|
||||
async with self._lock:
|
||||
self._clients[client_id] = (queue, batch_id)
|
||||
|
||||
logger.info(
|
||||
f"SSE client {client_id} connected (batch_id filter: {batch_id}). "
|
||||
f"Total clients: {len(self._clients)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Send initial retry hint for reconnection
|
||||
yield "retry: 3000\n\n"
|
||||
|
||||
# Keep connection alive and yield events
|
||||
while True:
|
||||
try:
|
||||
# Wait for events with timeout for keep-alive
|
||||
event = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||
yield event
|
||||
except asyncio.TimeoutError:
|
||||
# Send keep-alive comment to prevent connection timeout
|
||||
yield ": keep-alive\n\n"
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE client {client_id} subscription cancelled")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup: remove client from registry
|
||||
async with self._lock:
|
||||
self._clients.pop(client_id, None)
|
||||
logger.info(
|
||||
f"SSE client {client_id} disconnected. "
|
||||
f"Remaining clients: {len(self._clients)}"
|
||||
)
|
||||
|
||||
async def broadcast_status_change(
|
||||
self,
|
||||
receipt_id: int,
|
||||
status: str,
|
||||
processing_status: Optional[str] = None,
|
||||
batch_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Broadcast a status change event to all relevant clients.
|
||||
|
||||
Args:
|
||||
receipt_id: The receipt ID that changed.
|
||||
status: New workflow status (DRAFT, PENDING_REVIEW, etc.).
|
||||
processing_status: New processing status (pending, processing, completed, failed).
|
||||
batch_id: The batch ID this receipt belongs to (for filtering).
|
||||
|
||||
Returns:
|
||||
Number of clients notified.
|
||||
"""
|
||||
event = StatusChangeEvent(
|
||||
receipt_id=receipt_id,
|
||||
status=status,
|
||||
processing_status=processing_status,
|
||||
batch_id=batch_id,
|
||||
)
|
||||
sse_data = event.to_sse_data()
|
||||
|
||||
notified = 0
|
||||
async with self._lock:
|
||||
for client_id, (queue, client_batch_filter) in self._clients.items():
|
||||
# Send event if:
|
||||
# 1. Client has no filter (wants all events), OR
|
||||
# 2. Client's filter matches the event's batch_id
|
||||
if client_batch_filter is None or client_batch_filter == batch_id:
|
||||
try:
|
||||
queue.put_nowait(sse_data)
|
||||
notified += 1
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"SSE queue full for client {client_id}, dropping event"
|
||||
)
|
||||
|
||||
if notified > 0:
|
||||
logger.debug(
|
||||
f"SSE broadcast: receipt_id={receipt_id}, status={status}, "
|
||||
f"processing_status={processing_status}, notified={notified} clients"
|
||||
)
|
||||
|
||||
return notified
|
||||
|
||||
@property
|
||||
def client_count(self) -> int:
|
||||
"""Get current number of connected clients."""
|
||||
return len(self._clients)
|
||||
|
||||
|
||||
# Singleton instance for the application
|
||||
sse_broadcaster = SSEEventBroadcaster()
|
||||
|
||||
|
||||
# Convenience functions for external use
|
||||
async def subscribe(batch_id: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||
"""Subscribe to SSE status change events."""
|
||||
async for event in sse_broadcaster.subscribe(batch_id):
|
||||
yield event
|
||||
|
||||
|
||||
async def broadcast_status_change(
|
||||
receipt_id: int,
|
||||
status: str,
|
||||
processing_status: Optional[str] = None,
|
||||
batch_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Broadcast a status change event."""
|
||||
return await sse_broadcaster.broadcast_status_change(
|
||||
receipt_id=receipt_id,
|
||||
status=status,
|
||||
processing_status=processing_status,
|
||||
batch_id=batch_id,
|
||||
)
|
||||
@@ -0,0 +1,451 @@
|
||||
"""Service for syncing nomenclatures from Oracle to SQLite."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# Path setup handled by main.py - this is redundant
|
||||
# project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
# sys.path.insert(0, str(project_root / "shared"))
|
||||
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for schema lookups (populated dynamically from Oracle)
|
||||
# Key format: (server_id, company_id) for multi-server support
|
||||
_schema_cache: dict[tuple, str] = {}
|
||||
|
||||
|
||||
class SyncService:
|
||||
"""Service for syncing nomenclatures from Oracle."""
|
||||
|
||||
@staticmethod
|
||||
async def get_schema_for_company(company_id: int, server_id: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get Oracle schema for company ID from V_NOM_FIRME view.
|
||||
Results are cached in memory for performance.
|
||||
|
||||
Args:
|
||||
company_id: The company ID to look up
|
||||
server_id: Optional Oracle server ID for multi-server mode
|
||||
"""
|
||||
# Check cache first - use (server_id, company_id) as key for multi-server support
|
||||
cache_key = (server_id, company_id)
|
||||
if cache_key in _schema_cache:
|
||||
return _schema_cache[cache_key]
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection(server_id) as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT SCHEMA
|
||||
FROM CONTAFIN_ORACLE.V_NOM_FIRME
|
||||
WHERE ID_FIRMA = :company_id
|
||||
""", {'company_id': company_id})
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
schema = result[0]
|
||||
_schema_cache[cache_key] = schema
|
||||
logger.info(f"Resolved schema for company {company_id} on server {server_id}: {schema}")
|
||||
return schema
|
||||
else:
|
||||
logger.warning(f"No schema found for company {company_id} on server {server_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching schema for company {company_id} on server {server_id}: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def sync_suppliers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync suppliers (furnizori, id_tip_part=17) from Oracle to SQLite.
|
||||
Uses CORESP_TIP_PART joined with VNOM_PARTENERI view.
|
||||
Returns (synced_count, error_count).
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy async session for SQLite
|
||||
company_id: The company ID to sync suppliers for
|
||||
server_id: Optional Oracle server ID for multi-server mode
|
||||
"""
|
||||
schema = await SyncService.get_schema_for_company(company_id, server_id)
|
||||
if not schema:
|
||||
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
|
||||
return 0, 0
|
||||
|
||||
synced = 0
|
||||
errors = 0
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection(server_id) as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Fetch active suppliers from Oracle
|
||||
# id_tip_part = 17 means "furnizori" (suppliers)
|
||||
# Using CORESP_TIP_PART to filter by partner type
|
||||
cursor.execute(f"""
|
||||
SELECT B.ID_PART, B.DENUMIRE, B.COD_FISCAL, B.ADRESA
|
||||
FROM {schema}.CORESP_TIP_PART A
|
||||
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
|
||||
WHERE A.ID_TIP_PART = 17
|
||||
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
|
||||
AND B.ID_PART IS NOT NULL
|
||||
ORDER BY B.DENUMIRE
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
oracle_id, name, fiscal_code, address = row
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.oracle_id == oracle_id,
|
||||
SyncedSupplier.company_id == company_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Update existing record
|
||||
existing.name = name or ""
|
||||
existing.fiscal_code = fiscal_code
|
||||
existing.address = address
|
||||
existing.synced_at = datetime.utcnow()
|
||||
logger.debug(f"Updated supplier {oracle_id}: {name}")
|
||||
else:
|
||||
# Create new record
|
||||
supplier = SyncedSupplier(
|
||||
oracle_id=oracle_id,
|
||||
company_id=company_id,
|
||||
name=name or "",
|
||||
fiscal_code=fiscal_code,
|
||||
address=address,
|
||||
)
|
||||
session.add(supplier)
|
||||
logger.debug(f"Created supplier {oracle_id}: {name}")
|
||||
|
||||
synced += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing supplier row {row}: {e}")
|
||||
errors += 1
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
logger.info(f"Synced {synced} suppliers for company {company_id}, {errors} errors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing suppliers for company {company_id}: {e}")
|
||||
errors += 1
|
||||
await session.rollback()
|
||||
|
||||
return synced, errors
|
||||
|
||||
@staticmethod
|
||||
async def sync_cash_registers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync cash registers and bank accounts from Oracle to SQLite.
|
||||
Returns (synced_count, error_count).
|
||||
|
||||
Uses CORESP_TIP_PART with:
|
||||
- id_tip_part = 22: CASA LEI
|
||||
- id_tip_part = 23: CASA VALUTA
|
||||
- id_tip_part = 24: BANCA LEI
|
||||
- id_tip_part = 25: BANCA VALUTA
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy async session for SQLite
|
||||
company_id: The company ID to sync cash registers for
|
||||
server_id: Optional Oracle server ID for multi-server mode
|
||||
"""
|
||||
schema = await SyncService.get_schema_for_company(company_id, server_id)
|
||||
if not schema:
|
||||
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
|
||||
return 0, 0
|
||||
|
||||
synced = 0
|
||||
errors = 0
|
||||
|
||||
# Partner types mapping
|
||||
# 22=CASA LEI, 23=CASA VALUTA -> cash
|
||||
# 24=BANCA LEI, 25=BANCA VALUTA -> bank
|
||||
partner_types = [22, 23, 24, 25]
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection(server_id) as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Fetch cash/bank partners from CORESP_TIP_PART
|
||||
cursor.execute(f"""
|
||||
SELECT B.ID_PART, B.DENUMIRE, A.ID_TIP_PART
|
||||
FROM {schema}.CORESP_TIP_PART A
|
||||
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
|
||||
WHERE A.ID_TIP_PART IN (22, 23, 24, 25)
|
||||
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
|
||||
AND B.ID_PART IS NOT NULL
|
||||
ORDER BY A.ID_TIP_PART, B.DENUMIRE
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Type mapping: 22=CASA LEI, 23=CASA VALUTA -> cash; 24=BANCA LEI, 25=BANCA VALUTA -> bank
|
||||
type_mapping = {
|
||||
22: ("cash", "CASA_LEI"),
|
||||
23: ("cash", "CASA_VALUTA"),
|
||||
24: ("bank", "BANCA_LEI"),
|
||||
25: ("bank", "BANCA_VALUTA"),
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
oracle_id, name, tip_part_id = row
|
||||
|
||||
# Determine type based on partner type
|
||||
register_type, account_code = type_mapping.get(tip_part_id, ("cash", "UNKNOWN"))
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(SyncedCashRegister).where(
|
||||
SyncedCashRegister.oracle_id == oracle_id,
|
||||
SyncedCashRegister.company_id == company_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Update existing record
|
||||
existing.name = name or ""
|
||||
existing.account_code = account_code
|
||||
existing.register_type = register_type
|
||||
existing.synced_at = datetime.utcnow()
|
||||
logger.debug(f"Updated cash register {oracle_id}: {name}")
|
||||
else:
|
||||
# Create new record
|
||||
cash_register = SyncedCashRegister(
|
||||
oracle_id=oracle_id,
|
||||
company_id=company_id,
|
||||
name=name or "",
|
||||
account_code=account_code,
|
||||
register_type=register_type,
|
||||
)
|
||||
session.add(cash_register)
|
||||
logger.debug(f"Created cash register {oracle_id}: {name}")
|
||||
|
||||
synced += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing cash register row {row}: {e}")
|
||||
errors += 1
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
logger.info(f"Synced {synced} cash registers for company {company_id}, {errors} errors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing cash registers for company {company_id}: {e}")
|
||||
errors += 1
|
||||
await session.rollback()
|
||||
|
||||
return synced, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_fiscal_code_variants(fiscal_code: str) -> list:
|
||||
"""
|
||||
Generate all possible variants of a Romanian fiscal code (CUI).
|
||||
Database may store: "22891860", "RO22891860", "RO 22891860"
|
||||
OCR may extract: "RO22891860" or "22891860"
|
||||
"""
|
||||
import re
|
||||
# Extract just the digits
|
||||
digits = re.sub(r'[^0-9]', '', fiscal_code)
|
||||
if not digits:
|
||||
return [fiscal_code]
|
||||
|
||||
# Generate all variants
|
||||
variants = [
|
||||
digits, # Just digits: 22891860
|
||||
f"RO{digits}", # With RO prefix: RO22891860
|
||||
f"RO {digits}", # With RO prefix and space: RO 22891860
|
||||
]
|
||||
# Also add the original if different
|
||||
if fiscal_code not in variants:
|
||||
variants.append(fiscal_code)
|
||||
|
||||
return variants
|
||||
|
||||
@staticmethod
|
||||
async def search_supplier(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
fiscal_code: Optional[str] = None,
|
||||
name: Optional[str] = None
|
||||
) -> Tuple[bool, Optional[dict], str]:
|
||||
"""
|
||||
Search for supplier in SQLite first, then Oracle if not found.
|
||||
Returns (found, supplier_data, source).
|
||||
Source can be: 'synced', 'local', 'not_found'
|
||||
"""
|
||||
# 1. Search in synced suppliers
|
||||
if fiscal_code:
|
||||
# Search all variants of the fiscal code (with/without RO, with/without space)
|
||||
variants = SyncService._get_fiscal_code_variants(fiscal_code)
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.company_id == company_id,
|
||||
SyncedSupplier.fiscal_code.in_(variants)
|
||||
)
|
||||
elif name:
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.company_id == company_id,
|
||||
SyncedSupplier.name.ilike(f"%{name}%")
|
||||
)
|
||||
else:
|
||||
return False, None, "no_query"
|
||||
|
||||
result = await session.execute(stmt)
|
||||
supplier = result.scalar_one_or_none()
|
||||
|
||||
if supplier:
|
||||
# Return only text data - no IDs needed for autocomplete
|
||||
return True, {
|
||||
"name": supplier.name,
|
||||
"fiscal_code": supplier.fiscal_code,
|
||||
"address": supplier.address,
|
||||
}, "synced"
|
||||
|
||||
# 2. Search in local suppliers
|
||||
if fiscal_code:
|
||||
# Search all variants of the fiscal code (with/without RO, with/without space)
|
||||
variants = SyncService._get_fiscal_code_variants(fiscal_code)
|
||||
stmt = select(LocalSupplier).where(
|
||||
LocalSupplier.company_id == company_id,
|
||||
LocalSupplier.fiscal_code.in_(variants)
|
||||
)
|
||||
elif name:
|
||||
stmt = select(LocalSupplier).where(
|
||||
LocalSupplier.company_id == company_id,
|
||||
LocalSupplier.name.ilike(f"%{name}%")
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
local = result.scalar_one_or_none()
|
||||
|
||||
if local:
|
||||
# Return only text data - no IDs needed for autocomplete
|
||||
return True, {
|
||||
"name": local.name,
|
||||
"fiscal_code": local.fiscal_code,
|
||||
"address": local.address,
|
||||
}, "local"
|
||||
|
||||
# 3. Try live Oracle search (optional fallback for unsynced data)
|
||||
# This is a fallback - ideally sync should be up to date
|
||||
# TODO: Implement live Oracle search if needed
|
||||
|
||||
return False, None, "not_found"
|
||||
|
||||
@staticmethod
|
||||
async def create_local_supplier(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
name: str,
|
||||
fiscal_code: Optional[str],
|
||||
address: Optional[str],
|
||||
created_by: str
|
||||
) -> LocalSupplier:
|
||||
"""Create a local supplier entry from OCR data."""
|
||||
supplier = LocalSupplier(
|
||||
company_id=company_id,
|
||||
name=name,
|
||||
fiscal_code=fiscal_code,
|
||||
address=address,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(supplier)
|
||||
await session.commit()
|
||||
await session.refresh(supplier)
|
||||
logger.info(f"Created local supplier: {name} (CUI: {fiscal_code})")
|
||||
return supplier
|
||||
|
||||
@staticmethod
|
||||
async def get_all_suppliers(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
search: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get all suppliers (synced + local) for a company.
|
||||
Used for dropdown/autocomplete in UI.
|
||||
"""
|
||||
suppliers = []
|
||||
|
||||
# Get synced suppliers
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.limit(50) # Limit results for performance
|
||||
|
||||
result = await session.execute(stmt)
|
||||
synced = result.scalars().all()
|
||||
|
||||
for s in synced:
|
||||
suppliers.append({
|
||||
"id": s.id,
|
||||
"oracle_id": s.oracle_id,
|
||||
"name": s.name,
|
||||
"fiscal_code": s.fiscal_code,
|
||||
"source": "synced"
|
||||
})
|
||||
|
||||
# Get local suppliers
|
||||
stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.limit(50)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
local = result.scalars().all()
|
||||
|
||||
for l in local:
|
||||
suppliers.append({
|
||||
"id": l.id,
|
||||
"name": l.name,
|
||||
"fiscal_code": l.fiscal_code,
|
||||
"source": "local"
|
||||
})
|
||||
|
||||
return suppliers
|
||||
|
||||
@staticmethod
|
||||
async def get_all_cash_registers(
|
||||
session: AsyncSession,
|
||||
company_id: int
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get all cash registers for a company.
|
||||
Used for dropdown in UI.
|
||||
"""
|
||||
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
|
||||
result = await session.execute(stmt)
|
||||
registers = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"oracle_id": r.oracle_id,
|
||||
"name": r.name,
|
||||
"account_code": r.account_code,
|
||||
"register_type": r.register_type
|
||||
}
|
||||
for r in registers
|
||||
]
|
||||
Reference in New Issue
Block a user