"""CRUD operations for receipts.""" import json from datetime import datetime, date from decimal import Decimal from typing import Optional, List, Tuple, Any from sqlalchemy import select, func, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus from backend.modules.data_entry.schemas.receipt import ReceiptCreate, ReceiptUpdate, ReceiptFilter def _serialize_tva_breakdown(tva_breakdown: Optional[List[Any]]) -> Optional[str]: """Serialize TVA breakdown list to JSON string for SQLite storage.""" if tva_breakdown is None: return None # Convert Decimal to float for JSON serialization serializable = [] for entry in tva_breakdown: if hasattr(entry, 'model_dump'): # Pydantic model item = entry.model_dump() elif isinstance(entry, dict): item = entry.copy() else: item = dict(entry) # Convert Decimal to float if 'amount' in item and isinstance(item['amount'], Decimal): item['amount'] = float(item['amount']) serializable.append(item) return json.dumps(serializable) def _serialize_payment_methods(payment_methods: Optional[List[Any]]) -> Optional[str]: """Serialize payment methods list to JSON string for SQLite storage.""" if payment_methods is None: return None serializable = [] for pm in payment_methods: if hasattr(pm, 'model_dump'): item = pm.model_dump() elif isinstance(pm, dict): item = pm.copy() else: item = dict(pm) # Convert Decimal to float for JSON if 'amount' in item: if hasattr(item['amount'], '__float__'): item['amount'] = float(item['amount']) serializable.append(item) return json.dumps(serializable) class ReceiptCRUD: """CRUD operations for Receipt model.""" @staticmethod async def create( session: AsyncSession, data: ReceiptCreate, created_by: str, ) -> Receipt: """Create a new receipt.""" # Get data as dict and serialize tva_breakdown and payment_methods to JSON string receipt_data = data.model_dump() receipt_data['tva_breakdown'] = _serialize_tva_breakdown(receipt_data.get('tva_breakdown')) receipt_data['payment_methods'] = _serialize_payment_methods(receipt_data.get('payment_methods')) receipt = Receipt( **receipt_data, created_by=created_by, status=ReceiptStatus.DRAFT, ) session.add(receipt) await session.commit() await session.refresh(receipt) # Reload with relationships to avoid lazy loading issues with async return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True) @staticmethod async def get_by_id( session: AsyncSession, receipt_id: int, include_relations: bool = True, ) -> Optional[Receipt]: """Get receipt by ID, optionally with relationships.""" query = select(Receipt).where(Receipt.id == receipt_id) if include_relations: query = query.options( selectinload(Receipt.attachments), selectinload(Receipt.entries), ) result = await session.execute(query) return result.scalar_one_or_none() @staticmethod async def get_list( session: AsyncSession, filters: ReceiptFilter, ) -> Tuple[List[Receipt], int]: """Get paginated list of receipts with filters.""" # Base query query = select(Receipt).options( selectinload(Receipt.attachments), selectinload(Receipt.entries), ) # Apply filters if filters.status: query = query.where(Receipt.status == filters.status) if filters.direction: query = query.where(Receipt.direction == filters.direction) if filters.company_id: query = query.where(Receipt.company_id == filters.company_id) if filters.created_by: query = query.where(Receipt.created_by == filters.created_by) if filters.date_from: query = query.where(Receipt.receipt_date >= filters.date_from) if filters.date_to: query = query.where(Receipt.receipt_date <= filters.date_to) if filters.search: search_term = f"%{filters.search}%" query = query.where( or_( Receipt.description.ilike(search_term), Receipt.partner_name.ilike(search_term), Receipt.receipt_number.ilike(search_term), ) ) # Count total count_query = select(func.count()).select_from(query.subquery()) total_result = await session.execute(count_query) total = total_result.scalar() or 0 # Apply pagination and ordering query = query.order_by(Receipt.created_at.desc()) offset = (filters.page - 1) * filters.page_size query = query.offset(offset).limit(filters.page_size) # Execute result = await session.execute(query) receipts = result.scalars().all() return list(receipts), total @staticmethod async def get_pending_review( session: AsyncSession, company_id: Optional[int] = None, ) -> List[Receipt]: """Get all receipts pending review.""" query = select(Receipt).where( Receipt.status == ReceiptStatus.PENDING_REVIEW ).options( selectinload(Receipt.attachments), selectinload(Receipt.entries), ) if company_id: query = query.where(Receipt.company_id == company_id) query = query.order_by(Receipt.submitted_at.asc()) result = await session.execute(query) return list(result.scalars().all()) @staticmethod async def update( session: AsyncSession, receipt: Receipt, data: ReceiptUpdate, ) -> Receipt: """Update receipt fields.""" update_data = data.model_dump(exclude_unset=True) # Recalculate tva_total from tva_breakdown if breakdown is being updated if 'tva_breakdown' in update_data and update_data['tva_breakdown']: tva_total = sum( float(entry.get('amount', 0) if isinstance(entry, dict) else getattr(entry, 'amount', 0)) for entry in update_data['tva_breakdown'] ) update_data['tva_total'] = round(tva_total, 2) # Serialize tva_breakdown and payment_methods to JSON string if present if 'tva_breakdown' in update_data: update_data['tva_breakdown'] = _serialize_tva_breakdown(update_data['tva_breakdown']) if 'payment_methods' in update_data: update_data['payment_methods'] = _serialize_payment_methods(update_data['payment_methods']) for field, value in update_data.items(): setattr(receipt, field, value) receipt.updated_at = datetime.utcnow() session.add(receipt) await session.commit() await session.refresh(receipt) # Reload with relationships to avoid lazy loading issues with async return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True) @staticmethod async def update_status( session: AsyncSession, receipt: Receipt, new_status: ReceiptStatus, reviewed_by: Optional[str] = None, rejection_reason: Optional[str] = None, ) -> Receipt: """Update receipt workflow status.""" receipt.status = new_status receipt.updated_at = datetime.utcnow() if new_status == ReceiptStatus.PENDING_REVIEW: receipt.submitted_at = datetime.utcnow() if new_status in [ReceiptStatus.APPROVED, ReceiptStatus.REJECTED]: receipt.reviewed_by = reviewed_by receipt.reviewed_at = datetime.utcnow() if new_status == ReceiptStatus.REJECTED: receipt.rejection_reason = rejection_reason if new_status == ReceiptStatus.DRAFT: # Reset review fields when moving back to draft receipt.rejection_reason = None session.add(receipt) await session.commit() await session.refresh(receipt) # Reload with relationships to avoid lazy loading issues with async return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True) @staticmethod async def delete(session: AsyncSession, receipt: Receipt) -> bool: """Delete a receipt (cascade deletes attachments and entries).""" await session.delete(receipt) await session.commit() return True @staticmethod async def can_edit(receipt: Receipt, username: str) -> bool: """Check if user can edit receipt.""" # DRAFT and REJECTED receipts can be edited (to fix and resubmit) if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]: return False # Only creator can edit their own receipts return receipt.created_by == username @staticmethod async def can_delete(receipt: Receipt, username: str) -> bool: """Check if user can delete receipt.""" # Only DRAFT receipts can be deleted if receipt.status != ReceiptStatus.DRAFT: return False # Only creator can delete their own drafts return receipt.created_by == username @staticmethod async def can_submit(receipt: Receipt, username: str) -> bool: """Check if user can submit receipt for review.""" # Only DRAFT or REJECTED receipts can be submitted if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]: return False # Only creator can submit their own receipts return receipt.created_by == username @staticmethod async def get_stats( session: AsyncSession, company_id: int, created_by: Optional[str] = None, ) -> dict: """Get receipt statistics.""" base_query = select( Receipt.status, func.count(Receipt.id).label("count"), func.sum(Receipt.amount).label("total_amount"), ).where( Receipt.company_id == company_id ) if created_by: base_query = base_query.where(Receipt.created_by == created_by) query = base_query.group_by(Receipt.status) result = await session.execute(query) rows = result.all() stats = { "draft": {"count": 0, "amount": 0}, "pending_review": {"count": 0, "amount": 0}, "approved": {"count": 0, "amount": 0}, "rejected": {"count": 0, "amount": 0}, "synced": {"count": 0, "amount": 0}, "total": {"count": 0, "amount": 0}, } for row in rows: status_key = row.status.value stats[status_key] = { "count": row.count, "amount": float(row.total_amount or 0), } stats["total"]["count"] += row.count stats["total"]["amount"] += float(row.total_amount or 0) return stats