"""CRUD operations for accounting entries.""" from datetime import datetime from typing import Optional, List from sqlalchemy import select, delete from sqlalchemy.ext.asyncio import AsyncSession from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry, EntryType from backend.modules.data_entry.schemas.receipt import AccountingEntryCreate, AccountingEntryUpdate class AccountingEntryCRUD: """CRUD operations for AccountingEntry model.""" @staticmethod async def create( session: AsyncSession, receipt_id: int, data: AccountingEntryCreate, sort_order: int = 0, is_auto_generated: bool = True, ) -> AccountingEntry: """Create a new accounting entry.""" entry = AccountingEntry( receipt_id=receipt_id, entry_type=data.entry_type, account_code=data.account_code, account_name=data.account_name, amount=data.amount, partner_id=data.partner_id, cost_center_id=data.cost_center_id, is_auto_generated=is_auto_generated, sort_order=sort_order, ) session.add(entry) await session.commit() await session.refresh(entry) return entry @staticmethod async def create_bulk( session: AsyncSession, receipt_id: int, entries: List[AccountingEntryCreate], is_auto_generated: bool = True, ) -> List[AccountingEntry]: """Create multiple accounting entries at once.""" created_entries = [] for idx, entry_data in enumerate(entries): entry = AccountingEntry( receipt_id=receipt_id, entry_type=entry_data.entry_type, account_code=entry_data.account_code, account_name=entry_data.account_name, amount=entry_data.amount, partner_id=entry_data.partner_id, cost_center_id=entry_data.cost_center_id, is_auto_generated=is_auto_generated, sort_order=idx, ) session.add(entry) created_entries.append(entry) await session.commit() for entry in created_entries: await session.refresh(entry) return created_entries @staticmethod async def get_by_id( session: AsyncSession, entry_id: int, ) -> Optional[AccountingEntry]: """Get accounting entry by ID.""" query = select(AccountingEntry).where(AccountingEntry.id == entry_id) result = await session.execute(query) return result.scalar_one_or_none() @staticmethod async def get_by_receipt_id( session: AsyncSession, receipt_id: int, ) -> List[AccountingEntry]: """Get all accounting entries for a receipt.""" query = select(AccountingEntry).where( AccountingEntry.receipt_id == receipt_id ).order_by(AccountingEntry.sort_order.asc()) result = await session.execute(query) return list(result.scalars().all()) @staticmethod async def update( session: AsyncSession, entry: AccountingEntry, data: AccountingEntryUpdate, modified_by: str, ) -> AccountingEntry: """Update an accounting entry.""" update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(entry, field, value) entry.is_auto_generated = False entry.modified_by = modified_by entry.modified_at = datetime.utcnow() session.add(entry) await session.commit() await session.refresh(entry) return entry @staticmethod async def delete(session: AsyncSession, entry: AccountingEntry) -> bool: """Delete an accounting entry.""" await session.delete(entry) await session.commit() return True @staticmethod async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int: """Delete all accounting entries for a receipt.""" query = delete(AccountingEntry).where(AccountingEntry.receipt_id == receipt_id) result = await session.execute(query) await session.commit() return result.rowcount @staticmethod async def replace_all_for_receipt( session: AsyncSession, receipt_id: int, entries: List[AccountingEntryCreate], modified_by: str, ) -> List[AccountingEntry]: """Replace all entries for a receipt with new ones.""" # Delete existing entries await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id) # Create new entries (marked as manually modified) created_entries = [] for idx, entry_data in enumerate(entries): entry = AccountingEntry( receipt_id=receipt_id, entry_type=entry_data.entry_type, account_code=entry_data.account_code, account_name=entry_data.account_name, amount=entry_data.amount, partner_id=entry_data.partner_id, cost_center_id=entry_data.cost_center_id, is_auto_generated=False, modified_by=modified_by, modified_at=datetime.utcnow(), sort_order=idx, ) session.add(entry) created_entries.append(entry) await session.commit() for entry in created_entries: await session.refresh(entry) return created_entries @staticmethod async def validate_entries(entries: List[AccountingEntryCreate]) -> tuple[bool, str]: """ Validate accounting entries. Returns (is_valid, error_message). """ if not entries: return False, "At least one entry is required" total_debit = sum( e.amount for e in entries if e.entry_type == EntryType.DEBIT ) total_credit = sum( e.amount for e in entries if e.entry_type == EntryType.CREDIT ) # Check balance (debit should equal credit) if abs(total_debit - total_credit) > 0.01: return False, f"Entries not balanced: Debit={total_debit}, Credit={total_credit}" # Check for valid account codes for entry in entries: if not entry.account_code or len(entry.account_code) < 3: return False, f"Invalid account code: {entry.account_code}" return True, ""