feat: Migrate to ultrathin monolith architecture
Consolidate 3 separate applications (reports-app, data-entry-app, telegram-bot) into a unified
architecture with single backend and frontend:
Backend Changes:
- Unified FastAPI backend at backend/ with modular structure
- Modules: reports, data_entry, telegram in backend/modules/
- Centralized config.py and main.py with all routers registered
- Single worker mode (--workers 1) for Telegram bot compatibility
- Shared Oracle connection pool and JWT authentication
- Unified requirements.txt and environment configuration
Frontend Changes:
- Single Vue.js SPA with module-based routing
- Unified frontend at src/ with modules in src/modules/{reports,data-entry}/
- Shared components and stores in src/shared/
- Error boundaries for module isolation
- Dual API proxy in Vite for module communication
Infrastructure:
- New unified startup scripts: start-prod.sh, start-test.sh, start-backend.sh
- Environment templates: .env.dev.example, .env.test.example, .env.prod.example
- Updated deployment scripts for Windows IIS
- Simplified SSH tunnel management
Documentation:
- Comprehensive CLAUDE.md with architecture overview
- Module-specific docs in docs/{data-entry,telegram}/
- Architecture decision records in docs/ARCHITECTURE-DECISIONS.md
- Deployment guides consolidated in deployment/windows/docs/
This migration reduces complexity, improves maintainability, and enables easier
deployment while maintaining all existing functionality.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
147
backend/.env.dev.example
Normal file
147
backend/.env.dev.example
Normal file
@@ -0,0 +1,147 @@
|
||||
# ============================================================================
|
||||
# ROA2WEB Unified Backend - Environment Configuration (Development)
|
||||
# ============================================================================
|
||||
# Single backend process serving Reports, Data Entry, and Telegram modules
|
||||
# IMPORTANT: Never commit this file to git!
|
||||
|
||||
# ============================================================================
|
||||
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Connection to CONTAFIN_ORACLE schema for authentication and user management
|
||||
# Each company is a separate schema in Oracle Database
|
||||
# Development: Through SSH tunnel (localhost:1526)
|
||||
|
||||
ORACLE_USER=CONTAFIN_ORACLE
|
||||
ORACLE_PASSWORD=your_oracle_password_here
|
||||
ORACLE_HOST=localhost
|
||||
ORACLE_PORT=1526
|
||||
ORACLE_SID=ROA
|
||||
|
||||
# Development: Start SSH tunnel before running backend
|
||||
# ./ssh_tunnel.sh start (production) or ./ssh-tunnel-test.sh start (test)
|
||||
|
||||
# ============================================================================
|
||||
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
|
||||
|
||||
JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
|
||||
JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration settings (used by shared/auth/jwt_handler.py)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ============================================================================
|
||||
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
|
||||
# ============================================================================
|
||||
# Used by Telegram module for session token validation
|
||||
# Must match between backend and Telegram bot
|
||||
|
||||
AUTH_SESSION_SECRET=generate_with_secrets_token_urlsafe_32
|
||||
|
||||
# ============================================================================
|
||||
# SERVER CONFIGURATION
|
||||
# ============================================================================
|
||||
# Unified backend server settings
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
DEBUG=true
|
||||
|
||||
# CORS Origins (comma-separated, includes both old and new frontend ports)
|
||||
CORS_ORIGINS=http://localhost:3000,http://localhost:3010,http://localhost:5173
|
||||
|
||||
# ============================================================================
|
||||
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
|
||||
# ============================================================================
|
||||
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
|
||||
# Used by backend/modules/reports/cache/config.py
|
||||
|
||||
# Core Settings
|
||||
CACHE_ENABLED=True
|
||||
CACHE_TYPE=hybrid
|
||||
CACHE_SQLITE_PATH=./data/cache/roa2web_cache.db
|
||||
CACHE_MEMORY_MAX_SIZE=1000
|
||||
CACHE_DEFAULT_TTL=900
|
||||
|
||||
# TTL per Cache Type (seconds)
|
||||
CACHE_TTL_SCHEMA=86400
|
||||
CACHE_TTL_COMPANIES=1800
|
||||
CACHE_TTL_DASHBOARD_SUMMARY=1800
|
||||
CACHE_TTL_DASHBOARD_TRENDS=1800
|
||||
CACHE_TTL_INVOICES=600
|
||||
CACHE_TTL_INVOICES_SUMMARY=900
|
||||
CACHE_TTL_TREASURY=600
|
||||
|
||||
# Maintenance
|
||||
CACHE_CLEANUP_INTERVAL=3600
|
||||
|
||||
# Event-Based Invalidation (experimental)
|
||||
CACHE_AUTO_INVALIDATE=False
|
||||
CACHE_CHECK_INTERVAL=300
|
||||
|
||||
# Performance Tracking
|
||||
CACHE_TRACK_PERFORMANCE=True
|
||||
CACHE_BENCHMARK_ON_STARTUP=False
|
||||
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - CONFIGURATION
|
||||
# ============================================================================
|
||||
# Data Entry module settings (receipts, OCR, etc.)
|
||||
|
||||
# Environment identifier (dev/test/prod)
|
||||
ORACLE_ENV=dev
|
||||
|
||||
# SQLite Database (development)
|
||||
SQLITE_DATABASE_PATH=data/receipts/receipts_dev.db
|
||||
|
||||
# File uploads
|
||||
UPLOAD_PATH=data/receipts/uploads
|
||||
MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# Test company (for development testing)
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
# Obtain bot token from @BotFather on Telegram
|
||||
|
||||
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
|
||||
|
||||
# Backend URL for bot to communicate with API
|
||||
BACKEND_URL=http://localhost:8000
|
||||
|
||||
# Internal API port (bot's internal API for backend callbacks)
|
||||
INTERNAL_API_PORT=8002
|
||||
|
||||
# Enable internal API documentation (development only)
|
||||
ENABLE_DOCS=false
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
|
||||
# ============================================================================
|
||||
# Required for email-based 2FA authentication flow
|
||||
# Users can login with email + password instead of web app linking
|
||||
|
||||
# SMTP Server Configuration
|
||||
SMTP_HOST=mail.romfast.ro
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=ups@romfast.ro
|
||||
SMTP_PASSWORD=your_smtp_password_here
|
||||
SMTP_FROM_EMAIL=ups@romfast.ro
|
||||
SMTP_FROM_NAME=ROA2WEB
|
||||
SMTP_USE_TLS=true
|
||||
|
||||
# Email Retry Settings
|
||||
EMAIL_MAX_RETRIES=3
|
||||
EMAIL_RETRY_DELAY=2.0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
|
||||
# ============================================================================
|
||||
# Separate SQLite database for Telegram bot auth codes and sessions
|
||||
|
||||
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram.db
|
||||
146
backend/.env.example
Normal file
146
backend/.env.example
Normal file
@@ -0,0 +1,146 @@
|
||||
# ============================================================================
|
||||
# ROA2WEB Unified Backend - Environment Configuration Template
|
||||
# ============================================================================
|
||||
# Single backend process serving Reports, Data Entry, and Telegram modules
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. Copy this template: cp .env.example .env.dev
|
||||
# 2. Fill in your actual values in .env.dev
|
||||
# 3. Run: ./start-dev.sh (auto-copies .env.dev to .env)
|
||||
#
|
||||
# ENVIRONMENT FILES:
|
||||
# - .env.dev → Development config (committed to git with real values)
|
||||
# - .env.test → Test config (committed to git)
|
||||
# - .env.prod → Production config template (committed, use placeholders!)
|
||||
# - .env → Active config (auto-generated, NOT committed)
|
||||
#
|
||||
# IMPORTANT: Never manually edit .env - edit .env.dev instead!
|
||||
|
||||
# ============================================================================
|
||||
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Connection to CONTAFIN_ORACLE schema for authentication and user management
|
||||
# Each company is a separate schema in Oracle Database
|
||||
# Development: Through SSH tunnel (localhost:1526)
|
||||
# Windows Production: Direct connection to Oracle server
|
||||
|
||||
ORACLE_USER=CONTAFIN_ORACLE
|
||||
ORACLE_PASSWORD=SET_IN_PRODUCTION_ENV
|
||||
ORACLE_HOST=localhost
|
||||
ORACLE_PORT=1526
|
||||
ORACLE_SID=ROA
|
||||
|
||||
# Development Only: Start SSH tunnel before running backend
|
||||
# ./ssh_tunnel.sh start
|
||||
# ./ssh_tunnel.sh status
|
||||
|
||||
# ============================================================================
|
||||
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
|
||||
# Generate strong secret: python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
JWT_SECRET_KEY=GENERATE_STRONG_SECRET_IN_PRODUCTION
|
||||
JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration settings (used by shared/auth/jwt_handler.py)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ============================================================================
|
||||
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
|
||||
# ============================================================================
|
||||
# Used by Telegram module for session token validation
|
||||
# Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
AUTH_SESSION_SECRET=your-secure-random-secret-here-min-32-chars
|
||||
|
||||
# ============================================================================
|
||||
# SERVER CONFIGURATION
|
||||
# ============================================================================
|
||||
# Unified backend server settings
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
DEBUG=false
|
||||
|
||||
# CORS Origins (comma-separated)
|
||||
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
|
||||
|
||||
# ============================================================================
|
||||
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
|
||||
# ============================================================================
|
||||
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
|
||||
# Used by backend/modules/reports/cache/config.py
|
||||
|
||||
# Core Settings
|
||||
REPORTS_CACHE_ENABLED=True
|
||||
REPORTS_CACHE_TYPE=hybrid
|
||||
REPORTS_CACHE_SQLITE_PATH=./data/cache/roa2web_cache.db
|
||||
REPORTS_CACHE_MEMORY_MAX_SIZE=1000
|
||||
REPORTS_CACHE_DEFAULT_TTL=900
|
||||
|
||||
# TTL per Cache Type (seconds)
|
||||
REPORTS_CACHE_TTL_SCHEMA=86400
|
||||
REPORTS_CACHE_TTL_COMPANIES=1800
|
||||
REPORTS_CACHE_TTL_DASHBOARD_SUMMARY=1800
|
||||
REPORTS_CACHE_TTL_DASHBOARD_TRENDS=1800
|
||||
REPORTS_CACHE_TTL_INVOICES=600
|
||||
REPORTS_CACHE_TTL_INVOICES_SUMMARY=900
|
||||
REPORTS_CACHE_TTL_TREASURY=600
|
||||
|
||||
# Maintenance
|
||||
REPORTS_CACHE_CLEANUP_INTERVAL=3600
|
||||
|
||||
# Event-Based Invalidation (experimental)
|
||||
REPORTS_CACHE_AUTO_INVALIDATE=False
|
||||
REPORTS_CACHE_CHECK_INTERVAL=300
|
||||
|
||||
# Performance Tracking
|
||||
REPORTS_CACHE_TRACK_PERFORMANCE=True
|
||||
REPORTS_CACHE_BENCHMARK_ON_STARTUP=False
|
||||
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - CONFIGURATION
|
||||
# ============================================================================
|
||||
# Data Entry module settings (receipts, OCR, etc.)
|
||||
|
||||
# SQLite Database
|
||||
DATA_ENTRY_SQLITE_DATABASE_PATH=data/receipts/receipts.db
|
||||
|
||||
# File uploads
|
||||
DATA_ENTRY_UPLOAD_PATH=data/receipts/uploads
|
||||
DATA_ENTRY_MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
# Obtain bot token from @BotFather on Telegram
|
||||
|
||||
TELEGRAM_BOT_TOKEN=your_bot_token_here
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
|
||||
# ============================================================================
|
||||
# Required for email-based 2FA authentication flow
|
||||
# Users can login with email + password instead of web app linking
|
||||
|
||||
# SMTP Server Configuration
|
||||
TELEGRAM_SMTP_HOST=mail.romfast.ro
|
||||
TELEGRAM_SMTP_PORT=587
|
||||
TELEGRAM_SMTP_USER=ups@romfast.ro
|
||||
TELEGRAM_SMTP_PASSWORD=your_smtp_password_here
|
||||
TELEGRAM_SMTP_FROM_EMAIL=ups@romfast.ro
|
||||
TELEGRAM_SMTP_FROM_NAME=ROA2WEB
|
||||
TELEGRAM_SMTP_USE_TLS=true
|
||||
|
||||
# Email Retry Settings
|
||||
TELEGRAM_EMAIL_MAX_RETRIES=3
|
||||
TELEGRAM_EMAIL_RETRY_DELAY=2.0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
|
||||
# ============================================================================
|
||||
# Separate SQLite database for Telegram bot auth codes and sessions
|
||||
|
||||
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram.db
|
||||
139
backend/.env.prod.example
Normal file
139
backend/.env.prod.example
Normal file
@@ -0,0 +1,139 @@
|
||||
# ============================================================================
|
||||
# ROA2WEB Unified Backend - Environment Configuration (PRODUCTION)
|
||||
# ============================================================================
|
||||
# Single backend process serving Reports, Data Entry, and Telegram modules
|
||||
# IMPORTANT: This is a TEMPLATE - fill in production values before deploying!
|
||||
|
||||
# ============================================================================
|
||||
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Connection to CONTAFIN_ORACLE schema for authentication and user management
|
||||
# PRODUCTION: Direct connection to Oracle server (no SSH tunnel)
|
||||
|
||||
ORACLE_USER=CONTAFIN_ORACLE
|
||||
ORACLE_PASSWORD=CHANGE_IN_PRODUCTION
|
||||
ORACLE_HOST=your_oracle_server_ip_or_hostname
|
||||
ORACLE_PORT=1521
|
||||
ORACLE_SID=ROA
|
||||
|
||||
# ============================================================================
|
||||
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# CRITICAL: Generate new secrets for production!
|
||||
# python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
JWT_SECRET_KEY=GENERATE_NEW_SECRET_FOR_PRODUCTION
|
||||
JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration settings
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ============================================================================
|
||||
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
|
||||
# ============================================================================
|
||||
# CRITICAL: Generate new secret for production!
|
||||
# python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
AUTH_SESSION_SECRET=GENERATE_NEW_SECRET_FOR_PRODUCTION
|
||||
|
||||
# ============================================================================
|
||||
# SERVER CONFIGURATION
|
||||
# ============================================================================
|
||||
# Unified backend server settings
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
DEBUG=false
|
||||
|
||||
# CORS Origins (comma-separated) - Update with production frontend URL
|
||||
CORS_ORIGINS=https://your-production-domain.com,http://localhost:3000
|
||||
|
||||
# ============================================================================
|
||||
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
|
||||
# ============================================================================
|
||||
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
|
||||
|
||||
# Core Settings
|
||||
CACHE_ENABLED=True
|
||||
CACHE_TYPE=hybrid
|
||||
CACHE_SQLITE_PATH=./data/cache/roa2web_cache_prod.db
|
||||
CACHE_MEMORY_MAX_SIZE=1000
|
||||
CACHE_DEFAULT_TTL=900
|
||||
|
||||
# TTL per Cache Type (seconds)
|
||||
CACHE_TTL_SCHEMA=86400
|
||||
CACHE_TTL_COMPANIES=1800
|
||||
CACHE_TTL_DASHBOARD_SUMMARY=1800
|
||||
CACHE_TTL_DASHBOARD_TRENDS=1800
|
||||
CACHE_TTL_INVOICES=600
|
||||
CACHE_TTL_INVOICES_SUMMARY=900
|
||||
CACHE_TTL_TREASURY=600
|
||||
|
||||
# Maintenance
|
||||
CACHE_CLEANUP_INTERVAL=3600
|
||||
|
||||
# Event-Based Invalidation (experimental)
|
||||
CACHE_AUTO_INVALIDATE=False
|
||||
CACHE_CHECK_INTERVAL=300
|
||||
|
||||
# Performance Tracking
|
||||
CACHE_TRACK_PERFORMANCE=True
|
||||
CACHE_BENCHMARK_ON_STARTUP=False
|
||||
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - CONFIGURATION
|
||||
# ============================================================================
|
||||
# Data Entry module settings (receipts, OCR, etc.)
|
||||
|
||||
# Environment identifier
|
||||
ORACLE_ENV=prod
|
||||
|
||||
# SQLite Database (production)
|
||||
SQLITE_DATABASE_PATH=data/receipts/receipts_prod.db
|
||||
|
||||
# File uploads
|
||||
UPLOAD_PATH=data/receipts/uploads
|
||||
MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
# Obtain bot token from @BotFather on Telegram
|
||||
# CRITICAL: Use production bot token, not development!
|
||||
|
||||
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
|
||||
|
||||
# Backend URL for bot to communicate with API
|
||||
BACKEND_URL=http://localhost:8000
|
||||
|
||||
# Internal API port (bot's internal API for backend callbacks)
|
||||
INTERNAL_API_PORT=8002
|
||||
|
||||
# Enable internal API documentation (DISABLE in production!)
|
||||
ENABLE_DOCS=false
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
|
||||
# ============================================================================
|
||||
# CRITICAL: Update with production SMTP credentials
|
||||
|
||||
# SMTP Server Configuration
|
||||
SMTP_HOST=mail.romfast.ro
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=ups@romfast.ro
|
||||
SMTP_PASSWORD=CHANGE_IN_PRODUCTION
|
||||
SMTP_FROM_EMAIL=ups@romfast.ro
|
||||
SMTP_FROM_NAME=ROA2WEB
|
||||
SMTP_USE_TLS=true
|
||||
|
||||
# Email Retry Settings
|
||||
EMAIL_MAX_RETRIES=3
|
||||
EMAIL_RETRY_DELAY=2.0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
|
||||
# ============================================================================
|
||||
# Separate SQLite database for Telegram bot auth codes and sessions
|
||||
|
||||
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram_prod.db
|
||||
147
backend/.env.test.example
Normal file
147
backend/.env.test.example
Normal file
@@ -0,0 +1,147 @@
|
||||
# ============================================================================
|
||||
# ROA2WEB Unified Backend - Environment Configuration (TEST)
|
||||
# ============================================================================
|
||||
# TEST environment using Oracle TEST server (10.0.20.121)
|
||||
# Single backend process serving Reports, Data Entry, and Telegram modules
|
||||
# IMPORTANT: Never commit this file to git!
|
||||
|
||||
# ============================================================================
|
||||
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Connection to CONTAFIN_ORACLE schema for authentication and user management
|
||||
# TEST: Through SSH tunnel to 10.0.20.121 (localhost:1526)
|
||||
|
||||
ORACLE_USER=CONTAFIN_ORACLE
|
||||
ORACLE_PASSWORD=your_oracle_password_here
|
||||
ORACLE_HOST=localhost
|
||||
ORACLE_PORT=1526
|
||||
ORACLE_SID=roa
|
||||
|
||||
# TEST: Start SSH tunnel before running backend
|
||||
# ./ssh-tunnel-test.sh start
|
||||
|
||||
# ============================================================================
|
||||
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
|
||||
# ============================================================================
|
||||
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
|
||||
|
||||
JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
|
||||
JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration settings (used by shared/auth/jwt_handler.py)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=480
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ============================================================================
|
||||
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
|
||||
# ============================================================================
|
||||
# Used by Telegram module for session token validation
|
||||
# Must match between backend and Telegram bot
|
||||
|
||||
AUTH_SESSION_SECRET=generate_with_secrets_token_urlsafe_32
|
||||
|
||||
# ============================================================================
|
||||
# SERVER CONFIGURATION
|
||||
# ============================================================================
|
||||
# Unified backend server settings
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
DEBUG=true
|
||||
|
||||
# CORS Origins (comma-separated, includes both old and new frontend ports)
|
||||
CORS_ORIGINS=http://localhost:3000,http://localhost:3010,http://localhost:5173
|
||||
|
||||
# ============================================================================
|
||||
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
|
||||
# ============================================================================
|
||||
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
|
||||
# Used by backend/modules/reports/cache/config.py
|
||||
|
||||
# Core Settings
|
||||
CACHE_ENABLED=True
|
||||
CACHE_TYPE=hybrid
|
||||
CACHE_SQLITE_PATH=./data/cache/roa2web_cache_test.db
|
||||
CACHE_MEMORY_MAX_SIZE=1000
|
||||
CACHE_DEFAULT_TTL=900
|
||||
|
||||
# TTL per Cache Type (seconds)
|
||||
CACHE_TTL_SCHEMA=86400
|
||||
CACHE_TTL_COMPANIES=1800
|
||||
CACHE_TTL_DASHBOARD_SUMMARY=1800
|
||||
CACHE_TTL_DASHBOARD_TRENDS=1800
|
||||
CACHE_TTL_INVOICES=600
|
||||
CACHE_TTL_INVOICES_SUMMARY=900
|
||||
CACHE_TTL_TREASURY=600
|
||||
|
||||
# Maintenance
|
||||
CACHE_CLEANUP_INTERVAL=3600
|
||||
|
||||
# Event-Based Invalidation (experimental)
|
||||
CACHE_AUTO_INVALIDATE=False
|
||||
CACHE_CHECK_INTERVAL=300
|
||||
|
||||
# Performance Tracking
|
||||
CACHE_TRACK_PERFORMANCE=True
|
||||
CACHE_BENCHMARK_ON_STARTUP=False
|
||||
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - CONFIGURATION
|
||||
# ============================================================================
|
||||
# Data Entry module settings (receipts, OCR, etc.)
|
||||
|
||||
# Environment identifier (dev/test/prod)
|
||||
ORACLE_ENV=test
|
||||
|
||||
# SQLite Database (test)
|
||||
SQLITE_DATABASE_PATH=data/receipts/receipts_test.db
|
||||
|
||||
# File uploads
|
||||
UPLOAD_PATH=data/receipts/uploads
|
||||
MAX_UPLOAD_SIZE_MB=10
|
||||
|
||||
# Test company (for testing)
|
||||
TEST_COMPANY_ID=110
|
||||
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
|
||||
# ============================================================================
|
||||
# Obtain bot token from @BotFather on Telegram
|
||||
|
||||
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
|
||||
|
||||
# Backend URL for bot to communicate with API
|
||||
BACKEND_URL=http://localhost:8000
|
||||
|
||||
# Internal API port (bot's internal API for backend callbacks)
|
||||
INTERNAL_API_PORT=8002
|
||||
|
||||
# Enable internal API documentation (development only)
|
||||
ENABLE_DOCS=false
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
|
||||
# ============================================================================
|
||||
# Required for email-based 2FA authentication flow
|
||||
# Users can login with email + password instead of web app linking
|
||||
|
||||
# SMTP Server Configuration
|
||||
SMTP_HOST=mail.romfast.ro
|
||||
SMTP_PORT=587
|
||||
SMTP_USER=ups@romfast.ro
|
||||
SMTP_PASSWORD=your_smtp_password_here
|
||||
SMTP_FROM_EMAIL=ups@romfast.ro
|
||||
SMTP_FROM_NAME=ROA2WEB
|
||||
SMTP_USE_TLS=true
|
||||
|
||||
# Email Retry Settings
|
||||
EMAIL_MAX_RETRIES=3
|
||||
EMAIL_RETRY_DELAY=2.0
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
|
||||
# ============================================================================
|
||||
# Separate SQLite database for Telegram bot auth codes and sessions
|
||||
|
||||
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram_test.db
|
||||
212
backend/ENV-SETUP.md
Normal file
212
backend/ENV-SETUP.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# Environment Configuration Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The unified backend uses environment-specific configuration files that are automatically loaded by startup scripts.
|
||||
|
||||
**SECURITY**: All `.env*` files (except `.env*.example`) contain real credentials and are **NEVER committed to git**.
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── .env.prod.example # Production template (COMMITTED - no credentials)
|
||||
├── .env.test.example # Test template (COMMITTED - no credentials)
|
||||
├── .env.prod.example # Production template (COMMITTED - no credentials)
|
||||
├── .env.example # Generic template (COMMITTED)
|
||||
├── .env.prod # Production config (IGNORED - real credentials)
|
||||
├── .env.test # Test config (IGNORED - real credentials)
|
||||
├── .env.prod # Production config (IGNORED - real credentials)
|
||||
└── .env # Active config (IGNORED - auto-generated)
|
||||
```
|
||||
|
||||
## First-Time Setup
|
||||
|
||||
### Production
|
||||
```bash
|
||||
# 1. Copy template
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
|
||||
# 2. Edit with your credentials
|
||||
vim backend/.env.prod
|
||||
|
||||
# 3. Fill in:
|
||||
# - ORACLE_PASSWORD
|
||||
# - JWT_SECRET_KEY (generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))")
|
||||
# - AUTH_SESSION_SECRET (generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))")
|
||||
# - TELEGRAM_BOT_TOKEN (from @BotFather)
|
||||
# - SMTP_PASSWORD
|
||||
|
||||
# 4. Start
|
||||
./start-prod.sh
|
||||
```
|
||||
|
||||
### Test
|
||||
```bash
|
||||
# Same process with .env.test
|
||||
cp backend/.env.test.example backend/.env.test
|
||||
vim backend/.env.test
|
||||
# Fill in TEST credentials (separate from dev!)
|
||||
./start-test.sh
|
||||
```
|
||||
|
||||
### Production
|
||||
```bash
|
||||
# Same process with .env.prod
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
vim backend/.env.prod
|
||||
# Fill in PRODUCTION credentials (generate NEW secrets!)
|
||||
./start-backend.sh start
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Production
|
||||
```bash
|
||||
./start-prod.sh # Checks for .env.prod → copies to .env → starts backend
|
||||
```
|
||||
|
||||
### Test
|
||||
```bash
|
||||
./start-test.sh # Checks for .env.test → copies to .env → starts backend
|
||||
```
|
||||
|
||||
### Production
|
||||
```bash
|
||||
# Manual setup (one-time)
|
||||
cp .env.prod.example .env.prod
|
||||
vim .env.prod # Fill in credentials
|
||||
# Then start
|
||||
./start-backend.sh start
|
||||
```
|
||||
|
||||
## Important Rules
|
||||
|
||||
### ✅ DO
|
||||
- Copy `.env.*.example` to `.env.*` and fill in real credentials
|
||||
- Edit `.env.prod` for production changes
|
||||
- Edit `.env.test` for test environment changes
|
||||
- Edit `.env.prod` for production
|
||||
- Generate **new** secrets for each environment
|
||||
- Keep `.env.prod`, `.env.test`, `.env.prod` **local only** (never commit!)
|
||||
|
||||
### ❌ DON'T
|
||||
- Don't commit `.env`, `.env.prod`, `.env.test`, or `.env.prod` (they're in .gitignore)
|
||||
- Don't manually edit `.env` (it's auto-generated!)
|
||||
- Don't use same secrets across environments
|
||||
- Don't share credentials via git (use secure channels)
|
||||
- Don't put real credentials in `.env*.example` files
|
||||
|
||||
## Environment Differences
|
||||
|
||||
| Setting | .env.prod | .env.test | .env.prod |
|
||||
|---------|----------|-----------|-----------|
|
||||
| Oracle SID | `ROA` | `roa` | `ROA` |
|
||||
| JWT Expire | 30 min | 480 min | 30 min |
|
||||
| DEBUG | `true` | `true` | `false` |
|
||||
| Cache DB | `roa2web_cache.db` | `roa2web_cache_test.db` | `roa2web_cache_prod.db` |
|
||||
| Receipts DB | `receipts_dev.db` | `receipts_test.db` | `receipts_prod.db` |
|
||||
| Telegram DB | `telegram.db` | `telegram_test.db` | `telegram_prod.db` |
|
||||
|
||||
## Security Notes
|
||||
|
||||
### Template Files (.env.*.example)
|
||||
These contain **placeholders only**:
|
||||
- ✅ Safe to commit to git
|
||||
- ✅ Shared across team
|
||||
- ✅ No real credentials
|
||||
- 📖 Used as reference for first-time setup
|
||||
|
||||
### Actual Config Files (.env.prod, .env.test, .env.prod)
|
||||
These contain **real credentials**:
|
||||
- ❌ **NEVER commit to git** (in .gitignore)
|
||||
- ❌ Never share via email/chat
|
||||
- ✅ Keep local only
|
||||
- ✅ Generate unique secrets per environment
|
||||
- 🔐 Share securely if needed (encrypted vault, 1Password, etc.)
|
||||
|
||||
### Active Config (.env)
|
||||
This is **auto-generated** and **ignored by git**:
|
||||
- ❌ Never commit to git
|
||||
- 🔄 Auto-overwritten by startup scripts
|
||||
- 📝 Edit source files (.env.prod, .env.test) instead
|
||||
|
||||
## Generating Secrets
|
||||
|
||||
For `JWT_SECRET_KEY` and `AUTH_SESSION_SECRET`:
|
||||
```bash
|
||||
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
```
|
||||
|
||||
Generate **different** secrets for dev, test, and production!
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### First Time Setup
|
||||
```bash
|
||||
# 1. Copy template
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
|
||||
# 2. Fill credentials
|
||||
vim backend/.env.prod
|
||||
|
||||
# 3. Start
|
||||
./start-prod.sh
|
||||
```
|
||||
|
||||
### Changing Configuration
|
||||
```bash
|
||||
# 1. Edit source file
|
||||
vim backend/.env.prod
|
||||
|
||||
# 2. Restart to apply
|
||||
./start-prod.sh
|
||||
```
|
||||
|
||||
### Production Deployment
|
||||
```bash
|
||||
# 1. Copy template
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
|
||||
# 2. Fill in PRODUCTION values
|
||||
vim backend/.env.prod
|
||||
|
||||
# 3. Generate NEW secrets
|
||||
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
# 4. Start backend
|
||||
./start-backend.sh start
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Wrong database" error
|
||||
Check that you're using the correct startup script:
|
||||
- Production: `./start-prod.sh` (uses `.env.prod`)
|
||||
- Test: `./start-test.sh` (uses `.env.test`)
|
||||
|
||||
### ".env.prod not found" error
|
||||
First-time setup required:
|
||||
```bash
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
vim backend/.env.prod # Fill in your credentials
|
||||
```
|
||||
|
||||
### Changes not taking effect
|
||||
The `.env` file is regenerated on each start. Edit the source file (`.env.prod` or `.env.test`) instead.
|
||||
|
||||
### Checking what will be committed
|
||||
```bash
|
||||
git status backend/.env*
|
||||
# Should show:
|
||||
# modified: .env.prod.example (if you changed template)
|
||||
# nothing else!
|
||||
```
|
||||
|
||||
## Team Sharing
|
||||
|
||||
**Templates only** are committed to git:
|
||||
- Share configuration structure via `.env*.example`
|
||||
- Each developer creates their own `.env.prod` from template
|
||||
- Never commit actual credentials
|
||||
- Use secure channels for sharing sensitive values (1Password, encrypted vault, etc.)
|
||||
102
backend/QUICK-ENV-REFERENCE.md
Normal file
102
backend/QUICK-ENV-REFERENCE.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Quick Environment Reference
|
||||
|
||||
## 🔒 SECURITY FIRST
|
||||
|
||||
**All `.env*` files (except `.env*.example`) contain real credentials and are NEVER committed to git!**
|
||||
|
||||
## 🚀 First-Time Setup
|
||||
|
||||
```bash
|
||||
# 1. Copy template with real credentials
|
||||
cp backend/.env.prod.example backend/.env.prod
|
||||
|
||||
# 2. Edit with YOUR credentials
|
||||
vim backend/.env.prod
|
||||
|
||||
# 3. Fill in the placeholders:
|
||||
# - ORACLE_PASSWORD
|
||||
# - JWT_SECRET_KEY
|
||||
# - AUTH_SESSION_SECRET
|
||||
# - TELEGRAM_BOT_TOKEN
|
||||
# - SMTP_PASSWORD
|
||||
|
||||
# 4. Start production
|
||||
./start-prod.sh
|
||||
```
|
||||
|
||||
## 📋 Daily Usage
|
||||
|
||||
```bash
|
||||
# Production (uses .env.prod automatically)
|
||||
./start-prod.sh
|
||||
|
||||
# Test Environment (uses .env.test automatically)
|
||||
./start-test.sh
|
||||
|
||||
# Quick Restart (uses existing .env)
|
||||
./start-backend.sh restart
|
||||
```
|
||||
|
||||
## ✏️ Changing Configuration
|
||||
|
||||
```bash
|
||||
# 1. Edit the source file (NOT .env!)
|
||||
vim backend/.env.prod # Production
|
||||
vim backend/.env.test # Test
|
||||
|
||||
# 2. Restart to apply changes
|
||||
./start-prod.sh
|
||||
```
|
||||
|
||||
## 📁 Which File to Edit?
|
||||
|
||||
| You Want To... | Edit This File |
|
||||
|----------------|----------------|
|
||||
| Change dev database password | `backend/.env.prod` |
|
||||
| Update test server settings | `backend/.env.test` |
|
||||
| Add new environment variable | Templates: `.env*.example` + your `.env.prod`/`.env.test` |
|
||||
| Create production config | Copy `.env.prod.example` to `.env.prod` and fill secrets |
|
||||
|
||||
## 🔑 Generating Secrets
|
||||
|
||||
```bash
|
||||
# For JWT_SECRET_KEY and AUTH_SESSION_SECRET
|
||||
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
```
|
||||
|
||||
**Generate DIFFERENT secrets for each environment (dev, test, prod)!**
|
||||
|
||||
## ⚠️ Important
|
||||
|
||||
- **Never edit** `backend/.env` directly (it's auto-generated!)
|
||||
- **Always edit** `backend/.env.prod` or `.env.test`
|
||||
- **Never commit** `.env`, `.env.prod`, `.env.test`, `.env.prod`
|
||||
- **Only commit** `.env*.example` (templates with placeholders)
|
||||
- Restart after changes for them to take effect
|
||||
|
||||
## 🛡️ Git Behavior
|
||||
|
||||
| File | Git Status | Contains |
|
||||
|------|-----------|----------|
|
||||
| `.env.prod.example` | ✅ Committed | Template (placeholders) |
|
||||
| `.env.test.example` | ✅ Committed | Template (placeholders) |
|
||||
| `.env.prod.example` | ✅ Committed | Template (placeholders) |
|
||||
| `.env.example` | ✅ Committed | Generic template |
|
||||
| `.env.prod` | ❌ Ignored | **Real dev credentials** |
|
||||
| `.env.test` | ❌ Ignored | **Real test credentials** |
|
||||
| `.env.prod` | ❌ Ignored | **Real prod credentials** |
|
||||
| `.env` | ❌ Ignored | Auto-generated (current) |
|
||||
|
||||
## ✅ Quick Check
|
||||
|
||||
```bash
|
||||
# See what git will commit
|
||||
git status backend/.env*
|
||||
|
||||
# Should show ONLY .env*.example files
|
||||
# If .env.prod or .env.test appear, they're NOT properly ignored!
|
||||
```
|
||||
|
||||
## 📖 More Info
|
||||
|
||||
See `backend/ENV-SETUP.md` for complete documentation.
|
||||
0
backend/__init__.py
Normal file
0
backend/__init__.py
Normal file
173
backend/config.py
Normal file
173
backend/config.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Unified Configuration for ROA2WEB Backend
|
||||
Consolidates settings from Reports, Data Entry, and Telegram modules
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class UnifiedSettings(BaseSettings):
|
||||
"""Unified application settings for all modules."""
|
||||
|
||||
# ============================================================================
|
||||
# GENERAL APPLICATION SETTINGS
|
||||
# ============================================================================
|
||||
app_name: str = "ROA2WEB Unified Backend"
|
||||
app_version: str = "1.0.0"
|
||||
debug: bool = False
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
|
||||
# ============================================================================
|
||||
# ORACLE DATABASE (Shared by all modules)
|
||||
# ============================================================================
|
||||
oracle_user: str = ""
|
||||
oracle_password: str = ""
|
||||
oracle_host: str = "localhost"
|
||||
oracle_port: int = 1526
|
||||
oracle_sid: str = "ROA"
|
||||
|
||||
# ============================================================================
|
||||
# JWT AUTHENTICATION (Shared by all modules)
|
||||
# ============================================================================
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
# ============================================================================
|
||||
# SESSION SECURITY - EMAIL 2FA (Telegram module)
|
||||
# ============================================================================
|
||||
auth_session_secret: str = "change-me-in-production"
|
||||
|
||||
# ============================================================================
|
||||
# CORS
|
||||
# ============================================================================
|
||||
cors_origins: str = "http://localhost:3000,http://localhost:5173"
|
||||
|
||||
# ============================================================================
|
||||
# REPORTS MODULE - CACHE CONFIGURATION
|
||||
# ============================================================================
|
||||
reports_cache_enabled: bool = True
|
||||
reports_cache_type: str = "hybrid"
|
||||
reports_cache_sqlite_path: str = "./data/cache/roa2web_cache.db"
|
||||
reports_cache_memory_max_size: int = 1000
|
||||
reports_cache_default_ttl: int = 900
|
||||
|
||||
# Cache TTL per type (seconds)
|
||||
reports_cache_ttl_schema: int = 86400
|
||||
reports_cache_ttl_companies: int = 1800
|
||||
reports_cache_ttl_dashboard_summary: int = 1800
|
||||
reports_cache_ttl_dashboard_trends: int = 1800
|
||||
reports_cache_ttl_invoices: int = 600
|
||||
reports_cache_ttl_invoices_summary: int = 900
|
||||
reports_cache_ttl_treasury: int = 600
|
||||
|
||||
# Cache maintenance
|
||||
reports_cache_cleanup_interval: int = 3600
|
||||
reports_cache_auto_invalidate: bool = False
|
||||
reports_cache_check_interval: int = 300
|
||||
reports_cache_track_performance: bool = True
|
||||
reports_cache_benchmark_on_startup: bool = False
|
||||
|
||||
# ============================================================================
|
||||
# DATA ENTRY MODULE - CONFIGURATION
|
||||
# ============================================================================
|
||||
data_entry_sqlite_database_path: str = "data/receipts/receipts.db"
|
||||
data_entry_upload_path: str = "data/receipts/uploads"
|
||||
data_entry_max_upload_size_mb: int = 10
|
||||
data_entry_allowed_mime_types: List[str] = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM MODULE - BOT CONFIGURATION
|
||||
# ============================================================================
|
||||
telegram_bot_token: str = ""
|
||||
telegram_smtp_host: str = ""
|
||||
telegram_smtp_port: int = 587
|
||||
telegram_smtp_user: str = ""
|
||||
telegram_smtp_password: str = ""
|
||||
telegram_smtp_from_email: str = ""
|
||||
telegram_smtp_from_name: str = "ROA2WEB"
|
||||
telegram_smtp_use_tls: bool = True
|
||||
telegram_email_max_retries: int = 3
|
||||
telegram_email_retry_delay: float = 2.0
|
||||
telegram_sqlite_database_path: str = "data/telegram/telegram.db"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
case_sensitive = False
|
||||
|
||||
# ============================================================================
|
||||
# COMPUTED PROPERTIES
|
||||
# ============================================================================
|
||||
|
||||
@property
|
||||
def oracle_dsn(self) -> str:
|
||||
"""Get Oracle DSN string."""
|
||||
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
"""Get CORS origins as list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
# Data Entry properties
|
||||
@property
|
||||
def data_entry_database_url(self) -> str:
|
||||
"""Get SQLite database URL for async (Data Entry)."""
|
||||
return f"sqlite+aiosqlite:///{self.data_entry_sqlite_database_path}"
|
||||
|
||||
@property
|
||||
def data_entry_sync_database_url(self) -> str:
|
||||
"""Get SQLite database URL for sync operations (Alembic)."""
|
||||
return f"sqlite:///{self.data_entry_sqlite_database_path}"
|
||||
|
||||
@property
|
||||
def data_entry_upload_path_resolved(self) -> Path:
|
||||
"""Get resolved upload path."""
|
||||
path = Path(self.data_entry_upload_path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def data_entry_max_upload_size_bytes(self) -> int:
|
||||
"""Get max upload size in bytes."""
|
||||
return self.data_entry_max_upload_size_mb * 1024 * 1024
|
||||
|
||||
# Reports cache properties
|
||||
@property
|
||||
def reports_cache_sqlite_path_resolved(self) -> Path:
|
||||
"""Get resolved cache SQLite path."""
|
||||
path = Path(self.reports_cache_sqlite_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
# Telegram properties
|
||||
@property
|
||||
def telegram_sqlite_path_resolved(self) -> Path:
|
||||
"""Get resolved Telegram SQLite path."""
|
||||
path = Path(self.telegram_sqlite_database_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> UnifiedSettings:
|
||||
"""Get cached settings instance."""
|
||||
return UnifiedSettings()
|
||||
|
||||
|
||||
# Convenience instance
|
||||
settings = get_settings()
|
||||
45
backend/data/README.md
Normal file
45
backend/data/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Backend Runtime Data
|
||||
|
||||
This directory contains runtime data generated by the unified backend.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
data/
|
||||
├── cache/ # Reports module cache (hybrid L1+L2)
|
||||
│ └── *.db # SQLite L2 cache database
|
||||
├── receipts/ # Data Entry module data
|
||||
│ ├── *.db # SQLite receipts database
|
||||
│ └── uploads/ # User-uploaded files (receipts, attachments)
|
||||
└── telegram/ # Telegram bot data
|
||||
└── *.db # SQLite bot auth/session database
|
||||
```
|
||||
|
||||
## Git Behavior
|
||||
|
||||
- **Ignored**: All `*.db` files and `uploads/` contents
|
||||
- **Committed**: Only `.gitkeep` files to preserve directory structure
|
||||
|
||||
## Environment-Specific Databases
|
||||
|
||||
Different environments use separate databases:
|
||||
|
||||
- **Development** (`.env.prod`):
|
||||
- Cache: `roa2web_cache.db`
|
||||
- Receipts: `receipts_dev.db`
|
||||
- Telegram: `telegram.db`
|
||||
|
||||
- **Test** (`.env.test`):
|
||||
- Cache: `roa2web_cache_test.db`
|
||||
- Receipts: `receipts_test.db`
|
||||
- Telegram: `telegram_test.db`
|
||||
|
||||
- **Production** (`.env.prod`):
|
||||
- Cache: `roa2web_cache_prod.db`
|
||||
- Receipts: `receipts_prod.db`
|
||||
- Telegram: `telegram_prod.db`
|
||||
|
||||
## Auto-Created
|
||||
|
||||
All databases and directories are created automatically on first run.
|
||||
No manual setup required.
|
||||
0
backend/data/cache/.gitkeep
vendored
Normal file
0
backend/data/cache/.gitkeep
vendored
Normal file
0
backend/data/receipts/.gitkeep
Normal file
0
backend/data/receipts/.gitkeep
Normal file
0
backend/data/telegram/.gitkeep
Normal file
0
backend/data/telegram/.gitkeep
Normal file
428
backend/main.py
Normal file
428
backend/main.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
ROA2WEB Unified Backend - Single FastAPI Application
|
||||
Consolidates Reports, Data Entry, and Telegram modules into one process
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Add project root and shared modules to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root)) # Enable 'from backend.xxx import yyy'
|
||||
sys.path.insert(0, str(project_root / "shared")) # Enable 'from shared.xxx import yyy'
|
||||
|
||||
# Import configuration
|
||||
from backend.config import settings
|
||||
|
||||
# Import shared infrastructure
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from shared.auth.middleware import AuthenticationMiddleware
|
||||
from shared.auth.routes import create_auth_router
|
||||
from shared.routes.companies import create_companies_router
|
||||
from shared.routes.calendar import create_calendar_router
|
||||
|
||||
# Import module router factories
|
||||
from backend.modules.reports.routers import create_reports_router
|
||||
from backend.modules.data_entry.routers import create_data_entry_router
|
||||
from backend.modules.telegram.routers import create_telegram_router
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for background tasks
|
||||
telegram_bot_task = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INITIALIZATION FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
async def init_oracle_pool():
|
||||
"""Initialize Oracle connection pool (shared by all modules)."""
|
||||
logger.info("[ORACLE] Initializing connection pool...")
|
||||
await oracle_pool.initialize()
|
||||
logger.info("[ORACLE] ✅ Pool initialized successfully")
|
||||
|
||||
|
||||
async def init_reports_cache():
|
||||
"""Initialize Reports cache system."""
|
||||
logger.info("[REPORTS] Initializing cache system...")
|
||||
try:
|
||||
from backend.modules.reports.cache import init_cache, init_event_monitor, get_cache
|
||||
from backend.modules.reports.cache.config import CacheConfig
|
||||
|
||||
cache_config = CacheConfig.from_env()
|
||||
await init_cache(cache_config)
|
||||
logger.info(f"[REPORTS] ✅ Cache initialized: type={cache_config.cache_type}, enabled={cache_config.enabled}")
|
||||
|
||||
# Initialize event monitor
|
||||
cache = get_cache()
|
||||
await init_event_monitor(cache, cache_config)
|
||||
if cache_config.auto_invalidate_enabled:
|
||||
logger.info("[REPORTS] Event-based auto-invalidation ENABLED")
|
||||
else:
|
||||
logger.info("[REPORTS] Event-based auto-invalidation DISABLED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[REPORTS] ⚠️ Cache initialization error: {e}", exc_info=True)
|
||||
logger.warning("[REPORTS] Continuing without cache")
|
||||
|
||||
|
||||
async def init_data_entry_db():
|
||||
"""Initialize Data Entry SQLite database."""
|
||||
logger.info("[DATA-ENTRY] Initializing SQLite database...")
|
||||
try:
|
||||
from backend.modules.data_entry.db.database import init_db
|
||||
await init_db()
|
||||
logger.info(f"[DATA-ENTRY] ✅ Database initialized: {settings.data_entry_sqlite_database_path}")
|
||||
|
||||
# Ensure upload directory exists
|
||||
settings.data_entry_upload_path_resolved
|
||||
logger.info(f"[DATA-ENTRY] Upload path: {settings.data_entry_upload_path_resolved}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DATA-ENTRY] ❌ Database initialization error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def init_telegram_db():
|
||||
"""Initialize Telegram SQLite database."""
|
||||
logger.info("[TELEGRAM] Initializing SQLite database...")
|
||||
try:
|
||||
from backend.modules.telegram.db import init_database, cleanup_expired_codes, cleanup_expired_sessions, cleanup_expired_email_codes
|
||||
|
||||
await init_database()
|
||||
logger.info(f"[TELEGRAM] ✅ Database initialized: {settings.telegram_sqlite_database_path}")
|
||||
|
||||
# Cleanup expired data
|
||||
expired_codes = await cleanup_expired_codes()
|
||||
expired_sessions = await cleanup_expired_sessions()
|
||||
expired_email_codes = await cleanup_expired_email_codes()
|
||||
logger.info(f"[TELEGRAM] Cleanup: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[TELEGRAM] ❌ Database initialization error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def init_paddle_ocr_background():
|
||||
"""Initialize PaddleOCR in background thread (takes 15-20s)."""
|
||||
try:
|
||||
logger.info("[DATA-ENTRY] Pre-loading OCR engine (background)...")
|
||||
from backend.modules.data_entry.services.ocr_service import ocr_service
|
||||
ocr_service.ocr_engine._init_paddle_lazy()
|
||||
logger.info("[DATA-ENTRY] ✅ OCR engine ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DATA-ENTRY] ⚠️ OCR engine pre-load failed: {e}")
|
||||
|
||||
|
||||
async def run_telegram_bot():
|
||||
"""Run Telegram bot as background task."""
|
||||
logger.info("[TELEGRAM] Starting bot...")
|
||||
try:
|
||||
from telegram.ext import Application, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
||||
from backend.modules.telegram.bot.handlers import (
|
||||
start_command, help_command, clear_command, companies_command,
|
||||
unlink_command, selectcompany_command, dashboard_command, sold_command,
|
||||
facturi_command, trezorerie_command, menu_command, trezorerie_casa_command,
|
||||
trezorerie_banca_command, clienti_command, furnizori_command, evolutie_command,
|
||||
clearcache_command, togglecache_command, handle_text_message, button_callback,
|
||||
error_handler
|
||||
)
|
||||
from backend.modules.telegram.bot.email_handlers import email_login_handler
|
||||
|
||||
# Create Telegram application
|
||||
application = Application.builder().token(settings.telegram_bot_token).build()
|
||||
|
||||
# Register handlers
|
||||
application.add_handler(email_login_handler)
|
||||
application.add_handler(CommandHandler("start", start_command))
|
||||
application.add_handler(CommandHandler("menu", menu_command))
|
||||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("unlink", unlink_command))
|
||||
application.add_handler(CommandHandler("clear", clear_command))
|
||||
application.add_handler(CommandHandler("companies", companies_command))
|
||||
application.add_handler(CommandHandler("selectcompany", selectcompany_command))
|
||||
application.add_handler(CommandHandler("dashboard", dashboard_command))
|
||||
application.add_handler(CommandHandler("sold", sold_command))
|
||||
application.add_handler(CommandHandler("facturi", facturi_command))
|
||||
application.add_handler(CommandHandler("trezorerie", trezorerie_command))
|
||||
application.add_handler(CommandHandler("trezorerie_casa", trezorerie_casa_command))
|
||||
application.add_handler(CommandHandler("trezorerie_banca", trezorerie_banca_command))
|
||||
application.add_handler(CommandHandler("clienti", clienti_command))
|
||||
application.add_handler(CommandHandler("furnizori", furnizori_command))
|
||||
application.add_handler(CommandHandler("evolutie", evolutie_command))
|
||||
application.add_handler(CommandHandler("clearcache", clearcache_command))
|
||||
application.add_handler(CommandHandler("togglecache", togglecache_command))
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_text_message))
|
||||
application.add_handler(CallbackQueryHandler(button_callback))
|
||||
application.add_error_handler(error_handler)
|
||||
|
||||
# Initialize and start
|
||||
await application.initialize()
|
||||
await application.start()
|
||||
await application.updater.start_polling(drop_pending_updates=True)
|
||||
|
||||
bot_info = await application.bot.get_me()
|
||||
logger.info(f"[TELEGRAM] ✅ Bot running: @{bot_info.username}")
|
||||
|
||||
# Keep bot running
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[TELEGRAM] Bot task cancelled, stopping...")
|
||||
if 'application' in locals():
|
||||
await application.updater.stop()
|
||||
await application.stop()
|
||||
await application.shutdown()
|
||||
logger.info("[TELEGRAM] ✅ Bot stopped")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[TELEGRAM] ❌ Bot error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FASTAPI APPLICATION
|
||||
# ============================================================================
|
||||
|
||||
app = FastAPI(
|
||||
title="ROA2WEB Unified Backend",
|
||||
description="Unified FastAPI backend for Reports, Data Entry, and Telegram modules",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STARTUP/SHUTDOWN EVENT HANDLERS
|
||||
# ============================================================================
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Application startup - Initialize all resources."""
|
||||
global telegram_bot_task
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("[STARTUP] ROA2WEB Unified Backend")
|
||||
logger.info("=" * 80)
|
||||
|
||||
try:
|
||||
# Step 1: Initialize Oracle pool (shared by all modules)
|
||||
await init_oracle_pool()
|
||||
|
||||
# Step 2: Parallel initialization of module-specific resources
|
||||
logger.info("[STARTUP] Initializing module resources in parallel...")
|
||||
await asyncio.gather(
|
||||
init_reports_cache(),
|
||||
init_data_entry_db(),
|
||||
init_telegram_db(),
|
||||
)
|
||||
|
||||
# Step 3: Start PaddleOCR initialization in background thread
|
||||
import threading
|
||||
threading.Thread(target=init_paddle_ocr_background, daemon=True).start()
|
||||
|
||||
# Step 4: Start Telegram bot as background task
|
||||
if settings.telegram_bot_token:
|
||||
telegram_bot_task = asyncio.create_task(run_telegram_bot())
|
||||
logger.info("[STARTUP] ✅ Telegram bot task created")
|
||||
else:
|
||||
logger.warning("[STARTUP] ⚠️ TELEGRAM_BOT_TOKEN not set, bot disabled")
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("[STARTUP] ✅ All modules initialized successfully")
|
||||
logger.info(f"[STARTUP] ✅ Server running on http://{settings.api_host}:{settings.api_port}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[STARTUP] ❌ Initialization failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Application shutdown - Cleanup resources."""
|
||||
global telegram_bot_task
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("[SHUTDOWN] Stopping ROA2WEB Unified Backend...")
|
||||
logger.info("=" * 80)
|
||||
|
||||
try:
|
||||
# Stop Telegram bot
|
||||
if telegram_bot_task and not telegram_bot_task.done():
|
||||
logger.info("[SHUTDOWN] Stopping Telegram bot...")
|
||||
telegram_bot_task.cancel()
|
||||
try:
|
||||
await telegram_bot_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop Reports cache event monitor
|
||||
try:
|
||||
from backend.modules.reports.cache import close_cache, get_event_monitor
|
||||
monitor = get_event_monitor()
|
||||
if monitor:
|
||||
await monitor.stop()
|
||||
logger.info("[SHUTDOWN] Reports cache monitor stopped")
|
||||
|
||||
await close_cache()
|
||||
logger.info("[SHUTDOWN] Reports cache closed")
|
||||
except Exception as e:
|
||||
logger.error(f"[SHUTDOWN] Cache error: {e}")
|
||||
|
||||
# Close Oracle pool
|
||||
await oracle_pool.close_pool()
|
||||
logger.info("[SHUTDOWN] Oracle pool closed")
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("[SHUTDOWN] ✅ Shutdown complete")
|
||||
logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SHUTDOWN] Error during shutdown: {e}", exc_info=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MIDDLEWARE
|
||||
# ============================================================================
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins for production deployment
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Authentication middleware
|
||||
app.add_middleware(
|
||||
AuthenticationMiddleware,
|
||||
excluded_paths=[
|
||||
"/", "/docs", "/health", "/redoc", "/openapi.json",
|
||||
"/api/auth/login", "/api/auth/refresh",
|
||||
"/api/telegram/auth/verify-user",
|
||||
"/api/telegram/auth/verify-email",
|
||||
"/api/telegram/auth/login-with-email",
|
||||
"/api/telegram/auth/refresh-token",
|
||||
"/api/telegram/health"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ROUTER REGISTRATION
|
||||
# ============================================================================
|
||||
|
||||
# Module routers with prefixes
|
||||
app.include_router(create_reports_router(), prefix="/api/reports", tags=["reports"])
|
||||
app.include_router(create_data_entry_router(), prefix="/api/data-entry", tags=["data-entry"])
|
||||
app.include_router(create_telegram_router(), prefix="/api/telegram", tags=["telegram"])
|
||||
|
||||
# Shared routers
|
||||
auth_router = create_auth_router(prefix="", tags=["authentication"])
|
||||
app.include_router(auth_router, prefix="/api/auth")
|
||||
|
||||
companies_router = create_companies_router(oracle_pool, tags=["companies"])
|
||||
app.include_router(companies_router, prefix="/api/companies")
|
||||
|
||||
calendar_router = create_calendar_router(oracle_pool, tags=["calendar"])
|
||||
app.include_router(calendar_router, prefix="/api/calendar")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ROOT & HEALTH ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint - API information."""
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.app_version,
|
||||
"status": "running",
|
||||
"modules": ["reports", "data-entry", "telegram"],
|
||||
"docs": "/docs",
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint with module status."""
|
||||
health_status = {
|
||||
"api": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"modules": {}
|
||||
}
|
||||
|
||||
# Check Oracle connection
|
||||
try:
|
||||
async with oracle_pool.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1 FROM DUAL")
|
||||
health_status["modules"]["oracle"] = "connected"
|
||||
except Exception as e:
|
||||
health_status["modules"]["oracle"] = f"error: {str(e)}"
|
||||
|
||||
# Check Reports cache
|
||||
try:
|
||||
from backend.modules.reports.cache import get_cache
|
||||
cache = get_cache()
|
||||
health_status["modules"]["reports_cache"] = "initialized" if cache else "disabled"
|
||||
except Exception as e:
|
||||
health_status["modules"]["reports_cache"] = f"error: {str(e)}"
|
||||
|
||||
# Check Data Entry DB
|
||||
try:
|
||||
db_path = Path(settings.data_entry_sqlite_database_path)
|
||||
health_status["modules"]["data_entry_db"] = "exists" if db_path.exists() else "missing"
|
||||
except Exception as e:
|
||||
health_status["modules"]["data_entry_db"] = f"error: {str(e)}"
|
||||
|
||||
# Check Telegram bot
|
||||
global telegram_bot_task
|
||||
if telegram_bot_task:
|
||||
if telegram_bot_task.done():
|
||||
health_status["modules"]["telegram_bot"] = "stopped"
|
||||
else:
|
||||
health_status["modules"]["telegram_bot"] = "running"
|
||||
else:
|
||||
health_status["modules"]["telegram_bot"] = "disabled"
|
||||
|
||||
return health_status
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"backend.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=False,
|
||||
log_level="info"
|
||||
)
|
||||
0
backend/modules/__init__.py
Normal file
0
backend/modules/__init__.py
Normal file
0
backend/modules/data_entry/__init__.py
Normal file
0
backend/modules/data_entry/__init__.py
Normal file
96
backend/modules/data_entry/config.py
Normal file
96
backend/modules/data_entry/config.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Application configuration using pydantic-settings."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# App info
|
||||
app_name: str = "Data Entry API"
|
||||
app_version: str = "1.0.0"
|
||||
debug: bool = False
|
||||
|
||||
# API
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8003
|
||||
|
||||
# SQLite Database
|
||||
sqlite_database_path: str = "data/receipts/receipts.db"
|
||||
|
||||
# File uploads
|
||||
upload_path: str = "data/uploads"
|
||||
max_upload_size_mb: int = 10
|
||||
allowed_mime_types: List[str] = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
]
|
||||
|
||||
# Oracle Database (for nomenclatures)
|
||||
oracle_user: str = ""
|
||||
oracle_password: str = ""
|
||||
oracle_host: str = "localhost"
|
||||
oracle_port: int = 1526
|
||||
oracle_sid: str = "ROA"
|
||||
|
||||
# JWT Authentication
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expire_minutes: int = 480
|
||||
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:3010,http://localhost:3000"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""Get SQLite database URL for async."""
|
||||
return f"sqlite+aiosqlite:///{self.sqlite_database_path}"
|
||||
|
||||
@property
|
||||
def sync_database_url(self) -> str:
|
||||
"""Get SQLite database URL for sync operations (Alembic)."""
|
||||
return f"sqlite:///{self.sqlite_database_path}"
|
||||
|
||||
@property
|
||||
def upload_path_resolved(self) -> Path:
|
||||
"""Get resolved upload path."""
|
||||
path = Path(self.upload_path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def max_upload_size_bytes(self) -> int:
|
||||
"""Get max upload size in bytes."""
|
||||
return self.max_upload_size_mb * 1024 * 1024
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
"""Get CORS origins as list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
@property
|
||||
def oracle_dsn(self) -> str:
|
||||
"""Get Oracle DSN string."""
|
||||
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Convenience instance
|
||||
settings = get_settings()
|
||||
4
backend/modules/data_entry/db/__init__.py
Normal file
4
backend/modules/data_entry/db/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Database module
|
||||
from .database import get_session, init_db, engine
|
||||
|
||||
__all__ = ["get_session", "init_db", "engine"]
|
||||
10
backend/modules/data_entry/db/crud/__init__.py
Normal file
10
backend/modules/data_entry/db/crud/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# CRUD operations
|
||||
from .receipt import ReceiptCRUD
|
||||
from .attachment import AttachmentCRUD
|
||||
from .accounting_entry import AccountingEntryCRUD
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCRUD",
|
||||
"AttachmentCRUD",
|
||||
"AccountingEntryCRUD",
|
||||
]
|
||||
197
backend/modules/data_entry/db/crud/accounting_entry.py
Normal file
197
backend/modules/data_entry/db/crud/accounting_entry.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""CRUD operations for accounting entries."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry, EntryType
|
||||
from backend.modules.data_entry.schemas.receipt import AccountingEntryCreate, AccountingEntryUpdate
|
||||
|
||||
|
||||
class AccountingEntryCRUD:
|
||||
"""CRUD operations for AccountingEntry model."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
data: AccountingEntryCreate,
|
||||
sort_order: int = 0,
|
||||
is_auto_generated: bool = True,
|
||||
) -> AccountingEntry:
|
||||
"""Create a new accounting entry."""
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=data.entry_type,
|
||||
account_code=data.account_code,
|
||||
account_name=data.account_name,
|
||||
amount=data.amount,
|
||||
partner_id=data.partner_id,
|
||||
cost_center_id=data.cost_center_id,
|
||||
is_auto_generated=is_auto_generated,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
async def create_bulk(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
is_auto_generated: bool = True,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Create multiple accounting entries at once."""
|
||||
created_entries = []
|
||||
|
||||
for idx, entry_data in enumerate(entries):
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=entry_data.entry_type,
|
||||
account_code=entry_data.account_code,
|
||||
account_name=entry_data.account_name,
|
||||
amount=entry_data.amount,
|
||||
partner_id=entry_data.partner_id,
|
||||
cost_center_id=entry_data.cost_center_id,
|
||||
is_auto_generated=is_auto_generated,
|
||||
sort_order=idx,
|
||||
)
|
||||
session.add(entry)
|
||||
created_entries.append(entry)
|
||||
|
||||
await session.commit()
|
||||
|
||||
for entry in created_entries:
|
||||
await session.refresh(entry)
|
||||
|
||||
return created_entries
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
entry_id: int,
|
||||
) -> Optional[AccountingEntry]:
|
||||
"""Get accounting entry by ID."""
|
||||
query = select(AccountingEntry).where(AccountingEntry.id == entry_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_by_receipt_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Get all accounting entries for a receipt."""
|
||||
query = select(AccountingEntry).where(
|
||||
AccountingEntry.receipt_id == receipt_id
|
||||
).order_by(AccountingEntry.sort_order.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
session: AsyncSession,
|
||||
entry: AccountingEntry,
|
||||
data: AccountingEntryUpdate,
|
||||
modified_by: str,
|
||||
) -> AccountingEntry:
|
||||
"""Update an accounting entry."""
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(entry, field, value)
|
||||
|
||||
entry.is_auto_generated = False
|
||||
entry.modified_by = modified_by
|
||||
entry.modified_at = datetime.utcnow()
|
||||
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, entry: AccountingEntry) -> bool:
|
||||
"""Delete an accounting entry."""
|
||||
await session.delete(entry)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
|
||||
"""Delete all accounting entries for a receipt."""
|
||||
query = delete(AccountingEntry).where(AccountingEntry.receipt_id == receipt_id)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
@staticmethod
|
||||
async def replace_all_for_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
modified_by: str,
|
||||
) -> List[AccountingEntry]:
|
||||
"""Replace all entries for a receipt with new ones."""
|
||||
# Delete existing entries
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
|
||||
# Create new entries (marked as manually modified)
|
||||
created_entries = []
|
||||
|
||||
for idx, entry_data in enumerate(entries):
|
||||
entry = AccountingEntry(
|
||||
receipt_id=receipt_id,
|
||||
entry_type=entry_data.entry_type,
|
||||
account_code=entry_data.account_code,
|
||||
account_name=entry_data.account_name,
|
||||
amount=entry_data.amount,
|
||||
partner_id=entry_data.partner_id,
|
||||
cost_center_id=entry_data.cost_center_id,
|
||||
is_auto_generated=False,
|
||||
modified_by=modified_by,
|
||||
modified_at=datetime.utcnow(),
|
||||
sort_order=idx,
|
||||
)
|
||||
session.add(entry)
|
||||
created_entries.append(entry)
|
||||
|
||||
await session.commit()
|
||||
|
||||
for entry in created_entries:
|
||||
await session.refresh(entry)
|
||||
|
||||
return created_entries
|
||||
|
||||
@staticmethod
|
||||
async def validate_entries(entries: List[AccountingEntryCreate]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate accounting entries.
|
||||
Returns (is_valid, error_message).
|
||||
"""
|
||||
if not entries:
|
||||
return False, "At least one entry is required"
|
||||
|
||||
total_debit = sum(
|
||||
e.amount for e in entries if e.entry_type == EntryType.DEBIT
|
||||
)
|
||||
total_credit = sum(
|
||||
e.amount for e in entries if e.entry_type == EntryType.CREDIT
|
||||
)
|
||||
|
||||
# Check balance (debit should equal credit)
|
||||
if abs(total_debit - total_credit) > 0.01:
|
||||
return False, f"Entries not balanced: Debit={total_debit}, Credit={total_credit}"
|
||||
|
||||
# Check for valid account codes
|
||||
for entry in entries:
|
||||
if not entry.account_code or len(entry.account_code) < 3:
|
||||
return False, f"Invalid account code: {entry.account_code}"
|
||||
|
||||
return True, ""
|
||||
140
backend/modules/data_entry/db/crud/attachment.py
Normal file
140
backend/modules/data_entry/db/crud/attachment.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""CRUD operations for receipt attachments."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import aiofiles
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import UploadFile
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptAttachment
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
|
||||
class AttachmentCRUD:
|
||||
"""CRUD operations for ReceiptAttachment model."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_stored_filename(original_filename: str) -> str:
|
||||
"""Generate unique filename for storage."""
|
||||
ext = Path(original_filename).suffix.lower()
|
||||
return f"{uuid.uuid4()}{ext}"
|
||||
|
||||
@staticmethod
|
||||
def _get_upload_path(stored_filename: str) -> Path:
|
||||
"""Get full path for storing file, organized by year/month."""
|
||||
now = datetime.utcnow()
|
||||
relative_path = Path(str(now.year)) / f"{now.month:02d}"
|
||||
full_path = settings.upload_path_resolved / relative_path
|
||||
|
||||
# Ensure directory exists
|
||||
full_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return relative_path / stored_filename
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
file: UploadFile,
|
||||
) -> ReceiptAttachment:
|
||||
"""Create attachment by saving file and creating DB record."""
|
||||
# Generate stored filename
|
||||
stored_filename = AttachmentCRUD._generate_stored_filename(file.filename or "upload")
|
||||
|
||||
# Get relative path
|
||||
relative_path = AttachmentCRUD._get_upload_path(stored_filename)
|
||||
|
||||
# Full path for saving
|
||||
full_path = settings.upload_path_resolved / relative_path
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Validate file size
|
||||
if file_size > settings.max_upload_size_bytes:
|
||||
raise ValueError(f"File too large. Maximum size is {settings.max_upload_size_mb}MB")
|
||||
|
||||
# Validate MIME type
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
if mime_type not in settings.allowed_mime_types:
|
||||
raise ValueError(f"File type not allowed: {mime_type}")
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# Create DB record
|
||||
attachment = ReceiptAttachment(
|
||||
receipt_id=receipt_id,
|
||||
filename=file.filename or "upload",
|
||||
stored_filename=stored_filename,
|
||||
file_path=str(relative_path),
|
||||
file_size=file_size,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
session.add(attachment)
|
||||
await session.commit()
|
||||
await session.refresh(attachment)
|
||||
|
||||
return attachment
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
attachment_id: int,
|
||||
) -> Optional[ReceiptAttachment]:
|
||||
"""Get attachment by ID."""
|
||||
query = select(ReceiptAttachment).where(ReceiptAttachment.id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_by_receipt_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> List[ReceiptAttachment]:
|
||||
"""Get all attachments for a receipt."""
|
||||
query = select(ReceiptAttachment).where(
|
||||
ReceiptAttachment.receipt_id == receipt_id
|
||||
).order_by(ReceiptAttachment.uploaded_at.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
def get_file_path(attachment: ReceiptAttachment) -> Path:
|
||||
"""Get full file path for an attachment."""
|
||||
return settings.upload_path_resolved / attachment.file_path
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, attachment: ReceiptAttachment) -> bool:
|
||||
"""Delete attachment (file and DB record)."""
|
||||
# Delete file
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
# Delete DB record
|
||||
await session.delete(attachment)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
|
||||
"""Delete all attachments for a receipt."""
|
||||
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
|
||||
count = 0
|
||||
|
||||
for attachment in attachments:
|
||||
await AttachmentCRUD.delete(session, attachment)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
324
backend/modules/data_entry/db/crud/receipt.py
Normal file
324
backend/modules/data_entry/db/crud/receipt.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""CRUD operations for receipts."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus
|
||||
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, ReceiptUpdate, ReceiptFilter
|
||||
|
||||
|
||||
def _serialize_tva_breakdown(tva_breakdown: Optional[List[Any]]) -> Optional[str]:
|
||||
"""Serialize TVA breakdown list to JSON string for SQLite storage."""
|
||||
if tva_breakdown is None:
|
||||
return None
|
||||
|
||||
# Convert Decimal to float for JSON serialization
|
||||
serializable = []
|
||||
for entry in tva_breakdown:
|
||||
if hasattr(entry, 'model_dump'):
|
||||
# Pydantic model
|
||||
item = entry.model_dump()
|
||||
elif isinstance(entry, dict):
|
||||
item = entry.copy()
|
||||
else:
|
||||
item = dict(entry)
|
||||
|
||||
# Convert Decimal to float
|
||||
if 'amount' in item and isinstance(item['amount'], Decimal):
|
||||
item['amount'] = float(item['amount'])
|
||||
|
||||
serializable.append(item)
|
||||
|
||||
return json.dumps(serializable)
|
||||
|
||||
|
||||
def _serialize_payment_methods(payment_methods: Optional[List[Any]]) -> Optional[str]:
|
||||
"""Serialize payment methods list to JSON string for SQLite storage."""
|
||||
if payment_methods is None:
|
||||
return None
|
||||
|
||||
serializable = []
|
||||
for pm in payment_methods:
|
||||
if hasattr(pm, 'model_dump'):
|
||||
item = pm.model_dump()
|
||||
elif isinstance(pm, dict):
|
||||
item = pm.copy()
|
||||
else:
|
||||
item = dict(pm)
|
||||
|
||||
# Convert Decimal to float for JSON
|
||||
if 'amount' in item:
|
||||
if hasattr(item['amount'], '__float__'):
|
||||
item['amount'] = float(item['amount'])
|
||||
|
||||
serializable.append(item)
|
||||
|
||||
return json.dumps(serializable)
|
||||
|
||||
|
||||
class ReceiptCRUD:
|
||||
"""CRUD operations for Receipt model."""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
data: ReceiptCreate,
|
||||
created_by: str,
|
||||
) -> Receipt:
|
||||
"""Create a new receipt."""
|
||||
# Get data as dict and serialize tva_breakdown and payment_methods to JSON string
|
||||
receipt_data = data.model_dump()
|
||||
receipt_data['tva_breakdown'] = _serialize_tva_breakdown(receipt_data.get('tva_breakdown'))
|
||||
receipt_data['payment_methods'] = _serialize_payment_methods(receipt_data.get('payment_methods'))
|
||||
|
||||
receipt = Receipt(
|
||||
**receipt_data,
|
||||
created_by=created_by,
|
||||
status=ReceiptStatus.DRAFT,
|
||||
)
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
include_relations: bool = True,
|
||||
) -> Optional[Receipt]:
|
||||
"""Get receipt by ID, optionally with relationships."""
|
||||
query = select(Receipt).where(Receipt.id == receipt_id)
|
||||
|
||||
if include_relations:
|
||||
query = query.options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_list(
|
||||
session: AsyncSession,
|
||||
filters: ReceiptFilter,
|
||||
) -> Tuple[List[Receipt], int]:
|
||||
"""Get paginated list of receipts with filters."""
|
||||
# Base query
|
||||
query = select(Receipt).options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if filters.status:
|
||||
query = query.where(Receipt.status == filters.status)
|
||||
|
||||
if filters.direction:
|
||||
query = query.where(Receipt.direction == filters.direction)
|
||||
|
||||
if filters.company_id:
|
||||
query = query.where(Receipt.company_id == filters.company_id)
|
||||
|
||||
if filters.created_by:
|
||||
query = query.where(Receipt.created_by == filters.created_by)
|
||||
|
||||
if filters.date_from:
|
||||
query = query.where(Receipt.receipt_date >= filters.date_from)
|
||||
|
||||
if filters.date_to:
|
||||
query = query.where(Receipt.receipt_date <= filters.date_to)
|
||||
|
||||
if filters.search:
|
||||
search_term = f"%{filters.search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Receipt.description.ilike(search_term),
|
||||
Receipt.partner_name.ilike(search_term),
|
||||
Receipt.receipt_number.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await session.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Receipt.created_at.desc())
|
||||
offset = (filters.page - 1) * filters.page_size
|
||||
query = query.offset(offset).limit(filters.page_size)
|
||||
|
||||
# Execute
|
||||
result = await session.execute(query)
|
||||
receipts = result.scalars().all()
|
||||
|
||||
return list(receipts), total
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_review(
|
||||
session: AsyncSession,
|
||||
company_id: Optional[int] = None,
|
||||
) -> List[Receipt]:
|
||||
"""Get all receipts pending review."""
|
||||
query = select(Receipt).where(
|
||||
Receipt.status == ReceiptStatus.PENDING_REVIEW
|
||||
).options(
|
||||
selectinload(Receipt.attachments),
|
||||
selectinload(Receipt.entries),
|
||||
)
|
||||
|
||||
if company_id:
|
||||
query = query.where(Receipt.company_id == company_id)
|
||||
|
||||
query = query.order_by(Receipt.submitted_at.asc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
session: AsyncSession,
|
||||
receipt: Receipt,
|
||||
data: ReceiptUpdate,
|
||||
) -> Receipt:
|
||||
"""Update receipt fields."""
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Serialize tva_breakdown and payment_methods to JSON string if present
|
||||
if 'tva_breakdown' in update_data:
|
||||
update_data['tva_breakdown'] = _serialize_tva_breakdown(update_data['tva_breakdown'])
|
||||
if 'payment_methods' in update_data:
|
||||
update_data['payment_methods'] = _serialize_payment_methods(update_data['payment_methods'])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(receipt, field, value)
|
||||
|
||||
receipt.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def update_status(
|
||||
session: AsyncSession,
|
||||
receipt: Receipt,
|
||||
new_status: ReceiptStatus,
|
||||
reviewed_by: Optional[str] = None,
|
||||
rejection_reason: Optional[str] = None,
|
||||
) -> Receipt:
|
||||
"""Update receipt workflow status."""
|
||||
receipt.status = new_status
|
||||
receipt.updated_at = datetime.utcnow()
|
||||
|
||||
if new_status == ReceiptStatus.PENDING_REVIEW:
|
||||
receipt.submitted_at = datetime.utcnow()
|
||||
|
||||
if new_status in [ReceiptStatus.APPROVED, ReceiptStatus.REJECTED]:
|
||||
receipt.reviewed_by = reviewed_by
|
||||
receipt.reviewed_at = datetime.utcnow()
|
||||
|
||||
if new_status == ReceiptStatus.REJECTED:
|
||||
receipt.rejection_reason = rejection_reason
|
||||
|
||||
if new_status == ReceiptStatus.DRAFT:
|
||||
# Reset review fields when moving back to draft
|
||||
receipt.rejection_reason = None
|
||||
|
||||
session.add(receipt)
|
||||
await session.commit()
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Reload with relationships to avoid lazy loading issues with async
|
||||
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def delete(session: AsyncSession, receipt: Receipt) -> bool:
|
||||
"""Delete a receipt (cascade deletes attachments and entries)."""
|
||||
await session.delete(receipt)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def can_edit(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can edit receipt."""
|
||||
# DRAFT and REJECTED receipts can be edited (to fix and resubmit)
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
return False
|
||||
|
||||
# Only creator can edit their own receipts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def can_delete(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can delete receipt."""
|
||||
# Only DRAFT receipts can be deleted
|
||||
if receipt.status != ReceiptStatus.DRAFT:
|
||||
return False
|
||||
|
||||
# Only creator can delete their own drafts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def can_submit(receipt: Receipt, username: str) -> bool:
|
||||
"""Check if user can submit receipt for review."""
|
||||
# Only DRAFT or REJECTED receipts can be submitted
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
return False
|
||||
|
||||
# Only creator can submit their own receipts
|
||||
return receipt.created_by == username
|
||||
|
||||
@staticmethod
|
||||
async def get_stats(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get receipt statistics."""
|
||||
base_query = select(
|
||||
Receipt.status,
|
||||
func.count(Receipt.id).label("count"),
|
||||
func.sum(Receipt.amount).label("total_amount"),
|
||||
).where(
|
||||
Receipt.company_id == company_id
|
||||
)
|
||||
|
||||
if created_by:
|
||||
base_query = base_query.where(Receipt.created_by == created_by)
|
||||
|
||||
query = base_query.group_by(Receipt.status)
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
stats = {
|
||||
"draft": {"count": 0, "amount": 0},
|
||||
"pending_review": {"count": 0, "amount": 0},
|
||||
"approved": {"count": 0, "amount": 0},
|
||||
"rejected": {"count": 0, "amount": 0},
|
||||
"synced": {"count": 0, "amount": 0},
|
||||
"total": {"count": 0, "amount": 0},
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
status_key = row.status.value
|
||||
stats[status_key] = {
|
||||
"count": row.count,
|
||||
"amount": float(row.total_amount or 0),
|
||||
}
|
||||
stats["total"]["count"] += row.count
|
||||
stats["total"]["amount"] += float(row.total_amount or 0)
|
||||
|
||||
return stats
|
||||
49
backend/modules/data_entry/db/database.py
Normal file
49
backend/modules/data_entry/db/database.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Database configuration and session management using SQLModel."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from backend.modules.data_entry.config import settings
|
||||
|
||||
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.debug,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
async_session_maker = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database - create tables if they don't exist."""
|
||||
# Ensure data directory exists
|
||||
db_path = Path(settings.sqlite_database_path)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get async database session for dependency injection."""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
# Convenience function for manual session usage
|
||||
async def get_db_session() -> AsyncSession:
|
||||
"""Get a new database session (manual management)."""
|
||||
return async_session_maker()
|
||||
17
backend/modules/data_entry/db/models/__init__.py
Normal file
17
backend/modules/data_entry/db/models/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Database models
|
||||
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection
|
||||
from .accounting_entry import AccountingEntry, EntryType
|
||||
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
__all__ = [
|
||||
"Receipt",
|
||||
"ReceiptAttachment",
|
||||
"ReceiptStatus",
|
||||
"ReceiptType",
|
||||
"ReceiptDirection",
|
||||
"AccountingEntry",
|
||||
"EntryType",
|
||||
"SyncedSupplier",
|
||||
"LocalSupplier",
|
||||
"SyncedCashRegister",
|
||||
]
|
||||
49
backend/modules/data_entry/db/models/accounting_entry.py
Normal file
49
backend/modules/data_entry/db/models/accounting_entry.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""AccountingEntry SQLModel model for proposed accounting entries."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .receipt import Receipt
|
||||
|
||||
|
||||
class EntryType(str, Enum):
|
||||
"""Type of accounting entry."""
|
||||
DEBIT = "debit"
|
||||
CREDIT = "credit"
|
||||
|
||||
|
||||
class AccountingEntry(SQLModel, table=True):
|
||||
"""Proposed accounting entry for a receipt."""
|
||||
|
||||
__tablename__ = "accounting_entries"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
receipt_id: int = Field(foreign_key="receipts.id", index=True)
|
||||
|
||||
# Account
|
||||
entry_type: EntryType
|
||||
account_code: str = Field(max_length=20) # e.g., 6022, 5311, 4426
|
||||
account_name: Optional[str] = Field(default=None, max_length=200) # Cache: "Cheltuieli combustibil"
|
||||
|
||||
# Amount
|
||||
amount: Decimal = Field(decimal_places=2, max_digits=15)
|
||||
|
||||
# Analytics (optional)
|
||||
partner_id: Optional[int] = Field(default=None)
|
||||
cost_center_id: Optional[int] = Field(default=None)
|
||||
|
||||
# Entry metadata
|
||||
is_auto_generated: bool = Field(default=True) # True if system-generated
|
||||
modified_by: Optional[str] = Field(default=None, max_length=100) # Username if modified
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
# Order for display
|
||||
sort_order: int = Field(default=0)
|
||||
|
||||
# Relationship
|
||||
receipt: Optional["Receipt"] = Relationship(back_populates="entries")
|
||||
46
backend/modules/data_entry/db/models/nomenclature.py
Normal file
46
backend/modules/data_entry/db/models/nomenclature.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Nomenclature models for synced and local data."""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class SyncedSupplier(SQLModel, table=True):
|
||||
"""Suppliers synced from Oracle NOM_PARTENERI."""
|
||||
__tablename__ = "synced_suppliers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
oracle_id: int = Field(index=True) # Original Oracle ID
|
||||
company_id: int = Field(index=True) # Company this supplier belongs to
|
||||
name: str = Field(max_length=200)
|
||||
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True) # CUI/CIF
|
||||
address: Optional[str] = Field(default=None, max_length=500)
|
||||
synced_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class LocalSupplier(SQLModel, table=True):
|
||||
"""Suppliers created locally from OCR (not in Oracle)."""
|
||||
__tablename__ = "local_suppliers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
company_id: int = Field(index=True)
|
||||
name: str = Field(max_length=200)
|
||||
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True)
|
||||
address: Optional[str] = Field(default=None, max_length=500)
|
||||
created_by: str = Field(max_length=100) # Username who created it
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
# Flag to indicate if it should be synced to Oracle later
|
||||
pending_oracle_sync: bool = Field(default=True)
|
||||
|
||||
|
||||
class SyncedCashRegister(SQLModel, table=True):
|
||||
"""Cash registers and bank accounts synced from Oracle."""
|
||||
__tablename__ = "synced_cash_registers"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
oracle_id: int = Field(index=True)
|
||||
company_id: int = Field(index=True)
|
||||
name: str = Field(max_length=100)
|
||||
account_code: str = Field(max_length=20) # 5311, 5121, etc.
|
||||
register_type: str = Field(max_length=10) # 'cash' or 'bank'
|
||||
synced_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
127
backend/modules/data_entry/db/models/receipt.py
Normal file
127
backend/modules/data_entry/db/models/receipt.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Receipt and ReceiptAttachment SQLModel models."""
|
||||
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
|
||||
|
||||
class ReceiptType(str, Enum):
|
||||
"""Type of receipt document."""
|
||||
BON_FISCAL = "bon_fiscal"
|
||||
CHITANTA = "chitanta"
|
||||
|
||||
|
||||
class ReceiptDirection(str, Enum):
|
||||
"""Direction of receipt - expense or income."""
|
||||
CHELTUIALA = "cheltuiala" # Expense (receipt from supplier)
|
||||
INCASARE = "incasare" # Income (receipt issued to client)
|
||||
|
||||
|
||||
class ReceiptStatus(str, Enum):
|
||||
"""Workflow status of receipt."""
|
||||
DRAFT = "draft" # User is filling in data
|
||||
PENDING_REVIEW = "pending_review" # Awaiting accountant approval
|
||||
APPROVED = "approved" # Approved by accountant
|
||||
REJECTED = "rejected" # Rejected by accountant
|
||||
SYNCED = "synced" # Synced to Oracle (Phase 2)
|
||||
|
||||
|
||||
class PaymentMode(str, Enum):
|
||||
"""Payment mode - how the expense was paid."""
|
||||
CASA = "casa" # Numerar firma (5311)
|
||||
BANCA = "banca" # Virament/POS (5121)
|
||||
AVANS_DECONTARE = "avans_decontare" # Decont angajat (542)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accounting_entry import AccountingEntry
|
||||
|
||||
|
||||
class Receipt(SQLModel, table=True):
|
||||
"""Receipt (Bon Fiscal / Chitanta) with approval workflow."""
|
||||
|
||||
__tablename__ = "receipts"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
||||
# Document identification
|
||||
receipt_type: ReceiptType = Field(default=ReceiptType.BON_FISCAL)
|
||||
direction: ReceiptDirection = Field(default=ReceiptDirection.CHELTUIALA)
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
|
||||
# Main data
|
||||
receipt_date: date
|
||||
amount: Decimal = Field(decimal_places=2, max_digits=15)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# TVA info (extracted from OCR) - stored as JSON for multiple entries
|
||||
tva_breakdown: Optional[str] = Field(default=None, max_length=1000) # JSON: [{"code":"A","percent":19,"amount":"15.20"}]
|
||||
tva_total: Optional[Decimal] = Field(default=None, decimal_places=2, max_digits=15)
|
||||
items_count: Optional[int] = Field(default=None)
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Expense type (for auto-generating accounting entries)
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
|
||||
# Oracle references (nomenclatures)
|
||||
company_id: int
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200) # Supplier name from OCR/selection
|
||||
cui: Optional[str] = Field(default=None, max_length=20) # Fiscal code from OCR
|
||||
ocr_raw_text: Optional[str] = Field(default=None) # Raw OCR text for debugging
|
||||
payment_methods: Optional[str] = Field(default=None, max_length=500) # JSON: [{"method":"CARD","amount":"50.00"}]
|
||||
cash_register_id: Optional[int] = Field(default=None) # Cash/Bank ID from Oracle
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100) # Cache for display
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20) # Account code (5311, 5121)
|
||||
payment_mode: Optional[str] = Field(default=None, max_length=20) # PaymentMode value: casa/banca/avans_decontare
|
||||
|
||||
# Workflow
|
||||
status: ReceiptStatus = Field(default=ReceiptStatus.DRAFT)
|
||||
created_by: str = Field(max_length=100) # Username of creator
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
submitted_at: Optional[datetime] = Field(default=None) # When submitted for approval
|
||||
|
||||
# Approval
|
||||
reviewed_by: Optional[str] = Field(default=None, max_length=100) # Accountant username
|
||||
reviewed_at: Optional[datetime] = Field(default=None)
|
||||
rejection_reason: Optional[str] = Field(default=None, max_length=500) # Reason for rejection
|
||||
|
||||
# Phase 2 - Oracle sync
|
||||
oracle_synced_at: Optional[datetime] = Field(default=None)
|
||||
oracle_act_id: Optional[int] = Field(default=None)
|
||||
oracle_error: Optional[str] = Field(default=None, max_length=500)
|
||||
|
||||
# Relationships
|
||||
attachments: List["ReceiptAttachment"] = Relationship(
|
||||
back_populates="receipt",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
entries: List["AccountingEntry"] = Relationship(
|
||||
back_populates="receipt",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
|
||||
class ReceiptAttachment(SQLModel, table=True):
|
||||
"""Attachment (photo or PDF) for a receipt."""
|
||||
|
||||
__tablename__ = "receipt_attachments"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
receipt_id: int = Field(foreign_key="receipts.id", index=True)
|
||||
|
||||
# File info
|
||||
filename: str = Field(max_length=255) # Original filename
|
||||
stored_filename: str = Field(max_length=255) # Filename on disk (UUID)
|
||||
file_path: str = Field(max_length=500) # Relative path
|
||||
file_size: int # Size in bytes
|
||||
mime_type: str = Field(max_length=100) # MIME type (image/jpeg, application/pdf)
|
||||
uploaded_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Relationship
|
||||
receipt: Optional[Receipt] = Relationship(back_populates="attachments")
|
||||
89
backend/modules/data_entry/migrations/env.py
Normal file
89
backend/modules/data_entry/migrations/env.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Alembic environment configuration."""
|
||||
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Import all models to ensure they're registered with SQLModel
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
|
||||
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url from environment variable if set
|
||||
db_path = os.getenv("SQLITE_DATABASE_PATH", "data/receipts/receipts.db")
|
||||
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
27
backend/modules/data_entry/migrations/script.py.mako
Normal file
27
backend/modules/data_entry/migrations/script.py.mako
Normal file
@@ -0,0 +1,27 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Initial receipts schema
|
||||
|
||||
Revision ID: 001_initial
|
||||
Revises:
|
||||
Create Date: 2024-12-11
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001_initial'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create receipts table
|
||||
op.create_table(
|
||||
'receipts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_type', sa.Enum('BON_FISCAL', 'CHITANTA', name='receipttype'), nullable=False),
|
||||
sa.Column('direction', sa.Enum('CHELTUIALA', 'INCASARE', name='receiptdirection'), nullable=False),
|
||||
sa.Column('receipt_number', sa.String(length=50), nullable=True),
|
||||
sa.Column('receipt_series', sa.String(length=20), nullable=True),
|
||||
sa.Column('receipt_date', sa.Date(), nullable=False),
|
||||
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
|
||||
sa.Column('description', sa.String(length=500), nullable=True),
|
||||
sa.Column('expense_type_code', sa.String(length=20), nullable=True),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('partner_name', sa.String(length=200), nullable=True),
|
||||
sa.Column('cash_register_id', sa.Integer(), nullable=True),
|
||||
sa.Column('cash_register_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('cash_register_account', sa.String(length=20), nullable=True),
|
||||
sa.Column('status', sa.Enum('DRAFT', 'PENDING_REVIEW', 'APPROVED', 'REJECTED', 'SYNCED', name='receiptstatus'), nullable=False),
|
||||
sa.Column('created_by', sa.String(length=100), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('submitted_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('reviewed_by', sa.String(length=100), nullable=True),
|
||||
sa.Column('reviewed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('rejection_reason', sa.String(length=500), nullable=True),
|
||||
sa.Column('oracle_synced_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('oracle_act_id', sa.Integer(), nullable=True),
|
||||
sa.Column('oracle_error', sa.String(length=500), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_receipts_company_id'), 'receipts', ['company_id'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_status'), 'receipts', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_created_by'), 'receipts', ['created_by'], unique=False)
|
||||
op.create_index(op.f('ix_receipts_receipt_date'), 'receipts', ['receipt_date'], unique=False)
|
||||
|
||||
# Create receipt_attachments table
|
||||
op.create_table(
|
||||
'receipt_attachments',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('stored_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=False),
|
||||
sa.Column('uploaded_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_receipt_attachments_receipt_id'), 'receipt_attachments', ['receipt_id'], unique=False)
|
||||
|
||||
# Create accounting_entries table
|
||||
op.create_table(
|
||||
'accounting_entries',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_id', sa.Integer(), nullable=False),
|
||||
sa.Column('entry_type', sa.Enum('DEBIT', 'CREDIT', name='entrytype'), nullable=False),
|
||||
sa.Column('account_code', sa.String(length=20), nullable=False),
|
||||
sa.Column('account_name', sa.String(length=200), nullable=True),
|
||||
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('cost_center_id', sa.Integer(), nullable=True),
|
||||
sa.Column('is_auto_generated', sa.Boolean(), nullable=False),
|
||||
sa.Column('modified_by', sa.String(length=100), nullable=True),
|
||||
sa.Column('modified_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_accounting_entries_receipt_id'), 'accounting_entries', ['receipt_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_accounting_entries_receipt_id'), table_name='accounting_entries')
|
||||
op.drop_table('accounting_entries')
|
||||
|
||||
op.drop_index(op.f('ix_receipt_attachments_receipt_id'), table_name='receipt_attachments')
|
||||
op.drop_table('receipt_attachments')
|
||||
|
||||
op.drop_index(op.f('ix_receipts_receipt_date'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_created_by'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_status'), table_name='receipts')
|
||||
op.drop_index(op.f('ix_receipts_company_id'), table_name='receipts')
|
||||
op.drop_table('receipts')
|
||||
|
||||
# Drop enums (SQLite doesn't actually use these, but for consistency)
|
||||
op.execute("DROP TYPE IF EXISTS receipttype")
|
||||
op.execute("DROP TYPE IF EXISTS receiptdirection")
|
||||
op.execute("DROP TYPE IF EXISTS receiptstatus")
|
||||
op.execute("DROP TYPE IF EXISTS entrytype")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_tva_breakdown_to_receipt
|
||||
|
||||
Revision ID: 1cfb423c6953
|
||||
Revises: 001_initial
|
||||
Create Date: 2025-12-12 14:04:22.464289+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1cfb423c6953'
|
||||
down_revision: Union[str, None] = '001_initial'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add TVA-related columns to receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tva_breakdown', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True))
|
||||
batch_op.add_column(sa.Column('tva_total', sa.Numeric(precision=15, scale=2), nullable=True))
|
||||
batch_op.add_column(sa.Column('items_count', sa.Integer(), nullable=True))
|
||||
batch_op.add_column(sa.Column('vendor_address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove TVA-related columns from receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('vendor_address')
|
||||
batch_op.drop_column('items_count')
|
||||
batch_op.drop_column('tva_total')
|
||||
batch_op.drop_column('tva_breakdown')
|
||||
@@ -0,0 +1,89 @@
|
||||
"""add nomenclature tables
|
||||
|
||||
Revision ID: 3a653da79002
|
||||
Revises: 1cfb423c6953
|
||||
Create Date: 2025-12-13 00:28:05.719430+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3a653da79002'
|
||||
down_revision: Union[str, None] = '1cfb423c6953'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('local_suppliers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
||||
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||
sa.Column('created_by', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('pending_oracle_sync', sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_local_suppliers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_local_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
|
||||
|
||||
op.create_table('synced_cash_registers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('oracle_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
sa.Column('account_code', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False),
|
||||
sa.Column('register_type', sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False),
|
||||
sa.Column('synced_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_synced_cash_registers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_cash_registers_oracle_id'), ['oracle_id'], unique=False)
|
||||
|
||||
op.create_table('synced_suppliers',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('oracle_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
||||
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||
sa.Column('synced_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_company_id'), ['company_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_synced_suppliers_oracle_id'), ['oracle_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_oracle_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_fiscal_code'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_suppliers_company_id'))
|
||||
|
||||
op.drop_table('synced_suppliers')
|
||||
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_oracle_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_company_id'))
|
||||
|
||||
op.drop_table('synced_cash_registers')
|
||||
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_local_suppliers_fiscal_code'))
|
||||
batch_op.drop_index(batch_op.f('ix_local_suppliers_company_id'))
|
||||
|
||||
op.drop_table('local_suppliers')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add_ocr_fields_to_receipt
|
||||
|
||||
Revision ID: 4b8e5f2a1d93
|
||||
Revises: 3a653da79002
|
||||
Create Date: 2025-12-15 10:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4b8e5f2a1d93'
|
||||
down_revision: Union[str, None] = '3a653da79002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add OCR-related columns to receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('cui', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
|
||||
batch_op.add_column(sa.Column('ocr_raw_text', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('payment_methods', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove OCR-related columns from receipts table
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('payment_methods')
|
||||
batch_op.drop_column('ocr_raw_text')
|
||||
batch_op.drop_column('cui')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Remove partner_id from receipts - supplier data is text-only
|
||||
|
||||
Revision ID: 20251215_remove_partner_id
|
||||
Revises: 20251216_payment_mode
|
||||
Create Date: 2025-12-15
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '20251215_remove_partner_id'
|
||||
down_revision: Union[str, None] = '20251216_payment_mode'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove partner_id column - supplier data is now text-only (partner_name, cui)."""
|
||||
# Drop the partner_id column
|
||||
op.drop_column('receipts', 'partner_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Re-add partner_id column."""
|
||||
op.add_column('receipts', sa.Column('partner_id', sa.Integer(), nullable=True))
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Add payment_mode field to receipts table.
|
||||
|
||||
Revision ID: 20251216_payment_mode
|
||||
Revises: 4b8e5f2a1d93
|
||||
Create Date: 2024-12-16
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '20251216_payment_mode'
|
||||
down_revision = '4b8e5f2a1d93'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add payment_mode column and migrate existing data."""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('payment_mode', sa.String(length=20), nullable=True))
|
||||
|
||||
# Migrate existing data based on cash_register_account
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'casa'
|
||||
WHERE cash_register_account LIKE '531%' AND payment_mode IS NULL
|
||||
""")
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'banca'
|
||||
WHERE cash_register_account LIKE '512%' AND payment_mode IS NULL
|
||||
""")
|
||||
op.execute("""
|
||||
UPDATE receipts
|
||||
SET payment_mode = 'avans_decontare'
|
||||
WHERE cash_register_account LIKE '542%' AND payment_mode IS NULL
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove payment_mode column."""
|
||||
with op.batch_alter_table('receipts', schema=None) as batch_op:
|
||||
batch_op.drop_column('payment_mode')
|
||||
30
backend/modules/data_entry/routers/__init__.py
Normal file
30
backend/modules/data_entry/routers/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Data Entry module router factory."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
def create_data_entry_router() -> APIRouter:
|
||||
"""
|
||||
Create and configure Data Entry module router.
|
||||
|
||||
Includes all data entry endpoints:
|
||||
- /receipts - Receipt CRUD and workflow
|
||||
- /ocr - OCR processing for receipts
|
||||
- /nomenclature - Nomenclature syncing from Oracle
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for data entry module
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
# Import routers here to avoid circular imports
|
||||
from .receipts import router as receipts_router
|
||||
from .ocr import router as ocr_router
|
||||
from .nomenclature import router as nomenclature_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
|
||||
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
|
||||
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
|
||||
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
|
||||
|
||||
return router
|
||||
254
backend/modules/data_entry/routers/nomenclature.py
Normal file
254
backend/modules/data_entry/routers/nomenclature.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Nomenclature API endpoints."""
|
||||
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.services.sync_service import SyncService
|
||||
|
||||
# Import auth dependencies
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# Path setup handled by main.py - this is redundant
|
||||
# project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
# sys.path.insert(0, str(project_root / "shared"))
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Selected Company Dependency ============
|
||||
|
||||
async def get_selected_company(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
x_selected_company: Annotated[Optional[str], Header()] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get selected company from X-Selected-Company header.
|
||||
Validates user access. Falls back to first company if no header.
|
||||
"""
|
||||
if x_selected_company:
|
||||
try:
|
||||
company_id = int(x_selected_company)
|
||||
except ValueError:
|
||||
raise HTTPException(400, f"Invalid company ID: {x_selected_company}")
|
||||
|
||||
if str(company_id) in current_user.companies:
|
||||
return company_id
|
||||
raise HTTPException(403, f"Nu aveți acces la firma {company_id}")
|
||||
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise HTTPException(400, "Nu aveți nicio firmă asignată")
|
||||
|
||||
|
||||
SelectedCompany = Annotated[int, Depends(get_selected_company)]
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class SupplierSearchResult(BaseModel):
|
||||
found: bool
|
||||
supplier: Optional[dict] = None
|
||||
source: str # 'synced', 'local', 'not_found'
|
||||
|
||||
|
||||
class LocalSupplierCreate(BaseModel):
|
||||
name: str
|
||||
fiscal_code: Optional[str] = None
|
||||
address: Optional[str] = None
|
||||
|
||||
|
||||
class LocalSupplierResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
fiscal_code: Optional[str]
|
||||
address: Optional[str]
|
||||
is_local: bool = True
|
||||
|
||||
|
||||
class SyncResult(BaseModel):
|
||||
synced: int
|
||||
errors: int
|
||||
message: str
|
||||
|
||||
|
||||
class SupplierOption(BaseModel):
|
||||
id: int
|
||||
oracle_id: Optional[int] = None
|
||||
name: str
|
||||
fiscal_code: Optional[str]
|
||||
source: str # 'synced' or 'local'
|
||||
|
||||
|
||||
class CashRegisterOption(BaseModel):
|
||||
id: int
|
||||
oracle_id: int
|
||||
name: str
|
||||
account_code: str
|
||||
register_type: str
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/suppliers/search", response_model=SupplierSearchResult)
|
||||
async def search_supplier(
|
||||
fiscal_code: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Search for supplier by fiscal code or name."""
|
||||
if not fiscal_code and not name:
|
||||
raise HTTPException(status_code=400, detail="Provide fiscal_code or name")
|
||||
|
||||
cid = company_id or selected_company
|
||||
|
||||
found, supplier, source = await SyncService.search_supplier(
|
||||
session, cid, fiscal_code, name
|
||||
)
|
||||
|
||||
return SupplierSearchResult(found=found, supplier=supplier, source=source)
|
||||
|
||||
|
||||
@router.get("/suppliers", response_model=List[SupplierOption])
|
||||
async def get_suppliers(
|
||||
search: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all suppliers (synced + local) for dropdown/autocomplete."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
suppliers = await SyncService.get_all_suppliers(session, cid, search)
|
||||
|
||||
return [
|
||||
SupplierOption(
|
||||
id=s["id"],
|
||||
oracle_id=s.get("oracle_id"),
|
||||
name=s["name"],
|
||||
fiscal_code=s.get("fiscal_code"),
|
||||
source=s["source"]
|
||||
)
|
||||
for s in suppliers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/suppliers/local", response_model=LocalSupplierResponse)
|
||||
async def create_local_supplier(
|
||||
data: LocalSupplierCreate,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a local supplier from OCR data."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
supplier = await SyncService.create_local_supplier(
|
||||
session, cid, data.name, data.fiscal_code, data.address, current_user.username
|
||||
)
|
||||
|
||||
return LocalSupplierResponse(
|
||||
id=supplier.id,
|
||||
name=supplier.name,
|
||||
fiscal_code=supplier.fiscal_code,
|
||||
address=supplier.address,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cash-registers", response_model=List[CashRegisterOption])
|
||||
async def get_cash_registers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all cash registers for a company."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
registers = await SyncService.get_all_cash_registers(session, cid)
|
||||
|
||||
return [
|
||||
CashRegisterOption(
|
||||
id=r["id"],
|
||||
oracle_id=r["oracle_id"],
|
||||
name=r["name"],
|
||||
account_code=r["account_code"],
|
||||
register_type=r["register_type"]
|
||||
)
|
||||
for r in registers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/sync/suppliers", response_model=SyncResult)
|
||||
async def sync_suppliers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Manually trigger supplier sync from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
synced, errors = await SyncService.sync_suppliers(session, cid)
|
||||
|
||||
return SyncResult(
|
||||
synced=synced,
|
||||
errors=errors,
|
||||
message=f"Synced {synced} suppliers with {errors} errors"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/cash-registers", response_model=SyncResult)
|
||||
async def sync_cash_registers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Manually trigger cash register sync from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
synced, errors = await SyncService.sync_cash_registers(session, cid)
|
||||
|
||||
return SyncResult(
|
||||
synced=synced,
|
||||
errors=errors,
|
||||
message=f"Synced {synced} cash registers with {errors} errors"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/all", response_model=dict)
|
||||
async def sync_all_nomenclatures(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Sync all nomenclatures (suppliers + cash registers) from Oracle."""
|
||||
cid = company_id or selected_company
|
||||
|
||||
# Sync suppliers
|
||||
suppliers_synced, suppliers_errors = await SyncService.sync_suppliers(session, cid)
|
||||
|
||||
# Sync cash registers
|
||||
registers_synced, registers_errors = await SyncService.sync_cash_registers(session, cid)
|
||||
|
||||
return {
|
||||
"suppliers": {
|
||||
"synced": suppliers_synced,
|
||||
"errors": suppliers_errors
|
||||
},
|
||||
"cash_registers": {
|
||||
"synced": registers_synced,
|
||||
"errors": registers_errors
|
||||
},
|
||||
"total_synced": suppliers_synced + registers_synced,
|
||||
"total_errors": suppliers_errors + registers_errors,
|
||||
"message": f"Synced {suppliers_synced} suppliers and {registers_synced} cash registers"
|
||||
}
|
||||
218
backend/modules/data_entry/routers/ocr.py
Normal file
218
backend/modules/data_entry/routers/ocr.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""OCR API endpoints."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
|
||||
from backend.modules.data_entry.services.ocr_service import ocr_service
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.schemas.ocr import OCRResponse, OCRStatusResponse, ExtractionData, TvaEntry, PaymentMethod
|
||||
|
||||
# Auth integration (will be protected by middleware)
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/status", response_model=OCRStatusResponse)
|
||||
async def get_ocr_status():
|
||||
"""Check OCR service status and available engines."""
|
||||
engines = OCREngine.get_available_engines()
|
||||
available = len(engines) > 0
|
||||
|
||||
if available:
|
||||
message = f"OCR service ready with engines: {', '.join(engines)}"
|
||||
else:
|
||||
message = "No OCR engines available. Install PaddleOCR or Tesseract."
|
||||
|
||||
return OCRStatusResponse(
|
||||
available=available,
|
||||
engines=engines,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
@router.post("/extract", response_model=OCRResponse)
|
||||
async def extract_from_image(file: UploadFile = File(...)):
|
||||
"""
|
||||
Extract receipt data from uploaded image.
|
||||
|
||||
Accepts JPG, PNG, or PDF files (max 10MB).
|
||||
Returns extracted fields with confidence scores.
|
||||
"""
|
||||
allowed_types = ['image/jpeg', 'image/png', 'application/pdf']
|
||||
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF"
|
||||
)
|
||||
|
||||
# Get file extension
|
||||
suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg'
|
||||
if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']:
|
||||
suffix = '.jpg'
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
content = await file.read()
|
||||
|
||||
# Check file size (10MB limit)
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File too large. Maximum size is 10MB."
|
||||
)
|
||||
|
||||
tmp.write(content)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
success, message, result = await ocr_service.process_image(
|
||||
tmp_path, file.content_type
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
# Convert ExtractionResult to ExtractionData schema
|
||||
# Convert tva_entries from dict to TvaEntry objects
|
||||
tva_entries_schema = [
|
||||
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
|
||||
for e in result.tva_entries
|
||||
] if result.tva_entries else []
|
||||
|
||||
# Convert payment_methods from dict to PaymentMethod objects
|
||||
from decimal import Decimal
|
||||
payment_methods_list = [
|
||||
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
|
||||
for pm in result.payment_methods
|
||||
] if result.payment_methods else []
|
||||
|
||||
# Auto-suggest payment_mode based on detected methods
|
||||
suggested_payment_mode = None
|
||||
if payment_methods_list:
|
||||
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
|
||||
if has_card:
|
||||
suggested_payment_mode = 'banca'
|
||||
# NUMERAR -> no auto-suggestion, user chooses between casa/avans
|
||||
|
||||
data = ExtractionData(
|
||||
receipt_type=result.receipt_type,
|
||||
receipt_number=result.receipt_number,
|
||||
receipt_series=result.receipt_series,
|
||||
receipt_date=result.receipt_date,
|
||||
amount=result.amount,
|
||||
partner_name=result.partner_name,
|
||||
cui=result.cui,
|
||||
description=result.description,
|
||||
tva_entries=tva_entries_schema,
|
||||
tva_total=result.tva_total,
|
||||
address=result.address,
|
||||
items_count=result.items_count,
|
||||
payment_methods=payment_methods_list,
|
||||
suggested_payment_mode=suggested_payment_mode,
|
||||
confidence_amount=result.confidence_amount,
|
||||
confidence_date=result.confidence_date,
|
||||
confidence_vendor=result.confidence_vendor,
|
||||
overall_confidence=result.overall_confidence,
|
||||
raw_text=result.raw_text,
|
||||
ocr_engine=result.ocr_engine,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
)
|
||||
|
||||
return OCRResponse(success=True, message=message, data=data)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if tmp_path.exists():
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
|
||||
async def extract_from_attachment(
|
||||
attachment_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Extract receipt data from an existing attachment.
|
||||
|
||||
Re-processes an already uploaded file with OCR.
|
||||
"""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
# Check if file type is supported
|
||||
if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type not supported for OCR: {attachment.mime_type}"
|
||||
)
|
||||
|
||||
success, message, result = await ocr_service.process_image(
|
||||
file_path, attachment.mime_type
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=422, detail=message)
|
||||
|
||||
# Convert ExtractionResult to ExtractionData schema
|
||||
# Convert tva_entries from dict to TvaEntry objects
|
||||
tva_entries_schema = [
|
||||
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
|
||||
for e in result.tva_entries
|
||||
] if result.tva_entries else []
|
||||
|
||||
# Convert payment_methods from dict to PaymentMethod objects
|
||||
from decimal import Decimal
|
||||
payment_methods_list = [
|
||||
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
|
||||
for pm in result.payment_methods
|
||||
] if result.payment_methods else []
|
||||
|
||||
# Auto-suggest payment_mode based on detected methods
|
||||
suggested_payment_mode = None
|
||||
if payment_methods_list:
|
||||
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
|
||||
if has_card:
|
||||
suggested_payment_mode = 'banca'
|
||||
# NUMERAR -> no auto-suggestion, user chooses between casa/avans
|
||||
|
||||
data = ExtractionData(
|
||||
receipt_type=result.receipt_type,
|
||||
receipt_number=result.receipt_number,
|
||||
receipt_series=result.receipt_series,
|
||||
receipt_date=result.receipt_date,
|
||||
amount=result.amount,
|
||||
partner_name=result.partner_name,
|
||||
cui=result.cui,
|
||||
description=result.description,
|
||||
tva_entries=tva_entries_schema,
|
||||
tva_total=result.tva_total,
|
||||
address=result.address,
|
||||
items_count=result.items_count,
|
||||
payment_methods=payment_methods_list,
|
||||
suggested_payment_mode=suggested_payment_mode,
|
||||
confidence_amount=result.confidence_amount,
|
||||
confidence_date=result.confidence_date,
|
||||
confidence_vendor=result.confidence_vendor,
|
||||
overall_confidence=result.overall_confidence,
|
||||
raw_text=result.raw_text,
|
||||
ocr_engine=result.ocr_engine,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
)
|
||||
|
||||
return OCRResponse(success=True, message=message, data=data)
|
||||
517
backend/modules/data_entry/routers/receipts.py
Normal file
517
backend/modules/data_entry/routers/receipts.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""API endpoints for receipts."""
|
||||
|
||||
from typing import List, Optional, Annotated
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Header
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.database import get_session
|
||||
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
|
||||
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
|
||||
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
|
||||
from backend.modules.data_entry.services.receipt_service import ReceiptService
|
||||
from backend.modules.data_entry.services.nomenclature_service import NomenclatureService
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
ReceiptFilter,
|
||||
AttachmentResponse,
|
||||
AccountingEntryResponse,
|
||||
WorkflowAction,
|
||||
RejectRequest,
|
||||
EntriesUpdateRequest,
|
||||
PartnerOption,
|
||||
AccountOption,
|
||||
CashRegisterOption,
|
||||
ExpenseTypeOption,
|
||||
)
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptStatus, ReceiptDirection
|
||||
|
||||
# Auth integration
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Helper for selected company from header ============
|
||||
|
||||
async def get_selected_company(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
x_selected_company: Annotated[Optional[str], Header()] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get selected company from X-Selected-Company header.
|
||||
|
||||
Validates that the user has access to the specified company.
|
||||
Falls back to user's first company if no header is provided.
|
||||
|
||||
Raises:
|
||||
HTTPException 403: If user doesn't have access to specified company
|
||||
HTTPException 400: If user has no companies assigned
|
||||
"""
|
||||
if x_selected_company:
|
||||
try:
|
||||
company_id = int(x_selected_company)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid company ID format: {x_selected_company}"
|
||||
)
|
||||
|
||||
# Validate user has access to this company
|
||||
# Auth stores companies as strings
|
||||
if str(company_id) in current_user.companies:
|
||||
return company_id
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Nu aveți acces la firma {company_id}"
|
||||
)
|
||||
|
||||
# No header - use first company from user's list
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Nu aveți nicio firmă asignată"
|
||||
)
|
||||
|
||||
|
||||
# Dependency for injection
|
||||
SelectedCompany = Annotated[int, Depends(get_selected_company)]
|
||||
|
||||
|
||||
# Legacy function for backwards compatibility (deprecated)
|
||||
def get_current_user_company(current_user: CurrentUser) -> int:
|
||||
"""
|
||||
DEPRECATED: Use get_selected_company() dependency instead.
|
||||
This function returns the first company, ignoring X-Selected-Company header.
|
||||
"""
|
||||
if current_user.companies:
|
||||
try:
|
||||
return int(current_user.companies[0])
|
||||
except (ValueError, IndexError):
|
||||
return 1
|
||||
return 1
|
||||
|
||||
|
||||
# ============ Receipt CRUD Endpoints ============
|
||||
|
||||
@router.post("/", response_model=ReceiptResponse)
|
||||
async def create_receipt(
|
||||
data: ReceiptCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new receipt in DRAFT status."""
|
||||
receipt = await ReceiptService.create_receipt(session, data, current_user.username)
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.get("/", response_model=ReceiptListResponse)
|
||||
async def list_receipts(
|
||||
status: Optional[ReceiptStatus] = None,
|
||||
direction: Optional[ReceiptDirection] = None,
|
||||
company_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=100),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get paginated list of receipts with filters."""
|
||||
from datetime import date as date_type
|
||||
|
||||
filters = ReceiptFilter(
|
||||
status=status,
|
||||
direction=direction,
|
||||
company_id=company_id or selected_company,
|
||||
created_by=created_by,
|
||||
date_from=date_type.fromisoformat(date_from) if date_from else None,
|
||||
date_to=date_type.fromisoformat(date_to) if date_to else None,
|
||||
search=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return await ReceiptService.get_receipts(session, filters)
|
||||
|
||||
|
||||
@router.get("/pending", response_model=List[ReceiptResponse])
|
||||
async def list_pending_receipts(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get all receipts pending review (for accountant view)."""
|
||||
receipts = await ReceiptCRUD.get_pending_review(
|
||||
session, company_id or selected_company
|
||||
)
|
||||
return [ReceiptResponse.model_validate(r) for r in receipts]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_receipt_stats(
|
||||
company_id: Optional[int] = None,
|
||||
my_receipts: bool = False,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Get receipt statistics."""
|
||||
return await ReceiptCRUD.get_stats(
|
||||
session,
|
||||
company_id or selected_company,
|
||||
created_by=current_user.username if my_receipts else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{receipt_id}", response_model=ReceiptResponse)
|
||||
async def get_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get receipt details with attachments and accounting entries."""
|
||||
receipt = await ReceiptService.get_receipt(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.put("/{receipt_id}", response_model=ReceiptResponse)
|
||||
async def update_receipt(
|
||||
receipt_id: int,
|
||||
data: ReceiptUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update receipt (only DRAFT status, only by creator)."""
|
||||
success, message, receipt = await ReceiptService.update_receipt(
|
||||
session, receipt_id, data, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return ReceiptResponse.model_validate(receipt)
|
||||
|
||||
|
||||
@router.delete("/{receipt_id}")
|
||||
async def delete_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete receipt (only DRAFT status, only by creator)."""
|
||||
success, message = await ReceiptService.delete_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return {"success": True, "message": message}
|
||||
|
||||
|
||||
# ============ Workflow Endpoints ============
|
||||
|
||||
@router.post("/{receipt_id}/submit", response_model=WorkflowAction)
|
||||
async def submit_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Submit receipt for review (DRAFT → PENDING_REVIEW)."""
|
||||
success, message, receipt = await ReceiptService.submit_for_review(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/approve", response_model=WorkflowAction)
|
||||
async def approve_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Approve receipt (PENDING_REVIEW → APPROVED). Accountant action."""
|
||||
success, message, receipt = await ReceiptService.approve_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/reject", response_model=WorkflowAction)
|
||||
async def reject_receipt(
|
||||
receipt_id: int,
|
||||
data: RejectRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Reject receipt (PENDING_REVIEW → REJECTED). Accountant action."""
|
||||
success, message, receipt = await ReceiptService.reject_receipt(
|
||||
session, receipt_id, current_user.username, data.reason
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/resubmit", response_model=WorkflowAction)
|
||||
async def resubmit_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW)."""
|
||||
success, message, receipt = await ReceiptService.resubmit_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/unapprove", response_model=WorkflowAction)
|
||||
async def unapprove_receipt(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Unapprove receipt (APPROVED → PENDING_REVIEW). Returns to pending for corrections."""
|
||||
success, message, receipt = await ReceiptService.unapprove_receipt(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
return WorkflowAction(
|
||||
success=success,
|
||||
message=message,
|
||||
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
|
||||
)
|
||||
|
||||
|
||||
# ============ Accounting Entries Endpoints ============
|
||||
|
||||
@router.get("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
|
||||
async def get_receipt_entries(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get accounting entries for a receipt."""
|
||||
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
@router.put("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
|
||||
async def update_receipt_entries(
|
||||
receipt_id: int,
|
||||
data: EntriesUpdateRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update accounting entries for a receipt (accountant action)."""
|
||||
success, message, entries = await ReceiptService.update_entries(
|
||||
session, receipt_id, data.entries, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
@router.post("/{receipt_id}/entries/regenerate", response_model=List[AccountingEntryResponse])
|
||||
async def regenerate_entries(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Regenerate accounting entries based on receipt data."""
|
||||
success, message, _ = await ReceiptService.regenerate_entries(
|
||||
session, receipt_id, current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AccountingEntryResponse.model_validate(e) for e in entries]
|
||||
|
||||
|
||||
# ============ Attachment Endpoints ============
|
||||
|
||||
@router.post("/{receipt_id}/attachments", response_model=AttachmentResponse)
|
||||
async def upload_attachment(
|
||||
receipt_id: int,
|
||||
file: UploadFile = File(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Upload attachment for a receipt."""
|
||||
# Check receipt exists and user can modify it
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=False)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
# Only allow uploads for DRAFT and REJECTED receipts
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot upload attachments for this receipt status"
|
||||
)
|
||||
|
||||
# Only creator can upload
|
||||
if receipt.created_by != current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can upload attachments"
|
||||
)
|
||||
|
||||
try:
|
||||
attachment = await AttachmentCRUD.create(session, receipt_id, file)
|
||||
return AttachmentResponse.model_validate(attachment)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{receipt_id}/attachments", response_model=List[AttachmentResponse])
|
||||
async def list_attachments(
|
||||
receipt_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get all attachments for a receipt."""
|
||||
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
|
||||
return [AttachmentResponse.model_validate(a) for a in attachments]
|
||||
|
||||
|
||||
@router.get("/attachments/{attachment_id}/download")
|
||||
async def download_attachment(
|
||||
attachment_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Download an attachment file."""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
file_path = AttachmentCRUD.get_file_path(attachment)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=attachment.filename,
|
||||
media_type=attachment.mime_type,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/attachments/{attachment_id}")
|
||||
async def delete_attachment(
|
||||
attachment_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete an attachment."""
|
||||
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
|
||||
|
||||
if not attachment:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
# Get receipt to check permissions
|
||||
receipt = await ReceiptCRUD.get_by_id(session, attachment.receipt_id, include_relations=False)
|
||||
|
||||
if not receipt:
|
||||
raise HTTPException(status_code=404, detail="Receipt not found")
|
||||
|
||||
# Only allow deletion for DRAFT receipts by creator
|
||||
if receipt.status != ReceiptStatus.DRAFT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete attachments for this receipt status"
|
||||
)
|
||||
|
||||
if receipt.created_by != current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can delete attachments"
|
||||
)
|
||||
|
||||
await AttachmentCRUD.delete(session, attachment)
|
||||
return {"success": True, "message": "Attachment deleted"}
|
||||
|
||||
|
||||
# ============ Nomenclature Endpoints ============
|
||||
|
||||
@router.get("/nomenclature/partners", response_model=List[PartnerOption])
|
||||
async def get_partners(
|
||||
search: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get partners (suppliers/customers) for dropdown."""
|
||||
return await NomenclatureService.get_partners(
|
||||
company_id or selected_company, search, session
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nomenclature/accounts", response_model=List[AccountOption])
|
||||
async def get_accounts(
|
||||
prefix: Optional[str] = None,
|
||||
company_id: Optional[int] = None,
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get chart of accounts for dropdown."""
|
||||
return await NomenclatureService.get_accounts(
|
||||
company_id or selected_company, prefix
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nomenclature/cash-registers", response_model=List[CashRegisterOption])
|
||||
async def get_cash_registers(
|
||||
company_id: Optional[int] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
selected_company: SelectedCompany = None,
|
||||
):
|
||||
"""Get cash registers and bank accounts for dropdown."""
|
||||
return await NomenclatureService.get_cash_registers(company_id or selected_company, session)
|
||||
|
||||
|
||||
@router.get("/nomenclature/expense-types", response_model=List[ExpenseTypeOption])
|
||||
async def get_expense_types():
|
||||
"""Get predefined expense types for dropdown."""
|
||||
return await NomenclatureService.get_expense_types()
|
||||
28
backend/modules/data_entry/schemas/__init__.py
Normal file
28
backend/modules/data_entry/schemas/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Pydantic schemas
|
||||
from .receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
ReceiptFilter,
|
||||
AttachmentResponse,
|
||||
AccountingEntryCreate,
|
||||
AccountingEntryUpdate,
|
||||
AccountingEntryResponse,
|
||||
WorkflowAction,
|
||||
RejectRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReceiptCreate",
|
||||
"ReceiptUpdate",
|
||||
"ReceiptResponse",
|
||||
"ReceiptListResponse",
|
||||
"ReceiptFilter",
|
||||
"AttachmentResponse",
|
||||
"AccountingEntryCreate",
|
||||
"AccountingEntryUpdate",
|
||||
"AccountingEntryResponse",
|
||||
"WorkflowAction",
|
||||
"RejectRequest",
|
||||
]
|
||||
122
backend/modules/data_entry/schemas/ocr.py
Normal file
122
backend/modules/data_entry/schemas/ocr.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Pydantic schemas for OCR API."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TvaEntry(BaseModel):
|
||||
"""Single TVA entry with code, percentage and amount."""
|
||||
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
|
||||
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
|
||||
amount: Decimal = Field(description="TVA amount for this rate")
|
||||
|
||||
|
||||
class PaymentMethod(BaseModel):
|
||||
"""Payment method entry from OCR."""
|
||||
method: str = Field(description="CARD or NUMERAR")
|
||||
amount: Decimal = Field(description="Amount paid")
|
||||
|
||||
|
||||
class ExtractionData(BaseModel):
|
||||
"""Extracted receipt data from OCR."""
|
||||
|
||||
receipt_type: str = Field(default='bon_fiscal', description="Receipt type: bon_fiscal or chitanta")
|
||||
receipt_number: Optional[str] = Field(default=None, description="Receipt number")
|
||||
receipt_series: Optional[str] = Field(default=None, description="Receipt series")
|
||||
receipt_date: Optional[date] = Field(default=None, description="Receipt date")
|
||||
amount: Optional[Decimal] = Field(default=None, description="Total amount")
|
||||
partner_name: Optional[str] = Field(default=None, description="Vendor/partner name")
|
||||
cui: Optional[str] = Field(default=None, description="CUI (fiscal identification code)")
|
||||
description: Optional[str] = Field(default=None, description="Optional description")
|
||||
|
||||
# Additional extracted fields - Multiple TVA entries support
|
||||
tva_entries: List[TvaEntry] = Field(default=[], description="List of TVA entries by rate (A, B, C, D)")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
address: Optional[str] = Field(default=None, description="Vendor address")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items/articles")
|
||||
|
||||
# Payment methods extracted from receipt
|
||||
payment_methods: List[PaymentMethod] = Field(default=[], description="Payment methods from receipt (CARD, NUMERAR)")
|
||||
suggested_payment_mode: Optional[str] = Field(default=None, description="Auto-suggested payment mode based on OCR (casa/banca)")
|
||||
|
||||
# Client data (for B2B receipts - buyer information)
|
||||
client_name: Optional[str] = Field(default=None, description="Client/customer company name")
|
||||
client_cui: Optional[str] = Field(default=None, description="Client CUI/CIF fiscal code")
|
||||
client_address: Optional[str] = Field(default=None, description="Client address")
|
||||
|
||||
confidence_amount: float = Field(default=0.0, ge=0, le=1, description="Amount extraction confidence")
|
||||
confidence_date: float = Field(default=0.0, ge=0, le=1, description="Date extraction confidence")
|
||||
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
|
||||
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
|
||||
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
|
||||
raw_text: str = Field(default="", description="Raw OCR text")
|
||||
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
|
||||
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"receipt_type": "bon_fiscal",
|
||||
"receipt_number": "1360760",
|
||||
"receipt_series": "0146",
|
||||
"receipt_date": "2025-10-11",
|
||||
"amount": 186.16,
|
||||
"partner_name": "FIVE-HOLDING S.A.",
|
||||
"cui": "10562600",
|
||||
"description": None,
|
||||
"tva_entries": [
|
||||
{"code": "A", "percent": 19, "amount": 25.00},
|
||||
{"code": "B", "percent": 9, "amount": 7.31}
|
||||
],
|
||||
"tva_total": 32.31,
|
||||
"address": "JUD. CONSTANTA, MUN. CONSTANTA, STR. ION ROATA NR. 3",
|
||||
"items_count": 17,
|
||||
"confidence_amount": 0.98,
|
||||
"confidence_date": 0.98,
|
||||
"confidence_vendor": 0.95,
|
||||
"overall_confidence": 0.97,
|
||||
"raw_text": "FIVE-HOLDING S.A.\nCIF: RO10562600\n..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""OCR API response."""
|
||||
|
||||
success: bool = Field(description="Whether OCR processing was successful")
|
||||
message: str = Field(description="Status message")
|
||||
data: Optional[ExtractionData] = Field(default=None, description="Extracted data")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "OCR processing successful. Found: amount, date, vendor",
|
||||
"data": {
|
||||
"receipt_type": "bon_fiscal",
|
||||
"receipt_number": "12345",
|
||||
"receipt_date": "2024-01-15",
|
||||
"amount": 125.50,
|
||||
"partner_name": "MEGA IMAGE SRL",
|
||||
"cui": "12345678",
|
||||
"confidence_amount": 0.95,
|
||||
"confidence_date": 0.90,
|
||||
"confidence_vendor": 0.75,
|
||||
"overall_confidence": 0.87,
|
||||
"raw_text": "BON FISCAL\nMEGA IMAGE SRL\n..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OCRStatusResponse(BaseModel):
|
||||
"""OCR service status response."""
|
||||
|
||||
available: bool = Field(description="Whether OCR service is available")
|
||||
engines: list[str] = Field(description="Available OCR engines")
|
||||
message: str = Field(description="Status message")
|
||||
269
backend/modules/data_entry/schemas/receipt.py
Normal file
269
backend/modules/data_entry/schemas/receipt.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Pydantic schemas for receipts API."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Any, Union
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import ReceiptType, ReceiptDirection, ReceiptStatus
|
||||
from backend.modules.data_entry.db.models.accounting_entry import EntryType
|
||||
|
||||
|
||||
# ============ Accounting Entry Schemas ============
|
||||
|
||||
class AccountingEntryBase(BaseModel):
|
||||
"""Base schema for accounting entry."""
|
||||
entry_type: EntryType
|
||||
account_code: str = Field(max_length=20)
|
||||
account_name: Optional[str] = Field(default=None, max_length=200)
|
||||
amount: Decimal
|
||||
partner_id: Optional[int] = None
|
||||
cost_center_id: Optional[int] = None
|
||||
|
||||
|
||||
class AccountingEntryCreate(AccountingEntryBase):
|
||||
"""Schema for creating an accounting entry."""
|
||||
pass
|
||||
|
||||
|
||||
class AccountingEntryUpdate(BaseModel):
|
||||
"""Schema for updating an accounting entry."""
|
||||
entry_type: Optional[EntryType] = None
|
||||
account_code: Optional[str] = Field(default=None, max_length=20)
|
||||
account_name: Optional[str] = Field(default=None, max_length=200)
|
||||
amount: Optional[Decimal] = None
|
||||
partner_id: Optional[int] = None
|
||||
cost_center_id: Optional[int] = None
|
||||
|
||||
|
||||
class AccountingEntryResponse(AccountingEntryBase):
|
||||
"""Schema for accounting entry response."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
receipt_id: int
|
||||
is_auto_generated: bool
|
||||
modified_by: Optional[str] = None
|
||||
modified_at: Optional[datetime] = None
|
||||
sort_order: int
|
||||
|
||||
|
||||
# ============ Attachment Schemas ============
|
||||
|
||||
class AttachmentResponse(BaseModel):
|
||||
"""Schema for attachment response."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
receipt_id: int
|
||||
filename: str
|
||||
stored_filename: str
|
||||
file_path: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
uploaded_at: datetime
|
||||
|
||||
|
||||
# ============ TVA Schema ============
|
||||
|
||||
class TvaEntrySchema(BaseModel):
|
||||
"""Single TVA entry with code, percentage and amount."""
|
||||
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
|
||||
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
|
||||
amount: Decimal = Field(description="TVA amount for this rate")
|
||||
|
||||
|
||||
class PaymentMethodSchema(BaseModel):
|
||||
"""Payment method entry (CARD/NUMERAR)."""
|
||||
method: str = Field(description="Payment method: CARD or NUMERAR")
|
||||
amount: Decimal = Field(description="Amount paid with this method")
|
||||
|
||||
|
||||
# ============ Receipt Schemas ============
|
||||
|
||||
class ReceiptBase(BaseModel):
|
||||
"""Base schema for receipt."""
|
||||
receipt_type: ReceiptType = ReceiptType.BON_FISCAL
|
||||
direction: ReceiptDirection = ReceiptDirection.CHELTUIALA
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
receipt_date: date
|
||||
amount: Decimal = Field(gt=0)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
# TVA info (multiple entries support)
|
||||
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items")
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
|
||||
# Other fields
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
company_id: int
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200)
|
||||
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
|
||||
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
|
||||
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
|
||||
cash_register_id: Optional[int] = None
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100)
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20)
|
||||
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
|
||||
|
||||
|
||||
class ReceiptCreate(ReceiptBase):
|
||||
"""Schema for creating a receipt."""
|
||||
pass
|
||||
|
||||
|
||||
class ReceiptUpdate(BaseModel):
|
||||
"""Schema for updating a receipt (DRAFT only)."""
|
||||
receipt_type: Optional[ReceiptType] = None
|
||||
direction: Optional[ReceiptDirection] = None
|
||||
receipt_number: Optional[str] = Field(default=None, max_length=50)
|
||||
receipt_series: Optional[str] = Field(default=None, max_length=20)
|
||||
receipt_date: Optional[date] = None
|
||||
amount: Optional[Decimal] = Field(default=None, gt=0)
|
||||
description: Optional[str] = Field(default=None, max_length=500)
|
||||
# TVA info (multiple entries support)
|
||||
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
|
||||
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
|
||||
items_count: Optional[int] = Field(default=None, description="Number of items")
|
||||
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
|
||||
# Other fields
|
||||
expense_type_code: Optional[str] = Field(default=None, max_length=20)
|
||||
# partner_id removed - supplier data is text-only (partner_name, cui)
|
||||
partner_name: Optional[str] = Field(default=None, max_length=200)
|
||||
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
|
||||
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
|
||||
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
|
||||
cash_register_id: Optional[int] = None
|
||||
cash_register_name: Optional[str] = Field(default=None, max_length=100)
|
||||
cash_register_account: Optional[str] = Field(default=None, max_length=20)
|
||||
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
|
||||
|
||||
|
||||
class ReceiptResponse(ReceiptBase):
|
||||
"""Schema for receipt response with all fields."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
# Override amount to allow zero values in response (validation is on input, not output)
|
||||
amount: Decimal
|
||||
status: ReceiptStatus
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
submitted_at: Optional[datetime] = None
|
||||
reviewed_by: Optional[str] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
rejection_reason: Optional[str] = None
|
||||
oracle_synced_at: Optional[datetime] = None
|
||||
oracle_act_id: Optional[int] = None
|
||||
oracle_error: Optional[str] = None
|
||||
|
||||
# Relationships (optional, loaded when needed)
|
||||
attachments: List[AttachmentResponse] = []
|
||||
entries: List[AccountingEntryResponse] = []
|
||||
|
||||
@field_validator('tva_breakdown', mode='before')
|
||||
@classmethod
|
||||
def parse_tva_breakdown(cls, v: Any) -> Optional[List[dict]]:
|
||||
"""Deserialize tva_breakdown from JSON string if needed."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return None
|
||||
|
||||
@field_validator('payment_methods', mode='before')
|
||||
@classmethod
|
||||
def parse_payment_methods(cls, v: Any) -> Optional[List[dict]]:
|
||||
"""Deserialize payment_methods from JSON string if needed."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
class ReceiptListResponse(BaseModel):
|
||||
"""Schema for paginated receipt list response."""
|
||||
items: List[ReceiptResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class ReceiptFilter(BaseModel):
|
||||
"""Schema for filtering receipts."""
|
||||
status: Optional[ReceiptStatus] = None
|
||||
direction: Optional[ReceiptDirection] = None
|
||||
company_id: Optional[int] = None
|
||||
created_by: Optional[str] = None
|
||||
date_from: Optional[date] = None
|
||||
date_to: Optional[date] = None
|
||||
search: Optional[str] = None # Search in description, partner_name
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
# ============ Workflow Schemas ============
|
||||
|
||||
class WorkflowAction(BaseModel):
|
||||
"""Schema for workflow action response."""
|
||||
success: bool
|
||||
message: str
|
||||
receipt: Optional[ReceiptResponse] = None
|
||||
|
||||
|
||||
class RejectRequest(BaseModel):
|
||||
"""Schema for rejection request."""
|
||||
reason: str = Field(min_length=5, max_length=500)
|
||||
|
||||
|
||||
class EntriesUpdateRequest(BaseModel):
|
||||
"""Schema for bulk updating accounting entries."""
|
||||
entries: List[AccountingEntryCreate]
|
||||
|
||||
|
||||
# ============ Nomenclature Schemas ============
|
||||
|
||||
class PartnerOption(BaseModel):
|
||||
"""Schema for partner dropdown option (used for autocomplete assistance)."""
|
||||
name: str
|
||||
fiscal_code: Optional[str] = None
|
||||
address: Optional[str] = None
|
||||
source: str = "oracle" # 'oracle' (synced) or 'local'
|
||||
|
||||
|
||||
class AccountOption(BaseModel):
|
||||
"""Schema for account dropdown option."""
|
||||
code: str
|
||||
name: str
|
||||
|
||||
|
||||
class CashRegisterOption(BaseModel):
|
||||
"""Schema for cash register dropdown option."""
|
||||
id: int
|
||||
name: str
|
||||
account_code: str # 5311, 5121, etc.
|
||||
|
||||
|
||||
class ExpenseTypeOption(BaseModel):
|
||||
"""Schema for expense type dropdown option."""
|
||||
code: str
|
||||
name: str
|
||||
account_code: str
|
||||
has_vat: bool
|
||||
vat_percent: Decimal = Decimal("19")
|
||||
11
backend/modules/data_entry/services/__init__.py
Normal file
11
backend/modules/data_entry/services/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Business logic services
|
||||
from .receipt_service import ReceiptService
|
||||
from .nomenclature_service import NomenclatureService
|
||||
from .expense_types import EXPENSE_TYPES, ExpenseType
|
||||
|
||||
__all__ = [
|
||||
"ReceiptService",
|
||||
"NomenclatureService",
|
||||
"EXPENSE_TYPES",
|
||||
"ExpenseType",
|
||||
]
|
||||
101
backend/modules/data_entry/services/expense_types.py
Normal file
101
backend/modules/data_entry/services/expense_types.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Predefined expense types for automatic accounting entry generation."""
|
||||
|
||||
from decimal import Decimal
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpenseType:
|
||||
"""Expense type definition with accounting configuration."""
|
||||
code: str
|
||||
name: str
|
||||
account_code: str
|
||||
account_name: str
|
||||
has_vat: bool
|
||||
vat_percent: Decimal = Decimal("19")
|
||||
vat_account: str = "4426"
|
||||
|
||||
|
||||
# Predefined expense types
|
||||
EXPENSE_TYPES: Dict[str, ExpenseType] = {
|
||||
"FUEL": ExpenseType(
|
||||
code="FUEL",
|
||||
name="Combustibil",
|
||||
account_code="6022",
|
||||
account_name="Cheltuieli cu combustibilii",
|
||||
has_vat=True,
|
||||
),
|
||||
"MATERIALS": ExpenseType(
|
||||
code="MATERIALS",
|
||||
name="Materiale consumabile",
|
||||
account_code="6028",
|
||||
account_name="Alte cheltuieli cu materiale consumabile",
|
||||
has_vat=True,
|
||||
),
|
||||
"OFFICE": ExpenseType(
|
||||
code="OFFICE",
|
||||
name="Rechizite birou",
|
||||
account_code="6024",
|
||||
account_name="Cheltuieli privind materialele pentru ambalat",
|
||||
has_vat=True,
|
||||
),
|
||||
"PHONE": ExpenseType(
|
||||
code="PHONE",
|
||||
name="Telefonie / Internet",
|
||||
account_code="626",
|
||||
account_name="Cheltuieli postale si taxe de telecomunicatii",
|
||||
has_vat=True,
|
||||
),
|
||||
"PARKING": ExpenseType(
|
||||
code="PARKING",
|
||||
name="Parcare",
|
||||
account_code="6022",
|
||||
account_name="Cheltuieli cu combustibilii",
|
||||
has_vat=True,
|
||||
),
|
||||
"FOOD": ExpenseType(
|
||||
code="FOOD",
|
||||
name="Alimentatie",
|
||||
account_code="6028",
|
||||
account_name="Alte cheltuieli cu materiale consumabile",
|
||||
has_vat=False, # No deductible VAT for food
|
||||
),
|
||||
"TRANSPORT": ExpenseType(
|
||||
code="TRANSPORT",
|
||||
name="Transport",
|
||||
account_code="624",
|
||||
account_name="Cheltuieli cu transportul de bunuri si personal",
|
||||
has_vat=True,
|
||||
),
|
||||
"OTHER": ExpenseType(
|
||||
code="OTHER",
|
||||
name="Altele",
|
||||
account_code="628",
|
||||
account_name="Alte cheltuieli cu serviciile executate de terti",
|
||||
has_vat=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_expense_type(code: str) -> Optional[ExpenseType]:
|
||||
"""Get expense type by code."""
|
||||
return EXPENSE_TYPES.get(code)
|
||||
|
||||
|
||||
def get_all_expense_types() -> Dict[str, ExpenseType]:
|
||||
"""Get all expense types."""
|
||||
return EXPENSE_TYPES.copy()
|
||||
|
||||
|
||||
# Default cash register accounts
|
||||
CASH_REGISTER_ACCOUNTS = {
|
||||
"CASA": {
|
||||
"code": "5311",
|
||||
"name": "Casa in lei",
|
||||
},
|
||||
"BANCA": {
|
||||
"code": "5121",
|
||||
"name": "Conturi la banci in lei",
|
||||
},
|
||||
}
|
||||
270
backend/modules/data_entry/services/image_preprocessor.py
Normal file
270
backend/modules/data_entry/services/image_preprocessor.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Image preprocessing for optimal OCR results."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
try:
|
||||
import pdf2image
|
||||
PDF_AVAILABLE = True
|
||||
except ImportError:
|
||||
PDF_AVAILABLE = False
|
||||
|
||||
|
||||
class ImagePreprocessor:
|
||||
"""Preprocess receipt images for OCR."""
|
||||
|
||||
def _add_safety_padding(self, image: np.ndarray, padding: int = 50) -> np.ndarray:
|
||||
"""Add white padding around image to protect edge content during rotation.
|
||||
|
||||
This prevents left/right margin truncation in OCR by ensuring text near
|
||||
edges isn't lost during deskew rotation.
|
||||
"""
|
||||
if len(image.shape) == 2:
|
||||
# Grayscale
|
||||
return cv2.copyMakeBorder(
|
||||
image, padding, padding, padding, padding,
|
||||
cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
else:
|
||||
# Color (BGR)
|
||||
return cv2.copyMakeBorder(
|
||||
image, padding, padding, padding, padding,
|
||||
cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
||||
)
|
||||
|
||||
def load_image(self, path: Path) -> np.ndarray:
|
||||
"""Load image from file."""
|
||||
image = cv2.imread(str(path))
|
||||
if image is None:
|
||||
raise ValueError(f"Could not load image: {path}")
|
||||
return image
|
||||
|
||||
def pdf_to_images(self, path: Path, dpi: int = 300) -> List[np.ndarray]:
|
||||
"""
|
||||
Convert PDF to images.
|
||||
|
||||
Args:
|
||||
path: Path to PDF file
|
||||
dpi: Resolution (300 = fast & good quality, 400 = better but slower)
|
||||
"""
|
||||
if not PDF_AVAILABLE:
|
||||
raise RuntimeError("pdf2image not available. Install with: pip install pdf2image")
|
||||
images = pdf2image.convert_from_path(str(path), dpi=dpi)
|
||||
return [np.array(img) for img in images]
|
||||
|
||||
def preprocess(self, image: np.ndarray, high_quality: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Apply LIGHT preprocessing - better for clear PDFs.
|
||||
Heavy binarization can destroy text on clear images.
|
||||
"""
|
||||
return self.preprocess_light(image)
|
||||
|
||||
def preprocess_light(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Light preprocessing for CLEAR images (PDFs, good scans).
|
||||
Preserves original quality, only enhances contrast.
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
|
||||
height, width = gray.shape
|
||||
max_side = max(height, width)
|
||||
if max_side > 4000:
|
||||
scale = 4000 / max_side
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
height, width = gray.shape
|
||||
|
||||
# 2b. Scale UP if too small
|
||||
if width < 1500:
|
||||
scale = 1500 / width
|
||||
# Ensure we don't exceed 4000px after upscaling
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
if max(new_width, new_height) > 4000:
|
||||
scale = 4000 / max(new_width, new_height)
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Light contrast enhancement only
|
||||
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# NO binarization, NO morphological ops - preserve original quality
|
||||
return enhanced
|
||||
|
||||
def preprocess_heavy(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Heavy preprocessing for FADED thermal receipts.
|
||||
Aggressive binarization to recover faded text.
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
|
||||
height, width = gray.shape
|
||||
max_side = max(height, width)
|
||||
if max_side > 4000:
|
||||
scale = 4000 / max_side
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
height, width = gray.shape
|
||||
|
||||
# 2b. Scale UP if too small (larger = better OCR)
|
||||
if width < 1500:
|
||||
scale = 1500 / width
|
||||
# Ensure we don't exceed 4000px after upscaling
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
if max(new_width, new_height) > 4000:
|
||||
scale = 4000 / max(new_width, new_height)
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Contrast enhancement with CLAHE
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# 5. Denoise
|
||||
denoised = cv2.fastNlMeansDenoising(enhanced, h=8, templateWindowSize=7, searchWindowSize=21)
|
||||
|
||||
# 6. Sharpening
|
||||
gaussian = cv2.GaussianBlur(denoised, (0, 0), 2.0)
|
||||
sharpened = cv2.addWeighted(denoised, 1.5, gaussian, -0.5, 0)
|
||||
|
||||
# 7. Adaptive thresholding (binarization)
|
||||
binary = cv2.adaptiveThreshold(
|
||||
sharpened, 255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY,
|
||||
blockSize=11, C=5
|
||||
)
|
||||
|
||||
# 8. Morphological operations
|
||||
kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||
result = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_close)
|
||||
|
||||
return result
|
||||
|
||||
def preprocess_for_tesseract(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Tesseract-optimized preprocessing.
|
||||
Tesseract works best with:
|
||||
- Clean black text on white background (binarized)
|
||||
- High DPI (scale up small images)
|
||||
- Otsu thresholding (better than adaptive for clean documents)
|
||||
"""
|
||||
# 0. Add safety padding to protect edge content during deskew rotation
|
||||
image = self._add_safety_padding(image)
|
||||
|
||||
# 1. Grayscale
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image.copy()
|
||||
|
||||
# 2. Scale for optimal Tesseract (target ~2000px width for receipts)
|
||||
height, width = gray.shape
|
||||
if width < 2000:
|
||||
scale = 2000 / width
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
elif width > 3000:
|
||||
scale = 3000 / width
|
||||
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
||||
|
||||
# 3. Deskew
|
||||
gray = self._deskew(gray)
|
||||
|
||||
# 4. Strong contrast enhancement
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# 5. Denoise before binarization
|
||||
denoised = cv2.fastNlMeansDenoising(enhanced, h=10, templateWindowSize=7, searchWindowSize=21)
|
||||
|
||||
# 6. Otsu binarization (better than adaptive for clean PDFs)
|
||||
_, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# 7. Light morphological cleanup
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
|
||||
cleaned = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
return cleaned
|
||||
|
||||
def get_all_variants(self, image: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate 2 preprocessing variants for OCR (fast mode).
|
||||
Returns: [light_processed, heavy_processed]
|
||||
"""
|
||||
return [
|
||||
self.preprocess_light(image),
|
||||
self.preprocess_heavy(image),
|
||||
]
|
||||
|
||||
def _deskew(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Correct image rotation/skew using Hough lines.
|
||||
|
||||
Uses expanded canvas to preserve all content during rotation,
|
||||
preventing left/right margin truncation.
|
||||
"""
|
||||
edges = cv2.Canny(image, 50, 150, apertureSize=3)
|
||||
lines = cv2.HoughLinesP(
|
||||
edges, 1, np.pi / 180,
|
||||
threshold=100, minLineLength=100, maxLineGap=10
|
||||
)
|
||||
|
||||
if lines is None:
|
||||
return image
|
||||
|
||||
angles = []
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi
|
||||
if abs(angle) < 45:
|
||||
angles.append(angle)
|
||||
|
||||
if not angles:
|
||||
return image
|
||||
|
||||
median_angle = np.median(angles)
|
||||
if abs(median_angle) < 0.5:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
|
||||
|
||||
# Calculate new canvas size to fit entire rotated image (prevents edge truncation)
|
||||
cos_angle = abs(np.cos(np.radians(median_angle)))
|
||||
sin_angle = abs(np.sin(np.radians(median_angle)))
|
||||
new_w = int(h * sin_angle + w * cos_angle)
|
||||
new_h = int(h * cos_angle + w * sin_angle)
|
||||
|
||||
# Adjust rotation matrix for new canvas center
|
||||
M[0, 2] += (new_w - w) / 2
|
||||
M[1, 2] += (new_h - h) / 2
|
||||
|
||||
return cv2.warpAffine(
|
||||
image, M, (new_w, new_h),
|
||||
flags=cv2.INTER_CUBIC,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=255 # White background (grayscale)
|
||||
)
|
||||
234
backend/modules/data_entry/services/nomenclature_service.py
Normal file
234
backend/modules/data_entry/services/nomenclature_service.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Service for fetching nomenclatures from Oracle (read-only)."""
|
||||
|
||||
from typing import List, Optional
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
PartnerOption,
|
||||
AccountOption,
|
||||
CashRegisterOption,
|
||||
ExpenseTypeOption,
|
||||
)
|
||||
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
|
||||
class NomenclatureService:
|
||||
"""
|
||||
Service for fetching nomenclatures.
|
||||
|
||||
In Phase 1 (MVP), some nomenclatures are hardcoded.
|
||||
In Phase 2, these will be fetched from Oracle.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_partners(
|
||||
company_id: int,
|
||||
search: Optional[str] = None,
|
||||
session: Optional[AsyncSession] = None
|
||||
) -> List[PartnerOption]:
|
||||
"""
|
||||
Get partners (suppliers/customers) for a company.
|
||||
|
||||
Phase 1: Returns mock data.
|
||||
Phase 2: Returns synced data from SQLite (from Oracle sync).
|
||||
Phase 3: Will fetch live from Oracle.
|
||||
"""
|
||||
# If session is provided, try to get from synced SQLite data
|
||||
if session:
|
||||
# Try to get from SQLite synced data
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.order_by(SyncedSupplier.name) # Order alphabetically, no limit for AutoComplete
|
||||
|
||||
result = await session.execute(stmt)
|
||||
suppliers = result.scalars().all()
|
||||
|
||||
if suppliers:
|
||||
# Also get local suppliers
|
||||
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
local_stmt = local_stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
local_stmt = local_stmt.order_by(LocalSupplier.name) # Order alphabetically
|
||||
|
||||
local_result = await session.execute(local_stmt)
|
||||
local_suppliers = local_result.scalars().all()
|
||||
|
||||
# Combine both - no IDs needed, just text data for autocomplete
|
||||
partners = []
|
||||
for s in suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=s.name,
|
||||
fiscal_code=s.fiscal_code,
|
||||
address=s.address,
|
||||
source="oracle"
|
||||
))
|
||||
for l in local_suppliers:
|
||||
partners.append(PartnerOption(
|
||||
name=l.name, # No suffix - must match search results
|
||||
fiscal_code=l.fiscal_code,
|
||||
address=l.address,
|
||||
source="local"
|
||||
))
|
||||
|
||||
return partners
|
||||
|
||||
# Fallback to mock data for Phase 1 (when no synced data)
|
||||
mock_partners = [
|
||||
PartnerOption(name="OMV Petrom", fiscal_code="RO123456", source="mock"),
|
||||
PartnerOption(name="Dedeman", fiscal_code="RO789012", source="mock"),
|
||||
PartnerOption(name="Kaufland", fiscal_code="RO345678", source="mock"),
|
||||
PartnerOption(name="Emag", fiscal_code="RO901234", source="mock"),
|
||||
PartnerOption(name="Altex", fiscal_code="RO567890", source="mock"),
|
||||
]
|
||||
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
mock_partners = [
|
||||
p for p in mock_partners
|
||||
if search_lower in p.name.lower() or (p.fiscal_code and search_lower in p.fiscal_code.lower())
|
||||
]
|
||||
|
||||
return mock_partners
|
||||
|
||||
@staticmethod
|
||||
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
"""
|
||||
Get chart of accounts for a company.
|
||||
|
||||
Phase 1: Returns common expense/income accounts.
|
||||
Phase 2: Will fetch from Oracle PLAN_CONTURI.
|
||||
"""
|
||||
# Common accounts for expenses and receipts
|
||||
accounts = [
|
||||
# Expense accounts (Class 6)
|
||||
AccountOption(code="6022", name="Cheltuieli cu combustibilii"),
|
||||
AccountOption(code="6024", name="Cheltuieli materiale pentru ambalat"),
|
||||
AccountOption(code="6028", name="Alte cheltuieli cu materiale consumabile"),
|
||||
AccountOption(code="624", name="Cheltuieli cu transportul de bunuri si personal"),
|
||||
AccountOption(code="626", name="Cheltuieli postale si taxe telecomunicatii"),
|
||||
AccountOption(code="628", name="Alte cheltuieli cu serviciile executate de terti"),
|
||||
|
||||
# VAT
|
||||
AccountOption(code="4426", name="TVA deductibila"),
|
||||
AccountOption(code="4427", name="TVA colectata"),
|
||||
|
||||
# Cash and Bank (Class 5)
|
||||
AccountOption(code="5311", name="Casa in lei"),
|
||||
AccountOption(code="5121", name="Conturi la banci in lei"),
|
||||
|
||||
# Income accounts (Class 7)
|
||||
AccountOption(code="7588", name="Alte venituri din exploatare"),
|
||||
]
|
||||
|
||||
if prefix:
|
||||
accounts = [a for a in accounts if a.code.startswith(prefix)]
|
||||
|
||||
return accounts
|
||||
|
||||
@staticmethod
|
||||
async def get_cash_registers(
|
||||
company_id: int,
|
||||
session: Optional[AsyncSession] = None
|
||||
) -> List[CashRegisterOption]:
|
||||
"""
|
||||
Get cash registers and bank accounts for a company.
|
||||
|
||||
Phase 1: Returns default options.
|
||||
Phase 2: Returns synced data from SQLite (from Oracle sync).
|
||||
Phase 3: Will fetch live from Oracle NOM_CASE / NOM_BANCI.
|
||||
"""
|
||||
# If session is provided, try to get from synced SQLite data
|
||||
if session:
|
||||
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
|
||||
result = await session.execute(stmt)
|
||||
registers = result.scalars().all()
|
||||
|
||||
if registers:
|
||||
return [
|
||||
CashRegisterOption(id=r.id, name=r.name, account_code=r.account_code)
|
||||
for r in registers
|
||||
]
|
||||
|
||||
# Fallback to default cash registers for Phase 1
|
||||
return [
|
||||
CashRegisterOption(id=1, name="Casa principala", account_code="5311"),
|
||||
CashRegisterOption(id=2, name="Cont BCR", account_code="5121"),
|
||||
CashRegisterOption(id=3, name="Cont BRD", account_code="5121"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_expense_types() -> List[ExpenseTypeOption]:
|
||||
"""
|
||||
Get predefined expense types with their accounting configuration.
|
||||
"""
|
||||
return [
|
||||
ExpenseTypeOption(
|
||||
code=et.code,
|
||||
name=et.name,
|
||||
account_code=et.account_code,
|
||||
has_vat=et.has_vat,
|
||||
vat_percent=et.vat_percent,
|
||||
)
|
||||
for et in EXPENSE_TYPES.values()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_companies(username: str) -> List[dict]:
|
||||
"""
|
||||
Get companies accessible by user.
|
||||
|
||||
Phase 1: Returns mock data.
|
||||
Phase 2: Will fetch from shared auth based on user permissions.
|
||||
"""
|
||||
# TODO: Integrate with shared auth to get user's companies
|
||||
return [
|
||||
{"id": 1, "name": "SC Test SRL", "cui": "RO12345678"},
|
||||
{"id": 2, "name": "SC Demo SA", "cui": "RO87654321"},
|
||||
]
|
||||
|
||||
# ============ Phase 2 Oracle Integration Methods ============
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_partners_oracle(company_id: int, search: Optional[str] = None) -> List[PartnerOption]:
|
||||
"""
|
||||
Fetch partners from Oracle NOM_PARTENERI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
# Example query:
|
||||
# SELECT ID_PART, DEN_PART, COD_FISCAL
|
||||
# FROM {schema}.NOM_PARTENERI
|
||||
# WHERE DEN_PART LIKE :search
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_accounts_oracle(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
|
||||
"""
|
||||
Fetch chart of accounts from Oracle PLAN_CONTURI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_cash_registers_oracle(company_id: int) -> List[CashRegisterOption]:
|
||||
"""
|
||||
Fetch cash registers from Oracle NOM_CASE / NOM_BANCI.
|
||||
|
||||
Will be implemented in Phase 2.
|
||||
"""
|
||||
# TODO: Implement using shared oracle_pool
|
||||
raise NotImplementedError("Oracle integration pending - Phase 2")
|
||||
295
backend/modules/data_entry/services/ocr_engine.py
Normal file
295
backend/modules/data_entry/services/ocr_engine.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""OCR engine wrapper for PaddleOCR and Tesseract."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO) # Ensure logs are visible
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
# Lazy imports - these will be imported on first use
|
||||
PaddleOCR = None # Will be imported lazily
|
||||
pytesseract = None # Will be imported lazily
|
||||
|
||||
# Check availability without importing heavy libraries
|
||||
def _check_paddle_available() -> bool:
|
||||
"""Check if paddleocr is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("paddleocr") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_tesseract_available() -> bool:
|
||||
"""Check if pytesseract is installed without importing it."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("pytesseract") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
PADDLE_AVAILABLE = _check_paddle_available()
|
||||
TESSERACT_AVAILABLE = _check_tesseract_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Raw OCR result."""
|
||||
text: str
|
||||
confidence: float
|
||||
boxes: List[dict]
|
||||
engine: str = "" # OCR engine used: paddleocr or tesseract
|
||||
|
||||
|
||||
class OCREngine:
|
||||
"""Unified OCR engine with fallback support."""
|
||||
|
||||
def __init__(self):
|
||||
self._paddle = None
|
||||
self._paddle_init_started = False
|
||||
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
|
||||
self._paddle_init_lock = threading.Lock()
|
||||
|
||||
def _init_paddle_lazy(self):
|
||||
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
|
||||
global PaddleOCR
|
||||
|
||||
with self._paddle_init_lock:
|
||||
if self._paddle_init_started:
|
||||
return # Already initializing or done
|
||||
self._paddle_init_started = True
|
||||
|
||||
if PADDLE_AVAILABLE:
|
||||
try:
|
||||
print("Importing PaddleOCR (first use, may take ~15-20 seconds)...", flush=True)
|
||||
from paddleocr import PaddleOCR as _PaddleOCR
|
||||
PaddleOCR = _PaddleOCR
|
||||
|
||||
print("Initializing PaddleOCR engine...", flush=True)
|
||||
# PaddleOCR 3.x API - optimized for Romanian receipts
|
||||
# Note: 'latin' not available in PaddleOCR 3.x, 'en' works well for receipts
|
||||
self._paddle = PaddleOCR(
|
||||
lang='en', # 'en' handles Latin alphabet well for receipts
|
||||
# High quality settings for better accuracy
|
||||
det_db_thresh=0.3, # Lower threshold = detect more text (default 0.3)
|
||||
det_db_box_thresh=0.5, # Box confidence threshold (default 0.5)
|
||||
det_db_unclip_ratio=1.8, # Expand detected boxes slightly (default 1.5)
|
||||
rec_batch_num=6, # Batch size for recognition
|
||||
use_angle_cls=True, # Enable text angle classification
|
||||
)
|
||||
print("PaddleOCR initialized successfully with high-quality settings", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize PaddleOCR: {e}", flush=True)
|
||||
self._paddle = None
|
||||
|
||||
# Signal that initialization is complete (success or failure)
|
||||
self._paddle_ready.set()
|
||||
|
||||
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for PaddleOCR to be fully initialized.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (default 30s)
|
||||
|
||||
Returns:
|
||||
True if PaddleOCR is ready, False if timeout or unavailable
|
||||
"""
|
||||
if not PADDLE_AVAILABLE:
|
||||
return False
|
||||
|
||||
if self._paddle is not None:
|
||||
return True # Already ready
|
||||
|
||||
if not self._paddle_init_started:
|
||||
# Start initialization if not already started
|
||||
self._init_paddle_lazy()
|
||||
|
||||
# Wait for initialization to complete
|
||||
print(f"[OCR] Waiting for PaddleOCR to be ready (max {timeout}s)...", flush=True)
|
||||
start = time.time()
|
||||
ready = self._paddle_ready.wait(timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if ready and self._paddle is not None:
|
||||
print(f"[OCR] PaddleOCR ready after {elapsed:.1f}s", flush=True)
|
||||
return True
|
||||
else:
|
||||
print(f"[OCR] PaddleOCR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
|
||||
return False
|
||||
|
||||
def is_paddle_ready(self) -> bool:
|
||||
"""Check if PaddleOCR is ready without waiting."""
|
||||
return self._paddle is not None
|
||||
|
||||
def recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Perform OCR on preprocessed image."""
|
||||
logger.info(f"[OCR] Starting recognition, image shape: {image.shape}, dtype: {image.dtype}")
|
||||
|
||||
# Lazy init PaddleOCR on first call
|
||||
self._init_paddle_lazy()
|
||||
|
||||
if PADDLE_AVAILABLE and self._paddle:
|
||||
logger.info("[OCR] Using PaddleOCR engine")
|
||||
return self._paddle_recognize(image)
|
||||
elif TESSERACT_AVAILABLE:
|
||||
logger.info("[OCR] Using Tesseract engine (PaddleOCR not available)")
|
||||
return self._tesseract_recognize(image)
|
||||
else:
|
||||
logger.error("[OCR] No OCR engine available!")
|
||||
raise RuntimeError(
|
||||
"No OCR engine available. Install PaddleOCR or Tesseract."
|
||||
)
|
||||
|
||||
def _paddle_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using PaddleOCR 3.x API."""
|
||||
# Wait for PaddleOCR to be fully ready (handles background init)
|
||||
if not self.wait_for_paddle(timeout=30.0):
|
||||
logger.warning("[PaddleOCR] Not ready, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise RuntimeError("PaddleOCR not ready and Tesseract not available")
|
||||
|
||||
try:
|
||||
logger.info(f"[PaddleOCR] Processing image, shape: {image.shape}")
|
||||
|
||||
# PaddleOCR 3.x requires 3-channel images
|
||||
if len(image.shape) == 2:
|
||||
# Convert grayscale to 3-channel BGR
|
||||
import cv2
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
logger.info(f"[PaddleOCR] Converted to BGR, new shape: {image.shape}")
|
||||
|
||||
# PaddleOCR 3.x uses predict() with new parameter names
|
||||
logger.info("[PaddleOCR] Calling predict()...")
|
||||
result = self._paddle.predict(image, use_textline_orientation=True)
|
||||
logger.info(f"[PaddleOCR] predict() returned, result type: {type(result)}")
|
||||
|
||||
if not result or len(result) == 0:
|
||||
logger.warning("[PaddleOCR] No results returned")
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
# PaddleOCR 3.x returns OCRResult objects with different structure
|
||||
ocr_result = result[0]
|
||||
|
||||
# Extract texts and scores from the new format
|
||||
rec_texts = ocr_result.get('rec_texts', [])
|
||||
rec_scores = ocr_result.get('rec_scores', [])
|
||||
dt_polys = ocr_result.get('dt_polys', [])
|
||||
|
||||
if not rec_texts:
|
||||
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
boxes = []
|
||||
for i, text in enumerate(rec_texts):
|
||||
conf = rec_scores[i] if i < len(rec_scores) else 0.0
|
||||
box = dt_polys[i].tolist() if i < len(dt_polys) else []
|
||||
boxes.append({
|
||||
'text': text,
|
||||
'confidence': float(conf),
|
||||
'box': box
|
||||
})
|
||||
|
||||
avg_conf = sum(rec_scores) / len(rec_scores) if rec_scores else 0.0
|
||||
text_result = '\n'.join(rec_texts)
|
||||
logger.info(f"[PaddleOCR] SUCCESS - Found {len(rec_texts)} text lines, avg confidence: {avg_conf:.2%}")
|
||||
logger.debug(f"[PaddleOCR] Raw text preview: {text_result[:200]}...")
|
||||
return OCRResult(
|
||||
text=text_result,
|
||||
confidence=float(avg_conf),
|
||||
boxes=boxes,
|
||||
engine="paddleocr"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[PaddleOCR] ERROR: {e}, falling back to Tesseract")
|
||||
if TESSERACT_AVAILABLE:
|
||||
return self._tesseract_recognize(image)
|
||||
raise
|
||||
|
||||
def _tesseract_recognize(self, image: np.ndarray) -> OCRResult:
|
||||
"""Recognize text using Tesseract."""
|
||||
global pytesseract
|
||||
|
||||
logger.info(f"[Tesseract] Processing image, shape: {image.shape}")
|
||||
|
||||
# Lazy import pytesseract
|
||||
if pytesseract is None:
|
||||
logger.info("[Tesseract] Importing pytesseract...")
|
||||
import pytesseract as _pytesseract
|
||||
pytesseract = _pytesseract
|
||||
|
||||
# PSM 4: Single column (best for receipts)
|
||||
config = '--psm 4 -l ron+eng'
|
||||
text = pytesseract.image_to_string(image, config=config)
|
||||
|
||||
# Quick confidence estimate
|
||||
data = pytesseract.image_to_data(image, config=config, output_type=pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
||||
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
|
||||
|
||||
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
|
||||
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
|
||||
|
||||
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
|
||||
"""
|
||||
Run both OCR engines and return both results.
|
||||
|
||||
Returns:
|
||||
Tuple of (paddle_result, tesseract_result)
|
||||
tesseract_result may be None if Tesseract is not available
|
||||
"""
|
||||
logger.info(f"[OCR Dual] Starting dual recognition, image shape: {image.shape}")
|
||||
|
||||
# Lazy init PaddleOCR
|
||||
self._init_paddle_lazy()
|
||||
|
||||
paddle_result = None
|
||||
tesseract_result = None
|
||||
|
||||
# Run PaddleOCR
|
||||
if PADDLE_AVAILABLE and self._paddle:
|
||||
try:
|
||||
logger.info("[OCR Dual] Running PaddleOCR...")
|
||||
paddle_result = self._paddle_recognize(image)
|
||||
logger.info(f"[OCR Dual] PaddleOCR: {len(paddle_result.text)} chars, conf: {paddle_result.confidence:.2%}")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCR Dual] PaddleOCR failed: {e}")
|
||||
paddle_result = OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
|
||||
|
||||
# Run Tesseract
|
||||
if TESSERACT_AVAILABLE:
|
||||
try:
|
||||
logger.info("[OCR Dual] Running Tesseract...")
|
||||
tesseract_result = self._tesseract_recognize(image)
|
||||
logger.info(f"[OCR Dual] Tesseract: {len(tesseract_result.text)} chars, conf: {tesseract_result.confidence:.2%}")
|
||||
except Exception as e:
|
||||
logger.error(f"[OCR Dual] Tesseract failed: {e}")
|
||||
tesseract_result = OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
|
||||
|
||||
# Fallback if PaddleOCR not available
|
||||
if paddle_result is None:
|
||||
if tesseract_result:
|
||||
paddle_result = tesseract_result
|
||||
else:
|
||||
raise RuntimeError("No OCR engine available")
|
||||
|
||||
return paddle_result, tesseract_result
|
||||
|
||||
@staticmethod
|
||||
def get_available_engines() -> List[str]:
|
||||
"""Return list of available OCR engines."""
|
||||
engines = []
|
||||
if PADDLE_AVAILABLE:
|
||||
engines.append('paddleocr')
|
||||
if TESSERACT_AVAILABLE:
|
||||
engines.append('tesseract')
|
||||
return engines
|
||||
1501
backend/modules/data_entry/services/ocr_extractor.py
Normal file
1501
backend/modules/data_entry/services/ocr_extractor.py
Normal file
File diff suppressed because it is too large
Load Diff
569
backend/modules/data_entry/services/ocr_service.py
Normal file
569
backend/modules/data_entry/services/ocr_service.py
Normal file
@@ -0,0 +1,569 @@
|
||||
"""Main OCR service coordinating preprocessing, recognition, and extraction."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
|
||||
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x) - must be set before import
|
||||
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from backend.modules.data_entry.services.ocr_engine import OCREngine
|
||||
from backend.modules.data_entry.services.ocr_extractor import ReceiptExtractor, ExtractionResult
|
||||
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCRService:
|
||||
"""Service for OCR processing of receipt images."""
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
def __init__(self):
|
||||
self.preprocessor = ImagePreprocessor()
|
||||
self.ocr_engine = OCREngine()
|
||||
self.extractor = ReceiptExtractor()
|
||||
|
||||
async def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
mime_type: str
|
||||
) -> Tuple[bool, str, Optional[ExtractionResult]]:
|
||||
"""
|
||||
Process receipt image and extract structured data.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
mime_type: MIME type of the file
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message, extraction_result)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._process_sync,
|
||||
image_path,
|
||||
mime_type
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return False, f"OCR processing failed: {str(e)}", None
|
||||
|
||||
def _process_sync(
|
||||
self,
|
||||
image_path: Path,
|
||||
mime_type: str
|
||||
) -> Tuple[bool, str, Optional[ExtractionResult]]:
|
||||
"""Synchronous processing with ADAPTIVE OCR pipeline."""
|
||||
|
||||
start_time = time.time()
|
||||
print(f"[OCR Service] Starting processing: {image_path}, mime: {mime_type}", flush=True)
|
||||
|
||||
# Load image
|
||||
if mime_type == 'application/pdf':
|
||||
try:
|
||||
images = self.preprocessor.pdf_to_images(image_path)
|
||||
if not images:
|
||||
return False, "Failed to extract images from PDF", None
|
||||
image = images[0]
|
||||
except RuntimeError as e:
|
||||
return False, str(e), None
|
||||
else:
|
||||
try:
|
||||
image = self.preprocessor.load_image(image_path)
|
||||
except ValueError as e:
|
||||
return False, str(e), None
|
||||
|
||||
raw_texts = []
|
||||
extraction = None
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 1: PaddleOCR + Light (fastest, best for clear PDFs)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 1: PaddleOCR + Light preprocessing", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
light_img = self.preprocessor.preprocess_light(image)
|
||||
|
||||
try:
|
||||
paddle_light = self.ocr_engine._paddle_recognize(light_img)
|
||||
if paddle_light and paddle_light.text:
|
||||
extraction = self.extractor.extract(paddle_light.text)
|
||||
extraction.ocr_engine = "paddle-light"
|
||||
raw_texts.append(f"═══ PaddleOCR (light, conf: {paddle_light.confidence:.0%}) ═══\n{paddle_light.text}")
|
||||
|
||||
# Log extraction results
|
||||
print(f"[OCR] Step 1 Results:", flush=True)
|
||||
print(f" - OCR Confidence: {paddle_light.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total} (entries: {len(extraction.tva_entries) if extraction.tva_entries else 0})", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
|
||||
# Early exit if complete
|
||||
if self._is_extraction_complete(extraction):
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
print(f"[OCR] ✓✓✓ EARLY EXIT at Step 1 - All fields found! ({elapsed_ms}ms) ✓✓✓", flush=True)
|
||||
return True, "OCR complete (fast mode)", extraction
|
||||
else:
|
||||
print("[OCR] → Step 1 incomplete, continuing to Step 2...", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[OCR] PaddleOCR light failed: {e}", flush=True)
|
||||
extraction = ExtractionResult()
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 2: PaddleOCR + Heavy (for faded thermal receipts)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 2: PaddleOCR + Heavy preprocessing", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
heavy_img = self.preprocessor.preprocess_heavy(image)
|
||||
|
||||
try:
|
||||
paddle_heavy = self.ocr_engine._paddle_recognize(heavy_img)
|
||||
if paddle_heavy and paddle_heavy.text:
|
||||
extraction_heavy = self.extractor.extract(paddle_heavy.text)
|
||||
extraction_heavy.ocr_engine = "paddle-heavy"
|
||||
raw_texts.append(f"═══ PaddleOCR (heavy, conf: {paddle_heavy.confidence:.0%}) ═══\n{paddle_heavy.text}")
|
||||
|
||||
print(f"[OCR] Step 2 (Heavy) Results:", flush=True)
|
||||
print(f" - OCR Confidence: {paddle_heavy.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction_heavy.amount}", flush=True)
|
||||
print(f" - Date: {extraction_heavy.receipt_date}", flush=True)
|
||||
print(f" - CUI: {extraction_heavy.cui}", flush=True)
|
||||
|
||||
# Merge with previous
|
||||
extraction = self._merge_extractions(extraction, extraction_heavy)
|
||||
|
||||
print(f"[OCR] After merge:", flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total}", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
|
||||
if self._is_extraction_complete(extraction):
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
extraction.ocr_engine = "paddle-adaptive"
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
print(f"[OCR] ✓✓✓ EARLY EXIT at Step 2 - All fields found after merge! ({elapsed_ms}ms) ✓✓✓", flush=True)
|
||||
return True, "OCR complete (paddle dual)", extraction
|
||||
else:
|
||||
print("[OCR] → Step 2 incomplete, continuing to Step 3 (Tesseract)...", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[OCR] PaddleOCR heavy failed: {e}", flush=True)
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# STEP 3: Tesseract - ONLY to complete missing fields
|
||||
# Uses Tesseract-optimized preprocessing (binarized, high contrast)
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
print("=" * 60, flush=True)
|
||||
print("[OCR] STEP 3: Tesseract (complement only, not override)", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
try:
|
||||
# Use Tesseract-specific preprocessing (Otsu binarization)
|
||||
tesseract_img = self.preprocessor.preprocess_for_tesseract(image)
|
||||
tesseract_result = self.ocr_engine._tesseract_recognize(tesseract_img)
|
||||
if tesseract_result and tesseract_result.text:
|
||||
extraction_tess = self.extractor.extract(tesseract_result.text)
|
||||
extraction_tess.ocr_engine = "tesseract"
|
||||
raw_texts.append(f"═══ Tesseract (conf: {tesseract_result.confidence:.0%}) ═══\n{tesseract_result.text}")
|
||||
|
||||
print(f"[OCR] Step 3 (Tesseract) Results:", flush=True)
|
||||
print(f" - OCR Confidence: {tesseract_result.confidence:.0%}", flush=True)
|
||||
print(f" - Amount: {extraction_tess.amount}", flush=True)
|
||||
print(f" - Date: {extraction_tess.receipt_date}", flush=True)
|
||||
print(f" - CUI: {extraction_tess.cui}", flush=True)
|
||||
|
||||
# IMPORTANT: Tesseract only COMPLETES missing fields, never overrides!
|
||||
extraction = self._complement_extraction(extraction, extraction_tess)
|
||||
except Exception as e:
|
||||
print(f"[OCR] Tesseract failed: {e}", flush=True)
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# FINAL VALIDATION: Fix impossible values
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
if extraction:
|
||||
extraction = self._final_validation(extraction)
|
||||
|
||||
# Final result
|
||||
if extraction is None:
|
||||
return False, "No text detected", None
|
||||
|
||||
extraction.raw_text = "\n\n".join(raw_texts)
|
||||
extraction.ocr_engine = "adaptive-full"
|
||||
|
||||
# Build result message
|
||||
fields_found = []
|
||||
if extraction.amount: fields_found.append("amount")
|
||||
if extraction.receipt_date: fields_found.append("date")
|
||||
if extraction.receipt_number: fields_found.append("number")
|
||||
if extraction.cui: fields_found.append("CUI")
|
||||
if extraction.tva_total or extraction.tva_entries: fields_found.append("TVA")
|
||||
|
||||
message = f"OCR complete (full pipeline). Found: {', '.join(fields_found) or 'no fields'}"
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
extraction.processing_time_ms = elapsed_ms
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print(f"[OCR] FINAL RESULT (full pipeline) - {elapsed_ms}ms", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
print(f" - Amount: {extraction.amount}", flush=True)
|
||||
print(f" - Date: {extraction.receipt_date}", flush=True)
|
||||
print(f" - Number: {extraction.receipt_number}", flush=True)
|
||||
print(f" - CUI: {extraction.cui}", flush=True)
|
||||
print(f" - TVA: {extraction.tva_total}", flush=True)
|
||||
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
|
||||
print(f" - Processing Time: {elapsed_ms}ms", flush=True)
|
||||
print(f" - Message: {message}", flush=True)
|
||||
|
||||
return True, message, extraction
|
||||
|
||||
def _merge_extractions(
|
||||
self,
|
||||
paddle: Optional[ExtractionResult],
|
||||
tesseract: Optional[ExtractionResult]
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Merge two extractions, picking best fields from each engine.
|
||||
|
||||
Strategy:
|
||||
- For each field, prefer the one with higher confidence
|
||||
- Use validation rules (CUI format, date validity, company indicators)
|
||||
- Combine TVA entries if different
|
||||
"""
|
||||
result = ExtractionResult()
|
||||
|
||||
# Handle case where one is None
|
||||
if paddle is None and tesseract is None:
|
||||
return result
|
||||
if paddle is None:
|
||||
return tesseract
|
||||
if tesseract is None:
|
||||
return paddle
|
||||
|
||||
print("[Merge] Comparing PaddleOCR vs Tesseract extractions...", flush=True)
|
||||
|
||||
# === AMOUNT ===
|
||||
# Pick higher confidence, both must be positive
|
||||
if paddle.amount and tesseract.amount:
|
||||
if paddle.confidence_amount >= tesseract.confidence_amount:
|
||||
result.amount = paddle.amount
|
||||
result.confidence_amount = paddle.confidence_amount
|
||||
print(f"[Merge] Amount: PaddleOCR {paddle.amount} (conf: {paddle.confidence_amount:.0%})", flush=True)
|
||||
else:
|
||||
result.amount = tesseract.amount
|
||||
result.confidence_amount = tesseract.confidence_amount
|
||||
print(f"[Merge] Amount: Tesseract {tesseract.amount} (conf: {tesseract.confidence_amount:.0%})", flush=True)
|
||||
elif paddle.amount:
|
||||
result.amount = paddle.amount
|
||||
result.confidence_amount = paddle.confidence_amount
|
||||
elif tesseract.amount:
|
||||
result.amount = tesseract.amount
|
||||
result.confidence_amount = tesseract.confidence_amount
|
||||
|
||||
# === DATE ===
|
||||
# Pick higher confidence, validate date reasonableness
|
||||
if paddle.receipt_date and tesseract.receipt_date:
|
||||
if paddle.confidence_date >= tesseract.confidence_date:
|
||||
result.receipt_date = paddle.receipt_date
|
||||
result.confidence_date = paddle.confidence_date
|
||||
print(f"[Merge] Date: PaddleOCR {paddle.receipt_date}", flush=True)
|
||||
else:
|
||||
result.receipt_date = tesseract.receipt_date
|
||||
result.confidence_date = tesseract.confidence_date
|
||||
print(f"[Merge] Date: Tesseract {tesseract.receipt_date}", flush=True)
|
||||
elif paddle.receipt_date:
|
||||
result.receipt_date = paddle.receipt_date
|
||||
result.confidence_date = paddle.confidence_date
|
||||
elif tesseract.receipt_date:
|
||||
result.receipt_date = tesseract.receipt_date
|
||||
result.confidence_date = tesseract.confidence_date
|
||||
|
||||
# === VENDOR NAME ===
|
||||
# Prefer one with company indicators (S.R.L., S.A., etc.)
|
||||
paddle_has_indicator = self._has_company_indicator(paddle.partner_name)
|
||||
tesseract_has_indicator = self._has_company_indicator(tesseract.partner_name)
|
||||
|
||||
if paddle.partner_name and tesseract.partner_name:
|
||||
if paddle_has_indicator and not tesseract_has_indicator:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (has company indicator)", flush=True)
|
||||
elif tesseract_has_indicator and not paddle_has_indicator:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (has company indicator)", flush=True)
|
||||
elif paddle.confidence_vendor >= tesseract.confidence_vendor:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (higher conf)", flush=True)
|
||||
else:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (higher conf)", flush=True)
|
||||
elif paddle.partner_name:
|
||||
result.partner_name = paddle.partner_name
|
||||
result.confidence_vendor = paddle.confidence_vendor
|
||||
elif tesseract.partner_name:
|
||||
result.partner_name = tesseract.partner_name
|
||||
result.confidence_vendor = tesseract.confidence_vendor
|
||||
|
||||
# === CUI (Fiscal Code) ===
|
||||
# Validate format: 6-10 digits, prefer valid one
|
||||
paddle_cui_valid = self._is_valid_cui(paddle.cui)
|
||||
tesseract_cui_valid = self._is_valid_cui(tesseract.cui)
|
||||
|
||||
if paddle.cui and tesseract.cui:
|
||||
if paddle_cui_valid and not tesseract_cui_valid:
|
||||
result.cui = paddle.cui
|
||||
print(f"[Merge] CUI: PaddleOCR {paddle.cui} (valid format)", flush=True)
|
||||
elif tesseract_cui_valid and not paddle_cui_valid:
|
||||
result.cui = tesseract.cui
|
||||
print(f"[Merge] CUI: Tesseract {tesseract.cui} (valid format)", flush=True)
|
||||
else:
|
||||
# Both valid or both invalid - prefer PaddleOCR
|
||||
result.cui = paddle.cui
|
||||
print(f"[Merge] CUI: PaddleOCR {paddle.cui}", flush=True)
|
||||
elif paddle.cui and paddle_cui_valid:
|
||||
result.cui = paddle.cui
|
||||
elif tesseract.cui and tesseract_cui_valid:
|
||||
result.cui = tesseract.cui
|
||||
elif paddle.cui:
|
||||
result.cui = paddle.cui
|
||||
elif tesseract.cui:
|
||||
result.cui = tesseract.cui
|
||||
|
||||
# === TVA ENTRIES ===
|
||||
# Prefer non-empty, use the one with more entries or higher amounts
|
||||
if paddle.tva_entries and tesseract.tva_entries:
|
||||
# Compare: prefer the one with actual amounts (not just 0)
|
||||
paddle_total = sum(e.get('amount', Decimal('0')) for e in paddle.tva_entries)
|
||||
tesseract_total = sum(e.get('amount', Decimal('0')) for e in tesseract.tva_entries)
|
||||
|
||||
if paddle_total >= tesseract_total:
|
||||
result.tva_entries = paddle.tva_entries
|
||||
result.tva_total = paddle.tva_total
|
||||
print(f"[Merge] TVA: PaddleOCR (total: {paddle_total})", flush=True)
|
||||
else:
|
||||
result.tva_entries = tesseract.tva_entries
|
||||
result.tva_total = tesseract.tva_total
|
||||
print(f"[Merge] TVA: Tesseract (total: {tesseract_total})", flush=True)
|
||||
elif paddle.tva_entries:
|
||||
result.tva_entries = paddle.tva_entries
|
||||
result.tva_total = paddle.tva_total
|
||||
elif tesseract.tva_entries:
|
||||
result.tva_entries = tesseract.tva_entries
|
||||
result.tva_total = tesseract.tva_total
|
||||
|
||||
# === OTHER FIELDS ===
|
||||
# Simple preference: paddle > tesseract
|
||||
result.receipt_number = paddle.receipt_number or tesseract.receipt_number
|
||||
result.receipt_series = paddle.receipt_series or tesseract.receipt_series
|
||||
result.receipt_type = paddle.receipt_type or tesseract.receipt_type
|
||||
result.items_count = paddle.items_count or tesseract.items_count
|
||||
result.address = paddle.address or tesseract.address
|
||||
result.description = paddle.description or tesseract.description
|
||||
|
||||
return result
|
||||
|
||||
def _has_company_indicator(self, name: Optional[str]) -> bool:
|
||||
"""Check if vendor name has company type indicator (S.R.L., S.A., etc.)"""
|
||||
if not name:
|
||||
return False
|
||||
name_upper = name.upper()
|
||||
indicators = [
|
||||
r'\bS\.?\s*R\.?\s*L\.?\b',
|
||||
r'\bS\.?\s*A\.?\b',
|
||||
r'\bS\.?\s*N\.?\s*C\.?\b',
|
||||
r'\bP\.?\s*F\.?\s*A\.?\b',
|
||||
r'\bI\.?\s*I\.?\b',
|
||||
r'\bHOLDING\b',
|
||||
r'\bGROUP\b',
|
||||
r'\bCOMPANY\b',
|
||||
]
|
||||
for indicator in indicators:
|
||||
if re.search(indicator, name_upper):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_valid_cui(self, cui: Optional[str]) -> bool:
|
||||
"""Validate CUI format: 6-10 digits."""
|
||||
if not cui:
|
||||
return False
|
||||
# Remove any RO prefix
|
||||
cui_clean = re.sub(r'^RO', '', cui.upper())
|
||||
# Must be 6-10 digits
|
||||
return bool(re.match(r'^\d{6,10}$', cui_clean))
|
||||
|
||||
def _is_extraction_complete(self, ext: ExtractionResult, min_confidence: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if extraction has ALL required fields to skip further processing.
|
||||
|
||||
Required for early exit (ALL must be true):
|
||||
- Overall confidence >= 85%
|
||||
- ALL 5 critical fields present: number, date, amount, TVA, CUI
|
||||
"""
|
||||
# Must have high confidence
|
||||
if ext.overall_confidence < min_confidence:
|
||||
print(f"[OCR] Confidence {ext.overall_confidence:.0%} < {min_confidence:.0%} - continuing", flush=True)
|
||||
return False
|
||||
|
||||
# Check all required fields
|
||||
has_number = bool(ext.receipt_number)
|
||||
has_date = bool(ext.receipt_date)
|
||||
has_amount = bool(ext.amount)
|
||||
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
|
||||
has_cui = bool(ext.cui)
|
||||
|
||||
missing = []
|
||||
if not has_number: missing.append("number")
|
||||
if not has_date: missing.append("date")
|
||||
if not has_amount: missing.append("amount")
|
||||
if not has_tva: missing.append("TVA")
|
||||
if not has_cui: missing.append("CUI")
|
||||
|
||||
if missing:
|
||||
print(f"[OCR] Missing: {', '.join(missing)} - continuing", flush=True)
|
||||
return False
|
||||
|
||||
print(f"[OCR] ✓ All 5 fields found with {ext.overall_confidence:.0%} confidence", flush=True)
|
||||
return True
|
||||
|
||||
def _complement_extraction(
|
||||
self,
|
||||
primary: Optional[ExtractionResult],
|
||||
secondary: Optional[ExtractionResult]
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Complement primary extraction with missing fields from secondary.
|
||||
NEVER overrides existing values - only fills in gaps.
|
||||
|
||||
This is different from _merge_extractions which can override values.
|
||||
"""
|
||||
if primary is None and secondary is None:
|
||||
return ExtractionResult()
|
||||
if primary is None:
|
||||
return secondary
|
||||
if secondary is None:
|
||||
return primary
|
||||
|
||||
print("[Complement] Adding missing fields from Tesseract...", flush=True)
|
||||
|
||||
# Only fill missing amount
|
||||
if not primary.amount and secondary.amount:
|
||||
primary.amount = secondary.amount
|
||||
primary.confidence_amount = secondary.confidence_amount
|
||||
print(f"[Complement] Added amount: {secondary.amount}", flush=True)
|
||||
|
||||
# Only fill missing date
|
||||
if not primary.receipt_date and secondary.receipt_date:
|
||||
primary.receipt_date = secondary.receipt_date
|
||||
primary.confidence_date = secondary.confidence_date
|
||||
print(f"[Complement] Added date: {secondary.receipt_date}", flush=True)
|
||||
|
||||
# Only fill missing vendor
|
||||
if not primary.partner_name and secondary.partner_name:
|
||||
primary.partner_name = secondary.partner_name
|
||||
primary.confidence_vendor = secondary.confidence_vendor
|
||||
print(f"[Complement] Added vendor: {secondary.partner_name}", flush=True)
|
||||
|
||||
# Only fill missing CUI
|
||||
if not primary.cui and secondary.cui and self._is_valid_cui(secondary.cui):
|
||||
primary.cui = secondary.cui
|
||||
print(f"[Complement] Added CUI: {secondary.cui}", flush=True)
|
||||
|
||||
# Only fill missing TVA
|
||||
if not primary.tva_entries and secondary.tva_entries:
|
||||
primary.tva_entries = secondary.tva_entries
|
||||
primary.tva_total = secondary.tva_total
|
||||
print(f"[Complement] Added TVA: {secondary.tva_total}", flush=True)
|
||||
|
||||
# Only fill missing receipt number
|
||||
if not primary.receipt_number and secondary.receipt_number:
|
||||
primary.receipt_number = secondary.receipt_number
|
||||
print(f"[Complement] Added number: {secondary.receipt_number}", flush=True)
|
||||
|
||||
# Only fill missing address
|
||||
if not primary.address and secondary.address:
|
||||
primary.address = secondary.address
|
||||
print(f"[Complement] Added address: {secondary.address}", flush=True)
|
||||
|
||||
return primary
|
||||
|
||||
def _final_validation(self, extraction: ExtractionResult) -> ExtractionResult:
|
||||
"""
|
||||
Final validation and correction of impossible values.
|
||||
|
||||
Key rules:
|
||||
1. TVA cannot be greater than TOTAL (it's always a fraction)
|
||||
2. If TVA > TOTAL, recalculate TOTAL from TVA using known rates
|
||||
3. Validate TVA entries sum equals TVA total
|
||||
"""
|
||||
print("[Final Validation] Checking extracted values...", flush=True)
|
||||
|
||||
# Rule 1: TVA cannot be greater than TOTAL
|
||||
if extraction.tva_total and extraction.amount:
|
||||
if extraction.tva_total > extraction.amount:
|
||||
print(f"[Final Validation] TVA ({extraction.tva_total}) > TOTAL ({extraction.amount}) - IMPOSSIBLE!", flush=True)
|
||||
|
||||
# Calculate TOTAL from TVA using reverse formula:
|
||||
# total = base + tva = tva * (100/rate + 1) = tva * (100 + rate) / rate
|
||||
# For 9% TVA: total = tva * 109 / 9 = tva * 12.11
|
||||
# For 19% TVA: total = tva * 119 / 19 = tva * 6.26
|
||||
# For 21% TVA: total = tva * 121 / 21 = tva * 5.76
|
||||
|
||||
rate = 19 # Default rate assumption
|
||||
if extraction.tva_entries:
|
||||
# Use the rate from the first entry
|
||||
rate = extraction.tva_entries[0].get('percent', 19)
|
||||
|
||||
if rate > 0:
|
||||
# Formula: total = tva * (100 + rate) / rate
|
||||
calculated_total = extraction.tva_total * (Decimal('100') + Decimal(str(rate))) / Decimal(str(rate))
|
||||
calculated_total = calculated_total.quantize(Decimal('0.01'))
|
||||
|
||||
print(f"[Final Validation] Calculated TOTAL from TVA: {calculated_total} (using {rate}% rate)", flush=True)
|
||||
|
||||
extraction.amount = calculated_total
|
||||
extraction.confidence_amount = 0.70 # Lower confidence for calculated value
|
||||
|
||||
# Rule 2: TVA cannot be more than ~25% of total (max Romanian rate is 21%)
|
||||
if extraction.tva_total and extraction.amount:
|
||||
tva_percent = extraction.tva_total / extraction.amount * Decimal('100')
|
||||
if tva_percent > Decimal('25'):
|
||||
print(f"[Final Validation] Warning: TVA is {tva_percent:.1f}% of total - suspicious", flush=True)
|
||||
|
||||
# Rule 3: Validate TVA entries sum
|
||||
if extraction.tva_entries and extraction.tva_total:
|
||||
entries_sum = sum(e.get('amount', Decimal('0')) for e in extraction.tva_entries)
|
||||
tolerance = Decimal('0.05')
|
||||
if abs(entries_sum - extraction.tva_total) > tolerance:
|
||||
print(f"[Final Validation] TVA entries sum ({entries_sum}) != tva_total ({extraction.tva_total})", flush=True)
|
||||
# Use the sum as it's more reliable
|
||||
extraction.tva_total = entries_sum
|
||||
|
||||
print(f"[Final Validation] Done. Amount={extraction.amount}, TVA={extraction.tva_total}", flush=True)
|
||||
return extraction
|
||||
|
||||
|
||||
# Singleton instance
|
||||
ocr_service = OCRService()
|
||||
447
backend/modules/data_entry/services/receipt_service.py
Normal file
447
backend/modules/data_entry/services/receipt_service.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Business logic service for receipts workflow."""
|
||||
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus, ReceiptDirection
|
||||
from backend.modules.data_entry.db.models.accounting_entry import EntryType
|
||||
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
|
||||
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
|
||||
from backend.modules.data_entry.schemas.receipt import (
|
||||
ReceiptCreate,
|
||||
ReceiptUpdate,
|
||||
ReceiptFilter,
|
||||
ReceiptResponse,
|
||||
ReceiptListResponse,
|
||||
AccountingEntryCreate,
|
||||
)
|
||||
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES, get_expense_type
|
||||
|
||||
|
||||
# Payment mode to accounting account mapping
|
||||
PAYMENT_MODE_ACCOUNTS = {
|
||||
'casa': ('5311', 'Casa in lei'),
|
||||
'banca': ('5121', 'Conturi la banci in lei'),
|
||||
'avans_decontare': ('542', 'Avansuri de trezorerie'),
|
||||
}
|
||||
|
||||
|
||||
class ReceiptService:
|
||||
"""Service for receipt business logic and workflow."""
|
||||
|
||||
@staticmethod
|
||||
async def create_receipt(
|
||||
session: AsyncSession,
|
||||
data: ReceiptCreate,
|
||||
created_by: str,
|
||||
) -> Receipt:
|
||||
"""Create a new receipt in DRAFT status."""
|
||||
return await ReceiptCRUD.create(session, data, created_by)
|
||||
|
||||
@staticmethod
|
||||
async def get_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
) -> Optional[Receipt]:
|
||||
"""Get receipt by ID with all relationships."""
|
||||
return await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
|
||||
|
||||
@staticmethod
|
||||
async def get_receipts(
|
||||
session: AsyncSession,
|
||||
filters: ReceiptFilter,
|
||||
) -> ReceiptListResponse:
|
||||
"""Get paginated list of receipts."""
|
||||
receipts, total = await ReceiptCRUD.get_list(session, filters)
|
||||
|
||||
pages = (total + filters.page_size - 1) // filters.page_size if total > 0 else 1
|
||||
|
||||
return ReceiptListResponse(
|
||||
items=[ReceiptResponse.model_validate(r) for r in receipts],
|
||||
total=total,
|
||||
page=filters.page,
|
||||
page_size=filters.page_size,
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
data: ReceiptUpdate,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Update receipt (only DRAFT status).
|
||||
Returns (success, message, receipt).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if not await ReceiptCRUD.can_edit(receipt, username):
|
||||
return False, "Cannot edit this receipt", None
|
||||
|
||||
updated = await ReceiptCRUD.update(session, receipt, data)
|
||||
return True, "Receipt updated", updated
|
||||
|
||||
@staticmethod
|
||||
async def delete_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Delete receipt (only DRAFT status).
|
||||
Returns (success, message).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found"
|
||||
|
||||
if not await ReceiptCRUD.can_delete(receipt, username):
|
||||
return False, "Cannot delete this receipt"
|
||||
|
||||
await ReceiptCRUD.delete(session, receipt)
|
||||
return True, "Receipt deleted"
|
||||
|
||||
@staticmethod
|
||||
def generate_accounting_entries(receipt: Receipt) -> List[AccountingEntryCreate]:
|
||||
"""
|
||||
Generate accounting entries based on receipt data and expense type.
|
||||
"""
|
||||
entries: List[AccountingEntryCreate] = []
|
||||
|
||||
# Get expense type configuration
|
||||
expense_type = get_expense_type(receipt.expense_type_code or "OTHER")
|
||||
if not expense_type:
|
||||
expense_type = EXPENSE_TYPES["OTHER"]
|
||||
|
||||
amount = Decimal(str(receipt.amount))
|
||||
|
||||
if receipt.direction == ReceiptDirection.CHELTUIALA:
|
||||
# Expense: Debit expense account, Credit cash/bank
|
||||
if expense_type.has_vat:
|
||||
# Calculate net and VAT
|
||||
vat_rate = expense_type.vat_percent / Decimal("100")
|
||||
net_amount = (amount / (1 + vat_rate)).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
vat_amount = amount - net_amount
|
||||
|
||||
# Debit: Expense account (net)
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.account_code,
|
||||
account_name=expense_type.account_name,
|
||||
amount=net_amount,
|
||||
))
|
||||
|
||||
# Debit: VAT deductible
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.vat_account,
|
||||
account_name="TVA deductibila",
|
||||
amount=vat_amount,
|
||||
))
|
||||
else:
|
||||
# No VAT - full amount to expense
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=expense_type.account_code,
|
||||
account_name=expense_type.account_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
# Credit entry - based on payment_mode (new) or cash_register (legacy)
|
||||
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
|
||||
credit_account, credit_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
|
||||
elif receipt.cash_register_account:
|
||||
# Backwards compatibility for existing receipts
|
||||
credit_account = receipt.cash_register_account
|
||||
credit_name = receipt.cash_register_name or "Casa/Banca"
|
||||
else:
|
||||
# Default fallback
|
||||
credit_account = "5311"
|
||||
credit_name = "Casa in lei"
|
||||
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.CREDIT,
|
||||
account_code=credit_account,
|
||||
account_name=credit_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
else:
|
||||
# Income: Debit cash/bank, Credit income account
|
||||
# Based on payment_mode (new) or cash_register (legacy)
|
||||
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
|
||||
cash_account, cash_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
|
||||
elif receipt.cash_register_account:
|
||||
cash_account = receipt.cash_register_account
|
||||
cash_name = receipt.cash_register_name or "Casa/Banca"
|
||||
else:
|
||||
cash_account = "5311"
|
||||
cash_name = "Casa in lei"
|
||||
|
||||
# Debit: Cash/Bank
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.DEBIT,
|
||||
account_code=cash_account,
|
||||
account_name=cash_name,
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
# Credit: Income account (7xx - to be configured)
|
||||
entries.append(AccountingEntryCreate(
|
||||
entry_type=EntryType.CREDIT,
|
||||
account_code="7588",
|
||||
account_name="Alte venituri din exploatare",
|
||||
amount=amount,
|
||||
))
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
async def submit_for_review(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Submit receipt for review (DRAFT/REJECTED → PENDING_REVIEW).
|
||||
Generates accounting entries automatically.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if not await ReceiptCRUD.can_submit(receipt, username):
|
||||
return False, "Cannot submit this receipt", None
|
||||
|
||||
# Check if receipt has at least one attachment
|
||||
if not receipt.attachments:
|
||||
return False, "Receipt must have at least one attachment", None
|
||||
|
||||
# Check required fields
|
||||
if not receipt.expense_type_code:
|
||||
return False, "Expense type is required", None
|
||||
|
||||
# Validate payment_mode or cash_register (backwards compatibility)
|
||||
if not receipt.payment_mode and not receipt.cash_register_account:
|
||||
return False, "Modul de plata este obligatoriu", None
|
||||
|
||||
# Generate accounting entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
|
||||
# Delete existing entries and create new ones
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
# Refresh receipt to clear stale relationship references after entry deletion
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
# Reload with entries
|
||||
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
return True, "Receipt submitted for review", updated
|
||||
|
||||
@staticmethod
|
||||
async def approve_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Approve receipt (PENDING_REVIEW → APPROVED).
|
||||
Requires valid CUI (fiscal code) for approval.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Receipt is not pending review", None
|
||||
|
||||
# Validate CUI is present (required for Oracle import)
|
||||
if not receipt.cui:
|
||||
return False, "Trebuie completat codul fiscal (CUI) pentru aprobare", None
|
||||
|
||||
# Validate accounting entries
|
||||
if not receipt.entries:
|
||||
return False, "Receipt has no accounting entries", None
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.APPROVED, reviewed_by=username
|
||||
)
|
||||
|
||||
return True, "Receipt approved", updated
|
||||
|
||||
@staticmethod
|
||||
async def unapprove_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Unapprove receipt (APPROVED → PENDING_REVIEW).
|
||||
Returns receipt to pending review for corrections.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.APPROVED:
|
||||
return False, "Receipt is not approved", None
|
||||
|
||||
# Update status back to pending review
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
return True, "Receipt returned to pending review", updated
|
||||
|
||||
@staticmethod
|
||||
async def reject_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
reason: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Reject receipt (PENDING_REVIEW → REJECTED).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Receipt is not pending review", None
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session,
|
||||
receipt,
|
||||
ReceiptStatus.REJECTED,
|
||||
reviewed_by=username,
|
||||
rejection_reason=reason,
|
||||
)
|
||||
|
||||
return True, "Receipt rejected", updated
|
||||
|
||||
@staticmethod
|
||||
async def resubmit_receipt(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, Optional[Receipt]]:
|
||||
"""
|
||||
Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", None
|
||||
|
||||
if receipt.status != ReceiptStatus.REJECTED:
|
||||
return False, "Receipt is not rejected", None
|
||||
|
||||
if receipt.created_by != username:
|
||||
return False, "Only the creator can resubmit", None
|
||||
|
||||
# Re-generate accounting entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
# Refresh receipt to clear stale relationship references after entry deletion
|
||||
await session.refresh(receipt)
|
||||
|
||||
# Update status
|
||||
updated = await ReceiptCRUD.update_status(
|
||||
session, receipt, ReceiptStatus.PENDING_REVIEW
|
||||
)
|
||||
|
||||
# Reload with entries
|
||||
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
return True, "Receipt resubmitted for review", updated
|
||||
|
||||
@staticmethod
|
||||
async def regenerate_entries(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
username: str,
|
||||
) -> Tuple[bool, str, List[AccountingEntryCreate]]:
|
||||
"""
|
||||
Regenerate accounting entries for a receipt.
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", []
|
||||
|
||||
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.PENDING_REVIEW]:
|
||||
return False, "Cannot regenerate entries for this receipt status", []
|
||||
|
||||
# Generate new entries
|
||||
entries = ReceiptService.generate_accounting_entries(receipt)
|
||||
|
||||
# Replace existing entries
|
||||
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
|
||||
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
|
||||
|
||||
return True, "Entries regenerated", entries
|
||||
|
||||
@staticmethod
|
||||
async def update_entries(
|
||||
session: AsyncSession,
|
||||
receipt_id: int,
|
||||
entries: List[AccountingEntryCreate],
|
||||
username: str,
|
||||
) -> Tuple[bool, str, List]:
|
||||
"""
|
||||
Update accounting entries for a receipt (accountant action).
|
||||
"""
|
||||
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
|
||||
|
||||
if not receipt:
|
||||
return False, "Receipt not found", []
|
||||
|
||||
if receipt.status != ReceiptStatus.PENDING_REVIEW:
|
||||
return False, "Can only modify entries for receipts pending review", []
|
||||
|
||||
# Validate entries
|
||||
is_valid, error = await AccountingEntryCRUD.validate_entries(entries)
|
||||
if not is_valid:
|
||||
return False, error, []
|
||||
|
||||
# Replace entries
|
||||
updated_entries = await AccountingEntryCRUD.replace_all_for_receipt(
|
||||
session, receipt_id, entries, username
|
||||
)
|
||||
|
||||
return True, "Entries updated", updated_entries
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_count(
|
||||
session: AsyncSession,
|
||||
company_id: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Get count of receipts pending review."""
|
||||
receipts = await ReceiptCRUD.get_pending_review(session, company_id)
|
||||
return len(receipts)
|
||||
406
backend/modules/data_entry/services/sync_service.py
Normal file
406
backend/modules/data_entry/services/sync_service.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""Service for syncing nomenclatures from Oracle to SQLite."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# Path setup handled by main.py - this is redundant
|
||||
# project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
# sys.path.insert(0, str(project_root / "shared"))
|
||||
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for schema lookups (populated dynamically from Oracle)
|
||||
_schema_cache: dict[int, str] = {}
|
||||
|
||||
|
||||
class SyncService:
|
||||
"""Service for syncing nomenclatures from Oracle."""
|
||||
|
||||
@staticmethod
|
||||
async def get_schema_for_company(company_id: int) -> Optional[str]:
|
||||
"""
|
||||
Get Oracle schema for company ID from V_NOM_FIRME view.
|
||||
Results are cached in memory for performance.
|
||||
"""
|
||||
# Check cache first
|
||||
if company_id in _schema_cache:
|
||||
return _schema_cache[company_id]
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT SCHEMA
|
||||
FROM CONTAFIN_ORACLE.V_NOM_FIRME
|
||||
WHERE ID_FIRMA = :company_id
|
||||
""", {'company_id': company_id})
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
schema = result[0]
|
||||
_schema_cache[company_id] = schema
|
||||
logger.info(f"Resolved schema for company {company_id}: {schema}")
|
||||
return schema
|
||||
else:
|
||||
logger.warning(f"No schema found for company {company_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching schema for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def sync_suppliers(session: AsyncSession, company_id: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync suppliers (furnizori, id_tip_part=17) from Oracle to SQLite.
|
||||
Uses CORESP_TIP_PART joined with VNOM_PARTENERI view.
|
||||
Returns (synced_count, error_count).
|
||||
"""
|
||||
schema = await SyncService.get_schema_for_company(company_id)
|
||||
if not schema:
|
||||
logger.warning(f"No schema mapping for company {company_id}")
|
||||
return 0, 0
|
||||
|
||||
synced = 0
|
||||
errors = 0
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Fetch active suppliers from Oracle
|
||||
# id_tip_part = 17 means "furnizori" (suppliers)
|
||||
# Using CORESP_TIP_PART to filter by partner type
|
||||
cursor.execute(f"""
|
||||
SELECT B.ID_PART, B.DENUMIRE, B.COD_FISCAL, B.ADRESA
|
||||
FROM {schema}.CORESP_TIP_PART A
|
||||
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
|
||||
WHERE A.ID_TIP_PART = 17
|
||||
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
|
||||
AND B.ID_PART IS NOT NULL
|
||||
ORDER BY B.DENUMIRE
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
oracle_id, name, fiscal_code, address = row
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.oracle_id == oracle_id,
|
||||
SyncedSupplier.company_id == company_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Update existing record
|
||||
existing.name = name or ""
|
||||
existing.fiscal_code = fiscal_code
|
||||
existing.address = address
|
||||
existing.synced_at = datetime.utcnow()
|
||||
logger.debug(f"Updated supplier {oracle_id}: {name}")
|
||||
else:
|
||||
# Create new record
|
||||
supplier = SyncedSupplier(
|
||||
oracle_id=oracle_id,
|
||||
company_id=company_id,
|
||||
name=name or "",
|
||||
fiscal_code=fiscal_code,
|
||||
address=address,
|
||||
)
|
||||
session.add(supplier)
|
||||
logger.debug(f"Created supplier {oracle_id}: {name}")
|
||||
|
||||
synced += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing supplier row {row}: {e}")
|
||||
errors += 1
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
logger.info(f"Synced {synced} suppliers for company {company_id}, {errors} errors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing suppliers for company {company_id}: {e}")
|
||||
errors += 1
|
||||
await session.rollback()
|
||||
|
||||
return synced, errors
|
||||
|
||||
@staticmethod
|
||||
async def sync_cash_registers(session: AsyncSession, company_id: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Sync cash registers and bank accounts from Oracle to SQLite.
|
||||
Returns (synced_count, error_count).
|
||||
|
||||
Uses CORESP_TIP_PART with:
|
||||
- id_tip_part = 22: CASA LEI
|
||||
- id_tip_part = 23: CASA VALUTA
|
||||
- id_tip_part = 24: BANCA LEI
|
||||
- id_tip_part = 25: BANCA VALUTA
|
||||
"""
|
||||
schema = await SyncService.get_schema_for_company(company_id)
|
||||
if not schema:
|
||||
logger.warning(f"No schema mapping for company {company_id}")
|
||||
return 0, 0
|
||||
|
||||
synced = 0
|
||||
errors = 0
|
||||
|
||||
# Partner types mapping
|
||||
# 22=CASA LEI, 23=CASA VALUTA -> cash
|
||||
# 24=BANCA LEI, 25=BANCA VALUTA -> bank
|
||||
partner_types = [22, 23, 24, 25]
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Fetch cash/bank partners from CORESP_TIP_PART
|
||||
cursor.execute(f"""
|
||||
SELECT B.ID_PART, B.DENUMIRE, A.ID_TIP_PART
|
||||
FROM {schema}.CORESP_TIP_PART A
|
||||
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
|
||||
WHERE A.ID_TIP_PART IN (22, 23, 24, 25)
|
||||
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
|
||||
AND B.ID_PART IS NOT NULL
|
||||
ORDER BY A.ID_TIP_PART, B.DENUMIRE
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Type mapping: 22=CASA LEI, 23=CASA VALUTA -> cash; 24=BANCA LEI, 25=BANCA VALUTA -> bank
|
||||
type_mapping = {
|
||||
22: ("cash", "CASA_LEI"),
|
||||
23: ("cash", "CASA_VALUTA"),
|
||||
24: ("bank", "BANCA_LEI"),
|
||||
25: ("bank", "BANCA_VALUTA"),
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
oracle_id, name, tip_part_id = row
|
||||
|
||||
# Determine type based on partner type
|
||||
register_type, account_code = type_mapping.get(tip_part_id, ("cash", "UNKNOWN"))
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(SyncedCashRegister).where(
|
||||
SyncedCashRegister.oracle_id == oracle_id,
|
||||
SyncedCashRegister.company_id == company_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Update existing record
|
||||
existing.name = name or ""
|
||||
existing.account_code = account_code
|
||||
existing.register_type = register_type
|
||||
existing.synced_at = datetime.utcnow()
|
||||
logger.debug(f"Updated cash register {oracle_id}: {name}")
|
||||
else:
|
||||
# Create new record
|
||||
cash_register = SyncedCashRegister(
|
||||
oracle_id=oracle_id,
|
||||
company_id=company_id,
|
||||
name=name or "",
|
||||
account_code=account_code,
|
||||
register_type=register_type,
|
||||
)
|
||||
session.add(cash_register)
|
||||
logger.debug(f"Created cash register {oracle_id}: {name}")
|
||||
|
||||
synced += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing cash register row {row}: {e}")
|
||||
errors += 1
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
logger.info(f"Synced {synced} cash registers for company {company_id}, {errors} errors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing cash registers for company {company_id}: {e}")
|
||||
errors += 1
|
||||
await session.rollback()
|
||||
|
||||
return synced, errors
|
||||
|
||||
@staticmethod
|
||||
async def search_supplier(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
fiscal_code: Optional[str] = None,
|
||||
name: Optional[str] = None
|
||||
) -> Tuple[bool, Optional[dict], str]:
|
||||
"""
|
||||
Search for supplier in SQLite first, then Oracle if not found.
|
||||
Returns (found, supplier_data, source).
|
||||
Source can be: 'synced', 'local', 'not_found'
|
||||
"""
|
||||
# 1. Search in synced suppliers
|
||||
if fiscal_code:
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.company_id == company_id,
|
||||
SyncedSupplier.fiscal_code == fiscal_code
|
||||
)
|
||||
elif name:
|
||||
stmt = select(SyncedSupplier).where(
|
||||
SyncedSupplier.company_id == company_id,
|
||||
SyncedSupplier.name.ilike(f"%{name}%")
|
||||
)
|
||||
else:
|
||||
return False, None, "no_query"
|
||||
|
||||
result = await session.execute(stmt)
|
||||
supplier = result.scalar_one_or_none()
|
||||
|
||||
if supplier:
|
||||
# Return only text data - no IDs needed for autocomplete
|
||||
return True, {
|
||||
"name": supplier.name,
|
||||
"fiscal_code": supplier.fiscal_code,
|
||||
"address": supplier.address,
|
||||
}, "synced"
|
||||
|
||||
# 2. Search in local suppliers
|
||||
if fiscal_code:
|
||||
stmt = select(LocalSupplier).where(
|
||||
LocalSupplier.company_id == company_id,
|
||||
LocalSupplier.fiscal_code == fiscal_code
|
||||
)
|
||||
elif name:
|
||||
stmt = select(LocalSupplier).where(
|
||||
LocalSupplier.company_id == company_id,
|
||||
LocalSupplier.name.ilike(f"%{name}%")
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
local = result.scalar_one_or_none()
|
||||
|
||||
if local:
|
||||
# Return only text data - no IDs needed for autocomplete
|
||||
return True, {
|
||||
"name": local.name,
|
||||
"fiscal_code": local.fiscal_code,
|
||||
"address": local.address,
|
||||
}, "local"
|
||||
|
||||
# 3. Try live Oracle search (optional fallback for unsynced data)
|
||||
# This is a fallback - ideally sync should be up to date
|
||||
# TODO: Implement live Oracle search if needed
|
||||
|
||||
return False, None, "not_found"
|
||||
|
||||
@staticmethod
|
||||
async def create_local_supplier(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
name: str,
|
||||
fiscal_code: Optional[str],
|
||||
address: Optional[str],
|
||||
created_by: str
|
||||
) -> LocalSupplier:
|
||||
"""Create a local supplier entry from OCR data."""
|
||||
supplier = LocalSupplier(
|
||||
company_id=company_id,
|
||||
name=name,
|
||||
fiscal_code=fiscal_code,
|
||||
address=address,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(supplier)
|
||||
await session.commit()
|
||||
await session.refresh(supplier)
|
||||
logger.info(f"Created local supplier: {name} (CUI: {fiscal_code})")
|
||||
return supplier
|
||||
|
||||
@staticmethod
|
||||
async def get_all_suppliers(
|
||||
session: AsyncSession,
|
||||
company_id: int,
|
||||
search: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get all suppliers (synced + local) for a company.
|
||||
Used for dropdown/autocomplete in UI.
|
||||
"""
|
||||
suppliers = []
|
||||
|
||||
# Get synced suppliers
|
||||
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(SyncedSupplier.name.ilike(f"%{search}%")) |
|
||||
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.limit(50) # Limit results for performance
|
||||
|
||||
result = await session.execute(stmt)
|
||||
synced = result.scalars().all()
|
||||
|
||||
for s in synced:
|
||||
suppliers.append({
|
||||
"id": s.id,
|
||||
"oracle_id": s.oracle_id,
|
||||
"name": s.name,
|
||||
"fiscal_code": s.fiscal_code,
|
||||
"source": "synced"
|
||||
})
|
||||
|
||||
# Get local suppliers
|
||||
stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
(LocalSupplier.name.ilike(f"%{search}%")) |
|
||||
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
|
||||
)
|
||||
stmt = stmt.limit(50)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
local = result.scalars().all()
|
||||
|
||||
for l in local:
|
||||
suppliers.append({
|
||||
"id": l.id,
|
||||
"name": l.name,
|
||||
"fiscal_code": l.fiscal_code,
|
||||
"source": "local"
|
||||
})
|
||||
|
||||
return suppliers
|
||||
|
||||
@staticmethod
|
||||
async def get_all_cash_registers(
|
||||
session: AsyncSession,
|
||||
company_id: int
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get all cash registers for a company.
|
||||
Used for dropdown in UI.
|
||||
"""
|
||||
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
|
||||
result = await session.execute(stmt)
|
||||
registers = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"oracle_id": r.oracle_id,
|
||||
"name": r.name,
|
||||
"account_code": r.account_code,
|
||||
"register_type": r.register_type
|
||||
}
|
||||
for r in registers
|
||||
]
|
||||
0
backend/modules/reports/__init__.py
Normal file
0
backend/modules/reports/__init__.py
Normal file
66
backend/modules/reports/cache/__init__.py
vendored
Normal file
66
backend/modules/reports/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Cache module for ROA2WEB
|
||||
|
||||
Provides hybrid two-tier caching (Memory L1 + SQLite L2)
|
||||
with performance tracking and event-based invalidation.
|
||||
|
||||
Usage:
|
||||
# Initialize cache at app startup
|
||||
from app.cache import init_cache
|
||||
from app.cache.config import CacheConfig
|
||||
|
||||
config = CacheConfig.from_env()
|
||||
await init_cache(config)
|
||||
|
||||
# Use @cached decorator in services
|
||||
from app.cache.decorators import cached
|
||||
|
||||
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
|
||||
async def get_complete_summary(company: str, username: str):
|
||||
# ... Oracle query logic ...
|
||||
|
||||
# Get cache manager for manual operations
|
||||
from app.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
await cache.invalidate(company_id=123)
|
||||
"""
|
||||
|
||||
from .config import CacheConfig
|
||||
from .cache_manager import (
|
||||
init_cache,
|
||||
get_cache,
|
||||
close_cache,
|
||||
CacheManager
|
||||
)
|
||||
from .decorators import cached
|
||||
from .event_monitor import (
|
||||
init_event_monitor,
|
||||
get_event_monitor,
|
||||
toggle_event_monitor,
|
||||
preload_all_schema_mappings
|
||||
)
|
||||
from .benchmarks import run_baseline_benchmarks
|
||||
|
||||
__all__ = [
|
||||
# Configuration
|
||||
'CacheConfig',
|
||||
|
||||
# Cache Manager
|
||||
'init_cache',
|
||||
'get_cache',
|
||||
'close_cache',
|
||||
'CacheManager',
|
||||
|
||||
# Decorators
|
||||
'cached',
|
||||
|
||||
# Event Monitor
|
||||
'init_event_monitor',
|
||||
'get_event_monitor',
|
||||
'toggle_event_monitor',
|
||||
'preload_all_schema_mappings',
|
||||
|
||||
# Benchmarks
|
||||
'run_baseline_benchmarks',
|
||||
]
|
||||
269
backend/modules/reports/cache/benchmarks.py
vendored
Normal file
269
backend/modules/reports/cache/benchmarks.py
vendored
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Baseline performance benchmarking
|
||||
|
||||
Runs at startup to establish baseline Oracle query times
|
||||
Used for calculating "time saved" by cache
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_baseline_benchmarks() -> Dict[str, float]:
|
||||
"""
|
||||
Run baseline benchmarks for Oracle queries (without cache)
|
||||
|
||||
Measures typical query times to establish performance baselines
|
||||
These are used to calculate time saved when cache hits occur
|
||||
|
||||
NOTE: This implementation provides a framework. Actual benchmark
|
||||
implementations need access to Oracle services and sample data.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping cache_type to average query time (ms)
|
||||
"""
|
||||
from .cache_manager import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
logger.warning("Cache not initialized - skipping benchmarks")
|
||||
return {}
|
||||
|
||||
logger.info("Starting baseline performance benchmarks...")
|
||||
benchmarks = {}
|
||||
|
||||
try:
|
||||
# Benchmark: Schema lookup
|
||||
logger.info("Benchmarking: schema lookup")
|
||||
schema_times = await _benchmark_schema_lookup()
|
||||
if schema_times:
|
||||
avg_schema = sum(schema_times) / len(schema_times)
|
||||
benchmarks['schema'] = avg_schema
|
||||
await cache.sqlite.set_benchmark('schema', avg_schema, len(schema_times))
|
||||
logger.info(f" Schema lookup: {avg_schema:.2f}ms (avg of {len(schema_times)} samples)")
|
||||
|
||||
# Benchmark: Companies list
|
||||
logger.info("Benchmarking: companies list")
|
||||
companies_time = await _benchmark_companies_list()
|
||||
if companies_time:
|
||||
benchmarks['companies'] = companies_time
|
||||
await cache.sqlite.set_benchmark('companies', companies_time, 1)
|
||||
logger.info(f" Companies list: {companies_time:.2f}ms")
|
||||
|
||||
# Benchmark: Dashboard summary
|
||||
logger.info("Benchmarking: dashboard summary")
|
||||
dashboard_time = await _benchmark_dashboard_summary()
|
||||
if dashboard_time:
|
||||
benchmarks['dashboard_summary'] = dashboard_time
|
||||
await cache.sqlite.set_benchmark('dashboard_summary', dashboard_time, 1)
|
||||
logger.info(f" Dashboard summary: {dashboard_time:.2f}ms")
|
||||
|
||||
# Benchmark: Dashboard trends
|
||||
logger.info("Benchmarking: dashboard trends")
|
||||
trends_time = await _benchmark_dashboard_trends()
|
||||
if trends_time:
|
||||
benchmarks['dashboard_trends'] = trends_time
|
||||
await cache.sqlite.set_benchmark('dashboard_trends', trends_time, 1)
|
||||
logger.info(f" Dashboard trends: {trends_time:.2f}ms")
|
||||
|
||||
# Benchmark: Invoices
|
||||
logger.info("Benchmarking: invoices")
|
||||
invoices_time = await _benchmark_invoices()
|
||||
if invoices_time:
|
||||
benchmarks['invoices'] = invoices_time
|
||||
await cache.sqlite.set_benchmark('invoices', invoices_time, 1)
|
||||
logger.info(f" Invoices: {invoices_time:.2f}ms")
|
||||
|
||||
# Benchmark: Treasury
|
||||
logger.info("Benchmarking: treasury")
|
||||
treasury_time = await _benchmark_treasury()
|
||||
if treasury_time:
|
||||
benchmarks['treasury'] = treasury_time
|
||||
await cache.sqlite.set_benchmark('treasury', treasury_time, 1)
|
||||
logger.info(f" Treasury: {treasury_time:.2f}ms")
|
||||
|
||||
logger.info(f"Baseline benchmarks completed: {len(benchmarks)} types measured")
|
||||
return benchmarks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark error: {e}", exc_info=True)
|
||||
return benchmarks
|
||||
|
||||
|
||||
async def _benchmark_schema_lookup() -> list:
|
||||
"""
|
||||
Benchmark schema lookup queries
|
||||
|
||||
Returns:
|
||||
List of query times (ms) for multiple samples
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
# Get sample company IDs to test
|
||||
sample_companies = await _get_sample_company_ids(limit=10)
|
||||
if not sample_companies:
|
||||
logger.warning("No sample companies found for schema benchmark")
|
||||
return []
|
||||
|
||||
times = []
|
||||
for company_id in sample_companies:
|
||||
start = time.time()
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :id
|
||||
""", {'id': company_id})
|
||||
cursor.fetchone()
|
||||
elapsed_ms = (time.time() - start) * 1000
|
||||
times.append(elapsed_ms)
|
||||
|
||||
return times
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Schema benchmark error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _benchmark_companies_list() -> float:
|
||||
"""
|
||||
Benchmark companies list query
|
||||
|
||||
Returns:
|
||||
Query time (ms)
|
||||
"""
|
||||
try:
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
# Get sample username
|
||||
sample_user = await _get_sample_username()
|
||||
if not sample_user:
|
||||
return 0
|
||||
|
||||
start = time.time()
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT nf.id_firma, nf.denumire, nf.cui, nf.schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme nf
|
||||
JOIN CONTAFIN_ORACLE.vdef_util_firme uf ON nf.id_firma = uf.id_firma
|
||||
WHERE uf.nume_utilizator = :username
|
||||
ORDER BY nf.denumire
|
||||
""", {'username': sample_user})
|
||||
cursor.fetchall()
|
||||
elapsed_ms = (time.time() - start) * 1000
|
||||
return elapsed_ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Companies benchmark error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _benchmark_dashboard_summary() -> float:
|
||||
"""
|
||||
Benchmark dashboard summary query
|
||||
|
||||
Returns:
|
||||
Query time (ms)
|
||||
"""
|
||||
try:
|
||||
# This requires access to DashboardService
|
||||
# For now, return estimated value
|
||||
logger.warning("Dashboard summary benchmark not implemented - using estimate")
|
||||
return 250.0 # Estimated 250ms based on plan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard benchmark error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _benchmark_dashboard_trends() -> float:
|
||||
"""Benchmark dashboard trends query"""
|
||||
try:
|
||||
logger.warning("Dashboard trends benchmark not implemented - using estimate")
|
||||
return 400.0 # Estimated 400ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trends benchmark error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _benchmark_invoices() -> float:
|
||||
"""Benchmark invoices query"""
|
||||
try:
|
||||
logger.warning("Invoices benchmark not implemented - using estimate")
|
||||
return 180.0 # Estimated 180ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Invoices benchmark error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _benchmark_treasury() -> float:
|
||||
"""Benchmark treasury query"""
|
||||
try:
|
||||
logger.warning("Treasury benchmark not implemented - using estimate")
|
||||
return 250.0 # Estimated 250ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Treasury benchmark error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# Helper functions
|
||||
|
||||
async def _get_sample_company_ids(limit: int = 10) -> list:
|
||||
"""Get sample company IDs for testing"""
|
||||
try:
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"""
|
||||
SELECT id_firma
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE ROWNUM <= {limit}
|
||||
""")
|
||||
results = cursor.fetchall()
|
||||
return [row[0] for row in results]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get sample companies error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _get_sample_username() -> str:
|
||||
"""Get sample username for testing"""
|
||||
try:
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT nume_utilizator
|
||||
FROM CONTAFIN_ORACLE.vdef_util_firme
|
||||
WHERE ROWNUM <= 1
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else "admin"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get sample username error: {e}")
|
||||
return "admin"
|
||||
335
backend/modules/reports/cache/cache_manager.py
vendored
Normal file
335
backend/modules/reports/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Cache Manager - Orchestrator for hybrid L1 + L2 cache
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
from .config import CacheConfig
|
||||
from .memory_cache import MemoryCache
|
||||
from .sqlite_cache import SQLiteCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
Hybrid cache manager (Memory L1 + SQLite L2)
|
||||
|
||||
Features:
|
||||
- Two-tier caching: fast memory + persistent SQLite
|
||||
- Automatic TTL management per cache type
|
||||
- Performance tracking and benchmarking
|
||||
- Per-user cache enable/disable
|
||||
- Global cache toggle
|
||||
"""
|
||||
|
||||
def __init__(self, config: CacheConfig):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
config: Cache configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.memory = MemoryCache(max_size=config.memory_max_size)
|
||||
self.sqlite = SQLiteCache(db_path=config.sqlite_path)
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._initialized = False
|
||||
self._last_cache_source: Optional[str] = None # Track last cache source (L1/L2)
|
||||
|
||||
async def init(self):
|
||||
"""Initialize cache system"""
|
||||
if self._initialized:
|
||||
logger.warning("Cache already initialized")
|
||||
return
|
||||
|
||||
# Initialize SQLite database schema
|
||||
await self.sqlite.init_db()
|
||||
|
||||
# Start cleanup task
|
||||
if self.config.enabled:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Cache initialized: type={self.config.cache_type}, enabled={self.config.enabled}")
|
||||
|
||||
async def close(self):
|
||||
"""Close cache and cleanup"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Cache closed")
|
||||
|
||||
async def get(self, key: str, cache_type: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache (L1 → L2)
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
cache_type: Type of cache entry
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
self._last_cache_source = None
|
||||
return None
|
||||
|
||||
# Try L1 (Memory) first
|
||||
value = await self.memory.get(key)
|
||||
if value is not None:
|
||||
self._last_cache_source = "L1"
|
||||
logger.debug(f"Cache HIT L1 (memory): {key}")
|
||||
return value
|
||||
|
||||
# Try L2 (SQLite)
|
||||
value = await self.sqlite.get(key)
|
||||
if value is not None:
|
||||
self._last_cache_source = "L2"
|
||||
logger.debug(f"Cache HIT L2 (sqlite): {key}")
|
||||
|
||||
# Populate L1 for next time
|
||||
ttl = self.config.get_ttl_for_type(cache_type)
|
||||
await self.memory.set(key, value, ttl)
|
||||
|
||||
return value
|
||||
|
||||
# Cache MISS
|
||||
self._last_cache_source = None
|
||||
logger.debug(f"Cache MISS: {key}")
|
||||
return None
|
||||
|
||||
def get_last_cache_source(self) -> Optional[str]:
|
||||
"""
|
||||
Get source of last cache hit (L1/L2/None)
|
||||
|
||||
Returns:
|
||||
"L1" if last hit was from memory cache
|
||||
"L2" if last hit was from SQLite cache
|
||||
None if last call was a cache miss or cache disabled
|
||||
"""
|
||||
return self._last_cache_source
|
||||
|
||||
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int] = None,
|
||||
ttl: Optional[int] = None):
|
||||
"""
|
||||
Set value in cache (both L1 and L2)
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
cache_type: Type of cache entry
|
||||
company_id: Company ID (for company-specific caches)
|
||||
ttl: Time to live (uses default for cache_type if not provided)
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return
|
||||
|
||||
if ttl is None:
|
||||
ttl = self.config.get_ttl_for_type(cache_type)
|
||||
|
||||
# Store in both L1 and L2
|
||||
await self.memory.set(key, value, ttl)
|
||||
await self.sqlite.set(key, value, cache_type, company_id, ttl)
|
||||
|
||||
logger.debug(f"Cache SET (L1 + L2): {key} (TTL: {ttl}s)")
|
||||
|
||||
async def delete(self, key: str):
|
||||
"""Delete entry from both L1 and L2"""
|
||||
await self.memory.delete(key)
|
||||
await self.sqlite.delete(key)
|
||||
logger.debug(f"Cache deleted: {key}")
|
||||
|
||||
async def invalidate(self, company_id: Optional[int] = None, cache_type: Optional[str] = None):
|
||||
"""
|
||||
Invalidate cache entries
|
||||
|
||||
Args:
|
||||
company_id: If provided, clear only this company's cache
|
||||
cache_type: If provided, clear only this cache type
|
||||
"""
|
||||
if company_id is not None and cache_type is not None:
|
||||
# Clear specific company + type
|
||||
from .keys import generate_key_pattern
|
||||
pattern = generate_key_pattern(cache_type, company_id)
|
||||
await self.memory.clear_by_pattern(pattern)
|
||||
# SQLite: clear by company + type (needs query)
|
||||
# For now, just clear by company
|
||||
await self.sqlite.clear_by_company(company_id)
|
||||
logger.info(f"Cache invalidated: company={company_id}, type={cache_type}")
|
||||
|
||||
elif company_id is not None:
|
||||
# Clear all for company
|
||||
from .keys import generate_key_pattern
|
||||
# Clear all types for this company (pattern match all)
|
||||
# Memory: need to iterate and match company_id in key
|
||||
# For simplicity, clear by pattern prefix
|
||||
await self.memory.clear() # TODO: improve pattern matching
|
||||
await self.sqlite.clear_by_company(company_id)
|
||||
logger.info(f"Cache invalidated: company={company_id}")
|
||||
|
||||
elif cache_type is not None:
|
||||
# Clear all for type
|
||||
from .keys import generate_key_pattern
|
||||
pattern = generate_key_pattern(cache_type)
|
||||
await self.memory.clear_by_pattern(pattern)
|
||||
await self.sqlite.clear_by_type(cache_type)
|
||||
logger.info(f"Cache invalidated: type={cache_type}")
|
||||
|
||||
else:
|
||||
# Clear everything
|
||||
await self.memory.clear()
|
||||
await self.sqlite.clear()
|
||||
logger.info("Cache invalidated: ALL")
|
||||
|
||||
async def is_enabled_for_user(self, username: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if cache is enabled for specific user
|
||||
|
||||
Args:
|
||||
username: Username to check
|
||||
|
||||
Returns:
|
||||
True if cache enabled for user, False otherwise
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
|
||||
if username is None:
|
||||
return True
|
||||
|
||||
# Check per-user setting
|
||||
return await self.sqlite.get_user_cache_enabled(username)
|
||||
|
||||
async def set_user_cache_enabled(self, username: str, enabled: bool):
|
||||
"""Set user cache enabled/disabled"""
|
||||
await self.sqlite.set_user_cache_enabled(username, enabled)
|
||||
logger.info(f"User cache setting: {username} -> {enabled}")
|
||||
|
||||
# Benchmarking
|
||||
|
||||
async def get_benchmark(self, cache_type: str) -> Optional[float]:
|
||||
"""Get average benchmark time for cache type"""
|
||||
return await self.sqlite.get_benchmark(cache_type)
|
||||
|
||||
async def update_benchmark(self, cache_type: str, new_time_ms: float):
|
||||
"""
|
||||
Update benchmark with new measurement (exponential moving average)
|
||||
|
||||
Args:
|
||||
cache_type: Type of cache
|
||||
new_time_ms: New measured time in milliseconds
|
||||
"""
|
||||
current_avg = await self.sqlite.get_benchmark(cache_type)
|
||||
|
||||
if current_avg is None:
|
||||
# First measurement
|
||||
new_avg = new_time_ms
|
||||
sample_count = 1
|
||||
else:
|
||||
# Exponential moving average (alpha = 0.1)
|
||||
new_avg = 0.9 * current_avg + 0.1 * new_time_ms
|
||||
# Get current sample count (TODO: retrieve from DB)
|
||||
sample_count = 1 # Simplified for now
|
||||
|
||||
await self.sqlite.set_benchmark(cache_type, new_avg, sample_count)
|
||||
logger.debug(f"Benchmark updated: {cache_type} -> {new_avg:.2f}ms")
|
||||
|
||||
# Performance Tracking
|
||||
|
||||
async def track_performance(self, cache_type: str, is_hit: bool, actual_time_ms: float,
|
||||
time_saved_ms: Optional[float] = None,
|
||||
estimated_oracle_time_ms: Optional[float] = None,
|
||||
company_id: Optional[int] = None,
|
||||
username: Optional[str] = None):
|
||||
"""
|
||||
Track performance metric
|
||||
|
||||
Args:
|
||||
cache_type: Type of cache
|
||||
is_hit: True if cache hit, False if cache miss
|
||||
actual_time_ms: Actual response time
|
||||
time_saved_ms: Time saved by cache (for hits)
|
||||
estimated_oracle_time_ms: Estimated Oracle time (for hits)
|
||||
company_id: Company ID
|
||||
username: Username
|
||||
"""
|
||||
if not self.config.track_performance:
|
||||
return
|
||||
|
||||
await self.sqlite.log_performance(
|
||||
cache_type=cache_type,
|
||||
company_id=company_id,
|
||||
cache_hit=is_hit,
|
||||
response_time_ms=actual_time_ms,
|
||||
estimated_oracle_time_ms=estimated_oracle_time_ms,
|
||||
time_saved_ms=time_saved_ms,
|
||||
username=username
|
||||
)
|
||||
|
||||
# Statistics
|
||||
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get comprehensive cache statistics"""
|
||||
memory_stats = self.memory.get_stats()
|
||||
sqlite_stats = await self.sqlite.get_stats()
|
||||
|
||||
return {
|
||||
'enabled': self.config.enabled,
|
||||
'cache_type': self.config.cache_type,
|
||||
'memory': memory_stats,
|
||||
'sqlite': sqlite_stats,
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background task to cleanup expired entries"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.config.cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}", exc_info=True)
|
||||
|
||||
async def _cleanup_expired(self):
|
||||
"""Remove expired entries from both caches"""
|
||||
logger.info("Running cache cleanup...")
|
||||
await self.memory.cleanup_expired()
|
||||
await self.sqlite.cleanup_expired()
|
||||
logger.info("Cache cleanup completed")
|
||||
|
||||
|
||||
# Global cache manager instance
|
||||
_cache_manager: Optional[CacheManager] = None
|
||||
|
||||
|
||||
async def init_cache(config: CacheConfig):
|
||||
"""Initialize global cache manager"""
|
||||
global _cache_manager
|
||||
if _cache_manager is not None:
|
||||
logger.warning("Cache already initialized")
|
||||
return
|
||||
|
||||
_cache_manager = CacheManager(config)
|
||||
await _cache_manager.init()
|
||||
logger.info("Global cache manager initialized")
|
||||
|
||||
|
||||
def get_cache() -> Optional[CacheManager]:
|
||||
"""Get global cache manager instance"""
|
||||
return _cache_manager
|
||||
|
||||
|
||||
async def close_cache():
|
||||
"""Close global cache manager"""
|
||||
global _cache_manager
|
||||
if _cache_manager is not None:
|
||||
await _cache_manager.close()
|
||||
_cache_manager = None
|
||||
89
backend/modules/reports/cache/config.py
vendored
Normal file
89
backend/modules/reports/cache/config.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Cache configuration from environment variables
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Cache configuration loaded from environment variables"""
|
||||
|
||||
# Core Settings
|
||||
enabled: bool
|
||||
cache_type: str # 'hybrid', 'memory', 'sqlite', 'disabled'
|
||||
sqlite_path: str
|
||||
memory_max_size: int
|
||||
default_ttl: int
|
||||
|
||||
# TTL per Cache Type (seconds)
|
||||
ttl_schema: int
|
||||
ttl_companies: int
|
||||
ttl_dashboard_summary: int
|
||||
ttl_dashboard_trends: int
|
||||
ttl_invoices: int
|
||||
ttl_invoices_summary: int
|
||||
ttl_treasury: int
|
||||
ttl_trial_balance: int
|
||||
ttl_calendar_periods: int
|
||||
|
||||
# Maintenance
|
||||
cleanup_interval: int
|
||||
|
||||
# Event-Based Invalidation
|
||||
auto_invalidate_enabled: bool
|
||||
check_interval: int
|
||||
|
||||
# Performance Tracking
|
||||
track_performance: bool
|
||||
benchmark_on_startup: bool
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'CacheConfig':
|
||||
"""Load configuration from environment variables"""
|
||||
return cls(
|
||||
# Core Settings
|
||||
enabled=os.getenv('CACHE_ENABLED', 'True').lower() == 'true',
|
||||
cache_type=os.getenv('CACHE_TYPE', 'hybrid'),
|
||||
sqlite_path=os.getenv('CACHE_SQLITE_PATH', './data/cache/roa2web_cache.db'),
|
||||
memory_max_size=int(os.getenv('CACHE_MEMORY_MAX_SIZE', '1000')),
|
||||
default_ttl=int(os.getenv('CACHE_DEFAULT_TTL', '900')),
|
||||
|
||||
# TTL per Cache Type
|
||||
ttl_schema=int(os.getenv('CACHE_TTL_SCHEMA', '86400')),
|
||||
ttl_companies=int(os.getenv('CACHE_TTL_COMPANIES', '1800')),
|
||||
ttl_dashboard_summary=int(os.getenv('CACHE_TTL_DASHBOARD_SUMMARY', '1800')),
|
||||
ttl_dashboard_trends=int(os.getenv('CACHE_TTL_DASHBOARD_TRENDS', '1800')),
|
||||
ttl_invoices=int(os.getenv('CACHE_TTL_INVOICES', '600')),
|
||||
ttl_invoices_summary=int(os.getenv('CACHE_TTL_INVOICES_SUMMARY', '900')),
|
||||
ttl_treasury=int(os.getenv('CACHE_TTL_TREASURY', '600')),
|
||||
ttl_trial_balance=int(os.getenv('CACHE_TTL_TRIAL_BALANCE', '600')),
|
||||
ttl_calendar_periods=int(os.getenv('CACHE_TTL_CALENDAR_PERIODS', '3600')),
|
||||
|
||||
# Maintenance
|
||||
cleanup_interval=int(os.getenv('CACHE_CLEANUP_INTERVAL', '3600')),
|
||||
|
||||
# Event-Based Invalidation
|
||||
auto_invalidate_enabled=os.getenv('CACHE_AUTO_INVALIDATE', 'False').lower() == 'true',
|
||||
check_interval=int(os.getenv('CACHE_CHECK_INTERVAL', '300')),
|
||||
|
||||
# Performance Tracking
|
||||
track_performance=os.getenv('CACHE_TRACK_PERFORMANCE', 'True').lower() == 'true',
|
||||
benchmark_on_startup=os.getenv('CACHE_BENCHMARK_ON_STARTUP', 'True').lower() == 'true',
|
||||
)
|
||||
|
||||
def get_ttl_for_type(self, cache_type: str) -> int:
|
||||
"""Get TTL for specific cache type"""
|
||||
ttl_map = {
|
||||
'schema': self.ttl_schema,
|
||||
'companies': self.ttl_companies,
|
||||
'dashboard_summary': self.ttl_dashboard_summary,
|
||||
'dashboard_trends': self.ttl_dashboard_trends,
|
||||
'invoices': self.ttl_invoices,
|
||||
'invoices_summary': self.ttl_invoices_summary,
|
||||
'treasury': self.ttl_treasury,
|
||||
'trial_balance': self.ttl_trial_balance,
|
||||
'calendar_periods': self.ttl_calendar_periods,
|
||||
}
|
||||
return ttl_map.get(cache_type, self.default_ttl)
|
||||
254
backend/modules/reports/cache/decorators.py
vendored
Normal file
254
backend/modules/reports/cache/decorators.py
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Cache decorators for service methods
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
from .cache_manager import get_cache
|
||||
from .keys import generate_cache_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cached(cache_type: str, ttl: Optional[int] = None, key_params: Optional[List[str]] = None):
|
||||
"""
|
||||
Decorator for caching service method results with performance tracking
|
||||
|
||||
Usage:
|
||||
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
|
||||
async def get_complete_summary(company: str, username: str):
|
||||
# ... Oracle query logic ...
|
||||
|
||||
Features:
|
||||
- Automatic cache key generation from function parameters
|
||||
- Performance timing (cache hit vs miss)
|
||||
- Benchmark tracking for time saved calculation
|
||||
- Per-user cache enable/disable
|
||||
- Global cache toggle
|
||||
- Transparent - zero changes to function logic
|
||||
|
||||
Args:
|
||||
cache_type: Type of cache (used for TTL lookup and stats)
|
||||
ttl: Optional custom TTL (overrides config default)
|
||||
key_params: List of parameter names to include in cache key
|
||||
|
||||
Returns:
|
||||
Decorated async function
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
cache = get_cache()
|
||||
|
||||
# Extract username for per-user settings
|
||||
username = _extract_username(args, kwargs, key_params)
|
||||
|
||||
# Check if cache is enabled (global + per-user)
|
||||
cache_enabled = await cache.is_enabled_for_user(username) if cache else False
|
||||
|
||||
if not cache or not cache_enabled:
|
||||
# Cache disabled - execute directly
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Set metadata in request.state if available (for API responses)
|
||||
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
|
||||
kwargs['request'].state.cache_hit = False
|
||||
kwargs['request'].state.response_time_ms = elapsed_ms
|
||||
kwargs['request'].state.cache_source = None
|
||||
|
||||
if cache and cache.config.track_performance:
|
||||
await cache.track_performance(
|
||||
cache_type=cache_type,
|
||||
is_hit=False,
|
||||
actual_time_ms=elapsed_ms,
|
||||
username=username
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Generate cache key from function parameters
|
||||
cache_key = generate_cache_key(cache_type, key_params, args, kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_value = await cache.get(cache_key, cache_type)
|
||||
|
||||
if cached_value is not None:
|
||||
# ✅ CACHE HIT
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Set metadata in request.state if available (for API responses)
|
||||
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
|
||||
cache_source_value = cache.get_last_cache_source() # L1 or L2
|
||||
kwargs['request'].state.cache_hit = True
|
||||
kwargs['request'].state.response_time_ms = elapsed_ms
|
||||
kwargs['request'].state.cache_source = cache_source_value
|
||||
|
||||
# Get benchmark for calculating time saved
|
||||
benchmark = await cache.get_benchmark(cache_type)
|
||||
time_saved_ms = (benchmark - elapsed_ms) if benchmark else None
|
||||
|
||||
# Track performance
|
||||
if cache.config.track_performance:
|
||||
await cache.track_performance(
|
||||
cache_type=cache_type,
|
||||
is_hit=True,
|
||||
actual_time_ms=elapsed_ms,
|
||||
time_saved_ms=time_saved_ms,
|
||||
estimated_oracle_time_ms=benchmark,
|
||||
company_id=_extract_company_id(args, kwargs, key_params),
|
||||
username=username
|
||||
)
|
||||
|
||||
return cached_value
|
||||
|
||||
# ❌ CACHE MISS - execute function (query Oracle)
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Set metadata in request.state if available (for API responses)
|
||||
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
|
||||
kwargs['request'].state.cache_hit = False
|
||||
kwargs['request'].state.response_time_ms = elapsed_ms
|
||||
kwargs['request'].state.cache_source = None
|
||||
|
||||
# Update benchmark with real Oracle time
|
||||
await cache.update_benchmark(cache_type, elapsed_ms)
|
||||
|
||||
# Track performance
|
||||
if cache.config.track_performance:
|
||||
await cache.track_performance(
|
||||
cache_type=cache_type,
|
||||
is_hit=False,
|
||||
actual_time_ms=elapsed_ms,
|
||||
company_id=_extract_company_id(args, kwargs, key_params),
|
||||
username=username
|
||||
)
|
||||
|
||||
# Store in cache for next time
|
||||
company_id = _extract_company_id(args, kwargs, key_params)
|
||||
await cache.set(cache_key, result, cache_type, company_id, ttl)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_username(args, kwargs, key_params: Optional[List[str]]) -> Optional[str]:
|
||||
"""
|
||||
Extract username from function parameters (args or kwargs)
|
||||
|
||||
Checks:
|
||||
1. key_params position in args (if username is in key_params)
|
||||
2. Direct username in kwargs
|
||||
3. current_user object in kwargs
|
||||
4. user object in kwargs
|
||||
5. request.state.user (from AuthenticationMiddleware)
|
||||
|
||||
Args:
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
key_params: List of parameter names (for finding position in args)
|
||||
|
||||
Returns:
|
||||
Username string or None
|
||||
"""
|
||||
# Try to find username in args based on key_params position
|
||||
if key_params and 'username' in key_params:
|
||||
try:
|
||||
username_idx = key_params.index('username')
|
||||
if username_idx < len(args):
|
||||
username = args[username_idx]
|
||||
if username:
|
||||
return str(username)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Direct username parameter in kwargs
|
||||
if 'username' in kwargs:
|
||||
return kwargs['username']
|
||||
|
||||
# Current user object (from FastAPI Depends)
|
||||
if 'current_user' in kwargs:
|
||||
user = kwargs['current_user']
|
||||
if hasattr(user, 'username'):
|
||||
return user.username
|
||||
elif isinstance(user, dict) and 'username' in user:
|
||||
return user['username']
|
||||
return str(user)
|
||||
|
||||
# User object
|
||||
if 'user' in kwargs:
|
||||
user = kwargs['user']
|
||||
if hasattr(user, 'username'):
|
||||
return user.username
|
||||
elif isinstance(user, dict) and 'username' in user:
|
||||
return user['username']
|
||||
return str(user)
|
||||
|
||||
# Extract from request.state.user (set by AuthenticationMiddleware)
|
||||
if 'request' in kwargs:
|
||||
request = kwargs['request']
|
||||
if hasattr(request, 'state') and hasattr(request.state, 'user'):
|
||||
user = request.state.user
|
||||
if hasattr(user, 'username'):
|
||||
return user.username
|
||||
elif isinstance(user, dict) and 'username' in user:
|
||||
return user['username']
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_company_id(args, kwargs, key_params: Optional[List[str]]) -> Optional[int]:
|
||||
"""
|
||||
Extract company_id from function parameters for cache indexing
|
||||
|
||||
Tries multiple approaches:
|
||||
1. Direct company_id in kwargs
|
||||
2. company parameter (converted to int)
|
||||
3. Positional args based on key_params position
|
||||
|
||||
Args:
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
key_params: List of parameter names
|
||||
|
||||
Returns:
|
||||
Company ID as integer or None
|
||||
"""
|
||||
# Try kwargs first
|
||||
if 'company_id' in kwargs:
|
||||
try:
|
||||
return int(kwargs['company_id'])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if 'company' in kwargs:
|
||||
try:
|
||||
return int(kwargs['company'])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try positional args based on key_params
|
||||
if key_params:
|
||||
if 'company_id' in key_params:
|
||||
idx = key_params.index('company_id')
|
||||
if idx < len(args):
|
||||
try:
|
||||
return int(args[idx])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
elif 'company' in key_params:
|
||||
idx = key_params.index('company')
|
||||
if idx < len(args):
|
||||
try:
|
||||
return int(args[idx])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
333
backend/modules/reports/cache/event_monitor.py
vendored
Normal file
333
backend/modules/reports/cache/event_monitor.py
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Event-based cache invalidation monitor
|
||||
|
||||
Monitors {schema}.act tables for changes and invalidates cache automatically
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
# Path setup handled by main.py - this is redundant but kept for module isolation
|
||||
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventMonitor:
|
||||
"""
|
||||
Monitors schema.act tables for changes to trigger cache invalidation
|
||||
|
||||
Runs as background task, checking max(id_act) at configured intervals
|
||||
Uses permanent schema_mappings cache to avoid repeated schema lookups
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager, config):
|
||||
"""
|
||||
Initialize event monitor
|
||||
|
||||
Args:
|
||||
cache_manager: CacheManager instance
|
||||
config: CacheConfig instance
|
||||
"""
|
||||
self.cache = cache_manager
|
||||
self.config = config
|
||||
self.running = False
|
||||
self.task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
"""Start monitoring task"""
|
||||
if self.running:
|
||||
logger.warning("Event monitor already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info(
|
||||
f"Event monitor started (interval: {self.config.check_interval}s)"
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop monitoring task"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Event monitor stopped")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""Main monitoring loop"""
|
||||
while self.running:
|
||||
try:
|
||||
await self._check_all_companies()
|
||||
await asyncio.sleep(self.config.check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Event monitor error: {e}", exc_info=True)
|
||||
# Wait 1 minute on error before retrying
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _check_all_companies(self):
|
||||
"""
|
||||
Check all companies with active cache for changes
|
||||
|
||||
Queries max(id_act) from {schema}.act for each cached company
|
||||
and invalidates cache if changes detected
|
||||
"""
|
||||
try:
|
||||
# Get list of companies with active cache entries
|
||||
cached_companies = await self.cache.sqlite.get_cached_company_ids()
|
||||
|
||||
if not cached_companies:
|
||||
logger.debug("No cached companies to monitor")
|
||||
return
|
||||
|
||||
logger.info(f"Checking {len(cached_companies)} companies for changes...")
|
||||
invalidated_count = 0
|
||||
|
||||
for company_id in cached_companies:
|
||||
try:
|
||||
# Check if company data changed
|
||||
changed = await self._check_company_changes(company_id)
|
||||
|
||||
if changed:
|
||||
# Invalidate cache for this company
|
||||
await self.cache.invalidate(company_id=company_id)
|
||||
invalidated_count += 1
|
||||
logger.info(
|
||||
f"Cache invalidated for company {company_id} due to act changes"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Error for one company shouldn't stop checking others
|
||||
logger.error(f"Error checking company {company_id}: {e}")
|
||||
continue
|
||||
|
||||
if invalidated_count > 0:
|
||||
logger.info(
|
||||
f"Auto-invalidation complete: {invalidated_count} companies affected"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Check all companies error: {e}", exc_info=True)
|
||||
|
||||
async def _check_company_changes(self, company_id: int) -> bool:
|
||||
"""
|
||||
Check if company data changed (monitor max(id_act) in schema.act)
|
||||
|
||||
Args:
|
||||
company_id: Company ID to check
|
||||
|
||||
Returns:
|
||||
True if cache should be invalidated, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 1. Get schema (from permanent cache)
|
||||
schema = await self._get_schema_for_company(company_id)
|
||||
if not schema:
|
||||
logger.warning(f"Schema not found for company {company_id}")
|
||||
return False
|
||||
|
||||
# 2. Get current max(id_act) from Oracle
|
||||
current_max = await self._get_max_id_act(schema)
|
||||
|
||||
# 3. Get cached watermark
|
||||
cached_watermark = await self.cache.sqlite.get_watermark(company_id)
|
||||
|
||||
# 4. Compare
|
||||
if cached_watermark is None:
|
||||
# First time checking - store watermark, no invalidation
|
||||
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
|
||||
logger.debug(
|
||||
f"Watermark initialized for company {company_id}: {current_max}"
|
||||
)
|
||||
return False
|
||||
|
||||
if current_max > cached_watermark:
|
||||
# Changes detected!
|
||||
logger.info(
|
||||
f"Schema {schema} (company {company_id}): "
|
||||
f"id_act changed {cached_watermark} → {current_max}"
|
||||
)
|
||||
|
||||
# Update watermark
|
||||
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
|
||||
|
||||
return True # Invalidate cache
|
||||
|
||||
# No changes
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Check company {company_id} changes error: {e}")
|
||||
return False # Don't invalidate on error
|
||||
|
||||
async def _get_schema_for_company(self, company_id: int) -> Optional[str]:
|
||||
"""
|
||||
Get schema for company (with permanent caching)
|
||||
|
||||
First checks permanent schema_mappings cache,
|
||||
falls back to Oracle query if not cached
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
|
||||
Returns:
|
||||
Schema name or None
|
||||
"""
|
||||
# Check permanent cache first
|
||||
cached_schema = await self.cache.sqlite.get_schema_mapping(company_id)
|
||||
if cached_schema:
|
||||
logger.debug(f"Schema mapping HIT for company {company_id}: {cached_schema}")
|
||||
return cached_schema
|
||||
|
||||
# Cache MISS - query Oracle
|
||||
logger.info(f"Schema mapping MISS for company {company_id}, querying Oracle...")
|
||||
|
||||
try:
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :id
|
||||
""", {'id': company_id})
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
logger.warning(f"Company {company_id} not found in v_nom_firme")
|
||||
return None
|
||||
|
||||
schema = result[0]
|
||||
|
||||
# Store PERMANENT in schema_mappings (never expires)
|
||||
await self.cache.sqlite.set_schema_mapping(company_id, schema)
|
||||
|
||||
logger.info(f"Schema mapping stored for company {company_id}: {schema}")
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get schema for company {company_id} error: {e}")
|
||||
return None
|
||||
|
||||
async def _get_max_id_act(self, schema: str) -> int:
|
||||
"""
|
||||
Query max(id_act) from {schema}.act
|
||||
|
||||
Args:
|
||||
schema: Schema name
|
||||
|
||||
Returns:
|
||||
Max id_act value (0 if table empty)
|
||||
"""
|
||||
try:
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# IMPORTANT: Schema comes from v_nom_firme (trusted source)
|
||||
# so it's safe from SQL injection
|
||||
query = f"SELECT MAX(id_act) FROM {schema}.act"
|
||||
cursor.execute(query)
|
||||
|
||||
result = cursor.fetchone()
|
||||
max_id_act = result[0] if result and result[0] is not None else 0
|
||||
|
||||
return max_id_act
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get max_id_act for schema {schema} error: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# Optional: Preload all schema mappings at startup
|
||||
|
||||
async def preload_all_schema_mappings():
|
||||
"""
|
||||
Preload all schema mappings at startup (optional)
|
||||
|
||||
Prevents cache misses on first requests by populating
|
||||
schema_mappings table with all companies
|
||||
"""
|
||||
from .cache_manager import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
logger.warning("Cache not initialized - skipping schema preload")
|
||||
return
|
||||
|
||||
logger.info("Preloading all schema mappings...")
|
||||
|
||||
try:
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT id_firma, schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
""")
|
||||
results = cursor.fetchall()
|
||||
|
||||
for id_firma, schema in results:
|
||||
await cache.sqlite.set_schema_mapping(id_firma, schema)
|
||||
|
||||
logger.info(f"Preloaded {len(results)} schema mappings")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Schema preload error: {e}")
|
||||
|
||||
|
||||
# Global event monitor instance
|
||||
_event_monitor: Optional[EventMonitor] = None
|
||||
|
||||
|
||||
async def init_event_monitor(cache_manager, config):
|
||||
"""
|
||||
Initialize global event monitor
|
||||
|
||||
Args:
|
||||
cache_manager: CacheManager instance
|
||||
config: CacheConfig instance
|
||||
"""
|
||||
global _event_monitor
|
||||
_event_monitor = EventMonitor(cache_manager, config)
|
||||
|
||||
# Start if auto-invalidate enabled
|
||||
if config.auto_invalidate_enabled:
|
||||
await _event_monitor.start()
|
||||
|
||||
|
||||
def get_event_monitor() -> Optional[EventMonitor]:
|
||||
"""Get global event monitor instance"""
|
||||
return _event_monitor
|
||||
|
||||
|
||||
async def toggle_event_monitor(enabled: bool):
|
||||
"""
|
||||
Toggle event monitor on/off
|
||||
|
||||
Args:
|
||||
enabled: True to start monitoring, False to stop
|
||||
"""
|
||||
monitor = get_event_monitor()
|
||||
if not monitor:
|
||||
logger.warning("Event monitor not initialized")
|
||||
return
|
||||
|
||||
if enabled and not monitor.running:
|
||||
await monitor.start()
|
||||
elif not enabled and monitor.running:
|
||||
await monitor.stop()
|
||||
150
backend/modules/reports/cache/keys.py
vendored
Normal file
150
backend/modules/reports/cache/keys.py
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Cache key generation utilities
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
def generate_cache_key(cache_type: str, key_params: Optional[List[str]], args: tuple, kwargs: dict) -> str:
|
||||
"""
|
||||
Generate cache key from function parameters
|
||||
|
||||
Format: "{cache_type}:{param1_value}:{param2_value}:..."
|
||||
|
||||
Args:
|
||||
cache_type: Type of cache (e.g., 'dashboard_summary', 'invoices')
|
||||
key_params: List of parameter names to include in key
|
||||
args: Positional arguments from function call
|
||||
kwargs: Keyword arguments from function call
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
|
||||
Examples:
|
||||
generate_cache_key('schema', ['company_id'], (123,), {})
|
||||
-> "schema:123"
|
||||
|
||||
generate_cache_key('dashboard_summary', ['company', 'username'], (), {'company': '123', 'username': 'john'})
|
||||
-> "dashboard_summary:123:john"
|
||||
|
||||
generate_cache_key('invoices', ['company', 'invoice_type', 'status'], (123, 'CLIENTI', 'neplatite'), {})
|
||||
-> "invoices:123:CLIENTI:neplatite"
|
||||
"""
|
||||
key_parts = [cache_type]
|
||||
|
||||
if not key_params:
|
||||
# No specific params - use all args/kwargs (fallback)
|
||||
if args:
|
||||
key_parts.extend([str(arg) for arg in args])
|
||||
if kwargs:
|
||||
# Sort kwargs for consistent key generation
|
||||
sorted_kwargs = sorted(kwargs.items())
|
||||
key_parts.extend([f"{k}={v}" for k, v in sorted_kwargs])
|
||||
else:
|
||||
# Extract specific params
|
||||
for i, param_name in enumerate(key_params):
|
||||
# Try to get from kwargs first
|
||||
if param_name in kwargs:
|
||||
value = kwargs[param_name]
|
||||
# Then try positional args
|
||||
elif i < len(args):
|
||||
value = args[i]
|
||||
else:
|
||||
# Parameter not found - use placeholder
|
||||
value = "none"
|
||||
|
||||
key_parts.append(str(value))
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
def generate_key_pattern(cache_type: str, company_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
Generate cache key pattern for matching multiple keys
|
||||
|
||||
Used for invalidation by type or company
|
||||
|
||||
Args:
|
||||
cache_type: Type of cache
|
||||
company_id: Optional company ID to filter by
|
||||
|
||||
Returns:
|
||||
Pattern string (prefix)
|
||||
|
||||
Examples:
|
||||
generate_key_pattern('dashboard_summary')
|
||||
-> "dashboard_summary:"
|
||||
|
||||
generate_key_pattern('dashboard_summary', 123)
|
||||
-> "dashboard_summary:123"
|
||||
"""
|
||||
if company_id is not None:
|
||||
return f"{cache_type}:{company_id}"
|
||||
return f"{cache_type}:"
|
||||
|
||||
|
||||
def hash_complex_params(params: dict) -> str:
|
||||
"""
|
||||
Generate hash for complex parameters (e.g., filters, queries)
|
||||
|
||||
Used when cache key would be too long with full param values
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters to hash
|
||||
|
||||
Returns:
|
||||
8-character hash string
|
||||
|
||||
Example:
|
||||
filters = {'status': 'neplatite', 'date_from': '2024-01-01', 'date_to': '2024-12-31'}
|
||||
hash_complex_params(filters)
|
||||
-> "a3f8b2c1"
|
||||
"""
|
||||
# Sort keys for consistent hashing
|
||||
sorted_params = json.dumps(params, sort_keys=True)
|
||||
hash_obj = hashlib.sha256(sorted_params.encode())
|
||||
# Return first 8 characters of hex digest
|
||||
return hash_obj.hexdigest()[:8]
|
||||
|
||||
|
||||
def extract_company_id_from_key(cache_key: str) -> Optional[int]:
|
||||
"""
|
||||
Extract company_id from cache key
|
||||
|
||||
Assumes format: "cache_type:company_id:..."
|
||||
|
||||
Args:
|
||||
cache_key: Cache key string
|
||||
|
||||
Returns:
|
||||
Company ID or None if not found
|
||||
|
||||
Example:
|
||||
extract_company_id_from_key("dashboard_summary:123:john")
|
||||
-> 123
|
||||
"""
|
||||
parts = cache_key.split(":")
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
return int(parts[1])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def extract_cache_type_from_key(cache_key: str) -> str:
|
||||
"""
|
||||
Extract cache_type from cache key
|
||||
|
||||
Args:
|
||||
cache_key: Cache key string
|
||||
|
||||
Returns:
|
||||
Cache type (first part before colon)
|
||||
|
||||
Example:
|
||||
extract_cache_type_from_key("dashboard_summary:123:john")
|
||||
-> "dashboard_summary"
|
||||
"""
|
||||
return cache_key.split(":")[0]
|
||||
180
backend/modules/reports/cache/memory_cache.py
vendored
Normal file
180
backend/modules/reports/cache/memory_cache.py
vendored
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
In-memory cache with TTL (L1 cache)
|
||||
Fast, limited size, lost on restart
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryCache:
|
||||
"""
|
||||
In-memory LRU cache with TTL support
|
||||
|
||||
Features:
|
||||
- LRU eviction when max_size reached
|
||||
- Per-entry TTL expiration
|
||||
- Thread-safe operations
|
||||
- Fast O(1) get/set operations
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
"""
|
||||
Initialize memory cache
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
self._stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'sets': 0,
|
||||
'evictions': 0
|
||||
}
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
if key not in self._cache:
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check TTL expiration
|
||||
if entry['expires_at'] < time.time():
|
||||
# Expired - remove and return None
|
||||
del self._cache[key]
|
||||
self._stats['misses'] += 1
|
||||
logger.debug(f"Memory cache expired: {key}")
|
||||
return None
|
||||
|
||||
# Move to end (LRU - most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
self._stats['hits'] += 1
|
||||
logger.debug(f"Memory cache HIT: {key}")
|
||||
return entry['value']
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: int):
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds
|
||||
"""
|
||||
expires_at = time.time() + ttl
|
||||
|
||||
# Check if we need to evict (LRU)
|
||||
if key not in self._cache and len(self._cache) >= self.max_size:
|
||||
# Evict oldest entry (first item in OrderedDict)
|
||||
evicted_key = next(iter(self._cache))
|
||||
del self._cache[evicted_key]
|
||||
self._stats['evictions'] += 1
|
||||
logger.debug(f"Memory cache evicted (LRU): {evicted_key}")
|
||||
|
||||
# Store entry
|
||||
self._cache[key] = {
|
||||
'value': value,
|
||||
'expires_at': expires_at,
|
||||
'created_at': time.time()
|
||||
}
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
self._stats['sets'] += 1
|
||||
logger.debug(f"Memory cache SET: {key} (TTL: {ttl}s)")
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete entry from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
logger.debug(f"Memory cache deleted: {key}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def clear(self):
|
||||
"""Clear all entries from cache"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"Memory cache cleared: {count} entries removed")
|
||||
|
||||
async def clear_by_pattern(self, pattern: str):
|
||||
"""
|
||||
Clear entries matching pattern (simple prefix match)
|
||||
|
||||
Args:
|
||||
pattern: Key prefix to match (e.g., "dashboard_summary:123")
|
||||
"""
|
||||
keys_to_delete = [key for key in self._cache.keys() if key.startswith(pattern)]
|
||||
for key in keys_to_delete:
|
||||
del self._cache[key]
|
||||
|
||||
logger.info(f"Memory cache cleared by pattern '{pattern}': {len(keys_to_delete)} entries")
|
||||
|
||||
async def cleanup_expired(self):
|
||||
"""Remove all expired entries"""
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
key for key, entry in self._cache.items()
|
||||
if entry['expires_at'] < now
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Memory cache cleanup: {len(expired_keys)} expired entries removed")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with stats (hits, misses, size, etc.)
|
||||
"""
|
||||
total_requests = self._stats['hits'] + self._stats['misses']
|
||||
hit_rate = (self._stats['hits'] / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
'size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'hits': self._stats['hits'],
|
||||
'misses': self._stats['misses'],
|
||||
'sets': self._stats['sets'],
|
||||
'evictions': self._stats['evictions'],
|
||||
'hit_rate': hit_rate,
|
||||
'total_requests': total_requests
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset statistics counters"""
|
||||
self._stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'sets': 0,
|
||||
'evictions': 0
|
||||
}
|
||||
404
backend/modules/reports/cache/sqlite_cache.py
vendored
Normal file
404
backend/modules/reports/cache/sqlite_cache.py
vendored
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
SQLite persistent cache (L2 cache)
|
||||
Persistent, survives restarts, unlimited size
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import aiosqlite
|
||||
from typing import Any, Optional, List, Dict
|
||||
from pathlib import Path
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, date
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that handles Pydantic models, Decimal, datetime, etc."""
|
||||
def default(self, obj):
|
||||
# Handle Pydantic models
|
||||
if hasattr(obj, 'dict'):
|
||||
return obj.dict()
|
||||
if hasattr(obj, 'model_dump'): # Pydantic v2
|
||||
return obj.model_dump()
|
||||
# Handle Decimal
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
# Handle datetime/date
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class SQLiteCache:
|
||||
"""
|
||||
SQLite-based persistent cache
|
||||
|
||||
Features:
|
||||
- Persistent storage (survives restarts)
|
||||
- JSON serialization for complex objects
|
||||
- Schema mappings (permanent cache for company->schema)
|
||||
- Watermarks for event-based invalidation
|
||||
- Performance tracking and benchmarks
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""
|
||||
Initialize SQLite cache
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._ensure_db_dir()
|
||||
|
||||
def _ensure_db_dir(self):
|
||||
"""Ensure database directory exists"""
|
||||
db_dir = Path(self.db_path).parent
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def init_db(self):
|
||||
"""Initialize database schema with WAL mode enabled"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.commit()
|
||||
|
||||
# Table: cache_entries
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cache_entries (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
cache_type TEXT NOT NULL,
|
||||
company_id INTEGER,
|
||||
data_json TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
expires_at REAL NOT NULL,
|
||||
hit_count INTEGER DEFAULT 0,
|
||||
last_accessed REAL
|
||||
)
|
||||
""")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_cache_type ON cache_entries(cache_type)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_company_id ON cache_entries(company_id)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_expires_at ON cache_entries(expires_at)")
|
||||
|
||||
# Table: schema_mappings (PERMANENT)
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_mappings (
|
||||
id_firma INTEGER PRIMARY KEY,
|
||||
schema TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
last_verified REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Table: query_benchmarks
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS query_benchmarks (
|
||||
cache_type TEXT PRIMARY KEY,
|
||||
avg_time_ms REAL NOT NULL,
|
||||
sample_count INTEGER DEFAULT 0,
|
||||
last_updated REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Table: performance_log
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS performance_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
cache_type TEXT NOT NULL,
|
||||
company_id INTEGER,
|
||||
cache_hit BOOLEAN NOT NULL,
|
||||
response_time_ms REAL NOT NULL,
|
||||
estimated_oracle_time_ms REAL,
|
||||
time_saved_ms REAL,
|
||||
username TEXT,
|
||||
timestamp REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_perf_timestamp ON performance_log(timestamp)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_perf_cache_type ON performance_log(cache_type)")
|
||||
|
||||
# Table: user_cache_settings
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_cache_settings (
|
||||
username TEXT PRIMARY KEY,
|
||||
cache_enabled BOOLEAN DEFAULT TRUE,
|
||||
created_at REAL,
|
||||
updated_at REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Table: cache_config
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cache_config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Table: cache_watermarks
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cache_watermarks (
|
||||
company_id INTEGER PRIMARY KEY,
|
||||
schema TEXT NOT NULL,
|
||||
max_id_act INTEGER NOT NULL,
|
||||
checked_at REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
await db.commit()
|
||||
logger.info("SQLite cache database initialized")
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT data_json, expires_at
|
||||
FROM cache_entries
|
||||
WHERE cache_key = ?
|
||||
""", (key,)) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
data_json, expires_at = result
|
||||
|
||||
# Check TTL expiration
|
||||
if expires_at < time.time():
|
||||
# Expired - delete and return None
|
||||
await db.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
|
||||
await db.commit()
|
||||
logger.debug(f"SQLite cache expired: {key}")
|
||||
return None
|
||||
|
||||
# Update hit_count and last_accessed
|
||||
await db.execute("""
|
||||
UPDATE cache_entries
|
||||
SET hit_count = hit_count + 1, last_accessed = ?
|
||||
WHERE cache_key = ?
|
||||
""", (time.time(), key))
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"SQLite cache HIT: {key}")
|
||||
return json.loads(data_json)
|
||||
|
||||
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int], ttl: int):
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
cache_type: Type of cache entry
|
||||
company_id: Company ID (None for global caches)
|
||||
ttl: Time to live in seconds
|
||||
"""
|
||||
# Use custom encoder to handle Pydantic models, Decimal, datetime, etc.
|
||||
data_json = json.dumps(value, cls=CustomJSONEncoder)
|
||||
now = time.time()
|
||||
expires_at = now + ttl
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT OR REPLACE INTO cache_entries
|
||||
(cache_key, cache_type, company_id, data_json, created_at, expires_at, hit_count, last_accessed)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
|
||||
""", (key, cache_type, company_id, data_json, now, expires_at, now))
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"SQLite cache SET: {key} (TTL: {ttl}s)")
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete entry from cache"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
|
||||
await db.commit()
|
||||
deleted = cursor.rowcount > 0
|
||||
if deleted:
|
||||
logger.debug(f"SQLite cache deleted: {key}")
|
||||
return deleted
|
||||
|
||||
async def clear(self):
|
||||
"""Clear all cache entries"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute("DELETE FROM cache_entries")
|
||||
await db.commit()
|
||||
count = cursor.rowcount
|
||||
logger.info(f"SQLite cache cleared: {count} entries removed")
|
||||
|
||||
async def clear_by_company(self, company_id: int):
|
||||
"""Clear all entries for specific company"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute("DELETE FROM cache_entries WHERE company_id = ?", (company_id,))
|
||||
await db.commit()
|
||||
count = cursor.rowcount
|
||||
logger.info(f"SQLite cache cleared for company {company_id}: {count} entries")
|
||||
|
||||
async def clear_by_type(self, cache_type: str):
|
||||
"""Clear all entries of specific type"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute("DELETE FROM cache_entries WHERE cache_type = ?", (cache_type,))
|
||||
await db.commit()
|
||||
count = cursor.rowcount
|
||||
logger.info(f"SQLite cache cleared for type '{cache_type}': {count} entries")
|
||||
|
||||
async def cleanup_expired(self):
|
||||
"""Remove all expired entries"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute("DELETE FROM cache_entries WHERE expires_at < ?", (time.time(),))
|
||||
await db.commit()
|
||||
count = cursor.rowcount
|
||||
if count > 0:
|
||||
logger.info(f"SQLite cache cleanup: {count} expired entries removed")
|
||||
|
||||
# Schema Mappings (PERMANENT)
|
||||
|
||||
async def get_schema_mapping(self, company_id: int) -> Optional[str]:
|
||||
"""Get permanent cached schema for company"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT schema
|
||||
FROM schema_mappings
|
||||
WHERE id_firma = ?
|
||||
""", (company_id,)) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def set_schema_mapping(self, company_id: int, schema: str):
|
||||
"""Set permanent schema mapping (never expires)"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT OR REPLACE INTO schema_mappings
|
||||
(id_firma, schema, created_at, last_verified)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (company_id, schema, time.time(), time.time()))
|
||||
await db.commit()
|
||||
|
||||
# Benchmarks
|
||||
|
||||
async def get_benchmark(self, cache_type: str) -> Optional[float]:
|
||||
"""Get average benchmark time for cache type"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT avg_time_ms
|
||||
FROM query_benchmarks
|
||||
WHERE cache_type = ?
|
||||
""", (cache_type,)) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def set_benchmark(self, cache_type: str, avg_time_ms: float, sample_count: int):
|
||||
"""Set/update benchmark"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT OR REPLACE INTO query_benchmarks
|
||||
(cache_type, avg_time_ms, sample_count, last_updated)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (cache_type, avg_time_ms, sample_count, time.time()))
|
||||
await db.commit()
|
||||
|
||||
# Performance Tracking
|
||||
|
||||
async def log_performance(self, cache_type: str, company_id: Optional[int], cache_hit: bool,
|
||||
response_time_ms: float, estimated_oracle_time_ms: Optional[float],
|
||||
time_saved_ms: Optional[float], username: Optional[str]):
|
||||
"""Log performance metric"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT INTO performance_log
|
||||
(cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
|
||||
time_saved_ms, username, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
|
||||
time_saved_ms, username, time.time()))
|
||||
await db.commit()
|
||||
|
||||
# User Settings
|
||||
|
||||
async def get_user_cache_enabled(self, username: str) -> bool:
|
||||
"""Get user cache setting (default True)"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT cache_enabled
|
||||
FROM user_cache_settings
|
||||
WHERE username = ?
|
||||
""", (username,)) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return bool(result[0]) if result else True # Default enabled, explicit bool conversion
|
||||
|
||||
async def set_user_cache_enabled(self, username: str, enabled: bool):
|
||||
"""Set user cache setting"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT OR REPLACE INTO user_cache_settings
|
||||
(username, cache_enabled, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (username, enabled, time.time(), time.time()))
|
||||
await db.commit()
|
||||
|
||||
# Watermarks
|
||||
|
||||
async def get_watermark(self, company_id: int) -> Optional[int]:
|
||||
"""Get cached watermark (max_id_act) for company"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT max_id_act
|
||||
FROM cache_watermarks
|
||||
WHERE company_id = ?
|
||||
""", (company_id,)) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def set_watermark(self, company_id: int, schema: str, max_id_act: int):
|
||||
"""Set/update watermark for company"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
INSERT OR REPLACE INTO cache_watermarks
|
||||
(company_id, schema, max_id_act, checked_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (company_id, schema, max_id_act, time.time()))
|
||||
await db.commit()
|
||||
|
||||
async def get_cached_company_ids(self) -> List[int]:
|
||||
"""Get list of company_ids with active cache entries"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute("""
|
||||
SELECT DISTINCT company_id
|
||||
FROM cache_entries
|
||||
WHERE company_id IS NOT NULL
|
||||
AND expires_at > ?
|
||||
""", (time.time(),)) as cursor:
|
||||
results = await cursor.fetchall()
|
||||
return [row[0] for row in results]
|
||||
|
||||
# Statistics
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Total entries
|
||||
async with db.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
|
||||
total_entries = (await cursor.fetchone())[0]
|
||||
|
||||
# Active entries (not expired)
|
||||
async with db.execute("""
|
||||
SELECT COUNT(*) FROM cache_entries WHERE expires_at > ?
|
||||
""", (time.time(),)) as cursor:
|
||||
active_entries = (await cursor.fetchone())[0]
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'active_entries': active_entries,
|
||||
'expired_entries': total_entries - active_entries
|
||||
}
|
||||
0
backend/modules/reports/models/__init__.py
Normal file
0
backend/modules/reports/models/__init__.py
Normal file
19
backend/modules/reports/models/calendar.py
Normal file
19
backend/modules/reports/models/calendar.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Calendar period models for accounting period selector
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class CalendarPeriod(BaseModel):
|
||||
"""Model for an accounting period"""
|
||||
an: int # Year
|
||||
luna: int # Month (1-12)
|
||||
display_name: str # Format: "Decembrie 2025"
|
||||
|
||||
|
||||
class CalendarPeriodsResponse(BaseModel):
|
||||
"""Response model for calendar periods list"""
|
||||
periods: List[CalendarPeriod]
|
||||
current_period: Optional[CalendarPeriod] = None # Most recent period
|
||||
total_count: int
|
||||
129
backend/modules/reports/models/dashboard.py
Normal file
129
backend/modules/reports/models/dashboard.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from pydantic import BaseModel
|
||||
from decimal import Decimal
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
class TreasuryAccount(BaseModel):
|
||||
"""Cont de trezorerie (bancă/casă)"""
|
||||
cont: str # 5121, 5124, 5311, 5314
|
||||
nume_cont: str # "Bancă LEI", "Casă VALUTA" etc
|
||||
nume_banca: str # Numele băncii din vbalanta_parteneri.nume
|
||||
sold: Decimal
|
||||
valuta: str
|
||||
|
||||
class TrendData(BaseModel):
|
||||
"""Model pentru datele de trend - MODEL VECHI"""
|
||||
labels: List[str]
|
||||
incasari: List[Decimal]
|
||||
plati: List[Decimal]
|
||||
trezorerie: List[Decimal]
|
||||
incasari_total: Decimal
|
||||
plati_total: Decimal
|
||||
trezorerie_total: Decimal
|
||||
incasari_change: Optional[float] = None
|
||||
plati_change: Optional[float] = None
|
||||
trezorerie_change: Optional[float] = None
|
||||
|
||||
class TrendsResponse(BaseModel):
|
||||
"""Model pentru răspunsul endpoint-ului de trenduri - MODEL NOU"""
|
||||
# Current period data
|
||||
periods: List[str]
|
||||
clienti_facturat: List[float]
|
||||
clienti_incasat: List[float]
|
||||
furnizori_facturat: List[float]
|
||||
furnizori_achitat: List[float]
|
||||
clienti_sold: List[float]
|
||||
furnizori_sold: List[float]
|
||||
trezorerie_sold: Optional[List[float]] = None
|
||||
rata_incasare_clienti: List[float]
|
||||
rata_achitare_furnizori: List[float]
|
||||
|
||||
# Previous period data (for year-over-year comparison in sparklines)
|
||||
previous_periods: Optional[List[str]] = None
|
||||
clienti_facturat_prev: Optional[List[float]] = None
|
||||
clienti_incasat_prev: Optional[List[float]] = None
|
||||
furnizori_facturat_prev: Optional[List[float]] = None
|
||||
furnizori_achitat_prev: Optional[List[float]] = None
|
||||
clienti_sold_prev: Optional[List[float]] = None
|
||||
furnizori_sold_prev: Optional[List[float]] = None
|
||||
trezorerie_sold_prev: Optional[List[float]] = None
|
||||
|
||||
# Metadata and analytics
|
||||
metadata: Dict[str, Any]
|
||||
growth_rates: Optional[Dict[str, float]] = None
|
||||
|
||||
# Cache metadata (optional, for Telegram Bot)
|
||||
cache_hit: Optional[bool] = None
|
||||
response_time_ms: Optional[float] = None
|
||||
cache_source: Optional[str] = None
|
||||
|
||||
class DashboardSummary(BaseModel):
|
||||
"""Model pentru toate datele dashboard-ului"""
|
||||
# CLIENȚI - statistici existente
|
||||
clienti_total_facturat: Decimal # precdeb + debit (conturi 4111, 461)
|
||||
clienti_total_incasat: Decimal # preccred + credit (conturi 4111, 461)
|
||||
clienti_avansuri: Decimal # sold 419 (pasiv): credit - debit
|
||||
clienti_sold_total: Decimal # (facturat - incasat) - avansuri
|
||||
clienti_sold_restant: Decimal # sold cu datascad < azi
|
||||
|
||||
# CLIENȚI - NOI câmpuri pentru sold în termen
|
||||
clienti_sold_in_termen: Decimal # sold cu datascad >= azi
|
||||
|
||||
# CLIENȚI - NOI detalieri restanțe (sold cu datascad < azi)
|
||||
clienti_restant_7: Decimal # restant 1-7 zile
|
||||
clienti_restant_14: Decimal # restant 8-14 zile
|
||||
clienti_restant_30: Decimal # restant 15-30 zile
|
||||
clienti_restant_60: Decimal # restant 31-60 zile
|
||||
clienti_restant_90: Decimal # restant 61-90 zile
|
||||
clienti_restant_90plus: Decimal # restant 90+ zile
|
||||
|
||||
# CLIENȚI - NOI detalieri scadențe (sold cu datascad >= azi)
|
||||
clienti_scadent_7: Decimal # scadent în 1-7 zile
|
||||
clienti_scadent_14: Decimal # scadent în 8-14 zile
|
||||
clienti_scadent_30: Decimal # scadent în 15-30 zile
|
||||
clienti_scadent_60: Decimal # scadent în 31-60 zile
|
||||
clienti_scadent_90: Decimal # scadent în 61-90 zile
|
||||
clienti_scadent_90plus: Decimal # scadent în 90+ zile
|
||||
|
||||
# FURNIZORI - statistici existente
|
||||
furnizori_total_facturat: Decimal # preccred + credit (conturi 401, 404, 462)
|
||||
furnizori_total_achitat: Decimal # precdeb + debit (conturi 401, 404, 462)
|
||||
furnizori_avansuri: Decimal # sold 409x (activ): debit - credit
|
||||
furnizori_sold_total: Decimal # (facturat - achitat) - avansuri
|
||||
furnizori_sold_restant: Decimal # sold cu datascad < azi
|
||||
|
||||
# FURNIZORI - NOI câmpuri pentru sold în termen
|
||||
furnizori_sold_in_termen: Decimal # sold cu datascad >= azi
|
||||
|
||||
# FURNIZORI - NOI detalieri restanțe (sold cu datascad < azi)
|
||||
furnizori_restant_7: Decimal # restant 1-7 zile
|
||||
furnizori_restant_14: Decimal # restant 8-14 zile
|
||||
furnizori_restant_30: Decimal # restant 15-30 zile
|
||||
furnizori_restant_60: Decimal # restant 31-60 zile
|
||||
furnizori_restant_90: Decimal # restant 61-90 zile
|
||||
furnizori_restant_90plus: Decimal # restant 90+ zile
|
||||
|
||||
# FURNIZORI - NOI detalieri scadențe (sold cu datascad >= azi)
|
||||
furnizori_scadent_7: Decimal # scadent în 1-7 zile
|
||||
furnizori_scadent_14: Decimal # scadent în 8-14 zile
|
||||
furnizori_scadent_30: Decimal # scadent în 15-30 zile
|
||||
furnizori_scadent_60: Decimal # scadent în 31-60 zile
|
||||
furnizori_scadent_90: Decimal # scadent în 61-90 zile
|
||||
furnizori_scadent_90plus: Decimal # scadent în 90+ zile
|
||||
|
||||
# TREZORERIE - existente
|
||||
treasury_accounts: List[TreasuryAccount]
|
||||
treasury_totals_by_currency: Dict[str, Decimal]
|
||||
|
||||
# DATE SUPLIMENTARE pentru trend analysis
|
||||
clienti_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
|
||||
furnizori_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
|
||||
clienti_facturat_an_curent: Optional[Decimal] = Decimal('0')
|
||||
clienti_facturat_an_anterior: Optional[Decimal] = Decimal('0')
|
||||
furnizori_facturat_an_curent: Optional[Decimal] = Decimal('0')
|
||||
furnizori_facturat_an_anterior: Optional[Decimal] = Decimal('0')
|
||||
|
||||
# SOLDURI TVA
|
||||
tva_plata_precedent: Decimal = Decimal('0')
|
||||
tva_recuperat_precedent: Decimal = Decimal('0')
|
||||
tva_plata_curent: Decimal = Decimal('0')
|
||||
tva_recuperat_curent: Decimal = Decimal('0')
|
||||
79
backend/modules/reports/models/invoice.py
Normal file
79
backend/modules/reports/models/invoice.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Modele Pydantic pentru facturi - Compatibile cu aplicația Flask existentă
|
||||
"""
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from datetime import date
|
||||
from typing import Optional, List, Literal
|
||||
from decimal import Decimal
|
||||
|
||||
class InvoiceBase(BaseModel):
|
||||
"""Model de bază pentru factură - mapează exact pe rezultatul query-ului Flask"""
|
||||
nume: str = Field(description="Numele partenerului")
|
||||
nract: int = Field(description="Numărul actului")
|
||||
dataact: Optional[date] = Field(description="Data actului")
|
||||
datascad: Optional[date] = Field(description="Data scadentă")
|
||||
contract: Optional[str] = Field(description="Numărul contractului")
|
||||
cod_fiscal: Optional[str] = Field(description="Codul fiscal")
|
||||
reg_comert: Optional[str] = Field(description="Registrul comerțului")
|
||||
cont: Optional[str] = Field(description="Contul contabil")
|
||||
valuta: str = Field(default="RON", description="Valuta (RON, EUR, USD, etc.)")
|
||||
|
||||
class Invoice(InvoiceBase):
|
||||
"""Model complet pentru factură cu calcule financiare"""
|
||||
totctva: Decimal = Field(description="Total cu TVA", decimal_places=2)
|
||||
achitat: Decimal = Field(description="Suma achitată", decimal_places=2)
|
||||
soldfinal: Decimal = Field(description="Soldul final", decimal_places=2)
|
||||
css_class: Literal["", "invoice-paid", "invoice-overdue"] = Field(
|
||||
default="", description="Clasa CSS pentru stilizare"
|
||||
)
|
||||
|
||||
@validator('css_class', always=True)
|
||||
def determine_css_class(cls, v, values):
|
||||
"""Determină automat clasa CSS bazată pe status factură"""
|
||||
if 'soldfinal' in values and 'datascad' in values:
|
||||
sold = values['soldfinal']
|
||||
data_scad = values['datascad']
|
||||
|
||||
if sold < 1:
|
||||
return 'invoice-paid'
|
||||
elif data_scad and data_scad < date.today() and sold != 0:
|
||||
return 'invoice-overdue'
|
||||
return ''
|
||||
|
||||
class InvoiceFilter(BaseModel):
|
||||
"""Filtru pentru căutarea facturilor"""
|
||||
company: str = Field(description="Codul firmei (schema Oracle)")
|
||||
partner_type: Literal["CLIENTI", "FURNIZORI"] = Field(description="Tipul partenerului")
|
||||
luna: Optional[int] = Field(default=None, ge=1, le=12, description="Luna contabilă (1-12)")
|
||||
an: Optional[int] = Field(default=None, ge=2000, le=2100, description="Anul contabil")
|
||||
partner_name: Optional[str] = Field(description="Filtru după nume")
|
||||
cont: Optional[str] = Field(description="Filtru după cont contabil")
|
||||
only_unpaid: bool = Field(default=True, description="Doar neachitate")
|
||||
min_amount: Optional[Decimal] = Field(description="Suma minimă")
|
||||
max_amount: Optional[Decimal] = Field(description="Suma maximă")
|
||||
page: int = Field(default=1, ge=1, description="Pagina")
|
||||
page_size: int = Field(default=50, ge=1, le=10000000, description="Mărimea paginii")
|
||||
|
||||
class InvoiceListResponse(BaseModel):
|
||||
"""Răspuns pentru lista de facturi"""
|
||||
invoices: List[Invoice]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
total_amount: Decimal
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
accounting_period: Optional[dict] = Field(default=None, description="Perioada contabilă (an, luna)")
|
||||
# Total sold din TOATE facturile filtrate (nu doar pagina curentă)
|
||||
total_sold_all: Decimal = Field(default=Decimal('0.00'), description="Total sold din toate facturile filtrate")
|
||||
|
||||
class InvoiceSummary(BaseModel):
|
||||
"""Rezumat pentru facturi - pentru dashboard"""
|
||||
company: str
|
||||
partner_type: str
|
||||
total_invoices: int
|
||||
total_amount: Decimal
|
||||
paid_amount: Decimal
|
||||
outstanding_amount: Decimal
|
||||
overdue_amount: Decimal
|
||||
overdue_count: int
|
||||
52
backend/modules/reports/models/treasury.py
Normal file
52
backend/modules/reports/models/treasury.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pydantic import BaseModel
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
class AccountingPeriod(BaseModel):
|
||||
"""Model pentru perioada contabilă"""
|
||||
an: Optional[int] = None
|
||||
luna: Optional[int] = None
|
||||
|
||||
class BankCashRegister(BaseModel):
|
||||
"""Model pentru Registrul de Casă și Bancă"""
|
||||
nume: str
|
||||
nract: Optional[int] = None
|
||||
dataact: Optional[datetime] = None
|
||||
nume_cont_bancar: str # din vbalanta_parteneri.nume
|
||||
incasari: Decimal
|
||||
plati: Decimal
|
||||
sold: Decimal
|
||||
valuta: Optional[str] = None
|
||||
tip_registru: str # "BANCA LEI", "CASA VALUTA" etc
|
||||
explicatia: str
|
||||
|
||||
class RegisterFilter(BaseModel):
|
||||
"""Filtre pentru registrul de casă și bancă"""
|
||||
company: str
|
||||
register_type: Optional[str] = None # BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate
|
||||
luna: Optional[int] = None # Luna contabilă (1-12) pentru PACK_SESIUNE
|
||||
an: Optional[int] = None # Anul contabil pentru PACK_SESIUNE
|
||||
date_from: Optional[datetime] = None
|
||||
date_to: Optional[datetime] = None
|
||||
partner_name: Optional[str] = None
|
||||
bank_account: Optional[str] = None # Filter for specific bank/cash account (bancasa)
|
||||
page: int = 1
|
||||
page_size: int = 50
|
||||
|
||||
class RegisterListResponse(BaseModel):
|
||||
"""Răspuns pentru lista din registru"""
|
||||
registers: List[BankCashRegister]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
total_incasari: Decimal
|
||||
total_plati: Decimal
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
accounting_period: Optional[AccountingPeriod] = None
|
||||
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
|
||||
sold_precedent_all: Decimal = Decimal('0.00')
|
||||
total_incasari_all: Decimal = Decimal('0.00')
|
||||
total_plati_all: Decimal = Decimal('0.00')
|
||||
sold_final_all: Decimal = Decimal('0.00')
|
||||
102
backend/modules/reports/models/trial_balance.py
Normal file
102
backend/modules/reports/models/trial_balance.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Pydantic models for Trial Balance (Balanță de Verificare)
|
||||
Maps to Oracle VBAL VIEW (exists in each company schema)
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from decimal import Decimal
|
||||
|
||||
class TrialBalanceItem(BaseModel):
|
||||
"""
|
||||
Individual trial balance record from VBAL VIEW
|
||||
Real structure from Oracle:
|
||||
- CONT: account number
|
||||
- DENUMIRE: account description
|
||||
- PRECDEB/PRECCRED: previous balance debit/credit
|
||||
- RULDEB/RULCRED: monthly movement debit/credit
|
||||
- SOLDDEB/SOLDCRED: final balance debit/credit
|
||||
"""
|
||||
cont: str = Field(description="Număr cont contabil (CONT)")
|
||||
denumire: Optional[str] = Field(default="", description="Denumire cont (DENUMIRE)")
|
||||
sold_precedent_debit: Decimal = Field(description="Sold precedent debit (PRECDEB)", decimal_places=2)
|
||||
sold_precedent_credit: Decimal = Field(description="Sold precedent credit (PRECCRED)", decimal_places=2)
|
||||
rulaj_lunar_debit: Decimal = Field(description="Rulaj lunar debit (RULDEB)", decimal_places=2)
|
||||
rulaj_lunar_credit: Decimal = Field(description="Rulaj lunar credit (RULCRED)", decimal_places=2)
|
||||
sold_final_debit: Decimal = Field(description="Sold final debit (SOLDDEB)", decimal_places=2)
|
||||
sold_final_credit: Decimal = Field(description="Sold final credit (SOLDCRED)", decimal_places=2)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TrialBalanceFilters(BaseModel):
|
||||
"""
|
||||
Filters applied to trial balance data
|
||||
"""
|
||||
luna: int = Field(description="Luna (1-12)")
|
||||
an: int = Field(description="An")
|
||||
cont_filter: Optional[str] = Field(default=None, description="Filtru număr cont (partial match)")
|
||||
denumire_filter: Optional[str] = Field(default=None, description="Filtru denumire cont (partial match, case-insensitive)")
|
||||
|
||||
|
||||
class TrialBalancePagination(BaseModel):
|
||||
"""
|
||||
Pagination metadata
|
||||
"""
|
||||
total_items: int = Field(description="Total number of items")
|
||||
total_pages: int = Field(description="Total number of pages")
|
||||
current_page: int = Field(description="Current page number")
|
||||
page_size: int = Field(description="Items per page")
|
||||
|
||||
|
||||
class TrialBalanceTotals(BaseModel):
|
||||
"""
|
||||
Totals for all 6 columns from all filtered records (not just current page)
|
||||
"""
|
||||
total_sold_precedent_debit: Decimal = Decimal('0.00')
|
||||
total_sold_precedent_credit: Decimal = Decimal('0.00')
|
||||
total_rulaj_lunar_debit: Decimal = Decimal('0.00')
|
||||
total_rulaj_lunar_credit: Decimal = Decimal('0.00')
|
||||
total_sold_final_debit: Decimal = Decimal('0.00')
|
||||
total_sold_final_credit: Decimal = Decimal('0.00')
|
||||
|
||||
|
||||
class TrialBalanceResponse(BaseModel):
|
||||
"""
|
||||
Complete response for trial balance endpoint
|
||||
"""
|
||||
success: bool = Field(default=True, description="Request success status")
|
||||
data: dict = Field(description="Trial balance data with items, pagination, and filters")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"data": {
|
||||
"items": [
|
||||
{
|
||||
"cont": "4111",
|
||||
"dcont": "Furnizori interni",
|
||||
"sold_precedent_debit": 0.00,
|
||||
"sold_precedent_credit": 15000.00,
|
||||
"rulaj_lunar_debit": 5000.00,
|
||||
"rulaj_lunar_credit": 8000.00,
|
||||
"sold_final_debit": 0.00,
|
||||
"sold_final_credit": 18000.00
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
"total_items": 150,
|
||||
"total_pages": 3,
|
||||
"current_page": 1,
|
||||
"page_size": 50
|
||||
},
|
||||
"filters_applied": {
|
||||
"luna": 11,
|
||||
"an": 2025,
|
||||
"cont_filter": None,
|
||||
"denumire_filter": "furnizori"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
36
backend/modules/reports/routers/__init__.py
Normal file
36
backend/modules/reports/routers/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Reports module router factory."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
def create_reports_router() -> APIRouter:
|
||||
"""
|
||||
Create and configure Reports module router.
|
||||
|
||||
Includes all report-related endpoints:
|
||||
- /invoices - Invoice management
|
||||
- /dashboard - Dashboard and metrics
|
||||
- /treasury - Treasury operations
|
||||
- /trial-balance - Trial balance reports
|
||||
- /cache - Cache management
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for reports module
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
# Import routers here to avoid circular imports
|
||||
from .invoices import router as invoices_router
|
||||
from .dashboard import router as dashboard_router
|
||||
from .treasury import router as treasury_router
|
||||
from .trial_balance import router as trial_balance_router
|
||||
from .cache import router as cache_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/reports)
|
||||
router.include_router(invoices_router, prefix="/invoices", tags=["reports-invoices"])
|
||||
router.include_router(dashboard_router, prefix="/dashboard", tags=["reports-dashboard"])
|
||||
router.include_router(treasury_router, prefix="/treasury", tags=["reports-treasury"])
|
||||
router.include_router(trial_balance_router, prefix="/trial-balance", tags=["reports-trial-balance"])
|
||||
router.include_router(cache_router, prefix="/cache", tags=["reports-cache"])
|
||||
|
||||
return router
|
||||
398
backend/modules/reports/routers/cache.py
Normal file
398
backend/modules/reports/routers/cache.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
API Router pentru managementul cache-ului
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict, Any
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
from ..cache import get_cache, get_event_monitor, toggle_event_monitor
|
||||
|
||||
router = APIRouter(prefix="/cache", tags=["cache"])
|
||||
|
||||
|
||||
# Pydantic Models
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
"""Răspuns statistici cache"""
|
||||
enabled: bool
|
||||
global_enabled: bool
|
||||
user_enabled: bool
|
||||
cache_type: str
|
||||
hit_rate: float
|
||||
total_hits: int
|
||||
total_misses: int
|
||||
queries_saved: Dict[str, int]
|
||||
response_times: Dict[str, Dict[str, Any]]
|
||||
cache_size: Dict[str, int]
|
||||
auto_invalidate: bool
|
||||
last_cleanup: Optional[str] = None
|
||||
|
||||
|
||||
class InvalidateCacheRequest(BaseModel):
|
||||
"""Request pentru invalidare cache"""
|
||||
company_id: Optional[int] = None
|
||||
cache_type: Optional[str] = None
|
||||
|
||||
|
||||
class ToggleUserCacheRequest(BaseModel):
|
||||
"""Request pentru toggle cache per-user"""
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ToggleGlobalCacheRequest(BaseModel):
|
||||
"""Request pentru toggle cache global"""
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ToggleAutoInvalidateRequest(BaseModel):
|
||||
"""Request pentru toggle auto-invalidation"""
|
||||
enabled: bool
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
||||
async def _calculate_cache_stats() -> Dict[str, Any]:
|
||||
"""Calculate comprehensive cache statistics"""
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
raise HTTPException(status_code=503, detail="Cache not initialized")
|
||||
|
||||
# Get basic cache stats
|
||||
stats = await cache.get_stats()
|
||||
|
||||
# Calculate hit rate
|
||||
memory_stats = stats.get('memory', {})
|
||||
total_hits = memory_stats.get('hits', 0)
|
||||
total_misses = memory_stats.get('misses', 0)
|
||||
total_requests = total_hits + total_misses
|
||||
hit_rate = (total_hits / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
# Calculate queries saved (from performance_log)
|
||||
queries_saved = await _calculate_queries_saved(cache)
|
||||
|
||||
# Calculate response times per cache type
|
||||
response_times = await _calculate_response_times(cache)
|
||||
|
||||
# Get cache sizes
|
||||
cache_size = {
|
||||
'memory': memory_stats.get('size', 0),
|
||||
'sqlite': stats.get('sqlite', {}).get('active_entries', 0)
|
||||
}
|
||||
|
||||
# Get event monitor status
|
||||
monitor = get_event_monitor()
|
||||
auto_invalidate = monitor.running if monitor else False
|
||||
|
||||
return {
|
||||
'enabled': cache.config.enabled,
|
||||
'global_enabled': cache.config.enabled,
|
||||
'cache_type': cache.config.cache_type,
|
||||
'hit_rate': round(hit_rate, 1),
|
||||
'total_hits': total_hits,
|
||||
'total_misses': total_misses,
|
||||
'queries_saved': queries_saved,
|
||||
'response_times': response_times,
|
||||
'cache_size': cache_size,
|
||||
'auto_invalidate': auto_invalidate,
|
||||
'last_cleanup': None # TODO: track last cleanup time
|
||||
}
|
||||
|
||||
|
||||
async def _calculate_queries_saved(cache) -> Dict[str, int]:
|
||||
"""Calculate queries saved by time period"""
|
||||
import aiosqlite
|
||||
|
||||
try:
|
||||
async with aiosqlite.connect(cache.sqlite.db_path) as db:
|
||||
now = time.time()
|
||||
today_start = now - 86400 # 24 hours
|
||||
week_start = now - 604800 # 7 days
|
||||
|
||||
# Today
|
||||
async with db.execute("""
|
||||
SELECT COUNT(*) FROM performance_log
|
||||
WHERE cache_hit = 1 AND timestamp >= ?
|
||||
""", (today_start,)) as cursor:
|
||||
today = (await cursor.fetchone())[0]
|
||||
|
||||
# This week
|
||||
async with db.execute("""
|
||||
SELECT COUNT(*) FROM performance_log
|
||||
WHERE cache_hit = 1 AND timestamp >= ?
|
||||
""", (week_start,)) as cursor:
|
||||
week = (await cursor.fetchone())[0]
|
||||
|
||||
# All time
|
||||
async with db.execute("""
|
||||
SELECT COUNT(*) FROM performance_log
|
||||
WHERE cache_hit = 1
|
||||
""") as cursor:
|
||||
total = (await cursor.fetchone())[0]
|
||||
|
||||
return {
|
||||
'today': today,
|
||||
'week': week,
|
||||
'total': total
|
||||
}
|
||||
except Exception as e:
|
||||
return {'today': 0, 'week': 0, 'total': 0}
|
||||
|
||||
|
||||
async def _calculate_response_times(cache) -> Dict[str, Dict[str, Any]]:
|
||||
"""Calculate average response times per cache type"""
|
||||
import aiosqlite
|
||||
|
||||
try:
|
||||
async with aiosqlite.connect(cache.sqlite.db_path) as db:
|
||||
# Get average times per cache type
|
||||
async with db.execute("""
|
||||
SELECT
|
||||
cache_type,
|
||||
AVG(CASE WHEN cache_hit = 1 THEN response_time_ms ELSE NULL END) as avg_cached,
|
||||
AVG(CASE WHEN cache_hit = 0 THEN response_time_ms ELSE NULL END) as avg_oracle
|
||||
FROM performance_log
|
||||
WHERE timestamp >= ?
|
||||
GROUP BY cache_type
|
||||
""", (time.time() - 86400,)) as cursor: # Last 24 hours
|
||||
results = await cursor.fetchall()
|
||||
|
||||
response_times = {}
|
||||
for row in results:
|
||||
cache_type, avg_cached, avg_oracle = row
|
||||
if avg_cached and avg_oracle:
|
||||
improvement = int((avg_oracle - avg_cached) / avg_oracle * 100)
|
||||
response_times[cache_type] = {
|
||||
'cached': int(avg_cached),
|
||||
'oracle': int(avg_oracle),
|
||||
'improvement': improvement
|
||||
}
|
||||
|
||||
return response_times
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@router.get("/stats", response_model=CacheStatsResponse)
|
||||
async def get_cache_stats(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține statistici complete cache
|
||||
|
||||
Returns:
|
||||
- Hit rate, queries saved, response times
|
||||
- Cache sizes (memory + SQLite)
|
||||
- Auto-invalidation status
|
||||
- Per-user cache setting
|
||||
"""
|
||||
try:
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
raise HTTPException(status_code=503, detail="Cache not initialized")
|
||||
|
||||
# Get base stats
|
||||
stats = await _calculate_cache_stats()
|
||||
|
||||
# Add user-specific setting
|
||||
user_enabled = await cache.is_enabled_for_user(current_user.username)
|
||||
stats['user_enabled'] = user_enabled
|
||||
|
||||
return CacheStatsResponse(**stats)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error retrieving cache stats: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/invalidate")
|
||||
async def invalidate_cache(
|
||||
request: InvalidateCacheRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Invalidează cache
|
||||
|
||||
Args:
|
||||
company_id: Opțional - invalidează doar pentru această companie
|
||||
cache_type: Opțional - invalidează doar acest tip de cache
|
||||
|
||||
Returns:
|
||||
Message de confirmare
|
||||
"""
|
||||
try:
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
raise HTTPException(status_code=503, detail="Cache not initialized")
|
||||
|
||||
await cache.invalidate(
|
||||
company_id=request.company_id,
|
||||
cache_type=request.cache_type
|
||||
)
|
||||
|
||||
if request.company_id and request.cache_type:
|
||||
message = f"Cache invalidated for company {request.company_id}, type {request.cache_type}"
|
||||
elif request.company_id:
|
||||
message = f"Cache invalidated for company {request.company_id}"
|
||||
elif request.cache_type:
|
||||
message = f"Cache invalidated for type {request.cache_type}"
|
||||
else:
|
||||
message = "All cache invalidated"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"invalidated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error invalidating cache: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/toggle-user")
|
||||
async def toggle_user_cache(
|
||||
request: ToggleUserCacheRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Toggle cache per-user
|
||||
|
||||
Permite utilizatorului să activeze/dezactiveze cache-ul pentru el
|
||||
Folosit pentru A/B testing și comparații de performanță
|
||||
|
||||
Args:
|
||||
enabled: True pentru activare, False pentru dezactivare
|
||||
|
||||
Returns:
|
||||
Noul status
|
||||
"""
|
||||
try:
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
raise HTTPException(status_code=503, detail="Cache not initialized")
|
||||
|
||||
await cache.set_user_cache_enabled(current_user.username, request.enabled)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"username": current_user.username,
|
||||
"cache_enabled": request.enabled,
|
||||
"message": f"Cache {'enabled' if request.enabled else 'disabled'} for user {current_user.username}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error toggling user cache: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/toggle-global")
|
||||
async def toggle_global_cache(
|
||||
request: ToggleGlobalCacheRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Toggle cache global (ADMIN only)
|
||||
|
||||
Activează/dezactivează cache-ul la nivel global pentru toți utilizatorii
|
||||
|
||||
Args:
|
||||
enabled: True pentru activare, False pentru dezactivare
|
||||
|
||||
Returns:
|
||||
Noul status global
|
||||
"""
|
||||
try:
|
||||
# TODO: Add admin permission check
|
||||
# For now, allow any authenticated user
|
||||
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
raise HTTPException(status_code=503, detail="Cache not initialized")
|
||||
|
||||
# Update config (NOTE: This is runtime only, .env needs manual update)
|
||||
cache.config.enabled = request.enabled
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"global_enabled": request.enabled,
|
||||
"message": f"Cache {'enabled' if request.enabled else 'disabled'} globally",
|
||||
"note": "This change is runtime only. Update .env CACHE_ENABLED for persistence."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error toggling global cache: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/toggle-auto-invalidate")
|
||||
async def toggle_auto_invalidation(
|
||||
request: ToggleAutoInvalidateRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Toggle auto-invalidation monitoring
|
||||
|
||||
Activează/dezactivează monitorizarea automată a {schema}.act
|
||||
pentru invalidarea cache-ului când se detectează modificări
|
||||
|
||||
Args:
|
||||
enabled: True pentru activare, False pentru dezactivare
|
||||
|
||||
Returns:
|
||||
Noul status auto-invalidation
|
||||
"""
|
||||
try:
|
||||
# TODO: Add admin permission check
|
||||
# For now, allow any authenticated user
|
||||
|
||||
await toggle_event_monitor(request.enabled)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"auto_invalidate_enabled": request.enabled,
|
||||
"message": f"Auto-invalidation {'enabled' if request.enabled else 'disabled'}",
|
||||
"note": "Monitors max(id_act) in {schema}.act tables for changes"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error toggling auto-invalidation: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def cache_health():
|
||||
"""
|
||||
Health check pentru sistemul de cache
|
||||
|
||||
Returns:
|
||||
Status cache, mărime, și uptime
|
||||
"""
|
||||
try:
|
||||
cache = get_cache()
|
||||
if not cache:
|
||||
return {
|
||||
"status": "not_initialized",
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
stats = await cache.get_stats()
|
||||
monitor = get_event_monitor()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"enabled": cache.config.enabled,
|
||||
"cache_type": cache.config.cache_type,
|
||||
"memory_size": stats.get('memory', {}).get('size', 0),
|
||||
"sqlite_size": stats.get('sqlite', {}).get('active_entries', 0),
|
||||
"auto_invalidate_running": monitor.running if monitor else False
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
430
backend/modules/reports/routers/dashboard.py
Normal file
430
backend/modules/reports/routers/dashboard.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from typing import Optional
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..models.dashboard import DashboardSummary, TrendsResponse, TrendData
|
||||
from ..services.dashboard_service import DashboardService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_dashboard_summary(
|
||||
request: Request,
|
||||
company: str = Query(description="Codul firmei"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține toate datele pentru dashboard într-un singur apel
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează statistici clienți/furnizori și trezorerie
|
||||
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
|
||||
- Suportă filtrare pe luna/an contabil (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_complete_summary(company, current_user.username, luna=luna, an=an, request=request)
|
||||
|
||||
# Convert Pydantic model to dict for JSON serialization
|
||||
result_dict = result.dict() if hasattr(result, 'dict') else result
|
||||
|
||||
# Add cache metadata if requested (for Telegram Bot)
|
||||
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
|
||||
if include_metadata:
|
||||
cache_hit = getattr(request.state, 'cache_hit', False)
|
||||
response_time = getattr(request.state, 'response_time_ms', 0)
|
||||
cache_source = getattr(request.state, 'cache_source', None)
|
||||
result_dict['cache_hit'] = cache_hit
|
||||
result_dict['response_time_ms'] = response_time
|
||||
# Always include cache_source, even if None
|
||||
result_dict['cache_source'] = cache_source
|
||||
|
||||
return result_dict
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor dashboard: {str(e)}")
|
||||
|
||||
@router.get("/trends", response_model=TrendsResponse)
|
||||
async def get_dashboard_trends(
|
||||
request: Request,
|
||||
company: str = Query(description="Codul firmei"),
|
||||
period: str = Query(default="30d", description="Perioada pentru trends: 7d, 30d, ytd, 12m"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
compare_previous: bool = Query(default=True, description="Compară cu perioada anterioară"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține trenduri pentru indicatorii principali (clienți/furnizori)
|
||||
|
||||
- period: "7d" (7 zile), "30d" (30 zile), "ytd" (year to date), "12m" (12 luni)
|
||||
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
- compare_previous: dacă să compare cu perioada anterioară
|
||||
- Necesită autentificare JWT
|
||||
- Returnează date pentru grafice de trenduri
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
# Validează perioada
|
||||
valid_periods = ["7d", "30d", "ytd", "12m"]
|
||||
if period not in valid_periods:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Perioadă nevalidă: {period}. Valori permise: {', '.join(valid_periods)}"
|
||||
)
|
||||
|
||||
# Obține datele de trenduri
|
||||
result = await DashboardService.get_trends(int(company), period, luna=luna, an=an, request=request)
|
||||
|
||||
# Convert to dict if needed
|
||||
result_dict = result.dict() if hasattr(result, 'dict') else result
|
||||
|
||||
# Add cache metadata if requested (for Telegram Bot)
|
||||
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
|
||||
if include_metadata:
|
||||
cache_hit = getattr(request.state, 'cache_hit', False)
|
||||
response_time = getattr(request.state, 'response_time_ms', 0)
|
||||
cache_source = getattr(request.state, 'cache_source', None)
|
||||
result_dict['cache_hit'] = cache_hit
|
||||
result_dict['response_time_ms'] = response_time
|
||||
# Always include cache_source, even if None
|
||||
result_dict['cache_source'] = cache_source
|
||||
|
||||
# Return as TrendsResponse
|
||||
return TrendsResponse(**result_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Value error in trends endpoint: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea trendurilor: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea trendurilor: {str(e)}")
|
||||
|
||||
@router.get("/detailed-data")
|
||||
async def get_detailed_data(
|
||||
company: str = Query(description="Codul firmei"),
|
||||
data_type: str = Query(description="Tipul de date: clients, suppliers, treasury"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=25, ge=1, le=100),
|
||||
search: str = Query(default="", description="Termen de căutare"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține date detaliate pentru tabelele din dashboard
|
||||
"""
|
||||
logger.info(f"[ROUTER] detailed-data called: company={company}, data_type={data_type}")
|
||||
try:
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
logger.info(f"[ROUTER] Calling DashboardService.get_detailed_data")
|
||||
result = await DashboardService.get_detailed_data(
|
||||
company=company,
|
||||
data_type=data_type,
|
||||
luna=luna,
|
||||
an=an,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search
|
||||
)
|
||||
|
||||
logger.info(f"[ROUTER] Service returned: {len(result.get('data', []))} rows")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea datelor detaliate: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/performance")
|
||||
async def get_performance(
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
period: str = Query("7d", regex="^(7d|1m|3m|6m|ytd|12m)$", description="Perioada pentru analiză"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează date performanță pentru perioada selectată
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează grafice încasări vs plăți pentru perioada selectată
|
||||
- Calculează indicatori: rata încasării, cash conversion, working capital
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_performance_data(company, period)
|
||||
|
||||
# Convert to Chart.js compatible format
|
||||
return {
|
||||
"labels": result.get("labels", []),
|
||||
"datasets": [{
|
||||
"data": result.get("data", []),
|
||||
"label": result.get("label", "Performance"),
|
||||
"borderColor": result.get("borderColor", "#3B82F6"),
|
||||
"backgroundColor": result.get("backgroundColor", "rgba(59, 130, 246, 0.1)"),
|
||||
"tension": 0.4
|
||||
}]
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea datelor de performanță: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor de performanță: {str(e)}")
|
||||
|
||||
@router.get("/cashflow")
|
||||
async def get_cashflow(
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
period: str = Query("7d", regex="^(7d|1m|3m|6m)$", description="Perioada pentru previziune"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează previziune cash flow pentru perioada selectată
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Analizează scadențele viitoare pentru calculul cash flow-ului
|
||||
- Identifică zilele critice cu deficit de cash
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_cashflow_forecast(company, period)
|
||||
|
||||
# Convert to Chart.js compatible format
|
||||
return {
|
||||
"labels": result.get("labels", []),
|
||||
"datasets": [{
|
||||
"data": result.get("data", []),
|
||||
"label": result.get("label", "Cash Flow"),
|
||||
"borderColor": result.get("borderColor", "#10B981"),
|
||||
"backgroundColor": result.get("backgroundColor", "rgba(16, 185, 129, 0.1)"),
|
||||
"tension": 0.4
|
||||
}]
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea previziunii cash flow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea previziunii cash flow: {str(e)}")
|
||||
|
||||
@router.get("/maturity")
|
||||
async def get_maturity_analysis(
|
||||
request: Request,
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
period: str = Query("7d", regex="^(7d|1m|3m|6m|12m|all)$", description="Orizont de planificare pentru analiza scadențelor"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează analiza scadențelor pentru orizontul de planificare selectat
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Logică: Include TOATE restanțele + scadențele viitoare din perioada selectată
|
||||
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
- Perioade disponibile:
|
||||
* 7d: Toate restanțele + scadențe următoarelor 7 zile
|
||||
* 1m: Toate restanțele + scadențe următoarelor 30 zile
|
||||
* 3m: Toate restanțele + scadențe următoarelor 90 zile
|
||||
* 6m: Toate restanțele + scadențe următoarelor 180 zile
|
||||
* 12m: Toate restanțele + scadențe următoarelor 365 zile
|
||||
* all: Toate soldurile (fără filtru)
|
||||
- Compară scadențele clienți vs furnizori
|
||||
- Calculează balanța și oferă recomandări
|
||||
- Returnează metadate cu statistici complete
|
||||
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_maturity_analysis(company, period, luna=luna, an=an, request=request)
|
||||
|
||||
# Convert to dict if needed
|
||||
result_dict = result.dict() if hasattr(result, 'dict') else result
|
||||
|
||||
# Add cache metadata if requested (for Telegram Bot)
|
||||
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
|
||||
if include_metadata:
|
||||
cache_hit = getattr(request.state, 'cache_hit', False)
|
||||
response_time = getattr(request.state, 'response_time_ms', 0)
|
||||
cache_source = getattr(request.state, 'cache_source', None)
|
||||
result_dict['cache_hit'] = cache_hit
|
||||
result_dict['response_time_ms'] = response_time
|
||||
# Always include cache_source, even if None
|
||||
result_dict['cache_source'] = cache_source
|
||||
|
||||
return result_dict
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea analizei scadențelor: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea analizei scadențelor: {str(e)}")
|
||||
|
||||
@router.get("/monthly-flows")
|
||||
async def get_monthly_flows(
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează fluxurile lunare pentru firma selectată
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează date pentru analiza fluxurilor lunare
|
||||
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_monthly_flows(company, luna=luna, an=an)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea fluxurilor lunare: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea fluxurilor lunare: {str(e)}")
|
||||
|
||||
@router.get("/treasury-breakdown")
|
||||
async def get_treasury_breakdown(
|
||||
request: Request,
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează defalcarea trezoreriei pentru firma selectată
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează distribuția soldurilor pe conturi și tipuri
|
||||
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_treasury_breakdown(company, luna=luna, an=an, request=request)
|
||||
|
||||
# Convert to dict if needed
|
||||
result_dict = result.dict() if hasattr(result, 'dict') else result
|
||||
|
||||
# Add cache metadata if requested (for Telegram Bot)
|
||||
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
|
||||
if include_metadata:
|
||||
cache_hit = getattr(request.state, 'cache_hit', False)
|
||||
response_time = getattr(request.state, 'response_time_ms', 0)
|
||||
cache_source = getattr(request.state, 'cache_source', None)
|
||||
result_dict['cache_hit'] = cache_hit
|
||||
result_dict['response_time_ms'] = response_time
|
||||
# Always include cache_source, even if None
|
||||
result_dict['cache_source'] = cache_source
|
||||
|
||||
return result_dict
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
|
||||
|
||||
@router.get("/net-balance-breakdown")
|
||||
async def get_net_balance_breakdown(
|
||||
request: Request,
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează defalcarea balanței nete pentru firma selectată
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează analiza detaliată a balanței nete
|
||||
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
|
||||
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_net_balance_breakdown(company, luna=luna, an=an, request=request)
|
||||
|
||||
# Convert to dict if needed
|
||||
result_dict = result.dict() if hasattr(result, 'dict') else result
|
||||
|
||||
# Add cache metadata if requested (for Telegram Bot)
|
||||
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
|
||||
if include_metadata:
|
||||
cache_hit = getattr(request.state, 'cache_hit', False)
|
||||
response_time = getattr(request.state, 'response_time_ms', 0)
|
||||
cache_source = getattr(request.state, 'cache_source', None)
|
||||
result_dict['cache_hit'] = cache_hit
|
||||
result_dict['response_time_ms'] = response_time
|
||||
# Always include cache_source, even if None
|
||||
result_dict['cache_source'] = cache_source
|
||||
|
||||
return result_dict
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea defalcării balanței nete: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării balanței nete: {str(e)}")
|
||||
|
||||
@router.get("/current-period")
|
||||
async def get_current_period(
|
||||
company: int = Query(..., description="ID-ul firmei"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Returnează perioada curentă (an și lună) din calendarul Oracle
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează anul, luna și perioada curentă în format YYYY-MM
|
||||
- Folosit pentru afișarea lunii curente în dashboard
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if str(company) not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await DashboardService.get_current_period(company)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Eroare la obținerea perioadei curente: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea perioadei curente: {str(e)}")
|
||||
128
backend/modules/reports/routers/invoices.py
Normal file
128
backend/modules/reports/routers/invoices.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
API Router pentru facturi
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
from datetime import date
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.auth.dependencies import get_current_user, require_company_access
|
||||
from shared.auth.models import CurrentUser
|
||||
from ..models.invoice import InvoiceFilter, InvoiceListResponse, InvoiceSummary
|
||||
from ..services.invoice_service import InvoiceService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=InvoiceListResponse)
|
||||
async def get_invoices(
|
||||
company: str = Query(description="Codul firmei"),
|
||||
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
|
||||
cont: Optional[str] = Query(None, description="Filtru după cont contabil"),
|
||||
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
|
||||
min_amount: Optional[float] = Query(None, description="Suma minimă"),
|
||||
max_amount: Optional[float] = Query(None, description="Suma maximă"),
|
||||
page: int = Query(1, ge=1, description="Pagina"),
|
||||
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține lista de facturi pentru o firmă
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Utilizatorul trebuie să aibă acces la firma specificată
|
||||
- Suportă filtrare după luna/an contabil și paginare
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
filter_params = InvoiceFilter(
|
||||
company=company,
|
||||
partner_type=partner_type,
|
||||
luna=luna,
|
||||
an=an,
|
||||
partner_name=partner_name,
|
||||
cont=cont,
|
||||
only_unpaid=only_unpaid,
|
||||
min_amount=min_amount,
|
||||
max_amount=max_amount,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
result = await InvoiceService.get_invoices(filter_params, current_user.username)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea facturilor: {str(e)}")
|
||||
|
||||
@router.get("/summary", response_model=InvoiceSummary)
|
||||
async def get_invoices_summary(
|
||||
company: str = Query(description="Codul firmei"),
|
||||
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""Obține rezumatul facturilor pentru dashboard"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await InvoiceService.get_invoice_summary(company, partner_type, current_user.username)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea rezumatului facturilor: {str(e)}")
|
||||
|
||||
@router.get("/{invoice_number}")
|
||||
async def get_invoice_details(
|
||||
invoice_number: str,
|
||||
company: str = Query(description="Codul firmei"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""Obține detaliile unei facturi specifice"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
result = await InvoiceService.get_invoice_details(company, invoice_number, current_user.username)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea detaliilor facturii: {str(e)}")
|
||||
|
||||
@router.get("/export/{format}")
|
||||
async def export_invoices(
|
||||
format: str,
|
||||
company: str = Query(description="Codul firmei"),
|
||||
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
|
||||
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
|
||||
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
|
||||
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
|
||||
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Export facturi în format specificat (excel, pdf, csv)
|
||||
Această funcție va fi implementată în viitor
|
||||
"""
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
# Verifică formatul
|
||||
if format not in ["excel", "pdf", "csv"]:
|
||||
raise HTTPException(status_code=400, detail="Format invalid. Formatele suportate sunt: excel, pdf, csv")
|
||||
|
||||
# Pentru moment, returnează o eroare că funcția nu este implementată
|
||||
raise HTTPException(status_code=501, detail=f"Export în format {format} nu este încă implementat")
|
||||
117
backend/modules/reports/routers/treasury.py
Normal file
117
backend/modules/reports/routers/treasury.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional, List
|
||||
from datetime import date
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
from ..models.treasury import RegisterFilter, RegisterListResponse
|
||||
from ..services.treasury_service import TreasuryService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/bank-cash-register", response_model=RegisterListResponse)
|
||||
async def get_bank_cash_register(
|
||||
company: str = Query(description="Codul firmei"),
|
||||
register_type: Optional[str] = Query(None, description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
|
||||
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
|
||||
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
|
||||
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
|
||||
bank_account: Optional[str] = Query(None, description="Filtru cont bancă/casă (bancasa)"),
|
||||
page: int = Query(1, ge=1, description="Pagina"),
|
||||
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține registrul de casă și bancă
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Suportă filtrare pe tip registru: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA
|
||||
- Suportă filtrare și paginare
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
# Validează register_type dacă e specificat
|
||||
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
|
||||
if register_type and register_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
|
||||
)
|
||||
|
||||
# Convertește datele
|
||||
date_from_obj = None
|
||||
date_to_obj = None
|
||||
|
||||
if date_from:
|
||||
try:
|
||||
date_from_obj = date.fromisoformat(date_from)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Format dată început invalid")
|
||||
|
||||
if date_to:
|
||||
try:
|
||||
date_to_obj = date.fromisoformat(date_to)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Format dată sfârșit invalid")
|
||||
|
||||
filter_params = RegisterFilter(
|
||||
company=company,
|
||||
register_type=register_type,
|
||||
luna=luna,
|
||||
an=an,
|
||||
date_from=date_from_obj,
|
||||
date_to=date_to_obj,
|
||||
partner_name=partner_name,
|
||||
bank_account=bank_account,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
result = await TreasuryService.get_bank_cash_register(filter_params, current_user.username)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea registrului: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/bank-cash-accounts", response_model=List[str])
|
||||
async def get_bank_cash_accounts(
|
||||
company: str = Query(description="Codul firmei"),
|
||||
register_type: str = Query(description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține lista distinctă de conturi bancă/casă pentru dropdown
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Returnează lista de valori bancasa pentru tipul de registru selectat
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
|
||||
|
||||
# Validează register_type
|
||||
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
|
||||
if register_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
|
||||
)
|
||||
|
||||
result = await TreasuryService.get_bank_cash_accounts(int(company), register_type)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Eroare la obținerea conturilor: {str(e)}")
|
||||
90
backend/modules/reports/routers/trial_balance.py
Normal file
90
backend/modules/reports/routers/trial_balance.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
API Router for Trial Balance (Balanță de Verificare)
|
||||
Refactored to use service layer with caching
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
from ..models.trial_balance import TrialBalanceResponse
|
||||
from ..services.trial_balance_service import TrialBalanceService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=TrialBalanceResponse)
|
||||
async def get_trial_balance(
|
||||
company: str = Query(description="Codul firmei (ID)"),
|
||||
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna (1-12), default: luna curentă"),
|
||||
an: Optional[int] = Query(None, ge=2000, le=2100, description="An, default: anul curent"),
|
||||
cont_filter: Optional[str] = Query(None, description="Filtru număr cont (ex: '512', '4111')"),
|
||||
denumire_filter: Optional[str] = Query(None, description="Filtru denumire cont (partial match, case-insensitive)"),
|
||||
sort_by: str = Query("CONT", description="Coloană pentru sortare"),
|
||||
sort_order: str = Query("asc", description="Ordinea sortării (asc | desc)"),
|
||||
page: int = Query(1, ge=1, description="Pagina"),
|
||||
page_size: int = Query(50, ge=1, le=1000000, description="Mărimea paginii"),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Obține balanța de verificare sintetică pentru o firmă
|
||||
|
||||
- Necesită autentificare JWT
|
||||
- Utilizatorul trebuie să aibă acces la firma specificată
|
||||
- Suportă filtrare după cont și denumire
|
||||
- Suportă paginare și sortare
|
||||
- **CACHED 10 min** - folosește sistem cache two-tier (L1 Memory + L2 SQLite)
|
||||
"""
|
||||
try:
|
||||
# Verifică dacă utilizatorul are acces la firma specificată
|
||||
if company not in current_user.companies:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Nu aveți acces la firma {company}"
|
||||
)
|
||||
|
||||
# Setează valorile implicite pentru lună și an (luna și anul curent)
|
||||
current_date = date.today()
|
||||
if luna is None:
|
||||
luna = current_date.month
|
||||
if an is None:
|
||||
an = current_date.year
|
||||
|
||||
# Convert company to int
|
||||
company_id = int(company)
|
||||
|
||||
# Call service (with caching) - all business logic moved to service
|
||||
data = await TrialBalanceService.get_trial_balance(
|
||||
company_id=company_id,
|
||||
luna=luna,
|
||||
an=an,
|
||||
cont_filter=cont_filter,
|
||||
denumire_filter=denumire_filter,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
username=current_user.username
|
||||
)
|
||||
|
||||
return TrialBalanceResponse(
|
||||
success=True,
|
||||
data=data
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Schema not found or validation error
|
||||
logger.error(f"Validation error in trial balance: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
# Log unexpected errors
|
||||
logger.error(f"Error fetching trial balance: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la obținerea balanței de verificare: {str(e)}"
|
||||
)
|
||||
0
backend/modules/reports/schemas/__init__.py
Normal file
0
backend/modules/reports/schemas/__init__.py
Normal file
0
backend/modules/reports/services/__init__.py
Normal file
0
backend/modules/reports/services/__init__.py
Normal file
77
backend/modules/reports/services/calendar_service.py
Normal file
77
backend/modules/reports/services/calendar_service.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Calendar service for fetching available accounting periods
|
||||
"""
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from ..models.calendar import CalendarPeriod, CalendarPeriodsResponse
|
||||
from ..cache.decorators import cached
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CalendarService:
|
||||
"""Service for calendar/accounting period operations"""
|
||||
|
||||
# Romanian month names for display
|
||||
MONTH_NAMES_RO = [
|
||||
"Ianuarie", "Februarie", "Martie", "Aprilie", "Mai", "Iunie",
|
||||
"Iulie", "August", "Septembrie", "Octombrie", "Noiembrie", "Decembrie"
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='schema', key_params=['company_id'])
|
||||
async def _get_schema(company_id: int) -> str:
|
||||
"""Get schema for company (CACHED 24h)"""
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :company_id
|
||||
""", {'company_id': company_id})
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='calendar_periods', key_params=['company_id'])
|
||||
async def get_available_periods(company_id: int) -> CalendarPeriodsResponse:
|
||||
"""
|
||||
Get all available accounting periods for a company (CACHED 1h)
|
||||
|
||||
Returns periods ordered by year DESC, month DESC with Romanian month names.
|
||||
"""
|
||||
schema = await CalendarService._get_schema(company_id)
|
||||
if not schema:
|
||||
logger.warning(f"Schema not found for company {company_id}")
|
||||
return CalendarPeriodsResponse(periods=[], current_period=None, total_count=0)
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"""
|
||||
SELECT anul, luna
|
||||
FROM {schema}.calendar
|
||||
ORDER BY anul DESC, luna DESC
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
periods = []
|
||||
for row in rows:
|
||||
an, luna = row[0], row[1]
|
||||
month_name = CalendarService.MONTH_NAMES_RO[luna - 1]
|
||||
periods.append(CalendarPeriod(
|
||||
an=an,
|
||||
luna=luna,
|
||||
display_name=f"{month_name} {an}"
|
||||
))
|
||||
|
||||
current_period = periods[0] if periods else None
|
||||
|
||||
logger.info(f"Loaded {len(periods)} accounting periods for company {company_id}")
|
||||
|
||||
return CalendarPeriodsResponse(
|
||||
periods=periods,
|
||||
current_period=current_period,
|
||||
total_count=len(periods)
|
||||
)
|
||||
1995
backend/modules/reports/services/dashboard_service.py
Normal file
1995
backend/modules/reports/services/dashboard_service.py
Normal file
File diff suppressed because it is too large
Load Diff
324
backend/modules/reports/services/invoice_service.py
Normal file
324
backend/modules/reports/services/invoice_service.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Service pentru logica facturi - Portează query-urile din aplicația Flask
|
||||
"""
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from typing import List, Tuple
|
||||
from ..models.invoice import Invoice, InvoiceFilter, InvoiceListResponse, InvoiceSummary
|
||||
from ..cache.decorators import cached
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InvoiceService:
|
||||
"""Service pentru gestionarea facturilor"""
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='schema', key_params=['company_id'])
|
||||
async def _get_schema(company_id: int) -> str:
|
||||
"""Obține schema pentru company_id (CACHED PERMANENT)"""
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
schema_query = """
|
||||
SELECT schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :company_id
|
||||
"""
|
||||
cursor.execute(schema_query, {'company_id': company_id})
|
||||
schema_result = cursor.fetchone()
|
||||
|
||||
if not schema_result:
|
||||
raise ValueError(f"Schema not found for company {company_id}")
|
||||
|
||||
return schema_result[0]
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='invoices', key_params=['filter_params', 'username'])
|
||||
async def get_invoices(filter_params: InvoiceFilter, username: str) -> InvoiceListResponse:
|
||||
"""
|
||||
Obține lista de facturi - Query simplu pentru afișare în tabel (CACHED 10 min)
|
||||
"""
|
||||
company_id = int(filter_params.company)
|
||||
schema = await InvoiceService._get_schema(company_id)
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
|
||||
# Determină conturile în funcție de partner_type
|
||||
if filter_params.partner_type == "CLIENTI":
|
||||
conturi = "'4111', '461'"
|
||||
elif filter_params.partner_type == "FURNIZORI":
|
||||
conturi = "'401', '404', '462'"
|
||||
else:
|
||||
conturi = "'4111'" # default
|
||||
|
||||
# Determine period to use: from params or MAX from calendar
|
||||
if filter_params.luna and filter_params.an:
|
||||
period_condition = "vp.an = :an AND vp.luna = :luna"
|
||||
use_param_period = True
|
||||
else:
|
||||
period_condition = f"""vp.an = (SELECT anul FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))
|
||||
AND vp.luna = (SELECT luna FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))"""
|
||||
use_param_period = False
|
||||
|
||||
# Query cu calculele corecte pentru solduri
|
||||
base_query = f"""
|
||||
SELECT
|
||||
vp.NUME,
|
||||
vp.NRACT,
|
||||
vp.DATAACT,
|
||||
vp.DATASCAD,
|
||||
vp.CONTRACT,
|
||||
vp.COD_FISCAL,
|
||||
vp.REG_COMERT,
|
||||
CASE
|
||||
WHEN vp.CONT IN ('4111','461') THEN vp.PRECDEB + vp.DEBIT -- Total facturat clienți
|
||||
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECCRED + vp.CREDIT -- Total facturat furnizori
|
||||
END as total_facturat,
|
||||
CASE
|
||||
WHEN vp.CONT IN ('4111','461') THEN vp.PRECCRED + vp.CREDIT -- Încasat clienți
|
||||
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECDEB + vp.DEBIT -- Achitat furnizori
|
||||
END as achitat,
|
||||
CASE
|
||||
WHEN vp.CONT IN ('4111','461') THEN
|
||||
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT) -- Sold clienți
|
||||
WHEN vp.CONT IN ('401','404','462') THEN
|
||||
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT) -- Sold furnizori
|
||||
END as sold,
|
||||
vp.CONT,
|
||||
NVL(vp.NUME_VAL, 'RON') as valuta,
|
||||
CASE
|
||||
WHEN vp.DATASCAD < SYSDATE THEN 'restant'
|
||||
ELSE 'in_termen'
|
||||
END as status
|
||||
FROM {schema}.vireg_parteneri vp
|
||||
WHERE {period_condition}
|
||||
AND (
|
||||
(:partner_type = 'CLIENTI' AND vp.cont IN ('4111', '461'))
|
||||
OR
|
||||
(:partner_type = 'FURNIZORI' AND vp.cont IN ('401', '404', '462'))
|
||||
)
|
||||
"""
|
||||
|
||||
params = {'partner_type': filter_params.partner_type}
|
||||
|
||||
# Add period params if using explicit period
|
||||
if use_param_period:
|
||||
params['an'] = filter_params.an
|
||||
params['luna'] = filter_params.luna
|
||||
|
||||
if filter_params.partner_name:
|
||||
base_query += " AND UPPER(vp.nume) LIKE UPPER(:partner_name)"
|
||||
params['partner_name'] = f"%{filter_params.partner_name}%"
|
||||
|
||||
if filter_params.cont:
|
||||
base_query += " AND vp.cont = :cont"
|
||||
params['cont'] = filter_params.cont
|
||||
|
||||
if filter_params.min_amount:
|
||||
base_query += " AND total_facturat >= :min_amount"
|
||||
params['min_amount'] = filter_params.min_amount
|
||||
|
||||
if filter_params.max_amount:
|
||||
base_query += " AND total_facturat <= :max_amount"
|
||||
params['max_amount'] = filter_params.max_amount
|
||||
|
||||
if filter_params.only_unpaid:
|
||||
# Nu putem folosi aliasul "sold" în WHERE în Oracle, trebuie să repetăm calculul
|
||||
base_query += """ AND (
|
||||
CASE
|
||||
WHEN vp.CONT IN ('4111','461') THEN
|
||||
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT)
|
||||
WHEN vp.CONT IN ('401','404','462') THEN
|
||||
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT)
|
||||
END
|
||||
) > 0"""
|
||||
|
||||
# Count total pentru paginare
|
||||
count_query = f"SELECT COUNT(*) FROM ({base_query})"
|
||||
cursor.execute(count_query, params)
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
# Query pentru TOTAL SOLD din TOATE facturile filtrate (nu doar pagina curentă)
|
||||
total_sold_query = f"""
|
||||
SELECT NVL(SUM(sold), 0) as total_sold
|
||||
FROM ({base_query})
|
||||
"""
|
||||
cursor.execute(total_sold_query, params)
|
||||
total_sold_result = cursor.fetchone()
|
||||
total_sold_all = Decimal(str(total_sold_result[0])) if total_sold_result else Decimal('0.00')
|
||||
|
||||
# Get accounting period - use params if provided, else from calendar
|
||||
if use_param_period:
|
||||
accounting_period = {
|
||||
'an': filter_params.an,
|
||||
'luna': filter_params.luna
|
||||
}
|
||||
else:
|
||||
period_query = f"""
|
||||
SELECT anul, luna
|
||||
FROM {schema}.calendar
|
||||
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar)
|
||||
"""
|
||||
cursor.execute(period_query)
|
||||
period_result = cursor.fetchone()
|
||||
accounting_period = {
|
||||
'an': period_result[0] if period_result else None,
|
||||
'luna': period_result[1] if period_result else None
|
||||
}
|
||||
|
||||
# Adaugă ORDER BY și paginare - Ordonare cronologică (DATAACT, NRACT, NUME)
|
||||
base_query += " ORDER BY vp.DATAACT ASC, vp.NRACT ASC, vp.NUME"
|
||||
|
||||
# Paginare Oracle
|
||||
offset = (filter_params.page - 1) * filter_params.page_size
|
||||
limit = offset + filter_params.page_size
|
||||
paginated_query = f"""
|
||||
SELECT * FROM (
|
||||
SELECT ROWNUM as rn, t.* FROM ({base_query}) t WHERE ROWNUM <= :limit
|
||||
) WHERE rn > :offset
|
||||
"""
|
||||
params['offset'] = offset
|
||||
params['limit'] = limit
|
||||
|
||||
cursor.execute(paginated_query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Procesează rezultatele cu structura nouă
|
||||
invoices = []
|
||||
total_amount = Decimal('0.00')
|
||||
|
||||
for row in rows:
|
||||
# Skip ROWNUM, extrage valorile din query-ul nou
|
||||
nume = row[1]
|
||||
nract = row[2]
|
||||
dataact = row[3]
|
||||
datascad = row[4]
|
||||
contract = row[5]
|
||||
cod_fiscal = row[6]
|
||||
reg_comert = row[7]
|
||||
total_facturat = Decimal(str(row[8] or 0))
|
||||
achitat = Decimal(str(row[9] or 0))
|
||||
sold = Decimal(str(row[10] or 0))
|
||||
cont = row[11]
|
||||
valuta = row[12] or 'RON'
|
||||
status = row[13]
|
||||
|
||||
invoice_data = {
|
||||
'nume': nume or '',
|
||||
'nract': nract or 0,
|
||||
'dataact': dataact,
|
||||
'datascad': datascad,
|
||||
'contract': contract,
|
||||
'cod_fiscal': cod_fiscal,
|
||||
'reg_comert': reg_comert,
|
||||
'cont': cont,
|
||||
'totctva': total_facturat,
|
||||
'achitat': achitat,
|
||||
'soldfinal': sold,
|
||||
'valuta': valuta
|
||||
}
|
||||
|
||||
invoice = Invoice(**invoice_data)
|
||||
invoices.append(invoice)
|
||||
total_amount += total_facturat
|
||||
|
||||
return InvoiceListResponse(
|
||||
invoices=invoices,
|
||||
total_count=total_count,
|
||||
filtered_count=len(invoices),
|
||||
total_amount=total_amount,
|
||||
page=filter_params.page,
|
||||
page_size=filter_params.page_size,
|
||||
has_more=len(invoices) == filter_params.page_size,
|
||||
accounting_period=accounting_period,
|
||||
# Total sold din TOATE facturile filtrate
|
||||
total_sold_all=total_sold_all
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_invoice_details(company: str, invoice_number: str, username: str) -> Invoice:
|
||||
"""
|
||||
Obține detaliile unei facturi specifice
|
||||
"""
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Obține schema din v_nom_firme bazat pe id_firma
|
||||
company_id = int(company)
|
||||
schema_query = "SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme WHERE id_firma = :company_id"
|
||||
cursor.execute(schema_query, {'company_id': company_id})
|
||||
schema_result = cursor.fetchone()
|
||||
|
||||
if not schema_result:
|
||||
raise ValueError(f"Schema nu a fost găsită pentru id_firma {company_id}")
|
||||
|
||||
schema = schema_result[0]
|
||||
|
||||
# Query simplu pentru detalii factură
|
||||
detail_query = f"""
|
||||
SELECT
|
||||
NUME,
|
||||
NRACT,
|
||||
DATAACT,
|
||||
DATASCAD,
|
||||
CONTRACT,
|
||||
COD_FISCAL,
|
||||
REG_COMERT,
|
||||
PRECDEB,
|
||||
PRECCRED,
|
||||
DEBIT,
|
||||
CREDIT,
|
||||
CONT
|
||||
FROM {schema}.vireg_parteneri
|
||||
WHERE nract = :invoice_number
|
||||
AND an = (select anul from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
|
||||
AND luna = (select luna from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
|
||||
"""
|
||||
|
||||
cursor.execute(detail_query, {'invoice_number': invoice_number})
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
raise ValueError(f"Factura {invoice_number} nu a fost găsită")
|
||||
|
||||
# Extrage valorile
|
||||
nume = row[0]
|
||||
nract = row[1]
|
||||
dataact = row[2]
|
||||
datascad = row[3]
|
||||
contract = row[4]
|
||||
cod_fiscal = row[5]
|
||||
reg_comert = row[6]
|
||||
precdeb = Decimal(str(row[7] or 0))
|
||||
preccred = Decimal(str(row[8] or 0))
|
||||
debit = Decimal(str(row[9] or 0))
|
||||
credit = Decimal(str(row[10] or 0))
|
||||
cont = row[11]
|
||||
|
||||
# Calculează valorile în funcție de tipul contului
|
||||
if cont in ('4111', '461'): # CLIENTI
|
||||
totctva = precdeb + debit
|
||||
achitat = preccred + credit
|
||||
soldfinal = precdeb - preccred + debit - credit
|
||||
else: # FURNIZORI
|
||||
totctva = preccred + credit
|
||||
achitat = precdeb + debit
|
||||
soldfinal = preccred - precdeb + credit - debit
|
||||
|
||||
invoice_data = {
|
||||
'nume': nume or '',
|
||||
'nract': nract or 0,
|
||||
'dataact': dataact,
|
||||
'datascad': datascad,
|
||||
'contract': contract,
|
||||
'cod_fiscal': cod_fiscal,
|
||||
'reg_comert': reg_comert,
|
||||
'totctva': totctva,
|
||||
'achitat': achitat,
|
||||
'soldfinal': soldfinal
|
||||
}
|
||||
|
||||
return Invoice(**invoice_data)
|
||||
410
backend/modules/reports/services/treasury_service.py
Normal file
410
backend/modules/reports/services/treasury_service.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
import oracledb
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from ..models.treasury import BankCashRegister, RegisterFilter, RegisterListResponse, AccountingPeriod
|
||||
from ..cache.decorators import cached
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Tuple, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TreasuryService:
|
||||
"""Service pentru trezorerie - registru casă și bancă"""
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='schema', key_params=['company_id'])
|
||||
async def _get_schema(company_id: int) -> str:
|
||||
"""Obține schema pentru company_id (CACHED PERMANENT)"""
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
schema_query = """
|
||||
SELECT schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :company_id
|
||||
"""
|
||||
cursor.execute(schema_query, {'company_id': company_id})
|
||||
schema_result = cursor.fetchone()
|
||||
|
||||
if not schema_result:
|
||||
raise ValueError(f"Schema not found for company {company_id}")
|
||||
|
||||
return schema_result[0]
|
||||
|
||||
@staticmethod
|
||||
def _get_view_query(schema: str, register_type: Optional[str] = None) -> str:
|
||||
"""
|
||||
Construiește query-ul pentru view-ul vbancasa corespunzător.
|
||||
Dacă register_type este None, returnează UNION ALL pentru toate tipurile.
|
||||
NU se filtrează pe incasari/plati > 0 - se afișează TOATE înregistrările!
|
||||
"""
|
||||
view_configs = {
|
||||
'BANCA_LEI': {
|
||||
'view': f'{schema}.vbancasa_5121_cum',
|
||||
'incasari_col': 'incasari',
|
||||
'plati_col': 'plati',
|
||||
'valuta': "'RON'",
|
||||
'tip': "'BANCA LEI'"
|
||||
},
|
||||
'BANCA_VALUTA': {
|
||||
'view': f'{schema}.vbancasa_5124_cum',
|
||||
'incasari_col': 'incasval',
|
||||
'plati_col': 'platival',
|
||||
'valuta': "COALESCE(numeval, 'EUR')",
|
||||
'tip': "'BANCA VALUTA'"
|
||||
},
|
||||
'CASA_LEI': {
|
||||
'view': f'{schema}.vbancasa_5311_cum',
|
||||
'incasari_col': 'incasari',
|
||||
'plati_col': 'plati',
|
||||
'valuta': "'RON'",
|
||||
'tip': "'CASA LEI'"
|
||||
},
|
||||
'CASA_VALUTA': {
|
||||
'view': f'{schema}.vbancasa_5314_cum',
|
||||
'incasari_col': 'incasval',
|
||||
'plati_col': 'platival',
|
||||
'valuta': "COALESCE(numeval, 'EUR')",
|
||||
'tip': "'CASA VALUTA'"
|
||||
}
|
||||
}
|
||||
|
||||
def build_select(config):
|
||||
# NU se filtrează - se afișează TOATE înregistrările
|
||||
# SOLD CUMULAT: Running balance per bancasa using window function
|
||||
# NULL-date rows (opening balance) come first due to NULLS FIRST
|
||||
return f"""
|
||||
SELECT
|
||||
nume, nract, dataact, bancasa,
|
||||
{config['incasari_col']} as incasari,
|
||||
{config['plati_col']} as plati,
|
||||
SUM({config['incasari_col']} - {config['plati_col']}) OVER (
|
||||
PARTITION BY bancasa
|
||||
ORDER BY dataact ASC NULLS FIRST, nract ASC NULLS FIRST
|
||||
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
|
||||
) as sold,
|
||||
{config['valuta']} as valuta,
|
||||
{config['tip']} as tip_registru,
|
||||
explicatia
|
||||
FROM {config['view']}
|
||||
"""
|
||||
|
||||
if register_type and register_type in view_configs:
|
||||
return build_select(view_configs[register_type])
|
||||
else:
|
||||
# UNION ALL pentru toate tipurile
|
||||
queries = [build_select(cfg) for cfg in view_configs.values()]
|
||||
return " UNION ALL ".join(queries)
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='treasury', key_params=['filter_params', 'username'])
|
||||
async def get_bank_cash_register(filter_params: RegisterFilter, username: str) -> RegisterListResponse:
|
||||
"""
|
||||
Obține registrul de casă și bancă din vbancasa views (CACHED 10 min)
|
||||
|
||||
IMPORTANT: PACK_SESIUNE.SETAN și SETLUNA trebuie executate în ACEEAȘI
|
||||
tranzacție cu SELECT-ul din vbancasa* views!
|
||||
|
||||
Folosim un bloc PL/SQL anonim care:
|
||||
1. Obține anul și luna curentă din calendar
|
||||
2. Apelează PACK_SESIUNE.SETAN și SETLUNA
|
||||
3. Execută SELECT-ul din vbancasa*
|
||||
Toate în aceeași tranzacție!
|
||||
"""
|
||||
company_id = int(filter_params.company)
|
||||
schema = await TreasuryService._get_schema(company_id)
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
|
||||
# Construiește query-ul pentru tipul de registru selectat
|
||||
base_select = TreasuryService._get_view_query(schema, filter_params.register_type)
|
||||
|
||||
# Construiește WHERE conditions
|
||||
where_conditions = []
|
||||
|
||||
# Date filter preserves NULL-date rows (opening balance)
|
||||
# for correct cumulative sum calculation
|
||||
if filter_params.date_from and filter_params.date_to:
|
||||
where_conditions.append(f"(dataact IS NULL OR (dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD') AND dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD')))")
|
||||
elif filter_params.date_from:
|
||||
where_conditions.append(f"(dataact IS NULL OR dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
|
||||
elif filter_params.date_to:
|
||||
where_conditions.append(f"(dataact IS NULL OR dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
|
||||
|
||||
if filter_params.partner_name:
|
||||
# Escape single quotes pentru SQL
|
||||
partner_escaped = filter_params.partner_name.replace("'", "''")
|
||||
where_conditions.append(f"UPPER(nume) LIKE UPPER('%{partner_escaped}%')")
|
||||
|
||||
if filter_params.bank_account:
|
||||
# Escape single quotes pentru SQL
|
||||
bank_escaped = filter_params.bank_account.replace("'", "''")
|
||||
where_conditions.append(f"bancasa = '{bank_escaped}'")
|
||||
|
||||
where_clause = ""
|
||||
if where_conditions:
|
||||
where_clause = " WHERE " + " AND ".join(where_conditions)
|
||||
|
||||
# Paginare Oracle
|
||||
offset = (filter_params.page - 1) * filter_params.page_size
|
||||
limit_val = filter_params.page_size
|
||||
|
||||
# Determine period to use: from params or MAX from calendar
|
||||
if filter_params.luna and filter_params.an:
|
||||
use_param_period = True
|
||||
period_select = f"""
|
||||
v_an := :param_an;
|
||||
v_luna := :param_luna;
|
||||
"""
|
||||
else:
|
||||
use_param_period = False
|
||||
period_select = f"""
|
||||
SELECT anul, luna INTO v_an, v_luna
|
||||
FROM {schema}.calendar
|
||||
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
|
||||
"""
|
||||
|
||||
# Bloc PL/SQL anonim care face totul într-o singură tranzacție:
|
||||
# 1. Obține anul și luna din params sau calendar
|
||||
# 2. Setează PACK_SESIUNE.SETAN și SETLUNA
|
||||
# 3. Returnează datele prin REF CURSOR
|
||||
# IMPORTANT: Folosim ROW_NUMBER() pentru paginare corectă cu ORDER BY NULLS FIRST
|
||||
plsql_block = f"""
|
||||
DECLARE
|
||||
v_an NUMBER;
|
||||
v_luna NUMBER;
|
||||
v_cursor SYS_REFCURSOR;
|
||||
BEGIN
|
||||
-- Obține anul și luna din parametri sau calendar
|
||||
{period_select}
|
||||
|
||||
-- Setează contextul de sesiune (OBLIGATORIU înainte de SELECT din vbancasa*)
|
||||
{schema}.PACK_SESIUNE.SETAN(v_an);
|
||||
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
|
||||
|
||||
-- Return accounting period
|
||||
:out_an := v_an;
|
||||
:out_luna := v_luna;
|
||||
|
||||
-- Returnează datele prin cursor cu ROW_NUMBER pentru paginare corectă
|
||||
-- Pentru rânduri cu dataact=NULL (solduri precedente), sortare după bancasa
|
||||
-- Pentru rânduri cu date, sortare după data, număr, bancasa
|
||||
OPEN :result_cursor FOR
|
||||
SELECT * FROM (
|
||||
SELECT t.*, ROW_NUMBER() OVER (
|
||||
ORDER BY dataact ASC NULLS FIRST,
|
||||
CASE WHEN dataact IS NULL THEN bancasa END ASC,
|
||||
nract ASC NULLS FIRST,
|
||||
bancasa ASC
|
||||
) as rn
|
||||
FROM ({base_select}) t{where_clause}
|
||||
) WHERE rn > {offset} AND rn <= {offset + limit_val};
|
||||
END;
|
||||
"""
|
||||
|
||||
# Creează cursor pentru rezultate (oracledb.CURSOR pentru REF CURSOR)
|
||||
result_cursor = cursor.var(oracledb.CURSOR)
|
||||
out_an = cursor.var(int)
|
||||
out_luna = cursor.var(int)
|
||||
|
||||
# Build params dict
|
||||
exec_params = {'result_cursor': result_cursor, 'out_an': out_an, 'out_luna': out_luna}
|
||||
if use_param_period:
|
||||
exec_params['param_an'] = filter_params.an
|
||||
exec_params['param_luna'] = filter_params.luna
|
||||
|
||||
# Execută blocul PL/SQL cu REF CURSOR
|
||||
cursor.execute(plsql_block, exec_params)
|
||||
|
||||
# Get accounting period values
|
||||
accounting_year = out_an.getvalue()
|
||||
accounting_month = out_luna.getvalue()
|
||||
|
||||
# Obține rezultatele din cursor
|
||||
ref_cursor = result_cursor.getvalue()
|
||||
rows = ref_cursor.fetchall()
|
||||
ref_cursor.close()
|
||||
|
||||
# Pentru count total, executăm alt bloc PL/SQL
|
||||
count_plsql = f"""
|
||||
DECLARE
|
||||
v_an NUMBER;
|
||||
v_luna NUMBER;
|
||||
BEGIN
|
||||
-- Obține anul și luna din parametri sau calendar
|
||||
{period_select}
|
||||
|
||||
{schema}.PACK_SESIUNE.SETAN(v_an);
|
||||
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
|
||||
|
||||
SELECT COUNT(*) INTO :total_count FROM ({base_select}) sub{where_clause};
|
||||
END;
|
||||
"""
|
||||
|
||||
total_count_var = cursor.var(int)
|
||||
count_params = {'total_count': total_count_var}
|
||||
if use_param_period:
|
||||
count_params['param_an'] = filter_params.an
|
||||
count_params['param_luna'] = filter_params.luna
|
||||
cursor.execute(count_plsql, count_params)
|
||||
total_count = total_count_var.getvalue()
|
||||
|
||||
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
|
||||
# sold_precedent = suma sold pentru rânduri cu dataact IS NULL
|
||||
# total_incasari = suma incasari pentru rânduri cu dataact IS NOT NULL
|
||||
# total_plati = suma plati pentru rânduri cu dataact IS NOT NULL
|
||||
# Notă: where_clause poate fi gol sau poate conține "WHERE ..."
|
||||
# Dacă e gol, adăugăm WHERE; dacă nu, adăugăm AND
|
||||
dataact_null_cond = " AND dataact IS NULL" if where_clause else " WHERE dataact IS NULL"
|
||||
dataact_not_null_cond = " AND dataact IS NOT NULL" if where_clause else " WHERE dataact IS NOT NULL"
|
||||
|
||||
totals_plsql = f"""
|
||||
DECLARE
|
||||
v_an NUMBER;
|
||||
v_luna NUMBER;
|
||||
BEGIN
|
||||
-- Obține anul și luna din parametri sau calendar
|
||||
{period_select}
|
||||
|
||||
{schema}.PACK_SESIUNE.SETAN(v_an);
|
||||
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
|
||||
|
||||
-- Sold precedent: suma sold pentru rânduri fără dată (opening balance)
|
||||
SELECT NVL(SUM(sold), 0) INTO :sold_precedent_all
|
||||
FROM ({base_select}) sub{where_clause}{dataact_null_cond};
|
||||
|
||||
-- Total încasări: suma incasari pentru rânduri cu dată (transactions)
|
||||
SELECT NVL(SUM(incasari), 0) INTO :total_incasari_all
|
||||
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
|
||||
|
||||
-- Total plăți: suma plati pentru rânduri cu dată (transactions)
|
||||
SELECT NVL(SUM(plati), 0) INTO :total_plati_all
|
||||
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
|
||||
END;
|
||||
"""
|
||||
|
||||
sold_precedent_all_var = cursor.var(oracledb.NUMBER)
|
||||
total_incasari_all_var = cursor.var(oracledb.NUMBER)
|
||||
total_plati_all_var = cursor.var(oracledb.NUMBER)
|
||||
|
||||
totals_params = {
|
||||
'sold_precedent_all': sold_precedent_all_var,
|
||||
'total_incasari_all': total_incasari_all_var,
|
||||
'total_plati_all': total_plati_all_var
|
||||
}
|
||||
if use_param_period:
|
||||
totals_params['param_an'] = filter_params.an
|
||||
totals_params['param_luna'] = filter_params.luna
|
||||
|
||||
cursor.execute(totals_plsql, totals_params)
|
||||
|
||||
sold_precedent_all = Decimal(str(sold_precedent_all_var.getvalue() or 0))
|
||||
total_incasari_all = Decimal(str(total_incasari_all_var.getvalue() or 0))
|
||||
total_plati_all = Decimal(str(total_plati_all_var.getvalue() or 0))
|
||||
sold_final_all = sold_precedent_all + total_incasari_all - total_plati_all
|
||||
|
||||
# Procesare rezultate
|
||||
registers = []
|
||||
total_incasari = Decimal('0.00')
|
||||
total_plati = Decimal('0.00')
|
||||
|
||||
for row in rows:
|
||||
# Coloane: nume, nract, dataact, bancasa, incasari, plati, sold, valuta, tip_registru, explicatia, rn
|
||||
# row[0-9] = date, row[10] = rn (ROW_NUMBER la final)
|
||||
register_data = BankCashRegister(
|
||||
nume=row[0] or '',
|
||||
nract=row[1],
|
||||
dataact=row[2],
|
||||
nume_cont_bancar=row[3] or '',
|
||||
incasari=Decimal(str(row[4] or 0)),
|
||||
plati=Decimal(str(row[5] or 0)),
|
||||
sold=Decimal(str(row[6] or 0)),
|
||||
valuta=row[7],
|
||||
tip_registru=row[8],
|
||||
explicatia=row[9] or ''
|
||||
)
|
||||
registers.append(register_data)
|
||||
total_incasari += register_data.incasari
|
||||
total_plati += register_data.plati
|
||||
|
||||
logger.info(f"Treasury query for company {company_id}, type={filter_params.register_type}: {len(registers)} records, total={total_count}")
|
||||
|
||||
return RegisterListResponse(
|
||||
registers=registers,
|
||||
total_count=total_count,
|
||||
filtered_count=len(registers),
|
||||
total_incasari=total_incasari,
|
||||
total_plati=total_plati,
|
||||
page=filter_params.page,
|
||||
page_size=filter_params.page_size,
|
||||
has_more=len(registers) == filter_params.page_size,
|
||||
accounting_period=AccountingPeriod(an=accounting_year, luna=accounting_month),
|
||||
# Totaluri din TOATE înregistrările filtrate
|
||||
sold_precedent_all=sold_precedent_all,
|
||||
total_incasari_all=total_incasari_all,
|
||||
total_plati_all=total_plati_all,
|
||||
sold_final_all=sold_final_all
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='treasury', key_params=['company_id', 'register_type'])
|
||||
async def get_bank_cash_accounts(company_id: int, register_type: str) -> List[str]:
|
||||
"""
|
||||
Obține lista distinctă de conturi bancă/casă (bancasa) pentru dropdown.
|
||||
Cached pentru performanță.
|
||||
IMPORTANT: Trebuie să setăm contextul PACK_SESIUNE înainte de a accesa vbancasa views!
|
||||
"""
|
||||
schema = await TreasuryService._get_schema(company_id)
|
||||
|
||||
# Map register_type to view
|
||||
view_map = {
|
||||
'BANCA_LEI': f'{schema}.vbancasa_5121_cum',
|
||||
'BANCA_VALUTA': f'{schema}.vbancasa_5124_cum',
|
||||
'CASA_LEI': f'{schema}.vbancasa_5311_cum',
|
||||
'CASA_VALUTA': f'{schema}.vbancasa_5314_cum'
|
||||
}
|
||||
|
||||
if register_type not in view_map:
|
||||
return []
|
||||
|
||||
view_name = view_map[register_type]
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# PL/SQL block to set session context and get accounts
|
||||
plsql_block = f"""
|
||||
DECLARE
|
||||
v_an NUMBER;
|
||||
v_luna NUMBER;
|
||||
BEGIN
|
||||
-- Get current year and month from calendar
|
||||
SELECT anul, luna INTO v_an, v_luna
|
||||
FROM {schema}.calendar
|
||||
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
|
||||
|
||||
-- Set session context (REQUIRED before accessing vbancasa* views)
|
||||
{schema}.PACK_SESIUNE.SETAN(v_an);
|
||||
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
|
||||
|
||||
-- Return accounts via cursor
|
||||
OPEN :result_cursor FOR
|
||||
SELECT DISTINCT bancasa
|
||||
FROM {view_name}
|
||||
WHERE bancasa IS NOT NULL
|
||||
ORDER BY bancasa;
|
||||
END;
|
||||
"""
|
||||
|
||||
result_cursor = cursor.var(oracledb.CURSOR)
|
||||
cursor.execute(plsql_block, {'result_cursor': result_cursor})
|
||||
|
||||
ref_cursor = result_cursor.getvalue()
|
||||
rows = ref_cursor.fetchall()
|
||||
ref_cursor.close()
|
||||
|
||||
accounts = [row[0] for row in rows if row[0]]
|
||||
logger.info(f"Found {len(accounts)} bank/cash accounts for company {company_id}, type={register_type}")
|
||||
return accounts
|
||||
217
backend/modules/reports/services/trial_balance_service.py
Normal file
217
backend/modules/reports/services/trial_balance_service.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Service pentru Trial Balance (Balanță de Verificare) - Query VBAL VIEW
|
||||
Refactored to use caching system for optimal performance
|
||||
"""
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
from typing import Dict, Any
|
||||
from ..models.trial_balance import (
|
||||
TrialBalanceItem,
|
||||
TrialBalanceFilters,
|
||||
TrialBalancePagination,
|
||||
TrialBalanceResponse
|
||||
)
|
||||
from ..cache.decorators import cached
|
||||
from decimal import Decimal
|
||||
import math
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialBalanceService:
|
||||
"""Service pentru gestionarea balanței de verificare cu cache"""
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='schema', key_params=['company_id'])
|
||||
async def _get_schema(company_id: int) -> str:
|
||||
"""
|
||||
Obține schema pentru company_id (CACHED 24h)
|
||||
|
||||
This is cached permanently because company schemas rarely change.
|
||||
"""
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
schema_query = """
|
||||
SELECT schema
|
||||
FROM CONTAFIN_ORACLE.v_nom_firme
|
||||
WHERE id_firma = :company_id
|
||||
"""
|
||||
cursor.execute(schema_query, {'company_id': company_id})
|
||||
schema_result = cursor.fetchone()
|
||||
|
||||
if not schema_result:
|
||||
raise ValueError(f"Schema not found for company {company_id}")
|
||||
|
||||
return schema_result[0]
|
||||
|
||||
@staticmethod
|
||||
@cached(cache_type='trial_balance', key_params=['company_id', 'luna', 'an', 'cont_filter',
|
||||
'denumire_filter', 'sort_by', 'sort_order',
|
||||
'page', 'page_size', 'username'])
|
||||
async def get_trial_balance(
|
||||
company_id: int,
|
||||
luna: int,
|
||||
an: int,
|
||||
cont_filter: str | None,
|
||||
denumire_filter: str | None,
|
||||
sort_by: str,
|
||||
sort_order: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
username: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Obține balanța de verificare sintetică (CACHED 10 min)
|
||||
|
||||
Cache key includes all filter parameters to ensure unique cache entries
|
||||
for different query variations.
|
||||
|
||||
Args:
|
||||
company_id: ID firmei
|
||||
luna: Luna (1-12)
|
||||
an: Anul
|
||||
cont_filter: Filtru număr cont (optional)
|
||||
denumire_filter: Filtru denumire cont (optional)
|
||||
sort_by: Coloană pentru sortare
|
||||
sort_order: Ordinea sortării (asc/desc)
|
||||
page: Pagina
|
||||
page_size: Mărimea paginii
|
||||
username: Username pentru cache tracking
|
||||
|
||||
Returns:
|
||||
Dictionary cu items, pagination, filters_applied
|
||||
"""
|
||||
# Get schema (cached separately)
|
||||
schema = await TrialBalanceService._get_schema(company_id)
|
||||
|
||||
# Validate sort_order
|
||||
if sort_order.lower() not in ['asc', 'desc']:
|
||||
sort_order = 'asc'
|
||||
|
||||
# Validate sort_by (prevent SQL injection)
|
||||
valid_sort_columns = ['CONT', 'DENUMIRE', 'PRECDEB', 'PRECCRED',
|
||||
'RULDEB', 'RULCRED', 'SOLDDEB', 'SOLDCRED']
|
||||
if sort_by.upper() not in valid_sort_columns:
|
||||
sort_by = 'CONT'
|
||||
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Build base query for VBAL VIEW
|
||||
base_query = f"""
|
||||
SELECT
|
||||
CONT,
|
||||
NVL(DENUMIRE, '') as DENUMIRE,
|
||||
NVL(PRECDEB, 0) as PRECDEB,
|
||||
NVL(PRECCRED, 0) as PRECCRED,
|
||||
NVL(RULDEB, 0) as RULDEB,
|
||||
NVL(RULCRED, 0) as RULCRED,
|
||||
NVL(SOLDDEB, 0) as SOLDDEB,
|
||||
NVL(SOLDCRED, 0) as SOLDCRED
|
||||
FROM {schema}.VBAL
|
||||
WHERE AN = :an
|
||||
AND LUNA = :luna
|
||||
"""
|
||||
|
||||
params = {
|
||||
'an': an,
|
||||
'luna': luna
|
||||
}
|
||||
|
||||
# Add dynamic filters
|
||||
if cont_filter:
|
||||
base_query += " AND CONT LIKE :cont_filter"
|
||||
params['cont_filter'] = f"{cont_filter}%"
|
||||
|
||||
if denumire_filter:
|
||||
base_query += " AND UPPER(DENUMIRE) LIKE UPPER(:denumire_filter)"
|
||||
params['denumire_filter'] = f"%{denumire_filter}%"
|
||||
|
||||
# Count total for pagination
|
||||
count_query = f"SELECT COUNT(*) FROM ({base_query})"
|
||||
cursor.execute(count_query, params)
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
|
||||
totals_query = f"""
|
||||
SELECT
|
||||
NVL(SUM(PRECDEB), 0) as total_prec_deb,
|
||||
NVL(SUM(PRECCRED), 0) as total_prec_cred,
|
||||
NVL(SUM(RULDEB), 0) as total_rul_deb,
|
||||
NVL(SUM(RULCRED), 0) as total_rul_cred,
|
||||
NVL(SUM(SOLDDEB), 0) as total_sold_deb,
|
||||
NVL(SUM(SOLDCRED), 0) as total_sold_cred
|
||||
FROM ({base_query})
|
||||
"""
|
||||
cursor.execute(totals_query, params)
|
||||
totals_row = cursor.fetchone()
|
||||
|
||||
totals = {
|
||||
"total_sold_precedent_debit": Decimal(str(totals_row[0])) if totals_row else Decimal('0.00'),
|
||||
"total_sold_precedent_credit": Decimal(str(totals_row[1])) if totals_row else Decimal('0.00'),
|
||||
"total_rulaj_lunar_debit": Decimal(str(totals_row[2])) if totals_row else Decimal('0.00'),
|
||||
"total_rulaj_lunar_credit": Decimal(str(totals_row[3])) if totals_row else Decimal('0.00'),
|
||||
"total_sold_final_debit": Decimal(str(totals_row[4])) if totals_row else Decimal('0.00'),
|
||||
"total_sold_final_credit": Decimal(str(totals_row[5])) if totals_row else Decimal('0.00')
|
||||
}
|
||||
|
||||
# Add sorting
|
||||
base_query += f" ORDER BY {sort_by.upper()} {sort_order.upper()}"
|
||||
|
||||
# Pagination (Oracle ROWNUM with ORDER BY)
|
||||
offset = (page - 1) * page_size
|
||||
limit = offset + page_size
|
||||
|
||||
paginated_query = f"""
|
||||
SELECT * FROM (
|
||||
SELECT a.*, ROWNUM rnum FROM (
|
||||
{base_query}
|
||||
) a WHERE ROWNUM <= :limit
|
||||
) WHERE rnum > :offset
|
||||
"""
|
||||
params['offset'] = offset
|
||||
params['limit'] = limit
|
||||
|
||||
cursor.execute(paginated_query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Process results
|
||||
# Index: CONT(0), DENUMIRE(1), PRECDEB(2), PRECCRED(3),
|
||||
# RULDEB(4), RULCRED(5), SOLDDEB(6), SOLDCRED(7), rnum(8)
|
||||
items = []
|
||||
for row in rows:
|
||||
item = TrialBalanceItem(
|
||||
cont=row[0] or '',
|
||||
denumire=row[1] or '',
|
||||
sold_precedent_debit=Decimal(str(row[2] or 0)),
|
||||
sold_precedent_credit=Decimal(str(row[3] or 0)),
|
||||
rulaj_lunar_debit=Decimal(str(row[4] or 0)),
|
||||
rulaj_lunar_credit=Decimal(str(row[5] or 0)),
|
||||
sold_final_debit=Decimal(str(row[6] or 0)),
|
||||
sold_final_credit=Decimal(str(row[7] or 0))
|
||||
)
|
||||
items.append(item.dict())
|
||||
|
||||
# Calculate pagination
|
||||
total_pages = math.ceil(total_count / page_size) if page_size > 0 else 0
|
||||
|
||||
# Build response
|
||||
return {
|
||||
"items": items,
|
||||
"pagination": {
|
||||
"total_items": total_count,
|
||||
"total_pages": total_pages,
|
||||
"current_page": page,
|
||||
"page_size": page_size
|
||||
},
|
||||
"filters_applied": {
|
||||
"luna": luna,
|
||||
"an": an,
|
||||
"cont_filter": cont_filter,
|
||||
"denumire_filter": denumire_filter
|
||||
},
|
||||
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
|
||||
"totals": totals
|
||||
}
|
||||
0
backend/modules/telegram/__init__.py
Normal file
0
backend/modules/telegram/__init__.py
Normal file
0
backend/modules/telegram/agent/__init__.py
Normal file
0
backend/modules/telegram/agent/__init__.py
Normal file
313
backend/modules/telegram/agent/session.py
Normal file
313
backend/modules/telegram/agent/session.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Session Management for Telegram Bot
|
||||
|
||||
This module handles session state for Telegram users, specifically managing
|
||||
the active company selection for command handlers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from backend.modules.telegram.db.operations import (
|
||||
create_session,
|
||||
get_user_active_session,
|
||||
update_session_state,
|
||||
delete_user_sessions
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationSession:
|
||||
"""
|
||||
Manages session state for a single user.
|
||||
|
||||
Attributes:
|
||||
telegram_user_id: Telegram user ID
|
||||
session_id: UUID of the session
|
||||
active_company_id: Selected company ID
|
||||
active_company_name: Selected company name
|
||||
active_company_cui: Selected company CUI
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
telegram_user_id: int,
|
||||
session_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize a session.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
session_id: Existing session ID (if resuming), or None for new session
|
||||
"""
|
||||
self.telegram_user_id = telegram_user_id
|
||||
self.session_id = session_id
|
||||
self.created_at = datetime.now()
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
# Active company for this session
|
||||
self.active_company_id: Optional[int] = None
|
||||
self.active_company_name: Optional[str] = None
|
||||
self.active_company_cui: Optional[str] = None
|
||||
|
||||
def set_active_company(
|
||||
self,
|
||||
company_id: int,
|
||||
company_name: str,
|
||||
company_cui: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Set the active company for this session.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
company_name: Company name
|
||||
company_cui: Company CUI (optional)
|
||||
"""
|
||||
self.active_company_id = company_id
|
||||
self.active_company_name = company_name
|
||||
self.active_company_cui = company_cui
|
||||
self.updated_at = datetime.now()
|
||||
logger.info(
|
||||
f"Active company set for user {self.telegram_user_id}: "
|
||||
f"{company_name} (ID: {company_id})"
|
||||
)
|
||||
|
||||
def get_active_company(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the active company information.
|
||||
|
||||
Returns:
|
||||
Dict with company info (id, name, cui) or None if no company selected
|
||||
"""
|
||||
if self.active_company_id is not None:
|
||||
return {
|
||||
"id": self.active_company_id,
|
||||
"name": self.active_company_name,
|
||||
"cui": self.active_company_cui
|
||||
}
|
||||
return None
|
||||
|
||||
def clear_active_company(self):
|
||||
"""
|
||||
Clear the active company selection.
|
||||
"""
|
||||
logger.info(
|
||||
f"Clearing active company for user {self.telegram_user_id} "
|
||||
f"(was: {self.active_company_name})"
|
||||
)
|
||||
self.active_company_id = None
|
||||
self.active_company_name = None
|
||||
self.active_company_cui = None
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize session to dictionary (for database storage).
|
||||
|
||||
Returns:
|
||||
Dict representation of session
|
||||
"""
|
||||
return {
|
||||
"telegram_user_id": self.telegram_user_id,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"active_company_id": self.active_company_id,
|
||||
"active_company_name": self.active_company_name,
|
||||
"active_company_cui": self.active_company_cui
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ConversationSession':
|
||||
"""
|
||||
Deserialize session from dictionary.
|
||||
|
||||
Args:
|
||||
data: Dict representation of session
|
||||
|
||||
Returns:
|
||||
ConversationSession instance
|
||||
"""
|
||||
session = cls(
|
||||
telegram_user_id=data["telegram_user_id"],
|
||||
session_id=data.get("session_id")
|
||||
)
|
||||
|
||||
# Restore active company
|
||||
session.active_company_id = data.get("active_company_id")
|
||||
session.active_company_name = data.get("active_company_name")
|
||||
session.active_company_cui = data.get("active_company_cui")
|
||||
|
||||
if "created_at" in data:
|
||||
session.created_at = datetime.fromisoformat(data["created_at"])
|
||||
if "updated_at" in data:
|
||||
session.updated_at = datetime.fromisoformat(data["updated_at"])
|
||||
|
||||
return session
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages sessions for all users.
|
||||
|
||||
Provides methods to create, retrieve, update, and delete sessions.
|
||||
Sessions are stored both in memory (for quick access) and in database
|
||||
(for persistence).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
"""
|
||||
self._sessions: Dict[int, ConversationSession] = {}
|
||||
logger.info("SessionManager initialized")
|
||||
|
||||
async def get_or_create_session(
|
||||
self,
|
||||
telegram_user_id: int
|
||||
) -> ConversationSession:
|
||||
"""
|
||||
Get existing session for a user or create a new one.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
ConversationSession for the user
|
||||
"""
|
||||
# Check in-memory cache first
|
||||
if telegram_user_id in self._sessions:
|
||||
logger.debug(f"Found session in cache for user {telegram_user_id}")
|
||||
return self._sessions[telegram_user_id]
|
||||
|
||||
# Check database for existing session
|
||||
session_data = await get_user_active_session(telegram_user_id)
|
||||
|
||||
if session_data:
|
||||
# Restore session from database
|
||||
conversation_state_json = session_data.get('conversation_state')
|
||||
|
||||
if conversation_state_json:
|
||||
try:
|
||||
session_dict = json.loads(conversation_state_json)
|
||||
session = ConversationSession.from_dict(session_dict)
|
||||
session.session_id = session_data['session_id']
|
||||
self._sessions[telegram_user_id] = session
|
||||
logger.info(f"Restored session from database for user {telegram_user_id}")
|
||||
return session
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse session state: {e}")
|
||||
|
||||
# Create new session
|
||||
session = ConversationSession(telegram_user_id)
|
||||
|
||||
# Save to database
|
||||
session_id = await create_session(
|
||||
telegram_user_id=telegram_user_id,
|
||||
conversation_state=json.dumps(session.to_dict()),
|
||||
expires_in_hours=24
|
||||
)
|
||||
|
||||
session.session_id = session_id
|
||||
self._sessions[telegram_user_id] = session
|
||||
|
||||
logger.info(f"Created new session for user {telegram_user_id} (ID: {session_id})")
|
||||
return session
|
||||
|
||||
async def save_session(self, telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Save session to database.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if saved successfully
|
||||
"""
|
||||
session = self._sessions.get(telegram_user_id)
|
||||
|
||||
if not session or not session.session_id:
|
||||
logger.warning(f"No session to save for user {telegram_user_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
conversation_state = json.dumps(session.to_dict())
|
||||
|
||||
success = await update_session_state(
|
||||
session_id=session.session_id,
|
||||
conversation_state=conversation_state
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Saved session for user {telegram_user_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to save session for user {telegram_user_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session for user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def delete_session(self, telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Delete session completely (from memory and database).
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if deleted successfully
|
||||
"""
|
||||
# Remove from memory
|
||||
if telegram_user_id in self._sessions:
|
||||
del self._sessions[telegram_user_id]
|
||||
|
||||
# Delete from database
|
||||
success = await delete_user_sessions(telegram_user_id)
|
||||
|
||||
if success:
|
||||
logger.info(f"Deleted session for user {telegram_user_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to delete session for user {telegram_user_id}")
|
||||
|
||||
return success
|
||||
|
||||
def get_active_sessions_count(self) -> int:
|
||||
"""
|
||||
Get count of active sessions in memory.
|
||||
|
||||
Returns:
|
||||
int: Number of active sessions
|
||||
"""
|
||||
return len(self._sessions)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_session_manager_instance: Optional[SessionManager] = None
|
||||
|
||||
|
||||
def get_session_manager() -> SessionManager:
|
||||
"""
|
||||
Get or create the singleton SessionManager instance.
|
||||
|
||||
Returns:
|
||||
SessionManager: Singleton instance
|
||||
"""
|
||||
global _session_manager_instance
|
||||
if _session_manager_instance is None:
|
||||
_session_manager_instance = SessionManager()
|
||||
return _session_manager_instance
|
||||
|
||||
|
||||
# Export main classes and functions
|
||||
__all__ = [
|
||||
'ConversationSession',
|
||||
'SessionManager',
|
||||
'get_session_manager'
|
||||
]
|
||||
0
backend/modules/telegram/api/__init__.py
Normal file
0
backend/modules/telegram/api/__init__.py
Normal file
917
backend/modules/telegram/api/client.py
Normal file
917
backend/modules/telegram/api/client.py
Normal file
@@ -0,0 +1,917 @@
|
||||
"""
|
||||
API Client for ROA2WEB Backend Communication
|
||||
|
||||
This module provides an async HTTP client for communicating with the FastAPI backend.
|
||||
Handles authentication, requests, error handling, and response parsing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, Response, HTTPError, ConnectError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Backend configuration from environment
|
||||
# Default to port 8000 (production) instead of 8001 (development)
|
||||
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
||||
REQUEST_TIMEOUT = float(os.getenv("API_TIMEOUT", "30.0")) # 30 seconds default
|
||||
|
||||
|
||||
class BackendAPIClient:
|
||||
"""
|
||||
Async HTTP client for ROA2WEB FastAPI backend.
|
||||
|
||||
Provides methods for all API endpoints used by the Telegram bot:
|
||||
- Dashboard data
|
||||
- Invoices search and retrieval
|
||||
- Treasury/payment data
|
||||
- Report exports
|
||||
- Company listings
|
||||
- User authentication and token management
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = BACKEND_URL):
|
||||
"""
|
||||
Initialize the API client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the FastAPI backend
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.client: Optional[AsyncClient] = None
|
||||
logger.info(f"Backend API client initialized with base URL: {self.base_url}")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
self.client = AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
follow_redirects=True
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
def _get_auth_headers(self, jwt_token: str) -> Dict[str, str]:
|
||||
"""
|
||||
Generate authentication headers with JWT token.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with Authorization header
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async def _handle_response(self, response: Response) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle API response and extract data.
|
||||
|
||||
Args:
|
||||
response: HTTP response object
|
||||
|
||||
Returns:
|
||||
Dict: Response JSON data
|
||||
|
||||
Raises:
|
||||
HTTPError: If response status is not successful
|
||||
"""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except HTTPError as e:
|
||||
logger.error(f"API request failed: {e}")
|
||||
logger.error(f"Response body: {response.text}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse response: {e}")
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# AUTHENTICATION & USER ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
async def verify_user(
|
||||
self,
|
||||
oracle_username: str,
|
||||
linking_code: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify user exists in Oracle and get JWT token.
|
||||
Called during Telegram linking process (auto-linking flow).
|
||||
|
||||
Args:
|
||||
oracle_username: Oracle username extracted from linking code
|
||||
linking_code: The 8-character linking code for validation
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- success: True if verification succeeded
|
||||
- access_token: JWT access token
|
||||
- refresh_token: JWT refresh token
|
||||
- user: Dict with user_id, username, companies, permissions
|
||||
- message: Status message
|
||||
|
||||
None if user not found or error
|
||||
|
||||
Example:
|
||||
result = await client.verify_user("JOHN.DOE", "ABC12345")
|
||||
if result and result['success']:
|
||||
jwt_token = result['access_token']
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Flow A: Auto-linking (no password required)
|
||||
response = await self.client.post(
|
||||
"/api/telegram/auth/verify-user",
|
||||
json={
|
||||
"linking_code": linking_code,
|
||||
"oracle_username": oracle_username
|
||||
}
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except ConnectError as e:
|
||||
logger.error(f"Cannot connect to backend at {self.base_url}: {e}")
|
||||
logger.error("Verify that backend service is running and BACKEND_URL is correct")
|
||||
return None
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 404:
|
||||
logger.warning(f"User {oracle_username} not found in Oracle")
|
||||
return None
|
||||
logger.error(f"Failed to verify user {oracle_username}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying user: {e}")
|
||||
return None
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Optional[str]:
|
||||
"""
|
||||
Refresh JWT token for a user.
|
||||
|
||||
Args:
|
||||
refresh_token: JWT refresh token
|
||||
|
||||
Returns:
|
||||
str: New JWT access token, None if failed
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/telegram/auth/refresh-token",
|
||||
json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
data = await self._handle_response(response)
|
||||
return data.get('access_token')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh token: {e}")
|
||||
return None
|
||||
|
||||
async def verify_email(self, email: str) -> dict:
|
||||
"""
|
||||
Verify if email exists in Oracle database
|
||||
|
||||
Args:
|
||||
email: Email address to verify
|
||||
|
||||
Returns:
|
||||
dict with 'success' (bool), 'username' (str or None), and 'message' (str)
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: On network or HTTP errors
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/telegram/auth/verify-email",
|
||||
json={"email": email}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error verifying email {email}: {e.response.status_code}")
|
||||
return {
|
||||
"success": False,
|
||||
"username": None,
|
||||
"message": "Eroare la verificarea email-ului"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify email {email}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"username": None,
|
||||
"message": "Eroare la verificarea email-ului"
|
||||
}
|
||||
|
||||
async def login_with_email(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
telegram_user_id: int,
|
||||
session_token: str
|
||||
) -> dict:
|
||||
"""
|
||||
Login via email + password with session token
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Oracle password
|
||||
telegram_user_id: Telegram user ID
|
||||
session_token: Signed token from code validation
|
||||
|
||||
Returns:
|
||||
Login response with JWT tokens and user data
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: On network or HTTP errors
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/telegram/auth/login-with-email",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"telegram_user_id": telegram_user_id,
|
||||
"session_token": session_token
|
||||
},
|
||||
timeout=30.0 # 30 seconds timeout
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Email login successful for user {telegram_user_id}")
|
||||
|
||||
return data
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Email login HTTP error: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
# Parse error detail if available
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
return {
|
||||
"success": False,
|
||||
"message": error_data.get("detail", "Autentificare eșuată")
|
||||
}
|
||||
except:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Autentificare eșuată"
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Email login timeout")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Timeout. Te rugăm să încerci din nou."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Email login error: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Eroare de conexiune"
|
||||
}
|
||||
|
||||
async def get_user_companies(self, jwt_token: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of companies the user has access to.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
List of company dicts with id, nume_firma, cui, etc.
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/companies",
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
data = await self._handle_response(response)
|
||||
|
||||
# Backend returns {"companies": [...], "total_count": N}
|
||||
if isinstance(data, dict) and "companies" in data:
|
||||
return data["companies"]
|
||||
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get companies: {e}")
|
||||
return []
|
||||
|
||||
# =========================================================================
|
||||
# DASHBOARD ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
async def get_dashboard_data(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get dashboard statistics for a company.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with dashboard data (sold_total, facturi, plati, etc.)
|
||||
Includes _cache_hit and _response_time_ms metadata
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/reports/dashboard/summary",
|
||||
params={"company": str(company_id)},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get dashboard data for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_treasury_breakdown(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed treasury breakdown (casa + banca accounts).
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with treasury breakdown data (accounts by type)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/treasury-breakdown?company={company_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get treasury breakdown for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_detailed_data(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
data_type: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed data for clients or suppliers.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
data_type: Type of data ('clients' or 'suppliers')
|
||||
|
||||
Returns:
|
||||
Dict with detailed data (list of clients/suppliers with balances)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/detailed-data?company={company_id}&data_type={data_type}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get detailed data ({data_type}) for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_maturity_data(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
period: str = "all"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get maturity data (in term/overdue breakdown).
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
period: Period filter ('all', '30', '60', '90')
|
||||
|
||||
Returns:
|
||||
Dict with maturity data (in_term, overdue, total)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/maturity?company={company_id}&period={period}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get maturity data for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_performance_data(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get performance data (incasari/plati totals).
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with performance data (incasari_total, plati_total, net)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/performance?company={company_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance data for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_monthly_flows(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
months: int = 12
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get monthly cash flows data.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
months: Number of months to retrieve
|
||||
|
||||
Returns:
|
||||
Dict with monthly flows (months, incasari, plati arrays)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/monthly-flows?company={company_id}&months={months}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get monthly flows for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_trends(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
period: str = "12m"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get trends data (12-month historical data for collections/payments).
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
period: Period for trends (e.g., "12m", "6m", "ytd")
|
||||
|
||||
Returns:
|
||||
Dict with trends data including periods, clienti_incasat, furnizori_achitat arrays
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
# Add cache metadata header for Telegram Bot
|
||||
headers = self._get_auth_headers(jwt_token)
|
||||
headers['X-Include-Cache-Metadata'] = 'true'
|
||||
|
||||
response = await self.client.get(
|
||||
f"/api/reports/dashboard/trends?company={company_id}&period={period}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get trends for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# INVOICES ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
async def search_invoices(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search invoices with optional filters.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
filters: Optional filters dict:
|
||||
- date_from: str (YYYY-MM-DD)
|
||||
- date_to: str (YYYY-MM-DD)
|
||||
- status: str (paid, unpaid, overdue)
|
||||
- client_name: str
|
||||
- partner_type: str (CLIENTI, FURNIZORI)
|
||||
- partner_name: str
|
||||
- series: str
|
||||
- number: str
|
||||
|
||||
Returns:
|
||||
List of invoice dicts
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
params = {"company": company_id}
|
||||
if filters:
|
||||
params.update(filters)
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/reports/invoices/",
|
||||
params=params,
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
data = await self._handle_response(response)
|
||||
|
||||
if isinstance(data, dict) and 'invoices' in data:
|
||||
invoice_list = data['invoices']
|
||||
return invoice_list
|
||||
elif isinstance(data, list):
|
||||
return data
|
||||
else:
|
||||
logger.warning(f"📥 Unexpected response format: {type(data)}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search invoices for company {company_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_invoice_summary(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
partner_type: str = "CLIENTI"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get invoice summary statistics.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with summary (total_count, total_amount, paid, unpaid, etc.)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/reports/invoices/summary",
|
||||
params={
|
||||
"company": str(company_id),
|
||||
"partner_type": partner_type
|
||||
},
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get invoice summary for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# TREASURY ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
async def get_treasury_data(
|
||||
self,
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get treasury/cash flow data for a company.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with treasury data (cash_balance, incoming, outgoing, etc.)
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/reports/treasury/bank-cash-register",
|
||||
params={
|
||||
"company": str(company_id),
|
||||
"page": 1,
|
||||
"page_size": 1000
|
||||
},
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
return await self._handle_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get treasury data for company {company_id}: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# EXPORT ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
async def export_report(
|
||||
self,
|
||||
jwt_token: str,
|
||||
report_type: str,
|
||||
company_id: int,
|
||||
format: str = "xlsx",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Generate and export a report.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
report_type: Type of report ('dashboard', 'invoices', 'treasury')
|
||||
company_id: Company ID
|
||||
format: Export format ('xlsx', 'csv', 'pdf')
|
||||
filters: Optional filters for data
|
||||
|
||||
Returns:
|
||||
bytes: File content, None if failed
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
request_data = {
|
||||
"type": report_type,
|
||||
"company_id": company_id,
|
||||
"format": format,
|
||||
"filters": filters or {}
|
||||
}
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/telegram/export",
|
||||
json=request_data,
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export report: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# CACHE MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
jwt_token: str,
|
||||
company_id: Optional[int] = None,
|
||||
cache_type: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
company_id: Optional company ID (None = all companies)
|
||||
cache_type: Optional cache type (None = all types)
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
request_data = {}
|
||||
if company_id is not None:
|
||||
request_data['company_id'] = company_id
|
||||
if cache_type is not None:
|
||||
request_data['cache_type'] = cache_type
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/reports/cache/invalidate",
|
||||
json=request_data,
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
logger.info(f"Cache invalidated: company_id={company_id}, cache_type={cache_type}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to invalidate cache: {e}")
|
||||
return False
|
||||
|
||||
async def toggle_user_cache(
|
||||
self,
|
||||
jwt_token: str,
|
||||
enabled: bool
|
||||
) -> bool:
|
||||
"""
|
||||
Toggle cache for current user.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
enabled: True to enable cache, False to disable
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.post(
|
||||
"/api/reports/cache/toggle-user",
|
||||
json={"enabled": enabled},
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
logger.info(f"User cache toggled: enabled={enabled}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to toggle user cache: {e}")
|
||||
return False
|
||||
|
||||
async def get_cache_stats(
|
||||
self,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cache statistics including user-specific settings.
|
||||
|
||||
Args:
|
||||
jwt_token: JWT access token
|
||||
|
||||
Returns:
|
||||
Dict with cache stats including 'user_enabled' field
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.get(
|
||||
"/api/reports/cache/stats",
|
||||
headers=self._get_auth_headers(jwt_token)
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache stats: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# HEALTH CHECK
|
||||
# =========================================================================
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if backend is healthy and reachable.
|
||||
|
||||
Returns:
|
||||
bool: True if backend is healthy
|
||||
"""
|
||||
try:
|
||||
if not self.client or self.client.is_closed:
|
||||
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
response = await self.client.get("/api/telegram/health")
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backend health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance for global use
|
||||
_backend_client_instance: Optional[BackendAPIClient] = None
|
||||
|
||||
|
||||
def get_backend_client() -> BackendAPIClient:
|
||||
"""
|
||||
Get or create the singleton BackendAPIClient instance.
|
||||
|
||||
Returns:
|
||||
BackendAPIClient: Singleton instance
|
||||
"""
|
||||
global _backend_client_instance
|
||||
if _backend_client_instance is None:
|
||||
_backend_client_instance = BackendAPIClient()
|
||||
return _backend_client_instance
|
||||
|
||||
|
||||
# Export main classes and functions
|
||||
__all__ = [
|
||||
'BackendAPIClient',
|
||||
'get_backend_client',
|
||||
'BACKEND_URL'
|
||||
]
|
||||
0
backend/modules/telegram/auth/__init__.py
Normal file
0
backend/modules/telegram/auth/__init__.py
Normal file
171
backend/modules/telegram/auth/email_auth.py
Normal file
171
backend/modules/telegram/auth/email_auth.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Email authentication logic with crypto-secure code generation
|
||||
"""
|
||||
import secrets
|
||||
import re
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============================================================================
|
||||
# RATE LIMITING (In-Memory)
|
||||
# ============================================================================
|
||||
# NOTE: For production with multiple bot instances, migrate to Redis
|
||||
# See "Optional Future Enhancements" section in plan
|
||||
|
||||
_rate_limit_store: Dict[str, list] = defaultdict(list)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
identifier: str,
|
||||
max_attempts: int = 3,
|
||||
window_minutes: int = 60
|
||||
) -> bool:
|
||||
"""
|
||||
Check if identifier is within rate limit
|
||||
|
||||
Args:
|
||||
identifier: Email or telegram_user_id (as string)
|
||||
max_attempts: Maximum attempts allowed
|
||||
window_minutes: Time window in minutes
|
||||
|
||||
Returns:
|
||||
True if within limit (can proceed), False if exceeded
|
||||
|
||||
NOTE: In-memory implementation - resets on bot restart
|
||||
"""
|
||||
now = datetime.now()
|
||||
cutoff = now - timedelta(minutes=window_minutes)
|
||||
|
||||
# Clean old attempts
|
||||
_rate_limit_store[identifier] = [
|
||||
attempt for attempt in _rate_limit_store[identifier]
|
||||
if attempt > cutoff
|
||||
]
|
||||
|
||||
# Check limit
|
||||
if len(_rate_limit_store[identifier]) >= max_attempts:
|
||||
logger.warning(f"Rate limit exceeded for {identifier}")
|
||||
return False
|
||||
|
||||
# Add new attempt
|
||||
_rate_limit_store[identifier].append(now)
|
||||
return True
|
||||
|
||||
|
||||
def clear_rate_limit(identifier: str) -> None:
|
||||
"""Clear rate limit for identifier (e.g., after successful auth)"""
|
||||
if identifier in _rate_limit_store:
|
||||
del _rate_limit_store[identifier]
|
||||
logger.debug(f"Rate limit cleared for {identifier}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CODE GENERATION (Crypto-Secure)
|
||||
# ============================================================================
|
||||
|
||||
def generate_email_code() -> str:
|
||||
"""
|
||||
Generate crypto-secure 6-digit code
|
||||
|
||||
Uses secrets module (not random) for cryptographic security
|
||||
|
||||
Returns:
|
||||
6-digit string (000000 - 999999)
|
||||
"""
|
||||
# Generate 6-digit code using secrets (crypto-secure)
|
||||
code = ''.join(secrets.choice('0123456789') for _ in range(6))
|
||||
|
||||
logger.debug(f"Generated email auth code (length: {len(code)})")
|
||||
return code
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EMAIL VALIDATION
|
||||
# ============================================================================
|
||||
|
||||
def is_valid_email_format(email: str) -> bool:
|
||||
"""
|
||||
Validate email format (basic regex)
|
||||
|
||||
Args:
|
||||
email: Email address to validate
|
||||
|
||||
Returns:
|
||||
True if format is valid
|
||||
"""
|
||||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
async def verify_email_in_oracle(email: str) -> Optional[str]:
|
||||
"""
|
||||
Verify email exists in Oracle UTILIZATORI table via backend API
|
||||
|
||||
Args:
|
||||
email: Email address to check
|
||||
|
||||
Returns:
|
||||
Oracle username if found and active, None otherwise
|
||||
|
||||
NOTE: Uses backend API endpoint /api/telegram/auth/verify-email
|
||||
"""
|
||||
try:
|
||||
from backend.modules.telegram.api.client import get_backend_client
|
||||
|
||||
backend_client = get_backend_client()
|
||||
|
||||
# Call backend API to verify email
|
||||
response = await backend_client.verify_email(email)
|
||||
|
||||
if response.get('success'):
|
||||
username = response.get('username')
|
||||
logger.info(f"Email verified via backend: {email} -> {username}")
|
||||
return username
|
||||
else:
|
||||
logger.warning(f"Email not found or inactive: {email}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying email via backend: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION TOKEN GENERATION (Prevent User ID Spoofing)
|
||||
# ============================================================================
|
||||
|
||||
def generate_session_token(telegram_user_id: int, email: str) -> str:
|
||||
"""
|
||||
Generate signed session token for backend verification
|
||||
|
||||
This prevents user ID spoofing attacks where malicious clients
|
||||
could impersonate Telegram users by sending arbitrary user IDs
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
email: Verified email address
|
||||
|
||||
Returns:
|
||||
Signed token (simple implementation - upgrade to JWT in future)
|
||||
|
||||
NOTE: For production, use proper JWT signing with shared secret
|
||||
"""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# Get secret from env (should match backend)
|
||||
secret = os.getenv("AUTH_SESSION_SECRET", "change-me-in-production")
|
||||
|
||||
# Create signature: HMAC-like hash
|
||||
payload = f"{telegram_user_id}:{email}:{secret}"
|
||||
signature = hashlib.sha256(payload.encode()).hexdigest()[:16]
|
||||
|
||||
# Token format: user_id:email:signature
|
||||
token = f"{telegram_user_id}:{email}:{signature}"
|
||||
|
||||
logger.debug(f"Generated session token for user {telegram_user_id}")
|
||||
return token
|
||||
350
backend/modules/telegram/auth/linking.py
Normal file
350
backend/modules/telegram/auth/linking.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Authentication and User Linking Logic
|
||||
|
||||
This module handles the linking process between Telegram users and Oracle ERP accounts.
|
||||
It manages authentication codes, verifies users through the backend API, and maintains
|
||||
user sessions with JWT tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from telegram import User as TelegramUser
|
||||
|
||||
from backend.modules.telegram.db.operations import (
|
||||
get_user,
|
||||
create_or_update_user,
|
||||
link_user_to_oracle,
|
||||
update_user_tokens,
|
||||
verify_and_use_auth_code,
|
||||
is_user_linked
|
||||
)
|
||||
from backend.modules.telegram.api.client import get_backend_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def link_telegram_account(
|
||||
telegram_user: TelegramUser,
|
||||
auth_code: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Link a Telegram account to an Oracle ERP account using an authentication code.
|
||||
|
||||
Flow:
|
||||
1. Verify auth code in database (check exists, not used, not expired)
|
||||
2. Extract oracle_username from code
|
||||
3. Call backend API to verify user in Oracle and get JWT token
|
||||
4. Create/update Telegram user record
|
||||
5. Link user to Oracle account with JWT tokens
|
||||
6. Return success with user data
|
||||
|
||||
Args:
|
||||
telegram_user: Telegram User object from python-telegram-bot
|
||||
auth_code: 8-character authentication code from web frontend
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- success: True if linking succeeded
|
||||
- username: Oracle username
|
||||
- jwt_token: JWT access token
|
||||
- companies: List of companies user has access to
|
||||
OR None if linking failed
|
||||
|
||||
Example:
|
||||
result = await link_telegram_account(telegram_user, "ABC12345")
|
||||
if result:
|
||||
print(f"Linked to {result['username']}")
|
||||
else:
|
||||
print("Linking failed")
|
||||
"""
|
||||
try:
|
||||
telegram_user_id = telegram_user.id
|
||||
telegram_username = telegram_user.username
|
||||
first_name = telegram_user.first_name
|
||||
last_name = telegram_user.last_name
|
||||
|
||||
logger.info(
|
||||
f"Attempting to link Telegram user {telegram_user_id} "
|
||||
f"(@{telegram_username}) with code {auth_code}"
|
||||
)
|
||||
|
||||
# Step 1: Verify auth code
|
||||
code_data = await verify_and_use_auth_code(auth_code)
|
||||
|
||||
if not code_data:
|
||||
logger.warning(f"Invalid, expired, or already used auth code: {auth_code}")
|
||||
return None
|
||||
|
||||
oracle_username = code_data.get('oracle_username')
|
||||
logger.info(f"Auth code valid for Oracle user: {oracle_username}")
|
||||
|
||||
# Step 2: Create/update Telegram user record (basic info)
|
||||
user_created = await create_or_update_user(
|
||||
telegram_user_id=telegram_user_id,
|
||||
username=telegram_username,
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
)
|
||||
|
||||
if not user_created:
|
||||
logger.error(f"Failed to create/update Telegram user {telegram_user_id}")
|
||||
return None
|
||||
|
||||
# Step 3: Verify user in Oracle and get JWT token via backend API (auto-linking flow)
|
||||
backend_client = get_backend_client()
|
||||
async with backend_client:
|
||||
user_data = await backend_client.verify_user(
|
||||
oracle_username=oracle_username,
|
||||
linking_code=auth_code
|
||||
)
|
||||
|
||||
if not user_data or not user_data.get('success'):
|
||||
logger.error(f"Failed to verify Oracle user {oracle_username} via backend")
|
||||
return None
|
||||
|
||||
# Extract tokens and user info from response
|
||||
jwt_token = user_data.get('access_token')
|
||||
jwt_refresh_token = user_data.get('refresh_token', jwt_token)
|
||||
user_info = user_data.get('user', {})
|
||||
companies = user_info.get('companies', [])
|
||||
permissions = user_info.get('permissions', [])
|
||||
|
||||
# Token expiration (typically 30 minutes for access token)
|
||||
token_expires_at = datetime.now() + timedelta(minutes=30)
|
||||
|
||||
# Step 4: Link Telegram user to Oracle account
|
||||
linked = await link_user_to_oracle(
|
||||
telegram_user_id=telegram_user_id,
|
||||
oracle_username=oracle_username,
|
||||
jwt_token=jwt_token,
|
||||
jwt_refresh_token=jwt_refresh_token,
|
||||
token_expires_at=token_expires_at
|
||||
)
|
||||
|
||||
if not linked:
|
||||
logger.error(f"Failed to link user {telegram_user_id} to Oracle account")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Successfully linked Telegram user {telegram_user_id} "
|
||||
f"to Oracle user {oracle_username}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"telegram_user_id": telegram_user_id,
|
||||
"username": oracle_username,
|
||||
"jwt_token": jwt_token,
|
||||
"jwt_refresh_token": jwt_refresh_token,
|
||||
"companies": companies,
|
||||
"permissions": permissions,
|
||||
"linked_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error linking Telegram account: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_auth_data(telegram_user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get authentication data for a linked Telegram user.
|
||||
|
||||
This function retrieves the user's Oracle account information and JWT tokens.
|
||||
If the token is expired, it automatically refreshes it.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- telegram_user_id: Telegram user ID
|
||||
- username: Oracle username
|
||||
- jwt_token: Valid JWT access token (refreshed if needed)
|
||||
- jwt_refresh_token: JWT refresh token
|
||||
- companies: List of companies (fetched if not cached)
|
||||
OR None if user is not linked or error occurred
|
||||
|
||||
Example:
|
||||
auth_data = await get_user_auth_data(12345)
|
||||
if auth_data:
|
||||
jwt = auth_data['jwt_token']
|
||||
# Use JWT for API calls
|
||||
"""
|
||||
try:
|
||||
# Get user from database
|
||||
user_data = await get_user(telegram_user_id)
|
||||
|
||||
if not user_data:
|
||||
logger.warning(f"User {telegram_user_id} not found in database")
|
||||
return None
|
||||
|
||||
if not user_data.get('oracle_username'):
|
||||
logger.warning(f"User {telegram_user_id} is not linked to Oracle account")
|
||||
return None
|
||||
|
||||
oracle_username = user_data['oracle_username']
|
||||
jwt_token = user_data['jwt_token']
|
||||
jwt_refresh_token = user_data['jwt_refresh_token']
|
||||
token_expires_at_str = user_data['token_expires_at']
|
||||
|
||||
# Parse token expiration
|
||||
token_expires_at = datetime.fromisoformat(token_expires_at_str) if token_expires_at_str else None
|
||||
|
||||
# Check if token is expired or about to expire (< 5 minutes remaining)
|
||||
token_expired = (
|
||||
token_expires_at is None or
|
||||
datetime.now() >= token_expires_at - timedelta(minutes=5)
|
||||
)
|
||||
|
||||
if token_expired:
|
||||
logger.info(f"Token expired for user {telegram_user_id}, refreshing...")
|
||||
|
||||
# Refresh token via backend API
|
||||
backend_client = get_backend_client()
|
||||
async with backend_client:
|
||||
new_token = await backend_client.refresh_token(jwt_refresh_token)
|
||||
|
||||
if new_token:
|
||||
# Update token in database
|
||||
new_expires_at = datetime.now() + timedelta(minutes=30)
|
||||
await update_user_tokens(
|
||||
telegram_user_id=telegram_user_id,
|
||||
jwt_token=new_token,
|
||||
jwt_refresh_token=jwt_refresh_token, # Keep same refresh token
|
||||
token_expires_at=new_expires_at
|
||||
)
|
||||
|
||||
jwt_token = new_token
|
||||
logger.info(f"Token refreshed for user {telegram_user_id}")
|
||||
else:
|
||||
logger.error(f"Failed to refresh token for user {telegram_user_id}")
|
||||
return None
|
||||
|
||||
# Fetch user companies (fresh from backend)
|
||||
backend_client = get_backend_client()
|
||||
async with backend_client:
|
||||
companies = await backend_client.get_user_companies(jwt_token)
|
||||
|
||||
return {
|
||||
"telegram_user_id": telegram_user_id,
|
||||
"username": oracle_username,
|
||||
"jwt_token": jwt_token,
|
||||
"jwt_refresh_token": jwt_refresh_token,
|
||||
"companies": companies
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user auth data: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def check_user_linked(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Check if a Telegram user is linked to an Oracle account.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if user is linked, False otherwise
|
||||
|
||||
Example:
|
||||
if await check_user_linked(12345):
|
||||
print("User is linked")
|
||||
else:
|
||||
print("User needs to link account")
|
||||
"""
|
||||
try:
|
||||
return await is_user_linked(telegram_user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if user is linked: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_user_companies(telegram_user_id: int) -> Optional[list]:
|
||||
"""
|
||||
Get list of companies a user has access to.
|
||||
|
||||
This is a convenience function that fetches user auth data and returns
|
||||
just the companies list.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
List of company dicts, or None if user not linked
|
||||
|
||||
Example:
|
||||
companies = await get_user_companies(12345)
|
||||
if companies:
|
||||
for company in companies:
|
||||
print(f"{company['id']}: {company['nume_firma']}")
|
||||
"""
|
||||
try:
|
||||
auth_data = await get_user_auth_data(telegram_user_id)
|
||||
|
||||
if auth_data:
|
||||
return auth_data.get('companies', [])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user companies: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def unlink_user(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Unlink a Telegram user from their Oracle account.
|
||||
|
||||
This removes the linking but keeps the Telegram user record.
|
||||
Used for account disconnection or security purposes.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if successfully unlinked
|
||||
|
||||
Example:
|
||||
if await unlink_user(12345):
|
||||
print("Account unlinked")
|
||||
"""
|
||||
try:
|
||||
# Set Oracle username and tokens to NULL
|
||||
from backend.modules.telegram.db.database import DB_PATH
|
||||
import aiosqlite
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE telegram_users
|
||||
SET oracle_username = NULL,
|
||||
jwt_token = NULL,
|
||||
jwt_refresh_token = NULL,
|
||||
token_expires_at = NULL,
|
||||
linked_at = NULL
|
||||
WHERE telegram_user_id = ?
|
||||
""", (telegram_user_id,))
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"User {telegram_user_id} unlinked from Oracle account")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error unlinking user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Export main functions
|
||||
__all__ = [
|
||||
'link_telegram_account',
|
||||
'get_user_auth_data',
|
||||
'check_user_linked',
|
||||
'get_user_companies',
|
||||
'unlink_user'
|
||||
]
|
||||
0
backend/modules/telegram/bot/__init__.py
Normal file
0
backend/modules/telegram/bot/__init__.py
Normal file
768
backend/modules/telegram/bot/email_handlers.py
Normal file
768
backend/modules/telegram/bot/email_handlers.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""
|
||||
Telegram bot handlers for email-based authentication flow
|
||||
"""
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import (
|
||||
ContextTypes,
|
||||
ConversationHandler,
|
||||
CommandHandler,
|
||||
MessageHandler,
|
||||
CallbackQueryHandler,
|
||||
filters
|
||||
)
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
|
||||
from backend.modules.telegram.auth.email_auth import (
|
||||
is_valid_email_format,
|
||||
verify_email_in_oracle,
|
||||
generate_email_code,
|
||||
generate_session_token,
|
||||
check_rate_limit,
|
||||
clear_rate_limit
|
||||
)
|
||||
from backend.modules.telegram.utils.email_service import get_email_service
|
||||
from backend.modules.telegram.db.operations import (
|
||||
create_email_auth_code,
|
||||
get_email_auth_code,
|
||||
get_pending_email_code,
|
||||
mark_email_code_used,
|
||||
increment_failed_attempts,
|
||||
delete_user_email_codes,
|
||||
is_user_authenticated,
|
||||
link_user_to_oracle,
|
||||
create_or_update_user
|
||||
)
|
||||
from backend.modules.telegram.api.client import get_backend_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Conversation states
|
||||
AWAITING_EMAIL, AWAITING_CODE, AWAITING_PASSWORD = range(3)
|
||||
|
||||
# Constants
|
||||
MAX_CODE_ATTEMPTS = 3
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
async def edit_login_message(
|
||||
context: ContextTypes.DEFAULT_TYPE,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_markup=None,
|
||||
parse_mode="Markdown"
|
||||
):
|
||||
"""
|
||||
Helper function to edit the login message stored in context.
|
||||
If message_id is not stored, creates a new message instead.
|
||||
"""
|
||||
message_id = context.user_data.get('login_message_id')
|
||||
|
||||
if message_id:
|
||||
try:
|
||||
await context.bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
reply_markup=reply_markup,
|
||||
parse_mode=parse_mode
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not edit message {message_id}: {e}")
|
||||
# Fallback: send new message and update ID
|
||||
msg = await context.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_markup=reply_markup,
|
||||
parse_mode=parse_mode
|
||||
)
|
||||
context.user_data['login_message_id'] = msg.message_id
|
||||
else:
|
||||
# No message ID stored - create new message
|
||||
msg = await context.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_markup=reply_markup,
|
||||
parse_mode=parse_mode
|
||||
)
|
||||
context.user_data['login_message_id'] = msg.message_id
|
||||
|
||||
|
||||
async def delete_login_message(context: ContextTypes.DEFAULT_TYPE, chat_id: int):
|
||||
"""Delete the login message and clear the message_id from context"""
|
||||
message_id = context.user_data.get('login_message_id')
|
||||
|
||||
if message_id:
|
||||
try:
|
||||
await context.bot.delete_message(chat_id=chat_id, message_id=message_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete message {message_id}: {e}")
|
||||
|
||||
# Clear from context
|
||||
context.user_data.pop('login_message_id', None)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENTRY POINTS: /login command and action:login button
|
||||
# ============================================================================
|
||||
|
||||
async def login_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""
|
||||
Handler pentru /login command
|
||||
Oferă opțiuni de autentificare: Email sau Web App
|
||||
"""
|
||||
user = update.effective_user
|
||||
|
||||
# Check dacă e deja autentificat
|
||||
if await is_user_authenticated(user.id):
|
||||
await update.message.reply_text(
|
||||
"Ești deja autentificat.\n\n"
|
||||
"Folosește:\n"
|
||||
"• /companies - Vezi companiile tale\n"
|
||||
"• /help - Comenzi disponibile\n"
|
||||
"• /unlink - Deautentifică-te"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Check rate limiting (3 requests per hour)
|
||||
if not await check_rate_limit(f"login_{user.id}", max_attempts=3, window_minutes=60):
|
||||
await update.message.reply_text(
|
||||
"Prea multe încercări de autentificare.\n\n"
|
||||
"Te rugăm să aștepți 60 de minute înainte de a încerca din nou."
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Afișează opțiuni de autentificare
|
||||
keyboard = [
|
||||
[InlineKeyboardButton("Login cu Email + Parolă", callback_data="email_login")],
|
||||
[InlineKeyboardButton("Login din Web App", callback_data="web_login_info")],
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
# CREATE message and SAVE message_id
|
||||
msg = await update.message.reply_text(
|
||||
"Alege metoda de autentificare:",
|
||||
reply_markup=reply_markup,
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
|
||||
# Save message ID for future edits
|
||||
context.user_data['login_message_id'] = msg.message_id
|
||||
|
||||
return AWAITING_EMAIL
|
||||
|
||||
|
||||
async def action_login_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""
|
||||
Handler pentru butonul Login din meniu (action:login)
|
||||
Oferă opțiuni de autentificare: Email sau Web App
|
||||
"""
|
||||
query = update.callback_query
|
||||
user = update.effective_user
|
||||
|
||||
logger.info(f"[EMAIL_AUTH] action_login_callback triggered for user {user.id}")
|
||||
|
||||
await query.answer()
|
||||
|
||||
# Check dacă e deja autentificat
|
||||
if await is_user_authenticated(user.id):
|
||||
await query.edit_message_text(
|
||||
"Ești deja autentificat.\n\n"
|
||||
"Folosește:\n"
|
||||
"• /companies - Vezi companiile tale\n"
|
||||
"• /help - Comenzi disponibile\n"
|
||||
"• /unlink - Deautentifică-te"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Check rate limiting (3 requests per hour)
|
||||
if not await check_rate_limit(f"login_{user.id}", max_attempts=3, window_minutes=60):
|
||||
await query.edit_message_text(
|
||||
"Prea multe încercări de autentificare.\n\n"
|
||||
"Te rugăm să aștepți 60 de minute înainte de a încerca din nou."
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Afișează opțiuni de autentificare
|
||||
keyboard = [
|
||||
[InlineKeyboardButton("Login cu Email + Parolă", callback_data="email_login")],
|
||||
[InlineKeyboardButton("Login din Web App", callback_data="web_login_info")],
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
# EDIT existing menu message and SAVE message_id
|
||||
await query.edit_message_text(
|
||||
"Alege metoda de autentificare:",
|
||||
reply_markup=reply_markup,
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
|
||||
# Save message ID for future edits
|
||||
context.user_data['login_message_id'] = query.message.message_id
|
||||
|
||||
return AWAITING_EMAIL
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CALLBACK: Email Login
|
||||
# ============================================================================
|
||||
|
||||
async def email_login_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Callback pentru butonul 'Login cu Email'"""
|
||||
query = update.callback_query
|
||||
user = update.effective_user
|
||||
|
||||
logger.info(f"[EMAIL_AUTH] email_login_callback triggered for user {user.id}")
|
||||
|
||||
await query.answer()
|
||||
|
||||
# IMPORTANT: Salvează message_id înainte de a edita
|
||||
context.user_data['login_message_id'] = query.message.message_id
|
||||
|
||||
# EDIT same message - remove buttons, ask for email
|
||||
await query.edit_message_text(
|
||||
text="Introdu adresa de email ROA:",
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
|
||||
return AWAITING_EMAIL
|
||||
|
||||
|
||||
async def web_login_info_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Info despre web app login"""
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
await query.edit_message_text(
|
||||
"**Login din Web App**\n\n"
|
||||
"Pentru această metodă:\n\n"
|
||||
"1. Accesează aplicația web ROA2WEB\n"
|
||||
"2. Autentifică-te cu username + parolă\n"
|
||||
"3. Apasă butonul \"Link Telegram\"\n"
|
||||
"4. Copiază codul generat (8 caractere)\n"
|
||||
"5. Trimite-mi codul: /start ABC123XY\n\n"
|
||||
"Vei fi autentificat automat.",
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
|
||||
# IMPORTANT: Salvează message_id pentru ca /start să poată edita același mesaj
|
||||
context.user_data['web_login_message_id'] = query.message.message_id
|
||||
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STATE: AWAITING_EMAIL
|
||||
# ============================================================================
|
||||
|
||||
async def receive_email(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Handler pentru primirea email-ului"""
|
||||
email = update.message.text.strip().lower()
|
||||
user_id = update.effective_user.id
|
||||
|
||||
# ȘTERG mesajul utilizatorului imediat (chat curat)
|
||||
try:
|
||||
await update.message.delete()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete email message: {e}")
|
||||
|
||||
# Validare format email
|
||||
if not is_valid_email_format(email):
|
||||
# Show error in main message
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Email invalid\n\nIntrodu o adresă validă (nume@domeniu.ro)",
|
||||
reply_markup=InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
])
|
||||
)
|
||||
return AWAITING_EMAIL
|
||||
|
||||
# Check for existing pending code
|
||||
existing_code = await get_pending_email_code(user_id)
|
||||
if existing_code:
|
||||
# Delete old pending code
|
||||
await delete_user_email_codes(user_id)
|
||||
logger.info(f"Deleted existing pending code for user {user_id}")
|
||||
|
||||
# EDIT login message to show loading
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Verificare email...",
|
||||
reply_markup=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Verifică email în Oracle
|
||||
username = await verify_email_in_oracle(email)
|
||||
|
||||
# IMPORTANT: Generic response to prevent email enumeration
|
||||
# We always say "code sent" even if email doesn't exist
|
||||
|
||||
if username:
|
||||
# Email exists - generate and send code
|
||||
code = generate_email_code()
|
||||
|
||||
# Save code in database
|
||||
code_saved = await create_email_auth_code(
|
||||
code=code,
|
||||
email=email,
|
||||
username=username,
|
||||
telegram_user_id=user_id,
|
||||
expiry_minutes=5
|
||||
)
|
||||
|
||||
if not code_saved:
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare la salvarea codului.\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Send email (async with retry)
|
||||
email_service = get_email_service()
|
||||
email_sent = await email_service.send_auth_code(email, code, username)
|
||||
|
||||
if not email_sent:
|
||||
logger.error(f"Failed to send email to {email}")
|
||||
# Don't reveal this to user - they'll timeout naturally
|
||||
|
||||
# Wait 1 second for better UX (looks like verification happened)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# ALWAYS show this message (prevent enumeration)
|
||||
# EDIT same message with success + buttons
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text=f"Cod trimis pe {email}\n\nIntrodu codul primit pe email:",
|
||||
reply_markup=InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{email}")],
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
])
|
||||
)
|
||||
|
||||
# Save email in context for resend functionality
|
||||
context.user_data['pending_email'] = email
|
||||
context.user_data['pending_username'] = username
|
||||
|
||||
return AWAITING_CODE
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in receive_email: {e}", exc_info=True)
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare internă.\n\nIncearcă din nou mai târziu."
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STATE: AWAITING_CODE
|
||||
# ============================================================================
|
||||
|
||||
async def receive_code(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Handler pentru primirea codului din email"""
|
||||
code = update.message.text.strip()
|
||||
user_id = update.effective_user.id
|
||||
|
||||
# ȘTERG mesajul utilizatorului imediat (chat curat)
|
||||
try:
|
||||
await update.message.delete()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete code message: {e}")
|
||||
|
||||
# Validare format cod (6 digits)
|
||||
if not (code.isdigit() and len(code) == 6):
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod invalid\n\nIntrodu cele 6 cifre din email.",
|
||||
reply_markup=InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{context.user_data.get('pending_email', '')}")],
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
])
|
||||
)
|
||||
return AWAITING_CODE
|
||||
|
||||
# Verifică cod în DB
|
||||
try:
|
||||
code_data = await get_email_auth_code(code)
|
||||
|
||||
if not code_data:
|
||||
# EDIT login message to show error
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod invalid sau expirat\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Verificări de securitate
|
||||
|
||||
# 1. Check if already used
|
||||
if code_data['used']:
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod deja folosit\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# 2. Check if expired
|
||||
if datetime.now() > code_data['expires_at']:
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod expirat\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# 3. Check if belongs to this user
|
||||
if code_data['telegram_user_id'] != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} tried to use code belonging to "
|
||||
f"user {code_data['telegram_user_id']}"
|
||||
)
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod invalid"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# 4. Check failed attempts (max 3)
|
||||
if code_data['failed_attempts'] >= MAX_CODE_ATTEMPTS:
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Prea multe încercări greșite\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Cod valid - Marchează ca folosit
|
||||
await mark_email_code_used(code)
|
||||
|
||||
# Salvează date verificate în context
|
||||
context.user_data['verified_username'] = code_data['oracle_username']
|
||||
context.user_data['verified_email'] = code_data['email']
|
||||
context.user_data['session_token'] = generate_session_token(
|
||||
user_id,
|
||||
code_data['email']
|
||||
)
|
||||
|
||||
# EDIT same message - ask for password
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Cod validat!\n\nIntroduci parola ROA:",
|
||||
reply_markup=None # No buttons for security
|
||||
)
|
||||
|
||||
return AWAITING_PASSWORD
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating code: {e}", exc_info=True)
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare la validarea codului.\n\nIncearcă din nou."
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
async def resend_code_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Retrimite codul pe email"""
|
||||
query = update.callback_query
|
||||
await query.answer("Retrimitem codul...")
|
||||
|
||||
# Extract email from callback data
|
||||
callback_data = query.data # Format: "resend:email@example.com"
|
||||
if not callback_data.startswith("resend:"):
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
email = callback_data.split(":", 1)[1]
|
||||
user_id = update.effective_user.id
|
||||
|
||||
# Check rate limiting for resend (max 2 per 10 minutes)
|
||||
if not await check_rate_limit(f"resend_{user_id}", max_attempts=2, window_minutes=10):
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Prea multe solicitări\n\nAșteaptă 10 minute."
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Get username from context or re-verify
|
||||
username = context.user_data.get('pending_username')
|
||||
|
||||
if not username:
|
||||
username = await verify_email_in_oracle(email)
|
||||
if not username:
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Delete old code and generate new one
|
||||
await delete_user_email_codes(user_id)
|
||||
|
||||
code = generate_email_code()
|
||||
|
||||
# Save new code
|
||||
await create_email_auth_code(
|
||||
code=code,
|
||||
email=email,
|
||||
username=username,
|
||||
telegram_user_id=user_id,
|
||||
expiry_minutes=5
|
||||
)
|
||||
|
||||
# Send email
|
||||
email_service = get_email_service()
|
||||
await email_service.send_auth_code(email, code, username)
|
||||
|
||||
# FIX BUG: EDIT message and KEEP buttons!
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text=f"Cod retrimis pe {email}\n\nIntrodu codul primit pe email:",
|
||||
reply_markup=InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{email}")],
|
||||
[InlineKeyboardButton("Anulează", callback_data="cancel")]
|
||||
])
|
||||
)
|
||||
|
||||
return AWAITING_CODE
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STATE: AWAITING_PASSWORD
|
||||
# ============================================================================
|
||||
|
||||
async def receive_password(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Handler pentru primirea parolei Oracle"""
|
||||
password = update.message.text.strip()
|
||||
user_id = update.effective_user.id
|
||||
|
||||
# Șterge IMEDIAT mesajul cu parola (securitate)
|
||||
try:
|
||||
await update.message.delete()
|
||||
logger.info(f"Password message deleted for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete password message: {e}")
|
||||
|
||||
# Get verified data from context
|
||||
username = context.user_data.get('verified_username')
|
||||
email = context.user_data.get('verified_email')
|
||||
session_token = context.user_data.get('session_token')
|
||||
|
||||
if not all([username, email, session_token]):
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Sesiune expirată\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# EDIT login message to show loading
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Verificare...",
|
||||
reply_markup=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Call backend endpoint pentru verificare parolă + JWT
|
||||
backend_client = get_backend_client()
|
||||
|
||||
response = await backend_client.login_with_email(
|
||||
email=email,
|
||||
password=password,
|
||||
telegram_user_id=user_id,
|
||||
session_token=session_token
|
||||
)
|
||||
|
||||
if not response.get('success'):
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Credențiale invalide\n\nParolă incorectă sau cont inactiv.\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
# Success - Salvează user în telegram_users
|
||||
# First create or update user record
|
||||
await create_or_update_user(
|
||||
telegram_user_id=user_id,
|
||||
username=update.effective_user.username,
|
||||
first_name=update.effective_user.first_name,
|
||||
last_name=update.effective_user.last_name
|
||||
)
|
||||
|
||||
# Then link to Oracle
|
||||
from datetime import datetime, timedelta
|
||||
token_expires_at = datetime.now() + timedelta(minutes=30) # Default expiry
|
||||
|
||||
await link_user_to_oracle(
|
||||
telegram_user_id=user_id,
|
||||
oracle_username=response['username'],
|
||||
jwt_token=response['access_token'],
|
||||
jwt_refresh_token=response['refresh_token'],
|
||||
token_expires_at=token_expires_at
|
||||
)
|
||||
|
||||
# Clear rate limits on successful auth
|
||||
clear_rate_limit(f"login_{user_id}")
|
||||
clear_rate_limit(f"resend_{user_id}")
|
||||
|
||||
# Get session and active company BEFORE editing message
|
||||
from backend.modules.telegram.agent.session import get_session_manager
|
||||
from backend.modules.telegram.bot.menus import create_main_menu, pad_message_for_wide_buttons
|
||||
|
||||
session_manager = get_session_manager()
|
||||
session = await session_manager.get_or_create_session(user_id)
|
||||
company = session.get_active_company()
|
||||
|
||||
company_name = company['name'] if company else None
|
||||
company_cui = company.get('cui') if company else None
|
||||
|
||||
# Create main menu keyboard
|
||||
keyboard = create_main_menu(
|
||||
company_name=company_name,
|
||||
company_cui=company_cui,
|
||||
is_authenticated=True, # Now authenticated
|
||||
cache_enabled=True # Default enabled
|
||||
)
|
||||
|
||||
# Menu message with company info
|
||||
companies_count = len(response.get('companies', []))
|
||||
|
||||
if company_name:
|
||||
menu_text = f"{company_name}"
|
||||
else:
|
||||
menu_text = f"Companii disponibile: {companies_count}\n\nSelectează o companie pentru a continua"
|
||||
|
||||
menu_message = pad_message_for_wide_buttons(menu_text)
|
||||
|
||||
# EDIT login message to show menu (no deletion, direct edit)
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text=menu_message,
|
||||
reply_markup=keyboard
|
||||
)
|
||||
|
||||
# Clear sensitive data from context
|
||||
context.user_data.clear()
|
||||
|
||||
logger.info(f"User {user_id} authenticated successfully via email")
|
||||
|
||||
return ConversationHandler.END
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during password verification: {e}", exc_info=True)
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Eroare la autentificare.\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CANCEL HANDLER
|
||||
# ============================================================================
|
||||
|
||||
async def cancel_login(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Cancel conversation"""
|
||||
|
||||
# EDIT login message to show cancellation (don't delete)
|
||||
if update.callback_query:
|
||||
# Called from button
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Login anulat",
|
||||
reply_markup=None
|
||||
)
|
||||
elif update.message:
|
||||
# Called from /cancel command
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Login anulat",
|
||||
reply_markup=None
|
||||
)
|
||||
|
||||
# Clear context
|
||||
context.user_data.clear()
|
||||
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
async def conversation_timeout(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""Handler for conversation timeout"""
|
||||
|
||||
# EDIT login message to show timeout
|
||||
await edit_login_message(
|
||||
context=context,
|
||||
chat_id=update.effective_chat.id,
|
||||
text="Sesiune expirată\n\nConversația a expirat după 5 minute.\n\nIncearcă din nou cu /login"
|
||||
)
|
||||
|
||||
# Clear context
|
||||
context.user_data.clear()
|
||||
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONVERSATION HANDLER SETUP
|
||||
# ============================================================================
|
||||
|
||||
email_login_handler = ConversationHandler(
|
||||
entry_points=[
|
||||
CommandHandler('login', login_command),
|
||||
CallbackQueryHandler(action_login_callback, pattern='^action:login$'),
|
||||
CallbackQueryHandler(email_login_callback, pattern='^email_login$')
|
||||
],
|
||||
states={
|
||||
AWAITING_EMAIL: [
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_email)
|
||||
],
|
||||
AWAITING_CODE: [
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_code),
|
||||
CallbackQueryHandler(resend_code_callback, pattern='^resend:')
|
||||
],
|
||||
AWAITING_PASSWORD: [
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_password)
|
||||
],
|
||||
},
|
||||
fallbacks=[
|
||||
CommandHandler('cancel', cancel_login),
|
||||
CallbackQueryHandler(cancel_login, pattern='^cancel$'),
|
||||
CallbackQueryHandler(web_login_info_callback, pattern='^web_login_info$')
|
||||
],
|
||||
per_message=False, # Track conversation per user, not per message
|
||||
allow_reentry=True, # Allow starting new conversation even if previous one is active
|
||||
name="email_login_conversation"
|
||||
)
|
||||
629
backend/modules/telegram/bot/formatters.py
Normal file
629
backend/modules/telegram/bot/formatters.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""
|
||||
Response formatters for bot commands.
|
||||
Formats API responses into user-friendly Telegram messages.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
|
||||
|
||||
def format_dashboard_response(data: Dict[str, Any], company_name: str = None) -> str:
|
||||
"""
|
||||
Format dashboard data for Telegram (content only, no header).
|
||||
|
||||
Note: company_name parameter kept for backwards compatibility but not used.
|
||||
Use format_response_with_company() in handlers to add company header.
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Sold total trezorerie (casa + banca) - rotunjit la leu
|
||||
treasury_totals = data.get('treasury_totals_by_currency', {})
|
||||
sold_trezorerie = round(float(treasury_totals.get('RON', 0)))
|
||||
text += f"**Sold Trezorerie:** {sold_trezorerie:,} RON\n\n"
|
||||
|
||||
# Sold Clienți - rotunjit la leu
|
||||
clienti_sold = round(float(data.get('clienti_sold_total', 0)))
|
||||
clienti_in_termen = round(float(data.get('clienti_sold_in_termen', 0)))
|
||||
clienti_restant = round(float(data.get('clienti_sold_restant', 0)))
|
||||
|
||||
text += f"**Sold Clienți:** {clienti_sold:,} RON\n"
|
||||
text += f" - În termen: {clienti_in_termen:,} RON\n"
|
||||
text += f" - Restanță: {clienti_restant:,} RON\n\n"
|
||||
|
||||
# Sold Furnizori BRUT (pentru consistență cu detaliile) - rotunjit la leu
|
||||
furnizori_in_termen = round(float(data.get('furnizori_sold_in_termen', 0)))
|
||||
furnizori_restant = round(float(data.get('furnizori_sold_restant', 0)))
|
||||
furnizori_sold_brut = furnizori_in_termen + furnizori_restant
|
||||
furnizori_avansuri = round(float(data.get('furnizori_avansuri', 0)))
|
||||
furnizori_sold_net = round(float(data.get('furnizori_sold_total', 0)))
|
||||
|
||||
text += f"**Sold Furnizori:** {furnizori_sold_brut:,} RON\n"
|
||||
text += f" - În termen: {furnizori_in_termen:,} RON\n"
|
||||
text += f" - Restanță: {furnizori_restant:,} RON\n"
|
||||
if furnizori_avansuri != 0:
|
||||
text += f" - Avansuri: {furnizori_avansuri:,} RON\n"
|
||||
text += f" - Net (după avansuri): {furnizori_sold_net:,} RON"
|
||||
else:
|
||||
text += f" - Net: {furnizori_sold_net:,} RON"
|
||||
|
||||
# Solduri TVA - rotunjit la leu
|
||||
tva_plata_prec = round(float(data.get('tva_plata_precedent', 0)))
|
||||
tva_recup_prec = round(float(data.get('tva_recuperat_precedent', 0)))
|
||||
tva_plata_cur = round(float(data.get('tva_plata_curent', 0)))
|
||||
tva_recup_cur = round(float(data.get('tva_recuperat_curent', 0)))
|
||||
|
||||
# Afișează secțiunea doar dacă există cel puțin o valoare > 0
|
||||
if tva_plata_prec > 0 or tva_recup_prec > 0 or tva_plata_cur > 0 or tva_recup_cur > 0:
|
||||
text += "\n\n**Solduri TVA:**\n"
|
||||
if tva_plata_prec > 0:
|
||||
text += f" - TVA de plată precedent: {tva_plata_prec:,} RON\n"
|
||||
if tva_recup_prec > 0:
|
||||
text += f" - TVA de recuperat precedent: {tva_recup_prec:,} RON\n"
|
||||
if tva_plata_cur > 0:
|
||||
text += f" - TVA de plată curent: {tva_plata_cur:,} RON\n"
|
||||
if tva_recup_cur > 0:
|
||||
text += f" - TVA de recuperat curent: {tva_recup_cur:,} RON\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_invoices_response(
|
||||
invoices: List[Dict[str, Any]],
|
||||
company_name: str = None,
|
||||
limit: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
Format invoices list for Telegram - COMPACT TABLE FORMAT.
|
||||
|
||||
Args:
|
||||
invoices: List of invoice dicts
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
limit: Maximum number of invoices to display
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram (compact, no emojis)
|
||||
"""
|
||||
if not invoices:
|
||||
return "Nu s-au gasit facturi cu aceste criterii."
|
||||
|
||||
# Header (o singură dată)
|
||||
text = f"**Facturi** ({len(invoices)} total)\n\n"
|
||||
text += "Nr | Client | Suma | Status\n"
|
||||
text += "---|--------|------|-------\n"
|
||||
|
||||
# Lista facturi - compact, o linie per factură
|
||||
for idx, inv in enumerate(invoices[:limit], 1):
|
||||
seria = inv.get('seria', '')
|
||||
numar = inv.get('numar', '')
|
||||
client = inv.get('client', 'N/A')
|
||||
suma = inv.get('suma_totala', 0)
|
||||
status = inv.get('status', 'N/A')
|
||||
|
||||
# Truncate long client names for compact display
|
||||
client_short = client[:20] + "..." if len(client) > 20 else client
|
||||
|
||||
# Status marker (no emoji)
|
||||
status_marker = "PLATIT" if status == "platit" else "NEPLATIT"
|
||||
|
||||
text += f"{seria}{numar} | {client_short} | {suma:,.0f} | {status_marker}\n"
|
||||
|
||||
if len(invoices) > limit:
|
||||
text += f"\n+{len(invoices) - limit} facturi"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# FAZA 2: New Formatter Functions for Button Interface
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def format_treasury_casa_response(data: Dict[str, Any], company_name: str = None) -> str:
|
||||
"""
|
||||
Format treasury CASH data for Telegram (content only, no header).
|
||||
|
||||
Args:
|
||||
data: Dict with casa accounts and total from treasury breakdown
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram
|
||||
|
||||
Example:
|
||||
data = {'accounts': [...], 'total': 5000}
|
||||
text = format_treasury_casa_response(data)
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Total cash balance - rotunjit la leu (0 zecimale)
|
||||
total_cash = round(data.get('total', 0))
|
||||
text += f"**Sold Total Cash:** {total_cash:,} RON\n\n"
|
||||
|
||||
# Cash accounts
|
||||
casa_accounts = data.get('accounts', [])
|
||||
if casa_accounts:
|
||||
text += "**Conturi de Casa:**\n"
|
||||
for acc in casa_accounts: # Show all accounts
|
||||
name = acc.get('name', 'N/A')
|
||||
balance = round(acc.get('balance', 0))
|
||||
text += f" - {name}: {balance:,} RON\n"
|
||||
else:
|
||||
text += "Nu exista conturi de casa configurate."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_treasury_banca_response(data: Dict[str, Any], company_name: str = None) -> str:
|
||||
"""
|
||||
Format treasury BANK data for Telegram (content only, no header).
|
||||
|
||||
Args:
|
||||
data: Dict with banca accounts and total from treasury breakdown
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram
|
||||
|
||||
Example:
|
||||
data = {'accounts': [...], 'total': 15000}
|
||||
text = format_treasury_banca_response(data)
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Total bank balance - rotunjit la leu (0 zecimale)
|
||||
total_bank = round(data.get('total', 0))
|
||||
text += f"**Sold Total Banca:** {total_bank:,} RON\n\n"
|
||||
|
||||
# Bank accounts
|
||||
bank_accounts = data.get('accounts', [])
|
||||
if bank_accounts:
|
||||
text += "**Conturi Bancare:**\n"
|
||||
for acc in bank_accounts: # Show all accounts
|
||||
name = acc.get('name', 'N/A')
|
||||
balance = round(acc.get('balance', 0))
|
||||
text += f" - {name}: {balance:,} RON\n"
|
||||
else:
|
||||
text += "Nu exista conturi bancare configurate."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_clients_balance_response(
|
||||
clients: List[Dict[str, Any]],
|
||||
maturity_data: Dict[str, Any],
|
||||
company_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Format clients balance with maturity breakdown (content only, no header).
|
||||
|
||||
Args:
|
||||
clients: List of client dicts with id, name, balance
|
||||
maturity_data: Dict with in_term, overdue, total
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram
|
||||
|
||||
Example:
|
||||
clients = [{'id': 1, 'name': 'Client A', 'balance': 15000}]
|
||||
maturity = {'in_term': 10000, 'overdue': 5000, 'total': 15000}
|
||||
text = format_clients_balance_response(clients, maturity)
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Maturity breakdown - rotunjit la leu (0 zecimale)
|
||||
total = round(maturity_data.get('total', 0))
|
||||
in_term = round(maturity_data.get('in_term', 0))
|
||||
overdue = round(maturity_data.get('overdue', 0))
|
||||
|
||||
text += f"**Sold Total:** {total:,} RON\n\n"
|
||||
|
||||
text += "**Defalcare:**\n"
|
||||
text += f" - In termen: {in_term:,} RON\n"
|
||||
text += f" - Restanta: {overdue:,} RON\n\n"
|
||||
|
||||
# Top 10 clients
|
||||
if clients:
|
||||
text += f"**Top 10 Clienti** ({len(clients)} total):\n"
|
||||
# Sort by balance descending
|
||||
sorted_clients = sorted(
|
||||
clients,
|
||||
key=lambda x: x.get('balance', 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for idx, client in enumerate(sorted_clients[:10], 1):
|
||||
name = client.get('name', 'N/A')
|
||||
balance = round(client.get('balance', 0))
|
||||
text += f"{idx}. {name}: {balance:,} RON\n"
|
||||
|
||||
if len(clients) > 10:
|
||||
text += f"\nApasa butonul pentru lista completa"
|
||||
else:
|
||||
text += "Nu exista clienti cu solduri."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_suppliers_balance_response(
|
||||
suppliers: List[Dict[str, Any]],
|
||||
maturity_data: Dict[str, Any],
|
||||
company_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Format suppliers balance with maturity breakdown (content only, no header).
|
||||
|
||||
Args:
|
||||
suppliers: List of supplier dicts with id, name, balance
|
||||
maturity_data: Dict with in_term, overdue, total
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram
|
||||
|
||||
Example:
|
||||
suppliers = [{'id': 1, 'name': 'Supplier A', 'balance': 5000}]
|
||||
maturity = {'in_term': 4000, 'overdue': 1000, 'total': 5000}
|
||||
text = format_suppliers_balance_response(suppliers, maturity)
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Maturity breakdown - rotunjit la leu (0 zecimale)
|
||||
total = round(maturity_data.get('total', 0))
|
||||
in_term = round(maturity_data.get('in_term', 0))
|
||||
overdue = round(maturity_data.get('overdue', 0))
|
||||
|
||||
text += f"**Sold Total:** {total:,} RON\n\n"
|
||||
|
||||
text += "**Defalcare:**\n"
|
||||
text += f" - In termen: {in_term:,} RON\n"
|
||||
text += f" - Restanta: {overdue:,} RON\n\n"
|
||||
|
||||
# Top 10 suppliers
|
||||
if suppliers:
|
||||
text += f"**Top 10 Furnizori** ({len(suppliers)} total):\n"
|
||||
# Sort by balance descending
|
||||
sorted_suppliers = sorted(
|
||||
suppliers,
|
||||
key=lambda x: x.get('balance', 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for idx, supplier in enumerate(sorted_suppliers[:10], 1):
|
||||
name = supplier.get('name', 'N/A')
|
||||
balance = round(supplier.get('balance', 0))
|
||||
text += f"{idx}. {name}: {balance:,} RON\n"
|
||||
|
||||
if len(suppliers) > 10:
|
||||
text += f"\nApasa butonul pentru lista completa"
|
||||
else:
|
||||
text += "Nu exista furnizori cu solduri."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_cashflow_evolution_response(
|
||||
performance_data: Dict[str, Any],
|
||||
monthly_data: Dict[str, Any],
|
||||
company_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Format cash flow evolution data - Table format with mini-charts.
|
||||
|
||||
Args:
|
||||
performance_data: Dict with current_year and previous_year YTD data
|
||||
monthly_data: Dict with months, incasari, plati arrays + prev year data
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram (monospace table)
|
||||
|
||||
Example:
|
||||
YTD 2024 vs 2023:
|
||||
2024 2023 Δ Trend
|
||||
Inc: 500,000 480,000 +4.2% ████░
|
||||
Plt: 450,000 440,000 +2.3% ███░
|
||||
Net: 50,000 40,000 +25.0% █████
|
||||
"""
|
||||
text = ""
|
||||
|
||||
# Helper functions
|
||||
def calc_percent_change(current: float, previous: float) -> str:
|
||||
"""Calculate percentage change: +4.2% or -3.5%"""
|
||||
if previous == 0:
|
||||
return "+100%" if current > 0 else "0.0%"
|
||||
change = ((current - previous) / previous) * 100
|
||||
sign = "+" if change >= 0 else ""
|
||||
return f"{sign}{change:.1f}%"
|
||||
|
||||
def create_mini_chart(current: float, previous: float, width: int = 5) -> str:
|
||||
"""Create mini bar chart: ████░ (proportional bars)"""
|
||||
if current == 0 and previous == 0:
|
||||
return "─" * width
|
||||
|
||||
max_val = max(current, previous)
|
||||
if max_val == 0:
|
||||
return "─" * width
|
||||
|
||||
curr_bars = int((current / max_val) * width)
|
||||
prev_bars = int((previous / max_val) * width)
|
||||
|
||||
# Use filled and light blocks
|
||||
filled = "█" * curr_bars
|
||||
light = "░" * (width - curr_bars)
|
||||
return filled + light
|
||||
|
||||
def get_trend_arrow(current: float, previous: float) -> str:
|
||||
"""Get trend arrow: ↑ or ↓ or →"""
|
||||
if current > previous * 1.02: # More than 2% increase
|
||||
return "↑"
|
||||
elif current < previous * 0.98: # More than 2% decrease
|
||||
return "↓"
|
||||
else:
|
||||
return "→"
|
||||
|
||||
# Extract YTD data
|
||||
current = performance_data.get('current_year', {})
|
||||
previous = performance_data.get('previous_year', {})
|
||||
|
||||
current_year = current.get('year', '2024')
|
||||
previous_year = previous.get('year', '2023')
|
||||
|
||||
inc_cur = round(current.get('incasari', 0))
|
||||
plt_cur = round(current.get('plati', 0))
|
||||
net_cur = round(current.get('net', 0))
|
||||
|
||||
inc_prev = round(previous.get('incasari', 0))
|
||||
plt_prev = round(previous.get('plati', 0))
|
||||
net_prev = round(previous.get('net', 0))
|
||||
|
||||
# YTD Table Header
|
||||
text += f"**YTD {current_year} vs {previous_year}:**\n"
|
||||
text += f"` {current_year:>10} {previous_year:>10} Δ `\n"
|
||||
|
||||
# YTD Rows
|
||||
inc_pct = calc_percent_change(inc_cur, inc_prev)
|
||||
text += f"`Inc: {inc_cur:>10,} {inc_prev:>10,} {inc_pct:>6}`\n"
|
||||
|
||||
plt_pct = calc_percent_change(plt_cur, plt_prev)
|
||||
text += f"`Plt: {plt_cur:>10,} {plt_prev:>10,} {plt_pct:>6}`\n"
|
||||
|
||||
net_pct = calc_percent_change(net_cur, net_prev)
|
||||
text += f"`Net: {net_cur:>10,} {net_prev:>10,} {net_pct:>6}`\n\n"
|
||||
|
||||
# Monthly Evolution Table - Simplified (Net only)
|
||||
months = monthly_data.get('months', [])
|
||||
incasari = monthly_data.get('incasari', [])
|
||||
plati = monthly_data.get('plati', [])
|
||||
incasari_prev = monthly_data.get('incasari_prev', [])
|
||||
plati_prev = monthly_data.get('plati_prev', [])
|
||||
|
||||
if months and len(months) > 0:
|
||||
text += "**Evolutie Net (12 luni):**\n"
|
||||
text += f"` {current_year:>10} {previous_year:>10} Δ `\n"
|
||||
|
||||
for i, month in enumerate(months):
|
||||
inc = incasari[i] if i < len(incasari) else 0
|
||||
plt = plati[i] if i < len(plati) else 0
|
||||
inc_p = incasari_prev[i] if i < len(incasari_prev) else 0
|
||||
plt_p = plati_prev[i] if i < len(plati_prev) else 0
|
||||
|
||||
net = inc - plt
|
||||
net_p = inc_p - plt_p
|
||||
|
||||
# Extract short month name (first 3 chars before apostrophe)
|
||||
month_short = month.split("'")[0][:3] if "'" in month else month[:3]
|
||||
|
||||
# Calculate percentage change
|
||||
net_pct = calc_percent_change(net, net_p)
|
||||
|
||||
# Format row: Luna Net'current Net'prev Δ (aligned with YTD)
|
||||
text += f"`{month_short:<4} {int(net):>10,} {int(net_p):>10,} {net_pct:>6}`\n"
|
||||
else:
|
||||
text += "Nu exista date lunare disponibile."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_client_detail_response(
|
||||
client: Dict[str, Any],
|
||||
invoices: List[Dict[str, Any]],
|
||||
company_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Format client details with invoices - COMPACT TABLE FORMAT.
|
||||
|
||||
Args:
|
||||
client: Dict with client info (id, name, balance)
|
||||
invoices: List of invoice dicts for this client
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram (compact, no emojis)
|
||||
|
||||
Example:
|
||||
client = {'id': 1, 'name': 'Client A', 'balance': 15000}
|
||||
invoices = [{'id': 1, 'number': 'FV001', 'amount': 5000, 'status': 'unpaid'}]
|
||||
text = format_client_detail_response(client, invoices)
|
||||
"""
|
||||
client_name = client.get('name', 'N/A')
|
||||
balance = client.get('balance', 0)
|
||||
|
||||
# Header with client info
|
||||
text = f"**{client_name}**\n"
|
||||
text += f"**Sold total: {balance:,.2f} RON**"
|
||||
if invoices and len(invoices) > 1:
|
||||
text += f" • {len(invoices)} facturi"
|
||||
text += "\n\n"
|
||||
|
||||
# Invoices - compact table format (no emojis)
|
||||
if invoices:
|
||||
from datetime import datetime
|
||||
|
||||
# Sort invoices by date (most recent first)
|
||||
sorted_invoices = sorted(invoices, key=lambda x: x.get('dataact') or datetime.min, reverse=True)
|
||||
|
||||
# Invoice list - simple format without table
|
||||
text += "Facturi cu sold:\n"
|
||||
text += "━━━━━━━━━━━━━━━━━━━━\n"
|
||||
|
||||
# Invoice rows - one line each, simple format
|
||||
for inv in sorted_invoices[:10]:
|
||||
# Backend returns: nract, totctva, soldfinal, datascad, dataact, achitat
|
||||
number = str(inv.get('nract', 'N/A'))
|
||||
dataact = inv.get('dataact')
|
||||
|
||||
# Parse date - handle various formats to ensure dd.mm.yyyy
|
||||
if dataact:
|
||||
if isinstance(dataact, str):
|
||||
try:
|
||||
# Try ISO format first: "2024-10-25" or "2024-10-25 00:00:00"
|
||||
if '-' in dataact and len(dataact) >= 10:
|
||||
parsed_date = datetime.strptime(dataact[:10], '%Y-%m-%d')
|
||||
date_str = parsed_date.strftime('%d.%m.%Y')
|
||||
# Already in dd.mm.yyyy format
|
||||
elif '.' in dataact:
|
||||
date_str = dataact.split()[0][:10] # Take just date part
|
||||
else:
|
||||
date_str = dataact[:10] if len(dataact) >= 10 else dataact
|
||||
except:
|
||||
date_str = dataact[:10] if len(dataact) >= 10 else dataact
|
||||
else:
|
||||
# Datetime object - format as dd.mm.yyyy
|
||||
date_str = dataact.strftime('%d.%m.%Y')
|
||||
else:
|
||||
date_str = 'N/A'
|
||||
|
||||
sold = float(inv.get('soldfinal', 0) or 0)
|
||||
|
||||
# Simple format: Nr • Data • Sold
|
||||
text += f"Nr {number} • {date_str} • {sold:,.2f} RON\n"
|
||||
|
||||
if len(invoices) > 10:
|
||||
text += f"\n\n+{len(invoices) - 10} facturi"
|
||||
else:
|
||||
text += "Nu exista facturi neachitate"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_supplier_detail_response(
|
||||
supplier: Dict[str, Any],
|
||||
invoices: List[Dict[str, Any]],
|
||||
company_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Format supplier details with invoices - COMPACT TABLE FORMAT.
|
||||
|
||||
Args:
|
||||
supplier: Dict with supplier info (id, name, balance)
|
||||
invoices: List of invoice dicts for this supplier
|
||||
company_name: Company name (kept for compatibility, not used)
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string for Telegram (compact, no emojis)
|
||||
|
||||
Example:
|
||||
supplier = {'id': 1, 'name': 'Supplier A', 'balance': 5000}
|
||||
invoices = [{'id': 1, 'number': 'FC001', 'amount': 2000, 'status': 'unpaid'}]
|
||||
text = format_supplier_detail_response(supplier, invoices)
|
||||
"""
|
||||
supplier_name = supplier.get('name', 'N/A')
|
||||
balance = supplier.get('balance', 0)
|
||||
|
||||
# Header with supplier info
|
||||
text = f"**{supplier_name}**\n"
|
||||
text += f"**Sold total: {balance:,.2f} RON**"
|
||||
if invoices and len(invoices) > 1:
|
||||
text += f" • {len(invoices)} facturi"
|
||||
text += "\n\n"
|
||||
|
||||
# Invoices - compact table format (no emojis)
|
||||
if invoices:
|
||||
from datetime import datetime
|
||||
|
||||
# Sort invoices by date (most recent first)
|
||||
sorted_invoices = sorted(invoices, key=lambda x: x.get('dataact') or datetime.min, reverse=True)
|
||||
|
||||
# Invoice list - simple format without table
|
||||
text += "Facturi cu sold:\n"
|
||||
text += "━━━━━━━━━━━━━━━━━━━━\n"
|
||||
|
||||
# Invoice rows - one line each, simple format
|
||||
for inv in sorted_invoices[:10]:
|
||||
# Backend returns: nract, totctva, soldfinal, datascad, dataact, achitat
|
||||
number = str(inv.get('nract', 'N/A'))
|
||||
dataact = inv.get('dataact')
|
||||
|
||||
# Parse date - handle various formats to ensure dd.mm.yyyy
|
||||
if dataact:
|
||||
if isinstance(dataact, str):
|
||||
try:
|
||||
# Try ISO format first: "2024-10-25" or "2024-10-25 00:00:00"
|
||||
if '-' in dataact and len(dataact) >= 10:
|
||||
parsed_date = datetime.strptime(dataact[:10], '%Y-%m-%d')
|
||||
date_str = parsed_date.strftime('%d.%m.%Y')
|
||||
# Already in dd.mm.yyyy format
|
||||
elif '.' in dataact:
|
||||
date_str = dataact.split()[0][:10] # Take just date part
|
||||
else:
|
||||
date_str = dataact[:10] if len(dataact) >= 10 else dataact
|
||||
except:
|
||||
date_str = dataact[:10] if len(dataact) >= 10 else dataact
|
||||
else:
|
||||
# Datetime object - format as dd.mm.yyyy
|
||||
date_str = dataact.strftime('%d.%m.%Y')
|
||||
else:
|
||||
date_str = 'N/A'
|
||||
|
||||
sold = float(inv.get('soldfinal', 0) or 0)
|
||||
|
||||
# Simple format: Nr • Data • Sold
|
||||
text += f"Nr {number} • {date_str} • {sold:,.2f} RON\n"
|
||||
|
||||
if len(invoices) > 10:
|
||||
text += f"\n\n+{len(invoices) - 10} facturi"
|
||||
else:
|
||||
text += "Nu exista facturi neachitate"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# FAZA 6: Performance Footer for Cache Monitoring
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def add_performance_footer(message: str, cache_hit: bool, time_ms: float, cache_source: str = None) -> str:
|
||||
"""
|
||||
Add compact performance footer to bot responses.
|
||||
|
||||
Shows data source (cached L1/L2 or database) and response time.
|
||||
Format: "cached L1 | 15ms", "cached L2 | 25ms" or "db | 285ms"
|
||||
|
||||
Args:
|
||||
message: Existing message text
|
||||
cache_hit: True if data came from cache
|
||||
time_ms: Response time in milliseconds
|
||||
cache_source: Cache source ("L1" for memory, "L2" for SQLite) if cache_hit is True
|
||||
|
||||
Returns:
|
||||
Message with performance footer appended
|
||||
|
||||
Example:
|
||||
>>> add_performance_footer("Dashboard data...", True, 52.3, "L1")
|
||||
"Dashboard data...\n\ncached L1 | 52ms"
|
||||
>>> add_performance_footer("Dashboard data...", True, 25.8, "L2")
|
||||
"Dashboard data...\n\ncached L2 | 26ms"
|
||||
>>> add_performance_footer("Dashboard data...", False, 285.7)
|
||||
"Dashboard data...\n\ndb | 286ms"
|
||||
"""
|
||||
if cache_hit and cache_source:
|
||||
source = f"cached {cache_source}"
|
||||
elif cache_hit:
|
||||
source = "cached" # Fallback if source not provided
|
||||
else:
|
||||
source = "db"
|
||||
|
||||
footer = f"\n\n`{source} | {time_ms:.0f}ms`"
|
||||
return message + footer
|
||||
|
||||
2777
backend/modules/telegram/bot/handlers.py
Normal file
2777
backend/modules/telegram/bot/handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
814
backend/modules/telegram/bot/helpers.py
Normal file
814
backend/modules/telegram/bot/helpers.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Helper functions for Telegram bot command handlers.
|
||||
Provides utilities for company selection, API calls, and response formatting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, List, Any
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
|
||||
from backend.modules.telegram.api.client import get_backend_client
|
||||
from backend.modules.telegram.agent.session import SessionManager
|
||||
from backend.modules.telegram.bot.menus import pad_message_for_wide_buttons
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_active_company_or_prompt(
|
||||
update: Update,
|
||||
session_manager: SessionManager,
|
||||
telegram_user_id: int
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get active company from session or prompt user to select one with buttons.
|
||||
|
||||
This function checks if the user has an active company set in their session.
|
||||
If not, it fetches companies and displays selection buttons directly.
|
||||
|
||||
Args:
|
||||
update: Telegram Update object (for sending messages)
|
||||
session_manager: SessionManager instance
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Dict with company info (id, name, cui) if set, None if user needs to select
|
||||
|
||||
Example:
|
||||
company = await get_active_company_or_prompt(update, session_manager, user_id)
|
||||
if not company:
|
||||
return # User was shown company selection buttons
|
||||
# Continue with company operations...
|
||||
"""
|
||||
session = await session_manager.get_or_create_session(telegram_user_id)
|
||||
company = session.get_active_company()
|
||||
|
||||
if not company:
|
||||
# Get auth data and companies
|
||||
from backend.modules.telegram.auth.linking import get_user_auth_data
|
||||
auth_data = await get_user_auth_data(telegram_user_id)
|
||||
jwt_token = auth_data['jwt_token']
|
||||
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
companies = await client.get_user_companies(jwt_token=jwt_token)
|
||||
|
||||
if companies:
|
||||
keyboard = create_company_selection_keyboard_paginated(companies, page=0)
|
||||
message = (
|
||||
f"**Selecteaza mai intai o companie**\n\n"
|
||||
f"Companiile tale ({len(companies)}):"
|
||||
)
|
||||
# Apply padding to make inline keyboard buttons wider
|
||||
message = pad_message_for_wide_buttons(message)
|
||||
await update.message.reply_text(
|
||||
message,
|
||||
reply_markup=keyboard,
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
else:
|
||||
await update.message.reply_text(
|
||||
"Nu ai acces la nicio companie.\n"
|
||||
"Contacteaza administratorul.",
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
return None
|
||||
|
||||
return company
|
||||
|
||||
|
||||
async def search_companies_by_name(
|
||||
name_query: str,
|
||||
jwt_token: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search companies by partial name match (case-insensitive).
|
||||
|
||||
Fetches all companies from backend and filters them by name.
|
||||
Uses case-insensitive partial matching for flexible search.
|
||||
|
||||
Args:
|
||||
name_query: Search term (partial match, e.g., "ACME")
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
List of matching company dicts (each with id, nume_firma, cui, etc.)
|
||||
|
||||
Example:
|
||||
companies = await search_companies_by_name("acme", token)
|
||||
# Returns all companies with "acme" in their name (case-insensitive)
|
||||
"""
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
all_companies = await client.get_user_companies(jwt_token=jwt_token)
|
||||
|
||||
# Filter by name (case-insensitive partial match)
|
||||
query_lower = name_query.lower()
|
||||
matches = [
|
||||
comp for comp in all_companies
|
||||
if query_lower in comp.get('name', comp.get('nume_firma', '')).lower()
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Search '{name_query}': {len(matches)} matches out of {len(all_companies)} total"
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def create_company_selection_keyboard(
|
||||
companies: List[Dict[str, Any]],
|
||||
max_buttons: int = 10
|
||||
) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create inline keyboard for company selection (legacy - without pagination).
|
||||
|
||||
Generates a vertical list of buttons, one per company.
|
||||
Each button shows company name and CUI, and triggers a callback.
|
||||
|
||||
NOTE: This function is deprecated in favor of create_company_selection_keyboard_paginated.
|
||||
It's kept for backwards compatibility only.
|
||||
|
||||
Args:
|
||||
companies: List of company dicts (with id, nume_firma, cui)
|
||||
max_buttons: Maximum number of buttons to show (default: 10)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with company selection buttons
|
||||
|
||||
Example:
|
||||
keyboard = create_company_selection_keyboard(companies)
|
||||
await update.message.reply_text("Select company:", reply_markup=keyboard)
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
for company in companies[:max_buttons]:
|
||||
company_id = company.get('id_firma', company.get('id'))
|
||||
company_name = company.get('name', company.get('nume_firma', 'N/A'))
|
||||
company_cui = company.get('fiscal_code', company.get('cui', ''))
|
||||
|
||||
# Button text: "ACME SRL (CUI: 12345)"
|
||||
button_text = f"{company_name}"
|
||||
if company_cui:
|
||||
button_text += f" ({company_cui})"
|
||||
|
||||
# Callback data: "select_company:123"
|
||||
callback_data = f"select_company:{company_id}"
|
||||
|
||||
keyboard.append([InlineKeyboardButton(button_text, callback_data=callback_data)])
|
||||
|
||||
# Add overflow indicator if there are more companies
|
||||
if len(companies) > max_buttons:
|
||||
keyboard.append([InlineKeyboardButton(
|
||||
f"... și încă {len(companies) - max_buttons} companii",
|
||||
callback_data="noop"
|
||||
)])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_company_selection_keyboard_paginated(
|
||||
companies: List[Dict[str, Any]],
|
||||
page: int = 0,
|
||||
per_page: int = 10
|
||||
) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create paginated inline keyboard for company selection.
|
||||
|
||||
Generates a vertical list of buttons for one page of companies,
|
||||
with navigation buttons for previous/next pages.
|
||||
|
||||
Args:
|
||||
companies: Full list of company dicts (with id, nume_firma, cui)
|
||||
page: Current page number (0-indexed)
|
||||
per_page: Number of companies per page (default: 10)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with company buttons and pagination controls
|
||||
|
||||
Example:
|
||||
keyboard = create_company_selection_keyboard_paginated(companies, page=0)
|
||||
await update.message.reply_text("Select company:", reply_markup=keyboard)
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Calculate pagination
|
||||
total_companies = len(companies)
|
||||
total_pages = (total_companies + per_page - 1) // per_page # Ceiling division
|
||||
start_idx = page * per_page
|
||||
end_idx = min(start_idx + per_page, total_companies)
|
||||
|
||||
# Display companies for current page
|
||||
page_companies = companies[start_idx:end_idx]
|
||||
|
||||
for company in page_companies:
|
||||
company_id = company.get('id_firma', company.get('id'))
|
||||
company_name = company.get('name', company.get('nume_firma', 'N/A'))
|
||||
company_cui = company.get('fiscal_code', company.get('cui', ''))
|
||||
|
||||
# Button text: "ACME SRL (CUI: 12345)"
|
||||
button_text = f"{company_name}"
|
||||
if company_cui:
|
||||
button_text += f" ({company_cui})"
|
||||
|
||||
# Callback data: "select_company:123"
|
||||
callback_data = f"select_company:{company_id}"
|
||||
|
||||
keyboard.append([InlineKeyboardButton(button_text, callback_data=callback_data)])
|
||||
|
||||
# Pagination controls (only if more than one page)
|
||||
if total_pages > 1:
|
||||
nav_buttons = []
|
||||
|
||||
# Previous button
|
||||
if page > 0:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("< Anterior", callback_data=f"select_company_page:{page-1}")
|
||||
)
|
||||
|
||||
# Page indicator (non-clickable)
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
|
||||
)
|
||||
|
||||
# Next button
|
||||
if page < total_pages - 1:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("Urmator >", callback_data=f"select_company_page:{page+1}")
|
||||
)
|
||||
|
||||
keyboard.append(nav_buttons)
|
||||
|
||||
# Back to menu button
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("< Inapoi la Meniu", callback_data="action:menu")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def format_company_context_footer(company_name: str) -> str:
|
||||
"""
|
||||
Format discrete footer with company context.
|
||||
|
||||
Adds a subtle footer to command responses showing the active company
|
||||
and a quick link to change it.
|
||||
|
||||
Args:
|
||||
company_name: Active company name
|
||||
|
||||
Returns:
|
||||
Formatted footer string with separator and company name
|
||||
|
||||
Example:
|
||||
footer = format_company_context_footer("ACME SRL")
|
||||
message = f"Dashboard data...\n{footer}"
|
||||
# Output: "Dashboard data...\n\n━━━━━━━━━━━━━━\nCompanie: ACME SRL"
|
||||
"""
|
||||
return f"\n\n━━━━━━━━━━━━━━\nCompanie: {company_name}"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# FAZA 2: New Helper Functions for Button Interface
|
||||
# =========================================================================
|
||||
|
||||
|
||||
async def get_treasury_breakdown_split(
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get treasury breakdown split into casa and banca.
|
||||
|
||||
Fetches treasury breakdown from backend and transforms it
|
||||
to the format expected by formatters.
|
||||
|
||||
Backend returns:
|
||||
{
|
||||
"total": float,
|
||||
"breakdown": {
|
||||
"casa": {"total": float, "items": [{"nume": str, "cont": str, "sold": float}]},
|
||||
"banca": {"total": float, "items": [{"nume": str, "cont": str, "sold": float}]}
|
||||
},
|
||||
"currency": "RON"
|
||||
}
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
Dict with two keys:
|
||||
- 'casa': Dict with 'accounts' (list) and 'total' (float)
|
||||
- 'banca': Dict with 'accounts' (list) and 'total' (float)
|
||||
|
||||
None if request fails
|
||||
|
||||
Example:
|
||||
data = await get_treasury_breakdown_split(1, token)
|
||||
casa_total = data['casa']['total'] # Total cash balance
|
||||
bank_accounts = data['banca']['accounts'] # List of bank accounts
|
||||
"""
|
||||
try:
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
breakdown = await client.get_treasury_breakdown(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token
|
||||
)
|
||||
|
||||
if not breakdown:
|
||||
return None
|
||||
|
||||
# Backend already splits data into casa and banca
|
||||
# Transform backend structure to match formatter expectations
|
||||
breakdown_data = breakdown.get('breakdown', {})
|
||||
casa_data = breakdown_data.get('casa', {})
|
||||
banca_data = breakdown_data.get('banca', {})
|
||||
|
||||
# Transform items to accounts format (nume->name, sold->balance)
|
||||
casa_accounts = [
|
||||
{
|
||||
'name': item.get('nume', f"Cont {item.get('cont', 'N/A')}"),
|
||||
'balance': float(item.get('sold', 0)),
|
||||
'cont': item.get('cont', '')
|
||||
}
|
||||
for item in casa_data.get('items', [])
|
||||
]
|
||||
|
||||
banca_accounts = [
|
||||
{
|
||||
'name': item.get('nume', f"Cont {item.get('cont', 'N/A')}"),
|
||||
'balance': float(item.get('sold', 0)),
|
||||
'cont': item.get('cont', '')
|
||||
}
|
||||
for item in banca_data.get('items', [])
|
||||
]
|
||||
|
||||
result = {
|
||||
'casa': {
|
||||
'accounts': casa_accounts,
|
||||
'total': float(casa_data.get('total', 0))
|
||||
},
|
||||
'banca': {
|
||||
'accounts': banca_accounts,
|
||||
'total': float(banca_data.get('total', 0))
|
||||
}
|
||||
}
|
||||
|
||||
# Pass through cache metadata if present
|
||||
if 'cache_hit' in breakdown:
|
||||
result['cache_hit'] = breakdown['cache_hit']
|
||||
if 'response_time_ms' in breakdown:
|
||||
result['response_time_ms'] = breakdown['response_time_ms']
|
||||
if 'cache_source' in breakdown:
|
||||
result['cache_source'] = breakdown['cache_source']
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting treasury breakdown split: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_clients_with_maturity(
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get clients list with maturity breakdown.
|
||||
|
||||
Uses maturity analysis endpoint which returns client summaries
|
||||
with amounts and overdue status.
|
||||
|
||||
Backend returns:
|
||||
{
|
||||
"clients": [{"name": str, "amount": float, "dueDate": str, "daysOverdue": int}],
|
||||
"suppliers": [...],
|
||||
"balance": float,
|
||||
"metadata": {...}
|
||||
}
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- 'clients': List of client dicts (id, name, balance)
|
||||
- 'maturity': Dict with 'in_term', 'overdue', 'total' amounts
|
||||
|
||||
None if request fails
|
||||
|
||||
Example:
|
||||
data = await get_clients_with_maturity(1, token)
|
||||
clients = data['clients'] # List of all clients
|
||||
overdue = data['maturity']['overdue'] # Overdue amount
|
||||
"""
|
||||
try:
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
# Get maturity analysis (contains client summaries)
|
||||
maturity_response = await client.get_maturity_data(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token,
|
||||
period='all'
|
||||
)
|
||||
|
||||
if not maturity_response:
|
||||
return None
|
||||
|
||||
# Extract clients from maturity response
|
||||
clients_raw = maturity_response.get('clients', [])
|
||||
|
||||
# Transform to expected format: amount → balance
|
||||
clients = [
|
||||
{
|
||||
'name': c.get('name', 'N/A'),
|
||||
'balance': float(c.get('amount', 0)),
|
||||
'daysOverdue': c.get('daysOverdue', 0)
|
||||
}
|
||||
for c in clients_raw
|
||||
]
|
||||
|
||||
# Calculate maturity breakdown from clients data
|
||||
total = sum(c['balance'] for c in clients)
|
||||
overdue = sum(c['balance'] for c in clients if c.get('daysOverdue', 0) > 0)
|
||||
in_term = total - overdue
|
||||
|
||||
result = {
|
||||
'clients': clients,
|
||||
'maturity': {
|
||||
'in_term': in_term,
|
||||
'overdue': overdue,
|
||||
'total': total
|
||||
}
|
||||
}
|
||||
|
||||
# Pass through cache metadata if present
|
||||
if 'cache_hit' in maturity_response:
|
||||
result['cache_hit'] = maturity_response['cache_hit']
|
||||
if 'response_time_ms' in maturity_response:
|
||||
result['response_time_ms'] = maturity_response['response_time_ms']
|
||||
if 'cache_source' in maturity_response:
|
||||
result['cache_source'] = maturity_response['cache_source']
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting clients with maturity: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_suppliers_with_maturity(
|
||||
company_id: int,
|
||||
jwt_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get suppliers list with maturity breakdown.
|
||||
|
||||
Uses maturity analysis endpoint which returns supplier summaries
|
||||
with amounts and overdue status.
|
||||
|
||||
Backend returns:
|
||||
{
|
||||
"clients": [...],
|
||||
"suppliers": [{"name": str, "amount": float, "dueDate": str, "daysOverdue": int}],
|
||||
"balance": float,
|
||||
"metadata": {...}
|
||||
}
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- 'suppliers': List of supplier dicts (id, name, balance)
|
||||
- 'maturity': Dict with 'in_term', 'overdue', 'total' amounts
|
||||
|
||||
None if request fails
|
||||
|
||||
Example:
|
||||
data = await get_suppliers_with_maturity(1, token)
|
||||
suppliers = data['suppliers'] # List of all suppliers
|
||||
in_term = data['maturity']['in_term'] # In-term amount
|
||||
"""
|
||||
try:
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
# Get maturity analysis (contains supplier summaries)
|
||||
maturity_response = await client.get_maturity_data(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token,
|
||||
period='all'
|
||||
)
|
||||
|
||||
if not maturity_response:
|
||||
return None
|
||||
|
||||
# Extract suppliers from maturity response
|
||||
suppliers_raw = maturity_response.get('suppliers', [])
|
||||
|
||||
# Transform to expected format: amount → balance
|
||||
suppliers = [
|
||||
{
|
||||
'name': s.get('name', 'N/A'),
|
||||
'balance': float(s.get('amount', 0)),
|
||||
'daysOverdue': s.get('daysOverdue', 0)
|
||||
}
|
||||
for s in suppliers_raw
|
||||
]
|
||||
|
||||
# Calculate maturity breakdown from suppliers data
|
||||
total = sum(s['balance'] for s in suppliers)
|
||||
overdue = sum(s['balance'] for s in suppliers if s.get('daysOverdue', 0) > 0)
|
||||
in_term = total - overdue
|
||||
|
||||
result = {
|
||||
'suppliers': suppliers,
|
||||
'maturity': {
|
||||
'in_term': in_term,
|
||||
'overdue': overdue,
|
||||
'total': total
|
||||
}
|
||||
}
|
||||
|
||||
# Pass through cache metadata if present
|
||||
if 'cache_hit' in maturity_response:
|
||||
result['cache_hit'] = maturity_response['cache_hit']
|
||||
if 'response_time_ms' in maturity_response:
|
||||
result['response_time_ms'] = maturity_response['response_time_ms']
|
||||
if 'cache_source' in maturity_response:
|
||||
result['cache_source'] = maturity_response['cache_source']
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting suppliers with maturity: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_cashflow_evolution_data(
|
||||
company_id: int,
|
||||
jwt_token: str,
|
||||
period: str = "12m"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cash flow evolution data with YTD comparison.
|
||||
|
||||
Uses trends endpoint which returns 12-month historical data for current and previous year.
|
||||
Calculates YTD for comparison and extracts last 12 months in reverse chronological order.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
jwt_token: JWT authentication token
|
||||
period: Period for trends data (default: "12m")
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- 'performance': Dict with YTD data for current and previous year
|
||||
- 'monthly': Dict with last 12 months data (reverse chronological) + prev year comparison
|
||||
|
||||
None if request fails
|
||||
|
||||
Example:
|
||||
data = await get_cashflow_evolution_data(1, token)
|
||||
ytd_2025 = data['performance']['current_year']
|
||||
ytd_2024 = data['performance']['previous_year']
|
||||
"""
|
||||
try:
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
# Get trends data (12 months of historical data)
|
||||
trends_data = await client.get_trends(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token,
|
||||
period="12m"
|
||||
)
|
||||
|
||||
if not trends_data:
|
||||
return None
|
||||
|
||||
# Extract current year data
|
||||
periods = trends_data.get('periods', []) # ["2024-01", "2024-02", ...]
|
||||
clienti_incasat = trends_data.get('clienti_incasat', [])
|
||||
furnizori_achitat = trends_data.get('furnizori_achitat', [])
|
||||
|
||||
# Extract previous year data
|
||||
previous_periods = trends_data.get('previous_periods', [])
|
||||
clienti_incasat_prev = trends_data.get('clienti_incasat_prev', [])
|
||||
furnizori_achitat_prev = trends_data.get('furnizori_achitat_prev', [])
|
||||
|
||||
if not periods or not clienti_incasat or not furnizori_achitat:
|
||||
logger.warning("Trends data missing required fields")
|
||||
return None
|
||||
|
||||
# Calculate YTD (Year-To-Date) = sum of all available months
|
||||
incasari_ytd = sum(clienti_incasat)
|
||||
plati_ytd = sum(furnizori_achitat)
|
||||
net_ytd = incasari_ytd - plati_ytd
|
||||
|
||||
incasari_ytd_prev = sum(clienti_incasat_prev) if clienti_incasat_prev else 0
|
||||
plati_ytd_prev = sum(furnizori_achitat_prev) if furnizori_achitat_prev else 0
|
||||
net_ytd_prev = incasari_ytd_prev - plati_ytd_prev
|
||||
|
||||
# Extract years from periods
|
||||
current_year = periods[-1].split('-')[0] if periods else "2025"
|
||||
previous_year = previous_periods[-1].split('-')[0] if previous_periods else "2024"
|
||||
|
||||
# Take last 12 months (current year)
|
||||
last_12_periods = periods[-12:]
|
||||
last_12_incasari = clienti_incasat[-12:]
|
||||
last_12_plati = furnizori_achitat[-12:]
|
||||
|
||||
# Take corresponding previous year months
|
||||
last_12_periods_prev = previous_periods[-12:] if previous_periods else []
|
||||
last_12_incasari_prev = clienti_incasat_prev[-12:] if clienti_incasat_prev else [0] * 12
|
||||
last_12_plati_prev = furnizori_achitat_prev[-12:] if furnizori_achitat_prev else [0] * 12
|
||||
|
||||
# Month abbreviations (Romanian)
|
||||
month_abbr = {
|
||||
'01': 'Ian', '02': 'Feb', '03': 'Mar', '04': 'Apr',
|
||||
'05': 'Mai', '06': 'Iun', '07': 'Iul', '08': 'Aug',
|
||||
'09': 'Sep', '10': 'Oct', '11': 'Noi', '12': 'Dec'
|
||||
}
|
||||
|
||||
# Format months as "Noi'25/'24"
|
||||
formatted_months = []
|
||||
for i, period_str in enumerate(last_12_periods):
|
||||
if '-' in period_str:
|
||||
year = period_str.split('-')[0][-2:] # Last 2 digits: "25"
|
||||
month_num = period_str.split('-')[1]
|
||||
month_name = month_abbr.get(month_num, month_num)
|
||||
|
||||
# Get previous year month
|
||||
prev_year = previous_year[-2:] if previous_year else "24"
|
||||
|
||||
formatted_months.append(f"{month_name}'{year}/'{prev_year}")
|
||||
else:
|
||||
formatted_months.append(period_str)
|
||||
|
||||
# Reverse chronological order (newest first)
|
||||
formatted_months.reverse()
|
||||
last_12_incasari.reverse()
|
||||
last_12_plati.reverse()
|
||||
last_12_incasari_prev.reverse()
|
||||
last_12_plati_prev.reverse()
|
||||
|
||||
# Build performance summary (YTD)
|
||||
performance = {
|
||||
'current_year': {
|
||||
'year': current_year,
|
||||
'incasari': incasari_ytd,
|
||||
'plati': plati_ytd,
|
||||
'net': net_ytd
|
||||
},
|
||||
'previous_year': {
|
||||
'year': previous_year,
|
||||
'incasari': incasari_ytd_prev,
|
||||
'plati': plati_ytd_prev,
|
||||
'net': net_ytd_prev
|
||||
}
|
||||
}
|
||||
|
||||
# Build monthly breakdown (reverse chronological with prev year comparison)
|
||||
monthly = {
|
||||
'months': formatted_months,
|
||||
'incasari': last_12_incasari,
|
||||
'plati': last_12_plati,
|
||||
'incasari_prev': last_12_incasari_prev,
|
||||
'plati_prev': last_12_plati_prev
|
||||
}
|
||||
|
||||
result = {
|
||||
'performance': performance,
|
||||
'monthly': monthly
|
||||
}
|
||||
|
||||
# Pass through cache metadata if present
|
||||
if 'cache_hit' in trends_data:
|
||||
result['cache_hit'] = trends_data['cache_hit']
|
||||
if 'response_time_ms' in trends_data:
|
||||
result['response_time_ms'] = trends_data['response_time_ms']
|
||||
if 'cache_source' in trends_data:
|
||||
result['cache_source'] = trends_data['cache_source']
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cashflow evolution data: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_client_invoices(
|
||||
company_id: int,
|
||||
client_name: str,
|
||||
jwt_token: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get invoices for a specific client.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
client_name: Client name to filter by
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
List of invoice dicts for the specified client
|
||||
|
||||
Example:
|
||||
invoices = await get_client_invoices(1, "ACME Corp", token)
|
||||
for inv in invoices:
|
||||
print(inv['number'], inv['amount'])
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Fetching invoices for client '{client_name}' (company_id={company_id})")
|
||||
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
# Filter only by unpaid invoices (with balance > 0)
|
||||
invoices = await client.search_invoices(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token,
|
||||
filters={
|
||||
'partner_type': 'CLIENTI',
|
||||
'partner_name': client_name,
|
||||
'only_unpaid': True # Only show unpaid invoices (matching balance > 0)
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(invoices) if invoices else 0} invoices for client '{client_name}'")
|
||||
|
||||
if invoices:
|
||||
logger.debug(f"First invoice sample: {invoices[0]}")
|
||||
|
||||
return invoices or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting client invoices for '{client_name}': {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_supplier_invoices(
|
||||
company_id: int,
|
||||
supplier_name: str,
|
||||
jwt_token: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get invoices for a specific supplier.
|
||||
|
||||
Args:
|
||||
company_id: Company ID
|
||||
supplier_name: Supplier name to filter by
|
||||
jwt_token: JWT authentication token
|
||||
|
||||
Returns:
|
||||
List of invoice dicts for the specified supplier
|
||||
|
||||
Example:
|
||||
invoices = await get_supplier_invoices(1, "Supplier Inc", token)
|
||||
for inv in invoices:
|
||||
print(inv['number'], inv['amount'])
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Fetching invoices for supplier '{supplier_name}' (company_id={company_id})")
|
||||
|
||||
client = get_backend_client()
|
||||
async with client:
|
||||
# Filter only by unpaid invoices (with balance > 0)
|
||||
invoices = await client.search_invoices(
|
||||
company_id=company_id,
|
||||
jwt_token=jwt_token,
|
||||
filters={
|
||||
'partner_type': 'FURNIZORI',
|
||||
'partner_name': supplier_name,
|
||||
'only_unpaid': True # Only show unpaid invoices (matching balance > 0)
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(invoices) if invoices else 0} invoices for supplier '{supplier_name}'")
|
||||
|
||||
if invoices:
|
||||
logger.debug(f"First invoice sample: {invoices[0]}")
|
||||
|
||||
return invoices or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting supplier invoices for '{supplier_name}': {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
# Export all helper functions
|
||||
__all__ = [
|
||||
'get_active_company_or_prompt',
|
||||
'search_companies_by_name',
|
||||
'create_company_selection_keyboard',
|
||||
'create_company_selection_keyboard_paginated',
|
||||
'format_company_context_footer',
|
||||
'get_treasury_breakdown_split',
|
||||
'get_clients_with_maturity',
|
||||
'get_suppliers_with_maturity',
|
||||
'get_cashflow_evolution_data',
|
||||
'get_client_invoices',
|
||||
'get_supplier_invoices'
|
||||
]
|
||||
0
backend/modules/telegram/bot/keyboards.py
Normal file
0
backend/modules/telegram/bot/keyboards.py
Normal file
607
backend/modules/telegram/bot/menus.py
Normal file
607
backend/modules/telegram/bot/menus.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""
|
||||
Menu builders for Telegram bot inline keyboards.
|
||||
|
||||
This module provides functions to create InlineKeyboardMarkup objects
|
||||
for different menu levels and navigation patterns in the bot.
|
||||
|
||||
NOTE: All button texts are plain text WITHOUT emojis/icons as per requirements.
|
||||
|
||||
BUTTON WIDTH: Inline keyboard width is determined by the message text width.
|
||||
To make buttons wider, we pad message text with invisible characters.
|
||||
"""
|
||||
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# ============================================================================
|
||||
# IMPORTANT: BUTTON WIDTH CONFIGURATION
|
||||
# ============================================================================
|
||||
# Inline keyboard button width is determined by MESSAGE TEXT WIDTH!
|
||||
# DO NOT REMOVE PADDING - it makes buttons wide like BotFather!
|
||||
# ============================================================================
|
||||
|
||||
# Zero-Width Joiner character - invisible but prevents Telegram from trimming spaces
|
||||
# This character has ZERO width (invisible) but prevents space trimming
|
||||
ZERO_WIDTH_JOINER = '\u200D'
|
||||
|
||||
# Target character count per line to make buttons VERY WIDE
|
||||
# Higher value = wider buttons (BotFather uses ~45-50 chars)
|
||||
# DO NOT DECREASE THIS VALUE - buttons will become narrow!
|
||||
TARGET_WIDTH = 50 # Increased from 40 to make buttons WIDER
|
||||
|
||||
# Enable/disable padding globally (useful for testing)
|
||||
# KEEP THIS TRUE - disabling makes buttons narrow!
|
||||
ENABLE_BUTTON_PADDING = True
|
||||
|
||||
|
||||
def _get_current_month_ro() -> str:
|
||||
"""Get current month name in Romanian."""
|
||||
months_ro = {
|
||||
1: "Ianuarie", 2: "Februarie", 3: "Martie", 4: "Aprilie",
|
||||
5: "Mai", 6: "Iunie", 7: "Iulie", 8: "August",
|
||||
9: "Septembrie", 10: "Octombrie", 11: "Noiembrie", 12: "Decembrie"
|
||||
}
|
||||
now = datetime.now()
|
||||
return f"{months_ro[now.month]} {now.year}"
|
||||
|
||||
|
||||
def _pad_line_for_wide_buttons(text: str, target_width: int = TARGET_WIDTH) -> str:
|
||||
"""
|
||||
Pad a single line of text with invisible characters to make inline buttons wider.
|
||||
|
||||
⚠️ CRITICAL: DO NOT REMOVE THIS FUNCTION - it makes buttons wide!
|
||||
The width of InlineKeyboardMarkup buttons is determined by the message text width.
|
||||
By padding text with spaces + zero-width joiner, we force wider buttons.
|
||||
|
||||
How it works:
|
||||
1. Calculate how many characters needed to reach target_width
|
||||
2. Add spaces + Zero-Width Joiner (invisible character)
|
||||
3. Result: wider message = wider buttons (like BotFather)
|
||||
|
||||
Args:
|
||||
text: The text line to pad
|
||||
target_width: Target character count (default 50 for VERY WIDE buttons)
|
||||
|
||||
Returns:
|
||||
Padded text with invisible characters (user sees normal text, Telegram sees wider text)
|
||||
"""
|
||||
current_length = len(text)
|
||||
if current_length >= target_width:
|
||||
return text
|
||||
|
||||
# ⚠️ DO NOT REMOVE: Add spaces + zero-width joiner at the end
|
||||
# This makes buttons WIDE without changing visible text!
|
||||
padding_needed = target_width - current_length
|
||||
padding = ' ' * padding_needed + ZERO_WIDTH_JOINER
|
||||
|
||||
return text + padding
|
||||
|
||||
|
||||
def pad_message_for_wide_buttons(message: str, target_width: int = TARGET_WIDTH, force: bool = False) -> str:
|
||||
"""
|
||||
Pad all lines in a message to make inline keyboard buttons wider.
|
||||
|
||||
⚠️ CRITICAL: DO NOT REMOVE THIS FUNCTION - it makes buttons wide!
|
||||
This is the MAIN function that applies padding to ALL messages with keyboards.
|
||||
|
||||
Why we need this:
|
||||
- Telegram determines button width based on MESSAGE TEXT width
|
||||
- Short messages = narrow buttons
|
||||
- Wide messages (with invisible padding) = WIDE buttons like BotFather
|
||||
|
||||
Args:
|
||||
message: Multi-line message text
|
||||
target_width: Target character count per line (default 50)
|
||||
force: Force padding even if ENABLE_BUTTON_PADDING is False
|
||||
|
||||
Returns:
|
||||
Message with all lines padded (if enabled or forced)
|
||||
"""
|
||||
# ⚠️ DO NOT REMOVE: Check if padding is enabled
|
||||
if not ENABLE_BUTTON_PADDING and not force:
|
||||
return message
|
||||
|
||||
# ⚠️ DO NOT REMOVE: Apply padding to each line
|
||||
lines = message.split('\n')
|
||||
padded_lines = [_pad_line_for_wide_buttons(line, target_width) for line in lines]
|
||||
return '\n'.join(padded_lines)
|
||||
|
||||
|
||||
def format_response_with_company(
|
||||
content: str,
|
||||
company_name: Optional[str] = None,
|
||||
apply_padding: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Format a response with company name at the top (simplified format).
|
||||
|
||||
⚠️ IMPORTANT: Applies padding by default to make buttons WIDE!
|
||||
|
||||
Format:
|
||||
Company Name
|
||||
|
||||
[Content]
|
||||
|
||||
Args:
|
||||
content: The main content text
|
||||
company_name: Company name to show at top (if None, just returns content)
|
||||
apply_padding: Whether to apply invisible padding for wider buttons (default TRUE)
|
||||
|
||||
Returns:
|
||||
Formatted response with company name header AND padding for wide buttons
|
||||
"""
|
||||
if company_name:
|
||||
message = f"{company_name}\n\n{content}"
|
||||
else:
|
||||
message = content
|
||||
|
||||
# ⚠️ DO NOT REMOVE: Apply padding to make inline keyboard buttons WIDE!
|
||||
# Without this, buttons become narrow like before
|
||||
if apply_padding:
|
||||
message = pad_message_for_wide_buttons(message)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_menu_message(
|
||||
company_name: Optional[str] = None,
|
||||
company_cui: Optional[str] = None,
|
||||
apply_padding: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Get the menu message text with company details (simplified format).
|
||||
|
||||
⚠️ IMPORTANT: Applies padding by default to make menu buttons WIDE!
|
||||
|
||||
Format without labels - just values:
|
||||
- Line 1: Company name
|
||||
- Line 2: CUI
|
||||
- Line 3: Accounting month
|
||||
|
||||
Args:
|
||||
company_name: Active company name
|
||||
company_cui: Company fiscal code (CUI)
|
||||
apply_padding: Whether to apply invisible padding for wider buttons (default TRUE)
|
||||
|
||||
Returns:
|
||||
Formatted message text for menu WITH padding for wide buttons
|
||||
"""
|
||||
if company_name:
|
||||
# Simplified format: just values, no labels
|
||||
message = f"{company_name}\n"
|
||||
if company_cui:
|
||||
message += f"{company_cui}\n"
|
||||
message += f"{_get_current_month_ro()}"
|
||||
else:
|
||||
# No company selected - just prompt
|
||||
message = "Selectează o companie pentru a continua"
|
||||
|
||||
# ⚠️ DO NOT REMOVE: Apply padding to make inline keyboard buttons WIDE!
|
||||
# This makes buttons look like BotFather (wide, not narrow)
|
||||
if apply_padding:
|
||||
message = pad_message_for_wide_buttons(message)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def create_main_menu(
|
||||
company_name: Optional[str] = None,
|
||||
company_cui: Optional[str] = None,
|
||||
is_authenticated: bool = True,
|
||||
cache_enabled: Optional[bool] = None
|
||||
) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create main menu keyboard (Level 1) with financial options.
|
||||
|
||||
Layout: Full-width buttons with company selection at top
|
||||
|
||||
Args:
|
||||
company_name: Active company name, or None if no company selected
|
||||
company_cui: Company fiscal code (CUI), or None
|
||||
is_authenticated: Whether user is authenticated (affects Login/Logout button)
|
||||
cache_enabled: Cache state for user (True=ON, False=OFF, None=unknown)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with main menu buttons
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Only show financial menu if authenticated
|
||||
if is_authenticated:
|
||||
# Row 1: Company selection (full width, single line - InlineKeyboardButton doesn't support multiline)
|
||||
if company_name:
|
||||
# Short company name for button (CUI and month will be shown in message text)
|
||||
# Truncate long names to fit in button
|
||||
max_length = 35
|
||||
display_name = company_name if len(company_name) <= max_length else company_name[:max_length-3] + "..."
|
||||
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(
|
||||
f"{display_name}",
|
||||
callback_data="menu:select_company"
|
||||
)
|
||||
])
|
||||
else:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(
|
||||
"Selectare Companie",
|
||||
callback_data="menu:select_company"
|
||||
)
|
||||
])
|
||||
|
||||
# Rows 2-4: Financial options (2 buttons per row, made wide by message text padding)
|
||||
keyboard.extend([
|
||||
[
|
||||
InlineKeyboardButton("Sold Companie", callback_data="menu:sold"),
|
||||
InlineKeyboardButton("Trezorerie Casa", callback_data="menu:casa")
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton("Trezorerie Banca", callback_data="menu:banca"),
|
||||
InlineKeyboardButton("Sold Clienti", callback_data="menu:clienti")
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton("Sold Furnizori", callback_data="menu:furnizori"),
|
||||
InlineKeyboardButton("Evolutie Incasari", callback_data="menu:evolutie")
|
||||
]
|
||||
])
|
||||
|
||||
# Row 5: Cache options (2 buttons per row, only if authenticated)
|
||||
if is_authenticated:
|
||||
# Dynamic cache toggle button showing current state
|
||||
if cache_enabled is None:
|
||||
cache_button_text = "Toggle Cache"
|
||||
elif cache_enabled:
|
||||
cache_button_text = "Cache: ON"
|
||||
else:
|
||||
cache_button_text = "Cache: OFF"
|
||||
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(cache_button_text, callback_data="menu:togglecache"),
|
||||
InlineKeyboardButton("Clear Cache", callback_data="menu:clearcache")
|
||||
])
|
||||
|
||||
# Row 6: Help/Logout buttons (authenticated) or Login button (non-authenticated)
|
||||
if is_authenticated:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("Help", callback_data="action:help"),
|
||||
InlineKeyboardButton("Logout", callback_data="action:logout")
|
||||
])
|
||||
else:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("Login", callback_data="action:login")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_action_buttons(current_view: str, show_export: bool = True, show_back: bool = False, show_refresh: bool = True) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create action buttons for responses (Refresh, Export, Back, Menu).
|
||||
|
||||
Layout (buttons made wide by message text padding):
|
||||
[Refresh] [Export] (if show_refresh=True and show_export=True)
|
||||
[Refresh] (if show_refresh=True and show_export=False)
|
||||
[Înapoi] (if show_back=True, full width)
|
||||
[Menu] (full width, always shown)
|
||||
|
||||
Args:
|
||||
current_view: View identifier for refresh callback (e.g., "sold", "clienti")
|
||||
show_export: Whether to show Export button
|
||||
show_back: Whether to show Back button to list
|
||||
show_refresh: Whether to show Refresh button
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with action buttons
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Row 1: Refresh and optionally Export (only if show_refresh is True)
|
||||
if show_refresh:
|
||||
if show_export:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("Refresh", callback_data=f"action:refresh:{current_view}"),
|
||||
InlineKeyboardButton("Export", callback_data=f"action:export:{current_view}")
|
||||
])
|
||||
else:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("Refresh", callback_data=f"action:refresh:{current_view}")
|
||||
])
|
||||
|
||||
# Row 2: Back to List (if show_back is True)
|
||||
if show_back:
|
||||
# Determine back callback based on current view
|
||||
# ✅ FIX: Handle detail views (client_detail:name, supplier_detail:name)
|
||||
if current_view.startswith("client_detail:"):
|
||||
back_callback = "menu:clienti" # Back to client list
|
||||
elif current_view.startswith("supplier_detail:"):
|
||||
back_callback = "menu:furnizori" # Back to supplier list
|
||||
elif current_view == "clienti":
|
||||
back_callback = "clients_page:0" # Match handlers.py:1689
|
||||
elif current_view == "furnizori":
|
||||
back_callback = "suppliers_page:0" # Match handlers.py:1721
|
||||
else:
|
||||
back_callback = "action:menu" # Fallback to menu
|
||||
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("« Înapoi", callback_data=back_callback)
|
||||
])
|
||||
|
||||
# Row 3: Back to Menu (full width)
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("Meniu Principal", callback_data="action:menu")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_client_list_keyboard(clients: List[Dict], max_items: int = 10, page: int = 0) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create client list keyboard (Level 2) with client buttons and pagination.
|
||||
|
||||
Layout: 1 column for clients, pagination controls, 2 columns for navigation
|
||||
|
||||
Args:
|
||||
clients: List of client dicts with keys: id, name, balance
|
||||
max_items: Maximum number of clients per page (default: 10)
|
||||
page: Current page number (0-indexed)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with client list buttons and pagination
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Sort clients alphabetically by name
|
||||
sorted_clients = sorted(clients, key=lambda x: x.get('name', '').lower())
|
||||
|
||||
# Calculate pagination
|
||||
total_clients = len(sorted_clients)
|
||||
total_pages = (total_clients + max_items - 1) // max_items # Ceiling division
|
||||
start_idx = page * max_items
|
||||
end_idx = min(start_idx + max_items, total_clients)
|
||||
|
||||
# Display clients for current page
|
||||
display_clients = sorted_clients[start_idx:end_idx]
|
||||
|
||||
# Add client buttons (1 per row)
|
||||
for client in display_clients:
|
||||
client_name = client.get('name', 'N/A')
|
||||
balance = client.get('balance', 0)
|
||||
|
||||
# Format balance with thousands separator
|
||||
balance_str = f"{balance:,.0f}" if balance else "0"
|
||||
|
||||
button_text = f"{client_name} - {balance_str} RON"
|
||||
|
||||
# Limit callback_data to 64 bytes (Telegram limit)
|
||||
# Use only first 40 chars of name to stay within limit
|
||||
safe_name = client_name[:40] if len(client_name) > 40 else client_name
|
||||
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(
|
||||
button_text,
|
||||
callback_data=f"details:client:{safe_name}:0" # name:page
|
||||
)
|
||||
])
|
||||
|
||||
# Pagination controls (only if more than one page)
|
||||
if total_pages > 1:
|
||||
nav_buttons = []
|
||||
|
||||
# Previous button
|
||||
if page > 0:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("< Anterior", callback_data=f"clients_page:{page-1}")
|
||||
)
|
||||
|
||||
# Page indicator (non-clickable)
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
|
||||
)
|
||||
|
||||
# Next button
|
||||
if page < total_pages - 1:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("Următor >", callback_data=f"clients_page:{page+1}")
|
||||
)
|
||||
|
||||
keyboard.append(nav_buttons)
|
||||
|
||||
# Navigation row: Back button only
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("< Înapoi", callback_data="action:menu")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_supplier_list_keyboard(suppliers: List[Dict], max_items: int = 10, page: int = 0) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create supplier list keyboard (Level 2) with supplier buttons and pagination.
|
||||
|
||||
Layout: 1 column for suppliers, pagination controls, 2 columns for navigation
|
||||
|
||||
Args:
|
||||
suppliers: List of supplier dicts with keys: id, name, balance
|
||||
max_items: Maximum number of suppliers per page (default: 10)
|
||||
page: Current page number (0-indexed)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with supplier list buttons and pagination
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Sort suppliers alphabetically by name
|
||||
sorted_suppliers = sorted(suppliers, key=lambda x: x.get('name', '').lower())
|
||||
|
||||
# Calculate pagination
|
||||
total_suppliers = len(sorted_suppliers)
|
||||
total_pages = (total_suppliers + max_items - 1) // max_items # Ceiling division
|
||||
start_idx = page * max_items
|
||||
end_idx = min(start_idx + max_items, total_suppliers)
|
||||
|
||||
# Display suppliers for current page
|
||||
display_suppliers = sorted_suppliers[start_idx:end_idx]
|
||||
|
||||
# Add supplier buttons (1 per row)
|
||||
for supplier in display_suppliers:
|
||||
supplier_name = supplier.get('name', 'N/A')
|
||||
balance = supplier.get('balance', 0)
|
||||
|
||||
# Format balance with thousands separator
|
||||
balance_str = f"{balance:,.0f}" if balance else "0"
|
||||
|
||||
button_text = f"{supplier_name} - {balance_str} RON"
|
||||
|
||||
# Limit callback_data to 64 bytes (Telegram limit)
|
||||
# Use only first 40 chars of name to stay within limit
|
||||
safe_name = supplier_name[:40] if len(supplier_name) > 40 else supplier_name
|
||||
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(
|
||||
button_text,
|
||||
callback_data=f"details:supplier:{safe_name}:0" # name:page
|
||||
)
|
||||
])
|
||||
|
||||
# Pagination controls (only if more than one page)
|
||||
if total_pages > 1:
|
||||
nav_buttons = []
|
||||
|
||||
# Previous button
|
||||
if page > 0:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("< Anterior", callback_data=f"suppliers_page:{page-1}")
|
||||
)
|
||||
|
||||
# Page indicator (non-clickable)
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
|
||||
)
|
||||
|
||||
# Next button
|
||||
if page < total_pages - 1:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("Următor >", callback_data=f"suppliers_page:{page+1}")
|
||||
)
|
||||
|
||||
keyboard.append(nav_buttons)
|
||||
|
||||
# Navigation row: Back button only
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("< Înapoi", callback_data="action:menu")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_invoice_list_keyboard(
|
||||
invoices: List[Dict],
|
||||
partner_type: str,
|
||||
partner_name: str,
|
||||
max_items: int = 10,
|
||||
page: int = 0
|
||||
) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create invoice list keyboard (Level 3) with invoice buttons and pagination.
|
||||
|
||||
Layout: 1 column for invoices, pagination controls, 2 columns for navigation
|
||||
|
||||
Args:
|
||||
invoices: List of invoice dicts with keys: id, number, amount, status
|
||||
partner_type: "CLIENTI" or "FURNIZORI"
|
||||
partner_name: Client/supplier name (for back navigation)
|
||||
max_items: Maximum number of invoices per page (default: 10)
|
||||
page: Current page number (0-indexed)
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with invoice list buttons and pagination
|
||||
"""
|
||||
keyboard = []
|
||||
|
||||
# Limit partner_name to 30 chars for Telegram callback_data limit (64 bytes)
|
||||
safe_partner_name = partner_name[:30] if len(partner_name) > 30 else partner_name
|
||||
|
||||
# Calculate pagination
|
||||
total_invoices = len(invoices)
|
||||
total_pages = (total_invoices + max_items - 1) // max_items # Ceiling division
|
||||
start_idx = page * max_items
|
||||
end_idx = min(start_idx + max_items, total_invoices)
|
||||
|
||||
# Display invoices for current page
|
||||
display_invoices = invoices[start_idx:end_idx]
|
||||
|
||||
# Add invoice buttons (1 per row)
|
||||
for invoice in display_invoices:
|
||||
invoice_id = invoice.get('id', 0)
|
||||
invoice_number = invoice.get('number', 'N/A')
|
||||
amount = invoice.get('amount', 0)
|
||||
status = invoice.get('status', 'unknown')
|
||||
|
||||
# Format amount with thousands separator
|
||||
amount_str = f"{amount:,.0f}" if amount else "0"
|
||||
|
||||
# Status text indicator (no emojis)
|
||||
status_text = "[NEPLATIT]" if status in ['unpaid', 'overdue'] else "[PLATIT]"
|
||||
|
||||
button_text = f"{status_text} {invoice_number} - {amount_str} RON"
|
||||
keyboard.append([
|
||||
InlineKeyboardButton(
|
||||
button_text,
|
||||
callback_data=f"invoice:{partner_type}:{invoice_id}"
|
||||
)
|
||||
])
|
||||
|
||||
# Pagination controls (only if more than one page)
|
||||
if total_pages > 1:
|
||||
nav_buttons = []
|
||||
|
||||
# Previous button
|
||||
if page > 0:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("< Anterior", callback_data=f"invoices_page:{partner_type}:{safe_partner_name}:{page-1}")
|
||||
)
|
||||
|
||||
# Page indicator (non-clickable)
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
|
||||
)
|
||||
|
||||
# Next button
|
||||
if page < total_pages - 1:
|
||||
nav_buttons.append(
|
||||
InlineKeyboardButton("Următor >", callback_data=f"invoices_page:{partner_type}:{safe_partner_name}:{page+1}")
|
||||
)
|
||||
|
||||
keyboard.append(nav_buttons)
|
||||
|
||||
# Navigation row: Back and Export (2 buttons per row)
|
||||
back_target = "clienti" if partner_type == "CLIENTI" else "furnizori"
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("< Înapoi", callback_data=f"nav:back:{back_target}"),
|
||||
InlineKeyboardButton("Export", callback_data=f"action:export:{partner_type.lower()}")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def create_navigation_buttons(back_to: str) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create simple navigation buttons (just Back button).
|
||||
|
||||
Args:
|
||||
back_to: Target location identifier (e.g., "menu", "clienti", "furnizori")
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with navigation button
|
||||
"""
|
||||
keyboard = [
|
||||
[
|
||||
InlineKeyboardButton(
|
||||
f"< Înapoi la {back_to}",
|
||||
callback_data=f"nav:back:{back_to}"
|
||||
)
|
||||
]
|
||||
]
|
||||
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
316
backend/modules/telegram/bot_main.py
Normal file
316
backend/modules/telegram/bot_main.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Main entry point for ROA2WEB Telegram Bot
|
||||
|
||||
This bot provides access to the ROA2WEB ERP system through Telegram
|
||||
using direct command handlers for financial data queries.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import uvicorn
|
||||
from threading import Thread
|
||||
|
||||
# ============================================================================
|
||||
# LOAD ENVIRONMENT VARIABLES FIRST - BEFORE ANY APP IMPORTS
|
||||
# ============================================================================
|
||||
# This ensures all modules can access environment variables at import time
|
||||
env_path = Path(__file__).parent.parent / '.env'
|
||||
load_dotenv(env_path)
|
||||
|
||||
# Telegram imports
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
CallbackQueryHandler,
|
||||
MessageHandler,
|
||||
filters
|
||||
)
|
||||
|
||||
# Import database initialization
|
||||
from backend.modules.telegram.db import (
|
||||
init_database,
|
||||
cleanup_expired_codes,
|
||||
cleanup_expired_sessions,
|
||||
cleanup_expired_email_codes
|
||||
)
|
||||
|
||||
# Import bot handlers
|
||||
from backend.modules.telegram.bot.handlers import (
|
||||
start_command,
|
||||
help_command,
|
||||
clear_command,
|
||||
companies_command,
|
||||
unlink_command,
|
||||
selectcompany_command,
|
||||
dashboard_command,
|
||||
sold_command,
|
||||
facturi_command,
|
||||
trezorerie_command,
|
||||
# FAZA 3: New command handlers with button interface
|
||||
menu_command,
|
||||
trezorerie_casa_command,
|
||||
trezorerie_banca_command,
|
||||
clienti_command,
|
||||
furnizori_command,
|
||||
evolutie_command,
|
||||
# FAZA 6: Cache management commands
|
||||
clearcache_command,
|
||||
togglecache_command,
|
||||
# Text message handlers
|
||||
handle_text_message,
|
||||
# FAZA 4: Callback and error handlers
|
||||
button_callback,
|
||||
error_handler
|
||||
)
|
||||
|
||||
# Import email authentication handler
|
||||
from backend.modules.telegram.bot.email_handlers import email_login_handler
|
||||
|
||||
# Import internal API
|
||||
from backend.modules.telegram.internal_api import internal_api
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Environment variables (already loaded above)
|
||||
TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
|
||||
BACKEND_URL = os.getenv('BACKEND_URL', 'http://localhost:8000')
|
||||
INTERNAL_API_PORT = int(os.getenv('INTERNAL_API_PORT', '8002'))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM BOT SETUP
|
||||
# ============================================================================
|
||||
|
||||
def create_telegram_application() -> Application:
|
||||
"""
|
||||
Create and configure the Telegram bot application.
|
||||
|
||||
Returns:
|
||||
Application: Configured Telegram application
|
||||
"""
|
||||
logger.info("Creating Telegram application...")
|
||||
|
||||
# Create application
|
||||
application = Application.builder().token(TELEGRAM_BOT_TOKEN).build()
|
||||
|
||||
# Register email authentication conversation handler (must be before other handlers)
|
||||
application.add_handler(email_login_handler)
|
||||
|
||||
# Register essential command handlers
|
||||
application.add_handler(CommandHandler("start", start_command))
|
||||
application.add_handler(CommandHandler("menu", menu_command))
|
||||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("unlink", unlink_command))
|
||||
|
||||
# =========================================================================
|
||||
# LEGACY COMMAND HANDLERS (kept for backwards compatibility, hidden from help)
|
||||
# =========================================================================
|
||||
# NOTE: These commands are redundant with the button interface.
|
||||
# They're kept for users who already know them, but we push buttons in help.
|
||||
# Consider removing completely if migration is successful.
|
||||
|
||||
application.add_handler(CommandHandler("clear", clear_command))
|
||||
application.add_handler(CommandHandler("companies", companies_command))
|
||||
application.add_handler(CommandHandler("selectcompany", selectcompany_command))
|
||||
application.add_handler(CommandHandler("dashboard", dashboard_command))
|
||||
application.add_handler(CommandHandler("sold", sold_command))
|
||||
application.add_handler(CommandHandler("facturi", facturi_command))
|
||||
application.add_handler(CommandHandler("trezorerie", trezorerie_command))
|
||||
application.add_handler(CommandHandler("trezorerie_casa", trezorerie_casa_command))
|
||||
application.add_handler(CommandHandler("trezorerie_banca", trezorerie_banca_command))
|
||||
application.add_handler(CommandHandler("clienti", clienti_command))
|
||||
application.add_handler(CommandHandler("furnizori", furnizori_command))
|
||||
application.add_handler(CommandHandler("evolutie", evolutie_command))
|
||||
|
||||
# FAZA 6: Cache management commands
|
||||
application.add_handler(CommandHandler("clearcache", clearcache_command))
|
||||
application.add_handler(CommandHandler("togglecache", togglecache_command))
|
||||
|
||||
# Text message handler (for direct code input and future NLP)
|
||||
# IMPORTANT: This must be registered BEFORE CallbackQueryHandler
|
||||
# filters.TEXT & ~filters.COMMAND ensures we only process non-command text messages
|
||||
application.add_handler(MessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
handle_text_message
|
||||
))
|
||||
|
||||
# FAZA 4: Register callback query handler (for inline buttons)
|
||||
application.add_handler(CallbackQueryHandler(button_callback))
|
||||
|
||||
# Register error handler
|
||||
application.add_error_handler(error_handler)
|
||||
|
||||
logger.info("Telegram application configured with all handlers")
|
||||
|
||||
return application
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INTERNAL API SERVER
|
||||
# ============================================================================
|
||||
|
||||
def run_internal_api():
|
||||
"""
|
||||
Run the internal FastAPI server in a separate thread.
|
||||
|
||||
This API handles communication from the backend (saving auth codes).
|
||||
"""
|
||||
logger.info(f"Starting internal API on port {INTERNAL_API_PORT}...")
|
||||
|
||||
uvicorn.run(
|
||||
internal_api,
|
||||
host="0.0.0.0",
|
||||
port=INTERNAL_API_PORT,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STARTUP/SHUTDOWN
|
||||
# ============================================================================
|
||||
|
||||
async def startup():
|
||||
"""
|
||||
Initialize the bot application on startup.
|
||||
"""
|
||||
logger.info("🚀 ROA2WEB Telegram Bot - Starting up...")
|
||||
|
||||
# Initialize database
|
||||
try:
|
||||
logger.info("Initializing SQLite database...")
|
||||
await init_database()
|
||||
logger.info("✅ Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize database: {e}")
|
||||
raise
|
||||
|
||||
# Cleanup expired data
|
||||
try:
|
||||
logger.info("Cleaning up expired data...")
|
||||
expired_codes = await cleanup_expired_codes()
|
||||
expired_sessions = await cleanup_expired_sessions()
|
||||
expired_email_codes = await cleanup_expired_email_codes()
|
||||
logger.info(f"✅ Cleanup complete: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Cleanup failed (non-critical): {e}")
|
||||
|
||||
logger.info("✅ Startup complete")
|
||||
|
||||
|
||||
async def shutdown():
|
||||
"""
|
||||
Clean up resources on shutdown.
|
||||
"""
|
||||
logger.info("👋 ROA2WEB Telegram Bot - Shutting down...")
|
||||
logger.info("✅ Shutdown complete")
|
||||
|
||||
|
||||
async def scheduled_cleanup():
|
||||
"""
|
||||
Background task to periodically clean up expired data.
|
||||
Runs every hour to remove expired auth codes, sessions, and email codes.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(3600) # Sleep for 1 hour
|
||||
logger.info("🧹 Running scheduled cleanup...")
|
||||
expired_codes = await cleanup_expired_codes()
|
||||
expired_sessions = await cleanup_expired_sessions()
|
||||
expired_email_codes = await cleanup_expired_email_codes()
|
||||
logger.info(f"✅ Scheduled cleanup: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in scheduled cleanup: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN APPLICATION
|
||||
# ============================================================================
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main application entry point.
|
||||
|
||||
Runs both the Telegram bot and internal API server concurrently.
|
||||
"""
|
||||
try:
|
||||
# Run startup
|
||||
await startup()
|
||||
|
||||
# Create Telegram application
|
||||
telegram_app = create_telegram_application()
|
||||
|
||||
# Start internal API in a separate thread
|
||||
api_thread = Thread(target=run_internal_api, daemon=True)
|
||||
api_thread.start()
|
||||
logger.info(f"✅ Internal API started on port {INTERNAL_API_PORT}")
|
||||
|
||||
# Start scheduled cleanup task in background
|
||||
cleanup_task = asyncio.create_task(scheduled_cleanup())
|
||||
logger.info("✅ Scheduled cleanup task started")
|
||||
|
||||
# Initialize and start Telegram bot
|
||||
logger.info("🤖 Starting Telegram bot polling...")
|
||||
await telegram_app.initialize()
|
||||
await telegram_app.start()
|
||||
await telegram_app.updater.start_polling(drop_pending_updates=True)
|
||||
|
||||
logger.info("✅ Telegram bot is now running and polling for updates")
|
||||
logger.info(f"📱 Bot ready to receive messages at @{(await telegram_app.bot.get_me()).username}")
|
||||
logger.info("🎯 Bot is operational with direct command handlers!")
|
||||
|
||||
# Keep running until interrupted
|
||||
await asyncio.Event().wait()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("⚠️ Received interrupt signal")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Fatal error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
# Stop Telegram bot gracefully
|
||||
try:
|
||||
if 'telegram_app' in locals():
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await telegram_app.updater.stop()
|
||||
await telegram_app.stop()
|
||||
await telegram_app.shutdown()
|
||||
logger.info("✅ Telegram bot stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping Telegram bot: {e}")
|
||||
|
||||
await shutdown()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENTRY POINT
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check required environment variables
|
||||
if not os.getenv('TELEGRAM_BOT_TOKEN'):
|
||||
logger.error("❌ TELEGRAM_BOT_TOKEN is required")
|
||||
logger.error("Please set it in .env file")
|
||||
exit(1)
|
||||
|
||||
# Display startup banner
|
||||
logger.info("=" * 60)
|
||||
logger.info(" ROA2WEB TELEGRAM BOT")
|
||||
logger.info(" Financial ERP Assistant with Direct Commands")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Run the main application
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("👋 Application stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Application failed: {e}", exc_info=True)
|
||||
exit(1)
|
||||
86
backend/modules/telegram/db/__init__.py
Normal file
86
backend/modules/telegram/db/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Database module for Telegram Bot
|
||||
|
||||
Provides SQLite database operations for:
|
||||
- User management and Oracle account linking
|
||||
- Authentication code management
|
||||
- Conversation session management
|
||||
"""
|
||||
|
||||
from .database import (
|
||||
init_database,
|
||||
get_db_connection,
|
||||
cleanup_expired_codes,
|
||||
cleanup_expired_sessions,
|
||||
cleanup_expired_email_codes,
|
||||
get_database_stats,
|
||||
DB_PATH,
|
||||
)
|
||||
|
||||
from .operations import (
|
||||
# User operations
|
||||
create_or_update_user,
|
||||
get_user,
|
||||
link_user_to_oracle,
|
||||
update_user_tokens,
|
||||
update_user_last_active,
|
||||
is_user_linked,
|
||||
is_user_authenticated,
|
||||
# Auth code operations
|
||||
create_auth_code,
|
||||
get_auth_code,
|
||||
verify_and_use_auth_code,
|
||||
get_pending_codes_for_user,
|
||||
# Email auth code operations
|
||||
get_pending_email_code,
|
||||
create_email_auth_code,
|
||||
get_email_auth_code,
|
||||
increment_failed_attempts,
|
||||
mark_email_code_used,
|
||||
delete_user_email_codes,
|
||||
# Session operations
|
||||
create_session,
|
||||
get_session,
|
||||
get_user_active_session,
|
||||
update_session_state,
|
||||
delete_session,
|
||||
delete_user_sessions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Database setup
|
||||
'init_database',
|
||||
'get_db_connection',
|
||||
'cleanup_expired_codes',
|
||||
'cleanup_expired_sessions',
|
||||
'cleanup_expired_email_codes',
|
||||
'get_database_stats',
|
||||
'DB_PATH',
|
||||
# User operations
|
||||
'create_or_update_user',
|
||||
'get_user',
|
||||
'link_user_to_oracle',
|
||||
'update_user_tokens',
|
||||
'update_user_last_active',
|
||||
'is_user_linked',
|
||||
'is_user_authenticated',
|
||||
# Auth code operations
|
||||
'create_auth_code',
|
||||
'get_auth_code',
|
||||
'verify_and_use_auth_code',
|
||||
'get_pending_codes_for_user',
|
||||
# Email auth code operations
|
||||
'get_pending_email_code',
|
||||
'create_email_auth_code',
|
||||
'get_email_auth_code',
|
||||
'increment_failed_attempts',
|
||||
'mark_email_code_used',
|
||||
'delete_user_email_codes',
|
||||
# Session operations
|
||||
'create_session',
|
||||
'get_session',
|
||||
'get_user_active_session',
|
||||
'update_session_state',
|
||||
'delete_session',
|
||||
'delete_user_sessions',
|
||||
]
|
||||
310
backend/modules/telegram/db/database.py
Normal file
310
backend/modules/telegram/db/database.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
SQLite Database Setup for Telegram Bot
|
||||
|
||||
This module handles database connection, initialization, and schema creation.
|
||||
Uses aiosqlite for async SQLite operations.
|
||||
"""
|
||||
|
||||
import aiosqlite
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database file location
|
||||
DB_DIR = Path(__file__).parent.parent.parent / "data"
|
||||
DB_PATH = DB_DIR / "telegram_bot.db"
|
||||
|
||||
|
||||
async def get_db_connection() -> aiosqlite.Connection:
|
||||
"""
|
||||
Get a database connection.
|
||||
|
||||
Returns:
|
||||
aiosqlite.Connection: Database connection
|
||||
"""
|
||||
conn = await aiosqlite.connect(DB_PATH)
|
||||
conn.row_factory = aiosqlite.Row # Enable column access by name
|
||||
return conn
|
||||
|
||||
|
||||
async def init_database() -> None:
|
||||
"""
|
||||
Initialize the database and create all tables.
|
||||
Safe to call multiple times - only creates tables if they don't exist.
|
||||
"""
|
||||
try:
|
||||
# Ensure data directory exists
|
||||
DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Database directory: {DB_DIR}")
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
# Enable foreign keys
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
# Create telegram_users table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS telegram_users (
|
||||
telegram_user_id INTEGER PRIMARY KEY,
|
||||
username TEXT,
|
||||
first_name TEXT NOT NULL,
|
||||
last_name TEXT,
|
||||
oracle_username TEXT,
|
||||
jwt_token TEXT,
|
||||
jwt_refresh_token TEXT,
|
||||
token_expires_at TIMESTAMP,
|
||||
linked_at TIMESTAMP,
|
||||
last_active_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT 1
|
||||
)
|
||||
""")
|
||||
|
||||
# Create telegram_auth_codes table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS telegram_auth_codes (
|
||||
code TEXT PRIMARY KEY,
|
||||
telegram_user_id INTEGER,
|
||||
oracle_username TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
used BOOLEAN DEFAULT 0,
|
||||
used_at TIMESTAMP,
|
||||
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create telegram_sessions table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS telegram_sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
telegram_user_id INTEGER NOT NULL,
|
||||
conversation_state TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create email_auth_codes table (email-based authentication)
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS email_auth_codes (
|
||||
code TEXT PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
oracle_username TEXT NOT NULL,
|
||||
telegram_user_id INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
used_at TIMESTAMP,
|
||||
failed_attempts INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for better query performance
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_codes_telegram_user
|
||||
ON telegram_auth_codes(telegram_user_id)
|
||||
""")
|
||||
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_codes_expires
|
||||
ON telegram_auth_codes(expires_at)
|
||||
""")
|
||||
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_telegram_user
|
||||
ON telegram_sessions(telegram_user_id)
|
||||
""")
|
||||
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_expires
|
||||
ON telegram_sessions(expires_at)
|
||||
""")
|
||||
|
||||
# Create indexes for email_auth_codes table
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_email_auth_email
|
||||
ON email_auth_codes(email)
|
||||
""")
|
||||
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_email_auth_telegram_user
|
||||
ON email_auth_codes(telegram_user_id)
|
||||
""")
|
||||
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_email_auth_expires
|
||||
ON email_auth_codes(expires_at)
|
||||
""")
|
||||
|
||||
await db.commit()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# Log table info
|
||||
cursor = await db.execute("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table'
|
||||
ORDER BY name
|
||||
""")
|
||||
tables = await cursor.fetchall()
|
||||
logger.info(f"Existing tables: {[t[0] for t in tables]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup_expired_codes() -> int:
|
||||
"""
|
||||
Delete expired authentication codes from the database.
|
||||
This should be called periodically (e.g., every hour).
|
||||
|
||||
Returns:
|
||||
int: Number of expired codes deleted
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
DELETE FROM telegram_auth_codes
|
||||
WHERE expires_at < ?
|
||||
""", (datetime.now(),))
|
||||
|
||||
await db.commit()
|
||||
deleted = cursor.rowcount
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"Cleaned up {deleted} expired auth codes")
|
||||
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired codes: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def cleanup_expired_sessions() -> int:
|
||||
"""
|
||||
Delete expired sessions from the database.
|
||||
This should be called periodically (e.g., daily).
|
||||
|
||||
Returns:
|
||||
int: Number of expired sessions deleted
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
DELETE FROM telegram_sessions
|
||||
WHERE expires_at < ?
|
||||
""", (datetime.now(),))
|
||||
|
||||
await db.commit()
|
||||
deleted = cursor.rowcount
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"Cleaned up {deleted} expired sessions")
|
||||
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired sessions: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def cleanup_expired_email_codes() -> int:
|
||||
"""
|
||||
Delete expired and old used email codes from the database.
|
||||
This should be called periodically (e.g., hourly).
|
||||
|
||||
Returns:
|
||||
int: Number of email codes deleted
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
# Delete expired codes or used codes older than 1 day
|
||||
cursor = await db.execute("""
|
||||
DELETE FROM email_auth_codes
|
||||
WHERE expires_at < ?
|
||||
OR (used = 1 AND used_at < ?)
|
||||
""", (
|
||||
datetime.now(),
|
||||
datetime.now() - timedelta(days=1)
|
||||
))
|
||||
|
||||
await db.commit()
|
||||
deleted = cursor.rowcount
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"Cleaned up {deleted} expired/old email auth codes")
|
||||
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup email codes: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def get_database_stats() -> dict:
|
||||
"""
|
||||
Get database statistics for monitoring.
|
||||
|
||||
Returns:
|
||||
dict: Database statistics
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
stats = {}
|
||||
|
||||
# Count users
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM telegram_users")
|
||||
stats['total_users'] = (await cursor.fetchone())[0]
|
||||
|
||||
cursor = await db.execute(
|
||||
"SELECT COUNT(*) FROM telegram_users WHERE is_active = 1"
|
||||
)
|
||||
stats['active_users'] = (await cursor.fetchone())[0]
|
||||
|
||||
# Count pending codes
|
||||
cursor = await db.execute("""
|
||||
SELECT COUNT(*) FROM telegram_auth_codes
|
||||
WHERE used = 0 AND expires_at > ?
|
||||
""", (datetime.now(),))
|
||||
stats['pending_codes'] = (await cursor.fetchone())[0]
|
||||
|
||||
# Count active sessions
|
||||
cursor = await db.execute("""
|
||||
SELECT COUNT(*) FROM telegram_sessions
|
||||
WHERE expires_at > ?
|
||||
""", (datetime.now(),))
|
||||
stats['active_sessions'] = (await cursor.fetchone())[0]
|
||||
|
||||
# Database file size
|
||||
if DB_PATH.exists():
|
||||
stats['db_size_mb'] = DB_PATH.stat().st_size / (1024 * 1024)
|
||||
else:
|
||||
stats['db_size_mb'] = 0
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database stats: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Export main functions
|
||||
__all__ = [
|
||||
'get_db_connection',
|
||||
'init_database',
|
||||
'cleanup_expired_codes',
|
||||
'cleanup_expired_sessions',
|
||||
'cleanup_expired_email_codes',
|
||||
'get_database_stats',
|
||||
'DB_PATH',
|
||||
]
|
||||
813
backend/modules/telegram/db/operations.py
Normal file
813
backend/modules/telegram/db/operations.py
Normal file
@@ -0,0 +1,813 @@
|
||||
"""
|
||||
Database Operations for Telegram Bot
|
||||
|
||||
This module provides CRUD operations for:
|
||||
- telegram_users: Telegram user management and Oracle account linking
|
||||
- telegram_auth_codes: Authentication code management
|
||||
- telegram_sessions: Conversation session management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from .database import DB_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TELEGRAM USERS OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def create_or_update_user(
|
||||
telegram_user_id: int,
|
||||
username: Optional[str],
|
||||
first_name: str,
|
||||
last_name: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Create or update a Telegram user record.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
username: Telegram username (without @)
|
||||
first_name: User's first name
|
||||
last_name: User's last name
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
INSERT INTO telegram_users (
|
||||
telegram_user_id, username, first_name, last_name, last_active_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(telegram_user_id) DO UPDATE SET
|
||||
username = excluded.username,
|
||||
first_name = excluded.first_name,
|
||||
last_name = excluded.last_name,
|
||||
last_active_at = excluded.last_active_at
|
||||
""", (telegram_user_id, username, first_name, last_name, datetime.now()))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"User {telegram_user_id} created/updated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create/update user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_user(telegram_user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get user information by Telegram user ID.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: User data or None if not found
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_users
|
||||
WHERE telegram_user_id = ?
|
||||
""", (telegram_user_id,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user {telegram_user_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def link_user_to_oracle(
|
||||
telegram_user_id: int,
|
||||
oracle_username: str,
|
||||
jwt_token: str,
|
||||
jwt_refresh_token: str,
|
||||
token_expires_at: datetime
|
||||
) -> bool:
|
||||
"""
|
||||
Link a Telegram user to an Oracle account and save JWT tokens.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
oracle_username: Oracle username
|
||||
jwt_token: JWT access token
|
||||
jwt_refresh_token: JWT refresh token
|
||||
token_expires_at: Token expiration timestamp
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE telegram_users
|
||||
SET oracle_username = ?,
|
||||
jwt_token = ?,
|
||||
jwt_refresh_token = ?,
|
||||
token_expires_at = ?,
|
||||
linked_at = ?,
|
||||
is_active = 1
|
||||
WHERE telegram_user_id = ?
|
||||
""", (
|
||||
oracle_username,
|
||||
jwt_token,
|
||||
jwt_refresh_token,
|
||||
token_expires_at,
|
||||
datetime.now(),
|
||||
telegram_user_id
|
||||
))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"User {telegram_user_id} linked to Oracle user {oracle_username}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to link user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_user_tokens(
|
||||
telegram_user_id: int,
|
||||
jwt_token: str,
|
||||
jwt_refresh_token: str,
|
||||
token_expires_at: datetime
|
||||
) -> bool:
|
||||
"""
|
||||
Update JWT tokens for a user.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
jwt_token: New JWT access token
|
||||
jwt_refresh_token: New JWT refresh token
|
||||
token_expires_at: New token expiration timestamp
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE telegram_users
|
||||
SET jwt_token = ?,
|
||||
jwt_refresh_token = ?,
|
||||
token_expires_at = ?
|
||||
WHERE telegram_user_id = ?
|
||||
""", (jwt_token, jwt_refresh_token, token_expires_at, telegram_user_id))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Tokens updated for user {telegram_user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update tokens for user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_user_last_active(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Update the last active timestamp for a user.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE telegram_users
|
||||
SET last_active_at = ?
|
||||
WHERE telegram_user_id = ?
|
||||
""", (datetime.now(), telegram_user_id))
|
||||
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update last active for user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def is_user_linked(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Check if a user is linked to an Oracle account.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if user is linked
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT oracle_username FROM telegram_users
|
||||
WHERE telegram_user_id = ? AND oracle_username IS NOT NULL
|
||||
""", (telegram_user_id,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
return row is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if user {telegram_user_id} is linked: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def is_user_authenticated(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Check if a user is authenticated (linked and has valid token).
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if user is authenticated
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT oracle_username, jwt_token, token_expires_at
|
||||
FROM telegram_users
|
||||
WHERE telegram_user_id = ?
|
||||
AND oracle_username IS NOT NULL
|
||||
AND jwt_token IS NOT NULL
|
||||
""", (telegram_user_id,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
|
||||
# Check if token is expired (with some buffer)
|
||||
if row[2]: # token_expires_at
|
||||
expires_at = datetime.fromisoformat(row[2])
|
||||
# Token should have at least 5 minutes remaining
|
||||
if expires_at < datetime.now() + timedelta(minutes=5):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if user {telegram_user_id} is authenticated: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION CODES OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def create_auth_code(
|
||||
code: str,
|
||||
telegram_user_id: int,
|
||||
oracle_username: str,
|
||||
expires_in_minutes: int = 5
|
||||
) -> bool:
|
||||
"""
|
||||
Create a new authentication code for linking.
|
||||
|
||||
Args:
|
||||
code: 8-character authentication code
|
||||
telegram_user_id: Telegram user ID
|
||||
oracle_username: Oracle username to link
|
||||
expires_in_minutes: Code expiration time in minutes (default: 5)
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
expires_at = datetime.now() + timedelta(minutes=expires_in_minutes)
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
INSERT INTO telegram_auth_codes (
|
||||
code, telegram_user_id, oracle_username, expires_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (code, telegram_user_id, oracle_username, expires_at))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Auth code created for user {telegram_user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create auth code: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_auth_code(code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get authentication code information.
|
||||
|
||||
Args:
|
||||
code: 8-character authentication code
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Code data or None if not found
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_auth_codes
|
||||
WHERE code = ?
|
||||
""", (code,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get auth code: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_and_use_auth_code(code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify an authentication code and mark it as used.
|
||||
|
||||
Args:
|
||||
code: 8-character authentication code
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Code data if valid, None if invalid/expired
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
# Check if code exists, is not used, and not expired
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_auth_codes
|
||||
WHERE code = ?
|
||||
AND used = 0
|
||||
AND expires_at > ?
|
||||
""", (code, datetime.now()))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
logger.warning(f"Invalid or expired code: {code}")
|
||||
return None
|
||||
|
||||
# Mark code as used
|
||||
await db.execute("""
|
||||
UPDATE telegram_auth_codes
|
||||
SET used = 1, used_at = ?
|
||||
WHERE code = ?
|
||||
""", (datetime.now(), code))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Auth code {code} verified and used")
|
||||
|
||||
return dict(row)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify auth code: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_pending_codes_for_user(telegram_user_id: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all pending (unused, non-expired) codes for a user.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of pending codes
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_auth_codes
|
||||
WHERE telegram_user_id = ?
|
||||
AND used = 0
|
||||
AND expires_at > ?
|
||||
ORDER BY created_at DESC
|
||||
""", (telegram_user_id, datetime.now()))
|
||||
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pending codes for user {telegram_user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EMAIL AUTHENTICATION CODES OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def get_pending_email_code(
|
||||
telegram_user_id: int
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Get pending (non-expired, non-used) email code for user
|
||||
|
||||
Returns:
|
||||
Code data dict or None if no pending code
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT code, email, oracle_username, expires_at, failed_attempts
|
||||
FROM email_auth_codes
|
||||
WHERE telegram_user_id = ?
|
||||
AND used = 0
|
||||
AND expires_at > ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (telegram_user_id, datetime.now()))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
'code': row[0],
|
||||
'email': row[1],
|
||||
'oracle_username': row[2],
|
||||
'expires_at': datetime.fromisoformat(row[3]),
|
||||
'failed_attempts': row[4]
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pending email code: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_email_auth_code(
|
||||
code: str,
|
||||
email: str,
|
||||
username: str,
|
||||
telegram_user_id: int,
|
||||
expiry_minutes: int = 5
|
||||
) -> bool:
|
||||
"""
|
||||
Create new email authentication code
|
||||
|
||||
NOTE: Caller should check for existing pending codes first
|
||||
"""
|
||||
expires_at = datetime.now() + timedelta(minutes=expiry_minutes)
|
||||
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
INSERT INTO email_auth_codes
|
||||
(code, email, oracle_username, telegram_user_id, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (code, email, username, telegram_user_id, expires_at))
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Email auth code created for user {telegram_user_id}, "
|
||||
f"expires at {expires_at.isoformat()}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating email auth code: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def get_email_auth_code(code: str) -> Optional[Dict]:
|
||||
"""Get email auth code details"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT code, email, oracle_username, telegram_user_id,
|
||||
created_at, expires_at, used, used_at, failed_attempts
|
||||
FROM email_auth_codes
|
||||
WHERE code = ?
|
||||
""", (code,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
'code': row[0],
|
||||
'email': row[1],
|
||||
'oracle_username': row[2],
|
||||
'telegram_user_id': row[3],
|
||||
'created_at': datetime.fromisoformat(row[4]),
|
||||
'expires_at': datetime.fromisoformat(row[5]),
|
||||
'used': bool(row[6]),
|
||||
'used_at': datetime.fromisoformat(row[7]) if row[7] else None,
|
||||
'failed_attempts': row[8]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get email auth code: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def increment_failed_attempts(code: str) -> bool:
|
||||
"""Increment failed validation attempts for code"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE email_auth_codes
|
||||
SET failed_attempts = failed_attempts + 1
|
||||
WHERE code = ?
|
||||
""", (code,))
|
||||
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error incrementing failed attempts: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def mark_email_code_used(code: str) -> bool:
|
||||
"""Mark email code as used"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE email_auth_codes
|
||||
SET used = 1, used_at = ?
|
||||
WHERE code = ?
|
||||
""", (datetime.now(), code))
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Email auth code marked as used: {code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking email code as used: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_user_email_codes(telegram_user_id: int) -> int:
|
||||
"""Delete all email codes for user (cleanup)"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
DELETE FROM email_auth_codes
|
||||
WHERE telegram_user_id = ?
|
||||
""", (telegram_user_id,))
|
||||
|
||||
await db.commit()
|
||||
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Deleted {deleted} email codes for user {telegram_user_id}")
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user email codes: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def create_session(
|
||||
telegram_user_id: int,
|
||||
conversation_state: Optional[str] = None,
|
||||
expires_in_hours: int = 24
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Create a new conversation session.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
conversation_state: JSON string of conversation state
|
||||
expires_in_hours: Session expiration time in hours (default: 24)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Session ID if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
expires_at = datetime.now() + timedelta(hours=expires_in_hours)
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
INSERT INTO telegram_sessions (
|
||||
session_id, telegram_user_id, conversation_state, expires_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (session_id, telegram_user_id, conversation_state, expires_at))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Session {session_id} created for user {telegram_user_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session information.
|
||||
|
||||
Args:
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Session data or None if not found/expired
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_sessions
|
||||
WHERE session_id = ?
|
||||
AND expires_at > ?
|
||||
""", (session_id, datetime.now()))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get session {session_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_active_session(telegram_user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the most recent active session for a user.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Session data or None if no active session
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
SELECT * FROM telegram_sessions
|
||||
WHERE telegram_user_id = ?
|
||||
AND expires_at > ?
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
""", (telegram_user_id, datetime.now()))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active session for user {telegram_user_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def update_session_state(
|
||||
session_id: str,
|
||||
conversation_state: str
|
||||
) -> bool:
|
||||
"""
|
||||
Update the conversation state for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session UUID
|
||||
conversation_state: JSON string of conversation state
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
UPDATE telegram_sessions
|
||||
SET conversation_state = ?,
|
||||
updated_at = ?
|
||||
WHERE session_id = ?
|
||||
""", (conversation_state, datetime.now(), session_id))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Session {session_id} state updated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_session(session_id: str) -> bool:
|
||||
"""
|
||||
Delete a session.
|
||||
|
||||
Args:
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("""
|
||||
DELETE FROM telegram_sessions
|
||||
WHERE session_id = ?
|
||||
""", (session_id,))
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Session {session_id} deleted")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_user_sessions(telegram_user_id: int) -> bool:
|
||||
"""
|
||||
Delete all sessions for a user.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("""
|
||||
DELETE FROM telegram_sessions
|
||||
WHERE telegram_user_id = ?
|
||||
""", (telegram_user_id,))
|
||||
|
||||
await db.commit()
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Deleted {deleted} sessions for user {telegram_user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete sessions for user {telegram_user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Export all functions
|
||||
__all__ = [
|
||||
# User operations
|
||||
'create_or_update_user',
|
||||
'get_user',
|
||||
'link_user_to_oracle',
|
||||
'update_user_tokens',
|
||||
'update_user_last_active',
|
||||
'is_user_linked',
|
||||
'is_user_authenticated',
|
||||
# Auth code operations
|
||||
'create_auth_code',
|
||||
'get_auth_code',
|
||||
'verify_and_use_auth_code',
|
||||
'get_pending_codes_for_user',
|
||||
# Email auth code operations
|
||||
'get_pending_email_code',
|
||||
'create_email_auth_code',
|
||||
'get_email_auth_code',
|
||||
'increment_failed_attempts',
|
||||
'mark_email_code_used',
|
||||
'delete_user_email_codes',
|
||||
# Session operations
|
||||
'create_session',
|
||||
'get_session',
|
||||
'get_user_active_session',
|
||||
'update_session_state',
|
||||
'delete_session',
|
||||
'delete_user_sessions',
|
||||
]
|
||||
32
backend/modules/telegram/routers/__init__.py
Normal file
32
backend/modules/telegram/routers/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Telegram module router factory."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
def create_telegram_router() -> APIRouter:
|
||||
"""
|
||||
Create and configure Telegram module router.
|
||||
|
||||
Includes all Telegram bot internal API endpoints:
|
||||
- /auth/verify-user - Verify Telegram user authentication
|
||||
- /auth/generate-code - Generate auth code for linking
|
||||
- /auth/verify-code - Verify auth code
|
||||
- /stats - Bot database statistics
|
||||
|
||||
Returns:
|
||||
APIRouter: Configured router for Telegram module
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
# Import routers here to avoid circular imports
|
||||
from .auth_codes import router as auth_codes_router
|
||||
from .internal_api import internal_api as internal_api_router
|
||||
|
||||
# Include all sub-routers (no prefix - already prefixed in main.py with /api/telegram)
|
||||
# Auth codes router provides /auth/* endpoints
|
||||
router.include_router(auth_codes_router, tags=["telegram-auth"])
|
||||
|
||||
# Internal API router provides additional endpoints like /stats
|
||||
router.include_router(internal_api_router, tags=["telegram-internal"])
|
||||
|
||||
return router
|
||||
840
backend/modules/telegram/routers/auth_codes.py
Normal file
840
backend/modules/telegram/routers/auth_codes.py
Normal file
@@ -0,0 +1,840 @@
|
||||
"""
|
||||
API Router pentru Telegram Bot Integration
|
||||
Furnizează endpoint-uri pentru autentificare, linking și export rapoarte pentru Telegram bot
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from typing import List, Optional, Dict, Any
|
||||
# import sys # Removed - no longer needed
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from shared.auth.dependencies import get_current_user
|
||||
from shared.auth.models import CurrentUser
|
||||
from shared.auth.jwt_handler import jwt_handler
|
||||
from shared.database.oracle_pool import oracle_pool
|
||||
|
||||
# Telegram bot internal API URL (running on same server)
|
||||
TELEGRAM_BOT_INTERNAL_API = os.getenv("TELEGRAM_BOT_INTERNAL_API", "http://localhost:8002")
|
||||
|
||||
router = APIRouter(redirect_slashes=False)
|
||||
|
||||
# ==================== Schemas ====================
|
||||
|
||||
class GenerateCodeRequest(BaseModel):
|
||||
"""Request pentru generarea unui cod de linking"""
|
||||
telegram_user_id: int = Field(description="ID-ul utilizatorului Telegram")
|
||||
telegram_username: Optional[str] = Field(default=None, description="Username-ul Telegram")
|
||||
telegram_first_name: Optional[str] = Field(default=None, description="Prenumele utilizatorului")
|
||||
telegram_last_name: Optional[str] = Field(default=None, description="Numele utilizatorului")
|
||||
|
||||
|
||||
class GenerateCodeResponse(BaseModel):
|
||||
"""Response pentru generarea unui cod de linking"""
|
||||
linking_code: str = Field(description="Codul de linking generat (8 caractere)")
|
||||
expires_at: datetime = Field(description="Data și ora expirării codului")
|
||||
expires_in_minutes: int = Field(description="Minutele până la expirare")
|
||||
|
||||
|
||||
class VerifyUserRequest(BaseModel):
|
||||
"""
|
||||
Request pentru verificarea utilizatorului în Oracle
|
||||
|
||||
Suportă 2 flow-uri:
|
||||
1. Auto-linking (recomandat): doar linking_code și oracle_username
|
||||
- Bot-ul verifică codul în SQLite, extrage oracle_username
|
||||
- Backend face lookup în Oracle fără verificare parolă
|
||||
- Codul valid este proof-of-authorization
|
||||
|
||||
2. Full verification (opțional): username, password, linking_code
|
||||
- Verificare completă cu parolă în Oracle
|
||||
"""
|
||||
linking_code: str = Field(description="Codul de linking de la /generate-code")
|
||||
oracle_username: Optional[str] = Field(default=None, description="Username Oracle (pentru auto-linking)")
|
||||
username: Optional[str] = Field(default=None, description="Username pentru verificare completă")
|
||||
password: Optional[str] = Field(default=None, description="Parolă pentru verificare completă")
|
||||
|
||||
|
||||
class VerifyUserResponse(BaseModel):
|
||||
"""Response pentru verificarea utilizatorului"""
|
||||
success: bool = Field(description="True dacă verificarea a avut succes")
|
||||
access_token: Optional[str] = Field(default=None, description="JWT access token")
|
||||
refresh_token: Optional[str] = Field(default=None, description="JWT refresh token")
|
||||
user: Optional[Dict[str, Any]] = Field(default=None, description="Detalii utilizator")
|
||||
message: str = Field(description="Mesaj de status")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request pentru refresh JWT token"""
|
||||
refresh_token: str = Field(description="Refresh token-ul obținut la autentificare")
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response pentru refresh token"""
|
||||
access_token: str = Field(description="Noul JWT access token")
|
||||
expires_in: int = Field(description="Timpul de expirare în secunde")
|
||||
token_type: str = Field(default="bearer", description="Tipul token-ului")
|
||||
|
||||
|
||||
class ExportReportRequest(BaseModel):
|
||||
"""Request pentru exportul unui raport"""
|
||||
company_id: int = Field(description="ID-ul firmei")
|
||||
report_type: str = Field(description="Tipul raportului (invoices, payments, dashboard)")
|
||||
format: str = Field(default="excel", description="Formatul exportului (excel, pdf, csv)")
|
||||
filters: Optional[Dict[str, Any]] = Field(default=None, description="Filtre pentru raport")
|
||||
|
||||
|
||||
class ExportReportResponse(BaseModel):
|
||||
"""Response pentru exportul raportului"""
|
||||
success: bool = Field(description="True dacă exportul a avut succes")
|
||||
file_url: Optional[str] = Field(default=None, description="URL-ul fișierului generat")
|
||||
file_name: Optional[str] = Field(default=None, description="Numele fișierului generat")
|
||||
file_size_bytes: Optional[int] = Field(default=None, description="Mărimea fișierului în bytes")
|
||||
message: str = Field(description="Mesaj de status")
|
||||
|
||||
|
||||
class VerifyEmailRequest(BaseModel):
|
||||
"""Request pentru verificarea email-ului în Oracle"""
|
||||
email: str = Field(description="Adresa de email Oracle")
|
||||
|
||||
|
||||
class VerifyEmailResponse(BaseModel):
|
||||
"""Response pentru verificarea email-ului"""
|
||||
success: bool = Field(description="True dacă email-ul există și este activ")
|
||||
username: Optional[str] = Field(default=None, description="Username-ul Oracle asociat")
|
||||
message: str = Field(description="Mesaj de status")
|
||||
|
||||
|
||||
class TelegramEmailLoginRequest(BaseModel):
|
||||
"""Request pentru autentificare prin email + parolă"""
|
||||
email: str = Field(description="Adresa de email Oracle")
|
||||
password: str = Field(description="Parola Oracle")
|
||||
telegram_user_id: int = Field(description="ID-ul utilizatorului Telegram")
|
||||
session_token: str = Field(description="Token de sesiune pentru preveni spoofing")
|
||||
|
||||
|
||||
class TelegramEmailLoginResponse(BaseModel):
|
||||
"""Response pentru autentificare prin email + parolă"""
|
||||
success: bool = Field(description="True dacă autentificarea a avut succes")
|
||||
access_token: Optional[str] = Field(default=None, description="JWT access token")
|
||||
refresh_token: Optional[str] = Field(default=None, description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer", description="Tipul token-ului")
|
||||
user_id: Optional[int] = Field(default=None, description="ID-ul utilizatorului Oracle")
|
||||
username: Optional[str] = Field(default=None, description="Username-ul Oracle")
|
||||
companies: List[Dict[str, Any]] = Field(default_factory=list, description="Lista companiilor")
|
||||
message: str = Field(description="Mesaj de status")
|
||||
|
||||
|
||||
# ==================== Helper Functions ====================
|
||||
|
||||
# Rate limiting storage (in-memory)
|
||||
from collections import defaultdict
|
||||
_endpoint_rate_limits = defaultdict(list)
|
||||
|
||||
|
||||
def check_endpoint_rate_limit(
|
||||
identifier: str,
|
||||
max_attempts: int = 5,
|
||||
window_minutes: int = 5
|
||||
) -> bool:
|
||||
"""Backend rate limiting for sensitive endpoints"""
|
||||
now = datetime.now()
|
||||
cutoff = now - timedelta(minutes=window_minutes)
|
||||
|
||||
# Clean old attempts
|
||||
_endpoint_rate_limits[identifier] = [
|
||||
attempt for attempt in _endpoint_rate_limits[identifier]
|
||||
if attempt > cutoff
|
||||
]
|
||||
|
||||
# Check limit
|
||||
if len(_endpoint_rate_limits[identifier]) >= max_attempts:
|
||||
return False
|
||||
|
||||
# Add attempt
|
||||
_endpoint_rate_limits[identifier].append(now)
|
||||
return True
|
||||
|
||||
|
||||
def verify_session_token(
|
||||
telegram_user_id: int,
|
||||
email: str,
|
||||
token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify session token from bot to prevent user ID spoofing
|
||||
|
||||
Token format: user_id:email:signature
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
|
||||
token_user_id, token_email, signature = parts
|
||||
|
||||
# Verify user ID and email match
|
||||
if int(token_user_id) != telegram_user_id or token_email != email:
|
||||
return False
|
||||
|
||||
# Verify signature
|
||||
secret = os.getenv("AUTH_SESSION_SECRET", "change-me-in-production")
|
||||
payload = f"{telegram_user_id}:{email}:{secret}"
|
||||
expected_signature = hashlib.sha256(payload.encode()).hexdigest()[:16]
|
||||
|
||||
if signature != expected_signature:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def generate_linking_code(length: int = 8) -> str:
|
||||
"""
|
||||
Generează un cod alfanumeric aleatoriu pentru linking
|
||||
|
||||
Args:
|
||||
length: Lungimea codului (default: 8)
|
||||
|
||||
Returns:
|
||||
Codul generat (uppercase alphanumeric)
|
||||
"""
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
# Exclude caractere care pot fi confundate: 0, O, I, 1
|
||||
alphabet = alphabet.replace('0', '').replace('O', '').replace('I', '').replace('1', '')
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
async def get_oracle_user_by_username(username: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Obține informații despre utilizator din Oracle FĂRĂ verificare parolă.
|
||||
|
||||
Folosit pentru auto-linking când utilizatorul a fost deja autentificat
|
||||
prin generarea unui linking code valid în aplicația web.
|
||||
|
||||
Args:
|
||||
username: Username-ul utilizatorului Oracle
|
||||
|
||||
Returns:
|
||||
Dict cu informații despre utilizator sau None dacă nu există
|
||||
"""
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Obține detalii utilizator
|
||||
cursor.execute("""
|
||||
SELECT ID_UTIL, UTILIZATOR
|
||||
FROM UTILIZATORI
|
||||
WHERE UPPER(UTILIZATOR) = :username
|
||||
""", {'username': username.upper()})
|
||||
|
||||
user_row = cursor.fetchone()
|
||||
if not user_row:
|
||||
return None
|
||||
|
||||
user_id = user_row[0]
|
||||
actual_username = user_row[1]
|
||||
|
||||
# Obține companiile utilizatorului
|
||||
cursor.execute("""
|
||||
SELECT A.ID_FIRMA, A.FIRMA
|
||||
FROM V_NOM_FIRME A
|
||||
WHERE A.ID_FIRMA IN (
|
||||
SELECT ID_FIRMA
|
||||
FROM VDEF_UTIL_FIRME
|
||||
WHERE ID_PROGRAM = 2
|
||||
AND ID_UTIL = :user_id
|
||||
)
|
||||
ORDER BY A.FIRMA
|
||||
""", {'user_id': user_id})
|
||||
|
||||
companies_result = cursor.fetchall()
|
||||
companies = [str(row[0]) for row in companies_result]
|
||||
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'username': actual_username,
|
||||
'companies': companies,
|
||||
'permissions': ['read', 'reports']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting Oracle user by username: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_oracle_user(username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verifică utilizatorul în Oracle folosind pack_drepturi.verificautilizator
|
||||
|
||||
Args:
|
||||
username: Username-ul utilizatorului
|
||||
password: Parola utilizatorului
|
||||
|
||||
Returns:
|
||||
Dict cu informații despre utilizator sau None dacă verificarea eșuează
|
||||
"""
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Verifică autentificarea
|
||||
cursor.execute("""
|
||||
SELECT pack_drepturi.verificautilizator(:username, :password)
|
||||
FROM DUAL
|
||||
""", {
|
||||
'username': username.upper(),
|
||||
'password': password
|
||||
})
|
||||
|
||||
result = cursor.fetchone()
|
||||
verification_result = result[0] if result else -1
|
||||
|
||||
if verification_result == -1:
|
||||
return None
|
||||
|
||||
# Obține detalii utilizator
|
||||
cursor.execute("""
|
||||
SELECT ID_UTIL, UTILIZATOR
|
||||
FROM UTILIZATORI
|
||||
WHERE UPPER(UTILIZATOR) = :username
|
||||
""", {'username': username.upper()})
|
||||
|
||||
user_row = cursor.fetchone()
|
||||
if not user_row:
|
||||
return None
|
||||
|
||||
user_id = user_row[0]
|
||||
|
||||
# Obține companiile utilizatorului
|
||||
cursor.execute("""
|
||||
SELECT A.ID_FIRMA, A.FIRMA
|
||||
FROM V_NOM_FIRME A
|
||||
WHERE A.ID_FIRMA IN (
|
||||
SELECT ID_FIRMA
|
||||
FROM VDEF_UTIL_FIRME
|
||||
WHERE ID_PROGRAM = 2
|
||||
AND ID_UTIL = :user_id
|
||||
)
|
||||
ORDER BY A.FIRMA
|
||||
""", {'user_id': user_id})
|
||||
|
||||
companies_result = cursor.fetchall()
|
||||
companies = [str(row[0]) for row in companies_result]
|
||||
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'companies': companies,
|
||||
'permissions': ['read', 'reports']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error verifying Oracle user: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ==================== Endpoints ====================
|
||||
|
||||
@router.post("/auth/generate-code", response_model=GenerateCodeResponse)
|
||||
async def generate_linking_code_endpoint(
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Generează un cod de linking pentru conectarea unui utilizator Telegram
|
||||
|
||||
Flow:
|
||||
1. Utilizatorul autentificat în aplicație solicită un cod
|
||||
2. Se generează un cod unic de 8 caractere
|
||||
3. Codul este trimis la Telegram bot pentru salvare în SQLite cu TTL de 15 minute
|
||||
4. Utilizatorul introduce codul în Telegram bot pentru linking
|
||||
|
||||
Note:
|
||||
- Acest endpoint necesită autentificare JWT (utilizatorul trebuie să fie logat în aplicație)
|
||||
- Codul expiră după 15 minute
|
||||
- Fiecare request generează un cod nou (codurile vechi devin invalide)
|
||||
- Nu este nevoie de telegram_user_id în acest moment (utilizatorul nu e încă conectat la Telegram)
|
||||
"""
|
||||
try:
|
||||
# Generează cod unic
|
||||
linking_code = generate_linking_code()
|
||||
|
||||
# Setează expirarea la 15 minute
|
||||
expires_at = datetime.utcnow() + timedelta(minutes=15)
|
||||
expires_in_minutes = 15
|
||||
|
||||
# Salvează codul în database-ul Telegram bot (SQLite) via internal API
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
save_code_response = await client.post(
|
||||
f"{TELEGRAM_BOT_INTERNAL_API}/internal/save-code",
|
||||
json={
|
||||
"code": linking_code,
|
||||
"telegram_user_id": 0, # Not known yet (user hasn't linked)
|
||||
"oracle_username": current_user.username,
|
||||
"expires_in_minutes": expires_in_minutes
|
||||
}
|
||||
)
|
||||
|
||||
# Accept both 200 (OK) and 201 (Created) as success
|
||||
if save_code_response.status_code not in [200, 201]:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to save code to Telegram bot: {save_code_response.text}"
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Telegram bot service is not responding. Please try again later."
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Cannot connect to Telegram bot service. Please contact administrator."
|
||||
)
|
||||
|
||||
return GenerateCodeResponse(
|
||||
linking_code=linking_code,
|
||||
expires_at=expires_at,
|
||||
expires_in_minutes=expires_in_minutes
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la generarea codului de linking: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/verify-user", response_model=VerifyUserResponse)
|
||||
async def verify_user_endpoint(request: VerifyUserRequest):
|
||||
"""
|
||||
Verifică utilizatorul în Oracle și returnează JWT tokens
|
||||
|
||||
Suportă 2 flow-uri de autentificare:
|
||||
|
||||
Flow A - Auto-linking (RECOMANDAT):
|
||||
1. Bot verifică linking_code în SQLite (code valid = user s-a autentificat în web app)
|
||||
2. Bot extrage oracle_username din cod
|
||||
3. Bot trimite: {linking_code, oracle_username}
|
||||
4. Backend face lookup în Oracle (FĂRĂ verificare parolă)
|
||||
5. Backend generează și returnează JWT tokens
|
||||
|
||||
Flow B - Full verification (OPȚIONAL):
|
||||
1. Bot cere username și parolă de la user în Telegram
|
||||
2. Bot trimite: {linking_code, username, password}
|
||||
3. Backend verifică credențialele în Oracle
|
||||
4. Backend generează și returnează JWT tokens
|
||||
|
||||
Note:
|
||||
- Acest endpoint NU necesită autentificare JWT (este public pentru bot)
|
||||
- Flow A oferă UX superior (fără re-introducere parolă)
|
||||
- Linking code-ul valid este proof-of-authorization
|
||||
"""
|
||||
try:
|
||||
# Flow A: Auto-linking (oracle_username provided, no password)
|
||||
if request.oracle_username and not request.password:
|
||||
user_data = await get_oracle_user_by_username(request.oracle_username)
|
||||
|
||||
if not user_data:
|
||||
return VerifyUserResponse(
|
||||
success=False,
|
||||
message=f"Utilizatorul {request.oracle_username} nu există în Oracle"
|
||||
)
|
||||
|
||||
# Flow B: Full verification (username + password provided)
|
||||
elif request.username and request.password:
|
||||
user_data = await verify_oracle_user(request.username, request.password)
|
||||
|
||||
if not user_data:
|
||||
return VerifyUserResponse(
|
||||
success=False,
|
||||
message="Username sau parolă incorectă"
|
||||
)
|
||||
|
||||
# Invalid request (missing required fields)
|
||||
else:
|
||||
return VerifyUserResponse(
|
||||
success=False,
|
||||
message="Trebuie furnizat fie oracle_username (auto-linking) fie username+password (verificare completă)"
|
||||
)
|
||||
|
||||
# Generează JWT tokens
|
||||
access_token = jwt_handler.create_access_token(
|
||||
username=user_data['username'],
|
||||
companies=user_data['companies'],
|
||||
user_id=user_data['user_id'],
|
||||
permissions=user_data['permissions']
|
||||
)
|
||||
|
||||
refresh_token = jwt_handler.create_refresh_token(
|
||||
username=user_data['username'],
|
||||
user_id=user_data['user_id']
|
||||
)
|
||||
|
||||
return VerifyUserResponse(
|
||||
success=True,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user={
|
||||
'user_id': user_data['user_id'],
|
||||
'username': user_data['username'],
|
||||
'companies': user_data['companies'],
|
||||
'permissions': user_data['permissions']
|
||||
},
|
||||
message="Autentificare reușită"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la verificarea utilizatorului: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/refresh-token", response_model=RefreshTokenResponse)
|
||||
async def refresh_token_endpoint(request: RefreshTokenRequest):
|
||||
"""
|
||||
Refresh-uiește un JWT access token folosind refresh token-ul
|
||||
|
||||
Acest endpoint este folosit de Telegram bot pentru a obține un nou access token
|
||||
când cel curent expiră, fără a solicita din nou username/password.
|
||||
|
||||
Flow:
|
||||
1. Botul Telegram detectează că access token-ul a expirat
|
||||
2. Trimite refresh token-ul la acest endpoint
|
||||
3. Se validează refresh token-ul și se generează un nou access token
|
||||
4. Botul stochează noul access token în SQLite
|
||||
|
||||
Note:
|
||||
- Refresh token-ul este valid 7 zile (vs 30 minute pentru access token)
|
||||
- Dacă refresh token-ul expiră, utilizatorul trebuie să se re-autentifice
|
||||
"""
|
||||
try:
|
||||
# Verifică refresh token-ul
|
||||
token_data = jwt_handler.verify_token(request.refresh_token)
|
||||
|
||||
if not token_data or token_data.token_type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Refresh token invalid sau expirat"
|
||||
)
|
||||
|
||||
# Obține companiile actualizate din Oracle
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT A.ID_FIRMA
|
||||
FROM V_NOM_FIRME A
|
||||
WHERE A.ID_FIRMA IN (
|
||||
SELECT ID_FIRMA
|
||||
FROM VDEF_UTIL_FIRME
|
||||
WHERE ID_PROGRAM = 2
|
||||
AND ID_UTIL = :user_id
|
||||
)
|
||||
ORDER BY A.FIRMA
|
||||
""", {'user_id': token_data.user_id})
|
||||
|
||||
companies_result = cursor.fetchall()
|
||||
companies = [str(row[0]) for row in companies_result]
|
||||
|
||||
# Generează nou access token
|
||||
new_access_token = jwt_handler.create_access_token(
|
||||
username=token_data.username,
|
||||
companies=companies,
|
||||
user_id=token_data.user_id,
|
||||
permissions=token_data.permissions
|
||||
)
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=new_access_token,
|
||||
expires_in=jwt_handler.access_token_expire_minutes * 60,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la refresh token: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/verify-email", response_model=VerifyEmailResponse)
|
||||
async def verify_email_endpoint(request: VerifyEmailRequest):
|
||||
"""
|
||||
Verify if email exists in Oracle UTILIZATORI table (PUBLIC endpoint)
|
||||
|
||||
This is a PUBLIC endpoint used by the telegram bot during email authentication.
|
||||
Returns username if email exists and user is active.
|
||||
|
||||
Security: Generic error messages to prevent email enumeration.
|
||||
"""
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Query to find username by email
|
||||
cursor.execute("""
|
||||
SELECT UTILIZATOR
|
||||
FROM CONTAFIN_ORACLE.UTILIZATORI
|
||||
WHERE UPPER(EMAIL) = UPPER(:email)
|
||||
AND INACTIV = 0
|
||||
AND STERS = 0
|
||||
""", {"email": request.email})
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
username = row[0]
|
||||
return VerifyEmailResponse(
|
||||
success=True,
|
||||
username=username,
|
||||
message="Email verificat cu succes"
|
||||
)
|
||||
else:
|
||||
# Generic message (no enumeration)
|
||||
return VerifyEmailResponse(
|
||||
success=False,
|
||||
username=None,
|
||||
message="Email invalid sau inactiv"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Generic error message (no details exposed)
|
||||
return VerifyEmailResponse(
|
||||
success=False,
|
||||
username=None,
|
||||
message="Eroare la verificarea email-ului"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/login-with-email", response_model=TelegramEmailLoginResponse)
|
||||
async def login_with_email_endpoint(request: TelegramEmailLoginRequest):
|
||||
"""
|
||||
Telegram email + password authentication endpoint
|
||||
|
||||
Security features:
|
||||
- Rate limiting: 5 attempts per 5 minutes
|
||||
- Session token verification (prevent user ID spoofing)
|
||||
- Generic error messages (no username/email enumeration)
|
||||
- Password verification in Oracle (not stored)
|
||||
"""
|
||||
|
||||
# 1. Rate limiting
|
||||
rate_limit_key = f"email_login_{request.telegram_user_id}"
|
||||
if not check_endpoint_rate_limit(rate_limit_key, max_attempts=5, window_minutes=5):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Prea multe încercări. Te rugăm să aștepți 5 minute."
|
||||
)
|
||||
|
||||
# 2. Verify session token (prevent user ID spoofing)
|
||||
if not verify_session_token(
|
||||
request.telegram_user_id,
|
||||
request.email,
|
||||
request.session_token
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Sesiune invalidă. Te rugăm să reîncepi autentificarea."
|
||||
)
|
||||
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# 3. Find username by email
|
||||
cursor.execute("""
|
||||
SELECT ID_UTIL, UTILIZATOR, INACTIV, STERS
|
||||
FROM CONTAFIN_ORACLE.UTILIZATORI
|
||||
WHERE UPPER(EMAIL) = UPPER(:email)
|
||||
""", {"email": request.email})
|
||||
|
||||
user_row = cursor.fetchone()
|
||||
|
||||
# SECURITY: Generic error message (no email enumeration)
|
||||
if not user_row:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credențiale invalide" # Generic message
|
||||
)
|
||||
|
||||
user_id, username, inactiv, sters = user_row
|
||||
|
||||
# Check if user is active (INACTIV=0 means active, STERS=0 means not deleted)
|
||||
if inactiv != 0 or sters != 0:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credențiale invalide" # Generic message
|
||||
)
|
||||
|
||||
# 4. Verify password via Oracle stored procedure
|
||||
# NOTE: This procedure returns a verification code, NOT the user_id!
|
||||
# Returns -1 if authentication fails, any other value means success
|
||||
cursor.execute("""
|
||||
SELECT pack_drepturi.verificautilizator(:username, :password)
|
||||
FROM DUAL
|
||||
""", {
|
||||
"username": username.upper(), # IMPORTANT: Oracle usernames are uppercase
|
||||
"password": request.password
|
||||
})
|
||||
|
||||
verification_result = cursor.fetchone()[0]
|
||||
|
||||
# SECURITY: Generic error message (no username leak)
|
||||
if verification_result == -1:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credențiale invalide" # Generic message
|
||||
)
|
||||
|
||||
# 5. Get user companies
|
||||
cursor.execute("""
|
||||
SELECT A.ID_FIRMA, A.FIRMA
|
||||
FROM V_NOM_FIRME A
|
||||
WHERE A.ID_FIRMA IN (
|
||||
SELECT ID_FIRMA
|
||||
FROM VDEF_UTIL_FIRME
|
||||
WHERE ID_PROGRAM = 2
|
||||
AND ID_UTIL = :user_id
|
||||
)
|
||||
ORDER BY A.FIRMA
|
||||
""", {'user_id': user_id})
|
||||
|
||||
companies_result = cursor.fetchall()
|
||||
companies = [
|
||||
{"id": str(row[0]), "name": row[1]}
|
||||
for row in companies_result
|
||||
]
|
||||
company_ids = [str(row[0]) for row in companies_result]
|
||||
|
||||
# 6. Get user permissions (default for Telegram)
|
||||
permissions = ['read', 'reports']
|
||||
|
||||
# 7. Generate JWT tokens
|
||||
token_data = {
|
||||
"username": username,
|
||||
"user_id": user_id,
|
||||
"companies": company_ids,
|
||||
"permissions": permissions
|
||||
}
|
||||
|
||||
access_token = jwt_handler.create_access_token(**token_data)
|
||||
refresh_token = jwt_handler.create_refresh_token(
|
||||
username=username,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return TelegramEmailLoginResponse(
|
||||
success=True,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
companies=companies,
|
||||
message="Autentificare reușită"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in login_with_email: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Eroare internă. Te rugăm să încerci din nou mai târziu."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/export", response_model=ExportReportResponse)
|
||||
async def export_report_endpoint(
|
||||
request: ExportReportRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Exportă un raport în format Excel, PDF sau CSV
|
||||
|
||||
Acest endpoint este folosit de Telegram bot pentru a genera rapoarte
|
||||
și a le trimite utilizatorului.
|
||||
|
||||
Flow:
|
||||
1. Botul trimite cerere de export cu parametrii raportului
|
||||
2. Se validează că utilizatorul are acces la firma specificată
|
||||
3. Se generează raportul în formatul solicitat
|
||||
4. Se returnează URL-ul sau conținutul fișierului
|
||||
|
||||
Tipuri de rapoarte suportate:
|
||||
- invoices: Facturi (cu filtre: dată, status, client)
|
||||
- payments: Încasări (cu filtre: dată, metodă plată)
|
||||
- dashboard: Statistici dashboard (rezumat)
|
||||
|
||||
Formate suportate:
|
||||
- excel: XLSX (cel mai complet)
|
||||
- pdf: PDF (pentru printing)
|
||||
- csv: CSV (pentru import în alte sisteme)
|
||||
|
||||
Note:
|
||||
- Utilizatorul trebuie să aibă acces la firma specificată
|
||||
- Fișierele generate sunt temporare (șterse după 1 oră)
|
||||
"""
|
||||
try:
|
||||
# Verifică accesul la firmă
|
||||
company_id_str = str(request.company_id)
|
||||
if company_id_str not in current_user.companies:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Nu aveți acces la firma {request.company_id}"
|
||||
)
|
||||
|
||||
# TODO: Implementare export în funcție de report_type și format
|
||||
# Deocamdată returnăm un placeholder
|
||||
|
||||
return ExportReportResponse(
|
||||
success=True,
|
||||
file_url=f"/api/telegram/downloads/report_{request.report_type}_{request.company_id}.{request.format}",
|
||||
file_name=f"raport_{request.report_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{request.format}",
|
||||
file_size_bytes=0,
|
||||
message=f"Raport {request.report_type} generat cu succes în format {request.format}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Eroare la generarea raportului: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def telegram_health_check():
|
||||
"""
|
||||
Health check pentru routerul Telegram
|
||||
Verifică conectivitatea la Oracle și disponibilitatea serviciilor
|
||||
"""
|
||||
try:
|
||||
async with oracle_pool.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1 FROM DUAL")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "telegram-router",
|
||||
"database": "connected",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "degraded",
|
||||
"service": "telegram-router",
|
||||
"database": f"error: {str(e)}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user