fix telegram

This commit is contained in:
Claude Agent
2026-02-23 15:12:33 +00:00
parent 6c78fec8a7
commit 8bc567a9c5
426 changed files with 112478 additions and 1 deletions

View File

@@ -0,0 +1,94 @@
# Alembic configuration for Data Entry module
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d%%(month).2d%%(day).2d_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator
# version_path_separator = :
# set to 'true' to search source files recursively
# in each "version_locations" directory
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# SQLite database URL - will be overridden by env.py using SQLITE_DATABASE_PATH env var
sqlalchemy.url = sqlite:///data/receipts/receipts.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - disabled
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -q
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,110 @@
"""Application configuration using pydantic-settings."""
import os
from pathlib import Path
from typing import List
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# App info
app_name: str = "Data Entry API"
app_version: str = "1.0.0"
debug: bool = False
# API
api_host: str = "0.0.0.0"
api_port: int = 8003
# SQLite Database
sqlite_database_path: str = "data/receipts/receipts.db"
# File uploads
upload_path: str = "data/uploads"
max_upload_size_mb: int = 10
allowed_mime_types: List[str] = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
]
# Oracle Database (for nomenclatures)
oracle_user: str = ""
oracle_password: str = ""
oracle_host: str = "localhost"
oracle_port: int = 1526
oracle_sid: str = "ROA"
# JWT Authentication
jwt_secret_key: str = "change-me-in-production"
jwt_algorithm: str = "HS256"
jwt_expire_minutes: int = 480
# CORS
cors_origins: str = "http://localhost:3010,http://localhost:3000"
# OCR Engines (comma-separated list of active engines shown in UI)
# Available: tesseract, paddleocr, doctr, doctr_plus
# doctr_plus is recommended (2-tier sequential with early exit)
ocr_active_engines: str = "doctr,doctr_plus"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
@property
def database_url(self) -> str:
"""Get SQLite database URL for async."""
# Resolve to absolute path for Windows/IIS compatibility
abs_path = Path(self.sqlite_database_path).resolve()
return f"sqlite+aiosqlite:///{abs_path}"
@property
def sync_database_url(self) -> str:
"""Get SQLite database URL for sync operations (Alembic)."""
# Resolve to absolute path for Windows/IIS compatibility
abs_path = Path(self.sqlite_database_path).resolve()
return f"sqlite:///{abs_path}"
@property
def upload_path_resolved(self) -> Path:
"""Get resolved upload path."""
path = Path(self.upload_path)
path.mkdir(parents=True, exist_ok=True)
return path
@property
def max_upload_size_bytes(self) -> int:
"""Get max upload size in bytes."""
return self.max_upload_size_mb * 1024 * 1024
@property
def cors_origins_list(self) -> List[str]:
"""Get CORS origins as list."""
return [origin.strip() for origin in self.cors_origins.split(",")]
@property
def ocr_active_engines_list(self) -> List[str]:
"""Get OCR active engines as list."""
return [engine.strip() for engine in self.ocr_active_engines.split(",")]
@property
def oracle_dsn(self) -> str:
"""Get Oracle DSN string."""
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
# Convenience instance
settings = get_settings()

View File

@@ -0,0 +1,4 @@
# Database module
from .database import get_session, init_db, engine
__all__ = ["get_session", "init_db", "engine"]

View File

@@ -0,0 +1,13 @@
# CRUD operations
from .receipt import ReceiptCRUD
from .attachment import AttachmentCRUD
from .accounting_entry import AccountingEntryCRUD
from .ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
__all__ = [
"ReceiptCRUD",
"AttachmentCRUD",
"AccountingEntryCRUD",
"OCRPreferenceCRUD",
"OCRMetricsCRUD",
]

View File

@@ -0,0 +1,197 @@
"""CRUD operations for accounting entries."""
from datetime import datetime
from typing import Optional, List
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry, EntryType
from backend.modules.data_entry.schemas.receipt import AccountingEntryCreate, AccountingEntryUpdate
class AccountingEntryCRUD:
"""CRUD operations for AccountingEntry model."""
@staticmethod
async def create(
session: AsyncSession,
receipt_id: int,
data: AccountingEntryCreate,
sort_order: int = 0,
is_auto_generated: bool = True,
) -> AccountingEntry:
"""Create a new accounting entry."""
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=data.entry_type,
account_code=data.account_code,
account_name=data.account_name,
amount=data.amount,
partner_id=data.partner_id,
cost_center_id=data.cost_center_id,
is_auto_generated=is_auto_generated,
sort_order=sort_order,
)
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
@staticmethod
async def create_bulk(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
is_auto_generated: bool = True,
) -> List[AccountingEntry]:
"""Create multiple accounting entries at once."""
created_entries = []
for idx, entry_data in enumerate(entries):
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=entry_data.entry_type,
account_code=entry_data.account_code,
account_name=entry_data.account_name,
amount=entry_data.amount,
partner_id=entry_data.partner_id,
cost_center_id=entry_data.cost_center_id,
is_auto_generated=is_auto_generated,
sort_order=idx,
)
session.add(entry)
created_entries.append(entry)
await session.commit()
for entry in created_entries:
await session.refresh(entry)
return created_entries
@staticmethod
async def get_by_id(
session: AsyncSession,
entry_id: int,
) -> Optional[AccountingEntry]:
"""Get accounting entry by ID."""
query = select(AccountingEntry).where(AccountingEntry.id == entry_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_by_receipt_id(
session: AsyncSession,
receipt_id: int,
) -> List[AccountingEntry]:
"""Get all accounting entries for a receipt."""
query = select(AccountingEntry).where(
AccountingEntry.receipt_id == receipt_id
).order_by(AccountingEntry.sort_order.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
async def update(
session: AsyncSession,
entry: AccountingEntry,
data: AccountingEntryUpdate,
modified_by: str,
) -> AccountingEntry:
"""Update an accounting entry."""
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(entry, field, value)
entry.is_auto_generated = False
entry.modified_by = modified_by
entry.modified_at = datetime.utcnow()
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
@staticmethod
async def delete(session: AsyncSession, entry: AccountingEntry) -> bool:
"""Delete an accounting entry."""
await session.delete(entry)
await session.commit()
return True
@staticmethod
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
"""Delete all accounting entries for a receipt."""
query = delete(AccountingEntry).where(AccountingEntry.receipt_id == receipt_id)
result = await session.execute(query)
await session.commit()
return result.rowcount
@staticmethod
async def replace_all_for_receipt(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
modified_by: str,
) -> List[AccountingEntry]:
"""Replace all entries for a receipt with new ones."""
# Delete existing entries
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
# Create new entries (marked as manually modified)
created_entries = []
for idx, entry_data in enumerate(entries):
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=entry_data.entry_type,
account_code=entry_data.account_code,
account_name=entry_data.account_name,
amount=entry_data.amount,
partner_id=entry_data.partner_id,
cost_center_id=entry_data.cost_center_id,
is_auto_generated=False,
modified_by=modified_by,
modified_at=datetime.utcnow(),
sort_order=idx,
)
session.add(entry)
created_entries.append(entry)
await session.commit()
for entry in created_entries:
await session.refresh(entry)
return created_entries
@staticmethod
async def validate_entries(entries: List[AccountingEntryCreate]) -> tuple[bool, str]:
"""
Validate accounting entries.
Returns (is_valid, error_message).
"""
if not entries:
return False, "At least one entry is required"
total_debit = sum(
e.amount for e in entries if e.entry_type == EntryType.DEBIT
)
total_credit = sum(
e.amount for e in entries if e.entry_type == EntryType.CREDIT
)
# Check balance (debit should equal credit)
if abs(total_debit - total_credit) > 0.01:
return False, f"Entries not balanced: Debit={total_debit}, Credit={total_credit}"
# Check for valid account codes
for entry in entries:
if not entry.account_code or len(entry.account_code) < 3:
return False, f"Invalid account code: {entry.account_code}"
return True, ""

View File

@@ -0,0 +1,140 @@
"""CRUD operations for receipt attachments."""
import os
import uuid
import aiofiles
from datetime import datetime
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import UploadFile
from backend.modules.data_entry.db.models.receipt import ReceiptAttachment
from backend.config import settings
class AttachmentCRUD:
"""CRUD operations for ReceiptAttachment model."""
@staticmethod
def _generate_stored_filename(original_filename: str) -> str:
"""Generate unique filename for storage."""
ext = Path(original_filename).suffix.lower()
return f"{uuid.uuid4()}{ext}"
@staticmethod
def _get_upload_path(stored_filename: str) -> Path:
"""Get full path for storing file, organized by year/month."""
now = datetime.utcnow()
relative_path = Path(str(now.year)) / f"{now.month:02d}"
full_path = settings.data_entry_upload_path_resolved / relative_path
# Ensure directory exists
full_path.mkdir(parents=True, exist_ok=True)
return relative_path / stored_filename
@staticmethod
async def create(
session: AsyncSession,
receipt_id: int,
file: UploadFile,
) -> ReceiptAttachment:
"""Create attachment by saving file and creating DB record."""
# Generate stored filename
stored_filename = AttachmentCRUD._generate_stored_filename(file.filename or "upload")
# Get relative path
relative_path = AttachmentCRUD._get_upload_path(stored_filename)
# Full path for saving
full_path = settings.data_entry_upload_path_resolved / relative_path
# Read file content
content = await file.read()
file_size = len(content)
# Validate file size
if file_size > settings.data_entry_max_upload_size_bytes:
raise ValueError(f"File too large. Maximum size is {settings.data_entry_max_upload_size_mb}MB")
# Validate MIME type
mime_type = file.content_type or "application/octet-stream"
if mime_type not in settings.data_entry_allowed_mime_types:
raise ValueError(f"File type not allowed: {mime_type}")
# Save file
async with aiofiles.open(full_path, "wb") as f:
await f.write(content)
# Create DB record
attachment = ReceiptAttachment(
receipt_id=receipt_id,
filename=file.filename or "upload",
stored_filename=stored_filename,
file_path=str(relative_path),
file_size=file_size,
mime_type=mime_type,
)
session.add(attachment)
await session.commit()
await session.refresh(attachment)
return attachment
@staticmethod
async def get_by_id(
session: AsyncSession,
attachment_id: int,
) -> Optional[ReceiptAttachment]:
"""Get attachment by ID."""
query = select(ReceiptAttachment).where(ReceiptAttachment.id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_by_receipt_id(
session: AsyncSession,
receipt_id: int,
) -> List[ReceiptAttachment]:
"""Get all attachments for a receipt."""
query = select(ReceiptAttachment).where(
ReceiptAttachment.receipt_id == receipt_id
).order_by(ReceiptAttachment.uploaded_at.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
def get_file_path(attachment: ReceiptAttachment) -> Path:
"""Get full file path for an attachment."""
return settings.data_entry_upload_path_resolved / attachment.file_path
@staticmethod
async def delete(session: AsyncSession, attachment: ReceiptAttachment) -> bool:
"""Delete attachment (file and DB record)."""
# Delete file
file_path = AttachmentCRUD.get_file_path(attachment)
if file_path.exists():
os.remove(file_path)
# Delete DB record
await session.delete(attachment)
await session.commit()
return True
@staticmethod
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
"""Delete all attachments for a receipt."""
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
count = 0
for attachment in attachments:
await AttachmentCRUD.delete(session, attachment)
count += 1
return count

View File

@@ -0,0 +1,222 @@
"""CRUD operations for OCR settings and metrics."""
from datetime import datetime, timedelta
from typing import List, Optional
from sqlalchemy import func, select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.ocr_settings import (
UserOCRPreference,
OCRJobMetrics,
OCRMetricsSummary,
OCREngine,
)
class OCRPreferenceCRUD:
"""CRUD operations for user OCR preferences."""
@staticmethod
async def get_by_username(session: AsyncSession, username: str) -> Optional[UserOCRPreference]:
"""Get user's OCR preference by username."""
result = await session.execute(
select(UserOCRPreference).where(UserOCRPreference.username == username)
)
return result.scalar_one_or_none()
@staticmethod
async def create_or_update(
session: AsyncSession,
username: str,
preferred_engine: OCREngine
) -> UserOCRPreference:
"""Create or update user's OCR preference."""
existing = await OCRPreferenceCRUD.get_by_username(session, username)
if existing:
existing.preferred_engine = preferred_engine
existing.updated_at = datetime.utcnow()
await session.commit()
await session.refresh(existing)
return existing
else:
preference = UserOCRPreference(
username=username,
preferred_engine=preferred_engine
)
session.add(preference)
await session.commit()
await session.refresh(preference)
return preference
@staticmethod
async def delete_by_username(session: AsyncSession, username: str) -> bool:
"""Delete user's OCR preference."""
existing = await OCRPreferenceCRUD.get_by_username(session, username)
if existing:
await session.delete(existing)
await session.commit()
return True
return False
class OCRMetricsCRUD:
"""CRUD operations for OCR job metrics."""
@staticmethod
async def create(
session: AsyncSession,
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
company_id: Optional[int] = None
) -> OCRJobMetrics:
"""Create a new OCR job metrics record."""
metrics = OCRJobMetrics(
job_id=job_id,
username=username,
company_id=company_id,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
session.add(metrics)
await session.commit()
await session.refresh(metrics)
return metrics
@staticmethod
async def get_by_job_id(session: AsyncSession, job_id: str) -> Optional[OCRJobMetrics]:
"""Get metrics by job ID."""
result = await session.execute(
select(OCRJobMetrics).where(OCRJobMetrics.job_id == job_id)
)
return result.scalar_one_or_none()
@staticmethod
async def get_user_history(
session: AsyncSession,
username: str,
limit: int = 50,
offset: int = 0
) -> List[OCRJobMetrics]:
"""Get user's OCR job history."""
result = await session.execute(
select(OCRJobMetrics)
.where(OCRJobMetrics.username == username)
.order_by(OCRJobMetrics.created_at.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
@staticmethod
async def get_summary_by_engine(
session: AsyncSession,
days: int = 30,
username: Optional[str] = None
) -> List[OCRMetricsSummary]:
"""Get summary metrics grouped by engine."""
cutoff_date = datetime.utcnow() - timedelta(days=days)
# Build query
conditions = [OCRJobMetrics.created_at >= cutoff_date]
if username:
conditions.append(OCRJobMetrics.username == username)
# Query for aggregated metrics
result = await session.execute(
select(
OCRJobMetrics.engine_used,
func.count(OCRJobMetrics.id).label('total_jobs'),
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
func.avg(OCRJobMetrics.fields_extracted).label('avg_fields_extracted'),
)
.where(and_(*conditions))
.group_by(OCRJobMetrics.engine_used)
.order_by(func.count(OCRJobMetrics.id).desc())
)
summaries = []
for row in result.all():
total = row.total_jobs or 0
successful = row.successful_jobs or 0
success_rate = successful / total if total > 0 else 0.0
summaries.append(OCRMetricsSummary(
engine=row.engine_used,
total_jobs=total,
successful_jobs=successful,
failed_jobs=total - successful,
success_rate=success_rate,
avg_processing_time_ms=float(row.avg_processing_time_ms or 0),
avg_confidence=float(row.avg_confidence or 0),
avg_fields_extracted=float(row.avg_fields_extracted or 0),
))
return summaries
@staticmethod
async def get_overall_stats(
session: AsyncSession,
days: int = 30,
username: Optional[str] = None
) -> dict:
"""Get overall OCR statistics."""
cutoff_date = datetime.utcnow() - timedelta(days=days)
conditions = [OCRJobMetrics.created_at >= cutoff_date]
if username:
conditions.append(OCRJobMetrics.username == username)
result = await session.execute(
select(
func.count(OCRJobMetrics.id).label('total_jobs'),
func.sum(func.cast(OCRJobMetrics.success, sa.Integer)).label('successful_jobs'),
func.avg(OCRJobMetrics.processing_time_ms).label('avg_processing_time_ms'),
func.avg(OCRJobMetrics.overall_confidence).label('avg_confidence'),
)
.where(and_(*conditions))
)
row = result.one()
total = row.total_jobs or 0
successful = row.successful_jobs or 0
return {
"total_jobs": total,
"successful_jobs": successful,
"failed_jobs": total - successful,
"success_rate": (successful / total * 100) if total > 0 else 0.0,
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
"avg_confidence": float(row.avg_confidence or 0),
"period_days": days,
}
# Import sqlalchemy for func.cast
import sqlalchemy as sa

View File

@@ -0,0 +1,418 @@
"""CRUD operations for receipts."""
import json
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, List, Tuple, Any
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, ReceiptUpdate, ReceiptFilter
def _serialize_tva_breakdown(tva_breakdown: Optional[List[Any]]) -> Optional[str]:
"""Serialize TVA breakdown list to JSON string for SQLite storage."""
if tva_breakdown is None:
return None
# Convert Decimal to float for JSON serialization
serializable = []
for entry in tva_breakdown:
if hasattr(entry, 'model_dump'):
# Pydantic model
item = entry.model_dump()
elif isinstance(entry, dict):
item = entry.copy()
else:
item = dict(entry)
# Convert Decimal to float
if 'amount' in item and isinstance(item['amount'], Decimal):
item['amount'] = float(item['amount'])
serializable.append(item)
return json.dumps(serializable)
def _serialize_payment_methods(payment_methods: Optional[List[Any]]) -> Optional[str]:
"""Serialize payment methods list to JSON string for SQLite storage."""
if payment_methods is None:
return None
serializable = []
for pm in payment_methods:
if hasattr(pm, 'model_dump'):
item = pm.model_dump()
elif isinstance(pm, dict):
item = pm.copy()
else:
item = dict(pm)
# Convert Decimal to float for JSON
if 'amount' in item:
if hasattr(item['amount'], '__float__'):
item['amount'] = float(item['amount'])
serializable.append(item)
return json.dumps(serializable)
class ReceiptCRUD:
"""CRUD operations for Receipt model."""
@staticmethod
async def create(
session: AsyncSession,
data: ReceiptCreate,
created_by: str,
) -> Receipt:
"""Create a new receipt."""
# Get data as dict and serialize tva_breakdown and payment_methods to JSON string
receipt_data = data.model_dump()
receipt_data['tva_breakdown'] = _serialize_tva_breakdown(receipt_data.get('tva_breakdown'))
receipt_data['payment_methods'] = _serialize_payment_methods(receipt_data.get('payment_methods'))
receipt = Receipt(
**receipt_data,
created_by=created_by,
status=ReceiptStatus.DRAFT,
)
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def get_by_id(
session: AsyncSession,
receipt_id: int,
include_relations: bool = True,
) -> Optional[Receipt]:
"""Get receipt by ID, optionally with relationships."""
query = select(Receipt).where(Receipt.id == receipt_id)
if include_relations:
query = query.options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_list(
session: AsyncSession,
filters: ReceiptFilter,
) -> Tuple[List[Receipt], int]:
"""Get paginated list of receipts with filters."""
# Base query
query = select(Receipt).options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
# Apply filters
if filters.status:
query = query.where(Receipt.status == filters.status)
if filters.direction:
query = query.where(Receipt.direction == filters.direction)
if filters.company_id:
query = query.where(Receipt.company_id == filters.company_id)
if filters.created_by:
query = query.where(Receipt.created_by == filters.created_by)
if filters.date_from:
query = query.where(Receipt.receipt_date >= filters.date_from)
if filters.date_to:
query = query.where(Receipt.receipt_date <= filters.date_to)
if filters.search:
search_term = f"%{filters.search}%"
query = query.where(
or_(
Receipt.description.ilike(search_term),
Receipt.partner_name.ilike(search_term),
Receipt.receipt_number.ilike(search_term),
)
)
# Bulk upload filters (US-012)
# US-005: Support comma-separated values for processing_status filter (e.g., "pending,processing")
if filters.processing_status:
statuses = [s.strip() for s in filters.processing_status.split(",")]
if len(statuses) == 1:
query = query.where(Receipt.processing_status == statuses[0])
else:
query = query.where(Receipt.processing_status.in_(statuses))
if filters.batch_id:
query = query.where(Receipt.batch_id == filters.batch_id)
# Count total
count_query = select(func.count()).select_from(query.subquery())
total_result = await session.execute(count_query)
total = total_result.scalar() or 0
# Apply ordering based on sort_by parameter (US-012)
if filters.sort_by == "processing_started_at":
query = query.order_by(Receipt.processing_started_at.desc())
elif filters.sort_by == "processing_started_at_asc":
query = query.order_by(Receipt.processing_started_at.asc())
else:
# Default ordering
query = query.order_by(Receipt.created_at.desc())
# Apply pagination
offset = (filters.page - 1) * filters.page_size
query = query.offset(offset).limit(filters.page_size)
# Execute
result = await session.execute(query)
receipts = result.scalars().all()
return list(receipts), total
@staticmethod
async def get_processing_stats(
session: AsyncSession,
company_id: Optional[int] = None,
batch_id: Optional[str] = None,
) -> dict:
"""Get processing status counts for bulk uploaded receipts (US-012)."""
# Build base query for counting by processing_status
base_conditions = []
if company_id:
base_conditions.append(Receipt.company_id == company_id)
if batch_id:
base_conditions.append(Receipt.batch_id == batch_id)
# Only count receipts that have a processing_status (bulk uploads)
base_conditions.append(Receipt.processing_status.isnot(None))
query = select(
Receipt.processing_status,
func.count(Receipt.id).label("count")
)
for condition in base_conditions:
query = query.where(condition)
query = query.group_by(Receipt.processing_status)
result = await session.execute(query)
rows = result.all()
# Initialize stats
stats = {
"pending_count": 0,
"processing_count": 0,
"completed_count": 0,
"failed_count": 0,
}
# Map results
for row in rows:
status = row.processing_status
count = row.count
if status == "pending":
stats["pending_count"] = count
elif status == "processing":
stats["processing_count"] = count
elif status == "completed":
stats["completed_count"] = count
elif status == "failed":
stats["failed_count"] = count
return stats
@staticmethod
async def get_pending_review(
session: AsyncSession,
company_id: Optional[int] = None,
) -> List[Receipt]:
"""Get all receipts pending review."""
query = select(Receipt).where(
Receipt.status == ReceiptStatus.PENDING_REVIEW
).options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
if company_id:
query = query.where(Receipt.company_id == company_id)
query = query.order_by(Receipt.submitted_at.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
async def update(
session: AsyncSession,
receipt: Receipt,
data: ReceiptUpdate,
) -> Receipt:
"""Update receipt fields.
US-407: When a receipt is manually updated, reset processing_status and
processing_error to NULL. This allows failed OCR receipts to be corrected
manually and then submitted for approval without showing as "error" status.
"""
update_data = data.model_dump(exclude_unset=True)
# Recalculate tva_total from tva_breakdown if breakdown is being updated
if 'tva_breakdown' in update_data and update_data['tva_breakdown']:
tva_total = sum(
float(entry.get('amount', 0) if isinstance(entry, dict) else getattr(entry, 'amount', 0))
for entry in update_data['tva_breakdown']
)
update_data['tva_total'] = round(tva_total, 2)
# Serialize tva_breakdown and payment_methods to JSON string if present
if 'tva_breakdown' in update_data:
update_data['tva_breakdown'] = _serialize_tva_breakdown(update_data['tva_breakdown'])
if 'payment_methods' in update_data:
update_data['payment_methods'] = _serialize_payment_methods(update_data['payment_methods'])
for field, value in update_data.items():
setattr(receipt, field, value)
# US-407: Reset processing status when receipt is manually edited
# This clears the "failed" status so edited receipts can be submitted for approval
if receipt.processing_status == 'failed':
receipt.processing_status = None
receipt.processing_error = None
receipt.updated_at = datetime.utcnow()
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def update_status(
session: AsyncSession,
receipt: Receipt,
new_status: ReceiptStatus,
reviewed_by: Optional[str] = None,
rejection_reason: Optional[str] = None,
) -> Receipt:
"""Update receipt workflow status."""
receipt.status = new_status
receipt.updated_at = datetime.utcnow()
if new_status == ReceiptStatus.PENDING_REVIEW:
receipt.submitted_at = datetime.utcnow()
if new_status in [ReceiptStatus.APPROVED, ReceiptStatus.REJECTED]:
receipt.reviewed_by = reviewed_by
receipt.reviewed_at = datetime.utcnow()
if new_status == ReceiptStatus.REJECTED:
receipt.rejection_reason = rejection_reason
if new_status == ReceiptStatus.DRAFT:
# Reset review fields when moving back to draft
receipt.rejection_reason = None
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def delete(session: AsyncSession, receipt: Receipt) -> bool:
"""Delete a receipt (cascade deletes attachments and entries)."""
await session.delete(receipt)
await session.commit()
return True
@staticmethod
async def can_edit(receipt: Receipt, username: str) -> bool:
"""Check if user can edit receipt."""
# DRAFT and REJECTED receipts can be edited (to fix and resubmit)
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
return False
# Only creator can edit their own receipts
return receipt.created_by == username
@staticmethod
async def can_delete(receipt: Receipt, username: str) -> bool:
"""Check if user can delete receipt."""
# Only DRAFT receipts can be deleted
if receipt.status != ReceiptStatus.DRAFT:
return False
# Only creator can delete their own drafts
return receipt.created_by == username
@staticmethod
async def can_submit(receipt: Receipt, username: str) -> bool:
"""Check if user can submit receipt for review."""
# Only DRAFT or REJECTED receipts can be submitted
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
return False
# Only creator can submit their own receipts
return receipt.created_by == username
@staticmethod
async def get_stats(
session: AsyncSession,
company_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Get receipt statistics."""
base_query = select(
Receipt.status,
func.count(Receipt.id).label("count"),
func.sum(Receipt.amount).label("total_amount"),
).where(
Receipt.company_id == company_id
)
if created_by:
base_query = base_query.where(Receipt.created_by == created_by)
query = base_query.group_by(Receipt.status)
result = await session.execute(query)
rows = result.all()
stats = {
"draft": {"count": 0, "amount": 0},
"pending_review": {"count": 0, "amount": 0},
"approved": {"count": 0, "amount": 0},
"rejected": {"count": 0, "amount": 0},
"synced": {"count": 0, "amount": 0},
"total": {"count": 0, "amount": 0},
}
for row in rows:
status_key = row.status.value
stats[status_key] = {
"count": row.count,
"amount": float(row.total_amount or 0),
}
stats["total"]["count"] += row.count
stats["total"]["amount"] += float(row.total_amount or 0)
return stats

View File

@@ -0,0 +1,50 @@
"""Database configuration and session management using SQLModel."""
from pathlib import Path
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from backend.config import settings
# Create async engine
# Note: echo=False to disable SQL query logging (too verbose)
engine = create_async_engine(
settings.data_entry_database_url,
echo=False,
future=True,
)
# Create async session factory
async_session_maker = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def init_db() -> None:
"""Initialize database - create tables if they don't exist."""
# Ensure data directory exists
db_path = Path(settings.data_entry_sqlite_database_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Get async database session for dependency injection."""
async with async_session_maker() as session:
try:
yield session
finally:
await session.close()
# Convenience function for manual session usage
async def get_db_session() -> AsyncSession:
"""Get a new database session (manual management)."""
return async_session_maker()

View File

@@ -0,0 +1,131 @@
"""
Alembic migrations helper for Data Entry module.
Provides automatic migration execution at backend startup.
"""
import logging
import os
from pathlib import Path
logger = logging.getLogger(__name__)
def run_migrations() -> bool:
"""
Run pending Alembic migrations at startup.
Returns:
True if migrations ran successfully (or no pending migrations),
False if migrations failed (backend should continue with WARNING).
"""
try:
from alembic.config import Config
from alembic import command
from alembic.runtime.migration import MigrationContext
from sqlalchemy import create_engine
# Get the path to alembic.ini
data_entry_module = Path(__file__).parent.parent
alembic_ini_path = data_entry_module / "alembic.ini"
if not alembic_ini_path.exists():
logger.warning(f"[MIGRATIONS] alembic.ini not found at {alembic_ini_path}")
return False
# Get database path from environment or default
db_path = Path(os.getenv(
"SQLITE_DATABASE_PATH",
"data/receipts/receipts.db"
)).resolve()
# Ensure database directory exists
db_path.parent.mkdir(parents=True, exist_ok=True)
# Create Alembic config
alembic_cfg = Config(str(alembic_ini_path))
# Override database URL
sync_db_url = f"sqlite:///{db_path}"
alembic_cfg.set_main_option("sqlalchemy.url", sync_db_url)
# Set script location relative to alembic.ini
alembic_cfg.set_main_option(
"script_location",
str(data_entry_module / "migrations")
)
# Get current revision before upgrade
engine = create_engine(sync_db_url)
with engine.connect() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
engine.dispose()
logger.info(f"[MIGRATIONS] Current revision: {current_rev or 'None (fresh database)'}")
logger.info(f"[MIGRATIONS] Database path: {db_path}")
# Run upgrade to head
logger.info("[MIGRATIONS] Checking for pending migrations...")
command.upgrade(alembic_cfg, "head")
# Get new revision after upgrade
engine = create_engine(sync_db_url)
with engine.connect() as connection:
context = MigrationContext.configure(connection)
new_rev = context.get_current_revision()
engine.dispose()
if current_rev != new_rev:
logger.info(f"[MIGRATIONS] Applied: {current_rev or 'None'} -> {new_rev}")
else:
logger.info(f"[MIGRATIONS] No pending migrations. Current: {new_rev}")
return True
except ImportError as e:
logger.warning(f"[MIGRATIONS] Alembic not installed: {e}")
logger.warning("[MIGRATIONS] Skipping migrations - install alembic to enable")
return False
except Exception as e:
logger.error(f"[MIGRATIONS] Migration error: {e}", exc_info=True)
logger.warning("[MIGRATIONS] Backend will continue without migrations")
return False
def get_current_revision() -> str:
"""
Get the current Alembic revision.
Returns:
Current revision string, or 'unknown' if cannot be determined.
"""
try:
from alembic.runtime.migration import MigrationContext
from sqlalchemy import create_engine
# Get database path from environment or default
db_path = Path(os.getenv(
"SQLITE_DATABASE_PATH",
"data/receipts/receipts.db"
)).resolve()
if not db_path.exists():
return "no_database"
sync_db_url = f"sqlite:///{db_path}"
engine = create_engine(sync_db_url)
with engine.connect() as connection:
context = MigrationContext.configure(connection)
revision = context.get_current_revision()
engine.dispose()
return revision or "none"
except ImportError:
return "alembic_not_installed"
except Exception as e:
logger.debug(f"[MIGRATIONS] Could not get revision: {e}")
return "unknown"

View File

@@ -0,0 +1,29 @@
# Database models
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection, ProcessingStatus
from .accounting_entry import AccountingEntry, EntryType
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
from .ocr_settings import UserOCRPreference, OCRJobMetrics, OCRMetricsSummary, OCREngine
from .batch import BatchUpload, BatchJob, BatchStatus
__all__ = [
"Receipt",
"ReceiptAttachment",
"ReceiptStatus",
"ReceiptType",
"ReceiptDirection",
"ProcessingStatus",
"AccountingEntry",
"EntryType",
"SyncedSupplier",
"LocalSupplier",
"SyncedCashRegister",
# OCR Settings & Metrics
"UserOCRPreference",
"OCRJobMetrics",
"OCRMetricsSummary",
"OCREngine",
# Batch Upload
"BatchUpload",
"BatchJob",
"BatchStatus",
]

View File

@@ -0,0 +1,49 @@
"""AccountingEntry SQLModel model for proposed accounting entries."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship
if TYPE_CHECKING:
from .receipt import Receipt
class EntryType(str, Enum):
"""Type of accounting entry."""
DEBIT = "debit"
CREDIT = "credit"
class AccountingEntry(SQLModel, table=True):
"""Proposed accounting entry for a receipt."""
__tablename__ = "accounting_entries"
id: Optional[int] = Field(default=None, primary_key=True)
receipt_id: int = Field(foreign_key="receipts.id", index=True)
# Account
entry_type: EntryType
account_code: str = Field(max_length=20) # e.g., 6022, 5311, 4426
account_name: Optional[str] = Field(default=None, max_length=200) # Cache: "Cheltuieli combustibil"
# Amount
amount: Decimal = Field(decimal_places=2, max_digits=15)
# Analytics (optional)
partner_id: Optional[int] = Field(default=None)
cost_center_id: Optional[int] = Field(default=None)
# Entry metadata
is_auto_generated: bool = Field(default=True) # True if system-generated
modified_by: Optional[str] = Field(default=None, max_length=100) # Username if modified
modified_at: Optional[datetime] = Field(default=None)
# Order for display
sort_order: int = Field(default=0)
# Relationship
receipt: Optional["Receipt"] = Relationship(back_populates="entries")

View File

@@ -0,0 +1,64 @@
"""BatchUpload and BatchJob SQLModel models for bulk receipt processing."""
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlmodel import SQLModel, Field
class BatchStatus(str, Enum):
"""Status of a batch upload."""
PENDING = "pending" # Batch created, jobs queued
PROCESSING = "processing" # At least one job is processing
COMPLETED = "completed" # All jobs completed (success or failed)
FAILED = "failed" # Batch-level failure (e.g., all jobs failed)
class BatchUpload(SQLModel, table=True):
"""
Batch upload record for grouping multiple OCR jobs.
Tracks overall progress and status of a bulk upload operation.
"""
__tablename__ = "batch_uploads"
id: Optional[int] = Field(default=None, primary_key=True)
# User info
user_id: str = Field(max_length=100, index=True) # Username who created the batch
company_id: int = Field(index=True) # Company ID for receipt creation
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
# Status tracking
status: BatchStatus = Field(default=BatchStatus.PENDING)
total_files: int = Field(default=0)
class BatchJob(SQLModel, table=True):
"""
Junction table linking batch_uploads to ocr_jobs.
Each record represents one file in a batch, linking to its OCR job.
Also stores the receipt_id once the job completes and auto-creates a receipt.
"""
__tablename__ = "batch_jobs"
id: Optional[int] = Field(default=None, primary_key=True)
# Foreign keys
batch_id: int = Field(foreign_key="batch_uploads.id", index=True)
job_id: str = Field(max_length=36, index=True) # UUID from ocr_jobs table
# Original filename for display
filename: str = Field(max_length=255)
# Receipt reference (set after auto-create)
receipt_id: Optional[int] = Field(default=None, foreign_key="receipts.id")
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -0,0 +1,46 @@
"""Nomenclature models for synced and local data."""
from typing import Optional
from datetime import datetime
from sqlmodel import SQLModel, Field
class SyncedSupplier(SQLModel, table=True):
"""Suppliers synced from Oracle NOM_PARTENERI."""
__tablename__ = "synced_suppliers"
id: Optional[int] = Field(default=None, primary_key=True)
oracle_id: int = Field(index=True) # Original Oracle ID
company_id: int = Field(index=True) # Company this supplier belongs to
name: str = Field(max_length=200)
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True) # CUI/CIF
address: Optional[str] = Field(default=None, max_length=500)
synced_at: datetime = Field(default_factory=datetime.utcnow)
class LocalSupplier(SQLModel, table=True):
"""Suppliers created locally from OCR (not in Oracle)."""
__tablename__ = "local_suppliers"
id: Optional[int] = Field(default=None, primary_key=True)
company_id: int = Field(index=True)
name: str = Field(max_length=200)
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True)
address: Optional[str] = Field(default=None, max_length=500)
created_by: str = Field(max_length=100) # Username who created it
created_at: datetime = Field(default_factory=datetime.utcnow)
# Flag to indicate if it should be synced to Oracle later
pending_oracle_sync: bool = Field(default=True)
class SyncedCashRegister(SQLModel, table=True):
"""Cash registers and bank accounts synced from Oracle."""
__tablename__ = "synced_cash_registers"
id: Optional[int] = Field(default=None, primary_key=True)
oracle_id: int = Field(index=True)
company_id: int = Field(index=True)
name: str = Field(max_length=100)
account_code: str = Field(max_length=20) # 5311, 5121, etc.
register_type: str = Field(max_length=10) # 'cash' or 'bank'
synced_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -0,0 +1,102 @@
"""OCR settings and metrics SQLModel models."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional
from sqlmodel import SQLModel, Field
class OCREngine(str, Enum):
"""Available OCR engines."""
TESSERACT = "tesseract"
DOCTR = "doctr"
DOCTR_PLUS = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
PADDLEOCR = "paddleocr"
class UserOCRPreference(SQLModel, table=True):
"""
User's preferred OCR engine setting.
Each user can have one preferred OCR engine that will be
auto-selected when they upload new receipts for processing.
"""
__tablename__ = "user_ocr_preferences"
id: Optional[int] = Field(default=None, primary_key=True)
# User identification
username: str = Field(max_length=100, unique=True, index=True)
# Preference settings
preferred_engine: OCREngine = Field(default=OCREngine.DOCTR_PLUS)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class OCRJobMetrics(SQLModel, table=True):
"""
OCR job processing metrics for analytics.
Stores metrics for each OCR job to enable:
- Performance tracking by engine
- Success rate analysis
- Processing time trends
- User-specific analytics
"""
__tablename__ = "ocr_job_metrics"
id: Optional[int] = Field(default=None, primary_key=True)
# Job identification
job_id: str = Field(max_length=50, unique=True, index=True)
# User and company context
username: str = Field(max_length=100, index=True)
company_id: Optional[int] = Field(default=None, index=True)
# Engine used
engine_requested: str = Field(max_length=20) # What user/auto requested
engine_used: str = Field(max_length=50) # What was actually used (e.g., "doctr-light")
# Processing metrics
processing_time_ms: int = Field(default=0)
file_size_bytes: int = Field(default=0)
file_type: str = Field(max_length=50, default="image/jpeg") # MIME type
original_filename: Optional[str] = Field(default=None, max_length=255) # Original uploaded filename
# Success metrics
success: bool = Field(default=True)
error_message: Optional[str] = Field(default=None, max_length=500)
# Extraction quality metrics
overall_confidence: float = Field(default=0.0)
fields_extracted: int = Field(default=0) # Number of fields successfully extracted
needs_manual_review: Optional[bool] = Field(default=None)
validation_warnings_count: int = Field(default=0)
validation_errors_count: int = Field(default=0)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
class OCRMetricsSummary(SQLModel):
"""
Summary metrics for OCR analytics.
Not a database table - used for API responses.
"""
engine: str
total_jobs: int
successful_jobs: int
failed_jobs: int
success_rate: float # Computed: successful_jobs / total_jobs
avg_processing_time_ms: float
avg_confidence: float
avg_fields_extracted: float

View File

@@ -0,0 +1,143 @@
"""Receipt and ReceiptAttachment SQLModel models."""
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship
class ReceiptType(str, Enum):
"""Type of receipt document."""
BON_FISCAL = "bon_fiscal"
CHITANTA = "chitanta"
class ReceiptDirection(str, Enum):
"""Direction of receipt - expense or income."""
CHELTUIALA = "cheltuiala" # Expense (receipt from supplier)
INCASARE = "incasare" # Income (receipt issued to client)
class ReceiptStatus(str, Enum):
"""Workflow status of receipt."""
DRAFT = "draft" # User is filling in data
PENDING_REVIEW = "pending_review" # Awaiting accountant approval
APPROVED = "approved" # Approved by accountant
REJECTED = "rejected" # Rejected by accountant
SYNCED = "synced" # Synced to Oracle (Phase 2)
class PaymentMode(str, Enum):
"""Payment mode - how the expense was paid."""
CASA = "casa" # Numerar firma (5311)
BANCA = "banca" # Virament/POS (5121)
AVANS_DECONTARE = "avans_decontare" # Decont angajat (542)
class ProcessingStatus(str, Enum):
"""Processing status for bulk uploaded receipts."""
PENDING = "pending" # Waiting in queue
PROCESSING = "processing" # Currently being processed by OCR
COMPLETED = "completed" # Successfully processed
FAILED = "failed" # Processing failed with error
if TYPE_CHECKING:
from .accounting_entry import AccountingEntry
class Receipt(SQLModel, table=True):
"""Receipt (Bon Fiscal / Chitanta) with approval workflow."""
__tablename__ = "receipts"
id: Optional[int] = Field(default=None, primary_key=True)
# Document identification
receipt_type: ReceiptType = Field(default=ReceiptType.BON_FISCAL)
direction: ReceiptDirection = Field(default=ReceiptDirection.CHELTUIALA)
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
# Main data
receipt_date: date
amount: Decimal = Field(decimal_places=2, max_digits=15)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (extracted from OCR) - stored as JSON for multiple entries
tva_breakdown: Optional[str] = Field(default=None, max_length=1000) # JSON: [{"code":"A","percent":19,"amount":"15.20"}]
tva_total: Optional[Decimal] = Field(default=None, decimal_places=2, max_digits=15)
items_count: Optional[int] = Field(default=None)
vendor_address: Optional[str] = Field(default=None, max_length=500)
# Expense type (for auto-generating accounting entries)
expense_type_code: Optional[str] = Field(default=None, max_length=20)
# Oracle references (nomenclatures)
company_id: int
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200) # Supplier name from OCR/selection
cui: Optional[str] = Field(default=None, max_length=20) # Fiscal code from OCR
ocr_raw_text: Optional[str] = Field(default=None) # Raw OCR text for debugging
payment_methods: Optional[str] = Field(default=None, max_length=500) # JSON: [{"method":"CARD","amount":"50.00"}]
cash_register_id: Optional[int] = Field(default=None) # Cash/Bank ID from Oracle
cash_register_name: Optional[str] = Field(default=None, max_length=100) # Cache for display
cash_register_account: Optional[str] = Field(default=None, max_length=20) # Account code (5311, 5121)
payment_mode: Optional[str] = Field(default=None, max_length=20) # PaymentMode value: casa/banca/avans_decontare
# Workflow
status: ReceiptStatus = Field(default=ReceiptStatus.DRAFT)
created_by: str = Field(max_length=100) # Username of creator
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
submitted_at: Optional[datetime] = Field(default=None) # When submitted for approval
# Approval
reviewed_by: Optional[str] = Field(default=None, max_length=100) # Accountant username
reviewed_at: Optional[datetime] = Field(default=None)
rejection_reason: Optional[str] = Field(default=None, max_length=500) # Reason for rejection
# Phase 2 - Oracle sync
oracle_synced_at: Optional[datetime] = Field(default=None)
oracle_act_id: Optional[int] = Field(default=None)
oracle_error: Optional[str] = Field(default=None, max_length=500)
# Bulk upload batch tracking
batch_id: Optional[str] = Field(default=None, max_length=50, index=True)
processing_status: Optional[str] = Field(default=None, max_length=20, index=True) # ProcessingStatus enum value
processing_error: Optional[str] = Field(default=None) # Full error message text
file_hash: Optional[str] = Field(default=None, max_length=64, index=True) # SHA-256 hash for duplicate detection
processing_started_at: Optional[datetime] = Field(default=None)
processing_completed_at: Optional[datetime] = Field(default=None)
# Relationships
attachments: List["ReceiptAttachment"] = Relationship(
back_populates="receipt",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
entries: List["AccountingEntry"] = Relationship(
back_populates="receipt",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
class ReceiptAttachment(SQLModel, table=True):
"""Attachment (photo or PDF) for a receipt."""
__tablename__ = "receipt_attachments"
id: Optional[int] = Field(default=None, primary_key=True)
receipt_id: int = Field(foreign_key="receipts.id", index=True)
# File info
filename: str = Field(max_length=255) # Original filename
stored_filename: str = Field(max_length=255) # Filename on disk (UUID)
file_path: str = Field(max_length=500) # Relative path
file_size: int # Size in bytes
mime_type: str = Field(max_length=100) # MIME type (image/jpeg, application/pdf)
uploaded_at: datetime = Field(default_factory=datetime.utcnow)
# Relationship
receipt: Optional[Receipt] = Relationship(back_populates="attachments")

View File

@@ -0,0 +1,92 @@
"""Alembic environment configuration."""
import os
from pathlib import Path
from logging.config import fileConfig
from dotenv import load_dotenv
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlmodel import SQLModel
# Load environment variables from .env file
load_dotenv()
# Import all models to ensure they're registered with SQLModel
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
from backend.modules.data_entry.db.models.ocr_settings import UserOCRPreference, OCRJobMetrics
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Override sqlalchemy.url from environment variable if set
# Resolve to absolute path for Windows/IIS compatibility
db_path = Path(os.getenv("SQLITE_DATABASE_PATH", "data/receipts/receipts.db")).resolve()
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,112 @@
"""Initial receipts schema
Revision ID: 001_initial
Revises:
Create Date: 2024-12-11
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '001_initial'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create receipts table
op.create_table(
'receipts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_type', sa.Enum('BON_FISCAL', 'CHITANTA', name='receipttype'), nullable=False),
sa.Column('direction', sa.Enum('CHELTUIALA', 'INCASARE', name='receiptdirection'), nullable=False),
sa.Column('receipt_number', sa.String(length=50), nullable=True),
sa.Column('receipt_series', sa.String(length=20), nullable=True),
sa.Column('receipt_date', sa.Date(), nullable=False),
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
sa.Column('description', sa.String(length=500), nullable=True),
sa.Column('expense_type_code', sa.String(length=20), nullable=True),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=True),
sa.Column('partner_name', sa.String(length=200), nullable=True),
sa.Column('cash_register_id', sa.Integer(), nullable=True),
sa.Column('cash_register_name', sa.String(length=100), nullable=True),
sa.Column('cash_register_account', sa.String(length=20), nullable=True),
sa.Column('status', sa.Enum('DRAFT', 'PENDING_REVIEW', 'APPROVED', 'REJECTED', 'SYNCED', name='receiptstatus'), nullable=False),
sa.Column('created_by', sa.String(length=100), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('submitted_at', sa.DateTime(), nullable=True),
sa.Column('reviewed_by', sa.String(length=100), nullable=True),
sa.Column('reviewed_at', sa.DateTime(), nullable=True),
sa.Column('rejection_reason', sa.String(length=500), nullable=True),
sa.Column('oracle_synced_at', sa.DateTime(), nullable=True),
sa.Column('oracle_act_id', sa.Integer(), nullable=True),
sa.Column('oracle_error', sa.String(length=500), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_receipts_company_id'), 'receipts', ['company_id'], unique=False)
op.create_index(op.f('ix_receipts_status'), 'receipts', ['status'], unique=False)
op.create_index(op.f('ix_receipts_created_by'), 'receipts', ['created_by'], unique=False)
op.create_index(op.f('ix_receipts_receipt_date'), 'receipts', ['receipt_date'], unique=False)
# Create receipt_attachments table
op.create_table(
'receipt_attachments',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_id', sa.Integer(), nullable=False),
sa.Column('filename', sa.String(length=255), nullable=False),
sa.Column('stored_filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('file_size', sa.Integer(), nullable=False),
sa.Column('mime_type', sa.String(length=100), nullable=False),
sa.Column('uploaded_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_receipt_attachments_receipt_id'), 'receipt_attachments', ['receipt_id'], unique=False)
# Create accounting_entries table
op.create_table(
'accounting_entries',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_id', sa.Integer(), nullable=False),
sa.Column('entry_type', sa.Enum('DEBIT', 'CREDIT', name='entrytype'), nullable=False),
sa.Column('account_code', sa.String(length=20), nullable=False),
sa.Column('account_name', sa.String(length=200), nullable=True),
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=True),
sa.Column('cost_center_id', sa.Integer(), nullable=True),
sa.Column('is_auto_generated', sa.Boolean(), nullable=False),
sa.Column('modified_by', sa.String(length=100), nullable=True),
sa.Column('modified_at', sa.DateTime(), nullable=True),
sa.Column('sort_order', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_accounting_entries_receipt_id'), 'accounting_entries', ['receipt_id'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_accounting_entries_receipt_id'), table_name='accounting_entries')
op.drop_table('accounting_entries')
op.drop_index(op.f('ix_receipt_attachments_receipt_id'), table_name='receipt_attachments')
op.drop_table('receipt_attachments')
op.drop_index(op.f('ix_receipts_receipt_date'), table_name='receipts')
op.drop_index(op.f('ix_receipts_created_by'), table_name='receipts')
op.drop_index(op.f('ix_receipts_status'), table_name='receipts')
op.drop_index(op.f('ix_receipts_company_id'), table_name='receipts')
op.drop_table('receipts')
# Drop enums (SQLite doesn't actually use these, but for consistency)
op.execute("DROP TYPE IF EXISTS receipttype")
op.execute("DROP TYPE IF EXISTS receiptdirection")
op.execute("DROP TYPE IF EXISTS receiptstatus")
op.execute("DROP TYPE IF EXISTS entrytype")

View File

@@ -0,0 +1,37 @@
"""add_tva_breakdown_to_receipt
Revision ID: 1cfb423c6953
Revises: 001_initial
Create Date: 2025-12-12 14:04:22.464289+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '1cfb423c6953'
down_revision: Union[str, None] = '001_initial'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add TVA-related columns to receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('tva_breakdown', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True))
batch_op.add_column(sa.Column('tva_total', sa.Numeric(precision=15, scale=2), nullable=True))
batch_op.add_column(sa.Column('items_count', sa.Integer(), nullable=True))
batch_op.add_column(sa.Column('vendor_address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
def downgrade() -> None:
# Remove TVA-related columns from receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('vendor_address')
batch_op.drop_column('items_count')
batch_op.drop_column('tva_total')
batch_op.drop_column('tva_breakdown')

View File

@@ -0,0 +1,89 @@
"""add nomenclature tables
Revision ID: 3a653da79002
Revises: 1cfb423c6953
Create Date: 2025-12-13 00:28:05.719430+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '3a653da79002'
down_revision: Union[str, None] = '1cfb423c6953'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('local_suppliers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column('created_by', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('pending_oracle_sync', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_local_suppliers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_local_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
op.create_table('synced_cash_registers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('oracle_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
sa.Column('account_code', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False),
sa.Column('register_type', sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False),
sa.Column('synced_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_synced_cash_registers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_cash_registers_oracle_id'), ['oracle_id'], unique=False)
op.create_table('synced_suppliers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('oracle_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column('synced_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_synced_suppliers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_suppliers_oracle_id'), ['oracle_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_synced_suppliers_oracle_id'))
batch_op.drop_index(batch_op.f('ix_synced_suppliers_fiscal_code'))
batch_op.drop_index(batch_op.f('ix_synced_suppliers_company_id'))
op.drop_table('synced_suppliers')
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_oracle_id'))
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_company_id'))
op.drop_table('synced_cash_registers')
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_local_suppliers_fiscal_code'))
batch_op.drop_index(batch_op.f('ix_local_suppliers_company_id'))
op.drop_table('local_suppliers')
# ### end Alembic commands ###

View File

@@ -0,0 +1,35 @@
"""add_ocr_fields_to_receipt
Revision ID: 4b8e5f2a1d93
Revises: 3a653da79002
Create Date: 2025-12-15 10:00:00.000000+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '4b8e5f2a1d93'
down_revision: Union[str, None] = '3a653da79002'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add OCR-related columns to receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('cui', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
batch_op.add_column(sa.Column('ocr_raw_text', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('payment_methods', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
def downgrade() -> None:
# Remove OCR-related columns from receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('payment_methods')
batch_op.drop_column('ocr_raw_text')
batch_op.drop_column('cui')

View File

@@ -0,0 +1,29 @@
"""Remove partner_id from receipts - supplier data is text-only
Revision ID: 20251215_remove_partner_id
Revises: 20251216_payment_mode
Create Date: 2025-12-15
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '20251215_remove_partner_id'
down_revision: Union[str, None] = '20251216_payment_mode'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Remove partner_id column - supplier data is now text-only (partner_name, cui)."""
# Drop the partner_id column
op.drop_column('receipts', 'partner_id')
def downgrade() -> None:
"""Re-add partner_id column."""
op.add_column('receipts', sa.Column('partner_id', sa.Integer(), nullable=True))

View File

@@ -0,0 +1,44 @@
"""Add payment_mode field to receipts table.
Revision ID: 20251216_payment_mode
Revises: 4b8e5f2a1d93
Create Date: 2024-12-16
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '20251216_payment_mode'
down_revision = '4b8e5f2a1d93'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add payment_mode column and migrate existing data."""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('payment_mode', sa.String(length=20), nullable=True))
# Migrate existing data based on cash_register_account
op.execute("""
UPDATE receipts
SET payment_mode = 'casa'
WHERE cash_register_account LIKE '531%' AND payment_mode IS NULL
""")
op.execute("""
UPDATE receipts
SET payment_mode = 'banca'
WHERE cash_register_account LIKE '512%' AND payment_mode IS NULL
""")
op.execute("""
UPDATE receipts
SET payment_mode = 'avans_decontare'
WHERE cash_register_account LIKE '542%' AND payment_mode IS NULL
""")
def downgrade() -> None:
"""Remove payment_mode column."""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('payment_mode')

View File

@@ -0,0 +1,40 @@
"""Add needs_manual_review flag to receipts table.
Revision ID: 20251230_needs_manual_review
Revises: 20251216_payment_mode
Create Date: 2025-12-30
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '20251230_needs_manual_review'
down_revision = '20251216_payment_mode'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add needs_manual_review column for OCR validation tracking.
This column tracks whether a receipt needs manual supervisor review
based on OCR extraction validation warnings:
- NULL = not validated yet (old receipts before validation feature)
- FALSE = validated, no review needed
- TRUE = validated, needs review
"""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(
sa.Column('needs_manual_review', sa.Boolean(), nullable=True)
)
# NOTE: We do NOT set a default value for existing rows.
# NULL indicates the receipt was created before validation was implemented.
# Only new receipts (created after this migration) will have TRUE/FALSE values.
def downgrade() -> None:
"""Remove needs_manual_review column."""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('needs_manual_review')

View File

@@ -0,0 +1,74 @@
"""Add OCR settings and metrics tables.
Revision ID: add_ocr_settings_metrics
Revises: 20251230_add_needs_manual_review
Create Date: 2025-12-31
This migration adds:
- user_ocr_preferences: Store user's preferred OCR engine
- ocr_job_metrics: Store OCR job processing metrics for analytics
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers
revision = 'add_ocr_settings_metrics'
down_revision = '20251230_add_needs_manual_review'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Create OCR settings and metrics tables."""
# Create user_ocr_preferences table
op.create_table(
'user_ocr_preferences',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('username', sa.String(length=100), nullable=False),
sa.Column('preferred_engine', sa.String(length=20), nullable=False, server_default='doctr_plus'),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_user_ocr_preferences_username', 'user_ocr_preferences', ['username'], unique=True)
# Create ocr_job_metrics table
op.create_table(
'ocr_job_metrics',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('job_id', sa.String(length=50), nullable=False),
sa.Column('username', sa.String(length=100), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=True),
sa.Column('engine_requested', sa.String(length=20), nullable=False),
sa.Column('engine_used', sa.String(length=50), nullable=False),
sa.Column('processing_time_ms', sa.Integer(), nullable=False, server_default='0'),
sa.Column('file_size_bytes', sa.Integer(), nullable=False, server_default='0'),
sa.Column('file_type', sa.String(length=50), nullable=False, server_default='image/jpeg'),
sa.Column('success', sa.Boolean(), nullable=False, server_default='1'),
sa.Column('error_message', sa.String(length=500), nullable=True),
sa.Column('overall_confidence', sa.Float(), nullable=False, server_default='0.0'),
sa.Column('fields_extracted', sa.Integer(), nullable=False, server_default='0'),
sa.Column('needs_manual_review', sa.Boolean(), nullable=True),
sa.Column('validation_warnings_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('validation_errors_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_ocr_job_metrics_job_id', 'ocr_job_metrics', ['job_id'], unique=True)
op.create_index('ix_ocr_job_metrics_username', 'ocr_job_metrics', ['username'], unique=False)
op.create_index('ix_ocr_job_metrics_company_id', 'ocr_job_metrics', ['company_id'], unique=False)
op.create_index('ix_ocr_job_metrics_created_at', 'ocr_job_metrics', ['created_at'], unique=False)
def downgrade() -> None:
"""Drop OCR settings and metrics tables."""
op.drop_index('ix_ocr_job_metrics_created_at', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_company_id', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_username', table_name='ocr_job_metrics')
op.drop_index('ix_ocr_job_metrics_job_id', table_name='ocr_job_metrics')
op.drop_table('ocr_job_metrics')
op.drop_index('ix_user_ocr_preferences_username', table_name='user_ocr_preferences')
op.drop_table('user_ocr_preferences')

View File

@@ -0,0 +1,30 @@
"""Add original_filename to ocr_job_metrics.
Revision ID: add_original_filename_to_metrics
Revises: add_ocr_settings_metrics
Create Date: 2025-12-31
Adds original_filename column to track the uploaded filename.
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers
revision = 'add_original_filename_to_metrics'
down_revision = 'add_ocr_settings_metrics'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add original_filename column to ocr_job_metrics."""
op.add_column(
'ocr_job_metrics',
sa.Column('original_filename', sa.String(length=255), nullable=True)
)
def downgrade() -> None:
"""Remove original_filename column."""
op.drop_column('ocr_job_metrics', 'original_filename')

View File

@@ -0,0 +1,54 @@
"""Add company_id to batch_uploads table.
Revision ID: 20260109_batch_company
Revises: 20251231_add_original_filename_to_metrics
Create Date: 2026-01-09
This migration adds the company_id column to batch_uploads to support
automatic receipt creation during bulk upload processing.
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '20260109_batch_company'
down_revision = None # Will be auto-detected
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add company_id column to batch_uploads table."""
# Check if column already exists (SQLModel may have created it)
conn = op.get_bind()
inspector = sa.inspect(conn)
# Check if batch_uploads table exists
if 'batch_uploads' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('batch_uploads')]
if 'company_id' not in columns:
op.add_column(
'batch_uploads',
sa.Column('company_id', sa.Integer(), nullable=True)
)
# Create index for company_id
op.create_index(
'ix_batch_uploads_company_id',
'batch_uploads',
['company_id'],
unique=False
)
def downgrade() -> None:
"""Remove company_id column from batch_uploads table."""
conn = op.get_bind()
inspector = sa.inspect(conn)
if 'batch_uploads' in inspector.get_table_names():
columns = [col['name'] for col in inspector.get_columns('batch_uploads')]
if 'company_id' in columns:
op.drop_index('ix_batch_uploads_company_id', table_name='batch_uploads')
op.drop_column('batch_uploads', 'company_id')

View File

@@ -0,0 +1,125 @@
"""Add batch processing fields to receipts table.
Revision ID: add_batch_processing_fields
Revises: add_original_filename_to_metrics
Create Date: 2026-01-11
Adds fields for bulk upload batch tracking:
- batch_id: UUID string for grouping receipts from same upload
- processing_status: enum (pending/processing/completed/failed)
- processing_error: full error message text
- file_hash: SHA-256 hash for duplicate detection
- processing_started_at: when OCR processing started
- processing_completed_at: when OCR processing completed
Also creates indexes for efficient querying.
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers
revision = 'add_batch_processing_fields'
down_revision = 'add_original_filename_to_metrics'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add batch processing columns to receipts table."""
conn = op.get_bind()
inspector = sa.inspect(conn)
# Get existing columns
columns = [col['name'] for col in inspector.get_columns('receipts')]
# Add batch_id column with index
if 'batch_id' not in columns:
op.add_column(
'receipts',
sa.Column('batch_id', sa.String(length=50), nullable=True)
)
op.create_index(
'ix_receipts_batch_id',
'receipts',
['batch_id'],
unique=False
)
# Add processing_status column with index
if 'processing_status' not in columns:
op.add_column(
'receipts',
sa.Column('processing_status', sa.String(length=20), nullable=True)
)
op.create_index(
'ix_receipts_processing_status',
'receipts',
['processing_status'],
unique=False
)
# Add processing_error column (TEXT for full error messages)
if 'processing_error' not in columns:
op.add_column(
'receipts',
sa.Column('processing_error', sa.Text(), nullable=True)
)
# Add file_hash column with index for duplicate detection
if 'file_hash' not in columns:
op.add_column(
'receipts',
sa.Column('file_hash', sa.String(length=64), nullable=True)
)
op.create_index(
'ix_receipts_file_hash',
'receipts',
['file_hash'],
unique=False
)
# Add processing_started_at column
if 'processing_started_at' not in columns:
op.add_column(
'receipts',
sa.Column('processing_started_at', sa.DateTime(), nullable=True)
)
# Add processing_completed_at column
if 'processing_completed_at' not in columns:
op.add_column(
'receipts',
sa.Column('processing_completed_at', sa.DateTime(), nullable=True)
)
def downgrade() -> None:
"""Remove batch processing columns from receipts table."""
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('receipts')]
indexes = [idx['name'] for idx in inspector.get_indexes('receipts')]
# Remove indexes first (SQLite batch mode)
if 'ix_receipts_batch_id' in indexes:
op.drop_index('ix_receipts_batch_id', table_name='receipts')
if 'ix_receipts_processing_status' in indexes:
op.drop_index('ix_receipts_processing_status', table_name='receipts')
if 'ix_receipts_file_hash' in indexes:
op.drop_index('ix_receipts_file_hash', table_name='receipts')
# Remove columns (in reverse order of addition)
if 'processing_completed_at' in columns:
op.drop_column('receipts', 'processing_completed_at')
if 'processing_started_at' in columns:
op.drop_column('receipts', 'processing_started_at')
if 'file_hash' in columns:
op.drop_column('receipts', 'file_hash')
if 'processing_error' in columns:
op.drop_column('receipts', 'processing_error')
if 'processing_status' in columns:
op.drop_column('receipts', 'processing_status')
if 'batch_id' in columns:
op.drop_column('receipts', 'batch_id')

View File

@@ -0,0 +1,39 @@
"""Data Entry module router factory."""
from fastapi import APIRouter
def create_data_entry_router() -> APIRouter:
"""
Create and configure Data Entry module router.
Includes all data entry endpoints:
- /receipts - Receipt CRUD and workflow
- /ocr - OCR processing for receipts
- /nomenclature - Nomenclature syncing from Oracle
- /settings - User settings (OCR preferences)
- /metrics - OCR analytics and metrics
- /bulk - Bulk upload for batch processing
Returns:
APIRouter: Configured router for data entry module
"""
router = APIRouter()
# Import routers here to avoid circular imports
from .receipts import router as receipts_router
from .ocr import router as ocr_router
from .nomenclature import router as nomenclature_router
from .ocr_settings import router as ocr_settings_router
from .bulk import router as bulk_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
# OCR settings and metrics (endpoints at /settings/* and /metrics/*)
router.include_router(ocr_settings_router, tags=["data-entry-settings"])
# Bulk upload for batch processing
router.include_router(bulk_router, prefix="/bulk", tags=["data-entry-bulk"])
return router

View File

@@ -0,0 +1,997 @@
"""
Bulk upload API endpoints for batch receipt processing.
Endpoints:
- POST /upload - Submit multiple files for OCR processing in a single batch
- GET /batches/{batch_id}/status - Get batch status with optional long-polling
Validation:
- Max 100 files per batch
- Max 10MB per file
- Allowed types: PDF, PNG, JPG
Duplicate Detection (US-007):
- SHA-256 hash calculated for each file
- Duplicate files (same hash + company_id) are rejected with 409 Conflict info
- Duplicates reported in error list, non-duplicates processed normally
"""
import asyncio
import hashlib
import logging
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Annotated, List, Optional, Union
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query, Header
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.models import BatchUpload, BatchJob, BatchStatus, Receipt, ReceiptAttachment
from backend.modules.data_entry.schemas.bulk import (
BulkUploadResponse,
BulkUploadResponseWithDuplicates,
BatchStatusResponse,
BatchJobInfo,
DuplicateFileInfo,
RetryResponse,
BatchRetryResponse,
CancelJobResponse,
CancelBatchResponse
)
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus
from backend.config import settings
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
logger = logging.getLogger(__name__)
router = APIRouter()
# ============ Helper for selected company from header ============
async def get_selected_company(
current_user: CurrentUser = Depends(get_current_user),
x_selected_company: Annotated[Optional[str], Header()] = None
) -> int:
"""
Get selected company from X-Selected-Company header.
Validates that the user has access to the specified company.
Falls back to user's first company if no header is provided.
"""
if x_selected_company:
try:
company_id = int(x_selected_company)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid company ID format: {x_selected_company}"
)
if str(company_id) in current_user.companies:
return company_id
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company_id}"
)
# No header - use first company from user's list
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
pass
raise HTTPException(
status_code=400,
detail="Nu aveți nicio firmă asignată"
)
# Validation constants
MAX_FILES_PER_BATCH = 100
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
ALLOWED_MIME_TYPES = {"image/jpeg", "image/png", "application/pdf"}
def compute_file_hash(content: bytes) -> str:
"""
Compute SHA-256 hash of file content.
Used for duplicate detection - same file content = same hash.
Args:
content: Raw file bytes
Returns:
Hexadecimal string of SHA-256 hash (64 characters)
"""
return hashlib.sha256(content).hexdigest()
async def check_duplicate_hashes(
session: AsyncSession,
file_hashes: List[str],
company_id: int
) -> dict[str, int]:
"""
Check which file hashes already exist in the database for this company.
Args:
session: Database session
file_hashes: List of SHA-256 hashes to check
company_id: Company ID to scope the duplicate check
Returns:
Dict mapping hash -> existing receipt_id for duplicates found
"""
if not file_hashes:
return {}
# Query for existing receipts with these hashes for this company
result = await session.execute(
select(Receipt.file_hash, Receipt.id).where(
and_(
Receipt.file_hash.in_(file_hashes),
Receipt.company_id == company_id
)
)
)
# Build hash -> receipt_id mapping
# Note: result.all() is synchronous in SQLAlchemy async, returns list of tuples
duplicates = {}
rows = result.all()
for row in rows:
duplicates[row[0]] = row[1]
return duplicates
@router.post("/upload", response_model=Union[BulkUploadResponse, BulkUploadResponseWithDuplicates])
async def bulk_upload(
files: List[UploadFile] = File(..., description="Multiple files to upload"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
selected_company: int = Depends(get_selected_company)
):
"""
Upload multiple files for batch OCR processing.
Creates a batch record and queues all files as OCR jobs.
Invalid files cause entire batch rejection (validation errors).
Duplicate files are reported separately and skipped - non-duplicates are processed.
Duplicate Detection (US-007):
- SHA-256 hash calculated for each file before processing
- Files with existing hash for same company are rejected with 409 info
- Response includes duplicate details with existing_receipt_id
Args:
files: List of image/PDF files (max 100 files, max 10MB each)
Returns:
BulkUploadResponse with batch_id and list of job_ids
BulkUploadResponseWithDuplicates if some files were duplicates
Raises:
400: If validation fails (too many files, file too large, invalid type)
409: If ALL files are duplicates
500: If job creation fails
"""
# Validate file count
if len(files) == 0:
raise HTTPException(
status_code=400,
detail="No files provided"
)
if len(files) > MAX_FILES_PER_BATCH:
raise HTTPException(
status_code=400,
detail=f"Too many files. Maximum {MAX_FILES_PER_BATCH} files per batch."
)
# Pre-validate all files before creating any jobs (atomic check)
invalid_files = []
file_contents = []
for file in files:
# Check MIME type
if file.content_type not in ALLOWED_MIME_TYPES:
invalid_files.append(f"{file.filename}: Invalid type ({file.content_type})")
continue
# Read content and check size
content = await file.read()
if len(content) > MAX_FILE_SIZE_BYTES:
invalid_files.append(f"{file.filename}: File too large ({len(content) // (1024*1024)}MB > 10MB)")
continue
# Compute SHA-256 hash for duplicate detection (US-007)
file_hash = compute_file_hash(content)
# Store for later processing
file_contents.append({
"filename": file.filename,
"content": content,
"mime_type": file.content_type,
"file_hash": file_hash
})
# If any files are invalid, reject the entire batch
if invalid_files:
raise HTTPException(
status_code=400,
detail={
"message": f"Validation failed for {len(invalid_files)} file(s)",
"invalid_files": invalid_files
}
)
# Check for duplicates BEFORE creating batch (US-007)
all_hashes = [f["file_hash"] for f in file_contents]
existing_duplicates = await check_duplicate_hashes(session, all_hashes, selected_company)
# Separate duplicate files from processable files
duplicate_files: List[DuplicateFileInfo] = []
processable_files = []
for file_data in file_contents:
if file_data["file_hash"] in existing_duplicates:
existing_receipt_id = existing_duplicates[file_data["file_hash"]]
duplicate_files.append(DuplicateFileInfo(
filename=file_data["filename"],
error="duplicate",
existing_receipt_id=existing_receipt_id,
message=f"Fișier duplicat - există deja ca bon #{existing_receipt_id}"
))
logger.info(
f"[BulkUpload] Duplicate detected: {file_data['filename']} "
f"(hash={file_data['file_hash'][:16]}...) matches receipt #{existing_receipt_id}"
)
else:
processable_files.append(file_data)
# If ALL files are duplicates, return 409 Conflict
if len(duplicate_files) == len(file_contents):
raise HTTPException(
status_code=409,
detail={
"error": "all_duplicates",
"message": f"Toate cele {len(duplicate_files)} fișiere sunt duplicate",
"duplicates": [d.model_dump() for d in duplicate_files]
}
)
# If no processable files remain after filtering (shouldn't happen but be safe)
if not processable_files:
raise HTTPException(
status_code=409,
detail={
"error": "no_files_to_process",
"message": "Nu există fișiere de procesat",
"duplicates": [d.model_dump() for d in duplicate_files]
}
)
# Create batch record with company_id for auto-save
batch = BatchUpload(
user_id=current_user.username,
company_id=selected_company,
status=BatchStatus.PENDING,
total_files=len(processable_files) # Only count processable files
)
session.add(batch)
await session.flush() # Get batch.id before creating jobs
# Create OCR jobs for processable files only
job_ids = []
batch_jobs = []
try:
for file_data in processable_files:
# Create OCR job using existing job_queue
# Pass batch_id and file_hash for tracking
job = await job_queue.create_job(
file_bytes=file_data["content"],
mime_type=file_data["mime_type"],
engine="doctr_plus", # Default engine for bulk
username=current_user.username,
original_filename=file_data["filename"],
batch_id=batch.id, # Link job to batch for auto-save integration
file_hash=file_data["file_hash"] # Pass hash for storage in receipt
)
job_ids.append(job.id)
# Create batch_job link
batch_job = BatchJob(
batch_id=batch.id,
job_id=job.id,
filename=file_data["filename"]
)
batch_jobs.append(batch_job)
# Add all batch_job records
for bj in batch_jobs:
session.add(bj)
# Commit everything atomically
await session.commit()
logger.info(
f"[BulkUpload] Created batch {batch.id} with {len(job_ids)} jobs "
f"for user {current_user.username}"
f"{f', {len(duplicate_files)} duplicates skipped' if duplicate_files else ''}"
)
# Return response with duplicate info if any duplicates were found
if duplicate_files:
return BulkUploadResponseWithDuplicates(
batch_id=batch.id,
job_ids=job_ids,
total_files=len(file_contents),
processed_files=len(job_ids),
duplicate_files=len(duplicate_files),
duplicates=duplicate_files,
message=f"{len(job_ids)} fișier(e) în procesare, {len(duplicate_files)} duplicate ignorate"
)
return BulkUploadResponse(
batch_id=batch.id,
job_ids=job_ids,
total_files=len(job_ids),
message=f"{len(job_ids)} files queued for processing"
)
except Exception as e:
# Rollback on any error
await session.rollback()
logger.error(f"[BulkUpload] Failed to create batch: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to create batch: {str(e)}"
)
# Long-polling constants
MAX_WAIT_SECONDS = 30
POLL_INTERVAL_SECONDS = 0.5
async def _get_batch_status_snapshot(
batch_id: int,
session: AsyncSession
) -> Optional[dict]:
"""
Get current batch status snapshot.
Returns dict with status counts and jobs list, or None if batch not found.
"""
# Get batch record
batch_result = await session.execute(
select(BatchUpload).where(BatchUpload.id == batch_id)
)
batch = batch_result.scalar_one_or_none()
if not batch:
return None
# Get all batch_jobs for this batch
batch_jobs_result = await session.execute(
select(BatchJob).where(BatchJob.batch_id == batch_id)
)
batch_jobs = batch_jobs_result.scalars().all()
if not batch_jobs:
return {
"batch": batch,
"pending_count": 0,
"processing_count": 0,
"completed_count": 0,
"failed_count": 0,
"jobs": [],
"total_amount": None
}
# Get job statuses and error_messages from OCR job queue (SQLite)
job_statuses = {}
job_errors = {}
for bj in batch_jobs:
job = await job_queue.get_job(bj.job_id)
if job:
job_statuses[bj.job_id] = job.status.value
job_errors[bj.job_id] = job.error_message
else:
# Job not found in queue - treat as failed
job_statuses[bj.job_id] = "failed"
job_errors[bj.job_id] = "Job not found in queue"
# Count by status
pending_count = sum(1 for s in job_statuses.values() if s == "pending")
processing_count = sum(1 for s in job_statuses.values() if s == "processing")
completed_count = sum(1 for s in job_statuses.values() if s == "completed")
failed_count = sum(1 for s in job_statuses.values() if s == "failed")
# Build jobs list with status info
jobs_info = []
for bj in batch_jobs:
jobs_info.append({
"job_id": bj.job_id,
"filename": bj.filename,
"status": job_statuses.get(bj.job_id, "failed"),
"receipt_id": bj.receipt_id,
"error_message": job_errors.get(bj.job_id)
})
# Calculate total_amount from completed receipts
total_amount = None
receipt_ids = [bj.receipt_id for bj in batch_jobs if bj.receipt_id is not None]
if receipt_ids:
amount_result = await session.execute(
select(func.sum(Receipt.amount)).where(Receipt.id.in_(receipt_ids))
)
total_sum = amount_result.scalar()
if total_sum is not None:
total_amount = float(total_sum)
return {
"batch": batch,
"pending_count": pending_count,
"processing_count": processing_count,
"completed_count": completed_count,
"failed_count": failed_count,
"jobs": jobs_info,
"total_amount": total_amount
}
def _compute_batch_overall_status(pending: int, processing: int, completed: int, failed: int, total: int) -> str:
"""Compute overall batch status from job counts."""
if pending + processing == 0:
# All jobs finished
if failed == total:
return BatchStatus.FAILED.value
return BatchStatus.COMPLETED.value
elif processing > 0 or completed > 0 or failed > 0:
return BatchStatus.PROCESSING.value
else:
return BatchStatus.PENDING.value
@router.get("/batches/{batch_id}/status", response_model=BatchStatusResponse)
async def get_batch_status(
batch_id: int,
wait: Optional[int] = Query(
default=None,
ge=0,
le=MAX_WAIT_SECONDS,
description="Long-polling wait time in seconds (max 30)"
),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get batch processing status with optional long-polling.
Returns aggregated status counts and individual job statuses.
When `wait` parameter is provided, the endpoint will poll until:
- Status changes from initial snapshot
- All jobs complete (pending + processing = 0)
- Timeout is reached
Args:
batch_id: Batch ID to query
wait: Optional wait time in seconds for long-polling (0-30)
Returns:
BatchStatusResponse with status counts and job details
Raises:
404: If batch not found
"""
# Get initial snapshot
snapshot = await _get_batch_status_snapshot(batch_id, session)
if snapshot is None:
raise HTTPException(
status_code=404,
detail=f"Batch {batch_id} not found"
)
# If long-polling requested and jobs still in progress
if wait and wait > 0:
initial_pending = snapshot["pending_count"]
initial_processing = snapshot["processing_count"]
initial_completed = snapshot["completed_count"]
initial_failed = snapshot["failed_count"]
# Only wait if there are still jobs in progress
if initial_pending + initial_processing > 0:
elapsed = 0.0
while elapsed < wait:
await asyncio.sleep(POLL_INTERVAL_SECONDS)
elapsed += POLL_INTERVAL_SECONDS
# Refresh snapshot
snapshot = await _get_batch_status_snapshot(batch_id, session)
if snapshot is None:
# Batch deleted during polling (edge case)
raise HTTPException(status_code=404, detail=f"Batch {batch_id} not found")
# Check if status changed
current_pending = snapshot["pending_count"]
current_processing = snapshot["processing_count"]
current_completed = snapshot["completed_count"]
current_failed = snapshot["failed_count"]
if (current_pending != initial_pending or
current_processing != initial_processing or
current_completed != initial_completed or
current_failed != initial_failed):
# Status changed, return immediately
break
# Check if all jobs finished
if current_pending + current_processing == 0:
break
# Build response
batch = snapshot["batch"]
total_files = batch.total_files
overall_status = _compute_batch_overall_status(
snapshot["pending_count"],
snapshot["processing_count"],
snapshot["completed_count"],
snapshot["failed_count"],
total_files
)
jobs = [
BatchJobInfo(
job_id=j["job_id"],
filename=j["filename"],
status=j["status"],
receipt_id=j["receipt_id"],
error_message=j.get("error_message")
)
for j in snapshot["jobs"]
]
return BatchStatusResponse(
batch_id=batch.id,
status=overall_status,
total_files=total_files,
pending_count=snapshot["pending_count"],
processing_count=snapshot["processing_count"],
completed_count=snapshot["completed_count"],
failed_count=snapshot["failed_count"],
jobs=jobs,
total_amount=snapshot["total_amount"],
created_at=batch.created_at
)
# ============ Retry Endpoints (US-006) ============
async def _retry_single_receipt(
session: AsyncSession,
receipt: Receipt,
username: str
) -> tuple[bool, Optional[str], Optional[str]]:
"""
Retry processing for a single receipt.
Finds the original file from attachments, resets processing status,
and creates a new OCR job.
Args:
session: Database session
receipt: Receipt to retry
username: Username for the new OCR job
Returns:
Tuple of (success, job_id, error_message)
"""
# Get the first attachment to find the source file
attachments_result = await session.execute(
select(ReceiptAttachment)
.where(ReceiptAttachment.receipt_id == receipt.id)
.limit(1)
)
attachment = attachments_result.scalar_one_or_none()
if not attachment:
return False, None, "Bonul nu are fișier atașat"
# Construct full path to attachment file
file_path = settings.data_entry_upload_path_resolved / attachment.file_path
if not file_path.exists():
return False, None, "Fișierul original nu mai este disponibil"
# Read file content
try:
with open(file_path, 'rb') as f:
file_bytes = f.read()
except Exception as e:
logger.error(f"[Retry] Failed to read file {file_path}: {e}")
return False, None, f"Eroare la citirea fișierului: {str(e)}"
# Create new OCR job
try:
job = await job_queue.create_job(
file_bytes=file_bytes,
mime_type=attachment.mime_type,
engine="doctr_plus",
username=username,
original_filename=attachment.filename,
batch_id=None, # No batch for retry - direct processing
file_hash=receipt.file_hash
)
# Reset receipt processing status
receipt.processing_status = "pending"
receipt.processing_error = None
receipt.processing_started_at = datetime.utcnow()
receipt.processing_completed_at = None
await session.flush()
logger.info(f"[Retry] Receipt {receipt.id} requeued as job {job.id}")
return True, job.id, None
except Exception as e:
logger.error(f"[Retry] Failed to create job for receipt {receipt.id}: {e}")
return False, None, f"Eroare la crearea job-ului OCR: {str(e)}"
@router.post("/retry/{receipt_id}", response_model=RetryResponse)
async def retry_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
selected_company: int = Depends(get_selected_company)
):
"""
Retry OCR processing for a single failed receipt.
Resets the receipt's processing_status to 'pending' and creates
a new OCR job using the original attachment file.
Args:
receipt_id: ID of the receipt to retry
Returns:
RetryResponse with success status and new job ID
Raises:
404: If receipt not found
400: If receipt is not in 'failed' status
400: If original file is not available
"""
# Get the receipt
result = await session.execute(
select(Receipt).where(
and_(
Receipt.id == receipt_id,
Receipt.company_id == selected_company
)
)
)
receipt = result.scalar_one_or_none()
if not receipt:
raise HTTPException(
status_code=404,
detail=f"Bonul #{receipt_id} nu a fost găsit"
)
# Verify receipt is in failed status
if receipt.processing_status != "failed":
raise HTTPException(
status_code=400,
detail=f"Bonul nu este în stare de eroare (status actual: {receipt.processing_status})"
)
# Attempt retry
success, job_id, error = await _retry_single_receipt(
session, receipt, current_user.username
)
if not success:
raise HTTPException(
status_code=400,
detail=error or "Eroare necunoscută la reîncărcare"
)
await session.commit()
return RetryResponse(
success=True,
receipt_id=receipt_id,
job_id=job_id,
message="Bon reîncarcat în procesare"
)
@router.post("/retry-batch/{batch_id}", response_model=BatchRetryResponse)
async def retry_batch_failed(
batch_id: str,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
selected_company: int = Depends(get_selected_company)
):
"""
Retry all failed receipts in a batch.
Finds all receipts with batch_id matching and processing_status='failed',
then attempts to retry each one.
Args:
batch_id: Batch ID (UUID string from receipt.batch_id)
Returns:
BatchRetryResponse with counts of successful and failed retries
Raises:
404: If no failed receipts found for batch
"""
# Find all failed receipts in this batch
result = await session.execute(
select(Receipt).where(
and_(
Receipt.batch_id == batch_id,
Receipt.company_id == selected_company,
Receipt.processing_status == "failed"
)
)
)
failed_receipts = result.scalars().all()
if not failed_receipts:
raise HTTPException(
status_code=404,
detail=f"Nu există bonuri cu erori în batch-ul {batch_id}"
)
# Retry each receipt
retried_count = 0
failed_count = 0
errors = []
for receipt in failed_receipts:
success, job_id, error = await _retry_single_receipt(
session, receipt, current_user.username
)
if success:
retried_count += 1
else:
failed_count += 1
errors.append(f"Bon #{receipt.id}: {error}")
await session.commit()
return BatchRetryResponse(
success=retried_count > 0,
batch_id=batch_id,
retried_count=retried_count,
failed_count=failed_count,
errors=errors,
message=f"{retried_count} bonuri reîncarcate în procesare"
+ (f", {failed_count} erori" if failed_count > 0 else "")
)
# ============ Cancel Endpoints (US-014) ============
@router.post("/cancel/{job_id}", response_model=CancelJobResponse)
async def cancel_job(
job_id: str,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Cancel a single OCR processing job.
Only jobs with status 'pending' or 'processing' can be cancelled.
Jobs with status 'completed' or 'failed' cannot be cancelled.
Important: If a receipt has already been created from this job,
it will NOT be deleted - receipts are preserved for audit purposes.
Args:
job_id: The UUID of the OCR job to cancel
Returns:
CancelJobResponse with cancellation details
Raises:
404: If job not found in batch_jobs table
400: If job has already completed or failed
"""
# Find the job in batch_jobs table
batch_job_result = await session.execute(
select(BatchJob).where(BatchJob.job_id == job_id)
)
batch_job = batch_job_result.scalar_one_or_none()
if not batch_job:
raise HTTPException(
status_code=404,
detail=f"Job {job_id} nu a fost găsit"
)
# Get the OCR job from job_queue to check current status
ocr_job = await job_queue.get_job(job_id)
if not ocr_job:
raise HTTPException(
status_code=404,
detail=f"Job {job_id} nu există în coada de procesare"
)
# Check if job can be cancelled
current_status = ocr_job.status.value
if current_status == OCRJobStatus.completed.value:
raise HTTPException(
status_code=400,
detail=f"Job-ul a fost deja procesat cu succes. Nu poate fi anulat."
)
if current_status == OCRJobStatus.failed.value:
raise HTTPException(
status_code=400,
detail=f"Job-ul a eșuat deja. Folosiți opțiunea de reîncercare în loc de anulare."
)
if current_status == OCRJobStatus.cancelled.value:
raise HTTPException(
status_code=400,
detail=f"Job-ul a fost deja anulat."
)
# Update job status to cancelled in job_queue (SQLite)
cancelled_at = datetime.utcnow()
success = await job_queue.update_status(
job_id=job_id,
status=OCRJobStatus.cancelled,
error="Cancelled by user"
)
if not success:
raise HTTPException(
status_code=500,
detail=f"Eroare la anularea job-ului"
)
logger.info(
f"[CancelJob] Job {job_id} cancelled by {current_user.username} "
f"(previous status: {current_status})"
)
return CancelJobResponse(
success=True,
job_id=job_id,
cancelled_at=cancelled_at,
message=f"Job anulat cu succes"
)
@router.post("/cancel-batch/{batch_id}", response_model=CancelBatchResponse)
async def cancel_batch(
batch_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Cancel all pending/processing jobs in a batch.
Finds all jobs with status 'pending' or 'processing' in the specified batch
and marks them as 'cancelled'. Jobs with status 'completed' or 'failed'
are not affected.
Important: Receipts that have already been created from completed jobs
will NOT be deleted - they are preserved for audit purposes.
Args:
batch_id: The batch ID to cancel
Returns:
CancelBatchResponse with counts of cancelled and skipped jobs
Raises:
404: If batch not found or no jobs exist for batch
"""
# Verify batch exists
batch_result = await session.execute(
select(BatchUpload).where(BatchUpload.id == batch_id)
)
batch = batch_result.scalar_one_or_none()
if not batch:
raise HTTPException(
status_code=404,
detail=f"Batch {batch_id} nu a fost găsit"
)
# Get all batch_jobs for this batch
batch_jobs_result = await session.execute(
select(BatchJob).where(BatchJob.batch_id == batch_id)
)
batch_jobs = batch_jobs_result.scalars().all()
if not batch_jobs:
raise HTTPException(
status_code=404,
detail=f"Nu există job-uri în batch-ul {batch_id}"
)
# Process each job - cancel pending/processing, skip completed/failed
cancelled_count = 0
skipped_count = 0
for batch_job in batch_jobs:
# Get current job status from OCR job queue
ocr_job = await job_queue.get_job(batch_job.job_id)
if not ocr_job:
# Job not found in queue - treat as skipped
skipped_count += 1
continue
current_status = ocr_job.status.value
# Only cancel pending or processing jobs
if current_status in (OCRJobStatus.pending.value, OCRJobStatus.processing.value):
success = await job_queue.update_status(
job_id=batch_job.job_id,
status=OCRJobStatus.cancelled,
error="Cancelled by user (batch cancel)"
)
if success:
cancelled_count += 1
logger.debug(f"[CancelBatch] Cancelled job {batch_job.job_id}")
else:
# Failed to cancel - count as skipped
skipped_count += 1
logger.warning(
f"[CancelBatch] Failed to cancel job {batch_job.job_id}"
)
else:
# Job is completed, failed, or already cancelled - skip it
skipped_count += 1
logger.info(
f"[CancelBatch] Batch {batch_id} cancelled by {current_user.username}: "
f"{cancelled_count} cancelled, {skipped_count} skipped"
)
# Build message
if cancelled_count == 0:
message = f"Nu există job-uri de anulat în batch-ul {batch_id}"
elif skipped_count == 0:
message = f"{cancelled_count} job-uri anulate"
else:
message = f"{cancelled_count} job-uri anulate, {skipped_count} ignorate (deja procesate)"
return CancelBatchResponse(
success=cancelled_count > 0,
batch_id=batch_id,
cancelled_count=cancelled_count,
skipped_count=skipped_count,
message=message
)

View File

@@ -0,0 +1,260 @@
"""Nomenclature API endpoints."""
from typing import Optional, List, Annotated
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.services.sync_service import SyncService
# Import auth dependencies
import sys
from pathlib import Path
# Path setup handled by main.py - this is redundant
# project_root = Path(__file__).parent.parent.parent.parent.parent
# sys.path.insert(0, str(project_root / "shared"))
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============ Selected Company Dependency ============
async def get_selected_company(
current_user: CurrentUser = Depends(get_current_user),
x_selected_company: Annotated[Optional[str], Header()] = None
) -> int:
"""
Get selected company from X-Selected-Company header.
Validates user access. Falls back to first company if no header.
"""
if x_selected_company:
try:
company_id = int(x_selected_company)
except ValueError:
raise HTTPException(400, f"Invalid company ID: {x_selected_company}")
if str(company_id) in current_user.companies:
return company_id
raise HTTPException(403, f"Nu aveți acces la firma {company_id}")
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
pass
raise HTTPException(400, "Nu aveți nicio firmă asignată")
SelectedCompany = Annotated[int, Depends(get_selected_company)]
# Request/Response Models
class SupplierSearchResult(BaseModel):
found: bool
supplier: Optional[dict] = None
source: str # 'synced', 'local', 'not_found'
class LocalSupplierCreate(BaseModel):
name: str
fiscal_code: Optional[str] = None
address: Optional[str] = None
class LocalSupplierResponse(BaseModel):
id: int
name: str
fiscal_code: Optional[str]
address: Optional[str]
is_local: bool = True
class SyncResult(BaseModel):
synced: int
errors: int
message: str
class SupplierOption(BaseModel):
id: int
oracle_id: Optional[int] = None
name: str
fiscal_code: Optional[str]
source: str # 'synced' or 'local'
class CashRegisterOption(BaseModel):
id: int
oracle_id: int
name: str
account_code: str
register_type: str
# Endpoints
@router.get("/suppliers/search", response_model=SupplierSearchResult)
async def search_supplier(
fiscal_code: Optional[str] = None,
name: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Search for supplier by fiscal code or name."""
if not fiscal_code and not name:
raise HTTPException(status_code=400, detail="Provide fiscal_code or name")
cid = company_id or selected_company
found, supplier, source = await SyncService.search_supplier(
session, cid, fiscal_code, name
)
return SupplierSearchResult(found=found, supplier=supplier, source=source)
@router.get("/suppliers", response_model=List[SupplierOption])
async def get_suppliers(
search: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all suppliers (synced + local) for dropdown/autocomplete."""
cid = company_id or selected_company
suppliers = await SyncService.get_all_suppliers(session, cid, search)
return [
SupplierOption(
id=s["id"],
oracle_id=s.get("oracle_id"),
name=s["name"],
fiscal_code=s.get("fiscal_code"),
source=s["source"]
)
for s in suppliers
]
@router.post("/suppliers/local", response_model=LocalSupplierResponse)
async def create_local_supplier(
data: LocalSupplierCreate,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
current_user: CurrentUser = Depends(get_current_user),
):
"""Create a local supplier from OCR data."""
cid = company_id or selected_company
supplier = await SyncService.create_local_supplier(
session, cid, data.name, data.fiscal_code, data.address, current_user.username
)
return LocalSupplierResponse(
id=supplier.id,
name=supplier.name,
fiscal_code=supplier.fiscal_code,
address=supplier.address,
)
@router.get("/cash-registers", response_model=List[CashRegisterOption])
async def get_cash_registers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all cash registers for a company."""
cid = company_id or selected_company
registers = await SyncService.get_all_cash_registers(session, cid)
return [
CashRegisterOption(
id=r["id"],
oracle_id=r["oracle_id"],
name=r["name"],
account_code=r["account_code"],
register_type=r["register_type"]
)
for r in registers
]
@router.post("/sync/suppliers", response_model=SyncResult)
async def sync_suppliers(
request: Request,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Manually trigger supplier sync from Oracle."""
cid = company_id or selected_company
server_id = getattr(request.state, 'server_id', None)
synced, errors = await SyncService.sync_suppliers(session, cid, server_id=server_id)
return SyncResult(
synced=synced,
errors=errors,
message=f"Synced {synced} suppliers with {errors} errors"
)
@router.post("/sync/cash-registers", response_model=SyncResult)
async def sync_cash_registers(
request: Request,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Manually trigger cash register sync from Oracle."""
cid = company_id or selected_company
server_id = getattr(request.state, 'server_id', None)
synced, errors = await SyncService.sync_cash_registers(session, cid, server_id=server_id)
return SyncResult(
synced=synced,
errors=errors,
message=f"Synced {synced} cash registers with {errors} errors"
)
@router.post("/sync/all", response_model=dict)
async def sync_all_nomenclatures(
request: Request,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Sync all nomenclatures (suppliers + cash registers) from Oracle."""
cid = company_id or selected_company
server_id = getattr(request.state, 'server_id', None)
# Sync suppliers
suppliers_synced, suppliers_errors = await SyncService.sync_suppliers(session, cid, server_id=server_id)
# Sync cash registers
registers_synced, registers_errors = await SyncService.sync_cash_registers(session, cid, server_id=server_id)
return {
"suppliers": {
"synced": suppliers_synced,
"errors": suppliers_errors
},
"cash_registers": {
"synced": registers_synced,
"errors": registers_errors
},
"total_synced": suppliers_synced + registers_synced,
"total_errors": suppliers_errors + registers_errors,
"message": f"Synced {suppliers_synced} suppliers and {registers_synced} cash registers"
}

View File

@@ -0,0 +1,715 @@
"""
OCR API endpoints with async job queue support.
Endpoints:
- POST /extract - Submit OCR job (returns job_id immediately)
- GET /jobs/{job_id} - Get job status and result
- GET /queue/status - Get queue statistics
- GET /status - Check OCR service availability
For backwards compatibility, we also support sync mode via query param:
- POST /extract?sync=true - Process synchronously (blocks until complete)
"""
import os
import tempfile
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
from backend.modules.data_entry.services.ocr_service import ocr_service
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus as JobStatus
from backend.modules.data_entry.services.ocr.job_worker import estimate_wait_time
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
from backend.modules.data_entry.schemas.ocr import (
OCRResponse,
OCRStatusResponse,
ExtractionData,
TvaEntry,
PaymentMethod,
# New job queue schemas
OCREngineChoice,
OCRJobStatus,
OCRJobSubmitResponse,
OCRJobResponse,
OCRQueueStatusResponse,
)
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============================================================================
# OCR Job Queue Endpoints (NEW)
# ============================================================================
@router.post("/extract", response_model=OCRJobSubmitResponse)
async def submit_ocr_job(
file: UploadFile = File(...),
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus, description="OCR engine to use"),
sync: bool = Query(default=False, description="If true, process synchronously (blocks)"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Submit an OCR job for processing.
By default, returns immediately with a job_id. Poll GET /jobs/{job_id} for result.
Use ?sync=true for synchronous processing (blocks until complete).
This is for backwards compatibility but not recommended for production.
Args:
file: Image or PDF file (max 10MB)
engine: OCR engine choice (tesseract, doctr, doctr_plus, paddleocr)
sync: If true, process synchronously (legacy mode)
Returns:
OCRJobSubmitResponse with job_id, queue_position, estimated_wait
"""
allowed_types = ['image/jpeg', 'image/png', 'application/pdf']
if file.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF"
)
# Read file content
content = await file.read()
# Check file size (10MB limit)
if len(content) > 10 * 1024 * 1024:
raise HTTPException(
status_code=400,
detail="File too large. Maximum size is 10MB."
)
# Sync mode - use legacy processing (blocks)
if sync:
return await _process_sync(content, file, engine, current_user)
# Async mode - create job and return immediately
try:
job = await job_queue.create_job(
file_bytes=content,
mime_type=file.content_type,
engine=engine.value,
username=current_user.username,
original_filename=file.filename
)
# Get queue position
queue_position = await job_queue.get_queue_position(job.id)
estimated_wait = estimate_wait_time(queue_position or 1)
return OCRJobSubmitResponse(
job_id=job.id,
status=OCRJobStatus.pending,
queue_position=queue_position or 1,
estimated_wait_seconds=estimated_wait,
created_at=job.created_at or datetime.utcnow()
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to create OCR job: {str(e)}"
)
@router.get("/jobs/{job_id}", response_model=OCRJobResponse)
async def get_job_status(
job_id: str,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR job status and result (instant response).
For efficient polling, use GET /jobs/{job_id}/wait instead (long-polling).
Args:
job_id: Job UUID from POST /extract response
Returns:
OCRJobResponse with status, queue_position, and result (if completed)
"""
job = await job_queue.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Get queue position for pending jobs
queue_position = None
estimated_wait = None
if job.status == JobStatus.pending:
queue_position = await job_queue.get_queue_position(job_id)
estimated_wait = estimate_wait_time(queue_position or 1)
elif job.status == JobStatus.processing:
queue_position = 0
# Estimate remaining time based on average
avg_time = await job_queue.get_average_processing_time()
estimated_wait = int(avg_time * 0.5) # Rough estimate: half remaining
# Convert result to ExtractionData if available
result_data = None
if job.status == JobStatus.completed and job.result:
result_data = _dict_to_extraction_data(job.result)
# Apply fuzzy CUI matching
result_data = await _apply_fuzzy_cui_matching(result_data, session)
# Debug: log suggested_payment_mode being returned
print(f"[OCR Router] Returning job {job_id} with suggested_payment_mode={result_data.suggested_payment_mode}", flush=True)
return OCRJobResponse(
job_id=job.id,
status=OCRJobStatus(job.status.value),
queue_position=queue_position,
estimated_wait_seconds=estimated_wait,
created_at=job.created_at or datetime.utcnow(),
started_at=job.started_at,
completed_at=job.completed_at,
queue_wait_ms=job.queue_wait_ms,
ocr_time_ms=job.ocr_time_ms,
processing_time_ms=job.processing_time_ms,
result=result_data,
error=job.error_message
)
@router.get("/jobs/{job_id}/wait", response_model=OCRJobResponse)
async def wait_for_job_status(
job_id: str,
response: Response,
timeout: int = Query(default=30, ge=1, le=60, description="Max wait time in seconds"),
wait_for_terminal: bool = Query(default=False, description="If true, only return on completed/failed"),
_t: int = Query(default=None, description="Cache-busting timestamp (ignored)"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Long-poll for OCR job status change.
Waits until:
- Job status changes (default behavior - returns on any status change)
- Job reaches terminal state (if wait_for_terminal=true)
- Timeout expires (returns current status)
Recommended client timeout: timeout + 5 seconds
Args:
job_id: Job UUID from POST /extract response
timeout: Max wait time in seconds (1-60, default 30)
wait_for_terminal: If true, wait until completed/failed only
Returns:
OCRJobResponse with status, queue_position, and result (if completed)
"""
# Prevent caching - critical for long-polling
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
import asyncio
import time
start_time = time.time()
end_time = start_time + timeout
last_status = None
iteration = 0
print(f"[OCR Wait] Starting long-poll for job {job_id}, timeout={timeout}s, wait_for_terminal={wait_for_terminal}", flush=True)
while time.time() < end_time:
iteration += 1
job = await job_queue.get_job(job_id)
if not job:
print(f"[OCR Wait] Job {job_id} not found after {iteration} iterations", flush=True)
raise HTTPException(status_code=404, detail="Job not found")
# Return immediately if job completed or failed (terminal states)
if job.status in [JobStatus.completed, JobStatus.failed]:
elapsed = time.time() - start_time
print(f"[OCR Wait] Job {job_id} {job.status.value} after {elapsed:.1f}s ({iteration} iterations)", flush=True)
return await get_job_status(job_id, session, current_user)
# Return on status change (unless wait_for_terminal is set)
if not wait_for_terminal and last_status is not None and job.status != last_status:
elapsed = time.time() - start_time
print(f"[OCR Wait] Job {job_id} status changed {last_status.value}->{job.status.value} after {elapsed:.1f}s", flush=True)
return await get_job_status(job_id, session, current_user)
last_status = job.status
# Wait 500ms before next internal check (faster polling for better responsiveness)
await asyncio.sleep(0.5)
# Timeout - return current status
elapsed = time.time() - start_time
print(f"[OCR Wait] Job {job_id} timeout after {elapsed:.1f}s ({iteration} iterations), status={last_status.value if last_status else 'unknown'}", flush=True)
return await get_job_status(job_id, session, current_user)
@router.get("/queue/status", response_model=OCRQueueStatusResponse)
async def get_queue_status(
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR queue statistics.
Returns:
Queue status with pending/processing counts and average time
"""
stats = await job_queue.get_queue_stats()
return OCRQueueStatusResponse(
pending_jobs=stats["pending"],
processing_jobs=stats["processing"],
average_time_seconds=stats["average_time_seconds"]
)
# ============================================================================
# Legacy Endpoints (backwards compatibility)
# ============================================================================
@router.get("/status", response_model=OCRStatusResponse)
async def get_ocr_status():
"""Check OCR service status and available engines."""
engines = OCREngine.get_available_engines()
available = len(engines) > 0
if available:
message = f"OCR service ready with engines: {', '.join(engines)}"
else:
message = "No OCR engines available. Install PaddleOCR or Tesseract."
return OCRStatusResponse(
available=available,
engines=engines,
message=message
)
@router.get("/engines")
async def get_available_engines():
"""
Get list of enabled OCR engines based on .env configuration.
Returns engines availability and available processing modes.
Frontend should use this to filter engine selection dropdown.
Available engines: tesseract, doctr, doctr_plus, paddleocr
"""
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
default_engine = os.getenv("OCR_DEFAULT_ENGINE", "doctr_plus")
# Build engines dict
engines = {
"tesseract": tesseract_enabled,
"doctr": True, # Always available (primary engine)
"doctr_plus": True, # Always available (recommended)
"paddleocr": paddle_enabled,
}
# Build available modes based on enabled engines
modes = []
if tesseract_enabled:
modes.append("tesseract")
modes.append("doctr")
modes.append("doctr_plus")
if paddle_enabled:
modes.append("paddleocr")
return {
"engines": engines,
"available_modes": modes,
"default_mode": default_engine,
"memory_estimate_mb": {
"tesseract": 50,
"doctr": 600,
"doctr_plus": 600,
"paddleocr": 800,
}
}
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
async def extract_from_attachment(
attachment_id: int,
engine: OCREngineChoice = Query(default=OCREngineChoice.doctr_plus),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Extract receipt data from an existing attachment.
Re-processes an already uploaded file with OCR.
This endpoint always processes synchronously.
"""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
file_path = AttachmentCRUD.get_file_path(attachment)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
# Check if file type is supported
if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']:
raise HTTPException(
status_code=400,
detail=f"File type not supported for OCR: {attachment.mime_type}"
)
# TODO: Could use job queue here too, but keeping sync for now
success, message, result = await ocr_service.process_image(
file_path, attachment.mime_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
data = _result_to_extraction_data(result)
# Apply fuzzy CUI matching
data = await _apply_fuzzy_cui_matching(data, session)
return OCRResponse(success=True, message=message, data=data)
# ============================================================================
# Helper Functions
# ============================================================================
async def _apply_fuzzy_cui_matching(
extraction_data: ExtractionData,
session: AsyncSession
) -> ExtractionData:
"""
Apply fuzzy CUI matching to extraction data.
ONLY applies fuzzy matching if CUI is missing OR has invalid checksum.
If CUI has valid checksum, we trust the OCR and skip fuzzy matching.
Args:
extraction_data: ExtractionData with CUI to potentially correct
session: AsyncSession for database lookups
Returns:
ExtractionData with CUI corrected if a match was found
"""
from backend.modules.data_entry.services.ocr.validation import CUIChecksumRule
# Skip if no CUI and no vendor name (nothing to match)
if not extraction_data.cui and not extraction_data.partner_name:
return extraction_data
# Check if CUI has valid checksum - if valid, skip fuzzy matching
if extraction_data.cui:
cui_digits = CUIChecksumRule.extract_digits(extraction_data.cui)
if len(cui_digits) >= 6 and CUIChecksumRule.validate_checksum(cui_digits):
print(f"[Fuzzy Match] CUI {extraction_data.cui} has valid checksum, skipping fuzzy match", flush=True)
return extraction_data
# CUI missing or invalid checksum - try fuzzy matching
try:
match = await OCRValidationEngine.fuzzy_match_supplier(
cui=extraction_data.cui,
vendor_name=extraction_data.partner_name,
db_session=session
)
if match:
corrected_cui, supplier_name = match
if corrected_cui != extraction_data.cui:
print(f"[Fuzzy Match] Corrected: {extraction_data.cui} -> {corrected_cui} ({supplier_name})", flush=True)
extraction_data.cui = corrected_cui
# Also set partner_name if not already set
if not extraction_data.partner_name:
extraction_data.partner_name = supplier_name
except Exception as e:
print(f"[Fuzzy Match] Error: {e}", flush=True)
return extraction_data
async def _process_sync(
content: bytes,
file: UploadFile,
engine: OCREngineChoice,
current_user: CurrentUser
) -> OCRJobSubmitResponse:
"""
Process OCR synchronously (legacy mode).
Creates a job, processes it immediately, and returns the result
wrapped in a JobSubmitResponse for API consistency.
"""
# Get file extension
suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg'
if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']:
suffix = '.jpg'
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(content)
tmp_path = Path(tmp.name)
try:
success, message, result = await ocr_service.process_image(
tmp_path, file.content_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
# Create a fake job response with the result embedded
# This maintains API compatibility
now = datetime.utcnow()
# For sync mode, we return a special response that includes
# the result directly. Clients should check if result is present.
return OCRJobSubmitResponse(
job_id="sync-" + str(hash(content))[:16],
status=OCRJobStatus.completed,
queue_position=0,
estimated_wait_seconds=0,
created_at=now
)
finally:
# Clean up temp file
if tmp_path.exists():
os.unlink(tmp_path)
def _result_to_extraction_data(result) -> ExtractionData:
"""Convert ExtractionResult to ExtractionData schema."""
# Convert tva_entries from dict to TvaEntry objects
tva_entries_schema = [
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
for e in result.tva_entries
] if result.tva_entries else []
# Convert payment_methods from dict to PaymentMethod objects
payment_methods_list = [
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
for pm in result.payment_methods
] if result.payment_methods else []
# Auto-suggest payment_mode based on detected methods
suggested_payment_mode = None
if payment_methods_list:
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
if has_card:
suggested_payment_mode = 'banca'
return ExtractionData(
receipt_type=result.receipt_type,
receipt_number=result.receipt_number,
receipt_series=result.receipt_series,
receipt_date=result.receipt_date,
amount=result.amount,
partner_name=result.partner_name,
cui=result.cui,
description=result.description,
tva_entries=tva_entries_schema,
tva_total=result.tva_total,
address=result.address,
items_count=result.items_count,
payment_methods=payment_methods_list,
suggested_payment_mode=suggested_payment_mode,
client_name=result.client_name,
client_cui=result.client_cui,
client_address=result.client_address,
confidence_amount=result.confidence_amount,
confidence_date=result.confidence_date,
confidence_vendor=result.confidence_vendor,
confidence_client=getattr(result, 'confidence_client', 0.0),
overall_confidence=result.overall_confidence,
raw_text=result.raw_text,
raw_texts=getattr(result, 'raw_texts', []),
ocr_engine=result.ocr_engine,
processing_time_ms=result.processing_time_ms,
needs_manual_review=result.needs_manual_review,
validation_warnings=result.validation_warnings,
validation_errors=result.validation_errors,
inter_ocr_ratios=result.inter_ocr_ratios,
)
def _dict_to_extraction_data(data: dict) -> ExtractionData:
"""Convert result dict (from job queue) to ExtractionData schema."""
from datetime import date
# Parse date if string
receipt_date = data.get('receipt_date')
if isinstance(receipt_date, str):
try:
receipt_date = date.fromisoformat(receipt_date)
except (ValueError, TypeError):
receipt_date = None
# Convert tva_entries
tva_entries = data.get('tva_entries', []) or []
tva_entries_schema = []
for e in tva_entries:
if isinstance(e, dict):
tva_entries_schema.append(TvaEntry(
code=e.get('code'),
percent=e.get('percent', 0),
amount=Decimal(str(e.get('amount', 0)))
))
# Convert payment_methods
payment_methods = data.get('payment_methods', []) or []
payment_methods_list = []
for pm in payment_methods:
if isinstance(pm, dict):
payment_methods_list.append(PaymentMethod(
method=pm.get('method', 'NUMERAR'),
amount=Decimal(str(pm.get('amount', 0)))
))
# Convert amount and tva_total to Decimal
amount = data.get('amount')
if amount is not None:
amount = Decimal(str(amount))
tva_total = data.get('tva_total')
if tva_total is not None:
tva_total = Decimal(str(tva_total))
return ExtractionData(
receipt_type=data.get('receipt_type', 'bon_fiscal'),
receipt_number=data.get('receipt_number'),
receipt_series=data.get('receipt_series'),
receipt_date=receipt_date,
amount=amount,
partner_name=data.get('partner_name'),
cui=data.get('cui'),
description=data.get('description'),
tva_entries=tva_entries_schema,
tva_total=tva_total,
address=data.get('address'),
items_count=data.get('items_count'),
payment_methods=payment_methods_list,
suggested_payment_mode=data.get('suggested_payment_mode'),
client_name=data.get('client_name'),
client_cui=data.get('client_cui'),
client_address=data.get('client_address'),
confidence_amount=data.get('confidence_amount', 0.0),
confidence_date=data.get('confidence_date', 0.0),
confidence_vendor=data.get('confidence_vendor', 0.0),
confidence_client=data.get('confidence_client', 0.0),
confidence_tva=data.get('confidence_tva', 0.0),
confidence_payment=data.get('confidence_payment', 0.0),
overall_confidence=data.get('overall_confidence', 0.0),
raw_text=data.get('raw_text', ''),
raw_texts=data.get('raw_texts', []),
ocr_engine=data.get('ocr_engine', ''),
processing_time_ms=data.get('processing_time_ms', 0),
needs_manual_review=data.get('needs_manual_review'),
validation_warnings=data.get('validation_warnings', []),
validation_errors=data.get('validation_errors', []),
inter_ocr_ratios=data.get('inter_ocr_ratios', {}),
)
# ============================================================================
# Store Profiles Management Endpoints
# ============================================================================
@router.post("/profiles/reload")
async def reload_store_profiles(
current_user: CurrentUser = Depends(get_current_user)
) -> dict:
"""
Hot-reload all store profiles.
Reloads profile Python modules without server restart.
Use after adding/modifying profile files.
Returns:
Dict with reloaded count and profile list
"""
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
count = ProfileRegistry.reload_all()
status = ProfileRegistry.get_reload_status()
return {
"success": True,
"reloaded_modules": count,
"profiles_count": status["profiles_count"],
"registered_cuis": status["registered_cuis"],
"last_reload": status["last_reload"],
}
@router.get("/profiles")
async def list_store_profiles(
current_user: CurrentUser = Depends(get_current_user)
) -> dict:
"""
List all registered store profiles.
Returns:
Dict with profiles list and status
"""
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
profiles = ProfileRegistry.list_profiles()
status = ProfileRegistry.get_reload_status()
return {
"profiles": profiles,
"count": len(profiles),
"last_reload": status["last_reload"],
}
@router.get("/profiles/{cui}")
async def get_store_profile(
cui: str,
current_user: CurrentUser = Depends(get_current_user)
) -> dict:
"""
Get details for a specific store profile.
Args:
cui: Store CUI (with or without RO prefix)
Returns:
Profile details including validation hints
Raises:
404: If no profile exists for this CUI
"""
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
info = ProfileRegistry.get_profile_info(cui)
if not info:
raise HTTPException(
status_code=404,
detail=f"No profile registered for CUI: {cui}"
)
return info

View File

@@ -0,0 +1,268 @@
"""
OCR Settings and Metrics API endpoints.
Endpoints:
- GET /settings/ocr-preference - Get user's preferred OCR engine
- POST /settings/ocr-preference - Set user's preferred OCR engine
- GET /metrics/ocr/summary - Get OCR metrics summary by engine
- GET /metrics/ocr/history - Get user's OCR job history
- GET /metrics/ocr/stats - Get overall OCR statistics
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRPreferenceCRUD, OCRMetricsCRUD
from backend.modules.data_entry.db.models.ocr_settings import OCREngine, OCRMetricsSummary
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============================================================================
# Schemas
# ============================================================================
class OCRPreferenceResponse(BaseModel):
"""Response for OCR preference endpoint."""
username: str
preferred_engine: str
available_engines: List[str] = Field(
default=["tesseract", "doctr", "doctr_plus", "paddleocr"],
description="Available OCR engines"
)
class OCRPreferenceRequest(BaseModel):
"""Request to set OCR preference."""
preferred_engine: str = Field(
default="doctr_plus",
description="Preferred OCR engine: tesseract, doctr, doctr_plus, paddleocr"
)
class OCRMetricsHistoryItem(BaseModel):
"""Single OCR job metrics item."""
job_id: str
engine_requested: str
engine_used: str
processing_time_ms: int
success: bool
overall_confidence: float
fields_extracted: int
created_at: str
original_filename: Optional[str] = None
class OCRMetricsHistoryResponse(BaseModel):
"""Response for OCR history endpoint."""
items: List[OCRMetricsHistoryItem]
total: int
class OCRStatsResponse(BaseModel):
"""Response for OCR stats endpoint."""
total_jobs: int
successful_jobs: int
failed_jobs: int
success_rate: float
avg_processing_time_ms: float
avg_confidence: float
period_days: int
class OCRActiveEnginesResponse(BaseModel):
"""Response for active OCR engines endpoint."""
engines: List[str] = Field(description="List of active OCR engines from .env config")
recommended: str = Field(default="doctr_plus", description="Recommended engine")
# ============================================================================
# OCR Engines Configuration Endpoint
# ============================================================================
@router.get("/settings/ocr-engines", response_model=OCRActiveEnginesResponse)
async def get_active_ocr_engines():
"""
Get list of active OCR engines configured in .env.
Returns the engines that should be shown in the frontend dropdown.
Configured via OCR_ACTIVE_ENGINES environment variable.
Default: doctr,doctr_plus
Available: tesseract, paddleocr, doctr, doctr_plus
"""
from backend.modules.data_entry.config import settings
return OCRActiveEnginesResponse(
engines=settings.ocr_active_engines_list,
recommended="doctr_plus"
)
# ============================================================================
# OCR Preference Endpoints
# ============================================================================
@router.get("/settings/ocr-preference", response_model=OCRPreferenceResponse)
async def get_ocr_preference(
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get user's preferred OCR engine.
Returns the user's saved preference or 'doctr_plus' if not set.
Also returns list of available engines.
"""
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
preference = await OCRPreferenceCRUD.get_by_username(session, current_user.username)
# Get available engines from OCR service
available = OCREngineClass.get_available_engines()
return OCRPreferenceResponse(
username=current_user.username,
preferred_engine=preference.preferred_engine.value if preference else "doctr_plus",
available_engines=available
)
@router.post("/settings/ocr-preference", response_model=OCRPreferenceResponse)
async def set_ocr_preference(
request: OCRPreferenceRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Set user's preferred OCR engine.
Valid engines: tesseract, doctr, doctr_plus, paddleocr
Note: Available engines depend on .env configuration (OCR_ENABLE_PADDLEOCR, OCR_ENABLE_TESSERACT)
"""
from backend.modules.data_entry.services.ocr_engine import OCREngine as OCREngineClass
# Get dynamically available engines
available = OCREngineClass.get_available_engines()
if request.preferred_engine not in available:
raise HTTPException(
status_code=400,
detail=f"Invalid engine. Must be one of: {', '.join(available)}"
)
# Map string to enum
engine_map = {
"tesseract": OCREngine.TESSERACT,
"doctr": OCREngine.DOCTR,
"doctr_plus": OCREngine.DOCTR_PLUS,
"paddleocr": OCREngine.PADDLEOCR,
}
engine_enum = engine_map.get(request.preferred_engine, OCREngine.DOCTR_PLUS)
# Save preference
preference = await OCRPreferenceCRUD.create_or_update(
session,
current_user.username,
engine_enum
)
# Get available engines
available = OCREngineClass.get_available_engines()
return OCRPreferenceResponse(
username=current_user.username,
preferred_engine=preference.preferred_engine.value,
available_engines=available
)
# ============================================================================
# OCR Metrics Endpoints
# ============================================================================
@router.get("/metrics/ocr/summary", response_model=List[OCRMetricsSummary])
async def get_ocr_metrics_summary(
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get OCR metrics summary grouped by engine.
Returns aggregated metrics for each engine used in the specified period.
"""
summaries = await OCRMetricsCRUD.get_summary_by_engine(
session,
days=days,
username=current_user.username
)
return summaries
@router.get("/metrics/ocr/history", response_model=OCRMetricsHistoryResponse)
async def get_ocr_metrics_history(
limit: int = Query(default=50, ge=1, le=200, description="Max items to return"),
offset: int = Query(default=0, ge=0, description="Items to skip"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get user's OCR job history.
Returns list of OCR jobs with their metrics, ordered by most recent first.
"""
items = await OCRMetricsCRUD.get_user_history(
session,
username=current_user.username,
limit=limit,
offset=offset
)
history_items = [
OCRMetricsHistoryItem(
job_id=item.job_id,
engine_requested=item.engine_requested,
engine_used=item.engine_used,
processing_time_ms=item.processing_time_ms,
success=item.success,
overall_confidence=item.overall_confidence,
fields_extracted=item.fields_extracted,
created_at=item.created_at.isoformat(),
original_filename=item.original_filename
)
for item in items
]
return OCRMetricsHistoryResponse(
items=history_items,
total=len(history_items)
)
@router.get("/metrics/ocr/stats", response_model=OCRStatsResponse)
async def get_ocr_stats(
days: int = Query(default=30, ge=1, le=365, description="Number of days to include"),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Get overall OCR statistics for the user.
Returns aggregated stats including success rate, average processing time, etc.
"""
stats = await OCRMetricsCRUD.get_overall_stats(
session,
days=days,
username=current_user.username
)
return OCRStatsResponse(**stats)

View File

@@ -0,0 +1,705 @@
"""API endpoints for receipts."""
from typing import List, Optional, Annotated
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Header, Response
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.services.receipt_service import ReceiptService
from backend.modules.data_entry.services.nomenclature_service import NomenclatureService
from backend.modules.data_entry.schemas.receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptResponse,
ReceiptListResponse,
ReceiptFilter,
ProcessingStats,
AttachmentResponse,
AccountingEntryResponse,
WorkflowAction,
RejectRequest,
EntriesUpdateRequest,
PartnerOption,
AccountOption,
CashRegisterOption,
ExpenseTypeOption,
BulkDeleteRequest,
BulkDeleteResponse,
BulkDeleteFailure,
)
from backend.modules.data_entry.db.models.receipt import ReceiptStatus, ReceiptDirection
from backend.modules.data_entry.services import sse_service
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============ Helper for selected company from header ============
async def get_selected_company(
current_user: CurrentUser = Depends(get_current_user),
x_selected_company: Annotated[Optional[str], Header()] = None
) -> int:
"""
Get selected company from X-Selected-Company header.
Validates that the user has access to the specified company.
Falls back to user's first company if no header is provided.
Raises:
HTTPException 403: If user doesn't have access to specified company
HTTPException 400: If user has no companies assigned
"""
if x_selected_company:
try:
company_id = int(x_selected_company)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid company ID format: {x_selected_company}"
)
# Validate user has access to this company
# Auth stores companies as strings
if str(company_id) in current_user.companies:
return company_id
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company_id}"
)
# No header - use first company from user's list
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
pass
raise HTTPException(
status_code=400,
detail="Nu aveți nicio firmă asignată"
)
# Dependency for injection
SelectedCompany = Annotated[int, Depends(get_selected_company)]
# Legacy function for backwards compatibility (deprecated)
def get_current_user_company(current_user: CurrentUser) -> int:
"""
DEPRECATED: Use get_selected_company() dependency instead.
This function returns the first company, ignoring X-Selected-Company header.
"""
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
return 1
return 1
# ============ SSE Endpoint for Real-time Status Updates ============
@router.get("/sse/status")
async def sse_status_stream(
batch_id: Optional[str] = Query(
default=None,
description="Optional batch_id to filter events for a specific batch"
),
):
"""
Server-Sent Events endpoint for real-time receipt status updates.
This endpoint provides a persistent connection that streams status change
events as they occur. Clients receive updates for CRUD operations on receipts
without needing to poll.
Query Parameters:
batch_id: Optional filter to only receive events for a specific batch upload.
Event Format:
data: {"receipt_id": 123, "status": "DRAFT", "processing_status": "completed", ...}
Headers:
- Content-Type: text/event-stream
- Cache-Control: no-cache
- Connection: keep-alive
Reconnection:
The retry: 3000 header hints clients to reconnect after 3 seconds if disconnected.
Example:
curl -N http://localhost:8000/api/data-entry/receipts/sse/status
curl -N http://localhost:8000/api/data-entry/receipts/sse/status?batch_id=abc-123
"""
return StreamingResponse(
sse_service.subscribe(batch_id=batch_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
# ============ Receipt CRUD Endpoints ============
@router.post("/", response_model=ReceiptResponse)
async def create_receipt(
data: ReceiptCreate,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Create a new receipt in DRAFT status."""
receipt = await ReceiptService.create_receipt(session, data, current_user.username)
return ReceiptResponse.model_validate(receipt)
@router.get("/", response_model=ReceiptListResponse)
async def list_receipts(
response: Response,
status: Optional[ReceiptStatus] = None,
direction: Optional[ReceiptDirection] = None,
company_id: Optional[int] = None,
created_by: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
search: Optional[str] = None,
# Bulk upload filters (US-012)
processing_status: Optional[str] = Query(default=None, description="Filter by processing status: pending, processing, completed, failed"),
batch_id: Optional[str] = Query(default=None, description="Filter by batch_id UUID"),
sort_by: Optional[str] = Query(default=None, description="Sort field: processing_started_at, processing_started_at_asc"),
# Pagination
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get paginated list of receipts with filters.
US-012: Extended with batch_id, processing_status filters and processing_stats.
"""
# Disable browser caching to always get fresh data
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
from datetime import date as date_type
filters = ReceiptFilter(
status=status,
direction=direction,
company_id=company_id or selected_company,
created_by=created_by,
date_from=date_type.fromisoformat(date_from) if date_from else None,
date_to=date_type.fromisoformat(date_to) if date_to else None,
search=search,
processing_status=processing_status,
batch_id=batch_id,
sort_by=sort_by,
page=page,
page_size=page_size,
)
return await ReceiptService.get_receipts(session, filters)
@router.get("/pending", response_model=List[ReceiptResponse])
async def list_pending_receipts(
response: Response,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all receipts pending review (for accountant view)."""
# Disable browser caching to always get fresh data
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
receipts = await ReceiptCRUD.get_pending_review(
session, company_id or selected_company
)
return [ReceiptResponse.model_validate(r) for r in receipts]
@router.get("/stats")
async def get_receipt_stats(
response: Response,
company_id: Optional[int] = None,
my_receipts: bool = False,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
current_user: CurrentUser = Depends(get_current_user),
):
"""Get receipt statistics."""
# Disable browser caching to always get fresh data
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
return await ReceiptCRUD.get_stats(
session,
company_id or selected_company,
created_by=current_user.username if my_receipts else None,
)
@router.get("/{receipt_id}", response_model=ReceiptResponse)
async def get_receipt(
receipt_id: int,
response: Response,
session: AsyncSession = Depends(get_session),
):
"""Get receipt details with attachments and accounting entries."""
# Disable browser caching to always get fresh data
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
receipt = await ReceiptService.get_receipt(session, receipt_id)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
return ReceiptResponse.model_validate(receipt)
@router.put("/{receipt_id}", response_model=ReceiptResponse)
async def update_receipt(
receipt_id: int,
data: ReceiptUpdate,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Update receipt (only DRAFT status, only by creator)."""
success, message, receipt = await ReceiptService.update_receipt(
session, receipt_id, data, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return ReceiptResponse.model_validate(receipt)
@router.delete("/bulk", response_model=BulkDeleteResponse)
async def bulk_delete_receipts(
data: BulkDeleteRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""
Bulk delete receipts (US-024).
Deletes multiple receipts in a single request with partial success support.
Validation rules:
- Each receipt must be in DRAFT status
- Each receipt must be created by the current user
- Receipts with processing_status 'pending' or 'processing' cannot be deleted
Returns:
BulkDeleteResponse with deleted IDs and failed items with error messages
"""
deleted: List[int] = []
failed: List[BulkDeleteFailure] = []
for receipt_id in data.ids:
# Get receipt with relationships for deletion
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
if not receipt:
failed.append(BulkDeleteFailure(id=receipt_id, error="Bonul nu a fost găsit"))
continue
# Check if receipt is being processed (bulk upload in progress)
if receipt.processing_status in ["pending", "processing"]:
failed.append(BulkDeleteFailure(
id=receipt_id,
error="Bonul este în curs de procesare și nu poate fi șters"
))
continue
# Check status - only DRAFT can be deleted
if receipt.status != ReceiptStatus.DRAFT:
failed.append(BulkDeleteFailure(
id=receipt_id,
error=f"Doar bonurile în status DRAFT pot fi șterse (status curent: {receipt.status.value})"
))
continue
# Check ownership
if receipt.created_by != current_user.username:
failed.append(BulkDeleteFailure(
id=receipt_id,
error="Doar creatorul bonului poate să-l șteargă"
))
continue
# All validations passed - delete the receipt
# Note: Cascade delete handles attachments and accounting entries
await ReceiptCRUD.delete(session, receipt)
deleted.append(receipt_id)
return BulkDeleteResponse(deleted=deleted, failed=failed)
@router.delete("/{receipt_id}")
async def delete_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Delete receipt (only DRAFT status, only by creator)."""
success, message = await ReceiptService.delete_receipt(
session, receipt_id, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return {"success": True, "message": message}
# ============ Workflow Endpoints ============
@router.post("/{receipt_id}/submit", response_model=WorkflowAction)
async def submit_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Submit receipt for review (DRAFT → PENDING_REVIEW)."""
success, message, receipt = await ReceiptService.submit_for_review(
session, receipt_id, current_user.username
)
# Broadcast SSE event on success (US-030)
if success and receipt:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/approve", response_model=WorkflowAction)
async def approve_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Approve receipt (PENDING_REVIEW → APPROVED). Accountant action."""
success, message, receipt = await ReceiptService.approve_receipt(
session, receipt_id, current_user.username
)
# Broadcast SSE event on success (US-030)
if success and receipt:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/reject", response_model=WorkflowAction)
async def reject_receipt(
receipt_id: int,
data: RejectRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Reject receipt (PENDING_REVIEW → REJECTED). Accountant action."""
success, message, receipt = await ReceiptService.reject_receipt(
session, receipt_id, current_user.username, data.reason
)
# Broadcast SSE event on success (US-030)
if success and receipt:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/resubmit", response_model=WorkflowAction)
async def resubmit_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW)."""
success, message, receipt = await ReceiptService.resubmit_receipt(
session, receipt_id, current_user.username
)
# Broadcast SSE event on success (US-030)
if success and receipt:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/unapprove", response_model=WorkflowAction)
async def unapprove_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Unapprove receipt (APPROVED → PENDING_REVIEW). Returns to pending for corrections."""
success, message, receipt = await ReceiptService.unapprove_receipt(
session, receipt_id, current_user.username
)
# Broadcast SSE event on success (US-030)
if success and receipt:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
# ============ Accounting Entries Endpoints ============
@router.get("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
async def get_receipt_entries(
receipt_id: int,
session: AsyncSession = Depends(get_session),
):
"""Get accounting entries for a receipt."""
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
return [AccountingEntryResponse.model_validate(e) for e in entries]
@router.put("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
async def update_receipt_entries(
receipt_id: int,
data: EntriesUpdateRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Update accounting entries for a receipt (accountant action)."""
success, message, entries = await ReceiptService.update_entries(
session, receipt_id, data.entries, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return [AccountingEntryResponse.model_validate(e) for e in entries]
@router.post("/{receipt_id}/entries/regenerate", response_model=List[AccountingEntryResponse])
async def regenerate_entries(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Regenerate accounting entries based on receipt data."""
success, message, _ = await ReceiptService.regenerate_entries(
session, receipt_id, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
return [AccountingEntryResponse.model_validate(e) for e in entries]
# ============ Attachment Endpoints ============
@router.post("/{receipt_id}/attachments", response_model=AttachmentResponse)
async def upload_attachment(
receipt_id: int,
file: UploadFile = File(...),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Upload attachment for a receipt."""
# Check receipt exists and user can modify it
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=False)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
# Only allow uploads for DRAFT and REJECTED receipts
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
raise HTTPException(
status_code=400,
detail="Cannot upload attachments for this receipt status"
)
# Only creator can upload
if receipt.created_by != current_user.username:
raise HTTPException(
status_code=403,
detail="Only the creator can upload attachments"
)
try:
attachment = await AttachmentCRUD.create(session, receipt_id, file)
return AttachmentResponse.model_validate(attachment)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/{receipt_id}/attachments", response_model=List[AttachmentResponse])
async def list_attachments(
receipt_id: int,
session: AsyncSession = Depends(get_session),
):
"""Get all attachments for a receipt."""
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
return [AttachmentResponse.model_validate(a) for a in attachments]
@router.get("/attachments/{attachment_id}/download")
async def download_attachment(
attachment_id: int,
session: AsyncSession = Depends(get_session),
):
"""Download an attachment file."""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
file_path = AttachmentCRUD.get_file_path(attachment)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
return FileResponse(
path=str(file_path),
filename=attachment.filename,
media_type=attachment.mime_type,
)
@router.delete("/attachments/{attachment_id}")
async def delete_attachment(
attachment_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Delete an attachment."""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
# Get receipt to check permissions
receipt = await ReceiptCRUD.get_by_id(session, attachment.receipt_id, include_relations=False)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
# Only allow deletion for DRAFT receipts by creator
if receipt.status != ReceiptStatus.DRAFT:
raise HTTPException(
status_code=400,
detail="Cannot delete attachments for this receipt status"
)
if receipt.created_by != current_user.username:
raise HTTPException(
status_code=403,
detail="Only the creator can delete attachments"
)
await AttachmentCRUD.delete(session, attachment)
return {"success": True, "message": "Attachment deleted"}
# ============ Nomenclature Endpoints ============
@router.get("/nomenclature/partners", response_model=List[PartnerOption])
async def get_partners(
search: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get partners (suppliers/customers) for dropdown."""
return await NomenclatureService.get_partners(
company_id or selected_company, search, session
)
@router.get("/nomenclature/accounts", response_model=List[AccountOption])
async def get_accounts(
prefix: Optional[str] = None,
company_id: Optional[int] = None,
selected_company: SelectedCompany = None,
):
"""Get chart of accounts for dropdown."""
return await NomenclatureService.get_accounts(
company_id or selected_company, prefix
)
@router.get("/nomenclature/cash-registers", response_model=List[CashRegisterOption])
async def get_cash_registers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get cash registers and bank accounts for dropdown."""
return await NomenclatureService.get_cash_registers(company_id or selected_company, session)
@router.get("/nomenclature/expense-types", response_model=List[ExpenseTypeOption])
async def get_expense_types():
"""Get predefined expense types for dropdown."""
return await NomenclatureService.get_expense_types()

View File

@@ -0,0 +1,39 @@
# Pydantic schemas
from .receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptResponse,
ReceiptListResponse,
ReceiptFilter,
AttachmentResponse,
AccountingEntryCreate,
AccountingEntryUpdate,
AccountingEntryResponse,
WorkflowAction,
RejectRequest,
)
from .bulk import (
BulkUploadResponse,
BatchJobInfo,
BatchStatusResponse,
BulkUploadError,
)
__all__ = [
"ReceiptCreate",
"ReceiptUpdate",
"ReceiptResponse",
"ReceiptListResponse",
"ReceiptFilter",
"AttachmentResponse",
"AccountingEntryCreate",
"AccountingEntryUpdate",
"AccountingEntryResponse",
"WorkflowAction",
"RejectRequest",
# Bulk upload schemas
"BulkUploadResponse",
"BatchJobInfo",
"BatchStatusResponse",
"BulkUploadError",
]

View File

@@ -0,0 +1,212 @@
"""Pydantic schemas for bulk upload endpoints."""
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class BulkUploadResponse(BaseModel):
"""Response schema for bulk upload endpoint."""
batch_id: int = Field(..., description="Unique batch identifier for tracking")
job_ids: List[str] = Field(..., description="List of OCR job UUIDs created")
total_files: int = Field(..., description="Number of files in the batch")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"job_ids": [
"550e8400-e29b-41d4-a716-446655440001",
"550e8400-e29b-41d4-a716-446655440002",
],
"total_files": 2,
"message": "2 files queued for processing"
}
}
class BatchJobInfo(BaseModel):
"""Information about a single job in a batch."""
job_id: str = Field(..., description="OCR job UUID")
filename: str = Field(..., description="Original filename")
status: str = Field(..., description="Job status: pending, processing, completed, failed")
receipt_id: Optional[int] = Field(None, description="Created receipt ID (if completed)")
error_message: Optional[str] = Field(None, description="Error message (if failed)")
class BatchStatusResponse(BaseModel):
"""Response schema for batch status endpoint."""
batch_id: int = Field(..., description="Batch identifier")
status: str = Field(..., description="Overall batch status")
total_files: int = Field(..., description="Total number of files in batch")
pending_count: int = Field(..., description="Number of pending jobs")
processing_count: int = Field(..., description="Number of processing jobs")
completed_count: int = Field(..., description="Number of completed jobs")
failed_count: int = Field(..., description="Number of failed jobs")
jobs: List[BatchJobInfo] = Field(..., description="List of jobs with their status")
total_amount: Optional[float] = Field(None, description="Sum of all receipt amounts")
created_at: datetime = Field(..., description="Batch creation timestamp")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"status": "processing",
"total_files": 5,
"pending_count": 2,
"processing_count": 1,
"completed_count": 2,
"failed_count": 0,
"jobs": [
{"job_id": "abc-123", "filename": "bon1.pdf", "status": "completed", "receipt_id": 15},
{"job_id": "def-456", "filename": "bon2.jpg", "status": "processing", "receipt_id": None},
],
"total_amount": 150.50,
"created_at": "2025-01-09T10:30:00"
}
}
class DuplicateFileInfo(BaseModel):
"""Information about a duplicate file detected during upload."""
filename: str = Field(..., description="Name of the duplicate file")
error: str = Field(default="duplicate", description="Error type (always 'duplicate')")
existing_receipt_id: int = Field(..., description="ID of the existing receipt with same file hash")
message: str = Field(..., description="Human-readable error message")
class Config:
json_schema_extra = {
"example": {
"filename": "bon_lidl.pdf",
"error": "duplicate",
"existing_receipt_id": 123,
"message": "Fișier duplicat - există deja ca bon #123"
}
}
class BulkUploadResponseWithDuplicates(BaseModel):
"""Response schema for bulk upload with partial success (some duplicates)."""
batch_id: Optional[int] = Field(None, description="Batch ID (None if all files were duplicates)")
job_ids: List[str] = Field(default_factory=list, description="List of OCR job UUIDs created")
total_files: int = Field(..., description="Total number of files submitted")
processed_files: int = Field(..., description="Number of files successfully queued")
duplicate_files: int = Field(..., description="Number of duplicate files rejected")
duplicates: List[DuplicateFileInfo] = Field(default_factory=list, description="List of duplicate file details")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"job_ids": ["550e8400-e29b-41d4-a716-446655440001"],
"total_files": 3,
"processed_files": 1,
"duplicate_files": 2,
"duplicates": [
{
"filename": "bon_lidl.pdf",
"error": "duplicate",
"existing_receipt_id": 123,
"message": "Fișier duplicat - există deja ca bon #123"
}
],
"message": "1 fișier în procesare, 2 duplicate ignorate"
}
}
class BulkUploadError(BaseModel):
"""Error response for bulk upload validation failures."""
detail: str = Field(..., description="Error message")
invalid_files: Optional[List[str]] = Field(None, description="List of invalid filenames")
class RetryResponse(BaseModel):
"""Response schema for retry endpoints."""
success: bool = Field(..., description="Whether the retry was successful")
receipt_id: int = Field(..., description="Receipt ID that was retried")
job_id: Optional[str] = Field(None, description="New OCR job ID created")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"success": True,
"receipt_id": 123,
"job_id": "550e8400-e29b-41d4-a716-446655440001",
"message": "Bon reîncarcat în procesare"
}
}
class BatchRetryResponse(BaseModel):
"""Response schema for batch retry endpoint."""
success: bool = Field(..., description="Whether any retries were successful")
batch_id: str = Field(..., description="Batch ID that was retried")
retried_count: int = Field(..., description="Number of receipts successfully retried")
failed_count: int = Field(..., description="Number of receipts that couldn't be retried")
errors: List[str] = Field(default_factory=list, description="List of error messages")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"success": True,
"batch_id": "abc-123",
"retried_count": 3,
"failed_count": 0,
"errors": [],
"message": "3 bonuri reîncarcate în procesare"
}
}
class CancelJobResponse(BaseModel):
"""Response schema for cancel job endpoint."""
success: bool = Field(..., description="Whether the cancellation was successful")
job_id: str = Field(..., description="Job ID that was cancelled")
cancelled_at: datetime = Field(..., description="Timestamp when the job was cancelled")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"success": True,
"job_id": "550e8400-e29b-41d4-a716-446655440001",
"cancelled_at": "2025-01-11T15:30:00",
"message": "Job anulat cu succes"
}
}
class CancelBatchResponse(BaseModel):
"""Response schema for cancel batch endpoint."""
success: bool = Field(..., description="Whether any jobs were cancelled")
batch_id: int = Field(..., description="Batch ID that was cancelled")
cancelled_count: int = Field(..., description="Number of jobs successfully cancelled")
skipped_count: int = Field(..., description="Number of jobs skipped (completed/failed)")
message: str = Field(..., description="Status message")
class Config:
json_schema_extra = {
"example": {
"success": True,
"batch_id": 1,
"cancelled_count": 3,
"skipped_count": 2,
"message": "3 job-uri anulate, 2 ignorate (deja procesate)"
}
}

View File

@@ -0,0 +1,243 @@
"""Pydantic schemas for OCR API."""
from datetime import date
from decimal import Decimal
from typing import Optional, List
from pydantic import BaseModel, Field
class TvaEntry(BaseModel):
"""Single TVA entry with code, percentage and amount."""
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
amount: Decimal = Field(description="TVA amount for this rate")
class PaymentMethod(BaseModel):
"""Payment method entry from OCR."""
method: str = Field(description="CARD or NUMERAR")
amount: Decimal = Field(description="Amount paid")
class ValidationWarning(BaseModel):
"""Validation warning from OCR extraction."""
field: str = Field(description="Field name (e.g., 'amount', 'tva_total')")
rule: str = Field(description="Rule name (e.g., 'amount_range', 'tva_ratio')")
message: str = Field(description="Human-readable warning message")
severity: str = Field(description="Severity: 'info', 'warning', 'error'")
suggested_value: Optional[str] = Field(default=None, description="Suggested corrected value")
class ExtractionData(BaseModel):
"""Extracted receipt data from OCR."""
receipt_type: str = Field(default='bon_fiscal', description="Receipt type: bon_fiscal or chitanta")
receipt_number: Optional[str] = Field(default=None, description="Receipt number")
receipt_series: Optional[str] = Field(default=None, description="Receipt series")
receipt_date: Optional[date] = Field(default=None, description="Receipt date")
amount: Optional[Decimal] = Field(default=None, description="Total amount")
partner_name: Optional[str] = Field(default=None, description="Vendor/partner name")
cui: Optional[str] = Field(default=None, description="CUI (fiscal identification code)")
description: Optional[str] = Field(default=None, description="Optional description")
# Additional extracted fields - Multiple TVA entries support
tva_entries: List[TvaEntry] = Field(default=[], description="List of TVA entries by rate (A, B, C, D)")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
address: Optional[str] = Field(default=None, description="Vendor address")
items_count: Optional[int] = Field(default=None, description="Number of items/articles")
# Payment methods extracted from receipt
payment_methods: List[PaymentMethod] = Field(default=[], description="Payment methods from receipt (CARD, NUMERAR)")
suggested_payment_mode: Optional[str] = Field(default=None, description="Auto-suggested payment mode based on OCR (casa/banca)")
# Client data (for B2B receipts - buyer information)
client_name: Optional[str] = Field(default=None, description="Client/customer company name")
client_cui: Optional[str] = Field(default=None, description="Client CUI/CIF fiscal code")
client_address: Optional[str] = Field(default=None, description="Client address")
confidence_amount: float = Field(default=0.0, ge=0, le=1, description="Amount extraction confidence")
confidence_date: float = Field(default=0.0, ge=0, le=1, description="Date extraction confidence")
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
confidence_tva: float = Field(default=0.0, ge=0, le=1, description="TVA extraction confidence")
confidence_payment: float = Field(default=0.0, ge=0, le=1, description="Payment extraction confidence")
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
raw_text: str = Field(default="", description="Raw OCR text (primary)")
raw_texts: List[str] = Field(default=[], description="Raw OCR texts from all engine passes (for analysis)")
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
# Validation results (added by bon-ocr-validation feature)
# needs_manual_review: None = not validated yet (old receipts), False = no review needed, True = needs review
needs_manual_review: Optional[bool] = Field(default=None, description="Flag for supervisor review (None=not validated, False=ok, True=needs review)")
validation_warnings: List[str] = Field(default=[], description="Validation warnings")
validation_errors: List[str] = Field(default=[], description="Validation errors")
inter_ocr_ratios: dict[str, float] = Field(default={}, description="Inter-OCR consistency ratios")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"receipt_type": "bon_fiscal",
"receipt_number": "1360760",
"receipt_series": "0146",
"receipt_date": "2025-10-11",
"amount": 186.16,
"partner_name": "FIVE-HOLDING S.A.",
"cui": "10562600",
"description": None,
"tva_entries": [
{"code": "A", "percent": 19, "amount": 25.00},
{"code": "B", "percent": 9, "amount": 7.31}
],
"tva_total": 32.31,
"address": "JUD. CONSTANTA, MUN. CONSTANTA, STR. ION ROATA NR. 3",
"items_count": 17,
"confidence_amount": 0.98,
"confidence_date": 0.98,
"confidence_vendor": 0.95,
"overall_confidence": 0.97,
"raw_text": "FIVE-HOLDING S.A.\nCIF: RO10562600\n..."
}
}
class OCRResponse(BaseModel):
"""OCR API response."""
success: bool = Field(description="Whether OCR processing was successful")
message: str = Field(description="Status message")
data: Optional[ExtractionData] = Field(default=None, description="Extracted data")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"success": True,
"message": "OCR processing successful. Found: amount, date, vendor",
"data": {
"receipt_type": "bon_fiscal",
"receipt_number": "12345",
"receipt_date": "2024-01-15",
"amount": 125.50,
"partner_name": "MEGA IMAGE SRL",
"cui": "12345678",
"confidence_amount": 0.95,
"confidence_date": 0.90,
"confidence_vendor": 0.75,
"overall_confidence": 0.87,
"raw_text": "BON FISCAL\nMEGA IMAGE SRL\n..."
}
}
}
class OCRStatusResponse(BaseModel):
"""OCR service status response."""
available: bool = Field(description="Whether OCR service is available")
engines: list[str] = Field(description="Available OCR engines")
message: str = Field(description="Status message")
# ============================================================================
# Job Queue Schemas (for async OCR processing)
# ============================================================================
from datetime import datetime
from enum import Enum
class OCREngineChoice(str, Enum):
"""OCR engine selection options."""
tesseract = "tesseract"
doctr = "doctr" # 3.3x faster than PaddleOCR with same accuracy (90/100)
doctr_plus = "doctr_plus" # docTR with 2-tier sequential processing + early exit (optimized, recommended)
paddleocr = "paddleocr"
class OCRJobStatus(str, Enum):
"""OCR job status."""
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
class OCRJobSubmitResponse(BaseModel):
"""Response when submitting an OCR job."""
job_id: str = Field(description="Unique job identifier (UUID)")
status: OCRJobStatus = Field(description="Initial job status (pending)")
queue_position: int = Field(description="Position in queue (1 = next to process)")
estimated_wait_seconds: int = Field(description="Estimated wait time in seconds")
created_at: datetime = Field(description="Job creation timestamp")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"job_id": "abc123-def456-ghi789",
"status": "pending",
"queue_position": 3,
"estimated_wait_seconds": 21,
"created_at": "2024-01-15T12:00:00"
}
}
class OCRJobResponse(BaseModel):
"""Full OCR job status response."""
job_id: str = Field(description="Unique job identifier")
status: OCRJobStatus = Field(description="Current job status")
queue_position: Optional[int] = Field(default=None, description="Queue position (None if processing/completed)")
estimated_wait_seconds: Optional[int] = Field(default=None, description="Estimated wait time")
created_at: datetime = Field(description="Job creation timestamp")
started_at: Optional[datetime] = Field(default=None, description="Processing start timestamp")
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
# Detailed timing breakdown
queue_wait_ms: Optional[int] = Field(default=None, description="Time waiting in queue (started_at - created_at)")
ocr_time_ms: Optional[int] = Field(default=None, description="Actual OCR engine processing time")
processing_time_ms: Optional[int] = Field(default=None, description="Total job processing time (completed_at - started_at)")
result: Optional[ExtractionData] = Field(default=None, description="Extraction result (only if completed)")
error: Optional[str] = Field(default=None, description="Error message (only if failed)")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"job_id": "abc123-def456-ghi789",
"status": "completed",
"queue_position": None,
"estimated_wait_seconds": 0,
"created_at": "2024-01-15T12:00:00",
"started_at": "2024-01-15T12:00:21",
"completed_at": "2024-01-15T12:00:28",
"processing_time_ms": 6543,
"result": {
"receipt_number": "123",
"amount": 85.99,
"ocr_engine": "paddleocr-light"
}
}
}
class OCRQueueStatusResponse(BaseModel):
"""Queue statistics response."""
pending_jobs: int = Field(description="Number of jobs waiting in queue")
processing_jobs: int = Field(description="Number of jobs currently processing")
average_time_seconds: float = Field(description="Average processing time in seconds")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"pending_jobs": 5,
"processing_jobs": 1,
"average_time_seconds": 7.2
}
}

View File

@@ -0,0 +1,311 @@
"""Pydantic schemas for receipts API."""
import json
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, List, Any, Union
from pydantic import BaseModel, Field, ConfigDict, field_validator
from backend.modules.data_entry.db.models.receipt import ReceiptType, ReceiptDirection, ReceiptStatus, ProcessingStatus
from backend.modules.data_entry.db.models.accounting_entry import EntryType
# ============ Accounting Entry Schemas ============
class AccountingEntryBase(BaseModel):
"""Base schema for accounting entry."""
entry_type: EntryType
account_code: str = Field(max_length=20)
account_name: Optional[str] = Field(default=None, max_length=200)
amount: Decimal
partner_id: Optional[int] = None
cost_center_id: Optional[int] = None
class AccountingEntryCreate(AccountingEntryBase):
"""Schema for creating an accounting entry."""
pass
class AccountingEntryUpdate(BaseModel):
"""Schema for updating an accounting entry."""
entry_type: Optional[EntryType] = None
account_code: Optional[str] = Field(default=None, max_length=20)
account_name: Optional[str] = Field(default=None, max_length=200)
amount: Optional[Decimal] = None
partner_id: Optional[int] = None
cost_center_id: Optional[int] = None
class AccountingEntryResponse(AccountingEntryBase):
"""Schema for accounting entry response."""
model_config = ConfigDict(from_attributes=True)
id: int
receipt_id: int
is_auto_generated: bool
modified_by: Optional[str] = None
modified_at: Optional[datetime] = None
sort_order: int
# ============ Attachment Schemas ============
class AttachmentResponse(BaseModel):
"""Schema for attachment response."""
model_config = ConfigDict(from_attributes=True)
id: int
receipt_id: int
filename: str
stored_filename: str
file_path: str
file_size: int
mime_type: str
uploaded_at: datetime
# ============ TVA Schema ============
class TvaEntrySchema(BaseModel):
"""Single TVA entry with code, percentage and amount."""
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
amount: Decimal = Field(description="TVA amount for this rate")
class PaymentMethodSchema(BaseModel):
"""Payment method entry (CARD/NUMERAR)."""
method: str = Field(description="Payment method: CARD or NUMERAR")
amount: Decimal = Field(description="Amount paid with this method")
# ============ Receipt Schemas ============
class ReceiptBase(BaseModel):
"""Base schema for receipt."""
receipt_type: ReceiptType = ReceiptType.BON_FISCAL
direction: ReceiptDirection = ReceiptDirection.CHELTUIALA
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
receipt_date: date
amount: Decimal = Field(gt=0)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (multiple entries support)
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
items_count: Optional[int] = Field(default=None, description="Number of items")
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
# Other fields
expense_type_code: Optional[str] = Field(default=None, max_length=20)
company_id: int
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200)
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
cash_register_id: Optional[int] = None
cash_register_name: Optional[str] = Field(default=None, max_length=100)
cash_register_account: Optional[str] = Field(default=None, max_length=20)
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
class ReceiptCreate(ReceiptBase):
"""Schema for creating a receipt."""
pass
class ReceiptUpdate(BaseModel):
"""Schema for updating a receipt (DRAFT only)."""
receipt_type: Optional[ReceiptType] = None
direction: Optional[ReceiptDirection] = None
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
receipt_date: Optional[date] = None
amount: Optional[Decimal] = Field(default=None, gt=0)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (multiple entries support)
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
items_count: Optional[int] = Field(default=None, description="Number of items")
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
# Other fields
expense_type_code: Optional[str] = Field(default=None, max_length=20)
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200)
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
cash_register_id: Optional[int] = None
cash_register_name: Optional[str] = Field(default=None, max_length=100)
cash_register_account: Optional[str] = Field(default=None, max_length=20)
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
class ReceiptResponse(ReceiptBase):
"""Schema for receipt response with all fields."""
model_config = ConfigDict(from_attributes=True)
id: int
# Override amount to allow zero values in response (validation is on input, not output)
amount: Decimal
status: ReceiptStatus
created_by: str
created_at: datetime
updated_at: datetime
submitted_at: Optional[datetime] = None
reviewed_by: Optional[str] = None
reviewed_at: Optional[datetime] = None
rejection_reason: Optional[str] = None
oracle_synced_at: Optional[datetime] = None
oracle_act_id: Optional[int] = None
oracle_error: Optional[str] = None
# Bulk upload batch tracking (US-012)
batch_id: Optional[str] = None
processing_status: Optional[str] = None
processing_error: Optional[str] = None
file_hash: Optional[str] = None
processing_started_at: Optional[datetime] = None
processing_completed_at: Optional[datetime] = None
# Relationships (optional, loaded when needed)
attachments: List[AttachmentResponse] = []
entries: List[AccountingEntryResponse] = []
@field_validator('tva_breakdown', mode='before')
@classmethod
def parse_tva_breakdown(cls, v: Any) -> Optional[List[dict]]:
"""Deserialize tva_breakdown from JSON string if needed."""
if v is None:
return None
if isinstance(v, str):
try:
return json.loads(v)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(v, list):
return v
return None
@field_validator('payment_methods', mode='before')
@classmethod
def parse_payment_methods(cls, v: Any) -> Optional[List[dict]]:
"""Deserialize payment_methods from JSON string if needed."""
if v is None:
return None
if isinstance(v, str):
try:
return json.loads(v)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(v, list):
return v
return None
class ProcessingStats(BaseModel):
"""Statistics for bulk upload processing status (US-012)."""
pending_count: int = 0
processing_count: int = 0
completed_count: int = 0
failed_count: int = 0
class ReceiptListResponse(BaseModel):
"""Schema for paginated receipt list response."""
items: List[ReceiptResponse]
total: int
page: int
page_size: int
pages: int
# Processing stats for bulk upload filtering (US-012)
processing_stats: Optional[ProcessingStats] = None
class ReceiptFilter(BaseModel):
"""Schema for filtering receipts."""
status: Optional[ReceiptStatus] = None
direction: Optional[ReceiptDirection] = None
company_id: Optional[int] = None
created_by: Optional[str] = None
date_from: Optional[date] = None
date_to: Optional[date] = None
search: Optional[str] = None # Search in description, partner_name
# Bulk upload filters (US-012)
processing_status: Optional[str] = None # ProcessingStatus enum value
batch_id: Optional[str] = None # Filter by batch_id
sort_by: Optional[str] = None # Sort field (e.g., "processing_started_at")
# Pagination
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
# ============ Workflow Schemas ============
class WorkflowAction(BaseModel):
"""Schema for workflow action response."""
success: bool
message: str
receipt: Optional[ReceiptResponse] = None
class RejectRequest(BaseModel):
"""Schema for rejection request."""
reason: str = Field(min_length=5, max_length=500)
class EntriesUpdateRequest(BaseModel):
"""Schema for bulk updating accounting entries."""
entries: List[AccountingEntryCreate]
# ============ Nomenclature Schemas ============
class PartnerOption(BaseModel):
"""Schema for partner dropdown option (used for autocomplete assistance)."""
name: str
fiscal_code: Optional[str] = None
address: Optional[str] = None
source: str = "oracle" # 'oracle' (synced) or 'local'
class AccountOption(BaseModel):
"""Schema for account dropdown option."""
code: str
name: str
class CashRegisterOption(BaseModel):
"""Schema for cash register dropdown option."""
id: int
name: str
account_code: str # 5311, 5121, etc.
class ExpenseTypeOption(BaseModel):
"""Schema for expense type dropdown option."""
code: str
name: str
account_code: str
has_vat: bool
vat_percent: Decimal = Decimal("19")
# ============ Bulk Delete Schemas (US-024) ============
class BulkDeleteRequest(BaseModel):
"""Request schema for bulk delete endpoint."""
ids: List[int] = Field(..., min_length=1, description="List of receipt IDs to delete")
class BulkDeleteFailure(BaseModel):
"""Schema for a single failed deletion."""
id: int
error: str
class BulkDeleteResponse(BaseModel):
"""Response schema for bulk delete with partial success support."""
deleted: List[int] = Field(default_factory=list, description="IDs of successfully deleted receipts")
failed: List[BulkDeleteFailure] = Field(default_factory=list, description="IDs that failed with error messages")

View File

@@ -0,0 +1,16 @@
# Business logic services
from .receipt_service import ReceiptService
from .nomenclature_service import NomenclatureService
from .expense_types import EXPENSE_TYPES, ExpenseType
from .receipt_auto_create import ReceiptAutoCreateService, ReceiptCreateResult
from . import sse_service
__all__ = [
"ReceiptService",
"NomenclatureService",
"EXPENSE_TYPES",
"ExpenseType",
"ReceiptAutoCreateService",
"ReceiptCreateResult",
"sse_service",
]

View File

@@ -0,0 +1,215 @@
"""
Cleanup service for auto-deleting expired failed receipts.
US-008: Backend - Auto-Cleanup Erori După 7 Zile
- Finds receipts with processing_status='failed' and processing_completed_at < now() - 7 days
- Deletes the receipts and their attached files from storage
- Runs at startup and then daily as a background task
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
from backend.modules.data_entry.config import settings
logger = logging.getLogger(__name__)
# Cleanup configuration
CLEANUP_RETENTION_DAYS = 7
CLEANUP_INTERVAL_HOURS = 24
# In-memory storage for last cleanup stats (optional - for login notification)
_last_cleanup_stats: dict = {
"count": 0,
"timestamp": None
}
def get_last_cleanup_stats() -> dict:
"""Get stats from the last cleanup run for notification purposes."""
return _last_cleanup_stats.copy()
async def cleanup_expired_failed_receipts(session: AsyncSession) -> int:
"""
Find and delete receipts with processing_status='failed' older than 7 days.
This function:
1. Queries for failed receipts where processing_completed_at < now() - 7 days
2. Deletes attachment files from disk
3. Deletes the receipt records (cascade deletes attachment records)
Args:
session: AsyncSession for database operations
Returns:
Number of receipts deleted
"""
global _last_cleanup_stats
cutoff_date = datetime.utcnow() - timedelta(days=CLEANUP_RETENTION_DAYS)
# Find expired failed receipts with their attachments
query = select(Receipt).options(
selectinload(Receipt.attachments)
).where(
and_(
Receipt.processing_status == "failed",
Receipt.processing_completed_at.isnot(None),
Receipt.processing_completed_at < cutoff_date
)
)
result = await session.execute(query)
expired_receipts = result.scalars().all()
if not expired_receipts:
logger.debug("[Cleanup] No expired failed receipts found")
return 0
deleted_count = 0
deleted_files = 0
upload_base_path = settings.upload_path_resolved
for receipt in expired_receipts:
try:
# Delete attachment files from disk
for attachment in receipt.attachments:
file_path = upload_base_path / attachment.file_path
if file_path.exists():
try:
file_path.unlink()
deleted_files += 1
logger.debug(f"[Cleanup] Deleted file: {file_path}")
except OSError as e:
logger.warning(f"[Cleanup] Failed to delete file {file_path}: {e}")
# Also try to clean up empty parent directories
parent_dir = file_path.parent
if parent_dir.exists() and parent_dir != upload_base_path:
try:
# Only remove if directory is empty
if not any(parent_dir.iterdir()):
parent_dir.rmdir()
logger.debug(f"[Cleanup] Removed empty directory: {parent_dir}")
except OSError:
pass # Directory not empty or permission issue, skip
# Delete receipt (cascade deletes attachment records in DB)
await session.delete(receipt)
deleted_count += 1
except Exception as e:
logger.error(f"[Cleanup] Error deleting receipt {receipt.id}: {e}")
continue
# Commit all deletions
if deleted_count > 0:
await session.commit()
# Update stats for notification
_last_cleanup_stats = {
"count": deleted_count,
"files_deleted": deleted_files,
"timestamp": datetime.utcnow().isoformat()
}
logger.info(f"[Cleanup] Cleaned up {deleted_count} expired failed receipts ({deleted_files} files)")
return deleted_count
async def run_cleanup_task(get_session_func) -> None:
"""
Background task that runs cleanup at startup and then every 24 hours.
Args:
get_session_func: Async generator function that yields database sessions
"""
logger.info("[Cleanup] Starting cleanup background task")
# Run immediately at startup
try:
async for session in get_session_func():
count = await cleanup_expired_failed_receipts(session)
if count > 0:
logger.info(f"[Cleanup] Initial cleanup: {count} receipts removed")
break
except Exception as e:
logger.error(f"[Cleanup] Initial cleanup failed: {e}")
# Then run every 24 hours
while True:
try:
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
async for session in get_session_func():
count = await cleanup_expired_failed_receipts(session)
if count > 0:
logger.info(f"[Cleanup] Daily cleanup: {count} receipts removed")
break
except asyncio.CancelledError:
logger.info("[Cleanup] Cleanup task cancelled")
raise
except Exception as e:
logger.error(f"[Cleanup] Daily cleanup failed: {e}")
# Continue running even if one cleanup fails
# Global reference to cleanup task for graceful shutdown
_cleanup_task: Optional[asyncio.Task] = None
async def start_cleanup_task(get_session_func) -> bool:
"""
Start the cleanup background task.
Args:
get_session_func: Async generator function that yields database sessions
Returns:
True if task started successfully, False otherwise
"""
global _cleanup_task
if _cleanup_task is not None and not _cleanup_task.done():
logger.warning("[Cleanup] Cleanup task already running")
return False
try:
_cleanup_task = asyncio.create_task(run_cleanup_task(get_session_func))
logger.info("[Cleanup] ✅ Cleanup background task started")
return True
except Exception as e:
logger.error(f"[Cleanup] Failed to start cleanup task: {e}")
return False
async def stop_cleanup_task() -> None:
"""Stop the cleanup background task gracefully."""
global _cleanup_task
if _cleanup_task is not None and not _cleanup_task.done():
_cleanup_task.cancel()
try:
await _cleanup_task
except asyncio.CancelledError:
pass
logger.info("[Cleanup] Cleanup task stopped")
_cleanup_task = None
def is_cleanup_task_running() -> bool:
"""Check if the cleanup task is currently running."""
return _cleanup_task is not None and not _cleanup_task.done()

View File

@@ -0,0 +1,101 @@
"""Predefined expense types for automatic accounting entry generation."""
from decimal import Decimal
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class ExpenseType:
"""Expense type definition with accounting configuration."""
code: str
name: str
account_code: str
account_name: str
has_vat: bool
vat_percent: Decimal = Decimal("19")
vat_account: str = "4426"
# Predefined expense types
EXPENSE_TYPES: Dict[str, ExpenseType] = {
"FUEL": ExpenseType(
code="FUEL",
name="Combustibil",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"MATERIALS": ExpenseType(
code="MATERIALS",
name="Materiale consumabile",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=True,
),
"OFFICE": ExpenseType(
code="OFFICE",
name="Rechizite birou",
account_code="6024",
account_name="Cheltuieli privind materialele pentru ambalat",
has_vat=True,
),
"PHONE": ExpenseType(
code="PHONE",
name="Telefonie / Internet",
account_code="626",
account_name="Cheltuieli postale si taxe de telecomunicatii",
has_vat=True,
),
"PARKING": ExpenseType(
code="PARKING",
name="Parcare",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"FOOD": ExpenseType(
code="FOOD",
name="Alimentatie",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=False, # No deductible VAT for food
),
"TRANSPORT": ExpenseType(
code="TRANSPORT",
name="Transport",
account_code="624",
account_name="Cheltuieli cu transportul de bunuri si personal",
has_vat=True,
),
"OTHER": ExpenseType(
code="OTHER",
name="Altele",
account_code="628",
account_name="Alte cheltuieli cu serviciile executate de terti",
has_vat=True,
),
}
def get_expense_type(code: str) -> Optional[ExpenseType]:
"""Get expense type by code."""
return EXPENSE_TYPES.get(code)
def get_all_expense_types() -> Dict[str, ExpenseType]:
"""Get all expense types."""
return EXPENSE_TYPES.copy()
# Default cash register accounts
CASH_REGISTER_ACCOUNTS = {
"CASA": {
"code": "5311",
"name": "Casa in lei",
},
"BANCA": {
"code": "5121",
"name": "Conturi la banci in lei",
},
}

View File

@@ -0,0 +1,366 @@
"""Image preprocessing for optimal OCR results."""
from pathlib import Path
from typing import List
import numpy as np
import cv2
try:
import pdf2image
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
class ImagePreprocessor:
"""Preprocess receipt images for OCR."""
def _add_safety_padding(self, image: np.ndarray, padding: int = 50) -> np.ndarray:
"""Add white padding around image to protect edge content during rotation.
This prevents left/right margin truncation in OCR by ensuring text near
edges isn't lost during deskew rotation.
"""
if len(image.shape) == 2:
# Grayscale
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=255
)
else:
# Color (BGR)
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
def load_image(self, path: Path) -> np.ndarray:
"""Load image from file."""
image = cv2.imread(str(path))
if image is None:
raise ValueError(f"Could not load image: {path}")
return image
def pdf_to_images(self, path: Path, dpi: int = 300) -> List[np.ndarray]:
"""
Convert PDF to images.
Args:
path: Path to PDF file
dpi: Resolution (300 = fast & good quality, 400 = better but slower)
"""
if not PDF_AVAILABLE:
raise RuntimeError("pdf2image not available. Install with: pip install pdf2image")
images = pdf2image.convert_from_path(str(path), dpi=dpi)
return [np.array(img) for img in images]
def preprocess(self, image: np.ndarray, high_quality: bool = True) -> np.ndarray:
"""
Apply LIGHT preprocessing - better for clear PDFs.
Heavy binarization can destroy text on clear images.
"""
return self.preprocess_light(image)
def preprocess_light(self, image: np.ndarray) -> np.ndarray:
"""
Light preprocessing for CLEAR images (PDFs, good scans).
Preserves original quality, only enhances contrast.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Light contrast enhancement only
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# NO binarization, NO morphological ops - preserve original quality
return enhanced
def preprocess_medium(self, image: np.ndarray) -> np.ndarray:
"""
Medium preprocessing for MIXED-QUALITY images.
Balance between Light (too gentle) and Heavy (too aggressive).
Use cases:
- Moderately faded receipts
- Photos with uneven lighting
- Scans with slight blur
Preprocessing steps:
- Moderate contrast enhancement (CLAHE clipLimit=2.0)
- Light denoising (fastNlMeansDenoising h=6)
- Gentle sharpening
- NO binarization (preserves text boundaries)
- NO morphological operations (avoids digit concatenation)
This method was created to replace preprocess_heavy() which caused
digit concatenation errors on high-quality PDFs (85.99 → 859,762.16).
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Moderate contrast enhancement (CLAHE clipLimit=2.0)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Light denoising (less aggressive than Heavy)
denoised = cv2.fastNlMeansDenoising(enhanced, h=6, templateWindowSize=7, searchWindowSize=15)
# 6. Gentle sharpening
gaussian = cv2.GaussianBlur(denoised, (0, 0), 1.0)
sharpened = cv2.addWeighted(denoised, 1.3, gaussian, -0.3, 0)
# NO binarization, NO morphological operations
# This preserves text boundaries and avoids digit concatenation
return sharpened
def preprocess_heavy(self, image: np.ndarray) -> np.ndarray:
"""
Heavy preprocessing for FADED thermal receipts.
Aggressive binarization to recover faded text.
⚠️ DEPRECATED: Use preprocess_medium() instead.
Heavy preprocessing causes digit concatenation on clear PDFs
(e.g., 85.99 → 859,762.16 due to binarization + morphological operations).
Kept for backward compatibility only.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small (larger = better OCR)
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Contrast enhancement with CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Denoise
denoised = cv2.fastNlMeansDenoising(enhanced, h=8, templateWindowSize=7, searchWindowSize=21)
# 6. Sharpening
gaussian = cv2.GaussianBlur(denoised, (0, 0), 2.0)
sharpened = cv2.addWeighted(denoised, 1.5, gaussian, -0.5, 0)
# 7. Adaptive thresholding (binarization)
binary = cv2.adaptiveThreshold(
sharpened, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
blockSize=11, C=5
)
# 8. Morphological operations
kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
result = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_close)
return result
def preprocess_for_tesseract(self, image: np.ndarray, binarize: bool = False,
padding: int = 0, clahe_clip: float = 1.5) -> np.ndarray:
"""
Tesseract-optimized preprocessing (based on comprehensive benchmark).
BENCHMARK FINDINGS:
- DPI 200 is optimal (not 300!)
- Padding 40px fixes left margin truncation issues
- CLAHE 1.5 for most receipts, 2.0 for difficult ones
- NO deskew, NO denoising for clear PDFs
Recommended usage:
- Simple receipts: padding=0, clahe_clip=1.5
- Complex receipts: padding=40, clahe_clip=1.5
- Difficult/faded: padding=40, clahe_clip=2.0, binarize=True
Args:
image: Input image (RGB from pdf2image or BGR from OpenCV)
binarize: Apply Otsu binarization (for faded receipts)
padding: White padding in pixels (40px recommended for edge protection)
clahe_clip: CLAHE clip limit (1.5 normal, 2.0 for difficult)
Returns:
Preprocessed grayscale image
"""
# 1. Grayscale (handle both RGB and BGR)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# 2. Add padding if specified (protects against left margin truncation)
if padding > 0:
gray = cv2.copyMakeBorder(
gray, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=255
)
# 3. CLAHE contrast enhancement
clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# NO deskew, NO denoising - these DEGRADE quality on clear PDFs!
if not binarize:
return enhanced
# Binarization only for faded receipts
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Ensure correct polarity
if np.mean(binary) < 127:
binary = 255 - binary
return binary
def preprocess_for_tesseract_padded(self, image: np.ndarray) -> np.ndarray:
"""
Tesseract preprocessing with optimal padding (40px).
Best for complex receipts where left margin gets truncated.
"""
return self.preprocess_for_tesseract(image, padding=40)
def preprocess_for_tesseract_faded(self, image: np.ndarray) -> np.ndarray:
"""
Tesseract preprocessing for FADED thermal receipts.
Uses binarization to recover faded text.
"""
return self.preprocess_for_tesseract(image, binarize=True)
def get_all_variants(self, image: np.ndarray) -> List[np.ndarray]:
"""
Generate 2 preprocessing variants for OCR (fast mode).
Returns: [light_processed, heavy_processed]
"""
return [
self.preprocess_light(image),
self.preprocess_heavy(image),
]
def _deskew(self, image: np.ndarray) -> np.ndarray:
"""Correct image rotation/skew using Hough lines.
Uses expanded canvas to preserve all content during rotation,
preventing left/right margin truncation.
"""
edges = cv2.Canny(image, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(
edges, 1, np.pi / 180,
threshold=100, minLineLength=100, maxLineGap=10
)
if lines is None:
return image
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi
if abs(angle) < 45:
angles.append(angle)
if not angles:
return image
median_angle = np.median(angles)
if abs(median_angle) < 0.5:
return image
h, w = image.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
# Calculate new canvas size to fit entire rotated image (prevents edge truncation)
cos_angle = abs(np.cos(np.radians(median_angle)))
sin_angle = abs(np.sin(np.radians(median_angle)))
new_w = int(h * sin_angle + w * cos_angle)
new_h = int(h * cos_angle + w * sin_angle)
# Adjust rotation matrix for new canvas center
M[0, 2] += (new_w - w) / 2
M[1, 2] += (new_h - h) / 2
return cv2.warpAffine(
image, M, (new_w, new_h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_CONSTANT,
borderValue=255 # White background (grayscale)
)

View File

@@ -0,0 +1,216 @@
"""Service for fetching nomenclatures from Oracle (read-only)."""
from typing import List, Optional
from decimal import Decimal
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.schemas.receipt import (
PartnerOption,
AccountOption,
CashRegisterOption,
ExpenseTypeOption,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
class NomenclatureService:
"""
Service for fetching nomenclatures.
In Phase 1 (MVP), some nomenclatures are hardcoded.
In Phase 2, these will be fetched from Oracle.
"""
@staticmethod
async def get_partners(
company_id: int,
search: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[PartnerOption]:
"""
Get partners (suppliers/customers) for a company.
Returns synced suppliers from Oracle + local suppliers created from OCR.
If no suppliers exist, returns empty list (frontend will trigger sync).
"""
partners = []
if not session:
return partners
# Get synced suppliers from Oracle
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.order_by(SyncedSupplier.name)
result = await session.execute(stmt)
suppliers = result.scalars().all()
for s in suppliers:
partners.append(PartnerOption(
name=s.name,
fiscal_code=s.fiscal_code,
address=s.address,
source="oracle"
))
# Always get local suppliers (not just when synced exist)
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
local_stmt = local_stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
local_stmt = local_stmt.order_by(LocalSupplier.name)
local_result = await session.execute(local_stmt)
local_suppliers = local_result.scalars().all()
for l in local_suppliers:
partners.append(PartnerOption(
name=l.name,
fiscal_code=l.fiscal_code,
address=l.address,
source="local"
))
return partners
@staticmethod
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Get chart of accounts for a company.
Phase 1: Returns common expense/income accounts.
Phase 2: Will fetch from Oracle PLAN_CONTURI.
"""
# Common accounts for expenses and receipts
accounts = [
# Expense accounts (Class 6)
AccountOption(code="6022", name="Cheltuieli cu combustibilii"),
AccountOption(code="6024", name="Cheltuieli materiale pentru ambalat"),
AccountOption(code="6028", name="Alte cheltuieli cu materiale consumabile"),
AccountOption(code="624", name="Cheltuieli cu transportul de bunuri si personal"),
AccountOption(code="626", name="Cheltuieli postale si taxe telecomunicatii"),
AccountOption(code="628", name="Alte cheltuieli cu serviciile executate de terti"),
# VAT
AccountOption(code="4426", name="TVA deductibila"),
AccountOption(code="4427", name="TVA colectata"),
# Cash and Bank (Class 5)
AccountOption(code="5311", name="Casa in lei"),
AccountOption(code="5121", name="Conturi la banci in lei"),
# Income accounts (Class 7)
AccountOption(code="7588", name="Alte venituri din exploatare"),
]
if prefix:
accounts = [a for a in accounts if a.code.startswith(prefix)]
return accounts
@staticmethod
async def get_cash_registers(
company_id: int,
session: Optional[AsyncSession] = None
) -> List[CashRegisterOption]:
"""
Get cash registers and bank accounts for a company.
Phase 1: Returns default options.
Phase 2: Returns synced data from SQLite (from Oracle sync).
Phase 3: Will fetch live from Oracle NOM_CASE / NOM_BANCI.
"""
# If session is provided, try to get from synced SQLite data
if session:
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
if registers:
return [
CashRegisterOption(id=r.id, name=r.name, account_code=r.account_code)
for r in registers
]
# Fallback to default cash registers for Phase 1
return [
CashRegisterOption(id=1, name="Casa principala", account_code="5311"),
CashRegisterOption(id=2, name="Cont BCR", account_code="5121"),
CashRegisterOption(id=3, name="Cont BRD", account_code="5121"),
]
@staticmethod
async def get_expense_types() -> List[ExpenseTypeOption]:
"""
Get predefined expense types with their accounting configuration.
"""
return [
ExpenseTypeOption(
code=et.code,
name=et.name,
account_code=et.account_code,
has_vat=et.has_vat,
vat_percent=et.vat_percent,
)
for et in EXPENSE_TYPES.values()
]
@staticmethod
async def get_companies(username: str) -> List[dict]:
"""
Get companies accessible by user.
Phase 1: Returns mock data.
Phase 2: Will fetch from shared auth based on user permissions.
"""
# TODO: Integrate with shared auth to get user's companies
return [
{"id": 1, "name": "SC Test SRL", "cui": "RO12345678"},
{"id": 2, "name": "SC Demo SA", "cui": "RO87654321"},
]
# ============ Phase 2 Oracle Integration Methods ============
@staticmethod
async def _fetch_partners_oracle(company_id: int, search: Optional[str] = None) -> List[PartnerOption]:
"""
Fetch partners from Oracle NOM_PARTENERI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
# Example query:
# SELECT ID_PART, DEN_PART, COD_FISCAL
# FROM {schema}.NOM_PARTENERI
# WHERE DEN_PART LIKE :search
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_accounts_oracle(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Fetch chart of accounts from Oracle PLAN_CONTURI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_cash_registers_oracle(company_id: int) -> List[CashRegisterOption]:
"""
Fetch cash registers from Oracle NOM_CASE / NOM_BANCI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")

View File

@@ -0,0 +1,42 @@
"""
OCR Services Module
Provides persistent OCR worker pool with job queue for efficient processing.
Components:
- ocr_worker_pool: Manages ProcessPoolExecutor with persistent PaddleOCR
- job_queue: SQLite-based job queue for async processing
- job_worker: Background task that processes queued jobs
- tesseract_engine: Optimized Tesseract with multi-PSM and polarity fix
Architecture:
FastAPI → job_queue.create_job() → SQLite
job_worker loop → ocr_worker_pool.submit_task() → Worker Process
PaddleOCR/Tesseract
"""
from .ocr_worker_pool import ocr_worker_pool, OCRWorkerPool
from .job_queue import job_queue, OCRJobQueue, OCRJob, OCRJobStatus
from .job_worker import start_job_worker, stop_job_worker
from .tesseract_engine import TesseractEngine
from .validation import OCRValidationEngine
__all__ = [
# Worker pool
"ocr_worker_pool",
"OCRWorkerPool",
# Job queue
"job_queue",
"OCRJobQueue",
"OCRJob",
"OCRJobStatus",
# Job worker
"start_job_worker",
"stop_job_worker",
# Engines
"TesseractEngine",
# Validation
"OCRValidationEngine",
]

View File

@@ -0,0 +1,653 @@
"""
SQLite Job Queue Manager for OCR Processing
Provides async job queue for OCR requests:
- Jobs are stored in SQLite for persistence
- Queue position and time estimation
- Automatic expiration after 24 hours
- Statistics for monitoring
Schema:
ocr_jobs (
id TEXT PRIMARY KEY, -- UUID
status TEXT NOT NULL, -- pending, processing, completed, failed
file_path TEXT NOT NULL, -- Path to uploaded file
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT, -- JSON extraction result
error_message TEXT,
processing_time_ms INTEGER, -- Total job time (started_at to completed_at)
ocr_time_ms INTEGER, -- Actual OCR engine processing time
created_by TEXT, -- Username
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER, -- Foreign key to batch_uploads (for bulk processing)
file_hash TEXT -- SHA-256 hash for duplicate detection (US-007)
)
"""
import asyncio
import json
from decimal import Decimal
class DecimalEncoder(json.JSONEncoder):
"""JSON encoder that handles Decimal types."""
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
import logging
import os
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiosqlite
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_QUEUE_DIR = Path(__file__).parent.parent.parent.parent.parent / "data" / "ocr_queue"
DEFAULT_DB_PATH = DEFAULT_QUEUE_DIR / "ocr_jobs.db"
DEFAULT_FILES_DIR = DEFAULT_QUEUE_DIR / "files"
# Job expiration
JOB_EXPIRY_HOURS = 24
# SQLite busy timeout (milliseconds) - prevents "database is locked" errors
SQLITE_BUSY_TIMEOUT_MS = 5000
class OCRJobStatus(str, Enum):
"""Job status enum."""
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
cancelled = "cancelled"
@dataclass
class OCRJob:
"""OCR Job data class."""
id: str
status: OCRJobStatus
file_path: str
mime_type: str
engine: str = "doctr_plus"
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result_json: Optional[str] = None
error_message: Optional[str] = None
processing_time_ms: Optional[int] = None # Total job time (started_at to completed_at)
ocr_time_ms: Optional[int] = None # Actual OCR engine processing time
created_by: Optional[str] = None
original_filename: Optional[str] = None
expires_at: Optional[datetime] = None
batch_id: Optional[int] = None # Links to batch_uploads table for bulk processing
file_hash: Optional[str] = None # SHA-256 hash for duplicate detection (US-007)
@property
def queue_wait_ms(self) -> Optional[int]:
"""Calculate queue wait time (created_at to started_at)."""
if self.created_at and self.started_at:
delta = self.started_at - self.created_at
return int(delta.total_seconds() * 1000)
return None
@property
def result(self) -> Optional[Dict]:
"""Parse result_json to dict."""
if self.result_json:
try:
return json.loads(self.result_json)
except json.JSONDecodeError:
return None
return None
class OCRJobQueue:
"""
SQLite-based job queue for OCR processing.
Provides async methods for job management with position
tracking and time estimation.
"""
def __init__(
self,
db_path: Optional[Path] = None,
files_dir: Optional[Path] = None
):
"""
Initialize job queue.
Args:
db_path: Path to SQLite database (default: data/ocr_queue/ocr_jobs.db)
files_dir: Path to files directory (default: data/ocr_queue/files/)
"""
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self.files_dir = Path(files_dir) if files_dir else DEFAULT_FILES_DIR
self._lock = asyncio.Lock()
self._initialized = False
async def initialize(self) -> None:
"""
Initialize database and directories.
Creates SQLite database and tables if they don't exist.
Creates files directory for uploaded files.
"""
if self._initialized:
return
# Create directories
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.files_dir.mkdir(parents=True, exist_ok=True)
# Create database and tables
async with aiosqlite.connect(str(self.db_path)) as db:
# Enable WAL mode for better concurrency and set busy timeout
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
CREATE TABLE IF NOT EXISTS ocr_jobs (
id TEXT PRIMARY KEY,
status TEXT NOT NULL DEFAULT 'pending',
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL,
engine TEXT DEFAULT 'doctr_plus',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
result_json TEXT,
error_message TEXT,
processing_time_ms INTEGER,
ocr_time_ms INTEGER,
created_by TEXT,
original_filename TEXT,
expires_at TIMESTAMP,
batch_id INTEGER
)
''')
# Migration: add ocr_time_ms column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN ocr_time_ms INTEGER')
logger.info("[OCRJobQueue] Added ocr_time_ms column to existing table")
except Exception:
pass # Column already exists
# Migration: add batch_id column if it doesn't exist
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN batch_id INTEGER')
logger.info("[OCRJobQueue] Added batch_id column to existing table")
except Exception:
pass # Column already exists
# Migration: add file_hash column if it doesn't exist (US-007)
try:
await db.execute('ALTER TABLE ocr_jobs ADD COLUMN file_hash TEXT')
logger.info("[OCRJobQueue] Added file_hash column to existing table")
except Exception:
pass # Column already exists
# Index for efficient queue queries
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_status
ON ocr_jobs(status, created_at)
''')
# Index for expiration cleanup
await db.execute('''
CREATE INDEX IF NOT EXISTS idx_ocr_jobs_expires
ON ocr_jobs(expires_at)
''')
await db.commit()
self._initialized = True
logger.info(f"[OCRJobQueue] Initialized: db={self.db_path}, files={self.files_dir}")
async def create_job(
self,
file_bytes: bytes,
mime_type: str,
engine: str = "doctr_plus",
username: Optional[str] = None,
original_filename: Optional[str] = None,
batch_id: Optional[int] = None,
file_hash: Optional[str] = None
) -> OCRJob:
"""
Create a new OCR job.
Saves file to disk and creates database record.
Args:
file_bytes: Raw file bytes
mime_type: MIME type of file
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
username: Username of requester
original_filename: Original filename from upload
batch_id: Optional batch ID for bulk upload processing
file_hash: Optional SHA-256 hash for duplicate detection (US-007)
Returns:
Created OCRJob instance
"""
await self.initialize()
# Generate job ID
job_id = str(uuid.uuid4())
# Determine file extension
ext_map = {
'image/jpeg': '.jpg',
'image/png': '.png',
'application/pdf': '.pdf',
}
ext = ext_map.get(mime_type, '.bin')
# Save file
file_path = self.files_dir / f"{job_id}{ext}"
with open(file_path, 'wb') as f:
f.write(file_bytes)
# Calculate expiration
now = datetime.utcnow()
expires_at = now + timedelta(hours=JOB_EXPIRY_HOURS)
# Insert job record
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await db.execute('''
INSERT INTO ocr_jobs (
id, status, file_path, mime_type, engine,
created_at, created_by, original_filename, expires_at, batch_id, file_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
job_id, OCRJobStatus.pending.value, str(file_path), mime_type, engine,
now.isoformat(), username, original_filename, expires_at.isoformat(), batch_id, file_hash
))
await db.commit()
logger.info(f"[OCRJobQueue] Created job {job_id}: engine={engine}, file={file_path.name}, batch_id={batch_id}")
return OCRJob(
id=job_id,
status=OCRJobStatus.pending,
file_path=str(file_path),
mime_type=mime_type,
engine=engine,
created_at=now,
created_by=username,
original_filename=original_filename,
expires_at=expires_at,
batch_id=batch_id,
file_hash=file_hash
)
async def get_job(self, job_id: str) -> Optional[OCRJob]:
"""
Get job by ID.
Args:
job_id: Job UUID
Returns:
OCRJob or None if not found
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_job(row)
return None
async def get_queue_position(self, job_id: str) -> Optional[int]:
"""
Get position in queue for a pending job.
Args:
job_id: Job UUID
Returns:
Queue position (1 = next to process) or None if not pending
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
# Check if job is pending
async with db.execute(
'SELECT status, created_at FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
row = await cursor.fetchone()
if not row or row[0] != OCRJobStatus.pending.value:
return None
job_created_at = row[1]
# Count jobs ahead in queue (created before this job)
async with db.execute('''
SELECT COUNT(*) FROM ocr_jobs
WHERE status = 'pending' AND created_at < ?
''', (job_created_at,)) as cursor:
count = await cursor.fetchone()
return (count[0] + 1) if count else 1
async def get_next_pending(self) -> Optional[OCRJob]:
"""
Get the next pending job (oldest first) and atomically mark it as processing.
This prevents race conditions in parallel processing - only one worker
can claim each job.
Returns:
Next OCRJob to process or None if queue empty
"""
await self.initialize()
now = datetime.utcnow()
async with self._lock: # Serialize access to prevent race conditions
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get the next pending job
async with db.execute('''
SELECT * FROM ocr_jobs
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 1
''') as cursor:
row = await cursor.fetchone()
if not row:
return None
job_id = row['id']
# Atomically mark as processing
await db.execute('''
UPDATE ocr_jobs
SET status = 'processing', started_at = ?
WHERE id = ? AND status = 'pending'
''', (now.isoformat(), job_id))
await db.commit()
# Fetch the updated job
async with db.execute(
'SELECT * FROM ocr_jobs WHERE id = ?',
(job_id,)
) as cursor:
updated_row = await cursor.fetchone()
if updated_row:
return self._row_to_job(updated_row)
return None
async def update_status(
self,
job_id: str,
status: OCRJobStatus,
result: Optional[Dict] = None,
error: Optional[str] = None,
processing_time_ms: Optional[int] = None,
ocr_time_ms: Optional[int] = None
) -> bool:
"""
Update job status.
Args:
job_id: Job UUID
status: New status
result: Extraction result dict (for completed)
error: Error message (for failed)
processing_time_ms: Total job processing time (started_at to completed_at)
ocr_time_ms: Actual OCR engine processing time
Returns:
True if update successful
"""
await self.initialize()
now = datetime.utcnow()
result_json = json.dumps(result, cls=DecimalEncoder) if result else None
# Build update query based on status
if status == OCRJobStatus.processing:
query = '''
UPDATE ocr_jobs
SET status = ?, started_at = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), job_id)
elif status == OCRJobStatus.completed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, result_json = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), result_json, processing_time_ms, ocr_time_ms, job_id)
elif status == OCRJobStatus.failed:
query = '''
UPDATE ocr_jobs
SET status = ?, completed_at = ?, error_message = ?, processing_time_ms = ?, ocr_time_ms = ?
WHERE id = ?
'''
params = (status.value, now.isoformat(), error, processing_time_ms, ocr_time_ms, job_id)
else:
query = 'UPDATE ocr_jobs SET status = ? WHERE id = ?'
params = (status.value, job_id)
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
cursor = await db.execute(query, params)
await db.commit()
return cursor.rowcount > 0
async def get_average_processing_time(self) -> float:
"""
Calculate average processing time from recent completed jobs.
Uses last 50 completed jobs for accuracy.
Returns:
Average time in seconds (default 7.0 if no data)
"""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT AVG(processing_time_ms)
FROM (
SELECT processing_time_ms FROM ocr_jobs
WHERE status = 'completed' AND processing_time_ms IS NOT NULL
ORDER BY completed_at DESC
LIMIT 50
)
''') as cursor:
row = await cursor.fetchone()
if row and row[0]:
return row[0] / 1000.0 # Convert ms to seconds
return 7.0 # Default estimate
async def count_pending(self) -> int:
"""Count pending jobs in queue."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.pending.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def count_processing(self) -> int:
"""Count currently processing jobs."""
await self.initialize()
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute(
'SELECT COUNT(*) FROM ocr_jobs WHERE status = ?',
(OCRJobStatus.processing.value,)
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def cleanup_expired(self) -> int:
"""
Delete expired jobs and their files.
Returns:
Number of jobs deleted
"""
await self.initialize()
now = datetime.utcnow()
deleted = 0
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
db.row_factory = aiosqlite.Row
# Get expired jobs
async with db.execute('''
SELECT id, file_path FROM ocr_jobs
WHERE expires_at < ?
''', (now.isoformat(),)) as cursor:
rows = await cursor.fetchall()
for row in rows:
# Delete file
file_path = Path(row['file_path'])
if file_path.exists():
try:
file_path.unlink()
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
# Delete job record
await db.execute('DELETE FROM ocr_jobs WHERE id = ?', (row['id'],))
deleted += 1
await db.commit()
if deleted > 0:
logger.info(f"[OCRJobQueue] Cleaned up {deleted} expired job(s)")
return deleted
async def cleanup_job_file(self, job_id: str) -> bool:
"""
Delete the file associated with a job.
Called after processing to free disk space.
Args:
job_id: Job UUID
Returns:
True if file deleted
"""
job = await self.get_job(job_id)
if job:
file_path = Path(job.file_path)
if file_path.exists():
try:
file_path.unlink()
return True
except Exception as e:
logger.warning(f"[OCRJobQueue] Failed to delete file {file_path}: {e}")
return False
async def get_queue_stats(self) -> Dict[str, Any]:
"""
Get queue statistics.
Returns:
Dict with pending, processing, completed, failed counts
"""
await self.initialize()
stats = {
"pending": 0,
"processing": 0,
"completed": 0,
"failed": 0,
"average_time_seconds": 0.0,
}
async with aiosqlite.connect(str(self.db_path)) as db:
await db.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
async with db.execute('''
SELECT status, COUNT(*) as count
FROM ocr_jobs
GROUP BY status
''') as cursor:
rows = await cursor.fetchall()
for row in rows:
if row[0] in stats:
stats[row[0]] = row[1]
stats["average_time_seconds"] = await self.get_average_processing_time()
return stats
def _row_to_job(self, row: aiosqlite.Row) -> OCRJob:
"""Convert database row to OCRJob."""
def parse_datetime(val):
if val:
try:
return datetime.fromisoformat(val)
except (ValueError, TypeError):
return None
return None
return OCRJob(
id=row['id'],
status=OCRJobStatus(row['status']),
file_path=row['file_path'],
mime_type=row['mime_type'],
engine=row['engine'] or 'doctr_plus',
created_at=parse_datetime(row['created_at']),
started_at=parse_datetime(row['started_at']),
completed_at=parse_datetime(row['completed_at']),
result_json=row['result_json'],
error_message=row['error_message'],
processing_time_ms=row['processing_time_ms'],
ocr_time_ms=row['ocr_time_ms'] if 'ocr_time_ms' in row.keys() else None,
created_by=row['created_by'],
original_filename=row['original_filename'],
expires_at=parse_datetime(row['expires_at']),
batch_id=row['batch_id'] if 'batch_id' in row.keys() else None,
file_hash=row['file_hash'] if 'file_hash' in row.keys() else None,
)
# Singleton instance
job_queue = OCRJobQueue()

View File

@@ -0,0 +1,665 @@
"""
OCR Job Worker - Background Task for Queue Processing
Runs as an asyncio background task in FastAPI.
Continuously polls the job queue and processes OCR requests IN PARALLEL.
Architecture:
FastAPI startup
start_job_worker()
asyncio.create_task(_job_worker_loop())
while True:
# Process up to OCR_WORKERS jobs concurrently
jobs = get_pending_jobs(limit=available_slots)
for job in jobs:
asyncio.create_task(_process_job(job))
await asyncio.sleep(0.1)
"""
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import Optional, Set
from .job_queue import job_queue, OCRJobStatus, OCRJob
from .ocr_worker_pool import ocr_worker_pool
from backend.modules.data_entry.schemas.ocr import ExtractionData
logger = logging.getLogger(__name__)
# Global task reference
_job_worker_task: Optional[asyncio.Task] = None
_cleanup_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: Set[asyncio.Task] = set() # Track active job tasks
_concurrency_semaphore: Optional[asyncio.Semaphore] = None # Limit concurrent jobs
# Configuration
POLL_INTERVAL_SECONDS = 0.1 # How often to check for new jobs (faster for parallel)
CLEANUP_INTERVAL_SECONDS = 3600 # Clean expired jobs every hour
OCR_TIMEOUT_SECONDS = 120 # Max time for OCR processing
async def _job_worker_loop() -> None:
"""
Main worker loop - processes jobs from queue IN PARALLEL.
Runs continuously until shutdown. Uses semaphore to limit
concurrent jobs to OCR_WORKERS count. Launches jobs as
background tasks without waiting for completion.
"""
global _shutdown_event, _active_tasks, _concurrency_semaphore
# Get max concurrent jobs from env (matches worker pool size)
max_concurrent = int(os.getenv('OCR_WORKERS', '2'))
_concurrency_semaphore = asyncio.Semaphore(max_concurrent)
_active_tasks = set()
logger.info(f"[JobWorker] Starting PARALLEL worker loop (max_concurrent={max_concurrent})...")
_shutdown_event = asyncio.Event()
consecutive_errors = 0
max_consecutive_errors = 10
while not _shutdown_event.is_set():
try:
# Clean up completed tasks
done_tasks = {t for t in _active_tasks if t.done()}
for task in done_tasks:
_active_tasks.discard(task)
# Check for exceptions
try:
task.result()
except Exception as e:
logger.error(f"[JobWorker] Task failed: {e}")
# Check if we have capacity for more jobs
active_count = len(_active_tasks)
available_slots = max_concurrent - active_count
if available_slots > 0:
# Get next pending job
job = await job_queue.get_next_pending()
if job:
consecutive_errors = 0
# Launch job processing as background task
task = asyncio.create_task(_process_job_with_semaphore(job))
_active_tasks.add(task)
logger.debug(f"[JobWorker] Launched job {job.id} (active={len(_active_tasks)}/{max_concurrent})")
else:
# No pending jobs - wait briefly
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=POLL_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass
else:
# At capacity - wait for a slot to free up
await asyncio.sleep(POLL_INTERVAL_SECONDS)
except asyncio.CancelledError:
logger.info("[JobWorker] Worker loop cancelled")
break
except Exception as e:
consecutive_errors += 1
logger.error(f"[JobWorker] Error in worker loop ({consecutive_errors}/{max_consecutive_errors}): {e}")
if consecutive_errors >= max_consecutive_errors:
logger.error("[JobWorker] Too many consecutive errors, stopping worker")
break
await asyncio.sleep(min(consecutive_errors * 2, 30))
# Wait for active tasks to complete on shutdown
if _active_tasks:
logger.info(f"[JobWorker] Waiting for {len(_active_tasks)} active tasks to complete...")
await asyncio.gather(*_active_tasks, return_exceptions=True)
logger.info("[JobWorker] Worker loop stopped")
async def _process_job_with_semaphore(job: OCRJob) -> None:
"""
Process job with semaphore to limit concurrency.
Acquires semaphore before processing, releases after.
This ensures we don't exceed OCR_WORKERS concurrent jobs.
"""
global _concurrency_semaphore
async with _concurrency_semaphore:
await _process_job(job)
async def _process_job(job: OCRJob) -> None:
"""
Process a single OCR job.
Reads file, submits to worker pool, updates job status,
and saves metrics for analytics.
Args:
job: OCRJob to process
"""
logger.info(f"[JobWorker] Processing job {job.id}: engine={job.engine}, file={Path(job.file_path).name}")
start_time = time.time()
file_size = 0
file_type = "image/jpeg"
try:
# Note: Job already marked as 'processing' atomically in get_next_pending()
# Read file bytes
file_path = Path(job.file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_size = len(file_bytes)
# Determine file type from job or extension
file_type = getattr(job, 'mime_type', 'image/jpeg') or 'image/jpeg'
# Submit to worker pool
result = await ocr_worker_pool.submit_task(
image_bytes=file_bytes,
engine=job.engine,
preprocessing="auto",
timeout=OCR_TIMEOUT_SECONDS
)
elapsed_ms = int((time.time() - start_time) * 1000)
if result.get("success"):
# Job completed successfully
extraction = result.get("extraction", {})
# Include raw_texts for analysis (from all OCR engine passes)
extraction['raw_texts'] = result.get("raw_texts", [])
# Extract actual OCR processing time from extraction result
ocr_time_ms = extraction.get('processing_time_ms', 0)
# Debug: log suggested_payment_mode
spm = extraction.get('suggested_payment_mode')
logger.info(f"[JobWorker] Job {job.id} extraction has suggested_payment_mode={spm}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.completed,
result=extraction,
processing_time_ms=elapsed_ms,
ocr_time_ms=ocr_time_ms
)
logger.info(f"[JobWorker] Job {job.id} completed in {elapsed_ms}ms (ocr: {ocr_time_ms}ms)")
# Save metrics for successful job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=extraction.get('ocr_engine', job.engine),
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=True,
overall_confidence=extraction.get('overall_confidence', 0.0),
fields_extracted=_count_extracted_fields(extraction),
needs_manual_review=extraction.get('needs_manual_review'),
validation_warnings_count=len(extraction.get('validation_warnings', [])),
validation_errors_count=len(extraction.get('validation_errors', [])),
)
# Auto-save receipt for batch jobs
if job.batch_id:
auto_save_result = await _auto_save_batch_receipt(
job=job,
extraction=extraction,
file_path=str(file_path)
)
if not auto_save_result:
# Auto-save failed - mark job as failed
# Note: job_queue status already updated to 'completed' above
# We need to update it back to failed with the auto-save error
logger.warning(
f"[JobWorker] Job {job.id} OCR succeeded but auto-save failed"
)
else:
# Job failed
error_msg = result.get("error", "Unknown error")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=error_msg,
processing_time_ms=elapsed_ms
)
logger.warning(f"[JobWorker] Job {job.id} failed after {elapsed_ms}ms: {error_msg}")
# Save metrics for failed job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=error_msg,
)
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
logger.error(f"[JobWorker] Job {job.id} error after {elapsed_ms}ms: {e}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=str(e),
processing_time_ms=elapsed_ms
)
# Save metrics for error job
await _save_job_metrics(
job_id=job.id,
username=job.created_by or 'unknown',
engine_requested=job.engine,
engine_used=job.engine,
processing_time_ms=elapsed_ms,
file_size_bytes=file_size,
file_type=file_type,
original_filename=job.original_filename,
success=False,
error_message=str(e),
)
finally:
# Cleanup file after processing
try:
await job_queue.cleanup_job_file(job.id)
except Exception as e:
logger.warning(f"[JobWorker] Failed to cleanup file for job {job.id}: {e}")
async def _cleanup_loop() -> None:
"""
Periodic cleanup of expired jobs.
Runs every hour to delete jobs older than 24 hours.
"""
global _shutdown_event
logger.info("[JobWorker] Starting cleanup loop...")
while not _shutdown_event.is_set():
try:
# Wait for interval or shutdown
try:
await asyncio.wait_for(
_shutdown_event.wait(),
timeout=CLEANUP_INTERVAL_SECONDS
)
if _shutdown_event.is_set():
break
except asyncio.TimeoutError:
pass # Normal timeout, do cleanup
# Run cleanup
deleted = await job_queue.cleanup_expired()
if deleted > 0:
logger.info(f"[JobWorker] Cleanup: deleted {deleted} expired jobs")
except asyncio.CancelledError:
logger.info("[JobWorker] Cleanup loop cancelled")
break
except Exception as e:
logger.error(f"[JobWorker] Cleanup error: {e}")
await asyncio.sleep(60) # Retry after 1 minute
logger.info("[JobWorker] Cleanup loop stopped")
async def start_job_worker() -> bool:
"""
Start the job worker background task.
Called at FastAPI startup to begin processing queue.
Returns:
True if started successfully
"""
global _job_worker_task, _cleanup_task, _shutdown_event
if _job_worker_task is not None and not _job_worker_task.done():
logger.warning("[JobWorker] Already running")
return True
try:
# Initialize job queue
await job_queue.initialize()
# Initialize worker pool
if not ocr_worker_pool.initialize():
logger.error("[JobWorker] Failed to initialize worker pool")
return False
# Pre-warm worker pool in BACKGROUND (don't block startup)
# First OCR request may be slower if prewarm isn't done yet
async def _background_prewarm():
logger.info("[JobWorker] Pre-warming OCR worker pool (background)...")
warmup_success = await ocr_worker_pool.prewarm(timeout=90.0)
if warmup_success:
logger.info("[JobWorker] OCR worker pool pre-warmed successfully")
else:
logger.warning("[JobWorker] Worker pool pre-warm failed, first request will be slower")
asyncio.create_task(_background_prewarm())
# Start worker loop
_shutdown_event = asyncio.Event()
_job_worker_task = asyncio.create_task(_job_worker_loop())
# Start cleanup loop
_cleanup_task = asyncio.create_task(_cleanup_loop())
logger.info("[JobWorker] Started successfully")
return True
except Exception as e:
logger.error(f"[JobWorker] Failed to start: {e}")
return False
async def stop_job_worker() -> None:
"""
Stop the job worker background task.
Called at FastAPI shutdown to gracefully stop processing.
"""
global _job_worker_task, _cleanup_task, _shutdown_event
logger.info("[JobWorker] Stopping...")
# Signal shutdown
if _shutdown_event:
_shutdown_event.set()
# Cancel worker task
if _job_worker_task and not _job_worker_task.done():
_job_worker_task.cancel()
try:
await _job_worker_task
except asyncio.CancelledError:
pass
# Cancel cleanup task
if _cleanup_task and not _cleanup_task.done():
_cleanup_task.cancel()
try:
await _cleanup_task
except asyncio.CancelledError:
pass
# Shutdown worker pool
ocr_worker_pool.shutdown(wait=True)
_job_worker_task = None
_cleanup_task = None
_shutdown_event = None
logger.info("[JobWorker] Stopped")
def is_running() -> bool:
"""Check if job worker is running."""
return _job_worker_task is not None and not _job_worker_task.done()
def estimate_wait_time(queue_position: int) -> int:
"""
Estimate wait time for a job in queue.
Args:
queue_position: Position in queue (1 = next)
Returns:
Estimated wait time in seconds
"""
if queue_position <= 0:
return 0
# Get average processing time (synchronous fallback)
# Default ~7 seconds per job if no data
avg_time = 7.0
try:
# Try to get from queue stats
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
# Can't use sync call in async context, use default
pass
else:
avg_time = loop.run_until_complete(job_queue.get_average_processing_time())
except Exception:
pass
# Estimate: position * average_time
return int(queue_position * avg_time)
# ============================================================================
# Metrics Helper Functions
# ============================================================================
async def _save_job_metrics(
job_id: str,
username: str,
engine_requested: str,
engine_used: str,
processing_time_ms: int = 0,
file_size_bytes: int = 0,
file_type: str = "image/jpeg",
original_filename: Optional[str] = None,
success: bool = True,
error_message: Optional[str] = None,
overall_confidence: float = 0.0,
fields_extracted: int = 0,
needs_manual_review: Optional[bool] = None,
validation_warnings_count: int = 0,
validation_errors_count: int = 0,
) -> None:
"""
Save OCR job metrics to database for analytics.
Called after each job completes (success or failure).
Errors are logged but don't affect job processing.
"""
try:
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.crud.ocr_settings import OCRMetricsCRUD
async with await get_db_session() as session:
await OCRMetricsCRUD.create(
session=session,
job_id=job_id,
username=username,
engine_requested=engine_requested,
engine_used=engine_used,
processing_time_ms=processing_time_ms,
file_size_bytes=file_size_bytes,
file_type=file_type,
original_filename=original_filename,
success=success,
error_message=error_message,
overall_confidence=overall_confidence,
fields_extracted=fields_extracted,
needs_manual_review=needs_manual_review,
validation_warnings_count=validation_warnings_count,
validation_errors_count=validation_errors_count,
)
logger.debug(f"[JobWorker] Saved metrics for job {job_id}")
except Exception as e:
# Log but don't fail - metrics are nice-to-have
logger.warning(f"[JobWorker] Failed to save metrics for job {job_id}: {e}")
def _count_extracted_fields(extraction: dict) -> int:
"""
Count number of successfully extracted fields from OCR result.
Counts non-None values in key fields.
"""
key_fields = [
'receipt_number',
'receipt_date',
'amount',
'partner_name',
'cui',
'tva_total',
'address',
'items_count',
]
count = 0
for field in key_fields:
value = extraction.get(field)
if value is not None and value != '' and value != []:
count += 1
# Also count TVA entries if present
tva_entries = extraction.get('tva_entries', [])
if tva_entries and len(tva_entries) > 0:
count += 1
# Count payment methods if present
payment_methods = extraction.get('payment_methods', [])
if payment_methods and len(payment_methods) > 0:
count += 1
return count
# ============================================================================
# Auto-Save Batch Receipt Helper
# ============================================================================
async def _auto_save_batch_receipt(
job: OCRJob,
extraction: dict,
file_path: str
) -> bool:
"""
Automatically create a receipt from OCR result for batch jobs.
Called when a batch job completes successfully. Creates the receipt,
attachment, and accounting entries using ReceiptAutoCreateService.
Args:
job: Completed OCRJob with batch_id set
extraction: OCR extraction result dict
file_path: Path to the original uploaded file
Returns:
True if receipt created successfully, False otherwise
"""
if not job.batch_id:
return True # Not a batch job, nothing to do
logger.info(f"[JobWorker] Auto-saving receipt for batch job {job.id} (batch_id={job.batch_id})")
try:
# Import here to avoid circular imports
from backend.modules.data_entry.db.database import get_db_session
from backend.modules.data_entry.db.models import BatchUpload
from backend.modules.data_entry.services.receipt_auto_create import ReceiptAutoCreateService
from sqlalchemy import select
# Convert extraction dict to ExtractionData schema
ocr_result = ExtractionData(**extraction)
async with await get_db_session() as session:
# Get batch info to retrieve company_id and user_id
batch_result = await session.execute(
select(BatchUpload).where(BatchUpload.id == job.batch_id)
)
batch = batch_result.scalar_one_or_none()
if not batch:
error_msg = f"Batch {job.batch_id} not found"
logger.error(f"[JobWorker] Auto-save failed for job {job.id}: {error_msg}")
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
# Call ReceiptAutoCreateService
result = await ReceiptAutoCreateService.create_from_ocr_result(
session=session,
job_id=job.id,
ocr_result=ocr_result,
username=job.created_by or batch.user_id,
batch_id=job.batch_id,
company_id=batch.company_id,
file_path=file_path,
original_filename=job.original_filename,
file_hash=job.file_hash # Pass file_hash for duplicate detection (US-007)
)
if result.success:
logger.info(
f"[JobWorker] Auto-save successful for job {job.id}: "
f"receipt_id={result.receipt_id}"
)
return True
else:
error_msg = result.error_message or "Unknown error"
logger.warning(
f"[JobWorker] Auto-save validation failed for job {job.id}: {error_msg}"
)
# Update job status to failed with the auto-save error
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
return False
except Exception as e:
error_msg = str(e)
logger.error(f"[JobWorker] Auto-save exception for job {job.id}: {error_msg}")
# Update job status to failed
try:
await job_queue.update_status(
job_id=job.id,
status=OCRJobStatus.failed,
error=f"Auto-save error: {error_msg}"
)
except Exception as update_err:
logger.error(f"[JobWorker] Failed to update job status after auto-save error: {update_err}")
return False

View File

@@ -0,0 +1,561 @@
"""
OCR Worker Pool Manager
Manages a ProcessPoolExecutor with persistent OCR engine initialization.
Key features:
- ProcessPoolExecutor with configurable max_workers (from OCR_WORKERS env)
- Configurable max_tasks_per_child (from OCR_MAX_TASKS_PER_CHILD env, 0=no restart)
- mp_context='spawn' for Windows IIS compatibility
- docTR/PaddleOCR loaded ONCE at worker spawn (not 30s per request)
- atexit + signal handlers for cleanup
- Health check with auto-respawn
- Orphan process cleanup on Windows
Architecture:
Main Process │ Worker Process (PERSISTENT)
──────────────────────│──────────────────────────────────
OCRWorkerPool │ Worker initialized once
↓ │ ↓
submit_task() ────────│────→ process_ocr()
↓ │ ↓
Future.result() ←─────│──── Return result
"""
import asyncio
import atexit
import gc
import logging
import multiprocessing as mp
import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
# Try to import psutil for orphan process cleanup
try:
import psutil
PSUTIL_AVAILABLE = True
except ImportError:
PSUTIL_AVAILABLE = False
logger.warning("[OCRWorkerPool] psutil not available - orphan cleanup disabled")
class OCRWorkerPool:
"""
Singleton manager for OCR ProcessPoolExecutor.
Ensures OCR engines are loaded once and reused for all requests.
Uses max_tasks_per_child=5 to restart worker every 5 tasks (prevents memory leak).
"""
_instance: Optional["OCRWorkerPool"] = None
_initialized: bool = False
def __new__(cls) -> "OCRWorkerPool":
"""Singleton pattern - only one pool instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize worker pool (runs only once due to singleton)."""
if self._initialized:
return
self._executor: Optional[ProcessPoolExecutor] = None
self._worker_pid: Optional[int] = None
self._is_warming: bool = False
self._is_shutdown: bool = False
self._lock = asyncio.Lock() if asyncio.get_event_loop_policy() else None
self._sync_lock = mp.Lock()
# Register cleanup handlers
# NOTE: Only use atexit, NOT signal handlers!
# Signal handlers interfere with FastAPI's shutdown handling.
# FastAPI's shutdown event calls stop_job_worker() which calls shutdown().
atexit.register(self._cleanup_on_exit)
self._initialized = True
logger.info("[OCRWorkerPool] Singleton instance created")
def initialize(self) -> bool:
"""
Initialize the ProcessPoolExecutor.
Creates executor with spawn context for Windows compatibility.
Uses max_tasks_per_child=5 to restart worker periodically (prevents memory leak).
Returns:
True if initialization successful
"""
if self._executor is not None:
logger.warning("[OCRWorkerPool] Already initialized")
return True
if self._is_shutdown:
logger.error("[OCRWorkerPool] Cannot initialize - pool is shutdown")
return False
try:
# Cleanup any orphan workers from previous runs
self._cleanup_orphan_workers()
# Read configuration from environment
max_workers = int(os.getenv('OCR_WORKERS', '2'))
max_tasks_raw = os.getenv('OCR_MAX_TASKS_PER_CHILD', '0')
# 0 means no restart (None in ProcessPoolExecutor)
max_tasks_per_child = int(max_tasks_raw) if max_tasks_raw and int(max_tasks_raw) > 0 else None
# Create executor with spawn context (Windows compatible)
# Use mp_context='spawn' explicitly for cross-platform consistency
mp_context = mp.get_context('spawn')
# max_tasks_per_child only available in Python 3.11+
executor_kwargs = {
'max_workers': max_workers,
'mp_context': mp_context,
'initializer': _worker_initializer,
}
if sys.version_info >= (3, 11) and max_tasks_per_child is not None:
executor_kwargs['max_tasks_per_child'] = max_tasks_per_child
else:
logger.info(f"[OCRWorkerPool] max_tasks_per_child not supported (Python {sys.version_info.major}.{sys.version_info.minor})")
self._executor = ProcessPoolExecutor(**executor_kwargs)
logger.info(f"[OCRWorkerPool] ProcessPoolExecutor created (spawn context, max_workers={max_workers}, max_tasks_per_child={max_tasks_per_child})")
return True
except Exception as e:
logger.error(f"[OCRWorkerPool] Initialization failed: {e}")
return False
async def prewarm(self, timeout: float = 60.0) -> bool:
"""
Pre-warm the worker by loading PaddleOCR before first request.
This is called at FastAPI startup to avoid 30s delay on first request.
Submits a dummy task that triggers PaddleOCR initialization.
Args:
timeout: Maximum seconds to wait for warmup (default 60s)
Returns:
True if warmup successful, False if timeout or error
"""
if self._executor is None:
logger.error("[OCRWorkerPool] Cannot prewarm - not initialized")
return False
if self._is_warming:
logger.warning("[OCRWorkerPool] Already warming up")
return False
self._is_warming = True
logger.info("[OCRWorkerPool] Starting pre-warm (loading PaddleOCR in worker)...")
start_time = time.time()
try:
# Submit warmup task that initializes PaddleOCR
loop = asyncio.get_event_loop()
future = self._executor.submit(_warmup_task)
# Wait with timeout
result = await loop.run_in_executor(None, future.result, timeout)
elapsed = time.time() - start_time
if result.get("success"):
logger.info(f"[OCRWorkerPool] Pre-warm complete in {elapsed:.1f}s - PaddleOCR ready")
self._worker_pid = result.get("pid")
return True
else:
logger.error(f"[OCRWorkerPool] Pre-warm failed: {result.get('error')}")
return False
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"[OCRWorkerPool] Pre-warm failed after {elapsed:.1f}s: {e}")
return False
finally:
self._is_warming = False
async def submit_task(
self,
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto",
timeout: float = 120.0
) -> dict:
"""
Submit OCR task to worker process.
Args:
image_bytes: Raw image bytes
engine: OCR engine ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode ('light', 'medium', 'heavy', 'auto')
timeout: Maximum processing time in seconds
Returns:
Dict with extraction results
Raises:
RuntimeError: If pool not initialized or task fails
"""
if self._executor is None:
raise RuntimeError("OCR worker pool not initialized")
if self._is_shutdown:
raise RuntimeError("OCR worker pool is shutdown")
logger.info(f"[OCRWorkerPool] Submitting task: engine={engine}, preprocessing={preprocessing}, size={len(image_bytes)} bytes")
try:
loop = asyncio.get_event_loop()
future = self._executor.submit(
_process_ocr_task,
image_bytes,
engine,
preprocessing
)
# Wait for result with timeout
result = await loop.run_in_executor(None, future.result, timeout)
logger.info(f"[OCRWorkerPool] Task complete: success={result.get('success')}")
return result
except TimeoutError:
logger.error(f"[OCRWorkerPool] Task timed out after {timeout}s")
raise RuntimeError(f"OCR processing timed out after {timeout}s")
except Exception as e:
logger.error(f"[OCRWorkerPool] Task failed: {e}")
raise RuntimeError(f"OCR processing failed: {e}")
def is_healthy(self) -> bool:
"""
Check if worker pool is healthy.
Returns:
True if pool is ready to accept tasks
"""
if self._executor is None:
return False
if self._is_shutdown:
return False
# Check if worker process is still alive
if self._worker_pid and PSUTIL_AVAILABLE:
try:
proc = psutil.Process(self._worker_pid)
if not proc.is_running():
logger.warning("[OCRWorkerPool] Worker process died, needs respawn")
return False
except psutil.NoSuchProcess:
logger.warning("[OCRWorkerPool] Worker process not found")
return False
return True
def shutdown(self, wait: bool = True, timeout: float = 10.0) -> None:
"""
Shutdown the worker pool gracefully.
Args:
wait: Wait for pending tasks to complete
timeout: Maximum wait time in seconds
"""
if self._executor is None:
return
logger.info("[OCRWorkerPool] Shutting down...")
self._is_shutdown = True
try:
self._executor.shutdown(wait=wait, cancel_futures=True)
logger.info("[OCRWorkerPool] Executor shutdown complete")
except Exception as e:
logger.error(f"[OCRWorkerPool] Shutdown error: {e}")
self._executor = None
self._worker_pid = None
# Final orphan cleanup
self._cleanup_orphan_workers()
logger.info("[OCRWorkerPool] Shutdown complete")
def _cleanup_orphan_workers(self) -> int:
"""
Clean up orphan Python processes from previous runs.
On Windows with NSSM, orphan processes may remain after service restart.
This finds and kills any python.exe processes that were OCR workers.
Returns:
Number of processes killed
"""
if not PSUTIL_AVAILABLE:
return 0
killed = 0
current_pid = os.getpid()
try:
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
# Skip self
if proc.pid == current_pid:
continue
# Look for Python processes with OCR-related cmdline
if proc.name().lower() in ('python.exe', 'python3.exe', 'python', 'python3'):
cmdline = ' '.join(proc.cmdline() or [])
# Check if this is an OCR worker process
if 'ocr_worker_process' in cmdline.lower() or 'process_ocr_task' in cmdline.lower():
logger.warning(f"[OCRWorkerPool] Killing orphan worker: PID={proc.pid}")
proc.kill()
proc.wait(timeout=5)
killed += 1
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
continue
except Exception as e:
logger.error(f"[OCRWorkerPool] Orphan cleanup error: {e}")
if killed > 0:
logger.info(f"[OCRWorkerPool] Cleaned up {killed} orphan worker(s)")
return killed
def _cleanup_on_exit(self) -> None:
"""atexit handler for cleanup."""
logger.info("[OCRWorkerPool] atexit cleanup triggered")
self.shutdown(wait=False)
def _signal_handler(self, signum: int, frame: Any) -> None:
"""Signal handler for SIGTERM/SIGINT."""
logger.info(f"[OCRWorkerPool] Received signal {signum}, shutting down...")
self.shutdown(wait=False)
# ============================================================================
# WORKER PROCESS FUNCTIONS
# ============================================================================
# These functions run in the child process, not the main FastAPI process.
# Global engines - persist between tasks in worker process
_paddle_engine = None
_tesseract_engine = None
_doctr_engine = None # docTR engine (PyTorch backend)
_worker_initialized = False
def _worker_initializer() -> None:
"""
Called once when worker process spawns.
Initializes global OCR engines IN PARALLEL for faster startup.
Uses ThreadPoolExecutor to load enabled engines concurrently.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Total warmup time = max(engine_times) instead of sum(engine_times).
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
if _worker_initialized:
print(f"[Worker {os.getpid()}] Already initialized", flush=True)
return
# Check which engines are enabled via .env
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
enabled_engines = ["doctr"] # docTR is always loaded (primary engine)
if paddle_enabled:
enabled_engines.append("paddle")
if tesseract_enabled:
enabled_engines.append("tesseract")
print(f"[Worker {os.getpid()}] Initializing OCR engines: {enabled_engines}", flush=True)
if not paddle_enabled:
print(f"[Worker {os.getpid()}] PaddleOCR DISABLED - saving ~800MB RAM", flush=True)
if not tesseract_enabled:
print(f"[Worker {os.getpid()}] Tesseract DISABLED - saving ~50MB RAM", flush=True)
start_time = time.time()
# Define loader functions - each runs in its own thread
def load_doctr():
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_doctr_engine
engine = initialize_doctr_engine()
return ("doctr", engine, None)
except Exception as e:
return ("doctr", None, str(e))
def load_paddle():
if not paddle_enabled:
return ("paddle", None, "disabled via OCR_ENABLE_PADDLEOCR=false")
try:
from backend.modules.data_entry.services.ocr.ocr_worker_process import initialize_paddle_engine
engine = initialize_paddle_engine()
return ("paddle", engine, None)
except Exception as e:
return ("paddle", None, str(e))
def load_tesseract():
if not tesseract_enabled:
return ("tesseract", None, "disabled via OCR_ENABLE_TESSERACT=false")
try:
from backend.modules.data_entry.services.ocr.tesseract_engine import TesseractEngine
engine = TesseractEngine()
return ("tesseract", engine, None)
except Exception as e:
return ("tesseract", None, str(e))
# Build list of futures for enabled engines only
futures_to_submit = [load_doctr] # docTR always loaded
if paddle_enabled:
futures_to_submit.append(load_paddle)
if tesseract_enabled:
futures_to_submit.append(load_tesseract)
# Load engines in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=len(futures_to_submit)) as executor:
futures = [executor.submit(fn) for fn in futures_to_submit]
for future in as_completed(futures):
name, engine, error = future.result()
if error and "disabled" not in error:
print(f"[Worker {os.getpid()}] {name} init failed: {error}", flush=True)
elif engine:
print(f"[Worker {os.getpid()}] {name} loaded", flush=True)
if name == "doctr":
_doctr_engine = engine
elif name == "paddle":
_paddle_engine = engine
elif name == "tesseract":
_tesseract_engine = engine
elapsed = time.time() - start_time
_worker_initialized = True
print(f"[Worker {os.getpid()}] Initialization complete in {elapsed:.1f}s (engines: {enabled_engines})", flush=True)
def _warmup_task() -> dict:
"""
Warmup task that ensures engines are loaded.
Called at FastAPI startup to pre-warm the worker.
Returns success status and worker PID.
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Quick test - create a small dummy image
import numpy as np
dummy_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Test docTR if available (fastest engine)
if _doctr_engine is not None:
try:
_doctr_engine([dummy_img])
print(f"[Worker {os.getpid()}] docTR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] docTR warmup error: {e}", flush=True)
# Test PaddleOCR if available
if _paddle_engine is not None:
try:
_paddle_engine.predict(dummy_img)
print(f"[Worker {os.getpid()}] PaddleOCR warmup OK", flush=True)
except Exception as e:
print(f"[Worker {os.getpid()}] PaddleOCR warmup error: {e}", flush=True)
# Cleanup
gc.collect()
return {
"success": True,
"pid": os.getpid(),
"doctr_available": _doctr_engine is not None,
"paddle_available": _paddle_engine is not None,
"tesseract_available": _tesseract_engine is not None
}
except Exception as e:
return {
"success": False,
"pid": os.getpid(),
"error": str(e)
}
def _process_ocr_task(
image_bytes: bytes,
engine: str = "doctr_plus",
preprocessing: str = "auto"
) -> dict:
"""
Process OCR task in worker process.
This is the main work function called for each OCR request.
Uses persistent global engines loaded at worker init.
Args:
image_bytes: Raw image bytes
engine: OCR engine choice ('tesseract', 'doctr', 'doctr_plus', 'paddleocr')
preprocessing: Preprocessing mode
Returns:
Dict with extraction results
"""
global _paddle_engine, _tesseract_engine, _doctr_engine, _worker_initialized
try:
# Ensure initialization
if not _worker_initialized:
_worker_initializer()
# Import processing function
from backend.modules.data_entry.services.ocr.ocr_worker_process import process_ocr
# Run OCR
result = process_ocr(
image_bytes=image_bytes,
paddle_engine=_paddle_engine,
tesseract_engine=_tesseract_engine,
engine=engine,
preprocessing=preprocessing,
doctr_engine=_doctr_engine
)
# Cleanup after each task
gc.collect()
return result
except Exception as e:
print(f"[Worker {os.getpid()}] Task error: {e}", flush=True)
import traceback
traceback.print_exc()
return {
"success": False,
"error": str(e),
"pid": os.getpid()
}
# Singleton instance
ocr_worker_pool = OCRWorkerPool()

View File

@@ -0,0 +1,258 @@
# Store Profiles - OCR Extraction
Sistem de profile specifice pentru extracție OCR cu hot-reload.
---
## Quick Start: Adaugă un profil nou
```bash
# 1. Generează profil din PDF-uri (dry-run pentru preview)
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--dry-run
# 2. Generează și salvează
python scripts/generate_store_profile.py \
--name "Magazin Nou SRL" \
--cui "12345678" \
--receipts "docs/data-entry/MagazinNou*.pdf" \
--output backend/modules/data_entry/services/ocr/profiles/magazin_nou.py
# 3. Hot-reload (fără restart server)
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload
# 4. Verifică
curl http://localhost:8000/api/data-entry/ocr/profiles
```
---
## Structura directorului
```
profiles/
├── __init__.py # ProfileRegistry + hot-reload (~390 linii)
├── base.py # BaseStoreProfile + pattern-uri generice (~410 linii)
├── lidl.py # Multi-rate TVA (A/B)
├── omv.py # B2B, date YYYY.MM.DD
├── socar.py # B2B, date YYYY.MM.DD
├── brick.py # Standard TVA
├── dedeman.py # E-factura support
├── kineterra.py # Non-VAT payer
├── gama_ink.py # Standard TVA (toner/cartușe)
├── electrobering.py # Standard TVA (electronice)
├── pictus_velum.py # Standard TVA (rechizite)
├── unlimited_keys.py # Standard TVA, NUMERAR payment
├── best_print.py # Non-VAT payer (neplătitor TVA)
├── stepout_market.py # TVA 5% (cărți/librărie)
└── README.md # Acest fișier
```
---
## Profile existente (12 profile)
> **Note**: Pattern-urile TVA sunt **flexibile** și acceptă ORICE cotă (5%, 9%, 11%, 19%, 21%, etc.)
> pentru a gestiona atât datele istorice cât și schimbările viitoare ale legislației.
| Magazin | CUI | Fișier | Caracteristici |
|---------|-----|--------|----------------|
| LIDL DISCOUNT S.R.L. | 22891860 | `lidl.py` | Multi-rate TVA (coduri A, B, C, D) |
| OMV PETROM MARKETING S.R.L. | 11201891 | `omv.py` | B2B (client CUI), date YYYY.MM.DD |
| SOCAR PETROLEUM S.A. | 12546600 | `socar.py` | B2B (client CUI), date YYYY.MM.DD |
| FIVE-HOLDING S.A. (BRICK) | 10562600 | `brick.py` | Standard TVA |
| DEDEMAN SRL | 2816464 | `dedeman.py` | E-factura support |
| KINETERRA CONCEPT SRL | 31180432 | `kineterra.py` | Non-VAT payer (returnează `[]`) |
| GAMA INK SERVICE SRL | 17741882 | `gama_ink.py` | Standard TVA (toner, cartușe) |
| ELECTROBERING S.R.L. | 2744937 | `electrobering.py` | Standard TVA (electronice) |
| PICTUS VELUM SRL | 39634534 | `pictus_velum.py` | Standard TVA (rechizite) |
| UNLIMITED KEYS S.R.L. | 18993187 | `unlimited_keys.py` | Standard TVA, **NUMERAR** plată |
| BEST PRINT TRADE ACTIV SRL | 45417955 | `best_print.py` | **Non-VAT payer** (neplătitor TVA) |
| STEPOUT MARKET SRL | 35532655 | `stepout_market.py` | TVA 5% (cărți, librărie) |
---
## API Endpoints
| Endpoint | Metodă | Descriere |
|----------|--------|-----------|
| `/api/data-entry/ocr/profiles` | GET | Lista toate profilele |
| `/api/data-entry/ocr/profiles/{cui}` | GET | Detalii profil (acceptă RO prefix) |
| `/api/data-entry/ocr/profiles/reload` | POST | Hot-reload toate profilele |
### Exemple API
```bash
# Lista profile
curl http://localhost:8000/api/data-entry/ocr/profiles \
-H "Authorization: Bearer <token>"
# Detalii profil (cu sau fără RO prefix)
curl http://localhost:8000/api/data-entry/ocr/profiles/22891860
curl http://localhost:8000/api/data-entry/ocr/profiles/RO22891860
# Hot-reload după modificări
curl -X POST http://localhost:8000/api/data-entry/ocr/profiles/reload \
-H "Authorization: Bearer <token>"
# Response reload:
{
"success": true,
"reloaded_modules": 12,
"profiles_count": 12,
"registered_cuis": ["22891860", "11201891", "12546600", "10562600", ...],
"last_reload": "2026-01-06T22:37:05.000000"
}
```
---
## Cum funcționează sistemul
### Flow de extracție
```
ReceiptExtractor.extract()
├─► STEP 1: Extrage vendor + CUI
│ └─► _extract_vendor(), _extract_cui()
├─► ProfileRegistry.get_profile(cui)
│ └─► Returnează profil specific sau None
├─► STEP 2: Extracție cu profil (dacă există)
│ ├─► profile.extract_total()
│ ├─► profile.extract_date()
│ ├─► profile.extract_receipt_number()
│ ├─► profile.extract_tva_entries()
│ ├─► profile.extract_payment_methods()
│ └─► profile.extract_client_cui()
└─► STEP 3-4: Validare + post-procesare
```
### Fallback
Dacă nu există profil pentru CUI, se folosește logica generică din `ReceiptExtractor`.
---
## Structura unui profil
```python
from .base import BaseStoreProfile
from . import ProfileRegistry
@ProfileRegistry.register
class MagazinNouProfile(BaseStoreProfile):
"""Docstring cu descriere magazin."""
CUI_LIST = ["12345678"] # Poate avea mai multe CUI-uri
NAME_PATTERNS = ["MAGAZIN", "MAGAZIN NOU", "MAG4ZIN"] # OCR variants
STORE_NAME = "Magazin Nou SRL"
# Override doar ce e diferit de base class
def extract_tva_entries(self, text: str) -> List[dict]:
# Pattern-uri specifice magazinului
...
def get_validation_hints(self) -> Dict[str, Any]:
return {
"has_multi_rate_tva": False,
"card_equals_total": True,
"has_client_cui": False,
"has_efactura": False,
"is_non_vat_payer": False,
}
```
---
## Pattern-uri disponibile în base.py
BaseStoreProfile include pattern-uri generice OCR-tolerant:
| Pattern | Descriere |
|---------|-----------|
| `TOTAL_PATTERNS` | 8 variante pentru TOTAL (TOTAL:, TOTAL DE PLATA, etc.) |
| `DATE_PATTERNS` | 6 variante (DD.MM.YYYY, YYYY-MM-DD, DD/MM/YYYY) |
| `DATE_PATTERNS_OCR_SPACES` | 4 variante cu spații OCR ("2025. 08. 14") |
| `NUMBER_PATTERNS` | 11 variante pentru număr bon (NDS, BF, C3POS) |
| `PAYMENT_PATTERNS` | 8 variante pentru CARD/NUMERAR |
| `CLIENT_MARKERS` | 6 variante pentru secțiune CLIENT |
| `CLIENT_CUI_PATTERNS` | 7 variante pentru CUI client |
### Metode implementate în base class
- `extract_total(text)``Tuple[Decimal, float]`
- `extract_date(text)``Tuple[date, float]`
- `extract_receipt_number(text)``Tuple[str, float]`
- `extract_payment_methods(text)``List[dict]`
- `extract_client_cui(text)``Tuple[str, float]`
- `extract_client_name(text)``Tuple[str, float]`
---
## Când ai nevoie de profil custom?
| Situație | Exemplu | Ce trebuie override |
|----------|---------|---------------------|
| **Multi-rate TVA** | Lidl (TVA A, TVA B) | `extract_tva_entries()` |
| **Format dată special** | OMV/Socar (YYYY.MM.DD) | `DATE_PATTERNS_OCR_SPACES` |
| **B2B receipts** | Benzinării (au client CUI) | `extract_client_cui()` |
| **Non-VAT payer** | Kineterra | `extract_tva_entries()` returnează `[]` |
| **E-factura** | Dedeman | `extract_efactura_reference()` |
---
## Decizii de design
1. **Hot-reload manual** - endpoint `/profiles/reload` apelat când se modifică fișiere
2. **Persistență în Python** - profile în Git, version controlled
3. **Fallback graceful** - dacă nu există profil, folosește logica generică
4. **CUI normalization** - gestionează automat prefixul "RO" și whitespace
5. **Deduplicare TVA** - folosește `seen = set()` pentru a evita duplicate
---
## Comenzi utile
```bash
# Verifică syntax Python pentru toate profilele
for f in backend/modules/data_entry/services/ocr/profiles/*.py; do
python3 -m py_compile "$f" && echo "✓ $(basename $f)"
done
# Lista profile
ls -la backend/modules/data_entry/services/ocr/profiles/
# Pornește backend pentru testare
cd backend && source venv/bin/activate
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
# Test OCR pe un PDF
curl -X POST -F "file=@docs/data-entry/test.pdf" \
-H "Authorization: Bearer <token>" \
"http://localhost:8000/api/data-entry/ocr/extract?engine=doctr_plus"
```
---
## Script generare profile
`scripts/generate_store_profile.py` - generator automat de profile
```bash
# Vezi help
python scripts/generate_store_profile.py --help
# Funcționalități:
# - Analizează PDF-uri via OCR API
# - Detectează: TVA format, date format, payment patterns, B2B
# - Generează cod Python cu OCR error variants
# - Suportă glob patterns (*.pdf)
# - Verifică sintaxa după generare
```

View File

@@ -0,0 +1,398 @@
"""
Store Profiles Registry with Hot-Reload Support.
This module provides a registry for store-specific OCR extraction profiles.
Profiles can be reloaded at runtime without restarting the server.
Usage:
from backend.modules.data_entry.services.ocr.profiles import ProfileRegistry
# Get profile for a CUI
profile = ProfileRegistry.get_profile("22891860")
if profile:
tva_entries = profile.extract_tva_entries(text)
# Reload all profiles (after file changes)
count = ProfileRegistry.reload_all()
Architecture:
- ProfileRegistry: Singleton registry with class methods
- BaseStoreProfile: Abstract base class for profiles
- @ProfileRegistry.register: Decorator for profile classes
Hot-Reload Mechanism:
1. Admin calls POST /profiles/reload endpoint
2. Registry clears instance cache
3. importlib.reload() re-executes each profile module
4. @register decorator re-registers classes with new code
"""
from __future__ import annotations
import importlib
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Type, TYPE_CHECKING
if TYPE_CHECKING:
from .base import BaseStoreProfile
logger = logging.getLogger(__name__)
# Directory containing profile modules
PROFILES_DIR = Path(__file__).parent
class ProfileRegistry:
"""
Registry for store-specific OCR extraction profiles.
Uses class methods for singleton-like behavior without explicit instantiation.
Supports hot-reload via importlib.reload() for runtime updates.
Attributes:
_profiles: Maps CUI -> profile class (not instance)
_instances: Maps CUI -> profile instance (lazy, cleared on reload)
_last_reload: Timestamp of last reload
_loaded: Whether initial load has been performed
"""
# Class-level storage (singleton pattern via class methods)
_profiles: Dict[str, Type["BaseStoreProfile"]] = {}
_instances: Dict[str, "BaseStoreProfile"] = {}
_last_reload: Optional[datetime] = None
_loaded: bool = False
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
@classmethod
def register(cls, profile_class: Type["BaseStoreProfile"]) -> Type["BaseStoreProfile"]:
"""
Decorator to register a store profile class.
Registers the profile for all CUIs in the class's CUI_LIST.
Safe for re-registration during hot-reload (overwrites existing).
Usage:
@ProfileRegistry.register
class LidlProfile(BaseStoreProfile):
CUI_LIST = ["22891860"]
...
Args:
profile_class: Profile class to register
Returns:
The same class (allows use as decorator)
Raises:
ValueError: If CUI_LIST is empty
"""
cui_list = getattr(profile_class, 'CUI_LIST', [])
store_name = getattr(profile_class, 'STORE_NAME', profile_class.__name__)
if not cui_list:
logger.warning(f"Profile {profile_class.__name__} has empty CUI_LIST, skipping")
return profile_class
# Register for each CUI
for cui in cui_list:
# Normalize CUI (remove RO prefix, strip whitespace)
normalized_cui = cls._normalize_cui(cui)
if normalized_cui in cls._profiles:
old_class = cls._profiles[normalized_cui]
logger.debug(
f"Re-registering CUI {normalized_cui}: "
f"{old_class.__name__} -> {profile_class.__name__}"
)
# Clear cached instance for this CUI
cls._instances.pop(normalized_cui, None)
cls._profiles[normalized_cui] = profile_class
logger.debug(f"Registered profile {profile_class.__name__} for CUI {normalized_cui}")
logger.info(f"Registered {store_name} for CUIs: {cui_list}")
return profile_class
# -------------------------------------------------------------------------
# Lookup
# -------------------------------------------------------------------------
@classmethod
def get_profile(cls, cui: Optional[str]) -> Optional["BaseStoreProfile"]:
"""
Get profile instance for a CUI.
Uses lazy instantiation - creates instance on first access.
Returns None if no profile is registered for this CUI.
Args:
cui: CUI to lookup (with or without RO prefix)
Returns:
Profile instance or None
"""
if not cui:
return None
# Ensure profiles are loaded
if not cls._loaded:
cls._load_all_profiles()
normalized_cui = cls._normalize_cui(cui)
# Check if profile exists
profile_class = cls._profiles.get(normalized_cui)
if not profile_class:
return None
# Lazy instantiation
if normalized_cui not in cls._instances:
try:
cls._instances[normalized_cui] = profile_class()
logger.debug(f"Instantiated {profile_class.__name__} for CUI {normalized_cui}")
except Exception as e:
logger.error(f"Failed to instantiate {profile_class.__name__}: {e}")
return None
return cls._instances[normalized_cui]
@classmethod
def has_profile(cls, cui: Optional[str]) -> bool:
"""Check if a profile exists for this CUI."""
if not cui:
return False
if not cls._loaded:
cls._load_all_profiles()
return cls._normalize_cui(cui) in cls._profiles
# -------------------------------------------------------------------------
# Listing
# -------------------------------------------------------------------------
@classmethod
def list_profiles(cls) -> List[Dict]:
"""
List all registered profiles.
Returns:
List of dicts with cui, class_name, store_name, name_patterns
"""
if not cls._loaded:
cls._load_all_profiles()
result = []
seen_classes = set()
for cui, profile_class in cls._profiles.items():
# Avoid duplicates for profiles with multiple CUIs
if profile_class.__name__ in seen_classes:
continue
seen_classes.add(profile_class.__name__)
result.append({
"cuis": list(getattr(profile_class, 'CUI_LIST', [])),
"class_name": profile_class.__name__,
"store_name": getattr(profile_class, 'STORE_NAME', profile_class.__name__),
"name_patterns": list(getattr(profile_class, 'NAME_PATTERNS', [])),
})
return result
@classmethod
def get_profile_info(cls, cui: str) -> Optional[Dict]:
"""
Get detailed info about a profile.
Args:
cui: CUI to lookup
Returns:
Dict with profile details or None
"""
profile = cls.get_profile(cui)
if not profile:
return None
return {
"cui": cui,
"cuis": list(profile.CUI_LIST),
"class_name": profile.__class__.__name__,
"store_name": profile.STORE_NAME,
"name_patterns": list(profile.NAME_PATTERNS),
"validation_hints": profile.get_validation_hints(),
}
# -------------------------------------------------------------------------
# Hot-Reload
# -------------------------------------------------------------------------
@classmethod
def reload_all(cls) -> int:
"""
Hot-reload all profile modules.
Clears instance cache and reloads all .py files in profiles directory.
Decorator re-registers classes with updated code.
Returns:
Number of modules reloaded
"""
logger.info("Starting profile hot-reload...")
# Clear instance cache (will be recreated on next get_profile)
cls._instances.clear()
# Get list of profile modules (exclude __init__, base)
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
base_package = cls.__module__
count = 0
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
if full_name in sys.modules:
# Reload existing module
importlib.reload(sys.modules[full_name])
logger.debug(f"Reloaded module: {module_name}")
else:
# Import new module
importlib.import_module(full_name)
logger.debug(f"Imported new module: {module_name}")
count += 1
except Exception as e:
logger.error(f"Failed to reload {module_name}: {e}")
cls._last_reload = datetime.utcnow()
cls._loaded = True
logger.info(f"Profile hot-reload complete: {count} modules, {len(cls._profiles)} profiles")
return count
@classmethod
def get_reload_status(cls) -> Dict:
"""Get status of the registry including last reload time."""
return {
"loaded": cls._loaded,
"last_reload": cls._last_reload.isoformat() if cls._last_reload else None,
"profiles_count": len(cls._profiles),
"instances_count": len(cls._instances),
"registered_cuis": list(cls._profiles.keys()),
}
# -------------------------------------------------------------------------
# Internal methods
# -------------------------------------------------------------------------
@classmethod
def _normalize_cui(cls, cui: str) -> str:
"""
Normalize CUI for consistent lookup.
- Removes RO prefix (with or without space)
- Strips whitespace
- Converts to uppercase
Args:
cui: Raw CUI string
Returns:
Normalized CUI (digits only)
"""
if not cui:
return ""
cui = str(cui).strip().upper()
# Remove RO prefix (handles "RO12345" and "RO 12345")
if cui.startswith("RO"):
cui = cui[2:].lstrip()
return cui.strip()
@classmethod
def _get_profile_module_names(cls) -> List[str]:
"""
Get list of profile module names from profiles directory.
Excludes __init__.py and base.py.
Returns:
List of module names (without .py extension)
"""
excluded = {"__init__", "base", "__pycache__"}
modules = []
for path in PROFILES_DIR.glob("*.py"):
name = path.stem
if name not in excluded:
modules.append(name)
return sorted(modules)
@classmethod
def _load_all_profiles(cls) -> None:
"""
Initial load of all profile modules.
Called automatically on first get_profile() if not already loaded.
"""
if cls._loaded:
return
logger.info("Loading store profiles...")
module_names = cls._get_profile_module_names()
# Determine the module prefix based on how THIS module was imported
# This handles both:
# - Running from backend dir: "modules.data_entry.services.ocr.profiles"
# - Running from project root: "backend.modules.data_entry.services.ocr.profiles"
this_module = cls.__module__ # e.g. "backend.modules..." or "modules..."
base_package = this_module # Use the same prefix for child modules
for module_name in module_names:
full_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_name)
logger.debug(f"Loaded module: {module_name}")
except Exception as e:
logger.error(f"Failed to load {module_name}: {e}")
cls._loaded = True
cls._last_reload = datetime.utcnow()
logger.info(f"Loaded {len(cls._profiles)} store profiles")
@classmethod
def clear(cls) -> None:
"""
Clear all registered profiles.
Mainly useful for testing.
"""
cls._profiles.clear()
cls._instances.clear()
cls._loaded = False
cls._last_reload = None
# -------------------------------------------------------------------------
# Module exports
# -------------------------------------------------------------------------
__all__ = [
"ProfileRegistry",
"BaseStoreProfile",
]
# Re-export BaseStoreProfile for convenience
from .base import BaseStoreProfile

View File

@@ -0,0 +1,655 @@
"""
Optimized Tesseract Engine for OCR - SPEED + QUALITY OPTIMIZED
Performance optimizations (vs previous version):
- Single PSM mode (PSM 4) instead of multi-PSM (4 modes × 2 calls = 8x faster)
- Single Tesseract call per image (skip image_to_data for speed)
- Lighter preprocessing (no over-binarization)
- --dpi 300 flag for proper scaling
- OEM 3 (default LSTM+Legacy) for balanced speed/accuracy
Quality optimizations for Romanian receipts:
- PSM 4: Single column layout (optimal for receipts)
- Polarity correction: ensures black text on white background
- Language: Romanian only (-l ron) for faster recognition
- Fallback to PSM 6 if PSM 4 produces poor results
Previous issues fixed:
- Was 8x slower than PaddleOCR due to multi-PSM + dual calls
- Produced gibberish on clear PDFs due to over-binarization
"""
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import cv2
import numpy as np
# Check Tesseract availability
try:
import pytesseract
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
pytesseract = None
logger = logging.getLogger(__name__)
@dataclass
class OCRResult:
"""Raw OCR result from Tesseract."""
text: str
confidence: float
boxes: List[dict] = field(default_factory=list)
engine: str = "tesseract"
class TesseractEngine:
"""
Optimized Tesseract engine for receipt OCR.
TESTED OPTIMAL SETTINGS (from comprehensive benchmark):
- DPI 200 for PDF loading (not 300!)
- Padding 40px for edge protection
- PSM 6 for complex receipts, PSM 4 for simple ones
- Multi-pass strategy when quality is critical
SPEED vs QUALITY tradeoff:
- Fast mode (single pass): ~0.9s, ~6-7 keywords
- Quality mode (multi-pass): ~1.7s, ~8-9 keywords (+2 more keywords)
BENCHMARK RESULTS:
- padded_psm6_40: Best for complex receipts (igiena, five-holding)
- baseline_psm4: Best for simple receipts (rechizite, benzina)
- multi-pass: Best overall quality but slower
"""
# PSM modes for receipts
PSM_SINGLE_COLUMN = 4 # Best for simple vertical receipts
PSM_UNIFORM_BLOCK = 6 # Best for complex layouts
PSM_SPARSE_TEXT = 11 # Fallback for difficult receipts
# Optimal padding (from benchmark)
DEFAULT_PADDING = 40
def __init__(self):
"""Initialize Tesseract engine."""
if not TESSERACT_AVAILABLE:
raise RuntimeError("pytesseract not available. Install with: pip install pytesseract")
# Verify Tesseract installation
try:
self._version = pytesseract.get_tesseract_version()
except Exception as e:
raise RuntimeError(f"Tesseract not installed or not in PATH: {e}")
logger.info(f"[TesseractEngine] Initialized (v{self._version})")
def recognize(self, image: np.ndarray, fast_mode: bool = True) -> OCRResult:
"""
Perform OCR recognition on image (OPTIMIZED).
SPEED: Uses single PSM mode + single Tesseract call.
Previously used 4 PSM modes × 2 calls = 8 Tesseract invocations.
Now uses 1-2 calls maximum (with fallback).
Args:
image: Preprocessed grayscale image (DO NOT binarize for clear PDFs!)
fast_mode: If True, skip confidence calculation for maximum speed
Returns:
OCRResult with text and confidence
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity (black text on white background)
image = self._ensure_correct_polarity(image)
# Try PSM 4 first (single column - best for receipts)
result = self._recognize_fast(image, self.PSM_SINGLE_COLUMN, fast_mode)
# If poor result, try PSM 6 as fallback
if not result.text.strip() or result.confidence < 0.3:
logger.debug(f"[Tesseract] PSM {self.PSM_SINGLE_COLUMN} poor result, trying PSM {self.PSM_UNIFORM_BLOCK}")
fallback = self._recognize_fast(image, self.PSM_UNIFORM_BLOCK, fast_mode)
if len(fallback.text) > len(result.text):
result = fallback
if result.text.strip():
logger.info(f"[TesseractEngine] Result: {len(result.text)} chars, conf={result.confidence:.0%}")
return result
def _recognize_fast(self, image: np.ndarray, psm: int, fast_mode: bool = True) -> OCRResult:
"""
Fast single-call Tesseract recognition.
Optimizations:
- Single call (image_to_string only in fast mode)
- OEM 3 (LSTM+Legacy) - faster than OEM 1
- --dpi 300 for proper scaling
- Romanian only (-l ron)
Args:
image: Grayscale image
psm: Page segmentation mode
fast_mode: Skip confidence calculation for speed
Returns:
OCRResult
"""
# Build optimized config:
# OEM 3 = LSTM + Legacy (faster than pure LSTM)
# --dpi 300 = proper scaling hint
# -l ron = Romanian only (faster, avoids eng confusion)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
if fast_mode:
# Fast path: just get text, estimate confidence
text = pytesseract.image_to_string(image, config=config)
# Estimate confidence based on text quality
confidence = self._estimate_confidence(text)
else:
# Accurate path: get text + real confidence
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
confidence = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine="tesseract"
)
except Exception as e:
logger.warning(f"[Tesseract] PSM {psm} error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _estimate_confidence(self, text: str) -> float:
"""
Estimate OCR confidence based on text quality.
Heuristics:
- More alphanumeric chars = higher confidence
- Less garbage chars = higher confidence
- Romanian-specific patterns boost confidence
"""
if not text.strip():
return 0.0
# Count valid vs garbage chars
valid_chars = sum(1 for c in text if c.isalnum() or c in '.,;:-/\n ')
total_chars = len(text)
if total_chars == 0:
return 0.0
# Base confidence from char ratio
confidence = valid_chars / total_chars
# Boost for Romanian receipt patterns
text_lower = text.lower()
if any(word in text_lower for word in ['total', 'lei', 'ron', 'buc', 'tva', 'cif', 'bon']):
confidence = min(confidence + 0.1, 1.0)
return confidence
def recognize_multipass(self, image: np.ndarray) -> OCRResult:
"""
Multi-pass OCR for maximum quality (slower but more accurate).
Strategy (from benchmark testing):
- Pass 1: PSM 4 (single column) - no padding, fast baseline
- Pass 2: PSM 6 (uniform block) - with 40px padding, better for complex layouts
- Pass 3: PSM 11 (sparse text) - with 40px padding + stronger CLAHE, for difficult receipts
Merges results: picks the pass with highest keyword count.
On average finds +2.1 more keywords than single-pass (~8.7 vs 6.6).
Time: ~1.7s (vs ~0.9s for single pass)
Args:
image: Input image (RGB or grayscale)
Returns:
OCRResult from the best pass
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# Define passes with different settings
passes = [
# Pass 1: Fast baseline (no padding) - good for simple receipts
{"name": "pass1_psm4", "psm": 4, "padding": 0, "clahe_clip": 1.5},
# Pass 2: Padded PSM 6 - good for complex receipts
{"name": "pass2_psm6_padded", "psm": 6, "padding": 40, "clahe_clip": 1.5},
# Pass 3: Sparse text with stronger enhancement - for difficult cases
{"name": "pass3_psm11", "psm": 11, "padding": 40, "clahe_clip": 2.0},
]
best_result = None
best_score = -1
all_keywords = set()
for p in passes:
# Apply preprocessing for this pass
processed = gray.copy()
# Add padding if specified
if p["padding"] > 0:
processed = cv2.copyMakeBorder(
processed, p["padding"], p["padding"], p["padding"], p["padding"],
cv2.BORDER_CONSTANT, value=255
)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=p["clahe_clip"], tileGridSize=(8, 8))
processed = clahe.apply(processed)
# Ensure correct polarity
processed = self._ensure_correct_polarity(processed)
# Run OCR
config = f'--psm {p["psm"]} --oem 3 -l ron'
try:
text = pytesseract.image_to_string(processed, config=config)
confidence = self._estimate_confidence(text)
# Score based on Romanian receipt keywords
text_lower = text.lower()
keywords = ['cif', 'total', 'tva', 'lei', 'ron', 'buc', 'fiscal', 'bon',
'hartie', 'prosop', 'saci', 'creion', 'constanta', 'bucuresti']
found_keywords = [kw for kw in keywords if kw in text_lower]
all_keywords.update(found_keywords)
# Score: keywords + CIF bonus + TOTAL bonus
score = len(found_keywords) * 10
if self._has_cif_pattern(text):
score += 15
if self._has_total_pattern(text):
score += 10
logger.debug(f"[Tesseract] {p['name']}: {len(found_keywords)} keywords, score={score}")
if score > best_score:
best_score = score
best_result = OCRResult(
text=text,
confidence=confidence,
boxes=[],
engine=f"tesseract-multipass-{p['name']}"
)
except Exception as e:
logger.warning(f"[Tesseract] {p['name']} failed: {e}")
continue
if best_result:
logger.info(f"[TesseractEngine] Multi-pass best: {best_result.engine}, "
f"{len(all_keywords)} total keywords found")
return best_result
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract-multipass")
def _has_cif_pattern(self, text: str) -> bool:
"""Check if text contains a valid CIF/CUI pattern."""
import re
text_upper = text.upper()
patterns = [
r'CIF[:\s]*RO?\d{6,10}',
r'CUI[:\s]*RO?\d{6,10}',
r'C\.?I\.?F\.?[:\s]*RO?\d{6,10}',
]
for pattern in patterns:
if re.search(pattern, text_upper):
return True
return bool(re.search(r'RO\d{7,10}', text_upper))
def _has_total_pattern(self, text: str) -> bool:
"""Check if TOTAL is properly recognized (not truncated to BTOTAL/OTAL)."""
import re
text_upper = text.upper()
return bool(re.search(r'(^|\s)TOTAL\s', text_upper, re.MULTILINE))
def recognize_with_boxes(self, image: np.ndarray, psm: int = 4) -> OCRResult:
"""
Recognition with bounding boxes (slower, for debugging/visualization).
Use this only when you need box coordinates.
For normal OCR, use recognize() which is faster.
Args:
image: Grayscale image
psm: Page segmentation mode (default: 4 for receipts)
Returns:
OCRResult with text, confidence, and boxes
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = self._ensure_correct_polarity(image)
config = f'--psm {psm} --oem 3 --dpi 300 -l ron'
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image, config=config, output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
boxes = []
for i in range(len(data['text'])):
if data['text'][i].strip() and int(data['conf'][i]) > 0:
boxes.append({
'text': data['text'][i],
'confidence': int(data['conf'][i]) / 100,
'box': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
})
return OCRResult(text=text, confidence=avg_conf, boxes=boxes, engine="tesseract")
except Exception as e:
logger.warning(f"[Tesseract] recognize_with_boxes error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def _ensure_correct_polarity(self, image: np.ndarray) -> np.ndarray:
"""
Ensure image has black text on white background.
Receipts should have dark text on light background.
If image is inverted (light text on dark), invert it.
Detection method:
- Calculate mean pixel value
- If mean < 127, image is mostly dark (inverted)
- Invert to correct polarity
Args:
image: Grayscale image
Returns:
Polarity-corrected image
"""
mean_value = np.mean(image)
if mean_value < 127:
# Image is mostly dark = inverted (white text on black)
logger.debug(f"[TesseractEngine] Detected inverted polarity (mean={mean_value:.1f}), correcting...")
return 255 - image
return image
def recognize_numbers_only(self, image: np.ndarray) -> OCRResult:
"""
OCR optimized for numeric content (amounts, totals).
Uses character whitelist to reduce errors on numbers.
Args:
image: Preprocessed grayscale image
Returns:
OCRResult with numeric text
"""
if not TESSERACT_AVAILABLE:
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Ensure grayscale
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Fix polarity
image = self._ensure_correct_polarity(image)
# Config for numbers only
# Whitelist: digits, comma, period, space, RON, LEI
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789.,- '
try:
text = pytesseract.image_to_string(image, config=config)
data = pytesseract.image_to_data(
image,
config=config,
output_type=pytesseract.Output.DICT
)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
return OCRResult(
text=text.strip(),
confidence=avg_conf,
boxes=[],
engine="tesseract-numeric"
)
except Exception as e:
logger.error(f"[TesseractEngine] Numeric OCR error: {e}")
return OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
def recognize_cif_optimized(self, image: np.ndarray) -> Optional[str]:
"""
Optimized CIF extraction using multi-strategy approach.
BENCHMARK RESULTS (from test_critical_fields.py):
- digit_opt_dpi200: 33% accuracy (best)
- digit_whitelist: Works well on specific receipts
- basic_ron_eng: Good backup
Strategy:
1. Try digit-optimized preprocessing (2x scale + Otsu)
2. Try character whitelist (RO + digits only)
3. Try standard ron+eng config
4. Return best match based on CIF pattern validation
Args:
image: Input image (RGB from pdf2image or BGR from OpenCV)
Returns:
Extracted CIF string (e.g., "RO10562600") or None
"""
import re
if not TESSERACT_AVAILABLE:
return None
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# Extract top 35% of image (where CIF is typically found)
height = gray.shape[0]
top_region = gray[:int(height * 0.35), :]
candidates = []
# Strategy 1: Digit-optimized preprocessing (best performer: 33% accuracy)
try:
# Scale up 2x + Otsu binarization
scaled = cv2.resize(top_region, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(scaled)
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if np.mean(binary) < 127:
binary = 255 - binary
text = pytesseract.image_to_string(binary, config='--psm 6 --oem 3 -l ron')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('digit_opt', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] digit_opt strategy failed: {e}")
# Strategy 2: Character whitelist (RO + digits only)
try:
# Add padding
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
scaled = cv2.resize(padded, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
config = '--psm 6 --oem 1 -c tessedit_char_whitelist=0123456789ROro'
text = pytesseract.image_to_string(scaled, config=config)
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('whitelist', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] whitelist strategy failed: {e}")
# Strategy 3: Standard ron+eng config (good backup)
try:
padded = cv2.copyMakeBorder(top_region, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=255)
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(padded)
text = pytesseract.image_to_string(enhanced, config='--psm 6 --oem 3 -l ron+eng')
cif = self._extract_cif_from_text(text)
if cif:
candidates.append(('ron_eng', cif))
except Exception as e:
logger.debug(f"[TesseractEngine] ron_eng strategy failed: {e}")
if not candidates:
return None
# Log all candidates
for strategy, cif in candidates:
logger.debug(f"[TesseractEngine] CIF candidate from {strategy}: {cif}")
# Use majority voting if multiple strategies agree
from collections import Counter
cif_counts = Counter(cif for _, cif in candidates)
most_common_cif, count = cif_counts.most_common(1)[0]
if count > 1:
# Multiple strategies agree
logger.info(f"[TesseractEngine] CIF extracted (majority {count} strategies): {most_common_cif}")
return most_common_cif
# No agreement - prefer digit_opt strategy (33% accuracy in benchmarks)
for strategy, cif in candidates:
if strategy == 'digit_opt':
logger.info(f"[TesseractEngine] CIF extracted via digit_opt (preferred): {cif}")
return cif
# Fallback to first candidate
strategy, cif = candidates[0]
logger.info(f"[TesseractEngine] CIF extracted via {strategy}: {cif}")
return cif
def _extract_cif_from_text(self, text: str) -> Optional[str]:
"""Extract CIF/CUI from OCR text."""
import re
text_upper = text.upper().replace(' ', '')
patterns = [
r'CIF[:\s]*R?O?(\d{6,10})',
r'CUI[:\s]*R?O?(\d{6,10})',
r'C\.?I\.?F\.?[:\s]*R?O?(\d{6,10})',
r'RO(\d{7,10})',
r'R\.?O\.?[\s:]*(\d{6,10})',
]
for pattern in patterns:
match = re.search(pattern, text_upper)
if match:
digits = match.group(1).lstrip('0') or '0'
return f"RO{digits}"
return None
@staticmethod
def validate_romanian_cif(cif: str) -> bool:
"""
Validate Romanian CIF/CUI using checksum algorithm.
Romanian CIF format: RO + 2-10 digits
The last digit is a control digit calculated using modulo 11.
Algorithm:
1. Multiply each digit by corresponding weight (from right to left: 2,3,4,5,6,7,2,3,4,5)
2. Sum all products
3. Remainder of sum / 11 is the control digit
4. If remainder is 10, control digit is 0
Args:
cif: CIF string (e.g., "RO10562600", "10562600")
Returns:
True if CIF is valid, False otherwise
"""
# Remove RO prefix and spaces
cif = cif.upper().replace(' ', '').replace('RO', '')
# Must be 2-10 digits
if not cif.isdigit() or len(cif) < 2 or len(cif) > 10:
return False
# Weights for checksum calculation (right to left)
weights = [2, 3, 4, 5, 6, 7, 2, 3, 4, 5]
# Pad with zeros on the left to make it 10 digits
cif_padded = cif.zfill(10)
# Calculate checksum (excluding last digit which is control)
total = 0
for i in range(9):
total += int(cif_padded[i]) * weights[i]
# Control digit
control = total % 11
if control == 10:
control = 0
# Compare with last digit
return int(cif_padded[9]) == control
@staticmethod
def is_available() -> bool:
"""Check if Tesseract is available."""
if not TESSERACT_AVAILABLE:
return False
try:
pytesseract.get_tesseract_version()
return True
except Exception:
return False
@staticmethod
def get_version() -> Optional[str]:
"""Get Tesseract version string."""
if not TESSERACT_AVAILABLE:
return None
try:
return str(pytesseract.get_tesseract_version())
except Exception:
return None

View File

@@ -0,0 +1,476 @@
"""OCR engine wrapper for PaddleOCR, docTR, and Tesseract."""
import os
import logging
import threading
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
# Setup logging (respects LOG_LEVEL env var set in main.py)
logger = logging.getLogger(__name__)
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Lazy imports - these will be imported on first use
PaddleOCR = None # Will be imported lazily
pytesseract = None # Will be imported lazily
doctr_ocr_predictor = None # Will be imported lazily
# Check availability without importing heavy libraries
def _check_paddle_available() -> bool:
"""Check if paddleocr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("paddleocr") is not None
except Exception:
return False
def _check_tesseract_available() -> bool:
"""Check if pytesseract is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("pytesseract") is not None
except Exception:
return False
def _check_doctr_available() -> bool:
"""Check if doctr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("doctr") is not None
except Exception:
return False
PADDLE_AVAILABLE = _check_paddle_available()
TESSERACT_AVAILABLE = _check_tesseract_available()
DOCTR_AVAILABLE = _check_doctr_available()
@dataclass
class OCRResult:
"""Raw OCR result."""
text: str
confidence: float
boxes: List[dict]
engine: str = "" # OCR engine used: paddleocr or tesseract
class OCREngine:
"""Unified OCR engine with fallback support."""
def __init__(self):
self._paddle = None
self._paddle_init_started = False
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
self._paddle_init_lock = threading.Lock()
self._doctr = None
self._doctr_init_started = False
self._doctr_ready = threading.Event() # Signals when docTR is FULLY ready
self._doctr_init_lock = threading.Lock()
def _init_paddle_lazy(self):
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
global PaddleOCR
with self._paddle_init_lock:
if self._paddle_init_started:
return # Already initializing or done
self._paddle_init_started = True
if PADDLE_AVAILABLE:
try:
print("Importing PaddleOCR (first use, may take ~15-20 seconds)...", flush=True)
from paddleocr import PaddleOCR as _PaddleOCR
PaddleOCR = _PaddleOCR
print("Initializing PaddleOCR engine...", flush=True)
# PaddleOCR 3.x API - optimized for Romanian receipts
# Note: 'latin' not available in PaddleOCR 3.x, 'en' works well for receipts
self._paddle = PaddleOCR(
lang='en', # 'en' handles Latin alphabet well for receipts
# High quality settings for better accuracy
det_db_thresh=0.3, # Lower threshold = detect more text (default 0.3)
det_db_box_thresh=0.5, # Box confidence threshold (default 0.5)
det_db_unclip_ratio=1.8, # Expand detected boxes slightly (default 1.5)
rec_batch_num=6, # Batch size for recognition
use_angle_cls=True, # Enable text angle classification
)
print("PaddleOCR initialized successfully with high-quality settings", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize PaddleOCR: {e}", flush=True)
self._paddle = None
# Signal that initialization is complete (success or failure)
self._paddle_ready.set()
def _init_doctr_lazy(self):
"""Lazy initialize docTR on first use (avoids slow startup)."""
global doctr_ocr_predictor
with self._doctr_init_lock:
if self._doctr_init_started:
return # Already initializing or done
self._doctr_init_started = True
if DOCTR_AVAILABLE:
try:
print("Importing docTR (first use, may take ~10-15 seconds)...", flush=True)
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
print("Initializing docTR engine (PyTorch backend)...", flush=True)
# Initialize docTR predictor with pretrained models
# Uses db_resnet50 for detection and crnn_vgg16_bn for recognition
self._doctr = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True,
assume_straight_pages=True,
straighten_pages=False,
preserve_aspect_ratio=True,
)
doctr_ocr_predictor = self._doctr
print("docTR initialized successfully with PyTorch backend", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize docTR: {e}", flush=True)
self._doctr = None
# Signal that initialization is complete (success or failure)
self._doctr_ready.set()
def wait_for_doctr(self, timeout: float = 30.0) -> bool:
"""
Wait for docTR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if docTR is ready, False if timeout or unavailable
"""
if not DOCTR_AVAILABLE:
return False
if self._doctr is not None:
return True # Already ready
if not self._doctr_init_started:
# Start initialization if not already started
self._init_doctr_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for docTR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._doctr_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._doctr is not None:
print(f"[OCR] docTR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] docTR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_doctr_ready(self) -> bool:
"""Check if docTR is ready without waiting."""
return self._doctr is not None
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
"""
Wait for PaddleOCR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if PaddleOCR is ready, False if timeout or unavailable
"""
if not PADDLE_AVAILABLE:
return False
if self._paddle is not None:
return True # Already ready
if not self._paddle_init_started:
# Start initialization if not already started
self._init_paddle_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for PaddleOCR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._paddle_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._paddle is not None:
print(f"[OCR] PaddleOCR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] PaddleOCR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_paddle_ready(self) -> bool:
"""Check if PaddleOCR is ready without waiting."""
return self._paddle is not None
def recognize(self, image: np.ndarray) -> OCRResult:
"""Perform OCR on preprocessed image."""
logger.info(f"[OCR] Starting recognition, image shape: {image.shape}, dtype: {image.dtype}")
# Lazy init PaddleOCR on first call
self._init_paddle_lazy()
if PADDLE_AVAILABLE and self._paddle:
logger.info("[OCR] Using PaddleOCR engine")
return self._paddle_recognize(image)
elif TESSERACT_AVAILABLE:
logger.info("[OCR] Using Tesseract engine (PaddleOCR not available)")
return self._tesseract_recognize(image)
else:
logger.error("[OCR] No OCR engine available!")
raise RuntimeError(
"No OCR engine available. Install PaddleOCR or Tesseract."
)
def _paddle_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using PaddleOCR 3.x API."""
# Wait for PaddleOCR to be fully ready (handles background init)
if not self.wait_for_paddle(timeout=30.0):
logger.warning("[PaddleOCR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("PaddleOCR not ready and Tesseract not available")
try:
logger.info(f"[PaddleOCR] Processing image, shape: {image.shape}")
# PaddleOCR 3.x requires 3-channel images
if len(image.shape) == 2:
# Convert grayscale to 3-channel BGR
import cv2
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
logger.info(f"[PaddleOCR] Converted to BGR, new shape: {image.shape}")
# PaddleOCR 3.x uses predict() with new parameter names
logger.info("[PaddleOCR] Calling predict()...")
result = self._paddle.predict(image, use_textline_orientation=True)
logger.info(f"[PaddleOCR] predict() returned, result type: {type(result)}")
if not result or len(result) == 0:
logger.warning("[PaddleOCR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# PaddleOCR 3.x returns OCRResult objects with different structure
ocr_result = result[0]
# Extract texts and scores from the new format
rec_texts = ocr_result.get('rec_texts', [])
rec_scores = ocr_result.get('rec_scores', [])
dt_polys = ocr_result.get('dt_polys', [])
if not rec_texts:
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
boxes = []
for i, text in enumerate(rec_texts):
conf = rec_scores[i] if i < len(rec_scores) else 0.0
box = dt_polys[i].tolist() if i < len(dt_polys) else []
boxes.append({
'text': text,
'confidence': float(conf),
'box': box
})
avg_conf = sum(rec_scores) / len(rec_scores) if rec_scores else 0.0
text_result = '\n'.join(rec_texts)
logger.info(f"[PaddleOCR] SUCCESS - Found {len(rec_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[PaddleOCR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="paddleocr"
)
except Exception as e:
logger.error(f"[PaddleOCR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def _tesseract_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using Tesseract."""
global pytesseract
logger.info(f"[Tesseract] Processing image, shape: {image.shape}")
# Lazy import pytesseract
if pytesseract is None:
logger.info("[Tesseract] Importing pytesseract...")
import pytesseract as _pytesseract
pytesseract = _pytesseract
# PSM 4: Single column (best for receipts)
config = '--psm 4 -l ron+eng'
text = pytesseract.image_to_string(image, config=config)
# Quick confidence estimate
data = pytesseract.image_to_data(image, config=config, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
def _doctr_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using docTR."""
# Wait for docTR to be fully ready
if not self.wait_for_doctr(timeout=30.0):
logger.warning("[docTR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("docTR not ready and Tesseract not available")
try:
logger.info(f"[docTR] Processing image, shape: {image.shape}")
# docTR requires RGB images
import cv2
if len(image.shape) == 2:
# Convert grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
logger.info(f"[docTR] Converted grayscale to RGB, new shape: {image.shape}")
elif image.shape[2] == 4:
# Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
logger.info(f"[docTR] Converted RGBA to RGB, new shape: {image.shape}")
elif image.shape[2] == 3:
# Check if BGR (from OpenCV) and convert to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
logger.info(f"[docTR] Converted BGR to RGB, shape: {image.shape}")
# Process image with docTR
logger.info("[docTR] Running prediction...")
from doctr.io import DocumentFile
# docTR expects a document (list of pages as numpy arrays)
result = self._doctr([image])
if not result or not result.pages:
logger.warning("[docTR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="doctr")
# Extract text from all pages
all_texts = []
all_confidences = []
boxes = []
for page in result.pages:
for block in page.blocks:
for line in block.lines:
line_text = ' '.join(word.value for word in line.words)
line_confidence = sum(w.confidence for w in line.words) / len(line.words) if line.words else 0.0
all_texts.append(line_text)
all_confidences.append(line_confidence)
# Store word-level boxes
for word in line.words:
boxes.append({
'text': word.value,
'confidence': float(word.confidence),
'box': word.geometry # (xmin, ymin), (xmax, ymax)
})
text_result = '\n'.join(all_texts)
avg_conf = sum(all_confidences) / len(all_confidences) if all_confidences else 0.0
logger.info(f"[docTR] SUCCESS - Found {len(all_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[docTR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="doctr"
)
except Exception as e:
logger.error(f"[docTR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
"""
Run both OCR engines and return both results.
Returns:
Tuple of (paddle_result, tesseract_result)
tesseract_result may be None if Tesseract is not available
"""
logger.info(f"[OCR Dual] Starting dual recognition, image shape: {image.shape}")
# Lazy init PaddleOCR
self._init_paddle_lazy()
paddle_result = None
tesseract_result = None
# Run PaddleOCR
if PADDLE_AVAILABLE and self._paddle:
try:
logger.info("[OCR Dual] Running PaddleOCR...")
paddle_result = self._paddle_recognize(image)
logger.info(f"[OCR Dual] PaddleOCR: {len(paddle_result.text)} chars, conf: {paddle_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] PaddleOCR failed: {e}")
paddle_result = OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# Run Tesseract
if TESSERACT_AVAILABLE:
try:
logger.info("[OCR Dual] Running Tesseract...")
tesseract_result = self._tesseract_recognize(image)
logger.info(f"[OCR Dual] Tesseract: {len(tesseract_result.text)} chars, conf: {tesseract_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] Tesseract failed: {e}")
tesseract_result = OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Fallback if PaddleOCR not available
if paddle_result is None:
if tesseract_result:
paddle_result = tesseract_result
else:
raise RuntimeError("No OCR engine available")
return paddle_result, tesseract_result
@staticmethod
def get_available_engines() -> List[str]:
"""
Return list of available OCR engines.
Respects OCR_ENABLE_PADDLEOCR and OCR_ENABLE_TESSERACT from .env.
Engines that are disabled via .env are not returned even if installed.
Available engines: tesseract, doctr, doctr_plus, paddleocr
"""
# Check .env settings
paddle_enabled = os.getenv("OCR_ENABLE_PADDLEOCR", "true").lower() == "true"
tesseract_enabled = os.getenv("OCR_ENABLE_TESSERACT", "true").lower() == "true"
engines = []
# Base engines (only if installed AND enabled)
if TESSERACT_AVAILABLE and tesseract_enabled:
engines.append('tesseract')
if DOCTR_AVAILABLE:
engines.append('doctr')
engines.append('doctr_plus') # docTR with 2-tier sequential + early exit
if PADDLE_AVAILABLE and paddle_enabled:
engines.append('paddleocr')
return engines

View File

@@ -0,0 +1,735 @@
"""Main OCR service coordinating preprocessing, recognition, and extraction."""
import os
import re
import gc
import logging
import threading
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x) - must be set before import
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
import time
import asyncio
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
from pathlib import Path
from typing import Optional, Tuple
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr_extractor import ReceiptExtractor, ExtractionResult
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
# Setup logging
logger = logging.getLogger(__name__)
def get_memory_usage_mb() -> float:
"""Get current process memory usage in MB."""
try:
import resource
# Get memory in KB, convert to MB
rusage = resource.getrusage(resource.RUSAGE_SELF)
return rusage.ru_maxrss / 1024 # Linux returns KB
except Exception:
return 0.0
class OCRService:
"""Service for OCR processing of receipt images."""
# Single worker to prevent memory accumulation from parallel OCR
_executor = ThreadPoolExecutor(max_workers=1)
# Semaphore to ensure only one OCR operation at a time (memory protection)
_ocr_semaphore = threading.Semaphore(1)
# Memory threshold in MB - if exceeded, force GC before processing
_memory_threshold_mb = 2500
def __init__(self):
self.preprocessor = ImagePreprocessor()
self.ocr_engine = OCREngine()
self.extractor = ReceiptExtractor()
async def process_image(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""
Process receipt image and extract structured data.
Args:
image_path: Path to the image file
mime_type: MIME type of the file
Returns:
Tuple of (success, message, extraction_result)
"""
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self._executor,
self._process_sync,
image_path,
mime_type
)
return result
except Exception as e:
return False, f"OCR processing failed: {str(e)}", None
def _cleanup_memory(self, *arrays):
"""Explicitly delete numpy arrays and force garbage collection."""
for arr in arrays:
if arr is not None:
try:
del arr
except:
pass
gc.collect()
def _process_sync(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""Synchronous processing with ADAPTIVE OCR pipeline."""
# Acquire semaphore to ensure only one OCR at a time
acquired = self._ocr_semaphore.acquire(timeout=120) # 2 min timeout
if not acquired:
return False, "OCR service busy - please try again", None
try:
return self._process_sync_internal(image_path, mime_type)
finally:
# Always release semaphore and cleanup
self._ocr_semaphore.release()
# Force garbage collection after EVERY OCR request
gc.collect()
mem_after = get_memory_usage_mb()
print(f"[OCR Service] Memory after cleanup: {mem_after:.0f}MB", flush=True)
def _process_sync_internal(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""Internal processing - called with semaphore held."""
start_time = time.time()
mem_before = get_memory_usage_mb()
print(f"[OCR Service] Starting processing: {image_path}, mime: {mime_type}", flush=True)
print(f"[OCR Service] Memory before: {mem_before:.0f}MB", flush=True)
# Check if memory is high - force GC before processing
if mem_before > self._memory_threshold_mb:
print(f"[OCR Service] WARNING: Memory high ({mem_before:.0f}MB > {self._memory_threshold_mb}MB), forcing GC...", flush=True)
gc.collect()
mem_after_gc = get_memory_usage_mb()
print(f"[OCR Service] Memory after pre-GC: {mem_after_gc:.0f}MB", flush=True)
# Load image
images = None # For cleanup
image = None
if mime_type == 'application/pdf':
try:
images = self.preprocessor.pdf_to_images(image_path)
if not images:
return False, "Failed to extract images from PDF", None
image = images[0]
# Delete other pages immediately to save memory
if len(images) > 1:
for i in range(1, len(images)):
del images[i]
images = [image]
except RuntimeError as e:
return False, str(e), None
else:
try:
image = self.preprocessor.load_image(image_path)
except ValueError as e:
return False, str(e), None
raw_texts = []
extraction = None
# ══════════════════════════════════════════════════════════════
# STEP 1: PaddleOCR + Light (fastest, best for clear PDFs)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 1: PaddleOCR + Light preprocessing", flush=True)
print("=" * 60, flush=True)
light_img = self.preprocessor.preprocess_light(image)
try:
paddle_light = self.ocr_engine._paddle_recognize(light_img)
# Cleanup light_img immediately after OCR
del light_img
light_img = None
if paddle_light and paddle_light.text:
extraction = self.extractor.extract(paddle_light.text)
extraction.ocr_engine = "paddle-light"
raw_texts.append(f"═══ PaddleOCR (light, conf: {paddle_light.confidence:.0%}) ═══\n{paddle_light.text}")
# Log extraction results
print(f"[OCR] Step 1 Results:", flush=True)
print(f" - OCR Confidence: {paddle_light.confidence:.0%}", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total} (entries: {len(extraction.tva_entries) if extraction.tva_entries else 0})", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
# Early exit if complete
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] *** EARLY EXIT at Step 1 - All fields found! ({elapsed_ms}ms) ***", flush=True)
# Cleanup before return
del image
if images:
del images
return True, "OCR complete (fast mode)", extraction
else:
print("[OCR] -> Step 1 incomplete, continuing to Step 2...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR light failed: {e}", flush=True)
extraction = ExtractionResult()
# Cleanup on error
if light_img is not None:
del light_img
# ══════════════════════════════════════════════════════════════
# STEP 2: PaddleOCR + Medium (balanced preprocessing)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 2: PaddleOCR + Medium preprocessing", flush=True)
print("=" * 60, flush=True)
medium_img = self.preprocessor.preprocess_medium(image)
try:
paddle_medium = self.ocr_engine._paddle_recognize(medium_img)
# Cleanup medium_img immediately after OCR
del medium_img
medium_img = None
if paddle_medium and paddle_medium.text:
extraction_medium = self.extractor.extract(paddle_medium.text)
extraction_medium.ocr_engine = "paddle-medium"
raw_texts.append(f"═══ PaddleOCR (medium, conf: {paddle_medium.confidence:.0%}) ═══\n{paddle_medium.text}")
print(f"[OCR] Step 2 (Medium) Results:", flush=True)
print(f" - OCR Confidence: {paddle_medium.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_medium.amount}", flush=True)
print(f" - Date: {extraction_medium.receipt_date}", flush=True)
print(f" - CUI: {extraction_medium.cui}", flush=True)
# Merge with previous
extraction = self._merge_extractions(extraction, extraction_medium)
print(f"[OCR] After merge:", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "paddle-adaptive"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] *** EARLY EXIT at Step 2 - All fields found after merge! ({elapsed_ms}ms) ***", flush=True)
# Cleanup before return
del image
if images:
del images
return True, "OCR complete (paddle dual)", extraction
else:
print("[OCR] -> Step 2 incomplete, continuing to Step 3 (Tesseract)...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR medium failed: {e}", flush=True)
# Cleanup on error
if medium_img is not None:
del medium_img
# ══════════════════════════════════════════════════════════════
# STEP 3: Tesseract - ONLY to complete missing fields
# Uses Tesseract-optimized preprocessing (binarized, high contrast)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 3: Tesseract (complement only, not override)", flush=True)
print("=" * 60, flush=True)
tesseract_img = None
try:
# Use Tesseract-specific preprocessing (Otsu binarization)
tesseract_img = self.preprocessor.preprocess_for_tesseract(image)
tesseract_result = self.ocr_engine._tesseract_recognize(tesseract_img)
# Cleanup tesseract_img immediately after OCR
del tesseract_img
tesseract_img = None
if tesseract_result and tesseract_result.text:
extraction_tess = self.extractor.extract(tesseract_result.text)
extraction_tess.ocr_engine = "tesseract"
raw_texts.append(f"═══ Tesseract (conf: {tesseract_result.confidence:.0%}) ═══\n{tesseract_result.text}")
print(f"[OCR] Step 3 (Tesseract) Results:", flush=True)
print(f" - OCR Confidence: {tesseract_result.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_tess.amount}", flush=True)
print(f" - Date: {extraction_tess.receipt_date}", flush=True)
print(f" - CUI: {extraction_tess.cui}", flush=True)
# IMPORTANT: Tesseract only COMPLETES missing fields, never overrides!
extraction = self._complement_extraction(extraction, extraction_tess)
except Exception as e:
print(f"[OCR] Tesseract failed: {e}", flush=True)
# Cleanup on error
if tesseract_img is not None:
del tesseract_img
# Cleanup original image - no longer needed
del image
if images:
del images
# ══════════════════════════════════════════════════════════════
# FINAL VALIDATION: Fix impossible values
# ══════════════════════════════════════════════════════════════
if extraction:
extraction = self._final_validation(extraction)
# Final result
if extraction is None:
return False, "No text detected", None
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "adaptive-full"
# Build result message
fields_found = []
if extraction.amount: fields_found.append("amount")
if extraction.receipt_date: fields_found.append("date")
if extraction.receipt_number: fields_found.append("number")
if extraction.cui: fields_found.append("CUI")
if extraction.tva_total or extraction.tva_entries: fields_found.append("TVA")
message = f"OCR complete (full pipeline). Found: {', '.join(fields_found) or 'no fields'}"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print("=" * 60, flush=True)
print(f"[OCR] FINAL RESULT (full pipeline) - {elapsed_ms}ms", flush=True)
print("=" * 60, flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
print(f" - Processing Time: {elapsed_ms}ms", flush=True)
print(f" - Message: {message}", flush=True)
# ══════════════════════════════════════════════════════════════
# VALIDATION: Apply validation rules to final extraction
# ══════════════════════════════════════════════════════════════
print("\n" + "=" * 60, flush=True)
print("[Validation] Applying validation rules...", flush=True)
print("=" * 60, flush=True)
validator = OCRValidationEngine()
# Prepare data for validation with safe type conversions
def safe_float(value) -> Optional[float]:
"""Safely convert Decimal or number to float."""
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def safe_payment_sum(methods: list, method_type: str) -> Optional[float]:
"""Safely sum payment amounts for a given method type."""
if not methods:
return None
try:
total = sum(
float(pm.get('amount', 0) or 0)
for pm in methods
if pm.get('method') == method_type
)
return total if total > 0 else None
except (TypeError, ValueError):
return None
validation_data = {
'amount': safe_float(extraction.amount),
'tva': safe_float(extraction.tva_total),
'cui': extraction.cui,
'card_amount': safe_payment_sum(extraction.payment_methods, 'CARD'),
'cash_amount': safe_payment_sum(extraction.payment_methods, 'NUMERAR'),
'tva_entries': {
entry.get('code', ''): safe_float(entry.get('amount'))
for entry in (extraction.tva_entries or [])
if entry.get('code') and safe_float(entry.get('amount')) is not None
}
}
# Run validation (no light/medium comparison for final result)
validated_result = validator.validate_extraction(validation_data)
# Apply validation results to extraction
extraction.needs_manual_review = validated_result.needs_manual_review
extraction.validation_warnings = validated_result.validation_warnings
extraction.validation_errors = validated_result.validation_errors
extraction.confidence_adjustments = validated_result.confidence_adjustments
extraction.inter_ocr_ratios = validated_result.inter_ocr_ratios
print(f"[Validation] Complete:", flush=True)
print(f" - Warnings: {len(extraction.validation_warnings)}", flush=True)
print(f" - Errors: {len(extraction.validation_errors)}", flush=True)
print(f" - Needs Manual Review: {extraction.needs_manual_review}", flush=True)
if extraction.validation_warnings:
for warning in extraction.validation_warnings:
print(f" [!] {warning}", flush=True)
return True, message, extraction
def _merge_extractions(
self,
paddle: Optional[ExtractionResult],
tesseract: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Merge two extractions, picking best fields from each engine.
Strategy:
- For each field, prefer the one with higher confidence
- Use validation rules (CUI format, date validity, company indicators)
- Combine TVA entries if different
"""
result = ExtractionResult()
# Handle case where one is None
if paddle is None and tesseract is None:
return result
if paddle is None:
return tesseract
if tesseract is None:
return paddle
print("[Merge] Comparing PaddleOCR vs Tesseract extractions...", flush=True)
# === AMOUNT ===
# Pick higher confidence, both must be positive
if paddle.amount and tesseract.amount:
if paddle.confidence_amount >= tesseract.confidence_amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
print(f"[Merge] Amount: PaddleOCR {paddle.amount} (conf: {paddle.confidence_amount:.0%})", flush=True)
else:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
print(f"[Merge] Amount: Tesseract {tesseract.amount} (conf: {tesseract.confidence_amount:.0%})", flush=True)
elif paddle.amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
elif tesseract.amount:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
# === DATE ===
# Pick higher confidence, validate date reasonableness
if paddle.receipt_date and tesseract.receipt_date:
if paddle.confidence_date >= tesseract.confidence_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
print(f"[Merge] Date: PaddleOCR {paddle.receipt_date}", flush=True)
else:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
print(f"[Merge] Date: Tesseract {tesseract.receipt_date}", flush=True)
elif paddle.receipt_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
elif tesseract.receipt_date:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
# === VENDOR NAME ===
# Prefer one with company indicators (S.R.L., S.A., etc.)
paddle_has_indicator = self._has_company_indicator(paddle.partner_name)
tesseract_has_indicator = self._has_company_indicator(tesseract.partner_name)
if paddle.partner_name and tesseract.partner_name:
if paddle_has_indicator and not tesseract_has_indicator:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (has company indicator)", flush=True)
elif tesseract_has_indicator and not paddle_has_indicator:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (has company indicator)", flush=True)
elif paddle.confidence_vendor >= tesseract.confidence_vendor:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (higher conf)", flush=True)
else:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (higher conf)", flush=True)
elif paddle.partner_name:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
elif tesseract.partner_name:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
# === CUI (Fiscal Code) ===
# Validate format: 6-10 digits, prefer valid one
paddle_cui_valid = self._is_valid_cui(paddle.cui)
tesseract_cui_valid = self._is_valid_cui(tesseract.cui)
if paddle.cui and tesseract.cui:
if paddle_cui_valid and not tesseract_cui_valid:
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui} (valid format)", flush=True)
elif tesseract_cui_valid and not paddle_cui_valid:
result.cui = tesseract.cui
print(f"[Merge] CUI: Tesseract {tesseract.cui} (valid format)", flush=True)
else:
# Both valid or both invalid - prefer PaddleOCR
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui}", flush=True)
elif paddle.cui and paddle_cui_valid:
result.cui = paddle.cui
elif tesseract.cui and tesseract_cui_valid:
result.cui = tesseract.cui
elif paddle.cui:
result.cui = paddle.cui
elif tesseract.cui:
result.cui = tesseract.cui
# === TVA ENTRIES ===
# Prefer non-empty, use the one with more entries or higher amounts
if paddle.tva_entries and tesseract.tva_entries:
# Compare: prefer the one with actual amounts (not just 0)
paddle_total = sum(e.get('amount', Decimal('0')) for e in paddle.tva_entries)
tesseract_total = sum(e.get('amount', Decimal('0')) for e in tesseract.tva_entries)
if paddle_total >= tesseract_total:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
print(f"[Merge] TVA: PaddleOCR (total: {paddle_total})", flush=True)
else:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
print(f"[Merge] TVA: Tesseract (total: {tesseract_total})", flush=True)
elif paddle.tva_entries:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
elif tesseract.tva_entries:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
# === OTHER FIELDS ===
# Simple preference: paddle > tesseract
result.receipt_number = paddle.receipt_number or tesseract.receipt_number
result.receipt_series = paddle.receipt_series or tesseract.receipt_series
result.receipt_type = paddle.receipt_type or tesseract.receipt_type
result.items_count = paddle.items_count or tesseract.items_count
result.address = paddle.address or tesseract.address
result.description = paddle.description or tesseract.description
return result
def _has_company_indicator(self, name: Optional[str]) -> bool:
"""Check if vendor name has company type indicator (S.R.L., S.A., etc.)"""
if not name:
return False
name_upper = name.upper()
indicators = [
r'\bS\.?\s*R\.?\s*L\.?\b',
r'\bS\.?\s*A\.?\b',
r'\bS\.?\s*N\.?\s*C\.?\b',
r'\bP\.?\s*F\.?\s*A\.?\b',
r'\bI\.?\s*I\.?\b',
r'\bHOLDING\b',
r'\bGROUP\b',
r'\bCOMPANY\b',
]
for indicator in indicators:
if re.search(indicator, name_upper):
return True
return False
def _is_valid_cui(self, cui: Optional[str]) -> bool:
"""Validate CUI format: 6-10 digits."""
if not cui:
return False
# Remove any RO prefix
cui_clean = re.sub(r'^RO', '', cui.upper())
# Must be 6-10 digits
return bool(re.match(r'^\d{6,10}$', cui_clean))
def _is_extraction_complete(self, ext: ExtractionResult, min_confidence: float = 0.85) -> bool:
"""
Check if extraction has ALL required fields to skip further processing.
Required for early exit (ALL must be true):
- Overall confidence >= 85%
- ALL 5 critical fields present: number, date, amount, TVA, CUI
"""
# Must have high confidence
if ext.overall_confidence < min_confidence:
print(f"[OCR] Confidence {ext.overall_confidence:.0%} < {min_confidence:.0%} - continuing", flush=True)
return False
# Check all required fields
has_number = bool(ext.receipt_number)
has_date = bool(ext.receipt_date)
has_amount = bool(ext.amount)
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
has_cui = bool(ext.cui)
missing = []
if not has_number: missing.append("number")
if not has_date: missing.append("date")
if not has_amount: missing.append("amount")
if not has_tva: missing.append("TVA")
if not has_cui: missing.append("CUI")
if missing:
print(f"[OCR] Missing: {', '.join(missing)} - continuing", flush=True)
return False
print(f"[OCR] OK: All 5 fields found with {ext.overall_confidence:.0%} confidence", flush=True)
return True
def _complement_extraction(
self,
primary: Optional[ExtractionResult],
secondary: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Complement primary extraction with missing fields from secondary.
NEVER overrides existing values - only fills in gaps.
This is different from _merge_extractions which can override values.
"""
if primary is None and secondary is None:
return ExtractionResult()
if primary is None:
return secondary
if secondary is None:
return primary
print("[Complement] Adding missing fields from Tesseract...", flush=True)
# Only fill missing amount
if not primary.amount and secondary.amount:
primary.amount = secondary.amount
primary.confidence_amount = secondary.confidence_amount
print(f"[Complement] Added amount: {secondary.amount}", flush=True)
# Only fill missing date
if not primary.receipt_date and secondary.receipt_date:
primary.receipt_date = secondary.receipt_date
primary.confidence_date = secondary.confidence_date
print(f"[Complement] Added date: {secondary.receipt_date}", flush=True)
# Only fill missing vendor
if not primary.partner_name and secondary.partner_name:
primary.partner_name = secondary.partner_name
primary.confidence_vendor = secondary.confidence_vendor
print(f"[Complement] Added vendor: {secondary.partner_name}", flush=True)
# Only fill missing CUI
if not primary.cui and secondary.cui and self._is_valid_cui(secondary.cui):
primary.cui = secondary.cui
print(f"[Complement] Added CUI: {secondary.cui}", flush=True)
# Only fill missing TVA
if not primary.tva_entries and secondary.tva_entries:
primary.tva_entries = secondary.tva_entries
primary.tva_total = secondary.tva_total
print(f"[Complement] Added TVA: {secondary.tva_total}", flush=True)
# Only fill missing receipt number
if not primary.receipt_number and secondary.receipt_number:
primary.receipt_number = secondary.receipt_number
print(f"[Complement] Added number: {secondary.receipt_number}", flush=True)
# Only fill missing address
if not primary.address and secondary.address:
primary.address = secondary.address
print(f"[Complement] Added address: {secondary.address}", flush=True)
return primary
def _final_validation(self, extraction: ExtractionResult) -> ExtractionResult:
"""
Final validation and correction of impossible values.
Key rules:
1. TVA cannot be greater than TOTAL (it's always a fraction)
2. If TVA > TOTAL, recalculate TOTAL from TVA using known rates
3. Validate TVA entries sum equals TVA total
"""
print("[Final Validation] Checking extracted values...", flush=True)
# Rule 1: TVA cannot be greater than TOTAL
if extraction.tva_total and extraction.amount:
if extraction.tva_total > extraction.amount:
print(f"[Final Validation] TVA ({extraction.tva_total}) > TOTAL ({extraction.amount}) - IMPOSSIBLE!", flush=True)
# Calculate TOTAL from TVA using reverse formula:
# total = base + tva = tva * (100/rate + 1) = tva * (100 + rate) / rate
# For 9% TVA: total = tva * 109 / 9 = tva * 12.11
# For 19% TVA: total = tva * 119 / 19 = tva * 6.26
# For 21% TVA: total = tva * 121 / 21 = tva * 5.76
rate = 19 # Default rate assumption
if extraction.tva_entries:
# Use the rate from the first entry
rate = extraction.tva_entries[0].get('percent', 19)
if rate > 0:
# Formula: total = tva * (100 + rate) / rate
calculated_total = extraction.tva_total * (Decimal('100') + Decimal(str(rate))) / Decimal(str(rate))
calculated_total = calculated_total.quantize(Decimal('0.01'))
print(f"[Final Validation] Calculated TOTAL from TVA: {calculated_total} (using {rate}% rate)", flush=True)
extraction.amount = calculated_total
extraction.confidence_amount = 0.70 # Lower confidence for calculated value
# Rule 2: TVA cannot be more than ~25% of total (max Romanian rate is 21%)
if extraction.tva_total and extraction.amount:
tva_percent = extraction.tva_total / extraction.amount * Decimal('100')
if tva_percent > Decimal('25'):
print(f"[Final Validation] Warning: TVA is {tva_percent:.1f}% of total - suspicious", flush=True)
# Rule 3: Validate TVA entries sum
if extraction.tva_entries and extraction.tva_total:
entries_sum = sum(e.get('amount', Decimal('0')) for e in extraction.tva_entries)
tolerance = Decimal('0.05')
if abs(entries_sum - extraction.tva_total) > tolerance:
print(f"[Final Validation] TVA entries sum ({entries_sum}) != tva_total ({extraction.tva_total})", flush=True)
# Use the sum as it's more reliable
extraction.tva_total = entries_sum
print(f"[Final Validation] Done. Amount={extraction.amount}, TVA={extraction.tva_total}", flush=True)
return extraction
# Singleton instance
ocr_service = OCRService()

View File

@@ -0,0 +1,385 @@
"""
Auto-create Receipt from OCR results for bulk upload flow.
This service handles automatic creation of Receipt records from OCR extraction
results, enabling end-to-end processing without manual UI intervention.
The service:
1. Maps OCR ExtractionData fields to Receipt fields
2. Creates attachment from the original uploaded file
3. Generates accounting entries
4. Links the receipt back to the batch job for tracking
"""
import logging
import shutil
import uuid
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.receipt import (
Receipt,
ReceiptAttachment,
ReceiptStatus,
ReceiptType,
ReceiptDirection,
)
from backend.modules.data_entry.db.models.batch import BatchJob
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, TvaEntrySchema, PaymentMethodSchema
from backend.modules.data_entry.schemas.ocr import ExtractionData
from backend.modules.data_entry.services.receipt_service import ReceiptService
from backend.modules.data_entry.services import sse_service
from backend.config import settings
logger = logging.getLogger(__name__)
@dataclass
class ReceiptCreateResult:
"""Result of auto-create operation."""
success: bool
receipt_id: Optional[int] = None
error_message: Optional[str] = None
class ReceiptAutoCreateService:
"""
Service for automatically creating receipts from OCR results.
Used by the bulk upload flow to create receipts without user intervention.
Created receipts are in DRAFT status and require review before approval.
"""
@staticmethod
def _validate_ocr_result(ocr_result: ExtractionData) -> tuple[bool, str]:
"""
Perform minimal validation on OCR result.
Validates:
- amount > 0 (required for receipt)
- date is valid and not in future
Args:
ocr_result: Extracted data from OCR
Returns:
Tuple of (is_valid, error_message)
"""
# Validate amount exists and is positive
if ocr_result.amount is None:
return False, "Amount not extracted from receipt"
if ocr_result.amount <= 0:
return False, f"Invalid amount: {ocr_result.amount} (must be > 0)"
# Validate date exists and is not in the future
if ocr_result.receipt_date is None:
return False, "Receipt date not extracted"
today = date.today()
if ocr_result.receipt_date > today:
return False, f"Receipt date {ocr_result.receipt_date} is in the future"
return True, ""
@staticmethod
def _map_ocr_to_receipt(
ocr_result: ExtractionData,
company_id: int,
) -> ReceiptCreate:
"""
Map OCR ExtractionData fields to ReceiptCreate schema.
Args:
ocr_result: Extracted data from OCR
company_id: Company ID for the receipt
Returns:
ReceiptCreate schema ready for database insertion
"""
# Map receipt type
receipt_type = ReceiptType.BON_FISCAL
if ocr_result.receipt_type == "chitanta":
receipt_type = ReceiptType.CHITANTA
# Map TVA breakdown from OCR TvaEntry to schema TvaEntrySchema
tva_breakdown: Optional[List[TvaEntrySchema]] = None
if ocr_result.tva_entries:
tva_breakdown = [
TvaEntrySchema(
code=entry.code,
percent=entry.percent,
amount=entry.amount
)
for entry in ocr_result.tva_entries
]
# Map payment methods
payment_methods: Optional[List[PaymentMethodSchema]] = None
if ocr_result.payment_methods:
payment_methods = [
PaymentMethodSchema(
method=pm.method,
amount=pm.amount
)
for pm in ocr_result.payment_methods
]
# Create receipt data
return ReceiptCreate(
receipt_type=receipt_type,
direction=ReceiptDirection.CHELTUIALA, # Default to expense
receipt_number=ocr_result.receipt_number,
receipt_series=ocr_result.receipt_series,
receipt_date=ocr_result.receipt_date,
amount=ocr_result.amount,
description=ocr_result.description,
tva_breakdown=tva_breakdown,
tva_total=ocr_result.tva_total,
items_count=ocr_result.items_count,
vendor_address=ocr_result.address,
company_id=company_id,
partner_name=ocr_result.partner_name,
cui=ocr_result.cui,
ocr_raw_text=ocr_result.raw_text[:5000] if ocr_result.raw_text else None, # Limit size
payment_methods=payment_methods,
payment_mode=ocr_result.suggested_payment_mode,
)
@staticmethod
async def _create_attachment_from_file(
session: AsyncSession,
receipt_id: int,
source_file_path: str,
original_filename: Optional[str] = None,
) -> Optional[ReceiptAttachment]:
"""
Create attachment by copying file from OCR job location.
Args:
session: Database session
receipt_id: Receipt ID to attach to
source_file_path: Path to the original file from OCR job
original_filename: Original filename from upload (optional)
Returns:
Created ReceiptAttachment or None if failed
"""
source_path = Path(source_file_path)
if not source_path.exists():
logger.warning(f"[ReceiptAutoCreate] Source file not found: {source_path}")
return None
# Generate stored filename
ext = source_path.suffix.lower()
stored_filename = f"{uuid.uuid4()}{ext}"
# Determine relative path (organized by year/month)
now = datetime.utcnow()
relative_path = Path(str(now.year)) / f"{now.month:02d}"
# Full destination path
dest_dir = settings.data_entry_upload_path_resolved / relative_path
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = dest_dir / stored_filename
# Copy file to attachments directory
try:
shutil.copy2(source_path, dest_path)
except Exception as e:
logger.error(f"[ReceiptAutoCreate] Failed to copy file: {e}")
return None
# Get file size
file_size = dest_path.stat().st_size
# Determine MIME type
mime_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.pdf': 'application/pdf',
}
mime_type = mime_map.get(ext, 'application/octet-stream')
# Use original filename if provided, otherwise use source filename
display_filename = original_filename or source_path.name
# Create attachment record
attachment = ReceiptAttachment(
receipt_id=receipt_id,
filename=display_filename,
stored_filename=stored_filename,
file_path=str(relative_path / stored_filename),
file_size=file_size,
mime_type=mime_type,
)
session.add(attachment)
await session.flush()
return attachment
@staticmethod
async def _update_batch_job_receipt_id(
session: AsyncSession,
job_id: str,
receipt_id: int,
) -> None:
"""
Update batch_jobs table with the created receipt_id.
Args:
session: Database session
job_id: OCR job UUID
receipt_id: Created receipt ID
"""
await session.execute(
update(BatchJob)
.where(BatchJob.job_id == job_id)
.values(receipt_id=receipt_id)
)
@staticmethod
async def create_from_ocr_result(
session: AsyncSession,
job_id: str,
ocr_result: ExtractionData,
username: str,
batch_id: int,
company_id: int,
file_path: Optional[str] = None,
original_filename: Optional[str] = None,
file_hash: Optional[str] = None,
) -> ReceiptCreateResult:
"""
Create a receipt from OCR extraction result.
This method:
1. Validates the OCR result (amount > 0, date valid)
2. Maps OCR fields to Receipt fields
3. Creates the Receipt in DRAFT status
4. Creates attachment from original file
5. Generates accounting entries
6. Updates batch_jobs with receipt_id
Args:
session: Database session
job_id: OCR job UUID for tracking
ocr_result: Extracted data from OCR processing
username: User who initiated the upload
batch_id: Batch ID for grouping
company_id: Company ID for the receipt
file_path: Path to the original uploaded file
original_filename: Original filename from upload
file_hash: SHA-256 hash of the file for duplicate detection (US-007)
Returns:
ReceiptCreateResult with success status and receipt_id or error
"""
try:
# Step 1: Validate OCR result
is_valid, error_msg = ReceiptAutoCreateService._validate_ocr_result(ocr_result)
if not is_valid:
logger.warning(f"[ReceiptAutoCreate] Validation failed for job {job_id}: {error_msg}")
return ReceiptCreateResult(
success=False,
error_message=error_msg
)
# Step 2: Map OCR to Receipt schema
receipt_data = ReceiptAutoCreateService._map_ocr_to_receipt(
ocr_result=ocr_result,
company_id=company_id,
)
# Step 3: Create receipt in DRAFT status
receipt = await ReceiptCRUD.create(session, receipt_data, created_by=username)
# Set batch tracking fields (US-007, US-011)
receipt.batch_id = str(batch_id)
receipt.file_hash = file_hash
receipt.processing_status = "completed"
session.add(receipt)
await session.flush()
logger.info(
f"[ReceiptAutoCreate] Created receipt {receipt.id} for job {job_id}: "
f"amount={receipt.amount}, vendor={receipt.partner_name}, file_hash={file_hash[:16] if file_hash else None}..."
)
# Step 4: Create attachment from original file (if path provided)
if file_path:
attachment = await ReceiptAutoCreateService._create_attachment_from_file(
session=session,
receipt_id=receipt.id,
source_file_path=file_path,
original_filename=original_filename,
)
if attachment:
logger.info(f"[ReceiptAutoCreate] Created attachment for receipt {receipt.id}")
else:
logger.warning(f"[ReceiptAutoCreate] Failed to create attachment for receipt {receipt.id}")
# Step 5: Generate accounting entries
# Note: For DRAFT status, entries are generated but not required for validation
try:
entries = ReceiptService.generate_accounting_entries(receipt)
if entries:
await AccountingEntryCRUD.create_bulk(
session, receipt.id, entries, is_auto_generated=True
)
logger.info(
f"[ReceiptAutoCreate] Generated {len(entries)} accounting entries "
f"for receipt {receipt.id}"
)
except Exception as e:
# Don't fail the receipt creation if entry generation fails
logger.warning(
f"[ReceiptAutoCreate] Failed to generate entries for receipt {receipt.id}: {e}"
)
# Step 6: Update batch_jobs with receipt_id
await ReceiptAutoCreateService._update_batch_job_receipt_id(
session=session,
job_id=job_id,
receipt_id=receipt.id,
)
# Commit all changes
await session.commit()
# Broadcast SSE event for real-time updates (US-030)
try:
await sse_service.broadcast_status_change(
receipt_id=receipt.id,
status=receipt.status.value,
processing_status=receipt.processing_status,
batch_id=receipt.batch_id,
)
except Exception as e:
# Don't fail the receipt creation if SSE broadcast fails
logger.warning(f"[ReceiptAutoCreate] SSE broadcast failed for receipt {receipt.id}: {e}")
return ReceiptCreateResult(
success=True,
receipt_id=receipt.id
)
except Exception as e:
logger.error(f"[ReceiptAutoCreate] Failed to create receipt for job {job_id}: {e}")
await session.rollback()
return ReceiptCreateResult(
success=False,
error_message=str(e)
)

View File

@@ -0,0 +1,457 @@
"""Business logic service for receipts workflow."""
from decimal import Decimal, ROUND_HALF_UP
from typing import List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus, ReceiptDirection
from backend.modules.data_entry.db.models.accounting_entry import EntryType
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.schemas.receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptFilter,
ReceiptResponse,
ReceiptListResponse,
ProcessingStats,
AccountingEntryCreate,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES, get_expense_type
# Payment mode to accounting account mapping
PAYMENT_MODE_ACCOUNTS = {
'casa': ('5311', 'Casa in lei'),
'banca': ('5121', 'Conturi la banci in lei'),
'avans_decontare': ('542', 'Avansuri de trezorerie'),
}
class ReceiptService:
"""Service for receipt business logic and workflow."""
@staticmethod
async def create_receipt(
session: AsyncSession,
data: ReceiptCreate,
created_by: str,
) -> Receipt:
"""Create a new receipt in DRAFT status."""
return await ReceiptCRUD.create(session, data, created_by)
@staticmethod
async def get_receipt(
session: AsyncSession,
receipt_id: int,
) -> Optional[Receipt]:
"""Get receipt by ID with all relationships."""
return await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
@staticmethod
async def get_receipts(
session: AsyncSession,
filters: ReceiptFilter,
) -> ReceiptListResponse:
"""Get paginated list of receipts with processing_stats (US-012)."""
receipts, total = await ReceiptCRUD.get_list(session, filters)
pages = (total + filters.page_size - 1) // filters.page_size if total > 0 else 1
# Get processing stats for bulk uploaded receipts (US-012)
stats_dict = await ReceiptCRUD.get_processing_stats(
session,
company_id=filters.company_id,
batch_id=filters.batch_id,
)
processing_stats = ProcessingStats(**stats_dict)
return ReceiptListResponse(
items=[ReceiptResponse.model_validate(r) for r in receipts],
total=total,
page=filters.page,
page_size=filters.page_size,
pages=pages,
processing_stats=processing_stats,
)
@staticmethod
async def update_receipt(
session: AsyncSession,
receipt_id: int,
data: ReceiptUpdate,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Update receipt (only DRAFT status).
Returns (success, message, receipt).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_edit(receipt, username):
return False, "Cannot edit this receipt", None
updated = await ReceiptCRUD.update(session, receipt, data)
return True, "Receipt updated", updated
@staticmethod
async def delete_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str]:
"""
Delete receipt (only DRAFT status).
Returns (success, message).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found"
if not await ReceiptCRUD.can_delete(receipt, username):
return False, "Cannot delete this receipt"
await ReceiptCRUD.delete(session, receipt)
return True, "Receipt deleted"
@staticmethod
def generate_accounting_entries(receipt: Receipt) -> List[AccountingEntryCreate]:
"""
Generate accounting entries based on receipt data and expense type.
"""
entries: List[AccountingEntryCreate] = []
# Get expense type configuration
expense_type = get_expense_type(receipt.expense_type_code or "OTHER")
if not expense_type:
expense_type = EXPENSE_TYPES["OTHER"]
amount = Decimal(str(receipt.amount))
if receipt.direction == ReceiptDirection.CHELTUIALA:
# Expense: Debit expense account, Credit cash/bank
if expense_type.has_vat:
# Calculate net and VAT
vat_rate = expense_type.vat_percent / Decimal("100")
net_amount = (amount / (1 + vat_rate)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
vat_amount = amount - net_amount
# Debit: Expense account (net)
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=net_amount,
))
# Debit: VAT deductible
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.vat_account,
account_name="TVA deductibila",
amount=vat_amount,
))
else:
# No VAT - full amount to expense
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=amount,
))
# Credit entry - based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
credit_account, credit_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
# Backwards compatibility for existing receipts
credit_account = receipt.cash_register_account
credit_name = receipt.cash_register_name or "Casa/Banca"
else:
# Default fallback
credit_account = "5311"
credit_name = "Casa in lei"
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code=credit_account,
account_name=credit_name,
amount=amount,
))
else:
# Income: Debit cash/bank, Credit income account
# Based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
cash_account, cash_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
cash_account = receipt.cash_register_account
cash_name = receipt.cash_register_name or "Casa/Banca"
else:
cash_account = "5311"
cash_name = "Casa in lei"
# Debit: Cash/Bank
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=cash_account,
account_name=cash_name,
amount=amount,
))
# Credit: Income account (7xx - to be configured)
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code="7588",
account_name="Alte venituri din exploatare",
amount=amount,
))
return entries
@staticmethod
async def submit_for_review(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Submit receipt for review (DRAFT/REJECTED → PENDING_REVIEW).
Generates accounting entries automatically.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_submit(receipt, username):
return False, "Cannot submit this receipt", None
# Check if receipt has at least one attachment
if not receipt.attachments:
return False, "Receipt must have at least one attachment", None
# Check required fields
if not receipt.expense_type_code:
return False, "Expense type is required", None
# Validate payment_mode or cash_register (backwards compatibility)
if not receipt.payment_mode and not receipt.cash_register_account:
return False, "Modul de plata este obligatoriu", None
# Generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Delete existing entries and create new ones
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt submitted for review", updated
@staticmethod
async def approve_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Approve receipt (PENDING_REVIEW → APPROVED).
Requires valid CUI (fiscal code) for approval.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Validate CUI is present (required for Oracle import)
if not receipt.cui:
return False, "Trebuie completat codul fiscal (CUI) pentru aprobare", None
# Validate accounting entries
if not receipt.entries:
return False, "Receipt has no accounting entries", None
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.APPROVED, reviewed_by=username
)
return True, "Receipt approved", updated
@staticmethod
async def unapprove_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Unapprove receipt (APPROVED → PENDING_REVIEW).
Returns receipt to pending review for corrections.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.APPROVED:
return False, "Receipt is not approved", None
# Update status back to pending review
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
return True, "Receipt returned to pending review", updated
@staticmethod
async def reject_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
reason: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Reject receipt (PENDING_REVIEW → REJECTED).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Update status
updated = await ReceiptCRUD.update_status(
session,
receipt,
ReceiptStatus.REJECTED,
reviewed_by=username,
rejection_reason=reason,
)
return True, "Receipt rejected", updated
@staticmethod
async def resubmit_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.REJECTED:
return False, "Receipt is not rejected", None
if receipt.created_by != username:
return False, "Only the creator can resubmit", None
# Re-generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt resubmitted for review", updated
@staticmethod
async def regenerate_entries(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, List[AccountingEntryCreate]]:
"""
Regenerate accounting entries for a receipt.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.PENDING_REVIEW]:
return False, "Cannot regenerate entries for this receipt status", []
# Generate new entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Replace existing entries
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
return True, "Entries regenerated", entries
@staticmethod
async def update_entries(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
username: str,
) -> Tuple[bool, str, List]:
"""
Update accounting entries for a receipt (accountant action).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Can only modify entries for receipts pending review", []
# Validate entries
is_valid, error = await AccountingEntryCRUD.validate_entries(entries)
if not is_valid:
return False, error, []
# Replace entries
updated_entries = await AccountingEntryCRUD.replace_all_for_receipt(
session, receipt_id, entries, username
)
return True, "Entries updated", updated_entries
@staticmethod
async def get_pending_count(
session: AsyncSession,
company_id: Optional[int] = None,
) -> int:
"""Get count of receipts pending review."""
receipts = await ReceiptCRUD.get_pending_review(session, company_id)
return len(receipts)

View File

@@ -0,0 +1,197 @@
"""
Server-Sent Events (SSE) service for real-time status updates.
This module implements an event broadcaster pattern using asyncio.Queue per client.
When receipt status changes occur (CRUD operations), events are pushed to all
connected clients who are listening for that specific batch or all receipts.
Usage:
# In router endpoint (SSE stream):
async for event in sse_service.subscribe(batch_id=None):
yield event
# When status changes (from CRUD operations):
await sse_service.broadcast_status_change(receipt_id, status, processing_status, batch_id)
"""
import asyncio
import json
import logging
from dataclasses import dataclass, asdict
from typing import AsyncGenerator, Optional
from datetime import datetime
logger = logging.getLogger(__name__)
@dataclass
class StatusChangeEvent:
"""Event data for receipt status changes."""
receipt_id: int
status: str
processing_status: Optional[str] = None
batch_id: Optional[str] = None
timestamp: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow().isoformat()
def to_sse_data(self) -> str:
"""Format as SSE data line."""
data = asdict(self)
return f"data: {json.dumps(data)}\n\n"
class SSEEventBroadcaster:
"""
Manages SSE client connections and broadcasts events.
Each client gets its own asyncio.Queue. When an event occurs,
it's pushed to all relevant queues based on batch_id filtering.
"""
def __init__(self):
# Dict of {client_id: (queue, batch_id_filter)}
# batch_id_filter is None for clients that want all events
self._clients: dict[str, tuple[asyncio.Queue, Optional[str]]] = {}
self._client_counter = 0
self._lock = asyncio.Lock()
async def _generate_client_id(self) -> str:
"""Generate unique client ID."""
async with self._lock:
self._client_counter += 1
return f"client_{self._client_counter}_{datetime.utcnow().timestamp()}"
async def subscribe(
self,
batch_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""
Subscribe to SSE events.
Args:
batch_id: Optional filter - only receive events for this batch.
If None, receives all events.
Yields:
SSE-formatted event strings (ready to send to client).
"""
client_id = await self._generate_client_id()
queue: asyncio.Queue = asyncio.Queue()
# Register client
async with self._lock:
self._clients[client_id] = (queue, batch_id)
logger.info(
f"SSE client {client_id} connected (batch_id filter: {batch_id}). "
f"Total clients: {len(self._clients)}"
)
try:
# Send initial retry hint for reconnection
yield "retry: 3000\n\n"
# Keep connection alive and yield events
while True:
try:
# Wait for events with timeout for keep-alive
event = await asyncio.wait_for(queue.get(), timeout=30.0)
yield event
except asyncio.TimeoutError:
# Send keep-alive comment to prevent connection timeout
yield ": keep-alive\n\n"
except asyncio.CancelledError:
logger.info(f"SSE client {client_id} subscription cancelled")
raise
finally:
# Cleanup: remove client from registry
async with self._lock:
self._clients.pop(client_id, None)
logger.info(
f"SSE client {client_id} disconnected. "
f"Remaining clients: {len(self._clients)}"
)
async def broadcast_status_change(
self,
receipt_id: int,
status: str,
processing_status: Optional[str] = None,
batch_id: Optional[str] = None,
) -> int:
"""
Broadcast a status change event to all relevant clients.
Args:
receipt_id: The receipt ID that changed.
status: New workflow status (DRAFT, PENDING_REVIEW, etc.).
processing_status: New processing status (pending, processing, completed, failed).
batch_id: The batch ID this receipt belongs to (for filtering).
Returns:
Number of clients notified.
"""
event = StatusChangeEvent(
receipt_id=receipt_id,
status=status,
processing_status=processing_status,
batch_id=batch_id,
)
sse_data = event.to_sse_data()
notified = 0
async with self._lock:
for client_id, (queue, client_batch_filter) in self._clients.items():
# Send event if:
# 1. Client has no filter (wants all events), OR
# 2. Client's filter matches the event's batch_id
if client_batch_filter is None or client_batch_filter == batch_id:
try:
queue.put_nowait(sse_data)
notified += 1
except asyncio.QueueFull:
logger.warning(
f"SSE queue full for client {client_id}, dropping event"
)
if notified > 0:
logger.debug(
f"SSE broadcast: receipt_id={receipt_id}, status={status}, "
f"processing_status={processing_status}, notified={notified} clients"
)
return notified
@property
def client_count(self) -> int:
"""Get current number of connected clients."""
return len(self._clients)
# Singleton instance for the application
sse_broadcaster = SSEEventBroadcaster()
# Convenience functions for external use
async def subscribe(batch_id: Optional[str] = None) -> AsyncGenerator[str, None]:
"""Subscribe to SSE status change events."""
async for event in sse_broadcaster.subscribe(batch_id):
yield event
async def broadcast_status_change(
receipt_id: int,
status: str,
processing_status: Optional[str] = None,
batch_id: Optional[str] = None,
) -> int:
"""Broadcast a status change event."""
return await sse_broadcaster.broadcast_status_change(
receipt_id=receipt_id,
status=status,
processing_status=processing_status,
batch_id=batch_id,
)

View File

@@ -0,0 +1,451 @@
"""Service for syncing nomenclatures from Oracle to SQLite."""
import sys
from pathlib import Path
from typing import Optional, List, Tuple
from datetime import datetime
import logging
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
# Path setup handled by main.py - this is redundant
# project_root = Path(__file__).parent.parent.parent.parent.parent
# sys.path.insert(0, str(project_root / "shared"))
from shared.database.oracle_pool import oracle_pool
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
logger = logging.getLogger(__name__)
# Cache for schema lookups (populated dynamically from Oracle)
# Key format: (server_id, company_id) for multi-server support
_schema_cache: dict[tuple, str] = {}
class SyncService:
"""Service for syncing nomenclatures from Oracle."""
@staticmethod
async def get_schema_for_company(company_id: int, server_id: Optional[str] = None) -> Optional[str]:
"""
Get Oracle schema for company ID from V_NOM_FIRME view.
Results are cached in memory for performance.
Args:
company_id: The company ID to look up
server_id: Optional Oracle server ID for multi-server mode
"""
# Check cache first - use (server_id, company_id) as key for multi-server support
cache_key = (server_id, company_id)
if cache_key in _schema_cache:
return _schema_cache[cache_key]
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT SCHEMA
FROM CONTAFIN_ORACLE.V_NOM_FIRME
WHERE ID_FIRMA = :company_id
""", {'company_id': company_id})
result = cursor.fetchone()
if result:
schema = result[0]
_schema_cache[cache_key] = schema
logger.info(f"Resolved schema for company {company_id} on server {server_id}: {schema}")
return schema
else:
logger.warning(f"No schema found for company {company_id} on server {server_id}")
return None
except Exception as e:
logger.error(f"Error fetching schema for company {company_id} on server {server_id}: {e}")
return None
@staticmethod
async def sync_suppliers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
"""
Sync suppliers (furnizori, id_tip_part=17) from Oracle to SQLite.
Uses CORESP_TIP_PART joined with VNOM_PARTENERI view.
Returns (synced_count, error_count).
Args:
session: SQLAlchemy async session for SQLite
company_id: The company ID to sync suppliers for
server_id: Optional Oracle server ID for multi-server mode
"""
schema = await SyncService.get_schema_for_company(company_id, server_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
return 0, 0
synced = 0
errors = 0
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Fetch active suppliers from Oracle
# id_tip_part = 17 means "furnizori" (suppliers)
# Using CORESP_TIP_PART to filter by partner type
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, B.COD_FISCAL, B.ADRESA
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART = 17
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY B.DENUMIRE
""")
rows = cursor.fetchall()
for row in rows:
try:
oracle_id, name, fiscal_code, address = row
# Check if already exists
stmt = select(SyncedSupplier).where(
SyncedSupplier.oracle_id == oracle_id,
SyncedSupplier.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.fiscal_code = fiscal_code
existing.address = address
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated supplier {oracle_id}: {name}")
else:
# Create new record
supplier = SyncedSupplier(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
fiscal_code=fiscal_code,
address=address,
)
session.add(supplier)
logger.debug(f"Created supplier {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing supplier row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} suppliers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing suppliers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
async def sync_cash_registers(session: AsyncSession, company_id: int, server_id: Optional[str] = None) -> Tuple[int, int]:
"""
Sync cash registers and bank accounts from Oracle to SQLite.
Returns (synced_count, error_count).
Uses CORESP_TIP_PART with:
- id_tip_part = 22: CASA LEI
- id_tip_part = 23: CASA VALUTA
- id_tip_part = 24: BANCA LEI
- id_tip_part = 25: BANCA VALUTA
Args:
session: SQLAlchemy async session for SQLite
company_id: The company ID to sync cash registers for
server_id: Optional Oracle server ID for multi-server mode
"""
schema = await SyncService.get_schema_for_company(company_id, server_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id} on server {server_id}")
return 0, 0
synced = 0
errors = 0
# Partner types mapping
# 22=CASA LEI, 23=CASA VALUTA -> cash
# 24=BANCA LEI, 25=BANCA VALUTA -> bank
partner_types = [22, 23, 24, 25]
try:
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Fetch cash/bank partners from CORESP_TIP_PART
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, A.ID_TIP_PART
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART IN (22, 23, 24, 25)
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY A.ID_TIP_PART, B.DENUMIRE
""")
rows = cursor.fetchall()
# Type mapping: 22=CASA LEI, 23=CASA VALUTA -> cash; 24=BANCA LEI, 25=BANCA VALUTA -> bank
type_mapping = {
22: ("cash", "CASA_LEI"),
23: ("cash", "CASA_VALUTA"),
24: ("bank", "BANCA_LEI"),
25: ("bank", "BANCA_VALUTA"),
}
for row in rows:
try:
oracle_id, name, tip_part_id = row
# Determine type based on partner type
register_type, account_code = type_mapping.get(tip_part_id, ("cash", "UNKNOWN"))
# Check if already exists
stmt = select(SyncedCashRegister).where(
SyncedCashRegister.oracle_id == oracle_id,
SyncedCashRegister.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.account_code = account_code
existing.register_type = register_type
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated cash register {oracle_id}: {name}")
else:
# Create new record
cash_register = SyncedCashRegister(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
account_code=account_code,
register_type=register_type,
)
session.add(cash_register)
logger.debug(f"Created cash register {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing cash register row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} cash registers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing cash registers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
def _get_fiscal_code_variants(fiscal_code: str) -> list:
"""
Generate all possible variants of a Romanian fiscal code (CUI).
Database may store: "22891860", "RO22891860", "RO 22891860"
OCR may extract: "RO22891860" or "22891860"
"""
import re
# Extract just the digits
digits = re.sub(r'[^0-9]', '', fiscal_code)
if not digits:
return [fiscal_code]
# Generate all variants
variants = [
digits, # Just digits: 22891860
f"RO{digits}", # With RO prefix: RO22891860
f"RO {digits}", # With RO prefix and space: RO 22891860
]
# Also add the original if different
if fiscal_code not in variants:
variants.append(fiscal_code)
return variants
@staticmethod
async def search_supplier(
session: AsyncSession,
company_id: int,
fiscal_code: Optional[str] = None,
name: Optional[str] = None
) -> Tuple[bool, Optional[dict], str]:
"""
Search for supplier in SQLite first, then Oracle if not found.
Returns (found, supplier_data, source).
Source can be: 'synced', 'local', 'not_found'
"""
# 1. Search in synced suppliers
if fiscal_code:
# Search all variants of the fiscal code (with/without RO, with/without space)
variants = SyncService._get_fiscal_code_variants(fiscal_code)
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.fiscal_code.in_(variants)
)
elif name:
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.name.ilike(f"%{name}%")
)
else:
return False, None, "no_query"
result = await session.execute(stmt)
supplier = result.scalar_one_or_none()
if supplier:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": supplier.name,
"fiscal_code": supplier.fiscal_code,
"address": supplier.address,
}, "synced"
# 2. Search in local suppliers
if fiscal_code:
# Search all variants of the fiscal code (with/without RO, with/without space)
variants = SyncService._get_fiscal_code_variants(fiscal_code)
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.fiscal_code.in_(variants)
)
elif name:
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.name.ilike(f"%{name}%")
)
result = await session.execute(stmt)
local = result.scalar_one_or_none()
if local:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": local.name,
"fiscal_code": local.fiscal_code,
"address": local.address,
}, "local"
# 3. Try live Oracle search (optional fallback for unsynced data)
# This is a fallback - ideally sync should be up to date
# TODO: Implement live Oracle search if needed
return False, None, "not_found"
@staticmethod
async def create_local_supplier(
session: AsyncSession,
company_id: int,
name: str,
fiscal_code: Optional[str],
address: Optional[str],
created_by: str
) -> LocalSupplier:
"""Create a local supplier entry from OCR data."""
supplier = LocalSupplier(
company_id=company_id,
name=name,
fiscal_code=fiscal_code,
address=address,
created_by=created_by,
)
session.add(supplier)
await session.commit()
await session.refresh(supplier)
logger.info(f"Created local supplier: {name} (CUI: {fiscal_code})")
return supplier
@staticmethod
async def get_all_suppliers(
session: AsyncSession,
company_id: int,
search: Optional[str] = None
) -> List[dict]:
"""
Get all suppliers (synced + local) for a company.
Used for dropdown/autocomplete in UI.
"""
suppliers = []
# Get synced suppliers
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50) # Limit results for performance
result = await session.execute(stmt)
synced = result.scalars().all()
for s in synced:
suppliers.append({
"id": s.id,
"oracle_id": s.oracle_id,
"name": s.name,
"fiscal_code": s.fiscal_code,
"source": "synced"
})
# Get local suppliers
stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50)
result = await session.execute(stmt)
local = result.scalars().all()
for l in local:
suppliers.append({
"id": l.id,
"name": l.name,
"fiscal_code": l.fiscal_code,
"source": "local"
})
return suppliers
@staticmethod
async def get_all_cash_registers(
session: AsyncSession,
company_id: int
) -> List[dict]:
"""
Get all cash registers for a company.
Used for dropdown in UI.
"""
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
return [
{
"id": r.id,
"oracle_id": r.oracle_id,
"name": r.name,
"account_code": r.account_code,
"register_type": r.register_type
}
for r in registers
]

View File

@@ -0,0 +1,66 @@
"""
Cache module for ROA2WEB
Provides hybrid two-tier caching (Memory L1 + SQLite L2)
with performance tracking and event-based invalidation.
Usage:
# Initialize cache at app startup
from app.cache import init_cache
from app.cache.config import CacheConfig
config = CacheConfig.from_env()
await init_cache(config)
# Use @cached decorator in services
from app.cache.decorators import cached
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
async def get_complete_summary(company: str, username: str):
# ... Oracle query logic ...
# Get cache manager for manual operations
from app.cache import get_cache
cache = get_cache()
await cache.invalidate(company_id=123)
"""
from .config import CacheConfig
from .cache_manager import (
init_cache,
get_cache,
close_cache,
CacheManager
)
from .decorators import cached
from .event_monitor import (
init_event_monitor,
get_event_monitor,
toggle_event_monitor,
preload_all_schema_mappings
)
from .benchmarks import run_baseline_benchmarks
__all__ = [
# Configuration
'CacheConfig',
# Cache Manager
'init_cache',
'get_cache',
'close_cache',
'CacheManager',
# Decorators
'cached',
# Event Monitor
'init_event_monitor',
'get_event_monitor',
'toggle_event_monitor',
'preload_all_schema_mappings',
# Benchmarks
'run_baseline_benchmarks',
]

View File

@@ -0,0 +1,269 @@
"""
Baseline performance benchmarking
Runs at startup to establish baseline Oracle query times
Used for calculating "time saved" by cache
"""
import time
import logging
from typing import Dict
logger = logging.getLogger(__name__)
async def run_baseline_benchmarks() -> Dict[str, float]:
"""
Run baseline benchmarks for Oracle queries (without cache)
Measures typical query times to establish performance baselines
These are used to calculate time saved when cache hits occur
NOTE: This implementation provides a framework. Actual benchmark
implementations need access to Oracle services and sample data.
Returns:
Dictionary mapping cache_type to average query time (ms)
"""
from .cache_manager import get_cache
cache = get_cache()
if not cache:
logger.warning("Cache not initialized - skipping benchmarks")
return {}
logger.info("Starting baseline performance benchmarks...")
benchmarks = {}
try:
# Benchmark: Schema lookup
logger.info("Benchmarking: schema lookup")
schema_times = await _benchmark_schema_lookup()
if schema_times:
avg_schema = sum(schema_times) / len(schema_times)
benchmarks['schema'] = avg_schema
await cache.sqlite.set_benchmark('schema', avg_schema, len(schema_times))
logger.info(f" Schema lookup: {avg_schema:.2f}ms (avg of {len(schema_times)} samples)")
# Benchmark: Companies list
logger.info("Benchmarking: companies list")
companies_time = await _benchmark_companies_list()
if companies_time:
benchmarks['companies'] = companies_time
await cache.sqlite.set_benchmark('companies', companies_time, 1)
logger.info(f" Companies list: {companies_time:.2f}ms")
# Benchmark: Dashboard summary
logger.info("Benchmarking: dashboard summary")
dashboard_time = await _benchmark_dashboard_summary()
if dashboard_time:
benchmarks['dashboard_summary'] = dashboard_time
await cache.sqlite.set_benchmark('dashboard_summary', dashboard_time, 1)
logger.info(f" Dashboard summary: {dashboard_time:.2f}ms")
# Benchmark: Dashboard trends
logger.info("Benchmarking: dashboard trends")
trends_time = await _benchmark_dashboard_trends()
if trends_time:
benchmarks['dashboard_trends'] = trends_time
await cache.sqlite.set_benchmark('dashboard_trends', trends_time, 1)
logger.info(f" Dashboard trends: {trends_time:.2f}ms")
# Benchmark: Invoices
logger.info("Benchmarking: invoices")
invoices_time = await _benchmark_invoices()
if invoices_time:
benchmarks['invoices'] = invoices_time
await cache.sqlite.set_benchmark('invoices', invoices_time, 1)
logger.info(f" Invoices: {invoices_time:.2f}ms")
# Benchmark: Treasury
logger.info("Benchmarking: treasury")
treasury_time = await _benchmark_treasury()
if treasury_time:
benchmarks['treasury'] = treasury_time
await cache.sqlite.set_benchmark('treasury', treasury_time, 1)
logger.info(f" Treasury: {treasury_time:.2f}ms")
logger.info(f"Baseline benchmarks completed: {len(benchmarks)} types measured")
return benchmarks
except Exception as e:
logger.error(f"Benchmark error: {e}", exc_info=True)
return benchmarks
async def _benchmark_schema_lookup() -> list:
"""
Benchmark schema lookup queries
Returns:
List of query times (ms) for multiple samples
"""
try:
# Import here to avoid circular dependency
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
# Get sample company IDs to test
sample_companies = await _get_sample_company_ids(limit=10)
if not sample_companies:
logger.warning("No sample companies found for schema benchmark")
return []
times = []
for company_id in sample_companies:
start = time.time()
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :id
""", {'id': company_id})
cursor.fetchone()
elapsed_ms = (time.time() - start) * 1000
times.append(elapsed_ms)
return times
except Exception as e:
logger.error(f"Schema benchmark error: {e}")
return []
async def _benchmark_companies_list() -> float:
"""
Benchmark companies list query
Returns:
Query time (ms)
"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
# Get sample username
sample_user = await _get_sample_username()
if not sample_user:
return 0
start = time.time()
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT nf.id_firma, nf.denumire, nf.cui, nf.schema
FROM CONTAFIN_ORACLE.v_nom_firme nf
JOIN CONTAFIN_ORACLE.vdef_util_firme uf ON nf.id_firma = uf.id_firma
WHERE uf.nume_utilizator = :username
ORDER BY nf.denumire
""", {'username': sample_user})
cursor.fetchall()
elapsed_ms = (time.time() - start) * 1000
return elapsed_ms
except Exception as e:
logger.error(f"Companies benchmark error: {e}")
return 0
async def _benchmark_dashboard_summary() -> float:
"""
Benchmark dashboard summary query
Returns:
Query time (ms)
"""
try:
# This requires access to DashboardService
# For now, return estimated value
logger.warning("Dashboard summary benchmark not implemented - using estimate")
return 250.0 # Estimated 250ms based on plan
except Exception as e:
logger.error(f"Dashboard benchmark error: {e}")
return 0
async def _benchmark_dashboard_trends() -> float:
"""Benchmark dashboard trends query"""
try:
logger.warning("Dashboard trends benchmark not implemented - using estimate")
return 400.0 # Estimated 400ms
except Exception as e:
logger.error(f"Trends benchmark error: {e}")
return 0
async def _benchmark_invoices() -> float:
"""Benchmark invoices query"""
try:
logger.warning("Invoices benchmark not implemented - using estimate")
return 180.0 # Estimated 180ms
except Exception as e:
logger.error(f"Invoices benchmark error: {e}")
return 0
async def _benchmark_treasury() -> float:
"""Benchmark treasury query"""
try:
logger.warning("Treasury benchmark not implemented - using estimate")
return 250.0 # Estimated 250ms
except Exception as e:
logger.error(f"Treasury benchmark error: {e}")
return 0
# Helper functions
async def _get_sample_company_ids(limit: int = 10) -> list:
"""Get sample company IDs for testing"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(f"""
SELECT id_firma
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE ROWNUM <= {limit}
""")
results = cursor.fetchall()
return [row[0] for row in results]
except Exception as e:
logger.error(f"Get sample companies error: {e}")
return []
async def _get_sample_username() -> str:
"""Get sample username for testing"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT nume_utilizator
FROM CONTAFIN_ORACLE.vdef_util_firme
WHERE ROWNUM <= 1
""")
result = cursor.fetchone()
return result[0] if result else "admin"
except Exception as e:
logger.error(f"Get sample username error: {e}")
return "admin"

View File

@@ -0,0 +1,339 @@
"""
Cache Manager - Orchestrator for hybrid L1 + L2 cache
"""
import logging
import asyncio
from typing import Any, Optional
from .config import CacheConfig
from .memory_cache import MemoryCache
from .sqlite_cache import SQLiteCache
logger = logging.getLogger(__name__)
class CacheManager:
"""
Hybrid cache manager (Memory L1 + SQLite L2)
Features:
- Two-tier caching: fast memory + persistent SQLite
- Automatic TTL management per cache type
- Performance tracking and benchmarking
- Per-user cache enable/disable
- Global cache toggle
"""
def __init__(self, config: CacheConfig):
"""
Initialize cache manager
Args:
config: Cache configuration
"""
self.config = config
self.memory = MemoryCache(max_size=config.memory_max_size)
self.sqlite = SQLiteCache(db_path=config.sqlite_path)
self._cleanup_task: Optional[asyncio.Task] = None
self._initialized = False
self._last_cache_source: Optional[str] = None # Track last cache source (L1/L2)
async def init(self):
"""Initialize cache system"""
if self._initialized:
logger.warning("Cache already initialized")
return
# Initialize SQLite database schema
await self.sqlite.init_db()
# Start cleanup task
if self.config.enabled:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
self._initialized = True
logger.info(f"Cache initialized: type={self.config.cache_type}, enabled={self.config.enabled}")
async def close(self):
"""Close cache and cleanup"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Close SQLite connection manager
if hasattr(self.sqlite, 'close'):
await self.sqlite.close()
logger.info("Cache closed")
async def get(self, key: str, cache_type: str) -> Optional[Any]:
"""
Get value from cache (L1 → L2)
Args:
key: Cache key
cache_type: Type of cache entry
Returns:
Cached value or None if not found
"""
if not self.config.enabled:
self._last_cache_source = None
return None
# Try L1 (Memory) first
value = await self.memory.get(key)
if value is not None:
self._last_cache_source = "L1"
logger.debug(f"Cache HIT L1 (memory): {key}")
return value
# Try L2 (SQLite)
value = await self.sqlite.get(key)
if value is not None:
self._last_cache_source = "L2"
logger.debug(f"Cache HIT L2 (sqlite): {key}")
# Populate L1 for next time
ttl = self.config.get_ttl_for_type(cache_type)
await self.memory.set(key, value, ttl)
return value
# Cache MISS
self._last_cache_source = None
logger.debug(f"Cache MISS: {key}")
return None
def get_last_cache_source(self) -> Optional[str]:
"""
Get source of last cache hit (L1/L2/None)
Returns:
"L1" if last hit was from memory cache
"L2" if last hit was from SQLite cache
None if last call was a cache miss or cache disabled
"""
return self._last_cache_source
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int] = None,
ttl: Optional[int] = None):
"""
Set value in cache (both L1 and L2)
Args:
key: Cache key
value: Value to cache
cache_type: Type of cache entry
company_id: Company ID (for company-specific caches)
ttl: Time to live (uses default for cache_type if not provided)
"""
if not self.config.enabled:
return
if ttl is None:
ttl = self.config.get_ttl_for_type(cache_type)
# Store in both L1 and L2
await self.memory.set(key, value, ttl)
await self.sqlite.set(key, value, cache_type, company_id, ttl)
logger.debug(f"Cache SET (L1 + L2): {key} (TTL: {ttl}s)")
async def delete(self, key: str):
"""Delete entry from both L1 and L2"""
await self.memory.delete(key)
await self.sqlite.delete(key)
logger.debug(f"Cache deleted: {key}")
async def invalidate(self, company_id: Optional[int] = None, cache_type: Optional[str] = None):
"""
Invalidate cache entries
Args:
company_id: If provided, clear only this company's cache
cache_type: If provided, clear only this cache type
"""
if company_id is not None and cache_type is not None:
# Clear specific company + type
from .keys import generate_key_pattern
pattern = generate_key_pattern(cache_type, company_id)
await self.memory.clear_by_pattern(pattern)
# SQLite: clear by company + type (needs query)
# For now, just clear by company
await self.sqlite.clear_by_company(company_id)
logger.info(f"Cache invalidated: company={company_id}, type={cache_type}")
elif company_id is not None:
# Clear all for company
from .keys import generate_key_pattern
# Clear all types for this company (pattern match all)
# Memory: need to iterate and match company_id in key
# For simplicity, clear by pattern prefix
await self.memory.clear() # TODO: improve pattern matching
await self.sqlite.clear_by_company(company_id)
logger.info(f"Cache invalidated: company={company_id}")
elif cache_type is not None:
# Clear all for type
from .keys import generate_key_pattern
pattern = generate_key_pattern(cache_type)
await self.memory.clear_by_pattern(pattern)
await self.sqlite.clear_by_type(cache_type)
logger.info(f"Cache invalidated: type={cache_type}")
else:
# Clear everything
await self.memory.clear()
await self.sqlite.clear()
logger.info("Cache invalidated: ALL")
async def is_enabled_for_user(self, username: Optional[str]) -> bool:
"""
Check if cache is enabled for specific user
Args:
username: Username to check
Returns:
True if cache enabled for user, False otherwise
"""
if not self.config.enabled:
return False
if username is None:
return True
# Check per-user setting
return await self.sqlite.get_user_cache_enabled(username)
async def set_user_cache_enabled(self, username: str, enabled: bool):
"""Set user cache enabled/disabled"""
await self.sqlite.set_user_cache_enabled(username, enabled)
logger.info(f"User cache setting: {username} -> {enabled}")
# Benchmarking
async def get_benchmark(self, cache_type: str) -> Optional[float]:
"""Get average benchmark time for cache type"""
return await self.sqlite.get_benchmark(cache_type)
async def update_benchmark(self, cache_type: str, new_time_ms: float):
"""
Update benchmark with new measurement (exponential moving average)
Args:
cache_type: Type of cache
new_time_ms: New measured time in milliseconds
"""
current_avg = await self.sqlite.get_benchmark(cache_type)
if current_avg is None:
# First measurement
new_avg = new_time_ms
sample_count = 1
else:
# Exponential moving average (alpha = 0.1)
new_avg = 0.9 * current_avg + 0.1 * new_time_ms
# Get current sample count (TODO: retrieve from DB)
sample_count = 1 # Simplified for now
await self.sqlite.set_benchmark(cache_type, new_avg, sample_count)
logger.debug(f"Benchmark updated: {cache_type} -> {new_avg:.2f}ms")
# Performance Tracking
async def track_performance(self, cache_type: str, is_hit: bool, actual_time_ms: float,
time_saved_ms: Optional[float] = None,
estimated_oracle_time_ms: Optional[float] = None,
company_id: Optional[int] = None,
username: Optional[str] = None):
"""
Track performance metric
Args:
cache_type: Type of cache
is_hit: True if cache hit, False if cache miss
actual_time_ms: Actual response time
time_saved_ms: Time saved by cache (for hits)
estimated_oracle_time_ms: Estimated Oracle time (for hits)
company_id: Company ID
username: Username
"""
if not self.config.track_performance:
return
await self.sqlite.log_performance(
cache_type=cache_type,
company_id=company_id,
cache_hit=is_hit,
response_time_ms=actual_time_ms,
estimated_oracle_time_ms=estimated_oracle_time_ms,
time_saved_ms=time_saved_ms,
username=username
)
# Statistics
async def get_stats(self) -> dict:
"""Get comprehensive cache statistics"""
memory_stats = self.memory.get_stats()
sqlite_stats = await self.sqlite.get_stats()
return {
'enabled': self.config.enabled,
'cache_type': self.config.cache_type,
'memory': memory_stats,
'sqlite': sqlite_stats,
}
# Cleanup
async def _cleanup_loop(self):
"""Background task to cleanup expired entries"""
while True:
try:
await asyncio.sleep(self.config.cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Cleanup error: {e}", exc_info=True)
async def _cleanup_expired(self):
"""Remove expired entries from both caches"""
logger.info("Running cache cleanup...")
await self.memory.cleanup_expired()
await self.sqlite.cleanup_expired()
logger.info("Cache cleanup completed")
# Global cache manager instance
_cache_manager: Optional[CacheManager] = None
async def init_cache(config: CacheConfig):
"""Initialize global cache manager"""
global _cache_manager
if _cache_manager is not None:
logger.warning("Cache already initialized")
return
_cache_manager = CacheManager(config)
await _cache_manager.init()
logger.info("Global cache manager initialized")
def get_cache() -> Optional[CacheManager]:
"""Get global cache manager instance"""
return _cache_manager
async def close_cache():
"""Close global cache manager"""
global _cache_manager
if _cache_manager is not None:
await _cache_manager.close()
_cache_manager = None

View File

@@ -0,0 +1,89 @@
"""
Cache configuration from environment variables
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class CacheConfig:
"""Cache configuration loaded from environment variables"""
# Core Settings
enabled: bool
cache_type: str # 'hybrid', 'memory', 'sqlite', 'disabled'
sqlite_path: str
memory_max_size: int
default_ttl: int
# TTL per Cache Type (seconds)
ttl_schema: int
ttl_companies: int
ttl_dashboard_summary: int
ttl_dashboard_trends: int
ttl_invoices: int
ttl_invoices_summary: int
ttl_treasury: int
ttl_trial_balance: int
ttl_calendar_periods: int
# Maintenance
cleanup_interval: int
# Event-Based Invalidation
auto_invalidate_enabled: bool
check_interval: int
# Performance Tracking
track_performance: bool
benchmark_on_startup: bool
@classmethod
def from_env(cls) -> 'CacheConfig':
"""Load configuration from environment variables"""
return cls(
# Core Settings
enabled=os.getenv('CACHE_ENABLED', 'True').lower() == 'true',
cache_type=os.getenv('CACHE_TYPE', 'hybrid'),
sqlite_path=os.getenv('CACHE_SQLITE_PATH', './data/cache/roa2web_cache.db'),
memory_max_size=int(os.getenv('CACHE_MEMORY_MAX_SIZE', '1000')),
default_ttl=int(os.getenv('CACHE_DEFAULT_TTL', '900')),
# TTL per Cache Type
ttl_schema=int(os.getenv('CACHE_TTL_SCHEMA', '86400')),
ttl_companies=int(os.getenv('CACHE_TTL_COMPANIES', '1800')),
ttl_dashboard_summary=int(os.getenv('CACHE_TTL_DASHBOARD_SUMMARY', '1800')),
ttl_dashboard_trends=int(os.getenv('CACHE_TTL_DASHBOARD_TRENDS', '1800')),
ttl_invoices=int(os.getenv('CACHE_TTL_INVOICES', '600')),
ttl_invoices_summary=int(os.getenv('CACHE_TTL_INVOICES_SUMMARY', '900')),
ttl_treasury=int(os.getenv('CACHE_TTL_TREASURY', '600')),
ttl_trial_balance=int(os.getenv('CACHE_TTL_TRIAL_BALANCE', '600')),
ttl_calendar_periods=int(os.getenv('CACHE_TTL_CALENDAR_PERIODS', '3600')),
# Maintenance
cleanup_interval=int(os.getenv('CACHE_CLEANUP_INTERVAL', '3600')),
# Event-Based Invalidation
auto_invalidate_enabled=os.getenv('CACHE_AUTO_INVALIDATE', 'False').lower() == 'true',
check_interval=int(os.getenv('CACHE_CHECK_INTERVAL', '300')),
# Performance Tracking
track_performance=os.getenv('CACHE_TRACK_PERFORMANCE', 'True').lower() == 'true',
benchmark_on_startup=os.getenv('CACHE_BENCHMARK_ON_STARTUP', 'True').lower() == 'true',
)
def get_ttl_for_type(self, cache_type: str) -> int:
"""Get TTL for specific cache type"""
ttl_map = {
'schema': self.ttl_schema,
'companies': self.ttl_companies,
'dashboard_summary': self.ttl_dashboard_summary,
'dashboard_trends': self.ttl_dashboard_trends,
'invoices': self.ttl_invoices,
'invoices_summary': self.ttl_invoices_summary,
'treasury': self.ttl_treasury,
'trial_balance': self.ttl_trial_balance,
'calendar_periods': self.ttl_calendar_periods,
}
return ttl_map.get(cache_type, self.default_ttl)

View File

@@ -0,0 +1,285 @@
"""
Cache decorators for service methods
"""
import time
import logging
import sqlite3
import asyncio
from functools import wraps
from typing import Callable, Optional, List
from .cache_manager import get_cache
from .keys import generate_cache_key
logger = logging.getLogger(__name__)
# Retry configuration for SQLite locked database errors
SQLITE_MAX_RETRIES = 3
SQLITE_RETRY_BASE_DELAY = 0.1 # 100ms base delay, exponential backoff
def cached(cache_type: str, ttl: Optional[int] = None, key_params: Optional[List[str]] = None):
"""
Decorator for caching service method results with performance tracking
Usage:
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
async def get_complete_summary(company: str, username: str):
# ... Oracle query logic ...
Features:
- Automatic cache key generation from function parameters
- Performance timing (cache hit vs miss)
- Benchmark tracking for time saved calculation
- Per-user cache enable/disable
- Global cache toggle
- Transparent - zero changes to function logic
Args:
cache_type: Type of cache (used for TTL lookup and stats)
ttl: Optional custom TTL (overrides config default)
key_params: List of parameter names to include in cache key
Returns:
Decorated async function
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
cache = get_cache()
# Extract username for per-user settings
username = _extract_username(args, kwargs, key_params)
# Check if cache is enabled (global + per-user)
cache_enabled = await cache.is_enabled_for_user(username) if cache else False
if not cache or not cache_enabled:
# Cache disabled - execute directly
result = await func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
kwargs['request'].state.cache_hit = False
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = None
if cache and cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=False,
actual_time_ms=elapsed_ms,
username=username
)
return result
# Generate cache key from function parameters
cache_key = generate_cache_key(cache_type, key_params, args, kwargs)
# Try to get from cache with retry logic for SQLite locks
cached_value = None
for attempt in range(SQLITE_MAX_RETRIES):
try:
cached_value = await cache.get(cache_key, cache_type)
break
except sqlite3.OperationalError as e:
if "database is locked" in str(e) and attempt < SQLITE_MAX_RETRIES - 1:
delay = SQLITE_RETRY_BASE_DELAY * (attempt + 1)
logger.warning(f"SQLite locked on cache.get, retry {attempt + 1}/{SQLITE_MAX_RETRIES} after {delay}s")
await asyncio.sleep(delay)
else:
logger.error(f"SQLite error after {attempt + 1} retries: {e}")
cached_value = None
break
if cached_value is not None:
# ✅ CACHE HIT
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
cache_source_value = cache.get_last_cache_source() # L1 or L2
kwargs['request'].state.cache_hit = True
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = cache_source_value
# Get benchmark for calculating time saved
benchmark = await cache.get_benchmark(cache_type)
time_saved_ms = (benchmark - elapsed_ms) if benchmark else None
# Track performance
if cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=True,
actual_time_ms=elapsed_ms,
time_saved_ms=time_saved_ms,
estimated_oracle_time_ms=benchmark,
company_id=_extract_company_id(args, kwargs, key_params),
username=username
)
return cached_value
# ❌ CACHE MISS - execute function (query Oracle)
result = await func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
kwargs['request'].state.cache_hit = False
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = None
# Update benchmark with real Oracle time
await cache.update_benchmark(cache_type, elapsed_ms)
# Track performance
if cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=False,
actual_time_ms=elapsed_ms,
company_id=_extract_company_id(args, kwargs, key_params),
username=username
)
# Store in cache for next time (with retry logic for SQLite locks)
company_id = _extract_company_id(args, kwargs, key_params)
for attempt in range(SQLITE_MAX_RETRIES):
try:
await cache.set(cache_key, result, cache_type, company_id, ttl)
break
except sqlite3.OperationalError as e:
if "database is locked" in str(e) and attempt < SQLITE_MAX_RETRIES - 1:
delay = SQLITE_RETRY_BASE_DELAY * (attempt + 1)
logger.warning(f"SQLite locked on cache.set, retry {attempt + 1}/{SQLITE_MAX_RETRIES} after {delay}s")
await asyncio.sleep(delay)
else:
logger.error(f"SQLite error on cache.set after {attempt + 1} retries: {e}")
# Don't fail the request, just skip caching
break
return result
return wrapper
return decorator
def _extract_username(args, kwargs, key_params: Optional[List[str]]) -> Optional[str]:
"""
Extract username from function parameters (args or kwargs)
Checks:
1. key_params position in args (if username is in key_params)
2. Direct username in kwargs
3. current_user object in kwargs
4. user object in kwargs
5. request.state.user (from AuthenticationMiddleware)
Args:
args: Positional arguments
kwargs: Keyword arguments
key_params: List of parameter names (for finding position in args)
Returns:
Username string or None
"""
# Try to find username in args based on key_params position
if key_params and 'username' in key_params:
try:
username_idx = key_params.index('username')
if username_idx < len(args):
username = args[username_idx]
if username:
return str(username)
except (ValueError, IndexError):
pass
# Direct username parameter in kwargs
if 'username' in kwargs:
return kwargs['username']
# Current user object (from FastAPI Depends)
if 'current_user' in kwargs:
user = kwargs['current_user']
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return str(user)
# User object
if 'user' in kwargs:
user = kwargs['user']
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return str(user)
# Extract from request.state.user (set by AuthenticationMiddleware)
if 'request' in kwargs:
request = kwargs['request']
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user = request.state.user
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return None
def _extract_company_id(args, kwargs, key_params: Optional[List[str]]) -> Optional[int]:
"""
Extract company_id from function parameters for cache indexing
Tries multiple approaches:
1. Direct company_id in kwargs
2. company parameter (converted to int)
3. Positional args based on key_params position
Args:
args: Positional arguments
kwargs: Keyword arguments
key_params: List of parameter names
Returns:
Company ID as integer or None
"""
# Try kwargs first
if 'company_id' in kwargs:
try:
return int(kwargs['company_id'])
except (ValueError, TypeError):
pass
if 'company' in kwargs:
try:
return int(kwargs['company'])
except (ValueError, TypeError):
pass
# Try positional args based on key_params
if key_params:
if 'company_id' in key_params:
idx = key_params.index('company_id')
if idx < len(args):
try:
return int(args[idx])
except (ValueError, TypeError):
pass
elif 'company' in key_params:
idx = key_params.index('company')
if idx < len(args):
try:
return int(args[idx])
except (ValueError, TypeError):
pass
return None

View File

@@ -0,0 +1,333 @@
"""
Event-based cache invalidation monitor
Monitors {schema}.act tables for changes and invalidates cache automatically
"""
import asyncio
import logging
import sys
import os
from typing import Optional
# Path setup handled by main.py - this is redundant but kept for module isolation
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
logger = logging.getLogger(__name__)
class EventMonitor:
"""
Monitors schema.act tables for changes to trigger cache invalidation
Runs as background task, checking max(id_act) at configured intervals
Uses permanent schema_mappings cache to avoid repeated schema lookups
"""
def __init__(self, cache_manager, config):
"""
Initialize event monitor
Args:
cache_manager: CacheManager instance
config: CacheConfig instance
"""
self.cache = cache_manager
self.config = config
self.running = False
self.task: Optional[asyncio.Task] = None
async def start(self):
"""Start monitoring task"""
if self.running:
logger.warning("Event monitor already running")
return
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.info(
f"Event monitor started (interval: {self.config.check_interval}s)"
)
async def stop(self):
"""Stop monitoring task"""
if not self.running:
return
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info("Event monitor stopped")
async def _monitor_loop(self):
"""Main monitoring loop"""
while self.running:
try:
await self._check_all_companies()
await asyncio.sleep(self.config.check_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Event monitor error: {e}", exc_info=True)
# Wait 1 minute on error before retrying
await asyncio.sleep(60)
async def _check_all_companies(self):
"""
Check all companies with active cache for changes
Queries max(id_act) from {schema}.act for each cached company
and invalidates cache if changes detected
"""
try:
# Get list of companies with active cache entries
cached_companies = await self.cache.sqlite.get_cached_company_ids()
if not cached_companies:
logger.debug("No cached companies to monitor")
return
logger.info(f"Checking {len(cached_companies)} companies for changes...")
invalidated_count = 0
for company_id in cached_companies:
try:
# Check if company data changed
changed = await self._check_company_changes(company_id)
if changed:
# Invalidate cache for this company
await self.cache.invalidate(company_id=company_id)
invalidated_count += 1
logger.info(
f"Cache invalidated for company {company_id} due to act changes"
)
except Exception as e:
# Error for one company shouldn't stop checking others
logger.error(f"Error checking company {company_id}: {e}")
continue
if invalidated_count > 0:
logger.info(
f"Auto-invalidation complete: {invalidated_count} companies affected"
)
except Exception as e:
logger.error(f"Check all companies error: {e}", exc_info=True)
async def _check_company_changes(self, company_id: int) -> bool:
"""
Check if company data changed (monitor max(id_act) in schema.act)
Args:
company_id: Company ID to check
Returns:
True if cache should be invalidated, False otherwise
"""
try:
# 1. Get schema (from permanent cache)
schema = await self._get_schema_for_company(company_id)
if not schema:
logger.warning(f"Schema not found for company {company_id}")
return False
# 2. Get current max(id_act) from Oracle
current_max = await self._get_max_id_act(schema)
# 3. Get cached watermark
cached_watermark = await self.cache.sqlite.get_watermark(company_id)
# 4. Compare
if cached_watermark is None:
# First time checking - store watermark, no invalidation
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
logger.debug(
f"Watermark initialized for company {company_id}: {current_max}"
)
return False
if current_max > cached_watermark:
# Changes detected!
logger.info(
f"Schema {schema} (company {company_id}): "
f"id_act changed {cached_watermark} -> {current_max}"
)
# Update watermark
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
return True # Invalidate cache
# No changes
return False
except Exception as e:
logger.error(f"Check company {company_id} changes error: {e}")
return False # Don't invalidate on error
async def _get_schema_for_company(self, company_id: int) -> Optional[str]:
"""
Get schema for company (with permanent caching)
First checks permanent schema_mappings cache,
falls back to Oracle query if not cached
Args:
company_id: Company ID
Returns:
Schema name or None
"""
# Check permanent cache first
cached_schema = await self.cache.sqlite.get_schema_mapping(company_id)
if cached_schema:
logger.debug(f"Schema mapping HIT for company {company_id}: {cached_schema}")
return cached_schema
# Cache MISS - query Oracle
logger.info(f"Schema mapping MISS for company {company_id}, querying Oracle...")
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :id
""", {'id': company_id})
result = cursor.fetchone()
if not result:
logger.warning(f"Company {company_id} not found in v_nom_firme")
return None
schema = result[0]
# Store PERMANENT in schema_mappings (never expires)
await self.cache.sqlite.set_schema_mapping(company_id, schema)
logger.info(f"Schema mapping stored for company {company_id}: {schema}")
return schema
except Exception as e:
logger.error(f"Get schema for company {company_id} error: {e}")
return None
async def _get_max_id_act(self, schema: str) -> int:
"""
Query max(id_act) from {schema}.act
Args:
schema: Schema name
Returns:
Max id_act value (0 if table empty)
"""
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# IMPORTANT: Schema comes from v_nom_firme (trusted source)
# so it's safe from SQL injection
query = f"SELECT MAX(id_act) FROM {schema}.act"
cursor.execute(query)
result = cursor.fetchone()
max_id_act = result[0] if result and result[0] is not None else 0
return max_id_act
except Exception as e:
logger.error(f"Get max_id_act for schema {schema} error: {e}")
return 0
# Optional: Preload all schema mappings at startup
async def preload_all_schema_mappings():
"""
Preload all schema mappings at startup (optional)
Prevents cache misses on first requests by populating
schema_mappings table with all companies
"""
from .cache_manager import get_cache
cache = get_cache()
if not cache:
logger.warning("Cache not initialized - skipping schema preload")
return
logger.info("Preloading all schema mappings...")
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT id_firma, schema
FROM CONTAFIN_ORACLE.v_nom_firme
""")
results = cursor.fetchall()
for id_firma, schema in results:
await cache.sqlite.set_schema_mapping(id_firma, schema)
logger.info(f"Preloaded {len(results)} schema mappings")
except Exception as e:
logger.error(f"Schema preload error: {e}")
# Global event monitor instance
_event_monitor: Optional[EventMonitor] = None
async def init_event_monitor(cache_manager, config):
"""
Initialize global event monitor
Args:
cache_manager: CacheManager instance
config: CacheConfig instance
"""
global _event_monitor
_event_monitor = EventMonitor(cache_manager, config)
# Start if auto-invalidate enabled
if config.auto_invalidate_enabled:
await _event_monitor.start()
def get_event_monitor() -> Optional[EventMonitor]:
"""Get global event monitor instance"""
return _event_monitor
async def toggle_event_monitor(enabled: bool):
"""
Toggle event monitor on/off
Args:
enabled: True to start monitoring, False to stop
"""
monitor = get_event_monitor()
if not monitor:
logger.warning("Event monitor not initialized")
return
if enabled and not monitor.running:
await monitor.start()
elif not enabled and monitor.running:
await monitor.stop()

View File

@@ -0,0 +1,150 @@
"""
Cache key generation utilities
"""
import hashlib
import json
from typing import Any, List, Optional
def generate_cache_key(cache_type: str, key_params: Optional[List[str]], args: tuple, kwargs: dict) -> str:
"""
Generate cache key from function parameters
Format: "{cache_type}:{param1_value}:{param2_value}:..."
Args:
cache_type: Type of cache (e.g., 'dashboard_summary', 'invoices')
key_params: List of parameter names to include in key
args: Positional arguments from function call
kwargs: Keyword arguments from function call
Returns:
Cache key string
Examples:
generate_cache_key('schema', ['company_id'], (123,), {})
-> "schema:123"
generate_cache_key('dashboard_summary', ['company', 'username'], (), {'company': '123', 'username': 'john'})
-> "dashboard_summary:123:john"
generate_cache_key('invoices', ['company', 'invoice_type', 'status'], (123, 'CLIENTI', 'neplatite'), {})
-> "invoices:123:CLIENTI:neplatite"
"""
key_parts = [cache_type]
if not key_params:
# No specific params - use all args/kwargs (fallback)
if args:
key_parts.extend([str(arg) for arg in args])
if kwargs:
# Sort kwargs for consistent key generation
sorted_kwargs = sorted(kwargs.items())
key_parts.extend([f"{k}={v}" for k, v in sorted_kwargs])
else:
# Extract specific params
for i, param_name in enumerate(key_params):
# Try to get from kwargs first
if param_name in kwargs:
value = kwargs[param_name]
# Then try positional args
elif i < len(args):
value = args[i]
else:
# Parameter not found - use placeholder
value = "none"
key_parts.append(str(value))
return ":".join(key_parts)
def generate_key_pattern(cache_type: str, company_id: Optional[int] = None) -> str:
"""
Generate cache key pattern for matching multiple keys
Used for invalidation by type or company
Args:
cache_type: Type of cache
company_id: Optional company ID to filter by
Returns:
Pattern string (prefix)
Examples:
generate_key_pattern('dashboard_summary')
-> "dashboard_summary:"
generate_key_pattern('dashboard_summary', 123)
-> "dashboard_summary:123"
"""
if company_id is not None:
return f"{cache_type}:{company_id}"
return f"{cache_type}:"
def hash_complex_params(params: dict) -> str:
"""
Generate hash for complex parameters (e.g., filters, queries)
Used when cache key would be too long with full param values
Args:
params: Dictionary of parameters to hash
Returns:
8-character hash string
Example:
filters = {'status': 'neplatite', 'date_from': '2024-01-01', 'date_to': '2024-12-31'}
hash_complex_params(filters)
-> "a3f8b2c1"
"""
# Sort keys for consistent hashing
sorted_params = json.dumps(params, sort_keys=True)
hash_obj = hashlib.sha256(sorted_params.encode())
# Return first 8 characters of hex digest
return hash_obj.hexdigest()[:8]
def extract_company_id_from_key(cache_key: str) -> Optional[int]:
"""
Extract company_id from cache key
Assumes format: "cache_type:company_id:..."
Args:
cache_key: Cache key string
Returns:
Company ID or None if not found
Example:
extract_company_id_from_key("dashboard_summary:123:john")
-> 123
"""
parts = cache_key.split(":")
if len(parts) >= 2:
try:
return int(parts[1])
except (ValueError, TypeError):
pass
return None
def extract_cache_type_from_key(cache_key: str) -> str:
"""
Extract cache_type from cache key
Args:
cache_key: Cache key string
Returns:
Cache type (first part before colon)
Example:
extract_cache_type_from_key("dashboard_summary:123:john")
-> "dashboard_summary"
"""
return cache_key.split(":")[0]

View File

@@ -0,0 +1,180 @@
"""
In-memory cache with TTL (L1 cache)
Fast, limited size, lost on restart
"""
import time
import logging
from typing import Any, Optional, Dict
from collections import OrderedDict
logger = logging.getLogger(__name__)
class MemoryCache:
"""
In-memory LRU cache with TTL support
Features:
- LRU eviction when max_size reached
- Per-entry TTL expiration
- Thread-safe operations
- Fast O(1) get/set operations
"""
def __init__(self, max_size: int = 1000):
"""
Initialize memory cache
Args:
max_size: Maximum number of entries to store
"""
self.max_size = max_size
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'evictions': 0
}
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
if key not in self._cache:
self._stats['misses'] += 1
return None
entry = self._cache[key]
# Check TTL expiration
if entry['expires_at'] < time.time():
# Expired - remove and return None
del self._cache[key]
self._stats['misses'] += 1
logger.debug(f"Memory cache expired: {key}")
return None
# Move to end (LRU - most recently used)
self._cache.move_to_end(key)
self._stats['hits'] += 1
logger.debug(f"Memory cache HIT: {key}")
return entry['value']
async def set(self, key: str, value: Any, ttl: int):
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
"""
expires_at = time.time() + ttl
# Check if we need to evict (LRU)
if key not in self._cache and len(self._cache) >= self.max_size:
# Evict oldest entry (first item in OrderedDict)
evicted_key = next(iter(self._cache))
del self._cache[evicted_key]
self._stats['evictions'] += 1
logger.debug(f"Memory cache evicted (LRU): {evicted_key}")
# Store entry
self._cache[key] = {
'value': value,
'expires_at': expires_at,
'created_at': time.time()
}
# Move to end (most recently used)
self._cache.move_to_end(key)
self._stats['sets'] += 1
logger.debug(f"Memory cache SET: {key} (TTL: {ttl}s)")
async def delete(self, key: str) -> bool:
"""
Delete entry from cache
Args:
key: Cache key
Returns:
True if deleted, False if not found
"""
if key in self._cache:
del self._cache[key]
logger.debug(f"Memory cache deleted: {key}")
return True
return False
async def clear(self):
"""Clear all entries from cache"""
count = len(self._cache)
self._cache.clear()
logger.info(f"Memory cache cleared: {count} entries removed")
async def clear_by_pattern(self, pattern: str):
"""
Clear entries matching pattern (simple prefix match)
Args:
pattern: Key prefix to match (e.g., "dashboard_summary:123")
"""
keys_to_delete = [key for key in self._cache.keys() if key.startswith(pattern)]
for key in keys_to_delete:
del self._cache[key]
logger.info(f"Memory cache cleared by pattern '{pattern}': {len(keys_to_delete)} entries")
async def cleanup_expired(self):
"""Remove all expired entries"""
now = time.time()
expired_keys = [
key for key, entry in self._cache.items()
if entry['expires_at'] < now
]
for key in expired_keys:
del self._cache[key]
if expired_keys:
logger.info(f"Memory cache cleanup: {len(expired_keys)} expired entries removed")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics
Returns:
Dictionary with stats (hits, misses, size, etc.)
"""
total_requests = self._stats['hits'] + self._stats['misses']
hit_rate = (self._stats['hits'] / total_requests * 100) if total_requests > 0 else 0
return {
'size': len(self._cache),
'max_size': self.max_size,
'hits': self._stats['hits'],
'misses': self._stats['misses'],
'sets': self._stats['sets'],
'evictions': self._stats['evictions'],
'hit_rate': hit_rate,
'total_requests': total_requests
}
def reset_stats(self):
"""Reset statistics counters"""
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'evictions': 0
}

View File

@@ -0,0 +1,594 @@
"""
SQLite persistent cache (L2 cache)
Persistent, survives restarts, unlimited size
Uses singleton connection pattern with asyncio.Lock for write serialization
to prevent "database is locked" errors under concurrent access.
"""
import time
import json
import logging
import asyncio
import aiosqlite
from typing import Any, Optional, List, Dict
from pathlib import Path
from decimal import Decimal
from datetime import datetime, date
# SQLite busy timeout in milliseconds (wait for lock instead of failing immediately)
SQLITE_BUSY_TIMEOUT_MS = 5000
logger = logging.getLogger(__name__)
class CustomJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles Pydantic models, Decimal, datetime, etc."""
def default(self, obj):
# Handle Pydantic models
if hasattr(obj, 'dict'):
return obj.dict()
if hasattr(obj, 'model_dump'): # Pydantic v2
return obj.model_dump()
# Handle Decimal
if isinstance(obj, Decimal):
return float(obj)
# Handle datetime/date
if isinstance(obj, (datetime, date)):
return obj.isoformat()
return super().default(obj)
class SQLiteConnectionManager:
"""
Singleton connection manager with write serialization.
Solves "database is locked" errors by:
1. Maintaining a single persistent connection (instead of N connections per request)
2. Serializing all write operations through an asyncio.Lock
3. Using WAL mode for better concurrent read performance
Architecture:
┌─────────────────────────────────────┐
│ SQLiteConnectionManager │
│ (SINGLETON) │
│ │
│ _connection: aiosqlite.Connection │
│ _write_lock: asyncio.Lock │
└─────────────────────────────────────┘
┌───────────────┼───────────────┐
▼ ▼ ▼
Task 1 Task 2 Task N
cache.get() cache.set() cache.get()
│ │ │
└───────────────┴───────────────┘
async with _write_lock:
(serialized writes)
"""
_instance: Optional['SQLiteConnectionManager'] = None
_instance_lock: asyncio.Lock = None # Will be created on first use
def __init__(self, db_path: str):
"""
Initialize connection manager (called only by get_instance).
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._connection: Optional[aiosqlite.Connection] = None
self._write_lock: Optional[asyncio.Lock] = None
self._initialized = False
@classmethod
async def get_instance(cls, db_path: str) -> 'SQLiteConnectionManager':
"""
Get or create singleton instance.
Thread-safe singleton pattern using asyncio.Lock.
Args:
db_path: Path to SQLite database file
Returns:
SQLiteConnectionManager singleton instance
"""
# Create instance lock on first call (must be done in async context)
if cls._instance_lock is None:
cls._instance_lock = asyncio.Lock()
async with cls._instance_lock:
if cls._instance is None or cls._instance.db_path != db_path:
cls._instance = cls(db_path)
return cls._instance
async def initialize(self):
"""
Create connection with WAL mode and busy timeout.
Sets up:
- Busy timeout (5 seconds) - wait for locks instead of failing
- WAL journal mode - allows concurrent reads while writing
- Write lock for serializing write operations
"""
if self._initialized:
return
# Create write lock in async context
self._write_lock = asyncio.Lock()
# Create persistent connection
self._connection = await aiosqlite.connect(self.db_path)
await self._connection.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await self._connection.execute("PRAGMA journal_mode=WAL")
await self._connection.commit()
self._initialized = True
logger.info(f"SQLite connection manager initialized: {self.db_path}")
async def get_connection(self) -> aiosqlite.Connection:
"""
Get the persistent connection, with health check.
If connection is unhealthy (closed or stale), reconnects automatically.
Returns:
Active aiosqlite connection
"""
if self._connection is None or not await self._is_healthy():
await self._reconnect()
return self._connection
async def _is_healthy(self) -> bool:
"""
Check if connection is valid.
Returns:
True if connection can execute queries, False otherwise
"""
try:
async with self._connection.execute("SELECT 1") as cursor:
await cursor.fetchone()
return True
except Exception:
return False
async def _reconnect(self):
"""Reconnect if connection was lost."""
logger.warning("SQLite connection unhealthy, reconnecting...")
# Close old connection if exists
if self._connection:
try:
await self._connection.close()
except Exception:
pass
# Create new connection
self._connection = await aiosqlite.connect(self.db_path)
await self._connection.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}")
await self._connection.execute("PRAGMA journal_mode=WAL")
await self._connection.commit()
logger.info("SQLite connection re-established")
@property
def write_lock(self) -> asyncio.Lock:
"""Get the write lock for serializing write operations."""
return self._write_lock
async def close(self):
"""Close the connection and reset singleton."""
if self._connection:
try:
await self._connection.close()
except Exception as e:
logger.warning(f"Error closing SQLite connection: {e}")
self._connection = None
self._initialized = False
# Reset singleton
SQLiteConnectionManager._instance = None
logger.info("SQLite connection manager closed")
class SQLiteCache:
"""
SQLite-based persistent cache
Features:
- Persistent storage (survives restarts)
- JSON serialization for complex objects
- Schema mappings (permanent cache for company->schema)
- Watermarks for event-based invalidation
- Performance tracking and benchmarks
- Singleton connection with write serialization (prevents "database is locked")
"""
def __init__(self, db_path: str):
"""
Initialize SQLite cache
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._conn_manager: Optional[SQLiteConnectionManager] = None
self._ensure_db_dir()
def _ensure_db_dir(self):
"""Ensure database directory exists"""
db_dir = Path(self.db_path).parent
db_dir.mkdir(parents=True, exist_ok=True)
async def init_db(self):
"""Initialize database schema with WAL mode enabled"""
# Get or create singleton connection manager
self._conn_manager = await SQLiteConnectionManager.get_instance(self.db_path)
await self._conn_manager.initialize()
# Create tables using the persistent connection
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
# Table: cache_entries
await conn.execute("""
CREATE TABLE IF NOT EXISTS cache_entries (
cache_key TEXT PRIMARY KEY,
cache_type TEXT NOT NULL,
company_id INTEGER,
data_json TEXT NOT NULL,
created_at REAL NOT NULL,
expires_at REAL NOT NULL,
hit_count INTEGER DEFAULT 0,
last_accessed REAL
)
""")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_cache_type ON cache_entries(cache_type)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_company_id ON cache_entries(company_id)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_expires_at ON cache_entries(expires_at)")
# Table: schema_mappings (PERMANENT)
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_mappings (
id_firma INTEGER PRIMARY KEY,
schema TEXT NOT NULL,
created_at REAL NOT NULL,
last_verified REAL
)
""")
# Table: query_benchmarks
await conn.execute("""
CREATE TABLE IF NOT EXISTS query_benchmarks (
cache_type TEXT PRIMARY KEY,
avg_time_ms REAL NOT NULL,
sample_count INTEGER DEFAULT 0,
last_updated REAL
)
""")
# Table: performance_log
await conn.execute("""
CREATE TABLE IF NOT EXISTS performance_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cache_type TEXT NOT NULL,
company_id INTEGER,
cache_hit BOOLEAN NOT NULL,
response_time_ms REAL NOT NULL,
estimated_oracle_time_ms REAL,
time_saved_ms REAL,
username TEXT,
timestamp REAL NOT NULL
)
""")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_perf_timestamp ON performance_log(timestamp)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_perf_cache_type ON performance_log(cache_type)")
# Table: user_cache_settings
await conn.execute("""
CREATE TABLE IF NOT EXISTS user_cache_settings (
username TEXT PRIMARY KEY,
cache_enabled BOOLEAN DEFAULT TRUE,
created_at REAL,
updated_at REAL
)
""")
# Table: cache_config
await conn.execute("""
CREATE TABLE IF NOT EXISTS cache_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at REAL
)
""")
# Table: cache_watermarks
await conn.execute("""
CREATE TABLE IF NOT EXISTS cache_watermarks (
company_id INTEGER PRIMARY KEY,
schema TEXT NOT NULL,
max_id_act INTEGER NOT NULL,
checked_at REAL NOT NULL
)
""")
await conn.commit()
logger.info("SQLite cache database initialized")
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
# Use write lock because we may update hit_count or delete expired entries
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT data_json, expires_at
FROM cache_entries
WHERE cache_key = ?
""", (key,)) as cursor:
result = await cursor.fetchone()
if not result:
return None
data_json, expires_at = result
# Check TTL expiration
if expires_at < time.time():
# Expired - delete and return None
await conn.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
await conn.commit()
logger.debug(f"SQLite cache expired: {key}")
return None
# Update hit_count and last_accessed
await conn.execute("""
UPDATE cache_entries
SET hit_count = hit_count + 1, last_accessed = ?
WHERE cache_key = ?
""", (time.time(), key))
await conn.commit()
logger.debug(f"SQLite cache HIT: {key}")
return json.loads(data_json)
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int], ttl: int):
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
cache_type: Type of cache entry
company_id: Company ID (None for global caches)
ttl: Time to live in seconds
"""
# Use custom encoder to handle Pydantic models, Decimal, datetime, etc.
data_json = json.dumps(value, cls=CustomJSONEncoder)
now = time.time()
expires_at = now + ttl
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT OR REPLACE INTO cache_entries
(cache_key, cache_type, company_id, data_json, created_at, expires_at, hit_count, last_accessed)
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
""", (key, cache_type, company_id, data_json, now, expires_at, now))
await conn.commit()
logger.debug(f"SQLite cache SET: {key} (TTL: {ttl}s)")
async def delete(self, key: str) -> bool:
"""Delete entry from cache"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
cursor = await conn.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
await conn.commit()
deleted = cursor.rowcount > 0
if deleted:
logger.debug(f"SQLite cache deleted: {key}")
return deleted
async def clear(self):
"""Clear all cache entries"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
cursor = await conn.execute("DELETE FROM cache_entries")
await conn.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared: {count} entries removed")
async def clear_by_company(self, company_id: int):
"""Clear all entries for specific company"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
cursor = await conn.execute("DELETE FROM cache_entries WHERE company_id = ?", (company_id,))
await conn.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared for company {company_id}: {count} entries")
async def clear_by_type(self, cache_type: str):
"""Clear all entries of specific type"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
cursor = await conn.execute("DELETE FROM cache_entries WHERE cache_type = ?", (cache_type,))
await conn.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared for type '{cache_type}': {count} entries")
async def cleanup_expired(self):
"""Remove all expired entries"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
cursor = await conn.execute("DELETE FROM cache_entries WHERE expires_at < ?", (time.time(),))
await conn.commit()
count = cursor.rowcount
if count > 0:
logger.info(f"SQLite cache cleanup: {count} expired entries removed")
# Schema Mappings (PERMANENT)
async def get_schema_mapping(self, company_id: int) -> Optional[str]:
"""Get permanent cached schema for company (READ-ONLY, no lock needed)"""
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT schema
FROM schema_mappings
WHERE id_firma = ?
""", (company_id,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_schema_mapping(self, company_id: int, schema: str):
"""Set permanent schema mapping (never expires)"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT OR REPLACE INTO schema_mappings
(id_firma, schema, created_at, last_verified)
VALUES (?, ?, ?, ?)
""", (company_id, schema, time.time(), time.time()))
await conn.commit()
# Benchmarks
async def get_benchmark(self, cache_type: str) -> Optional[float]:
"""Get average benchmark time for cache type (READ-ONLY, no lock needed)"""
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT avg_time_ms
FROM query_benchmarks
WHERE cache_type = ?
""", (cache_type,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_benchmark(self, cache_type: str, avg_time_ms: float, sample_count: int):
"""Set/update benchmark"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT OR REPLACE INTO query_benchmarks
(cache_type, avg_time_ms, sample_count, last_updated)
VALUES (?, ?, ?, ?)
""", (cache_type, avg_time_ms, sample_count, time.time()))
await conn.commit()
# Performance Tracking
async def log_performance(self, cache_type: str, company_id: Optional[int], cache_hit: bool,
response_time_ms: float, estimated_oracle_time_ms: Optional[float],
time_saved_ms: Optional[float], username: Optional[str]):
"""Log performance metric"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT INTO performance_log
(cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
time_saved_ms, username, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
time_saved_ms, username, time.time()))
await conn.commit()
# User Settings
async def get_user_cache_enabled(self, username: str) -> bool:
"""Get user cache setting (default True) - READ-ONLY, no lock needed"""
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT cache_enabled
FROM user_cache_settings
WHERE username = ?
""", (username,)) as cursor:
result = await cursor.fetchone()
return bool(result[0]) if result else True # Default enabled, explicit bool conversion
async def set_user_cache_enabled(self, username: str, enabled: bool):
"""Set user cache setting"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT OR REPLACE INTO user_cache_settings
(username, cache_enabled, created_at, updated_at)
VALUES (?, ?, ?, ?)
""", (username, enabled, time.time(), time.time()))
await conn.commit()
# Watermarks
async def get_watermark(self, company_id: int) -> Optional[int]:
"""Get cached watermark (max_id_act) for company - READ-ONLY, no lock needed"""
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT max_id_act
FROM cache_watermarks
WHERE company_id = ?
""", (company_id,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_watermark(self, company_id: int, schema: str, max_id_act: int):
"""Set/update watermark for company"""
async with self._conn_manager.write_lock:
conn = await self._conn_manager.get_connection()
await conn.execute("""
INSERT OR REPLACE INTO cache_watermarks
(company_id, schema, max_id_act, checked_at)
VALUES (?, ?, ?, ?)
""", (company_id, schema, max_id_act, time.time()))
await conn.commit()
async def get_cached_company_ids(self) -> List[int]:
"""Get list of company_ids with active cache entries - READ-ONLY, no lock needed"""
conn = await self._conn_manager.get_connection()
async with conn.execute("""
SELECT DISTINCT company_id
FROM cache_entries
WHERE company_id IS NOT NULL
AND expires_at > ?
""", (time.time(),)) as cursor:
results = await cursor.fetchall()
return [row[0] for row in results]
# Statistics
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics - READ-ONLY, no lock needed"""
conn = await self._conn_manager.get_connection()
# Total entries
async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
total_entries = (await cursor.fetchone())[0]
# Active entries (not expired)
async with conn.execute("""
SELECT COUNT(*) FROM cache_entries WHERE expires_at > ?
""", (time.time(),)) as cursor:
active_entries = (await cursor.fetchone())[0]
return {
'total_entries': total_entries,
'active_entries': active_entries,
'expired_entries': total_entries - active_entries
}
async def close(self):
"""Close the connection manager"""
if self._conn_manager:
await self._conn_manager.close()
self._conn_manager = None

View File

@@ -0,0 +1,19 @@
"""
Calendar period models for accounting period selector
"""
from pydantic import BaseModel
from typing import List, Optional
class CalendarPeriod(BaseModel):
"""Model for an accounting period"""
an: int # Year
luna: int # Month (1-12)
display_name: str # Format: "Decembrie 2025"
class CalendarPeriodsResponse(BaseModel):
"""Response model for calendar periods list"""
periods: List[CalendarPeriod]
current_period: Optional[CalendarPeriod] = None # Most recent period
total_count: int

View File

@@ -0,0 +1,156 @@
from pydantic import BaseModel
from decimal import Decimal
from typing import List, Dict, Optional, Any
class BudgetDebtSubAccount(BaseModel):
"""Cont individual din cadrul unui grup de datorii buget"""
cont: str # ex: "4311"
label: str # ex: "4311 - CAS angajat"
precedent: Decimal # sold luna precedentă (pozitiv=datorie, negativ=creanță)
curent: Decimal # sold luna curentă (pozitiv=datorie, negativ=creanță)
datorat: Decimal = Decimal('0') # datorie din luna precedentă (= preccred - precdeb)
achitat: Decimal = Decimal('0') # plăți efectuate luna curentă (= ruldeb)
sold: Decimal = Decimal('0') # sold final real (= soldcred - solddeb)
class BudgetDebtGroup(BaseModel):
"""Grup de datorii la buget (TVA / BASS / CAM)"""
key: str # 'TVA', 'BASS', 'CAM'
label: str # 'TVA', 'BASS', 'CAM'
precedent: Decimal # total grup luna prec (semn ±)
curent: Decimal # total grup luna crt (semn ±)
sub_accounts: List[BudgetDebtSubAccount] = []
datorat: Decimal = Decimal('0') # total datorie grup luna precedentă
achitat: Decimal = Decimal('0') # total plăți grup luna curentă
sold: Decimal = Decimal('0') # sold final real al grupului
class TreasuryAccount(BaseModel):
"""Cont de trezorerie (bancă/casă)"""
cont: str # 5121, 5124, 5311, 5314
nume_cont: str # "Bancă LEI", "Casă VALUTA" etc
nume_banca: str # Numele băncii din vbalanta_parteneri.nume
sold: Decimal
valuta: str
class TrendData(BaseModel):
"""Model pentru datele de trend - MODEL VECHI"""
labels: List[str]
incasari: List[Decimal]
plati: List[Decimal]
trezorerie: List[Decimal]
incasari_total: Decimal
plati_total: Decimal
trezorerie_total: Decimal
incasari_change: Optional[float] = None
plati_change: Optional[float] = None
trezorerie_change: Optional[float] = None
class TrendsResponse(BaseModel):
"""Model pentru răspunsul endpoint-ului de trenduri - MODEL NOU"""
# Current period data
periods: List[str]
clienti_facturat: List[float]
clienti_incasat: List[float]
furnizori_facturat: List[float]
furnizori_achitat: List[float]
clienti_sold: List[float]
furnizori_sold: List[float]
trezorerie_sold: Optional[List[float]] = None
rata_incasare_clienti: List[float]
rata_achitare_furnizori: List[float]
# Previous period data (for year-over-year comparison in sparklines)
previous_periods: Optional[List[str]] = None
clienti_facturat_prev: Optional[List[float]] = None
clienti_incasat_prev: Optional[List[float]] = None
furnizori_facturat_prev: Optional[List[float]] = None
furnizori_achitat_prev: Optional[List[float]] = None
clienti_sold_prev: Optional[List[float]] = None
furnizori_sold_prev: Optional[List[float]] = None
trezorerie_sold_prev: Optional[List[float]] = None
# Metadata and analytics
metadata: Dict[str, Any]
growth_rates: Optional[Dict[str, float]] = None
# Cache metadata (optional, for Telegram Bot)
cache_hit: Optional[bool] = None
response_time_ms: Optional[float] = None
cache_source: Optional[str] = None
class DashboardSummary(BaseModel):
"""Model pentru toate datele dashboard-ului"""
# CLIENȚI - statistici existente
clienti_total_facturat: Decimal # precdeb + debit (conturi 4111, 461)
clienti_total_incasat: Decimal # preccred + credit (conturi 4111, 461)
clienti_avansuri: Decimal # sold 419 (pasiv): credit - debit
clienti_sold_total: Decimal # (facturat - incasat) - avansuri
clienti_sold_restant: Decimal # sold cu datascad < azi
# CLIENȚI - NOI câmpuri pentru sold în termen
clienti_sold_in_termen: Decimal # sold cu datascad >= azi
# CLIENȚI - NOI detalieri restanțe (sold cu datascad < azi)
clienti_restant_7: Decimal # restant 1-7 zile
clienti_restant_14: Decimal # restant 8-14 zile
clienti_restant_30: Decimal # restant 15-30 zile
clienti_restant_60: Decimal # restant 31-60 zile
clienti_restant_90: Decimal # restant 61-90 zile
clienti_restant_90plus: Decimal # restant 90+ zile
# CLIENȚI - NOI detalieri scadențe (sold cu datascad >= azi)
clienti_scadent_7: Decimal # scadent în 1-7 zile
clienti_scadent_14: Decimal # scadent în 8-14 zile
clienti_scadent_30: Decimal # scadent în 15-30 zile
clienti_scadent_60: Decimal # scadent în 31-60 zile
clienti_scadent_90: Decimal # scadent în 61-90 zile
clienti_scadent_90plus: Decimal # scadent în 90+ zile
# FURNIZORI - statistici existente
furnizori_total_facturat: Decimal # preccred + credit (conturi 401, 404, 462)
furnizori_total_achitat: Decimal # precdeb + debit (conturi 401, 404, 462)
furnizori_avansuri: Decimal # sold 409x (activ): debit - credit
furnizori_sold_total: Decimal # (facturat - achitat) - avansuri
furnizori_sold_restant: Decimal # sold cu datascad < azi
# FURNIZORI - NOI câmpuri pentru sold în termen
furnizori_sold_in_termen: Decimal # sold cu datascad >= azi
# FURNIZORI - NOI detalieri restanțe (sold cu datascad < azi)
furnizori_restant_7: Decimal # restant 1-7 zile
furnizori_restant_14: Decimal # restant 8-14 zile
furnizori_restant_30: Decimal # restant 15-30 zile
furnizori_restant_60: Decimal # restant 31-60 zile
furnizori_restant_90: Decimal # restant 61-90 zile
furnizori_restant_90plus: Decimal # restant 90+ zile
# FURNIZORI - NOI detalieri scadențe (sold cu datascad >= azi)
furnizori_scadent_7: Decimal # scadent în 1-7 zile
furnizori_scadent_14: Decimal # scadent în 8-14 zile
furnizori_scadent_30: Decimal # scadent în 15-30 zile
furnizori_scadent_60: Decimal # scadent în 31-60 zile
furnizori_scadent_90: Decimal # scadent în 61-90 zile
furnizori_scadent_90plus: Decimal # scadent în 90+ zile
# TREZORERIE - existente
treasury_accounts: List[TreasuryAccount]
treasury_totals_by_currency: Dict[str, Decimal]
# DATE SUPLIMENTARE pentru trend analysis
clienti_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
furnizori_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
clienti_facturat_an_curent: Optional[Decimal] = Decimal('0')
clienti_facturat_an_anterior: Optional[Decimal] = Decimal('0')
furnizori_facturat_an_curent: Optional[Decimal] = Decimal('0')
furnizori_facturat_an_anterior: Optional[Decimal] = Decimal('0')
# SOLDURI TVA
tva_plata_precedent: Decimal = Decimal('0')
tva_recuperat_precedent: Decimal = Decimal('0')
tva_plata_curent: Decimal = Decimal('0')
tva_recuperat_curent: Decimal = Decimal('0')
# DATORII LA BUGET - breakdown pe grupe (TVA / BASS / CAM) cu sub-conturi
budget_debt_breakdown: List[BudgetDebtGroup] = []
budget_debt_total_precedent: Decimal = Decimal('0') # suma tuturor grupurilor luna prec
budget_debt_total_sold: Decimal = Decimal('0') # sold final total (cât mai rămâne de plată)

View File

@@ -0,0 +1,79 @@
"""
Modele Pydantic pentru facturi - Compatibile cu aplicația Flask existentă
"""
from pydantic import BaseModel, Field, validator
from datetime import date
from typing import Optional, List, Literal
from decimal import Decimal
class InvoiceBase(BaseModel):
"""Model de bază pentru factură - mapează exact pe rezultatul query-ului Flask"""
nume: str = Field(description="Numele partenerului")
nract: int = Field(description="Numărul actului")
dataact: Optional[date] = Field(description="Data actului")
datascad: Optional[date] = Field(description="Data scadentă")
contract: Optional[str] = Field(description="Numărul contractului")
cod_fiscal: Optional[str] = Field(description="Codul fiscal")
reg_comert: Optional[str] = Field(description="Registrul comerțului")
cont: Optional[str] = Field(description="Contul contabil")
valuta: str = Field(default="RON", description="Valuta (RON, EUR, USD, etc.)")
class Invoice(InvoiceBase):
"""Model complet pentru factură cu calcule financiare"""
totctva: Decimal = Field(description="Total cu TVA", decimal_places=2)
achitat: Decimal = Field(description="Suma achitată", decimal_places=2)
soldfinal: Decimal = Field(description="Soldul final", decimal_places=2)
css_class: Literal["", "invoice-paid", "invoice-overdue"] = Field(
default="", description="Clasa CSS pentru stilizare"
)
@validator('css_class', always=True)
def determine_css_class(cls, v, values):
"""Determină automat clasa CSS bazată pe status factură"""
if 'soldfinal' in values and 'datascad' in values:
sold = values['soldfinal']
data_scad = values['datascad']
if sold < 1:
return 'invoice-paid'
elif data_scad and data_scad < date.today() and sold != 0:
return 'invoice-overdue'
return ''
class InvoiceFilter(BaseModel):
"""Filtru pentru căutarea facturilor"""
company: str = Field(description="Codul firmei (schema Oracle)")
partner_type: Literal["CLIENTI", "FURNIZORI"] = Field(description="Tipul partenerului")
luna: Optional[int] = Field(default=None, ge=1, le=12, description="Luna contabilă (1-12)")
an: Optional[int] = Field(default=None, ge=2000, le=2100, description="Anul contabil")
partner_name: Optional[str] = Field(description="Filtru după nume")
cont: Optional[str] = Field(description="Filtru după cont contabil")
only_unpaid: bool = Field(default=True, description="Doar neachitate")
min_amount: Optional[Decimal] = Field(description="Suma minimă")
max_amount: Optional[Decimal] = Field(description="Suma maximă")
page: int = Field(default=1, ge=1, description="Pagina")
page_size: int = Field(default=50, ge=1, le=10000000, description="Mărimea paginii")
class InvoiceListResponse(BaseModel):
"""Răspuns pentru lista de facturi"""
invoices: List[Invoice]
total_count: int
filtered_count: int
total_amount: Decimal
page: int
page_size: int
has_more: bool
accounting_period: Optional[dict] = Field(default=None, description="Perioada contabilă (an, luna)")
# Total sold din TOATE facturile filtrate (nu doar pagina curentă)
total_sold_all: Decimal = Field(default=Decimal('0.00'), description="Total sold din toate facturile filtrate")
class InvoiceSummary(BaseModel):
"""Rezumat pentru facturi - pentru dashboard"""
company: str
partner_type: str
total_invoices: int
total_amount: Decimal
paid_amount: Decimal
outstanding_amount: Decimal
overdue_amount: Decimal
overdue_count: int

View File

@@ -0,0 +1,52 @@
from pydantic import BaseModel
from decimal import Decimal
from datetime import datetime
from typing import Optional, List
class AccountingPeriod(BaseModel):
"""Model pentru perioada contabilă"""
an: Optional[int] = None
luna: Optional[int] = None
class BankCashRegister(BaseModel):
"""Model pentru Registrul de Casă și Bancă"""
nume: str
nract: Optional[int] = None
dataact: Optional[datetime] = None
nume_cont_bancar: str # din vbalanta_parteneri.nume
incasari: Decimal
plati: Decimal
sold: Decimal
valuta: Optional[str] = None
tip_registru: str # "BANCA LEI", "CASA VALUTA" etc
explicatia: str
class RegisterFilter(BaseModel):
"""Filtre pentru registrul de casă și bancă"""
company: str
register_type: Optional[str] = None # BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate
luna: Optional[int] = None # Luna contabilă (1-12) pentru PACK_SESIUNE
an: Optional[int] = None # Anul contabil pentru PACK_SESIUNE
date_from: Optional[datetime] = None
date_to: Optional[datetime] = None
partner_name: Optional[str] = None
bank_account: Optional[str] = None # Filter for specific bank/cash account (bancasa)
page: int = 1
page_size: int = 50
class RegisterListResponse(BaseModel):
"""Răspuns pentru lista din registru"""
registers: List[BankCashRegister]
total_count: int
filtered_count: int
total_incasari: Decimal
total_plati: Decimal
page: int
page_size: int
has_more: bool
accounting_period: Optional[AccountingPeriod] = None
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
sold_precedent_all: Decimal = Decimal('0.00')
total_incasari_all: Decimal = Decimal('0.00')
total_plati_all: Decimal = Decimal('0.00')
sold_final_all: Decimal = Decimal('0.00')

View File

@@ -0,0 +1,102 @@
"""
Pydantic models for Trial Balance (Balanță de Verificare)
Maps to Oracle VBAL VIEW (exists in each company schema)
"""
from pydantic import BaseModel, Field
from typing import Optional, List
from decimal import Decimal
class TrialBalanceItem(BaseModel):
"""
Individual trial balance record from VBAL VIEW
Real structure from Oracle:
- CONT: account number
- DENUMIRE: account description
- PRECDEB/PRECCRED: previous balance debit/credit
- RULDEB/RULCRED: monthly movement debit/credit
- SOLDDEB/SOLDCRED: final balance debit/credit
"""
cont: str = Field(description="Număr cont contabil (CONT)")
denumire: Optional[str] = Field(default="", description="Denumire cont (DENUMIRE)")
sold_precedent_debit: Decimal = Field(description="Sold precedent debit (PRECDEB)", decimal_places=2)
sold_precedent_credit: Decimal = Field(description="Sold precedent credit (PRECCRED)", decimal_places=2)
rulaj_lunar_debit: Decimal = Field(description="Rulaj lunar debit (RULDEB)", decimal_places=2)
rulaj_lunar_credit: Decimal = Field(description="Rulaj lunar credit (RULCRED)", decimal_places=2)
sold_final_debit: Decimal = Field(description="Sold final debit (SOLDDEB)", decimal_places=2)
sold_final_credit: Decimal = Field(description="Sold final credit (SOLDCRED)", decimal_places=2)
class Config:
from_attributes = True
class TrialBalanceFilters(BaseModel):
"""
Filters applied to trial balance data
"""
luna: int = Field(description="Luna (1-12)")
an: int = Field(description="An")
cont_filter: Optional[str] = Field(default=None, description="Filtru număr cont (partial match)")
denumire_filter: Optional[str] = Field(default=None, description="Filtru denumire cont (partial match, case-insensitive)")
class TrialBalancePagination(BaseModel):
"""
Pagination metadata
"""
total_items: int = Field(description="Total number of items")
total_pages: int = Field(description="Total number of pages")
current_page: int = Field(description="Current page number")
page_size: int = Field(description="Items per page")
class TrialBalanceTotals(BaseModel):
"""
Totals for all 6 columns from all filtered records (not just current page)
"""
total_sold_precedent_debit: Decimal = Decimal('0.00')
total_sold_precedent_credit: Decimal = Decimal('0.00')
total_rulaj_lunar_debit: Decimal = Decimal('0.00')
total_rulaj_lunar_credit: Decimal = Decimal('0.00')
total_sold_final_debit: Decimal = Decimal('0.00')
total_sold_final_credit: Decimal = Decimal('0.00')
class TrialBalanceResponse(BaseModel):
"""
Complete response for trial balance endpoint
"""
success: bool = Field(default=True, description="Request success status")
data: dict = Field(description="Trial balance data with items, pagination, and filters")
class Config:
json_schema_extra = {
"example": {
"success": True,
"data": {
"items": [
{
"cont": "4111",
"dcont": "Furnizori interni",
"sold_precedent_debit": 0.00,
"sold_precedent_credit": 15000.00,
"rulaj_lunar_debit": 5000.00,
"rulaj_lunar_credit": 8000.00,
"sold_final_debit": 0.00,
"sold_final_credit": 18000.00
}
],
"pagination": {
"total_items": 150,
"total_pages": 3,
"current_page": 1,
"page_size": 50
},
"filters_applied": {
"luna": 11,
"an": 2025,
"cont_filter": None,
"denumire_filter": "furnizori"
}
}
}
}

View File

@@ -0,0 +1,36 @@
"""Reports module router factory."""
from fastapi import APIRouter
def create_reports_router() -> APIRouter:
"""
Create and configure Reports module router.
Includes all report-related endpoints:
- /invoices - Invoice management
- /dashboard - Dashboard and metrics
- /treasury - Treasury operations
- /trial-balance - Trial balance reports
- /cache - Cache management
Returns:
APIRouter: Configured router for reports module
"""
router = APIRouter()
# Import routers here to avoid circular imports
from .invoices import router as invoices_router
from .dashboard import router as dashboard_router
from .treasury import router as treasury_router
from .trial_balance import router as trial_balance_router
from .cache import router as cache_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/reports)
router.include_router(invoices_router, prefix="/invoices", tags=["reports-invoices"])
router.include_router(dashboard_router, prefix="/dashboard", tags=["reports-dashboard"])
router.include_router(treasury_router, prefix="/treasury", tags=["reports-treasury"])
router.include_router(trial_balance_router, prefix="/trial-balance", tags=["reports-trial-balance"])
router.include_router(cache_router, prefix="/cache", tags=["reports-cache"])
return router

View File

@@ -0,0 +1,398 @@
"""
API Router pentru managementul cache-ului
"""
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from typing import Optional, Dict, Any
# import sys # Removed - no longer needed
import os
import time
from datetime import datetime, timedelta
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..cache import get_cache, get_event_monitor, toggle_event_monitor
router = APIRouter(tags=["cache"])
# Pydantic Models
class CacheStatsResponse(BaseModel):
"""Răspuns statistici cache"""
enabled: bool
global_enabled: bool
user_enabled: bool
cache_type: str
hit_rate: float
total_hits: int
total_misses: int
queries_saved: Dict[str, int]
response_times: Dict[str, Dict[str, Any]]
cache_size: Dict[str, int]
auto_invalidate: bool
last_cleanup: Optional[str] = None
class InvalidateCacheRequest(BaseModel):
"""Request pentru invalidare cache"""
company_id: Optional[int] = None
cache_type: Optional[str] = None
class ToggleUserCacheRequest(BaseModel):
"""Request pentru toggle cache per-user"""
enabled: bool
class ToggleGlobalCacheRequest(BaseModel):
"""Request pentru toggle cache global"""
enabled: bool
class ToggleAutoInvalidateRequest(BaseModel):
"""Request pentru toggle auto-invalidation"""
enabled: bool
# Helper Functions
async def _calculate_cache_stats() -> Dict[str, Any]:
"""Calculate comprehensive cache statistics"""
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Get basic cache stats
stats = await cache.get_stats()
# Calculate hit rate
memory_stats = stats.get('memory', {})
total_hits = memory_stats.get('hits', 0)
total_misses = memory_stats.get('misses', 0)
total_requests = total_hits + total_misses
hit_rate = (total_hits / total_requests * 100) if total_requests > 0 else 0
# Calculate queries saved (from performance_log)
queries_saved = await _calculate_queries_saved(cache)
# Calculate response times per cache type
response_times = await _calculate_response_times(cache)
# Get cache sizes
cache_size = {
'memory': memory_stats.get('size', 0),
'sqlite': stats.get('sqlite', {}).get('active_entries', 0)
}
# Get event monitor status
monitor = get_event_monitor()
auto_invalidate = monitor.running if monitor else False
return {
'enabled': cache.config.enabled,
'global_enabled': cache.config.enabled,
'cache_type': cache.config.cache_type,
'hit_rate': round(hit_rate, 1),
'total_hits': total_hits,
'total_misses': total_misses,
'queries_saved': queries_saved,
'response_times': response_times,
'cache_size': cache_size,
'auto_invalidate': auto_invalidate,
'last_cleanup': None # TODO: track last cleanup time
}
async def _calculate_queries_saved(cache) -> Dict[str, int]:
"""Calculate queries saved by time period"""
import aiosqlite
try:
async with aiosqlite.connect(cache.sqlite.db_path) as db:
now = time.time()
today_start = now - 86400 # 24 hours
week_start = now - 604800 # 7 days
# Today
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1 AND timestamp >= ?
""", (today_start,)) as cursor:
today = (await cursor.fetchone())[0]
# This week
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1 AND timestamp >= ?
""", (week_start,)) as cursor:
week = (await cursor.fetchone())[0]
# All time
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1
""") as cursor:
total = (await cursor.fetchone())[0]
return {
'today': today,
'week': week,
'total': total
}
except Exception as e:
return {'today': 0, 'week': 0, 'total': 0}
async def _calculate_response_times(cache) -> Dict[str, Dict[str, Any]]:
"""Calculate average response times per cache type"""
import aiosqlite
try:
async with aiosqlite.connect(cache.sqlite.db_path) as db:
# Get average times per cache type
async with db.execute("""
SELECT
cache_type,
AVG(CASE WHEN cache_hit = 1 THEN response_time_ms ELSE NULL END) as avg_cached,
AVG(CASE WHEN cache_hit = 0 THEN response_time_ms ELSE NULL END) as avg_oracle
FROM performance_log
WHERE timestamp >= ?
GROUP BY cache_type
""", (time.time() - 86400,)) as cursor: # Last 24 hours
results = await cursor.fetchall()
response_times = {}
for row in results:
cache_type, avg_cached, avg_oracle = row
if avg_cached and avg_oracle:
improvement = int((avg_oracle - avg_cached) / avg_oracle * 100)
response_times[cache_type] = {
'cached': int(avg_cached),
'oracle': int(avg_oracle),
'improvement': improvement
}
return response_times
except Exception as e:
return {}
# API Endpoints
@router.get("/stats", response_model=CacheStatsResponse)
async def get_cache_stats(
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține statistici complete cache
Returns:
- Hit rate, queries saved, response times
- Cache sizes (memory + SQLite)
- Auto-invalidation status
- Per-user cache setting
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Get base stats
stats = await _calculate_cache_stats()
# Add user-specific setting
user_enabled = await cache.is_enabled_for_user(current_user.username)
stats['user_enabled'] = user_enabled
return CacheStatsResponse(**stats)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving cache stats: {str(e)}")
@router.post("/invalidate")
async def invalidate_cache(
request: InvalidateCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Invalidează cache
Args:
company_id: Opțional - invalidează doar pentru această companie
cache_type: Opțional - invalidează doar acest tip de cache
Returns:
Message de confirmare
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
await cache.invalidate(
company_id=request.company_id,
cache_type=request.cache_type
)
if request.company_id and request.cache_type:
message = f"Cache invalidated for company {request.company_id}, type {request.cache_type}"
elif request.company_id:
message = f"Cache invalidated for company {request.company_id}"
elif request.cache_type:
message = f"Cache invalidated for type {request.cache_type}"
else:
message = "All cache invalidated"
return {
"success": True,
"message": message,
"invalidated_at": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error invalidating cache: {str(e)}")
@router.post("/toggle-user")
async def toggle_user_cache(
request: ToggleUserCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle cache per-user
Permite utilizatorului să activeze/dezactiveze cache-ul pentru el
Folosit pentru A/B testing și comparații de performanță
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
await cache.set_user_cache_enabled(current_user.username, request.enabled)
return {
"success": True,
"username": current_user.username,
"cache_enabled": request.enabled,
"message": f"Cache {'enabled' if request.enabled else 'disabled'} for user {current_user.username}"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling user cache: {str(e)}")
@router.post("/toggle-global")
async def toggle_global_cache(
request: ToggleGlobalCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle cache global (ADMIN only)
Activează/dezactivează cache-ul la nivel global pentru toți utilizatorii
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status global
"""
try:
# TODO: Add admin permission check
# For now, allow any authenticated user
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Update config (NOTE: This is runtime only, .env needs manual update)
cache.config.enabled = request.enabled
return {
"success": True,
"global_enabled": request.enabled,
"message": f"Cache {'enabled' if request.enabled else 'disabled'} globally",
"note": "This change is runtime only. Update .env CACHE_ENABLED for persistence."
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling global cache: {str(e)}")
@router.post("/toggle-auto-invalidate")
async def toggle_auto_invalidation(
request: ToggleAutoInvalidateRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle auto-invalidation monitoring
Activează/dezactivează monitorizarea automată a {schema}.act
pentru invalidarea cache-ului când se detectează modificări
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status auto-invalidation
"""
try:
# TODO: Add admin permission check
# For now, allow any authenticated user
await toggle_event_monitor(request.enabled)
return {
"success": True,
"auto_invalidate_enabled": request.enabled,
"message": f"Auto-invalidation {'enabled' if request.enabled else 'disabled'}",
"note": "Monitors max(id_act) in {schema}.act tables for changes"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling auto-invalidation: {str(e)}")
@router.get("/health")
async def cache_health():
"""
Health check pentru sistemul de cache
Returns:
Status cache, mărime, și uptime
"""
try:
cache = get_cache()
if not cache:
return {
"status": "not_initialized",
"enabled": False
}
stats = await cache.get_stats()
monitor = get_event_monitor()
return {
"status": "healthy",
"enabled": cache.config.enabled,
"cache_type": cache.config.cache_type,
"memory_size": stats.get('memory', {}).get('size', 0),
"sqlite_size": stats.get('sqlite', {}).get('active_entries', 0),
"auto_invalidate_running": monitor.running if monitor else False
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}

View File

@@ -0,0 +1,661 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from typing import Optional
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
import logging
logger = logging.getLogger(__name__)
from ..models.dashboard import DashboardSummary, TrendsResponse, TrendData
from ..models.financial_indicators import FinancialIndicatorsResponse
from ..services.dashboard_service import DashboardService
from ..services.financial_indicators_service import FinancialIndicatorsService
from ..cache.decorators import cached
router = APIRouter()
@router.get("/summary")
async def get_dashboard_summary(
request: Request,
company: str = Query(description="Codul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține toate datele pentru dashboard într-un singur apel
- Necesită autentificare JWT
- Returnează statistici clienți/furnizori și trezorerie
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
- Suportă filtrare pe luna/an contabil (dacă nu sunt specificate, folosește ultima perioadă)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_complete_summary(company, current_user.username, luna=luna, an=an, request=request, server_id=server_id)
# Convert Pydantic model to dict for JSON serialization
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor dashboard: {str(e)}")
@router.get("/trends", response_model=TrendsResponse)
async def get_dashboard_trends(
request: Request,
company: str = Query(description="Codul firmei"),
period: str = Query(default="30d", description="Perioada pentru trends: 7d, 30d, ytd, 12m"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
compare_previous: bool = Query(default=True, description="Compară cu perioada anterioară"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține trenduri pentru indicatorii principali (clienți/furnizori)
- period: "7d" (7 zile), "30d" (30 zile), "ytd" (year to date), "12m" (12 luni)
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- compare_previous: dacă să compare cu perioada anterioară
- Necesită autentificare JWT
- Returnează date pentru grafice de trenduri
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
# Validează perioada
valid_periods = ["7d", "30d", "ytd", "12m"]
if period not in valid_periods:
raise HTTPException(
status_code=400,
detail=f"Perioadă nevalidă: {period}. Valori permise: {', '.join(valid_periods)}"
)
server_id = getattr(request.state, 'server_id', None)
# Obține datele de trenduri
result = await DashboardService.get_trends(int(company), period, luna=luna, an=an, request=request, server_id=server_id)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
# Return as TrendsResponse
return TrendsResponse(**result_dict)
except ValueError as e:
logger.error(f"Value error in trends endpoint: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea trendurilor: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea trendurilor: {str(e)}")
@router.get("/detailed-data")
async def get_detailed_data(
request: Request,
company: str = Query(description="Codul firmei"),
data_type: str = Query(description="Tipul de date: clients, suppliers, treasury"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
page: int = Query(default=1, ge=1),
page_size: int = Query(default=25, ge=1, le=100),
search: str = Query(default="", description="Termen de căutare"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține date detaliate pentru tabelele din dashboard
"""
logger.info(f"[ROUTER] detailed-data called: company={company}, data_type={data_type}")
try:
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
logger.info(f"[ROUTER] Calling DashboardService.get_detailed_data")
result = await DashboardService.get_detailed_data(
company=company,
data_type=data_type,
luna=luna,
an=an,
page=page,
page_size=page_size,
search=search,
server_id=server_id
)
logger.info(f"[ROUTER] Service returned: {len(result.get('data', []))} rows")
return result
except Exception as e:
logger.error(f"Eroare la obținerea datelor detaliate: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/performance")
async def get_performance(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m|ytd|12m)$", description="Perioada pentru analiză"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează date performanță pentru perioada selectată
- Necesită autentificare JWT
- Returnează grafice încasări vs plăți pentru perioada selectată
- Calculează indicatori: rata încasării, cash conversion, working capital
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_performance_data(company, period, server_id=server_id)
# Convert to Chart.js compatible format
return {
"labels": result.get("labels", []),
"datasets": [{
"data": result.get("data", []),
"label": result.get("label", "Performance"),
"borderColor": result.get("borderColor", "#3B82F6"),
"backgroundColor": result.get("backgroundColor", "rgba(59, 130, 246, 0.1)"),
"tension": 0.4
}]
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea datelor de performanță: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor de performanță: {str(e)}")
@router.get("/cashflow")
async def get_cashflow(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m)$", description="Perioada pentru previziune"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează previziune cash flow pentru perioada selectată
- Necesită autentificare JWT
- Analizează scadențele viitoare pentru calculul cash flow-ului
- Identifică zilele critice cu deficit de cash
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_cashflow_forecast(company, period, server_id=server_id)
# Convert to Chart.js compatible format
return {
"labels": result.get("labels", []),
"datasets": [{
"data": result.get("data", []),
"label": result.get("label", "Cash Flow"),
"borderColor": result.get("borderColor", "#10B981"),
"backgroundColor": result.get("backgroundColor", "rgba(16, 185, 129, 0.1)"),
"tension": 0.4
}]
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea previziunii cash flow: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea previziunii cash flow: {str(e)}")
@router.get("/maturity")
async def get_maturity_analysis(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m|12m|all)$", description="Orizont de planificare pentru analiza scadențelor"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează analiza scadențelor pentru orizontul de planificare selectat
- Necesită autentificare JWT
- Logică: Include TOATE restanțele + scadențele viitoare din perioada selectată
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Perioade disponibile:
* 7d: Toate restanțele + scadențe următoarelor 7 zile
* 1m: Toate restanțele + scadențe următoarelor 30 zile
* 3m: Toate restanțele + scadențe următoarelor 90 zile
* 6m: Toate restanțele + scadențe următoarelor 180 zile
* 12m: Toate restanțele + scadențe următoarelor 365 zile
* all: Toate soldurile (fără filtru)
- Compară scadențele clienți vs furnizori
- Calculează balanța și oferă recomandări
- Returnează metadate cu statistici complete
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_maturity_analysis(company, period, luna=luna, an=an, request=request, server_id=server_id)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea analizei scadențelor: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea analizei scadențelor: {str(e)}")
@router.get("/monthly-flows")
async def get_monthly_flows(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează fluxurile lunare pentru firma selectată
- Necesită autentificare JWT
- Returnează date pentru analiza fluxurilor lunare
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
# Apelăm serviciul cu request pentru cache metadata
result = await DashboardService.get_monthly_flows(company, luna=luna, an=an, request=request, server_id=server_id)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot / Dashboard)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea fluxurilor lunare: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea fluxurilor lunare: {str(e)}")
@router.get("/treasury-breakdown")
async def get_treasury_breakdown(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează defalcarea trezoreriei pentru firma selectată
- Necesită autentificare JWT
- Returnează distribuția soldurilor pe conturi și tipuri
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_treasury_breakdown(company, luna=luna, an=an, request=request, server_id=server_id)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
@router.get("/net-balance-breakdown")
async def get_net_balance_breakdown(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează defalcarea balanței nete pentru firma selectată
- Necesită autentificare JWT
- Returnează analiza detaliată a balanței nete
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_net_balance_breakdown(company, luna=luna, an=an, request=request, server_id=server_id)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea defalcării balanței nete: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării balanței nete: {str(e)}")
@router.get("/current-period")
async def get_current_period(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează perioada curentă (an și lună) din calendarul Oracle
- Necesită autentificare JWT
- Returnează anul, luna și perioada curentă în format YYYY-MM
- Folosit pentru afișarea lunii curente în dashboard
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await DashboardService.get_current_period(company, server_id=server_id)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea perioadei curente: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea perioadei curente: {str(e)}")
@router.get(
"/financial-indicators",
tags=["dashboard"]
)
async def get_financial_indicators(
request: Request,
company: int = Query(..., description="ID-ul firmei (required)"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
include_sparklines: bool = Query(True, description="Include date istorice pentru sparklines (12 luni)"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează toți indicatorii financiari calculați pentru firma selectată.
Acest endpoint agregă datele din:
- Lichiditate: Current Ratio, Quick Ratio, Cash Ratio
- Eficiență: DSO, DPO, Cash Conversion Cycle, rate încasare/plată
- Risc: creanțe/datorii restante, raport datorii/trezorerie
- Cash Flow: flux net lunar, YTD, YoY, acoperire
- Dinamică: creștere vânzări/achiziții YoY, marjă implicită
- Altman Z-Score: scor și componente X1-X4
Parametri:
- company (required): ID-ul firmei pentru care se calculează indicatorii
- luna (optional): Luna contabilă (1-12). Dacă nu este specificată,
se folosește ultima perioadă disponibilă.
- an (optional): Anul contabil (2000-2100). Dacă nu este specificat,
se folosește anul curent.
- include_sparklines (optional, default=true): Dacă să includă date istorice
pentru vizualizarea trendului pe ultimele 12 luni (sparkline_data și sparkline_labels
în fiecare indicator)
Cache:
- TTL: 30 minute pentru indicatori curenți (cache_type='financial_indicators')
- TTL: 1 oră pentru date istorice sparkline (cache_type='financial_indicators_historical')
- Se invalidează automat la schimbarea datelor din balanță
Necesită autentificare JWT și acces la firma specificată.
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company}"
)
# Dacă luna/an nu sunt specificate, obținem perioada curentă
# Folosim variabile tipizate explicit pentru a evita erori de tip
resolved_luna: int
resolved_an: int
server_id = getattr(request.state, 'server_id', None)
if luna is None or an is None:
try:
current_period = await DashboardService.get_current_period(company, server_id=server_id)
resolved_luna = luna if luna is not None else current_period.get('luna', 12)
resolved_an = an if an is not None else current_period.get('an', 2024)
except Exception as e:
logger.warning(f"Could not get current period: {e}, using defaults")
from datetime import datetime
resolved_luna = luna if luna is not None else datetime.now().month
resolved_an = an if an is not None else datetime.now().year
else:
resolved_luna = luna
resolved_an = an
# Dacă include_sparklines este True, folosim metoda care include datele istorice
if include_sparklines:
response = await FinancialIndicatorsService.get_indicators_with_sparklines(
company, resolved_luna, resolved_an, months=12, request=request, server_id=server_id
)
# FIX: Cache poate returna dict în loc de obiect Pydantic
# Extragem valorile pentru logging în mod compatibil cu ambele tipuri
if isinstance(response, dict):
zscore_val = response.get('altman_zscore', {}).get('zscore', {}).get('value')
zscore_status = response.get('altman_zscore', {}).get('zscore', {}).get('status')
else:
zscore_val = response.altman_zscore.zscore.value
zscore_status = response.altman_zscore.zscore.status
logger.info(
f"Financial indicators with sparklines for company {company}, "
f"luna={resolved_luna}, an={resolved_an}: "
f"Z-Score={zscore_val} ({zscore_status}), "
f"cache_hit={getattr(request.state, 'cache_hit', False)}, "
f"response_time={getattr(request.state, 'response_time_ms', 0):.1f}ms"
)
# Add cache metadata if requested (for Telegram Bot / Dashboard)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
result_dict = response.dict() if hasattr(response, 'dict') else response
result_dict['cache_hit'] = getattr(request.state, 'cache_hit', False)
result_dict['response_time_ms'] = getattr(request.state, 'response_time_ms', 0)
result_dict['cache_source'] = getattr(request.state, 'cache_source', None)
return result_dict
return response
# Dacă include_sparklines este False, calculăm doar indicatorii curenți
import asyncio
# Apelăm serviciul pentru fiecare categorie de indicatori
lichiditate_task = FinancialIndicatorsService.calculate_liquidity_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
eficienta_task = FinancialIndicatorsService.calculate_efficiency_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
risc_task = FinancialIndicatorsService.calculate_risk_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
cash_flow_task = FinancialIndicatorsService.calculate_cashflow_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
dinamica_task = FinancialIndicatorsService.calculate_dynamics_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
altman_task = FinancialIndicatorsService.calculate_altman_zscore(
company, resolved_luna, resolved_an, server_id=server_id
)
profitabilitate_task = FinancialIndicatorsService.calculate_profitability_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
solvabilitate_task = FinancialIndicatorsService.calculate_solvability_indicators(
company, resolved_luna, resolved_an, server_id=server_id
)
# Executăm toate calculele în paralel pentru performanță
(
lichiditate,
eficienta,
risc,
cash_flow,
dinamica,
altman_zscore,
profitabilitate,
solvabilitate
) = await asyncio.gather(
lichiditate_task,
eficienta_task,
risc_task,
cash_flow_task,
dinamica_task,
altman_task,
profitabilitate_task,
solvabilitate_task
)
# Construim răspunsul
response = FinancialIndicatorsResponse(
lichiditate=lichiditate,
eficienta=eficienta,
risc=risc,
cash_flow=cash_flow,
dinamica=dinamica,
altman_zscore=altman_zscore,
profitabilitate=profitabilitate,
solvabilitate=solvabilitate
)
# FIX: Cache poate returna dict în loc de obiect Pydantic
if isinstance(altman_zscore, dict):
zscore_val = altman_zscore.get('zscore', {}).get('value')
zscore_status = altman_zscore.get('zscore', {}).get('status')
else:
zscore_val = altman_zscore.zscore.value
zscore_status = altman_zscore.zscore.status
logger.info(
f"Financial indicators for company {company}, luna={resolved_luna}, an={resolved_an}: "
f"Z-Score={zscore_val} ({zscore_status})"
)
# Add cache metadata if requested (for Telegram Bot / Dashboard)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
result_dict = response.dict() if hasattr(response, 'dict') else response
result_dict['cache_hit'] = getattr(request.state, 'cache_hit', False)
result_dict['response_time_ms'] = getattr(request.state, 'response_time_ms', 0)
result_dict['cache_source'] = getattr(request.state, 'cache_source', None)
return result_dict
return response
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea indicatorilor financiari: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Eroare la obținerea indicatorilor financiari: {str(e)}"
)

View File

@@ -0,0 +1,140 @@
"""
API Router pentru facturi
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from typing import List, Optional
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user, require_company_access
from shared.auth.models import CurrentUser
from ..models.invoice import InvoiceFilter, InvoiceListResponse, InvoiceSummary
from ..services.invoice_service import InvoiceService
router = APIRouter()
@router.get("/", response_model=InvoiceListResponse)
async def get_invoices(
request: Request,
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
cont: Optional[str] = Query(None, description="Filtru după cont contabil"),
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
min_amount: Optional[float] = Query(None, description="Suma minimă"),
max_amount: Optional[float] = Query(None, description="Suma maximă"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține lista de facturi pentru o firmă
- Necesită autentificare JWT
- Utilizatorul trebuie să aibă acces la firma specificată
- Suportă filtrare după luna/an contabil și paginare
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
filter_params = InvoiceFilter(
company=company,
partner_type=partner_type,
luna=luna,
an=an,
partner_name=partner_name,
cont=cont,
only_unpaid=only_unpaid,
min_amount=min_amount,
max_amount=max_amount,
page=page,
page_size=page_size
)
result = await InvoiceService.get_invoices(filter_params, current_user.username, server_id=server_id)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea facturilor: {str(e)}")
@router.get("/summary", response_model=InvoiceSummary)
async def get_invoices_summary(
request: Request,
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
current_user: CurrentUser = Depends(get_current_user)
):
"""Obține rezumatul facturilor pentru dashboard"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await InvoiceService.get_invoice_summary(company, partner_type, current_user.username, server_id=server_id)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea rezumatului facturilor: {str(e)}")
@router.get("/{invoice_number}")
async def get_invoice_details(
request: Request,
invoice_number: str,
company: str = Query(description="Codul firmei"),
current_user: CurrentUser = Depends(get_current_user)
):
"""Obține detaliile unei facturi specifice"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
result = await InvoiceService.get_invoice_details(company, invoice_number, current_user.username, server_id=server_id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea detaliilor facturii: {str(e)}")
@router.get("/export/{format}")
async def export_invoices(
request: Request,
format: str,
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Export facturi în format specificat (excel, pdf, csv)
Această funcție va fi implementată în viitor
"""
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None) # For future use
# Verifică formatul
if format not in ["excel", "pdf", "csv"]:
raise HTTPException(status_code=400, detail="Format invalid. Formatele suportate sunt: excel, pdf, csv")
# Pentru moment, returnează o eroare că funcția nu este implementată
raise HTTPException(status_code=501, detail=f"Export în format {format} nu este încă implementat")

View File

@@ -0,0 +1,123 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from typing import Optional, List
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..models.treasury import RegisterFilter, RegisterListResponse
from ..services.treasury_service import TreasuryService
router = APIRouter()
@router.get("/bank-cash-register", response_model=RegisterListResponse)
async def get_bank_cash_register(
request: Request,
company: str = Query(description="Codul firmei"),
register_type: Optional[str] = Query(None, description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
bank_account: Optional[str] = Query(None, description="Filtru cont bancă/casă (bancasa)"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține registrul de casă și bancă
- Necesită autentificare JWT
- Suportă filtrare pe tip registru: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA
- Suportă filtrare și paginare
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
# Validează register_type dacă e specificat
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
if register_type and register_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
)
# Convertește datele
date_from_obj = None
date_to_obj = None
if date_from:
try:
date_from_obj = date.fromisoformat(date_from)
except ValueError:
raise HTTPException(status_code=400, detail="Format dată început invalid")
if date_to:
try:
date_to_obj = date.fromisoformat(date_to)
except ValueError:
raise HTTPException(status_code=400, detail="Format dată sfârșit invalid")
filter_params = RegisterFilter(
company=company,
register_type=register_type,
luna=luna,
an=an,
date_from=date_from_obj,
date_to=date_to_obj,
partner_name=partner_name,
bank_account=bank_account,
page=page,
page_size=page_size
)
result = await TreasuryService.get_bank_cash_register(filter_params, current_user.username, server_id=server_id)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea registrului: {str(e)}")
@router.get("/bank-cash-accounts", response_model=List[str])
async def get_bank_cash_accounts(
request: Request,
company: str = Query(description="Codul firmei"),
register_type: str = Query(description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține lista distinctă de conturi bancă/casă pentru dropdown
- Necesită autentificare JWT
- Returnează lista de valori bancasa pentru tipul de registru selectat
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
server_id = getattr(request.state, 'server_id', None)
# Validează register_type
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
if register_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
)
result = await TreasuryService.get_bank_cash_accounts(int(company), register_type, server_id=server_id)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea conturilor: {str(e)}")

View File

@@ -0,0 +1,94 @@
"""
API Router for Trial Balance (Balanță de Verificare)
Refactored to use service layer with caching
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from typing import Optional
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..models.trial_balance import TrialBalanceResponse
from ..services.trial_balance_service import TrialBalanceService
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/", response_model=TrialBalanceResponse)
async def get_trial_balance(
request: Request,
company: str = Query(description="Codul firmei (ID)"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna (1-12), default: luna curentă"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="An, default: anul curent"),
cont_filter: Optional[str] = Query(None, description="Filtru număr cont (ex: '512', '4111')"),
denumire_filter: Optional[str] = Query(None, description="Filtru denumire cont (partial match, case-insensitive)"),
sort_by: str = Query("CONT", description="Coloană pentru sortare"),
sort_order: str = Query("asc", description="Ordinea sortării (asc | desc)"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=1000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține balanța de verificare sintetică pentru o firmă
- Necesită autentificare JWT
- Utilizatorul trebuie să aibă acces la firma specificată
- Suportă filtrare după cont și denumire
- Suportă paginare și sortare
- **CACHED 10 min** - folosește sistem cache two-tier (L1 Memory + L2 SQLite)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company}"
)
server_id = getattr(request.state, 'server_id', None)
# Setează valorile implicite pentru lună și an (luna și anul curent)
current_date = date.today()
if luna is None:
luna = current_date.month
if an is None:
an = current_date.year
# Convert company to int
company_id = int(company)
# Call service (with caching) - all business logic moved to service
data = await TrialBalanceService.get_trial_balance(
company_id=company_id,
luna=luna,
an=an,
cont_filter=cont_filter,
denumire_filter=denumire_filter,
sort_by=sort_by,
sort_order=sort_order,
page=page,
page_size=page_size,
username=current_user.username,
server_id=server_id
)
return TrialBalanceResponse(
success=True,
data=data
)
except ValueError as e:
# Schema not found or validation error
logger.error(f"Validation error in trial balance: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# Log unexpected errors
logger.error(f"Error fetching trial balance: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Eroare la obținerea balanței de verificare: {str(e)}"
)

View File

@@ -0,0 +1,78 @@
"""
Calendar service for fetching available accounting periods
"""
# import sys # Removed - no longer needed
import os
from typing import Optional
from shared.database.oracle_pool import oracle_pool
from ..models.calendar import CalendarPeriod, CalendarPeriodsResponse
from ..cache.decorators import cached
import logging
logger = logging.getLogger(__name__)
class CalendarService:
"""Service for calendar/accounting period operations"""
# Romanian month names for display
MONTH_NAMES_RO = [
"Ianuarie", "Februarie", "Martie", "Aprilie", "Mai", "Iunie",
"Iulie", "August", "Septembrie", "Octombrie", "Noiembrie", "Decembrie"
]
@staticmethod
@cached(cache_type='schema', key_params=['company_id', 'server_id'])
async def _get_schema(company_id: int, server_id: Optional[str] = None) -> str:
"""Get schema for company (CACHED 24h)"""
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
""", {'company_id': company_id})
result = cursor.fetchone()
return result[0] if result else None
@staticmethod
@cached(cache_type='calendar_periods', key_params=['company_id', 'server_id'])
async def get_available_periods(company_id: int, server_id: Optional[str] = None) -> CalendarPeriodsResponse:
"""
Get all available accounting periods for a company (CACHED 1h)
Returns periods ordered by year DESC, month DESC with Romanian month names.
"""
schema = await CalendarService._get_schema(company_id, server_id)
if not schema:
logger.warning(f"Schema not found for company {company_id}")
return CalendarPeriodsResponse(periods=[], current_period=None, total_count=0)
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
cursor.execute(f"""
SELECT anul, luna
FROM {schema}.calendar
ORDER BY anul DESC, luna DESC
""")
rows = cursor.fetchall()
periods = []
for row in rows:
an, luna = row[0], row[1]
month_name = CalendarService.MONTH_NAMES_RO[luna - 1]
periods.append(CalendarPeriod(
an=an,
luna=luna,
display_name=f"{month_name} {an}"
))
current_period = periods[0] if periods else None
logger.info(f"Loaded {len(periods)} accounting periods for company {company_id}")
return CalendarPeriodsResponse(
periods=periods,
current_period=current_period,
total_count=len(periods)
)

View File

@@ -0,0 +1,324 @@
"""
Service pentru logica facturi - Portează query-urile din aplicația Flask
"""
# import sys # Removed - no longer needed
import os
from shared.database.oracle_pool import oracle_pool
from typing import List, Tuple, Optional
from ..models.invoice import Invoice, InvoiceFilter, InvoiceListResponse, InvoiceSummary
from ..cache.decorators import cached
from decimal import Decimal
import logging
logger = logging.getLogger(__name__)
class InvoiceService:
"""Service pentru gestionarea facturilor"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id', 'server_id'])
async def _get_schema(company_id: int, server_id: Optional[str] = None) -> str:
"""Obține schema pentru company_id (CACHED PERMANENT)"""
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
@cached(cache_type='invoices', key_params=['filter_params', 'username', 'server_id'])
async def get_invoices(filter_params: InvoiceFilter, username: str, server_id: Optional[str] = None) -> InvoiceListResponse:
"""
Obține lista de facturi - Query simplu pentru afișare în tabel (CACHED 10 min)
"""
company_id = int(filter_params.company)
schema = await InvoiceService._get_schema(company_id, server_id)
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Determină conturile în funcție de partner_type
if filter_params.partner_type == "CLIENTI":
conturi = "'4111', '461'"
elif filter_params.partner_type == "FURNIZORI":
conturi = "'401', '404', '462'"
else:
conturi = "'4111'" # default
# Determine period to use: from params or MAX from calendar
if filter_params.luna and filter_params.an:
period_condition = "vp.an = :an AND vp.luna = :luna"
use_param_period = True
else:
period_condition = f"""vp.an = (SELECT anul FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))
AND vp.luna = (SELECT luna FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))"""
use_param_period = False
# Query cu calculele corecte pentru solduri
base_query = f"""
SELECT
vp.NUME,
vp.NRACT,
vp.DATAACT,
vp.DATASCAD,
vp.CONTRACT,
vp.COD_FISCAL,
vp.REG_COMERT,
CASE
WHEN vp.CONT IN ('4111','461') THEN vp.PRECDEB + vp.DEBIT -- Total facturat clienți
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECCRED + vp.CREDIT -- Total facturat furnizori
END as total_facturat,
CASE
WHEN vp.CONT IN ('4111','461') THEN vp.PRECCRED + vp.CREDIT -- Încasat clienți
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECDEB + vp.DEBIT -- Achitat furnizori
END as achitat,
CASE
WHEN vp.CONT IN ('4111','461') THEN
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT) -- Sold clienți
WHEN vp.CONT IN ('401','404','462') THEN
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT) -- Sold furnizori
END as sold,
vp.CONT,
NVL(vp.NUME_VAL, 'RON') as valuta,
CASE
WHEN vp.DATASCAD < SYSDATE THEN 'restant'
ELSE 'in_termen'
END as status
FROM {schema}.vireg_parteneri vp
WHERE {period_condition}
AND (
(:partner_type = 'CLIENTI' AND vp.cont IN ('4111', '461'))
OR
(:partner_type = 'FURNIZORI' AND vp.cont IN ('401', '404', '462'))
)
"""
params = {'partner_type': filter_params.partner_type}
# Add period params if using explicit period
if use_param_period:
params['an'] = filter_params.an
params['luna'] = filter_params.luna
if filter_params.partner_name:
base_query += " AND UPPER(vp.nume) LIKE UPPER(:partner_name)"
params['partner_name'] = f"%{filter_params.partner_name}%"
if filter_params.cont:
base_query += " AND vp.cont = :cont"
params['cont'] = filter_params.cont
if filter_params.min_amount:
base_query += " AND total_facturat >= :min_amount"
params['min_amount'] = filter_params.min_amount
if filter_params.max_amount:
base_query += " AND total_facturat <= :max_amount"
params['max_amount'] = filter_params.max_amount
if filter_params.only_unpaid:
# Nu putem folosi aliasul "sold" în WHERE în Oracle, trebuie să repetăm calculul
base_query += """ AND (
CASE
WHEN vp.CONT IN ('4111','461') THEN
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT)
WHEN vp.CONT IN ('401','404','462') THEN
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT)
END
) > 0"""
# Count total pentru paginare
count_query = f"SELECT COUNT(*) FROM ({base_query})"
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# Query pentru TOTAL SOLD din TOATE facturile filtrate (nu doar pagina curentă)
total_sold_query = f"""
SELECT NVL(SUM(sold), 0) as total_sold
FROM ({base_query})
"""
cursor.execute(total_sold_query, params)
total_sold_result = cursor.fetchone()
total_sold_all = Decimal(str(total_sold_result[0])) if total_sold_result else Decimal('0.00')
# Get accounting period - use params if provided, else from calendar
if use_param_period:
accounting_period = {
'an': filter_params.an,
'luna': filter_params.luna
}
else:
period_query = f"""
SELECT anul, luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar)
"""
cursor.execute(period_query)
period_result = cursor.fetchone()
accounting_period = {
'an': period_result[0] if period_result else None,
'luna': period_result[1] if period_result else None
}
# Adaugă ORDER BY și paginare - Ordonare cronologică (DATAACT, NRACT, NUME)
base_query += " ORDER BY vp.DATAACT ASC, vp.NRACT ASC, vp.NUME"
# Paginare Oracle
offset = (filter_params.page - 1) * filter_params.page_size
limit = offset + filter_params.page_size
paginated_query = f"""
SELECT * FROM (
SELECT ROWNUM as rn, t.* FROM ({base_query}) t WHERE ROWNUM <= :limit
) WHERE rn > :offset
"""
params['offset'] = offset
params['limit'] = limit
cursor.execute(paginated_query, params)
rows = cursor.fetchall()
# Procesează rezultatele cu structura nouă
invoices = []
total_amount = Decimal('0.00')
for row in rows:
# Skip ROWNUM, extrage valorile din query-ul nou
nume = row[1]
nract = row[2]
dataact = row[3]
datascad = row[4]
contract = row[5]
cod_fiscal = row[6]
reg_comert = row[7]
total_facturat = Decimal(str(row[8] or 0))
achitat = Decimal(str(row[9] or 0))
sold = Decimal(str(row[10] or 0))
cont = row[11]
valuta = row[12] or 'RON'
status = row[13]
invoice_data = {
'nume': nume or '',
'nract': nract or 0,
'dataact': dataact,
'datascad': datascad,
'contract': contract,
'cod_fiscal': cod_fiscal,
'reg_comert': reg_comert,
'cont': cont,
'totctva': total_facturat,
'achitat': achitat,
'soldfinal': sold,
'valuta': valuta
}
invoice = Invoice(**invoice_data)
invoices.append(invoice)
total_amount += total_facturat
return InvoiceListResponse(
invoices=invoices,
total_count=total_count,
filtered_count=len(invoices),
total_amount=total_amount,
page=filter_params.page,
page_size=filter_params.page_size,
has_more=len(invoices) == filter_params.page_size,
accounting_period=accounting_period,
# Total sold din TOATE facturile filtrate
total_sold_all=total_sold_all
)
@staticmethod
async def get_invoice_details(company: str, invoice_number: str, username: str, server_id: Optional[str] = None) -> Invoice:
"""
Obține detaliile unei facturi specifice
"""
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Obține schema din v_nom_firme bazat pe id_firma
company_id = int(company)
schema_query = "SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme WHERE id_firma = :company_id"
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema nu a fost găsită pentru id_firma {company_id}")
schema = schema_result[0]
# Query simplu pentru detalii factură
detail_query = f"""
SELECT
NUME,
NRACT,
DATAACT,
DATASCAD,
CONTRACT,
COD_FISCAL,
REG_COMERT,
PRECDEB,
PRECCRED,
DEBIT,
CREDIT,
CONT
FROM {schema}.vireg_parteneri
WHERE nract = :invoice_number
AND an = (select anul from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
AND luna = (select luna from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
"""
cursor.execute(detail_query, {'invoice_number': invoice_number})
row = cursor.fetchone()
if not row:
raise ValueError(f"Factura {invoice_number} nu a fost găsită")
# Extrage valorile
nume = row[0]
nract = row[1]
dataact = row[2]
datascad = row[3]
contract = row[4]
cod_fiscal = row[5]
reg_comert = row[6]
precdeb = Decimal(str(row[7] or 0))
preccred = Decimal(str(row[8] or 0))
debit = Decimal(str(row[9] or 0))
credit = Decimal(str(row[10] or 0))
cont = row[11]
# Calculează valorile în funcție de tipul contului
if cont in ('4111', '461'): # CLIENTI
totctva = precdeb + debit
achitat = preccred + credit
soldfinal = precdeb - preccred + debit - credit
else: # FURNIZORI
totctva = preccred + credit
achitat = precdeb + debit
soldfinal = preccred - precdeb + credit - debit
invoice_data = {
'nume': nume or '',
'nract': nract or 0,
'dataact': dataact,
'datascad': datascad,
'contract': contract,
'cod_fiscal': cod_fiscal,
'reg_comert': reg_comert,
'totctva': totctva,
'achitat': achitat,
'soldfinal': soldfinal
}
return Invoice(**invoice_data)

View File

@@ -0,0 +1,410 @@
# import sys # Removed - no longer needed
import os
from typing import Optional, List, Tuple, Any
import oracledb
from shared.database.oracle_pool import oracle_pool
from ..models.treasury import BankCashRegister, RegisterFilter, RegisterListResponse, AccountingPeriod
from ..cache.decorators import cached
from decimal import Decimal
import logging
logger = logging.getLogger(__name__)
class TreasuryService:
"""Service pentru trezorerie - registru casă și bancă"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id', 'server_id'])
async def _get_schema(company_id: int, server_id: Optional[str] = None) -> str:
"""Obține schema pentru company_id (CACHED PERMANENT)"""
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
def _get_view_query(schema: str, register_type: Optional[str] = None) -> str:
"""
Construiește query-ul pentru view-ul vbancasa corespunzător.
Dacă register_type este None, returnează UNION ALL pentru toate tipurile.
NU se filtrează pe incasari/plati > 0 - se afișează TOATE înregistrările!
"""
view_configs = {
'BANCA_LEI': {
'view': f'{schema}.vbancasa_5121_cum',
'incasari_col': 'incasari',
'plati_col': 'plati',
'valuta': "'RON'",
'tip': "'BANCA LEI'"
},
'BANCA_VALUTA': {
'view': f'{schema}.vbancasa_5124_cum',
'incasari_col': 'incasval',
'plati_col': 'platival',
'valuta': "COALESCE(numeval, 'EUR')",
'tip': "'BANCA VALUTA'"
},
'CASA_LEI': {
'view': f'{schema}.vbancasa_5311_cum',
'incasari_col': 'incasari',
'plati_col': 'plati',
'valuta': "'RON'",
'tip': "'CASA LEI'"
},
'CASA_VALUTA': {
'view': f'{schema}.vbancasa_5314_cum',
'incasari_col': 'incasval',
'plati_col': 'platival',
'valuta': "COALESCE(numeval, 'EUR')",
'tip': "'CASA VALUTA'"
}
}
def build_select(config):
# NU se filtrează - se afișează TOATE înregistrările
# SOLD CUMULAT: Running balance per bancasa using window function
# NULL-date rows (opening balance) come first due to NULLS FIRST
return f"""
SELECT
nume, nract, dataact, bancasa,
{config['incasari_col']} as incasari,
{config['plati_col']} as plati,
SUM({config['incasari_col']} - {config['plati_col']}) OVER (
PARTITION BY bancasa
ORDER BY dataact ASC NULLS FIRST, nract ASC NULLS FIRST
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) as sold,
{config['valuta']} as valuta,
{config['tip']} as tip_registru,
explicatia
FROM {config['view']}
"""
if register_type and register_type in view_configs:
return build_select(view_configs[register_type])
else:
# UNION ALL pentru toate tipurile
queries = [build_select(cfg) for cfg in view_configs.values()]
return " UNION ALL ".join(queries)
@staticmethod
@cached(cache_type='treasury', key_params=['filter_params', 'username', 'server_id'])
async def get_bank_cash_register(filter_params: RegisterFilter, username: str, server_id: Optional[str] = None) -> RegisterListResponse:
"""
Obține registrul de casă și bancă din vbancasa views (CACHED 10 min)
IMPORTANT: PACK_SESIUNE.SETAN și SETLUNA trebuie executate în ACEEAȘI
tranzacție cu SELECT-ul din vbancasa* views!
Folosim un bloc PL/SQL anonim care:
1. Obține anul și luna curentă din calendar
2. Apelează PACK_SESIUNE.SETAN și SETLUNA
3. Execută SELECT-ul din vbancasa*
Toate în aceeași tranzacție!
"""
company_id = int(filter_params.company)
schema = await TreasuryService._get_schema(company_id, server_id)
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Construiește query-ul pentru tipul de registru selectat
base_select = TreasuryService._get_view_query(schema, filter_params.register_type)
# Construiește WHERE conditions
where_conditions = []
# Date filter preserves NULL-date rows (opening balance)
# for correct cumulative sum calculation
if filter_params.date_from and filter_params.date_to:
where_conditions.append(f"(dataact IS NULL OR (dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD') AND dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD')))")
elif filter_params.date_from:
where_conditions.append(f"(dataact IS NULL OR dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
elif filter_params.date_to:
where_conditions.append(f"(dataact IS NULL OR dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
if filter_params.partner_name:
# Escape single quotes pentru SQL
partner_escaped = filter_params.partner_name.replace("'", "''")
where_conditions.append(f"UPPER(nume) LIKE UPPER('%{partner_escaped}%')")
if filter_params.bank_account:
# Escape single quotes pentru SQL
bank_escaped = filter_params.bank_account.replace("'", "''")
where_conditions.append(f"bancasa = '{bank_escaped}'")
where_clause = ""
if where_conditions:
where_clause = " WHERE " + " AND ".join(where_conditions)
# Paginare Oracle
offset = (filter_params.page - 1) * filter_params.page_size
limit_val = filter_params.page_size
# Determine period to use: from params or MAX from calendar
if filter_params.luna and filter_params.an:
use_param_period = True
period_select = f"""
v_an := :param_an;
v_luna := :param_luna;
"""
else:
use_param_period = False
period_select = f"""
SELECT anul, luna INTO v_an, v_luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
"""
# Bloc PL/SQL anonim care face totul într-o singură tranzacție:
# 1. Obține anul și luna din params sau calendar
# 2. Setează PACK_SESIUNE.SETAN și SETLUNA
# 3. Returnează datele prin REF CURSOR
# IMPORTANT: Folosim ROW_NUMBER() pentru paginare corectă cu ORDER BY NULLS FIRST
plsql_block = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
v_cursor SYS_REFCURSOR;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
-- Setează contextul de sesiune (OBLIGATORIU înainte de SELECT din vbancasa*)
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Return accounting period
:out_an := v_an;
:out_luna := v_luna;
-- Returnează datele prin cursor cu ROW_NUMBER pentru paginare corectă
-- Pentru rânduri cu dataact=NULL (solduri precedente), sortare după bancasa
-- Pentru rânduri cu date, sortare după data, număr, bancasa
OPEN :result_cursor FOR
SELECT * FROM (
SELECT t.*, ROW_NUMBER() OVER (
ORDER BY dataact ASC NULLS FIRST,
CASE WHEN dataact IS NULL THEN bancasa END ASC,
nract ASC NULLS FIRST,
bancasa ASC
) as rn
FROM ({base_select}) t{where_clause}
) WHERE rn > {offset} AND rn <= {offset + limit_val};
END;
"""
# Creează cursor pentru rezultate (oracledb.CURSOR pentru REF CURSOR)
result_cursor = cursor.var(oracledb.CURSOR)
out_an = cursor.var(int)
out_luna = cursor.var(int)
# Build params dict
exec_params = {'result_cursor': result_cursor, 'out_an': out_an, 'out_luna': out_luna}
if use_param_period:
exec_params['param_an'] = filter_params.an
exec_params['param_luna'] = filter_params.luna
# Execută blocul PL/SQL cu REF CURSOR
cursor.execute(plsql_block, exec_params)
# Get accounting period values
accounting_year = out_an.getvalue()
accounting_month = out_luna.getvalue()
# Obține rezultatele din cursor
ref_cursor = result_cursor.getvalue()
rows = ref_cursor.fetchall()
ref_cursor.close()
# Pentru count total, executăm alt bloc PL/SQL
count_plsql = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
SELECT COUNT(*) INTO :total_count FROM ({base_select}) sub{where_clause};
END;
"""
total_count_var = cursor.var(int)
count_params = {'total_count': total_count_var}
if use_param_period:
count_params['param_an'] = filter_params.an
count_params['param_luna'] = filter_params.luna
cursor.execute(count_plsql, count_params)
total_count = total_count_var.getvalue()
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
# sold_precedent = suma sold pentru rânduri cu dataact IS NULL
# total_incasari = suma incasari pentru rânduri cu dataact IS NOT NULL
# total_plati = suma plati pentru rânduri cu dataact IS NOT NULL
# Notă: where_clause poate fi gol sau poate conține "WHERE ..."
# Dacă e gol, adăugăm WHERE; dacă nu, adăugăm AND
dataact_null_cond = " AND dataact IS NULL" if where_clause else " WHERE dataact IS NULL"
dataact_not_null_cond = " AND dataact IS NOT NULL" if where_clause else " WHERE dataact IS NOT NULL"
totals_plsql = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Sold precedent: suma sold pentru rânduri fără dată (opening balance)
SELECT NVL(SUM(sold), 0) INTO :sold_precedent_all
FROM ({base_select}) sub{where_clause}{dataact_null_cond};
-- Total încasări: suma incasari pentru rânduri cu dată (transactions)
SELECT NVL(SUM(incasari), 0) INTO :total_incasari_all
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
-- Total plăți: suma plati pentru rânduri cu dată (transactions)
SELECT NVL(SUM(plati), 0) INTO :total_plati_all
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
END;
"""
sold_precedent_all_var = cursor.var(oracledb.NUMBER)
total_incasari_all_var = cursor.var(oracledb.NUMBER)
total_plati_all_var = cursor.var(oracledb.NUMBER)
totals_params = {
'sold_precedent_all': sold_precedent_all_var,
'total_incasari_all': total_incasari_all_var,
'total_plati_all': total_plati_all_var
}
if use_param_period:
totals_params['param_an'] = filter_params.an
totals_params['param_luna'] = filter_params.luna
cursor.execute(totals_plsql, totals_params)
sold_precedent_all = Decimal(str(sold_precedent_all_var.getvalue() or 0))
total_incasari_all = Decimal(str(total_incasari_all_var.getvalue() or 0))
total_plati_all = Decimal(str(total_plati_all_var.getvalue() or 0))
sold_final_all = sold_precedent_all + total_incasari_all - total_plati_all
# Procesare rezultate
registers = []
total_incasari = Decimal('0.00')
total_plati = Decimal('0.00')
for row in rows:
# Coloane: nume, nract, dataact, bancasa, incasari, plati, sold, valuta, tip_registru, explicatia, rn
# row[0-9] = date, row[10] = rn (ROW_NUMBER la final)
register_data = BankCashRegister(
nume=row[0] or '',
nract=row[1],
dataact=row[2],
nume_cont_bancar=row[3] or '',
incasari=Decimal(str(row[4] or 0)),
plati=Decimal(str(row[5] or 0)),
sold=Decimal(str(row[6] or 0)),
valuta=row[7],
tip_registru=row[8],
explicatia=row[9] or ''
)
registers.append(register_data)
total_incasari += register_data.incasari
total_plati += register_data.plati
logger.info(f"Treasury query for company {company_id}, type={filter_params.register_type}: {len(registers)} records, total={total_count}")
return RegisterListResponse(
registers=registers,
total_count=total_count,
filtered_count=len(registers),
total_incasari=total_incasari,
total_plati=total_plati,
page=filter_params.page,
page_size=filter_params.page_size,
has_more=len(registers) == filter_params.page_size,
accounting_period=AccountingPeriod(an=accounting_year, luna=accounting_month),
# Totaluri din TOATE înregistrările filtrate
sold_precedent_all=sold_precedent_all,
total_incasari_all=total_incasari_all,
total_plati_all=total_plati_all,
sold_final_all=sold_final_all
)
@staticmethod
@cached(cache_type='treasury', key_params=['company_id', 'register_type', 'server_id'])
async def get_bank_cash_accounts(company_id: int, register_type: str, server_id: Optional[str] = None) -> List[str]:
"""
Obține lista distinctă de conturi bancă/casă (bancasa) pentru dropdown.
Cached pentru performanță.
IMPORTANT: Trebuie să setăm contextul PACK_SESIUNE înainte de a accesa vbancasa views!
"""
schema = await TreasuryService._get_schema(company_id, server_id)
# Map register_type to view
view_map = {
'BANCA_LEI': f'{schema}.vbancasa_5121_cum',
'BANCA_VALUTA': f'{schema}.vbancasa_5124_cum',
'CASA_LEI': f'{schema}.vbancasa_5311_cum',
'CASA_VALUTA': f'{schema}.vbancasa_5314_cum'
}
if register_type not in view_map:
return []
view_name = view_map[register_type]
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# PL/SQL block to set session context and get accounts
plsql_block = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Get current year and month from calendar
SELECT anul, luna INTO v_an, v_luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
-- Set session context (REQUIRED before accessing vbancasa* views)
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Return accounts via cursor
OPEN :result_cursor FOR
SELECT DISTINCT bancasa
FROM {view_name}
WHERE bancasa IS NOT NULL
ORDER BY bancasa;
END;
"""
result_cursor = cursor.var(oracledb.CURSOR)
cursor.execute(plsql_block, {'result_cursor': result_cursor})
ref_cursor = result_cursor.getvalue()
rows = ref_cursor.fetchall()
ref_cursor.close()
accounts = [row[0] for row in rows if row[0]]
logger.info(f"Found {len(accounts)} bank/cash accounts for company {company_id}, type={register_type}")
return accounts

View File

@@ -0,0 +1,219 @@
"""
Service pentru Trial Balance (Balanță de Verificare) - Query VBAL VIEW
Refactored to use caching system for optimal performance
"""
# import sys # Removed - no longer needed
import os
from typing import Dict, Any, Optional
from shared.database.oracle_pool import oracle_pool
from ..models.trial_balance import (
TrialBalanceItem,
TrialBalanceFilters,
TrialBalancePagination,
TrialBalanceResponse
)
from ..cache.decorators import cached
from decimal import Decimal
import math
import logging
logger = logging.getLogger(__name__)
class TrialBalanceService:
"""Service pentru gestionarea balanței de verificare cu cache"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id', 'server_id'])
async def _get_schema(company_id: int, server_id: Optional[str] = None) -> str:
"""
Obține schema pentru company_id (CACHED 24h)
This is cached permanently because company schemas rarely change.
"""
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
@cached(cache_type='trial_balance', key_params=['company_id', 'luna', 'an', 'cont_filter',
'denumire_filter', 'sort_by', 'sort_order',
'page', 'page_size', 'username', 'server_id'])
async def get_trial_balance(
company_id: int,
luna: int,
an: int,
cont_filter: str | None,
denumire_filter: str | None,
sort_by: str,
sort_order: str,
page: int,
page_size: int,
username: str,
server_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Obține balanța de verificare sintetică (CACHED 10 min)
Cache key includes all filter parameters to ensure unique cache entries
for different query variations.
Args:
company_id: ID firmei
luna: Luna (1-12)
an: Anul
cont_filter: Filtru număr cont (optional)
denumire_filter: Filtru denumire cont (optional)
sort_by: Coloană pentru sortare
sort_order: Ordinea sortării (asc/desc)
page: Pagina
page_size: Mărimea paginii
username: Username pentru cache tracking
server_id: Optional Oracle server identifier for multi-server support
Returns:
Dictionary cu items, pagination, filters_applied
"""
# Get schema (cached separately)
schema = await TrialBalanceService._get_schema(company_id, server_id)
# Validate sort_order
if sort_order.lower() not in ['asc', 'desc']:
sort_order = 'asc'
# Validate sort_by (prevent SQL injection)
valid_sort_columns = ['CONT', 'DENUMIRE', 'PRECDEB', 'PRECCRED',
'RULDEB', 'RULCRED', 'SOLDDEB', 'SOLDCRED']
if sort_by.upper() not in valid_sort_columns:
sort_by = 'CONT'
async with oracle_pool.get_connection(server_id) as connection:
with connection.cursor() as cursor:
# Build base query for VBAL VIEW
base_query = f"""
SELECT
CONT,
NVL(DENUMIRE, '') as DENUMIRE,
NVL(PRECDEB, 0) as PRECDEB,
NVL(PRECCRED, 0) as PRECCRED,
NVL(RULDEB, 0) as RULDEB,
NVL(RULCRED, 0) as RULCRED,
NVL(SOLDDEB, 0) as SOLDDEB,
NVL(SOLDCRED, 0) as SOLDCRED
FROM {schema}.VBAL
WHERE AN = :an
AND LUNA = :luna
"""
params = {
'an': an,
'luna': luna
}
# Add dynamic filters
if cont_filter:
base_query += " AND CONT LIKE :cont_filter"
params['cont_filter'] = f"{cont_filter}%"
if denumire_filter:
base_query += " AND UPPER(DENUMIRE) LIKE UPPER(:denumire_filter)"
params['denumire_filter'] = f"%{denumire_filter}%"
# Count total for pagination
count_query = f"SELECT COUNT(*) FROM ({base_query})"
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
totals_query = f"""
SELECT
NVL(SUM(PRECDEB), 0) as total_prec_deb,
NVL(SUM(PRECCRED), 0) as total_prec_cred,
NVL(SUM(RULDEB), 0) as total_rul_deb,
NVL(SUM(RULCRED), 0) as total_rul_cred,
NVL(SUM(SOLDDEB), 0) as total_sold_deb,
NVL(SUM(SOLDCRED), 0) as total_sold_cred
FROM ({base_query})
"""
cursor.execute(totals_query, params)
totals_row = cursor.fetchone()
totals = {
"total_sold_precedent_debit": Decimal(str(totals_row[0])) if totals_row else Decimal('0.00'),
"total_sold_precedent_credit": Decimal(str(totals_row[1])) if totals_row else Decimal('0.00'),
"total_rulaj_lunar_debit": Decimal(str(totals_row[2])) if totals_row else Decimal('0.00'),
"total_rulaj_lunar_credit": Decimal(str(totals_row[3])) if totals_row else Decimal('0.00'),
"total_sold_final_debit": Decimal(str(totals_row[4])) if totals_row else Decimal('0.00'),
"total_sold_final_credit": Decimal(str(totals_row[5])) if totals_row else Decimal('0.00')
}
# Add sorting
base_query += f" ORDER BY {sort_by.upper()} {sort_order.upper()}"
# Pagination (Oracle ROWNUM with ORDER BY)
offset = (page - 1) * page_size
limit = offset + page_size
paginated_query = f"""
SELECT * FROM (
SELECT a.*, ROWNUM rnum FROM (
{base_query}
) a WHERE ROWNUM <= :limit
) WHERE rnum > :offset
"""
params['offset'] = offset
params['limit'] = limit
cursor.execute(paginated_query, params)
rows = cursor.fetchall()
# Process results
# Index: CONT(0), DENUMIRE(1), PRECDEB(2), PRECCRED(3),
# RULDEB(4), RULCRED(5), SOLDDEB(6), SOLDCRED(7), rnum(8)
items = []
for row in rows:
item = TrialBalanceItem(
cont=row[0] or '',
denumire=row[1] or '',
sold_precedent_debit=Decimal(str(row[2] or 0)),
sold_precedent_credit=Decimal(str(row[3] or 0)),
rulaj_lunar_debit=Decimal(str(row[4] or 0)),
rulaj_lunar_credit=Decimal(str(row[5] or 0)),
sold_final_debit=Decimal(str(row[6] or 0)),
sold_final_credit=Decimal(str(row[7] or 0))
)
items.append(item.dict())
# Calculate pagination
total_pages = math.ceil(total_count / page_size) if page_size > 0 else 0
# Build response
return {
"items": items,
"pagination": {
"total_items": total_count,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size
},
"filters_applied": {
"luna": luna,
"an": an,
"cont_filter": cont_filter,
"denumire_filter": denumire_filter
},
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
"totals": totals
}

View File

@@ -0,0 +1,313 @@
"""
Session Management for Telegram Bot
This module handles session state for Telegram users, specifically managing
the active company selection for command handlers.
"""
import logging
import json
from typing import Dict, Any, Optional
from datetime import datetime
from backend.modules.telegram.db.operations import (
create_session,
get_user_active_session,
update_session_state,
delete_user_sessions
)
logger = logging.getLogger(__name__)
class ConversationSession:
"""
Manages session state for a single user.
Attributes:
telegram_user_id: Telegram user ID
session_id: UUID of the session
active_company_id: Selected company ID
active_company_name: Selected company name
active_company_cui: Selected company CUI
"""
def __init__(
self,
telegram_user_id: int,
session_id: Optional[str] = None
):
"""
Initialize a session.
Args:
telegram_user_id: Telegram user ID
session_id: Existing session ID (if resuming), or None for new session
"""
self.telegram_user_id = telegram_user_id
self.session_id = session_id
self.created_at = datetime.now()
self.updated_at = datetime.now()
# Active company for this session
self.active_company_id: Optional[int] = None
self.active_company_name: Optional[str] = None
self.active_company_cui: Optional[str] = None
def set_active_company(
self,
company_id: int,
company_name: str,
company_cui: Optional[str] = None
):
"""
Set the active company for this session.
Args:
company_id: Company ID
company_name: Company name
company_cui: Company CUI (optional)
"""
self.active_company_id = company_id
self.active_company_name = company_name
self.active_company_cui = company_cui
self.updated_at = datetime.now()
logger.info(
f"Active company set for user {self.telegram_user_id}: "
f"{company_name} (ID: {company_id})"
)
def get_active_company(self) -> Optional[Dict[str, Any]]:
"""
Get the active company information.
Returns:
Dict with company info (id, name, cui) or None if no company selected
"""
if self.active_company_id is not None:
return {
"id": self.active_company_id,
"name": self.active_company_name,
"cui": self.active_company_cui
}
return None
def clear_active_company(self):
"""
Clear the active company selection.
"""
logger.info(
f"Clearing active company for user {self.telegram_user_id} "
f"(was: {self.active_company_name})"
)
self.active_company_id = None
self.active_company_name = None
self.active_company_cui = None
self.updated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""
Serialize session to dictionary (for database storage).
Returns:
Dict representation of session
"""
return {
"telegram_user_id": self.telegram_user_id,
"session_id": self.session_id,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"active_company_id": self.active_company_id,
"active_company_name": self.active_company_name,
"active_company_cui": self.active_company_cui
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ConversationSession':
"""
Deserialize session from dictionary.
Args:
data: Dict representation of session
Returns:
ConversationSession instance
"""
session = cls(
telegram_user_id=data["telegram_user_id"],
session_id=data.get("session_id")
)
# Restore active company
session.active_company_id = data.get("active_company_id")
session.active_company_name = data.get("active_company_name")
session.active_company_cui = data.get("active_company_cui")
if "created_at" in data:
session.created_at = datetime.fromisoformat(data["created_at"])
if "updated_at" in data:
session.updated_at = datetime.fromisoformat(data["updated_at"])
return session
class SessionManager:
"""
Manages sessions for all users.
Provides methods to create, retrieve, update, and delete sessions.
Sessions are stored both in memory (for quick access) and in database
(for persistence).
"""
def __init__(self):
"""
Initialize the session manager.
"""
self._sessions: Dict[int, ConversationSession] = {}
logger.info("SessionManager initialized")
async def get_or_create_session(
self,
telegram_user_id: int
) -> ConversationSession:
"""
Get existing session for a user or create a new one.
Args:
telegram_user_id: Telegram user ID
Returns:
ConversationSession for the user
"""
# Check in-memory cache first
if telegram_user_id in self._sessions:
logger.debug(f"Found session in cache for user {telegram_user_id}")
return self._sessions[telegram_user_id]
# Check database for existing session
session_data = await get_user_active_session(telegram_user_id)
if session_data:
# Restore session from database
conversation_state_json = session_data.get('conversation_state')
if conversation_state_json:
try:
session_dict = json.loads(conversation_state_json)
session = ConversationSession.from_dict(session_dict)
session.session_id = session_data['session_id']
self._sessions[telegram_user_id] = session
logger.info(f"Restored session from database for user {telegram_user_id}")
return session
except json.JSONDecodeError as e:
logger.error(f"Failed to parse session state: {e}")
# Create new session
session = ConversationSession(telegram_user_id)
# Save to database
session_id = await create_session(
telegram_user_id=telegram_user_id,
conversation_state=json.dumps(session.to_dict()),
expires_in_hours=24
)
session.session_id = session_id
self._sessions[telegram_user_id] = session
logger.info(f"Created new session for user {telegram_user_id} (ID: {session_id})")
return session
async def save_session(self, telegram_user_id: int) -> bool:
"""
Save session to database.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if saved successfully
"""
session = self._sessions.get(telegram_user_id)
if not session or not session.session_id:
logger.warning(f"No session to save for user {telegram_user_id}")
return False
try:
conversation_state = json.dumps(session.to_dict())
success = await update_session_state(
session_id=session.session_id,
conversation_state=conversation_state
)
if success:
logger.debug(f"Saved session for user {telegram_user_id}")
else:
logger.warning(f"Failed to save session for user {telegram_user_id}")
return success
except Exception as e:
logger.error(f"Error saving session for user {telegram_user_id}: {e}")
return False
async def delete_session(self, telegram_user_id: int) -> bool:
"""
Delete session completely (from memory and database).
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if deleted successfully
"""
# Remove from memory
if telegram_user_id in self._sessions:
del self._sessions[telegram_user_id]
# Delete from database
success = await delete_user_sessions(telegram_user_id)
if success:
logger.info(f"Deleted session for user {telegram_user_id}")
else:
logger.warning(f"Failed to delete session for user {telegram_user_id}")
return success
def get_active_sessions_count(self) -> int:
"""
Get count of active sessions in memory.
Returns:
int: Number of active sessions
"""
return len(self._sessions)
# Singleton instance
_session_manager_instance: Optional[SessionManager] = None
def get_session_manager() -> SessionManager:
"""
Get or create the singleton SessionManager instance.
Returns:
SessionManager: Singleton instance
"""
global _session_manager_instance
if _session_manager_instance is None:
_session_manager_instance = SessionManager()
return _session_manager_instance
# Export main classes and functions
__all__ = [
'ConversationSession',
'SessionManager',
'get_session_manager'
]

View File

@@ -0,0 +1,969 @@
"""
API Client for ROA2WEB Backend Communication
This module provides an async HTTP client for communicating with the FastAPI backend.
Handles authentication, requests, error handling, and response parsing.
"""
import logging
import os
from typing import Optional, Dict, Any, List
from datetime import datetime
import httpx
from httpx import AsyncClient, Response, HTTPError, ConnectError
logger = logging.getLogger(__name__)
# Backend configuration from environment
# Default to port 8000 (production) instead of 8001 (development)
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
REQUEST_TIMEOUT = float(os.getenv("API_TIMEOUT", "30.0")) # 30 seconds default
class BackendAPIClient:
"""
Async HTTP client for ROA2WEB FastAPI backend.
Provides methods for all API endpoints used by the Telegram bot:
- Dashboard data
- Invoices search and retrieval
- Treasury/payment data
- Report exports
- Company listings
- User authentication and token management
"""
def __init__(self, base_url: str = BACKEND_URL):
"""
Initialize the API client.
Args:
base_url: Base URL of the FastAPI backend
"""
self.base_url = base_url.rstrip('/')
self.client: Optional[AsyncClient] = None
logger.info(f"Backend API client initialized with base URL: {self.base_url}")
async def __aenter__(self):
"""Async context manager entry."""
self.client = AsyncClient(
base_url=self.base_url,
timeout=REQUEST_TIMEOUT,
follow_redirects=True
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self.client:
await self.client.aclose()
def _get_auth_headers(self, jwt_token: str) -> Dict[str, str]:
"""
Generate authentication headers with JWT token.
Args:
jwt_token: JWT access token
Returns:
Dict with Authorization header
"""
return {
"Authorization": f"Bearer {jwt_token}",
"Content-Type": "application/json"
}
async def _handle_response(self, response: Response) -> Dict[str, Any]:
"""
Handle API response and extract data.
Args:
response: HTTP response object
Returns:
Dict: Response JSON data
Raises:
HTTPError: If response status is not successful
"""
try:
response.raise_for_status()
return response.json()
except HTTPError as e:
logger.error(f"API request failed: {e}")
logger.error(f"Response body: {response.text}")
raise
except Exception as e:
logger.error(f"Failed to parse response: {e}")
raise
# =========================================================================
# AUTHENTICATION & USER ENDPOINTS
# =========================================================================
async def verify_user(
self,
oracle_username: str,
linking_code: str,
server_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Verify user exists in Oracle and get JWT token.
Called during Telegram linking process (auto-linking flow).
Args:
oracle_username: Oracle username extracted from linking code
linking_code: The 8-character linking code for validation
Returns:
Dict with:
- success: True if verification succeeded
- access_token: JWT access token
- refresh_token: JWT refresh token
- user: Dict with user_id, username, companies, permissions
- message: Status message
None if user not found or error
Example:
result = await client.verify_user("JOHN.DOE", "ABC12345")
if result and result['success']:
jwt_token = result['access_token']
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Flow A: Auto-linking (no password required)
response = await self.client.post(
"/api/telegram/auth/verify-user",
json={
"linking_code": linking_code,
"oracle_username": oracle_username,
"server_id": server_id
}
)
return await self._handle_response(response)
except ConnectError as e:
logger.error(f"Cannot connect to backend at {self.base_url}: {e}")
logger.error("Verify that backend service is running and BACKEND_URL is correct")
return None
except HTTPError as e:
if e.response.status_code == 404:
logger.warning(f"User {oracle_username} not found in Oracle")
return None
logger.error(f"Failed to verify user {oracle_username}: {e}")
return None
except Exception as e:
logger.error(f"Error verifying user: {e}")
return None
async def refresh_token(self, refresh_token: str) -> Optional[str]:
"""
Refresh JWT token for a user.
Args:
refresh_token: JWT refresh token
Returns:
str: New JWT access token, None if failed
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/refresh-token",
json={"refresh_token": refresh_token}
)
data = await self._handle_response(response)
return data.get('access_token')
except Exception as e:
logger.error(f"Failed to refresh token: {e}")
return None
async def verify_email(self, email: str, server_id: Optional[str] = None) -> dict:
"""
Verify if email exists in Oracle database
Args:
email: Email address to verify
server_id: Optional Oracle server ID (for multi-server mode)
Returns:
dict with 'success' (bool), 'username' (str or None), and 'message' (str)
Raises:
httpx.HTTPError: On network or HTTP errors
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/verify-email",
json={"email": email, "server_id": server_id}
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error verifying email {email}: {e.response.status_code}")
return {
"success": False,
"username": None,
"message": "Eroare la verificarea email-ului"
}
except Exception as e:
logger.error(f"Failed to verify email {email}: {e}")
return {
"success": False,
"username": None,
"message": "Eroare la verificarea email-ului"
}
async def login_with_email(
self,
email: str,
password: str,
telegram_user_id: int,
session_token: str,
server_id: Optional[str] = None
) -> dict:
"""
Login via email + password with session token
Args:
email: User email address
password: Oracle password
telegram_user_id: Telegram user ID
session_token: Signed token from code validation
server_id: Optional Oracle server ID (for multi-server mode)
Returns:
Login response with JWT tokens and user data
Raises:
httpx.HTTPError: On network or HTTP errors
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/login-with-email",
json={
"email": email,
"password": password,
"telegram_user_id": telegram_user_id,
"session_token": session_token,
"server_id": server_id
},
timeout=30.0 # 30 seconds timeout
)
response.raise_for_status()
data = response.json()
logger.info(f"Email login successful for user {telegram_user_id}")
return data
except httpx.HTTPStatusError as e:
logger.error(f"Email login HTTP error: {e.response.status_code} - {e.response.text}")
# Parse error detail if available
try:
error_data = e.response.json()
return {
"success": False,
"message": error_data.get("detail", "Autentificare eșuată")
}
except:
return {
"success": False,
"message": "Autentificare eșuată"
}
except httpx.TimeoutException:
logger.error("Email login timeout")
return {
"success": False,
"message": "Timeout. Te rugăm să încerci din nou."
}
except Exception as e:
logger.error(f"Email login error: {e}", exc_info=True)
return {
"success": False,
"message": "Eroare de conexiune"
}
async def switch_server(
self,
jwt_token: str,
oracle_username: str,
new_server_id: str,
oracle_password: str = None
) -> dict:
"""
Switch the active Oracle server for the authenticated user.
Args:
jwt_token: Current JWT access token (used for authentication)
oracle_username: Oracle username of the current user
new_server_id: Target Oracle server ID
oracle_password: Oracle password on the new server (required if servers have different passwords)
Returns:
Dict with success, access_token, refresh_token, message
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
payload = {"oracle_username": oracle_username, "new_server_id": new_server_id}
if oracle_password:
payload["oracle_password"] = oracle_password
response = await self.client.post(
"/api/telegram/auth/switch-server",
json=payload,
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Switch server HTTP error: {e.response.status_code}")
try:
return {"success": False, "message": e.response.json().get("detail", "Eroare")}
except Exception:
return {"success": False, "message": "Eroare la schimbarea serverului"}
except Exception as e:
logger.error(f"Switch server error: {e}")
return {"success": False, "message": "Eroare de conexiune"}
async def get_user_companies(self, jwt_token: str) -> List[Dict[str, Any]]:
"""
Get list of companies the user has access to.
Args:
jwt_token: JWT access token
Returns:
List of company dicts with id, nume_firma, cui, etc.
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/companies",
headers=self._get_auth_headers(jwt_token)
)
data = await self._handle_response(response)
# Backend returns {"companies": [...], "total_count": N}
if isinstance(data, dict) and "companies" in data:
return data["companies"]
return data if isinstance(data, list) else []
except Exception as e:
logger.error(f"Failed to get companies: {e}")
return []
# =========================================================================
# DASHBOARD ENDPOINTS
# =========================================================================
async def get_dashboard_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get dashboard statistics for a company.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with dashboard data (sold_total, facturi, plati, etc.)
Includes _cache_hit and _response_time_ms metadata
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
"/api/reports/dashboard/summary",
params={"company": str(company_id)},
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get dashboard data for company {company_id}: {e}")
return None
async def get_treasury_breakdown(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get detailed treasury breakdown (casa + banca accounts).
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with treasury breakdown data (accounts by type)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/treasury-breakdown?company={company_id}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get treasury breakdown for company {company_id}: {e}")
return None
async def get_detailed_data(
self,
company_id: int,
jwt_token: str,
data_type: str
) -> Optional[Dict[str, Any]]:
"""
Get detailed data for clients or suppliers.
Args:
company_id: Company ID
jwt_token: JWT access token
data_type: Type of data ('clients' or 'suppliers')
Returns:
Dict with detailed data (list of clients/suppliers with balances)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/detailed-data?company={company_id}&data_type={data_type}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get detailed data ({data_type}) for company {company_id}: {e}")
return None
async def get_maturity_data(
self,
company_id: int,
jwt_token: str,
period: str = "all"
) -> Optional[Dict[str, Any]]:
"""
Get maturity data (in term/overdue breakdown).
Args:
company_id: Company ID
jwt_token: JWT access token
period: Period filter ('all', '30', '60', '90')
Returns:
Dict with maturity data (in_term, overdue, total)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/maturity?company={company_id}&period={period}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get maturity data for company {company_id}: {e}")
return None
async def get_performance_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get performance data (incasari/plati totals).
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with performance data (incasari_total, plati_total, net)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/performance?company={company_id}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get performance data for company {company_id}: {e}")
return None
async def get_monthly_flows(
self,
company_id: int,
jwt_token: str,
months: int = 12
) -> Optional[Dict[str, Any]]:
"""
Get monthly cash flows data.
Args:
company_id: Company ID
jwt_token: JWT access token
months: Number of months to retrieve
Returns:
Dict with monthly flows (months, incasari, plati arrays)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/monthly-flows?company={company_id}&months={months}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get monthly flows for company {company_id}: {e}")
return None
async def get_trends(
self,
company_id: int,
jwt_token: str,
period: str = "12m"
) -> Optional[Dict[str, Any]]:
"""
Get trends data (12-month historical data for collections/payments).
Args:
company_id: Company ID
jwt_token: JWT access token
period: Period for trends (e.g., "12m", "6m", "ytd")
Returns:
Dict with trends data including periods, clienti_incasat, furnizori_achitat arrays
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/trends?company={company_id}&period={period}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get trends for company {company_id}: {e}")
return None
# =========================================================================
# INVOICES ENDPOINTS
# =========================================================================
async def search_invoices(
self,
company_id: int,
jwt_token: str,
filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search invoices with optional filters.
Args:
company_id: Company ID
jwt_token: JWT access token
filters: Optional filters dict:
- date_from: str (YYYY-MM-DD)
- date_to: str (YYYY-MM-DD)
- status: str (paid, unpaid, overdue)
- client_name: str
- partner_type: str (CLIENTI, FURNIZORI)
- partner_name: str
- series: str
- number: str
Returns:
List of invoice dicts
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
params = {"company": company_id}
if filters:
params.update(filters)
response = await self.client.get(
"/api/reports/invoices/",
params=params,
headers=self._get_auth_headers(jwt_token)
)
data = await self._handle_response(response)
if isinstance(data, dict) and 'invoices' in data:
invoice_list = data['invoices']
return invoice_list
elif isinstance(data, list):
return data
else:
logger.warning(f"📥 Unexpected response format: {type(data)}")
return []
except Exception as e:
logger.error(f"Failed to search invoices for company {company_id}: {e}")
return []
async def get_invoice_summary(
self,
company_id: int,
jwt_token: str,
partner_type: str = "CLIENTI"
) -> Optional[Dict[str, Any]]:
"""
Get invoice summary statistics.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with summary (total_count, total_amount, paid, unpaid, etc.)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/invoices/summary",
params={
"company": str(company_id),
"partner_type": partner_type
},
headers=self._get_auth_headers(jwt_token)
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get invoice summary for company {company_id}: {e}")
return None
# =========================================================================
# TREASURY ENDPOINTS
# =========================================================================
async def get_treasury_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get treasury/cash flow data for a company.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with treasury data (cash_balance, incoming, outgoing, etc.)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/treasury/bank-cash-register",
params={
"company": str(company_id),
"page": 1,
"page_size": 1000
},
headers=self._get_auth_headers(jwt_token)
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get treasury data for company {company_id}: {e}")
return None
# =========================================================================
# EXPORT ENDPOINTS
# =========================================================================
async def export_report(
self,
jwt_token: str,
report_type: str,
company_id: int,
format: str = "xlsx",
filters: Optional[Dict[str, Any]] = None
) -> Optional[bytes]:
"""
Generate and export a report.
Args:
jwt_token: JWT access token
report_type: Type of report ('dashboard', 'invoices', 'treasury')
company_id: Company ID
format: Export format ('xlsx', 'csv', 'pdf')
filters: Optional filters for data
Returns:
bytes: File content, None if failed
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
request_data = {
"type": report_type,
"company_id": company_id,
"format": format,
"filters": filters or {}
}
response = await self.client.post(
"/api/telegram/export",
json=request_data,
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to export report: {e}")
return None
# =========================================================================
# CACHE MANAGEMENT
# =========================================================================
async def invalidate_cache(
self,
jwt_token: str,
company_id: Optional[int] = None,
cache_type: Optional[str] = None
) -> bool:
"""
Invalidate cache entries.
Args:
jwt_token: JWT access token
company_id: Optional company ID (None = all companies)
cache_type: Optional cache type (None = all types)
Returns:
bool: True if successful
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
request_data = {}
if company_id is not None:
request_data['company_id'] = company_id
if cache_type is not None:
request_data['cache_type'] = cache_type
response = await self.client.post(
"/api/reports/cache/invalidate",
json=request_data,
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
logger.info(f"Cache invalidated: company_id={company_id}, cache_type={cache_type}")
return True
except Exception as e:
logger.error(f"Failed to invalidate cache: {e}")
return False
async def toggle_user_cache(
self,
jwt_token: str,
enabled: bool
) -> bool:
"""
Toggle cache for current user.
Args:
jwt_token: JWT access token
enabled: True to enable cache, False to disable
Returns:
bool: True if successful
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/reports/cache/toggle-user",
json={"enabled": enabled},
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
logger.info(f"User cache toggled: enabled={enabled}")
return True
except Exception as e:
logger.error(f"Failed to toggle user cache: {e}")
return False
async def get_cache_stats(
self,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get cache statistics including user-specific settings.
Args:
jwt_token: JWT access token
Returns:
Dict with cache stats including 'user_enabled' field
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/cache/stats",
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to get cache stats: {e}")
return None
# =========================================================================
# HEALTH CHECK
# =========================================================================
async def health_check(self) -> bool:
"""
Check if backend is healthy and reachable.
Returns:
bool: True if backend is healthy
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get("/api/telegram/health")
return response.status_code == 200
except Exception as e:
logger.error(f"Backend health check failed: {e}")
return False
# Singleton instance for global use
_backend_client_instance: Optional[BackendAPIClient] = None
def get_backend_client() -> BackendAPIClient:
"""
Get or create the singleton BackendAPIClient instance.
Returns:
BackendAPIClient: Singleton instance
"""
global _backend_client_instance
if _backend_client_instance is None:
_backend_client_instance = BackendAPIClient()
return _backend_client_instance
# Export main classes and functions
__all__ = [
'BackendAPIClient',
'get_backend_client',
'BACKEND_URL'
]

Some files were not shown because too many files have changed in this diff Show More