""" Bulk upload API endpoints for batch receipt processing. Endpoints: - POST /upload - Submit multiple files for OCR processing in a single batch - GET /batches/{batch_id}/status - Get batch status with optional long-polling Validation: - Max 100 files per batch - Max 10MB per file - Allowed types: PDF, PNG, JPG Duplicate Detection (US-007): - SHA-256 hash calculated for each file - Duplicate files (same hash + company_id) are rejected with 409 Conflict info - Duplicates reported in error list, non-duplicates processed normally """ import asyncio import hashlib import logging from datetime import datetime from decimal import Decimal from pathlib import Path from typing import Annotated, List, Optional, Union from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query, Header from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from backend.modules.data_entry.db.database import get_session from backend.modules.data_entry.db.models import BatchUpload, BatchJob, BatchStatus, Receipt, ReceiptAttachment from backend.modules.data_entry.schemas.bulk import ( BulkUploadResponse, BulkUploadResponseWithDuplicates, BatchStatusResponse, BatchJobInfo, DuplicateFileInfo, RetryResponse, BatchRetryResponse, CancelJobResponse, CancelBatchResponse ) from backend.modules.data_entry.services.ocr.job_queue import job_queue, OCRJobStatus from backend.config import settings # Auth integration from shared.auth.dependencies import get_current_user from shared.auth.models import CurrentUser logger = logging.getLogger(__name__) router = APIRouter() # ============ Helper for selected company from header ============ async def get_selected_company( current_user: CurrentUser = Depends(get_current_user), x_selected_company: Annotated[Optional[str], Header()] = None ) -> int: """ Get selected company from X-Selected-Company header. Validates that the user has access to the specified company. Falls back to user's first company if no header is provided. """ if x_selected_company: try: company_id = int(x_selected_company) except ValueError: raise HTTPException( status_code=400, detail=f"Invalid company ID format: {x_selected_company}" ) if str(company_id) in current_user.companies: return company_id raise HTTPException( status_code=403, detail=f"Nu aveți acces la firma {company_id}" ) # No header - use first company from user's list if current_user.companies: try: return int(current_user.companies[0]) except (ValueError, IndexError): pass raise HTTPException( status_code=400, detail="Nu aveți nicio firmă asignată" ) # Validation constants MAX_FILES_PER_BATCH = 100 MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 # 10MB ALLOWED_MIME_TYPES = {"image/jpeg", "image/png", "application/pdf"} def compute_file_hash(content: bytes) -> str: """ Compute SHA-256 hash of file content. Used for duplicate detection - same file content = same hash. Args: content: Raw file bytes Returns: Hexadecimal string of SHA-256 hash (64 characters) """ return hashlib.sha256(content).hexdigest() async def check_duplicate_hashes( session: AsyncSession, file_hashes: List[str], company_id: int ) -> dict[str, int]: """ Check which file hashes already exist in the database for this company. Args: session: Database session file_hashes: List of SHA-256 hashes to check company_id: Company ID to scope the duplicate check Returns: Dict mapping hash -> existing receipt_id for duplicates found """ if not file_hashes: return {} # Query for existing receipts with these hashes for this company result = await session.execute( select(Receipt.file_hash, Receipt.id).where( and_( Receipt.file_hash.in_(file_hashes), Receipt.company_id == company_id ) ) ) # Build hash -> receipt_id mapping # Note: result.all() is synchronous in SQLAlchemy async, returns list of tuples duplicates = {} rows = result.all() for row in rows: duplicates[row[0]] = row[1] return duplicates @router.post("/upload", response_model=Union[BulkUploadResponse, BulkUploadResponseWithDuplicates]) async def bulk_upload( files: List[UploadFile] = File(..., description="Multiple files to upload"), session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user), selected_company: int = Depends(get_selected_company) ): """ Upload multiple files for batch OCR processing. Creates a batch record and queues all files as OCR jobs. Invalid files cause entire batch rejection (validation errors). Duplicate files are reported separately and skipped - non-duplicates are processed. Duplicate Detection (US-007): - SHA-256 hash calculated for each file before processing - Files with existing hash for same company are rejected with 409 info - Response includes duplicate details with existing_receipt_id Args: files: List of image/PDF files (max 100 files, max 10MB each) Returns: BulkUploadResponse with batch_id and list of job_ids BulkUploadResponseWithDuplicates if some files were duplicates Raises: 400: If validation fails (too many files, file too large, invalid type) 409: If ALL files are duplicates 500: If job creation fails """ # Validate file count if len(files) == 0: raise HTTPException( status_code=400, detail="No files provided" ) if len(files) > MAX_FILES_PER_BATCH: raise HTTPException( status_code=400, detail=f"Too many files. Maximum {MAX_FILES_PER_BATCH} files per batch." ) # Pre-validate all files before creating any jobs (atomic check) invalid_files = [] file_contents = [] for file in files: # Check MIME type if file.content_type not in ALLOWED_MIME_TYPES: invalid_files.append(f"{file.filename}: Invalid type ({file.content_type})") continue # Read content and check size content = await file.read() if len(content) > MAX_FILE_SIZE_BYTES: invalid_files.append(f"{file.filename}: File too large ({len(content) // (1024*1024)}MB > 10MB)") continue # Compute SHA-256 hash for duplicate detection (US-007) file_hash = compute_file_hash(content) # Store for later processing file_contents.append({ "filename": file.filename, "content": content, "mime_type": file.content_type, "file_hash": file_hash }) # If any files are invalid, reject the entire batch if invalid_files: raise HTTPException( status_code=400, detail={ "message": f"Validation failed for {len(invalid_files)} file(s)", "invalid_files": invalid_files } ) # Check for duplicates BEFORE creating batch (US-007) all_hashes = [f["file_hash"] for f in file_contents] existing_duplicates = await check_duplicate_hashes(session, all_hashes, selected_company) # Separate duplicate files from processable files duplicate_files: List[DuplicateFileInfo] = [] processable_files = [] for file_data in file_contents: if file_data["file_hash"] in existing_duplicates: existing_receipt_id = existing_duplicates[file_data["file_hash"]] duplicate_files.append(DuplicateFileInfo( filename=file_data["filename"], error="duplicate", existing_receipt_id=existing_receipt_id, message=f"Fișier duplicat - există deja ca bon #{existing_receipt_id}" )) logger.info( f"[BulkUpload] Duplicate detected: {file_data['filename']} " f"(hash={file_data['file_hash'][:16]}...) matches receipt #{existing_receipt_id}" ) else: processable_files.append(file_data) # If ALL files are duplicates, return 409 Conflict if len(duplicate_files) == len(file_contents): raise HTTPException( status_code=409, detail={ "error": "all_duplicates", "message": f"Toate cele {len(duplicate_files)} fișiere sunt duplicate", "duplicates": [d.model_dump() for d in duplicate_files] } ) # If no processable files remain after filtering (shouldn't happen but be safe) if not processable_files: raise HTTPException( status_code=409, detail={ "error": "no_files_to_process", "message": "Nu există fișiere de procesat", "duplicates": [d.model_dump() for d in duplicate_files] } ) # Create batch record with company_id for auto-save batch = BatchUpload( user_id=current_user.username, company_id=selected_company, status=BatchStatus.PENDING, total_files=len(processable_files) # Only count processable files ) session.add(batch) await session.flush() # Get batch.id before creating jobs # Create OCR jobs for processable files only job_ids = [] batch_jobs = [] try: for file_data in processable_files: # Create OCR job using existing job_queue # Pass batch_id and file_hash for tracking job = await job_queue.create_job( file_bytes=file_data["content"], mime_type=file_data["mime_type"], engine="doctr_plus", # Default engine for bulk username=current_user.username, original_filename=file_data["filename"], batch_id=batch.id, # Link job to batch for auto-save integration file_hash=file_data["file_hash"] # Pass hash for storage in receipt ) job_ids.append(job.id) # Create batch_job link batch_job = BatchJob( batch_id=batch.id, job_id=job.id, filename=file_data["filename"] ) batch_jobs.append(batch_job) # Add all batch_job records for bj in batch_jobs: session.add(bj) # Commit everything atomically await session.commit() logger.info( f"[BulkUpload] Created batch {batch.id} with {len(job_ids)} jobs " f"for user {current_user.username}" f"{f', {len(duplicate_files)} duplicates skipped' if duplicate_files else ''}" ) # Return response with duplicate info if any duplicates were found if duplicate_files: return BulkUploadResponseWithDuplicates( batch_id=batch.id, job_ids=job_ids, total_files=len(file_contents), processed_files=len(job_ids), duplicate_files=len(duplicate_files), duplicates=duplicate_files, message=f"{len(job_ids)} fișier(e) în procesare, {len(duplicate_files)} duplicate ignorate" ) return BulkUploadResponse( batch_id=batch.id, job_ids=job_ids, total_files=len(job_ids), message=f"{len(job_ids)} files queued for processing" ) except Exception as e: # Rollback on any error await session.rollback() logger.error(f"[BulkUpload] Failed to create batch: {e}") raise HTTPException( status_code=500, detail=f"Failed to create batch: {str(e)}" ) # Long-polling constants MAX_WAIT_SECONDS = 30 POLL_INTERVAL_SECONDS = 0.5 async def _get_batch_status_snapshot( batch_id: int, session: AsyncSession ) -> Optional[dict]: """ Get current batch status snapshot. Returns dict with status counts and jobs list, or None if batch not found. """ # Get batch record batch_result = await session.execute( select(BatchUpload).where(BatchUpload.id == batch_id) ) batch = batch_result.scalar_one_or_none() if not batch: return None # Get all batch_jobs for this batch batch_jobs_result = await session.execute( select(BatchJob).where(BatchJob.batch_id == batch_id) ) batch_jobs = batch_jobs_result.scalars().all() if not batch_jobs: return { "batch": batch, "pending_count": 0, "processing_count": 0, "completed_count": 0, "failed_count": 0, "jobs": [], "total_amount": None } # Get job statuses and error_messages from OCR job queue (SQLite) job_statuses = {} job_errors = {} for bj in batch_jobs: job = await job_queue.get_job(bj.job_id) if job: job_statuses[bj.job_id] = job.status.value job_errors[bj.job_id] = job.error_message else: # Job not found in queue - treat as failed job_statuses[bj.job_id] = "failed" job_errors[bj.job_id] = "Job not found in queue" # Count by status pending_count = sum(1 for s in job_statuses.values() if s == "pending") processing_count = sum(1 for s in job_statuses.values() if s == "processing") completed_count = sum(1 for s in job_statuses.values() if s == "completed") failed_count = sum(1 for s in job_statuses.values() if s == "failed") # Build jobs list with status info jobs_info = [] for bj in batch_jobs: jobs_info.append({ "job_id": bj.job_id, "filename": bj.filename, "status": job_statuses.get(bj.job_id, "failed"), "receipt_id": bj.receipt_id, "error_message": job_errors.get(bj.job_id) }) # Calculate total_amount from completed receipts total_amount = None receipt_ids = [bj.receipt_id for bj in batch_jobs if bj.receipt_id is not None] if receipt_ids: amount_result = await session.execute( select(func.sum(Receipt.amount)).where(Receipt.id.in_(receipt_ids)) ) total_sum = amount_result.scalar() if total_sum is not None: total_amount = float(total_sum) return { "batch": batch, "pending_count": pending_count, "processing_count": processing_count, "completed_count": completed_count, "failed_count": failed_count, "jobs": jobs_info, "total_amount": total_amount } def _compute_batch_overall_status(pending: int, processing: int, completed: int, failed: int, total: int) -> str: """Compute overall batch status from job counts.""" if pending + processing == 0: # All jobs finished if failed == total: return BatchStatus.FAILED.value return BatchStatus.COMPLETED.value elif processing > 0 or completed > 0 or failed > 0: return BatchStatus.PROCESSING.value else: return BatchStatus.PENDING.value @router.get("/batches/{batch_id}/status", response_model=BatchStatusResponse) async def get_batch_status( batch_id: int, wait: Optional[int] = Query( default=None, ge=0, le=MAX_WAIT_SECONDS, description="Long-polling wait time in seconds (max 30)" ), session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user) ): """ Get batch processing status with optional long-polling. Returns aggregated status counts and individual job statuses. When `wait` parameter is provided, the endpoint will poll until: - Status changes from initial snapshot - All jobs complete (pending + processing = 0) - Timeout is reached Args: batch_id: Batch ID to query wait: Optional wait time in seconds for long-polling (0-30) Returns: BatchStatusResponse with status counts and job details Raises: 404: If batch not found """ # Get initial snapshot snapshot = await _get_batch_status_snapshot(batch_id, session) if snapshot is None: raise HTTPException( status_code=404, detail=f"Batch {batch_id} not found" ) # If long-polling requested and jobs still in progress if wait and wait > 0: initial_pending = snapshot["pending_count"] initial_processing = snapshot["processing_count"] initial_completed = snapshot["completed_count"] initial_failed = snapshot["failed_count"] # Only wait if there are still jobs in progress if initial_pending + initial_processing > 0: elapsed = 0.0 while elapsed < wait: await asyncio.sleep(POLL_INTERVAL_SECONDS) elapsed += POLL_INTERVAL_SECONDS # Refresh snapshot snapshot = await _get_batch_status_snapshot(batch_id, session) if snapshot is None: # Batch deleted during polling (edge case) raise HTTPException(status_code=404, detail=f"Batch {batch_id} not found") # Check if status changed current_pending = snapshot["pending_count"] current_processing = snapshot["processing_count"] current_completed = snapshot["completed_count"] current_failed = snapshot["failed_count"] if (current_pending != initial_pending or current_processing != initial_processing or current_completed != initial_completed or current_failed != initial_failed): # Status changed, return immediately break # Check if all jobs finished if current_pending + current_processing == 0: break # Build response batch = snapshot["batch"] total_files = batch.total_files overall_status = _compute_batch_overall_status( snapshot["pending_count"], snapshot["processing_count"], snapshot["completed_count"], snapshot["failed_count"], total_files ) jobs = [ BatchJobInfo( job_id=j["job_id"], filename=j["filename"], status=j["status"], receipt_id=j["receipt_id"], error_message=j.get("error_message") ) for j in snapshot["jobs"] ] return BatchStatusResponse( batch_id=batch.id, status=overall_status, total_files=total_files, pending_count=snapshot["pending_count"], processing_count=snapshot["processing_count"], completed_count=snapshot["completed_count"], failed_count=snapshot["failed_count"], jobs=jobs, total_amount=snapshot["total_amount"], created_at=batch.created_at ) # ============ Retry Endpoints (US-006) ============ async def _retry_single_receipt( session: AsyncSession, receipt: Receipt, username: str ) -> tuple[bool, Optional[str], Optional[str]]: """ Retry processing for a single receipt. Finds the original file from attachments, resets processing status, and creates a new OCR job. Args: session: Database session receipt: Receipt to retry username: Username for the new OCR job Returns: Tuple of (success, job_id, error_message) """ # Get the first attachment to find the source file attachments_result = await session.execute( select(ReceiptAttachment) .where(ReceiptAttachment.receipt_id == receipt.id) .limit(1) ) attachment = attachments_result.scalar_one_or_none() if not attachment: return False, None, "Bonul nu are fișier atașat" # Construct full path to attachment file file_path = settings.data_entry_upload_path_resolved / attachment.file_path if not file_path.exists(): return False, None, "Fișierul original nu mai este disponibil" # Read file content try: with open(file_path, 'rb') as f: file_bytes = f.read() except Exception as e: logger.error(f"[Retry] Failed to read file {file_path}: {e}") return False, None, f"Eroare la citirea fișierului: {str(e)}" # Create new OCR job try: job = await job_queue.create_job( file_bytes=file_bytes, mime_type=attachment.mime_type, engine="doctr_plus", username=username, original_filename=attachment.filename, batch_id=None, # No batch for retry - direct processing file_hash=receipt.file_hash ) # Reset receipt processing status receipt.processing_status = "pending" receipt.processing_error = None receipt.processing_started_at = datetime.utcnow() receipt.processing_completed_at = None await session.flush() logger.info(f"[Retry] Receipt {receipt.id} requeued as job {job.id}") return True, job.id, None except Exception as e: logger.error(f"[Retry] Failed to create job for receipt {receipt.id}: {e}") return False, None, f"Eroare la crearea job-ului OCR: {str(e)}" @router.post("/retry/{receipt_id}", response_model=RetryResponse) async def retry_receipt( receipt_id: int, session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user), selected_company: int = Depends(get_selected_company) ): """ Retry OCR processing for a single failed receipt. Resets the receipt's processing_status to 'pending' and creates a new OCR job using the original attachment file. Args: receipt_id: ID of the receipt to retry Returns: RetryResponse with success status and new job ID Raises: 404: If receipt not found 400: If receipt is not in 'failed' status 400: If original file is not available """ # Get the receipt result = await session.execute( select(Receipt).where( and_( Receipt.id == receipt_id, Receipt.company_id == selected_company ) ) ) receipt = result.scalar_one_or_none() if not receipt: raise HTTPException( status_code=404, detail=f"Bonul #{receipt_id} nu a fost găsit" ) # Verify receipt is in failed status if receipt.processing_status != "failed": raise HTTPException( status_code=400, detail=f"Bonul nu este în stare de eroare (status actual: {receipt.processing_status})" ) # Attempt retry success, job_id, error = await _retry_single_receipt( session, receipt, current_user.username ) if not success: raise HTTPException( status_code=400, detail=error or "Eroare necunoscută la reîncărcare" ) await session.commit() return RetryResponse( success=True, receipt_id=receipt_id, job_id=job_id, message="Bon reîncarcat în procesare" ) @router.post("/retry-batch/{batch_id}", response_model=BatchRetryResponse) async def retry_batch_failed( batch_id: str, session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user), selected_company: int = Depends(get_selected_company) ): """ Retry all failed receipts in a batch. Finds all receipts with batch_id matching and processing_status='failed', then attempts to retry each one. Args: batch_id: Batch ID (UUID string from receipt.batch_id) Returns: BatchRetryResponse with counts of successful and failed retries Raises: 404: If no failed receipts found for batch """ # Find all failed receipts in this batch result = await session.execute( select(Receipt).where( and_( Receipt.batch_id == batch_id, Receipt.company_id == selected_company, Receipt.processing_status == "failed" ) ) ) failed_receipts = result.scalars().all() if not failed_receipts: raise HTTPException( status_code=404, detail=f"Nu există bonuri cu erori în batch-ul {batch_id}" ) # Retry each receipt retried_count = 0 failed_count = 0 errors = [] for receipt in failed_receipts: success, job_id, error = await _retry_single_receipt( session, receipt, current_user.username ) if success: retried_count += 1 else: failed_count += 1 errors.append(f"Bon #{receipt.id}: {error}") await session.commit() return BatchRetryResponse( success=retried_count > 0, batch_id=batch_id, retried_count=retried_count, failed_count=failed_count, errors=errors, message=f"{retried_count} bonuri reîncarcate în procesare" + (f", {failed_count} erori" if failed_count > 0 else "") ) # ============ Cancel Endpoints (US-014) ============ @router.post("/cancel/{job_id}", response_model=CancelJobResponse) async def cancel_job( job_id: str, session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user) ): """ Cancel a single OCR processing job. Only jobs with status 'pending' or 'processing' can be cancelled. Jobs with status 'completed' or 'failed' cannot be cancelled. Important: If a receipt has already been created from this job, it will NOT be deleted - receipts are preserved for audit purposes. Args: job_id: The UUID of the OCR job to cancel Returns: CancelJobResponse with cancellation details Raises: 404: If job not found in batch_jobs table 400: If job has already completed or failed """ # Find the job in batch_jobs table batch_job_result = await session.execute( select(BatchJob).where(BatchJob.job_id == job_id) ) batch_job = batch_job_result.scalar_one_or_none() if not batch_job: raise HTTPException( status_code=404, detail=f"Job {job_id} nu a fost găsit" ) # Get the OCR job from job_queue to check current status ocr_job = await job_queue.get_job(job_id) if not ocr_job: raise HTTPException( status_code=404, detail=f"Job {job_id} nu există în coada de procesare" ) # Check if job can be cancelled current_status = ocr_job.status.value if current_status == OCRJobStatus.completed.value: raise HTTPException( status_code=400, detail=f"Job-ul a fost deja procesat cu succes. Nu poate fi anulat." ) if current_status == OCRJobStatus.failed.value: raise HTTPException( status_code=400, detail=f"Job-ul a eșuat deja. Folosiți opțiunea de reîncercare în loc de anulare." ) if current_status == OCRJobStatus.cancelled.value: raise HTTPException( status_code=400, detail=f"Job-ul a fost deja anulat." ) # Update job status to cancelled in job_queue (SQLite) cancelled_at = datetime.utcnow() success = await job_queue.update_status( job_id=job_id, status=OCRJobStatus.cancelled, error="Cancelled by user" ) if not success: raise HTTPException( status_code=500, detail=f"Eroare la anularea job-ului" ) logger.info( f"[CancelJob] Job {job_id} cancelled by {current_user.username} " f"(previous status: {current_status})" ) return CancelJobResponse( success=True, job_id=job_id, cancelled_at=cancelled_at, message=f"Job anulat cu succes" ) @router.post("/cancel-batch/{batch_id}", response_model=CancelBatchResponse) async def cancel_batch( batch_id: int, session: AsyncSession = Depends(get_session), current_user: CurrentUser = Depends(get_current_user) ): """ Cancel all pending/processing jobs in a batch. Finds all jobs with status 'pending' or 'processing' in the specified batch and marks them as 'cancelled'. Jobs with status 'completed' or 'failed' are not affected. Important: Receipts that have already been created from completed jobs will NOT be deleted - they are preserved for audit purposes. Args: batch_id: The batch ID to cancel Returns: CancelBatchResponse with counts of cancelled and skipped jobs Raises: 404: If batch not found or no jobs exist for batch """ # Verify batch exists batch_result = await session.execute( select(BatchUpload).where(BatchUpload.id == batch_id) ) batch = batch_result.scalar_one_or_none() if not batch: raise HTTPException( status_code=404, detail=f"Batch {batch_id} nu a fost găsit" ) # Get all batch_jobs for this batch batch_jobs_result = await session.execute( select(BatchJob).where(BatchJob.batch_id == batch_id) ) batch_jobs = batch_jobs_result.scalars().all() if not batch_jobs: raise HTTPException( status_code=404, detail=f"Nu există job-uri în batch-ul {batch_id}" ) # Process each job - cancel pending/processing, skip completed/failed cancelled_count = 0 skipped_count = 0 for batch_job in batch_jobs: # Get current job status from OCR job queue ocr_job = await job_queue.get_job(batch_job.job_id) if not ocr_job: # Job not found in queue - treat as skipped skipped_count += 1 continue current_status = ocr_job.status.value # Only cancel pending or processing jobs if current_status in (OCRJobStatus.pending.value, OCRJobStatus.processing.value): success = await job_queue.update_status( job_id=batch_job.job_id, status=OCRJobStatus.cancelled, error="Cancelled by user (batch cancel)" ) if success: cancelled_count += 1 logger.debug(f"[CancelBatch] Cancelled job {batch_job.job_id}") else: # Failed to cancel - count as skipped skipped_count += 1 logger.warning( f"[CancelBatch] Failed to cancel job {batch_job.job_id}" ) else: # Job is completed, failed, or already cancelled - skip it skipped_count += 1 logger.info( f"[CancelBatch] Batch {batch_id} cancelled by {current_user.username}: " f"{cancelled_count} cancelled, {skipped_count} skipped" ) # Build message if cancelled_count == 0: message = f"Nu există job-uri de anulat în batch-ul {batch_id}" elif skipped_count == 0: message = f"{cancelled_count} job-uri anulate" else: message = f"{cancelled_count} job-uri anulate, {skipped_count} ignorate (deja procesate)" return CancelBatchResponse( success=cancelled_count > 0, batch_id=batch_id, cancelled_count=cancelled_count, skipped_count=skipped_count, message=message )