feat: Migrate to ultrathin monolith architecture

Consolidate 3 separate applications (reports-app, data-entry-app, telegram-bot) into a unified
architecture with single backend and frontend:

Backend Changes:
- Unified FastAPI backend at backend/ with modular structure
- Modules: reports, data_entry, telegram in backend/modules/
- Centralized config.py and main.py with all routers registered
- Single worker mode (--workers 1) for Telegram bot compatibility
- Shared Oracle connection pool and JWT authentication
- Unified requirements.txt and environment configuration

Frontend Changes:
- Single Vue.js SPA with module-based routing
- Unified frontend at src/ with modules in src/modules/{reports,data-entry}/
- Shared components and stores in src/shared/
- Error boundaries for module isolation
- Dual API proxy in Vite for module communication

Infrastructure:
- New unified startup scripts: start-prod.sh, start-test.sh, start-backend.sh
- Environment templates: .env.dev.example, .env.test.example, .env.prod.example
- Updated deployment scripts for Windows IIS
- Simplified SSH tunnel management

Documentation:
- Comprehensive CLAUDE.md with architecture overview
- Module-specific docs in docs/{data-entry,telegram}/
- Architecture decision records in docs/ARCHITECTURE-DECISIONS.md
- Deployment guides consolidated in deployment/windows/docs/

This migration reduces complexity, improves maintainability, and enables easier
deployment while maintaining all existing functionality.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-29 23:48:14 +02:00
parent 2a101f1ef5
commit c5e051ad80
378 changed files with 7566 additions and 73730 deletions

147
backend/.env.dev.example Normal file
View File

@@ -0,0 +1,147 @@
# ============================================================================
# ROA2WEB Unified Backend - Environment Configuration (Development)
# ============================================================================
# Single backend process serving Reports, Data Entry, and Telegram modules
# IMPORTANT: Never commit this file to git!
# ============================================================================
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
# ============================================================================
# Connection to CONTAFIN_ORACLE schema for authentication and user management
# Each company is a separate schema in Oracle Database
# Development: Through SSH tunnel (localhost:1526)
ORACLE_USER=CONTAFIN_ORACLE
ORACLE_PASSWORD=your_oracle_password_here
ORACLE_HOST=localhost
ORACLE_PORT=1526
ORACLE_SID=ROA
# Development: Start SSH tunnel before running backend
# ./ssh_tunnel.sh start (production) or ./ssh-tunnel-test.sh start (test)
# ============================================================================
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
# ============================================================================
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
JWT_ALGORITHM=HS256
# Token expiration settings (used by shared/auth/jwt_handler.py)
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
# ============================================================================
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
# ============================================================================
# Used by Telegram module for session token validation
# Must match between backend and Telegram bot
AUTH_SESSION_SECRET=generate_with_secrets_token_urlsafe_32
# ============================================================================
# SERVER CONFIGURATION
# ============================================================================
# Unified backend server settings
API_HOST=0.0.0.0
API_PORT=8000
DEBUG=true
# CORS Origins (comma-separated, includes both old and new frontend ports)
CORS_ORIGINS=http://localhost:3000,http://localhost:3010,http://localhost:5173
# ============================================================================
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
# ============================================================================
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
# Used by backend/modules/reports/cache/config.py
# Core Settings
CACHE_ENABLED=True
CACHE_TYPE=hybrid
CACHE_SQLITE_PATH=./data/cache/roa2web_cache.db
CACHE_MEMORY_MAX_SIZE=1000
CACHE_DEFAULT_TTL=900
# TTL per Cache Type (seconds)
CACHE_TTL_SCHEMA=86400
CACHE_TTL_COMPANIES=1800
CACHE_TTL_DASHBOARD_SUMMARY=1800
CACHE_TTL_DASHBOARD_TRENDS=1800
CACHE_TTL_INVOICES=600
CACHE_TTL_INVOICES_SUMMARY=900
CACHE_TTL_TREASURY=600
# Maintenance
CACHE_CLEANUP_INTERVAL=3600
# Event-Based Invalidation (experimental)
CACHE_AUTO_INVALIDATE=False
CACHE_CHECK_INTERVAL=300
# Performance Tracking
CACHE_TRACK_PERFORMANCE=True
CACHE_BENCHMARK_ON_STARTUP=False
# ============================================================================
# DATA ENTRY MODULE - CONFIGURATION
# ============================================================================
# Data Entry module settings (receipts, OCR, etc.)
# Environment identifier (dev/test/prod)
ORACLE_ENV=dev
# SQLite Database (development)
SQLITE_DATABASE_PATH=data/receipts/receipts_dev.db
# File uploads
UPLOAD_PATH=data/receipts/uploads
MAX_UPLOAD_SIZE_MB=10
# Test company (for development testing)
TEST_COMPANY_ID=110
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================
# Obtain bot token from @BotFather on Telegram
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
# Backend URL for bot to communicate with API
BACKEND_URL=http://localhost:8000
# Internal API port (bot's internal API for backend callbacks)
INTERNAL_API_PORT=8002
# Enable internal API documentation (development only)
ENABLE_DOCS=false
# ============================================================================
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
# ============================================================================
# Required for email-based 2FA authentication flow
# Users can login with email + password instead of web app linking
# SMTP Server Configuration
SMTP_HOST=mail.romfast.ro
SMTP_PORT=587
SMTP_USER=ups@romfast.ro
SMTP_PASSWORD=your_smtp_password_here
SMTP_FROM_EMAIL=ups@romfast.ro
SMTP_FROM_NAME=ROA2WEB
SMTP_USE_TLS=true
# Email Retry Settings
EMAIL_MAX_RETRIES=3
EMAIL_RETRY_DELAY=2.0
# ============================================================================
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
# ============================================================================
# Separate SQLite database for Telegram bot auth codes and sessions
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram.db

146
backend/.env.example Normal file
View File

@@ -0,0 +1,146 @@
# ============================================================================
# ROA2WEB Unified Backend - Environment Configuration Template
# ============================================================================
# Single backend process serving Reports, Data Entry, and Telegram modules
#
# SETUP INSTRUCTIONS:
# 1. Copy this template: cp .env.example .env.dev
# 2. Fill in your actual values in .env.dev
# 3. Run: ./start-dev.sh (auto-copies .env.dev to .env)
#
# ENVIRONMENT FILES:
# - .env.dev → Development config (committed to git with real values)
# - .env.test → Test config (committed to git)
# - .env.prod → Production config template (committed, use placeholders!)
# - .env → Active config (auto-generated, NOT committed)
#
# IMPORTANT: Never manually edit .env - edit .env.dev instead!
# ============================================================================
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
# ============================================================================
# Connection to CONTAFIN_ORACLE schema for authentication and user management
# Each company is a separate schema in Oracle Database
# Development: Through SSH tunnel (localhost:1526)
# Windows Production: Direct connection to Oracle server
ORACLE_USER=CONTAFIN_ORACLE
ORACLE_PASSWORD=SET_IN_PRODUCTION_ENV
ORACLE_HOST=localhost
ORACLE_PORT=1526
ORACLE_SID=ROA
# Development Only: Start SSH tunnel before running backend
# ./ssh_tunnel.sh start
# ./ssh_tunnel.sh status
# ============================================================================
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
# ============================================================================
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
# Generate strong secret: python3 -c "import secrets; print(secrets.token_urlsafe(32))"
JWT_SECRET_KEY=GENERATE_STRONG_SECRET_IN_PRODUCTION
JWT_ALGORITHM=HS256
# Token expiration settings (used by shared/auth/jwt_handler.py)
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
# ============================================================================
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
# ============================================================================
# Used by Telegram module for session token validation
# Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))"
AUTH_SESSION_SECRET=your-secure-random-secret-here-min-32-chars
# ============================================================================
# SERVER CONFIGURATION
# ============================================================================
# Unified backend server settings
API_HOST=0.0.0.0
API_PORT=8000
DEBUG=false
# CORS Origins (comma-separated)
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
# ============================================================================
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
# ============================================================================
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
# Used by backend/modules/reports/cache/config.py
# Core Settings
REPORTS_CACHE_ENABLED=True
REPORTS_CACHE_TYPE=hybrid
REPORTS_CACHE_SQLITE_PATH=./data/cache/roa2web_cache.db
REPORTS_CACHE_MEMORY_MAX_SIZE=1000
REPORTS_CACHE_DEFAULT_TTL=900
# TTL per Cache Type (seconds)
REPORTS_CACHE_TTL_SCHEMA=86400
REPORTS_CACHE_TTL_COMPANIES=1800
REPORTS_CACHE_TTL_DASHBOARD_SUMMARY=1800
REPORTS_CACHE_TTL_DASHBOARD_TRENDS=1800
REPORTS_CACHE_TTL_INVOICES=600
REPORTS_CACHE_TTL_INVOICES_SUMMARY=900
REPORTS_CACHE_TTL_TREASURY=600
# Maintenance
REPORTS_CACHE_CLEANUP_INTERVAL=3600
# Event-Based Invalidation (experimental)
REPORTS_CACHE_AUTO_INVALIDATE=False
REPORTS_CACHE_CHECK_INTERVAL=300
# Performance Tracking
REPORTS_CACHE_TRACK_PERFORMANCE=True
REPORTS_CACHE_BENCHMARK_ON_STARTUP=False
# ============================================================================
# DATA ENTRY MODULE - CONFIGURATION
# ============================================================================
# Data Entry module settings (receipts, OCR, etc.)
# SQLite Database
DATA_ENTRY_SQLITE_DATABASE_PATH=data/receipts/receipts.db
# File uploads
DATA_ENTRY_UPLOAD_PATH=data/receipts/uploads
DATA_ENTRY_MAX_UPLOAD_SIZE_MB=10
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================
# Obtain bot token from @BotFather on Telegram
TELEGRAM_BOT_TOKEN=your_bot_token_here
# ============================================================================
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
# ============================================================================
# Required for email-based 2FA authentication flow
# Users can login with email + password instead of web app linking
# SMTP Server Configuration
TELEGRAM_SMTP_HOST=mail.romfast.ro
TELEGRAM_SMTP_PORT=587
TELEGRAM_SMTP_USER=ups@romfast.ro
TELEGRAM_SMTP_PASSWORD=your_smtp_password_here
TELEGRAM_SMTP_FROM_EMAIL=ups@romfast.ro
TELEGRAM_SMTP_FROM_NAME=ROA2WEB
TELEGRAM_SMTP_USE_TLS=true
# Email Retry Settings
TELEGRAM_EMAIL_MAX_RETRIES=3
TELEGRAM_EMAIL_RETRY_DELAY=2.0
# ============================================================================
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
# ============================================================================
# Separate SQLite database for Telegram bot auth codes and sessions
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram.db

139
backend/.env.prod.example Normal file
View File

@@ -0,0 +1,139 @@
# ============================================================================
# ROA2WEB Unified Backend - Environment Configuration (PRODUCTION)
# ============================================================================
# Single backend process serving Reports, Data Entry, and Telegram modules
# IMPORTANT: This is a TEMPLATE - fill in production values before deploying!
# ============================================================================
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
# ============================================================================
# Connection to CONTAFIN_ORACLE schema for authentication and user management
# PRODUCTION: Direct connection to Oracle server (no SSH tunnel)
ORACLE_USER=CONTAFIN_ORACLE
ORACLE_PASSWORD=CHANGE_IN_PRODUCTION
ORACLE_HOST=your_oracle_server_ip_or_hostname
ORACLE_PORT=1521
ORACLE_SID=ROA
# ============================================================================
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
# ============================================================================
# CRITICAL: Generate new secrets for production!
# python3 -c "import secrets; print(secrets.token_urlsafe(32))"
JWT_SECRET_KEY=GENERATE_NEW_SECRET_FOR_PRODUCTION
JWT_ALGORITHM=HS256
# Token expiration settings
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
# ============================================================================
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
# ============================================================================
# CRITICAL: Generate new secret for production!
# python3 -c "import secrets; print(secrets.token_urlsafe(32))"
AUTH_SESSION_SECRET=GENERATE_NEW_SECRET_FOR_PRODUCTION
# ============================================================================
# SERVER CONFIGURATION
# ============================================================================
# Unified backend server settings
API_HOST=0.0.0.0
API_PORT=8000
DEBUG=false
# CORS Origins (comma-separated) - Update with production frontend URL
CORS_ORIGINS=https://your-production-domain.com,http://localhost:3000
# ============================================================================
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
# ============================================================================
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
# Core Settings
CACHE_ENABLED=True
CACHE_TYPE=hybrid
CACHE_SQLITE_PATH=./data/cache/roa2web_cache_prod.db
CACHE_MEMORY_MAX_SIZE=1000
CACHE_DEFAULT_TTL=900
# TTL per Cache Type (seconds)
CACHE_TTL_SCHEMA=86400
CACHE_TTL_COMPANIES=1800
CACHE_TTL_DASHBOARD_SUMMARY=1800
CACHE_TTL_DASHBOARD_TRENDS=1800
CACHE_TTL_INVOICES=600
CACHE_TTL_INVOICES_SUMMARY=900
CACHE_TTL_TREASURY=600
# Maintenance
CACHE_CLEANUP_INTERVAL=3600
# Event-Based Invalidation (experimental)
CACHE_AUTO_INVALIDATE=False
CACHE_CHECK_INTERVAL=300
# Performance Tracking
CACHE_TRACK_PERFORMANCE=True
CACHE_BENCHMARK_ON_STARTUP=False
# ============================================================================
# DATA ENTRY MODULE - CONFIGURATION
# ============================================================================
# Data Entry module settings (receipts, OCR, etc.)
# Environment identifier
ORACLE_ENV=prod
# SQLite Database (production)
SQLITE_DATABASE_PATH=data/receipts/receipts_prod.db
# File uploads
UPLOAD_PATH=data/receipts/uploads
MAX_UPLOAD_SIZE_MB=10
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================
# Obtain bot token from @BotFather on Telegram
# CRITICAL: Use production bot token, not development!
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
# Backend URL for bot to communicate with API
BACKEND_URL=http://localhost:8000
# Internal API port (bot's internal API for backend callbacks)
INTERNAL_API_PORT=8002
# Enable internal API documentation (DISABLE in production!)
ENABLE_DOCS=false
# ============================================================================
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
# ============================================================================
# CRITICAL: Update with production SMTP credentials
# SMTP Server Configuration
SMTP_HOST=mail.romfast.ro
SMTP_PORT=587
SMTP_USER=ups@romfast.ro
SMTP_PASSWORD=CHANGE_IN_PRODUCTION
SMTP_FROM_EMAIL=ups@romfast.ro
SMTP_FROM_NAME=ROA2WEB
SMTP_USE_TLS=true
# Email Retry Settings
EMAIL_MAX_RETRIES=3
EMAIL_RETRY_DELAY=2.0
# ============================================================================
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
# ============================================================================
# Separate SQLite database for Telegram bot auth codes and sessions
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram_prod.db

147
backend/.env.test.example Normal file
View File

@@ -0,0 +1,147 @@
# ============================================================================
# ROA2WEB Unified Backend - Environment Configuration (TEST)
# ============================================================================
# TEST environment using Oracle TEST server (10.0.20.121)
# Single backend process serving Reports, Data Entry, and Telegram modules
# IMPORTANT: Never commit this file to git!
# ============================================================================
# ORACLE DATABASE CONFIGURATION (REQUIRED - Shared by all modules)
# ============================================================================
# Connection to CONTAFIN_ORACLE schema for authentication and user management
# TEST: Through SSH tunnel to 10.0.20.121 (localhost:1526)
ORACLE_USER=CONTAFIN_ORACLE
ORACLE_PASSWORD=your_oracle_password_here
ORACLE_HOST=localhost
ORACLE_PORT=1526
ORACLE_SID=roa
# TEST: Start SSH tunnel before running backend
# ./ssh-tunnel-test.sh start
# ============================================================================
# JWT AUTHENTICATION (REQUIRED - Shared by all modules)
# ============================================================================
# Used for JWT token generation and validation (shared/auth/jwt_handler.py)
JWT_SECRET_KEY=generate_with_secrets_token_urlsafe_32
JWT_ALGORITHM=HS256
# Token expiration settings (used by shared/auth/jwt_handler.py)
ACCESS_TOKEN_EXPIRE_MINUTES=480
REFRESH_TOKEN_EXPIRE_DAYS=7
# ============================================================================
# SESSION SECURITY - EMAIL 2FA (REQUIRED for Telegram email login)
# ============================================================================
# Used by Telegram module for session token validation
# Must match between backend and Telegram bot
AUTH_SESSION_SECRET=generate_with_secrets_token_urlsafe_32
# ============================================================================
# SERVER CONFIGURATION
# ============================================================================
# Unified backend server settings
API_HOST=0.0.0.0
API_PORT=8000
DEBUG=true
# CORS Origins (comma-separated, includes both old and new frontend ports)
CORS_ORIGINS=http://localhost:3000,http://localhost:3010,http://localhost:5173
# ============================================================================
# REPORTS MODULE - CACHE CONFIGURATION (OPTIONAL - defaults provided)
# ============================================================================
# Two-tier hybrid cache system (L1: in-memory LRU, L2: SQLite persistent)
# Used by backend/modules/reports/cache/config.py
# Core Settings
CACHE_ENABLED=True
CACHE_TYPE=hybrid
CACHE_SQLITE_PATH=./data/cache/roa2web_cache_test.db
CACHE_MEMORY_MAX_SIZE=1000
CACHE_DEFAULT_TTL=900
# TTL per Cache Type (seconds)
CACHE_TTL_SCHEMA=86400
CACHE_TTL_COMPANIES=1800
CACHE_TTL_DASHBOARD_SUMMARY=1800
CACHE_TTL_DASHBOARD_TRENDS=1800
CACHE_TTL_INVOICES=600
CACHE_TTL_INVOICES_SUMMARY=900
CACHE_TTL_TREASURY=600
# Maintenance
CACHE_CLEANUP_INTERVAL=3600
# Event-Based Invalidation (experimental)
CACHE_AUTO_INVALIDATE=False
CACHE_CHECK_INTERVAL=300
# Performance Tracking
CACHE_TRACK_PERFORMANCE=True
CACHE_BENCHMARK_ON_STARTUP=False
# ============================================================================
# DATA ENTRY MODULE - CONFIGURATION
# ============================================================================
# Data Entry module settings (receipts, OCR, etc.)
# Environment identifier (dev/test/prod)
ORACLE_ENV=test
# SQLite Database (test)
SQLITE_DATABASE_PATH=data/receipts/receipts_test.db
# File uploads
UPLOAD_PATH=data/receipts/uploads
MAX_UPLOAD_SIZE_MB=10
# Test company (for testing)
TEST_COMPANY_ID=110
TEST_COMPANY_SCHEMA=MARIUSM_AUTO
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION (REQUIRED for Telegram features)
# ============================================================================
# Obtain bot token from @BotFather on Telegram
TELEGRAM_BOT_TOKEN=your_bot_token_from_botfather
# Backend URL for bot to communicate with API
BACKEND_URL=http://localhost:8000
# Internal API port (bot's internal API for backend callbacks)
INTERNAL_API_PORT=8002
# Enable internal API documentation (development only)
ENABLE_DOCS=false
# ============================================================================
# TELEGRAM MODULE - EMAIL AUTHENTICATION (SMTP) (REQUIRED for email 2FA)
# ============================================================================
# Required for email-based 2FA authentication flow
# Users can login with email + password instead of web app linking
# SMTP Server Configuration
SMTP_HOST=mail.romfast.ro
SMTP_PORT=587
SMTP_USER=ups@romfast.ro
SMTP_PASSWORD=your_smtp_password_here
SMTP_FROM_EMAIL=ups@romfast.ro
SMTP_FROM_NAME=ROA2WEB
SMTP_USE_TLS=true
# Email Retry Settings
EMAIL_MAX_RETRIES=3
EMAIL_RETRY_DELAY=2.0
# ============================================================================
# TELEGRAM MODULE - DATABASE (SQLite for bot data)
# ============================================================================
# Separate SQLite database for Telegram bot auth codes and sessions
TELEGRAM_SQLITE_DATABASE_PATH=data/telegram/telegram_test.db

212
backend/ENV-SETUP.md Normal file
View File

@@ -0,0 +1,212 @@
# Environment Configuration Guide
## Overview
The unified backend uses environment-specific configuration files that are automatically loaded by startup scripts.
**SECURITY**: All `.env*` files (except `.env*.example`) contain real credentials and are **NEVER committed to git**.
## File Structure
```
backend/
├── .env.prod.example # Production template (COMMITTED - no credentials)
├── .env.test.example # Test template (COMMITTED - no credentials)
├── .env.prod.example # Production template (COMMITTED - no credentials)
├── .env.example # Generic template (COMMITTED)
├── .env.prod # Production config (IGNORED - real credentials)
├── .env.test # Test config (IGNORED - real credentials)
├── .env.prod # Production config (IGNORED - real credentials)
└── .env # Active config (IGNORED - auto-generated)
```
## First-Time Setup
### Production
```bash
# 1. Copy template
cp backend/.env.prod.example backend/.env.prod
# 2. Edit with your credentials
vim backend/.env.prod
# 3. Fill in:
# - ORACLE_PASSWORD
# - JWT_SECRET_KEY (generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))")
# - AUTH_SESSION_SECRET (generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))")
# - TELEGRAM_BOT_TOKEN (from @BotFather)
# - SMTP_PASSWORD
# 4. Start
./start-prod.sh
```
### Test
```bash
# Same process with .env.test
cp backend/.env.test.example backend/.env.test
vim backend/.env.test
# Fill in TEST credentials (separate from dev!)
./start-test.sh
```
### Production
```bash
# Same process with .env.prod
cp backend/.env.prod.example backend/.env.prod
vim backend/.env.prod
# Fill in PRODUCTION credentials (generate NEW secrets!)
./start-backend.sh start
```
## How It Works
### Production
```bash
./start-prod.sh # Checks for .env.prod → copies to .env → starts backend
```
### Test
```bash
./start-test.sh # Checks for .env.test → copies to .env → starts backend
```
### Production
```bash
# Manual setup (one-time)
cp .env.prod.example .env.prod
vim .env.prod # Fill in credentials
# Then start
./start-backend.sh start
```
## Important Rules
### ✅ DO
- Copy `.env.*.example` to `.env.*` and fill in real credentials
- Edit `.env.prod` for production changes
- Edit `.env.test` for test environment changes
- Edit `.env.prod` for production
- Generate **new** secrets for each environment
- Keep `.env.prod`, `.env.test`, `.env.prod` **local only** (never commit!)
### ❌ DON'T
- Don't commit `.env`, `.env.prod`, `.env.test`, or `.env.prod` (they're in .gitignore)
- Don't manually edit `.env` (it's auto-generated!)
- Don't use same secrets across environments
- Don't share credentials via git (use secure channels)
- Don't put real credentials in `.env*.example` files
## Environment Differences
| Setting | .env.prod | .env.test | .env.prod |
|---------|----------|-----------|-----------|
| Oracle SID | `ROA` | `roa` | `ROA` |
| JWT Expire | 30 min | 480 min | 30 min |
| DEBUG | `true` | `true` | `false` |
| Cache DB | `roa2web_cache.db` | `roa2web_cache_test.db` | `roa2web_cache_prod.db` |
| Receipts DB | `receipts_dev.db` | `receipts_test.db` | `receipts_prod.db` |
| Telegram DB | `telegram.db` | `telegram_test.db` | `telegram_prod.db` |
## Security Notes
### Template Files (.env.*.example)
These contain **placeholders only**:
- ✅ Safe to commit to git
- ✅ Shared across team
- ✅ No real credentials
- 📖 Used as reference for first-time setup
### Actual Config Files (.env.prod, .env.test, .env.prod)
These contain **real credentials**:
-**NEVER commit to git** (in .gitignore)
- ❌ Never share via email/chat
- ✅ Keep local only
- ✅ Generate unique secrets per environment
- 🔐 Share securely if needed (encrypted vault, 1Password, etc.)
### Active Config (.env)
This is **auto-generated** and **ignored by git**:
- ❌ Never commit to git
- 🔄 Auto-overwritten by startup scripts
- 📝 Edit source files (.env.prod, .env.test) instead
## Generating Secrets
For `JWT_SECRET_KEY` and `AUTH_SESSION_SECRET`:
```bash
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
```
Generate **different** secrets for dev, test, and production!
## Quick Reference
### First Time Setup
```bash
# 1. Copy template
cp backend/.env.prod.example backend/.env.prod
# 2. Fill credentials
vim backend/.env.prod
# 3. Start
./start-prod.sh
```
### Changing Configuration
```bash
# 1. Edit source file
vim backend/.env.prod
# 2. Restart to apply
./start-prod.sh
```
### Production Deployment
```bash
# 1. Copy template
cp backend/.env.prod.example backend/.env.prod
# 2. Fill in PRODUCTION values
vim backend/.env.prod
# 3. Generate NEW secrets
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
# 4. Start backend
./start-backend.sh start
```
## Troubleshooting
### "Wrong database" error
Check that you're using the correct startup script:
- Production: `./start-prod.sh` (uses `.env.prod`)
- Test: `./start-test.sh` (uses `.env.test`)
### ".env.prod not found" error
First-time setup required:
```bash
cp backend/.env.prod.example backend/.env.prod
vim backend/.env.prod # Fill in your credentials
```
### Changes not taking effect
The `.env` file is regenerated on each start. Edit the source file (`.env.prod` or `.env.test`) instead.
### Checking what will be committed
```bash
git status backend/.env*
# Should show:
# modified: .env.prod.example (if you changed template)
# nothing else!
```
## Team Sharing
**Templates only** are committed to git:
- Share configuration structure via `.env*.example`
- Each developer creates their own `.env.prod` from template
- Never commit actual credentials
- Use secure channels for sharing sensitive values (1Password, encrypted vault, etc.)

View File

@@ -0,0 +1,102 @@
# Quick Environment Reference
## 🔒 SECURITY FIRST
**All `.env*` files (except `.env*.example`) contain real credentials and are NEVER committed to git!**
## 🚀 First-Time Setup
```bash
# 1. Copy template with real credentials
cp backend/.env.prod.example backend/.env.prod
# 2. Edit with YOUR credentials
vim backend/.env.prod
# 3. Fill in the placeholders:
# - ORACLE_PASSWORD
# - JWT_SECRET_KEY
# - AUTH_SESSION_SECRET
# - TELEGRAM_BOT_TOKEN
# - SMTP_PASSWORD
# 4. Start production
./start-prod.sh
```
## 📋 Daily Usage
```bash
# Production (uses .env.prod automatically)
./start-prod.sh
# Test Environment (uses .env.test automatically)
./start-test.sh
# Quick Restart (uses existing .env)
./start-backend.sh restart
```
## ✏️ Changing Configuration
```bash
# 1. Edit the source file (NOT .env!)
vim backend/.env.prod # Production
vim backend/.env.test # Test
# 2. Restart to apply changes
./start-prod.sh
```
## 📁 Which File to Edit?
| You Want To... | Edit This File |
|----------------|----------------|
| Change dev database password | `backend/.env.prod` |
| Update test server settings | `backend/.env.test` |
| Add new environment variable | Templates: `.env*.example` + your `.env.prod`/`.env.test` |
| Create production config | Copy `.env.prod.example` to `.env.prod` and fill secrets |
## 🔑 Generating Secrets
```bash
# For JWT_SECRET_KEY and AUTH_SESSION_SECRET
python3 -c "import secrets; print(secrets.token_urlsafe(32))"
```
**Generate DIFFERENT secrets for each environment (dev, test, prod)!**
## ⚠️ Important
- **Never edit** `backend/.env` directly (it's auto-generated!)
- **Always edit** `backend/.env.prod` or `.env.test`
- **Never commit** `.env`, `.env.prod`, `.env.test`, `.env.prod`
- **Only commit** `.env*.example` (templates with placeholders)
- Restart after changes for them to take effect
## 🛡️ Git Behavior
| File | Git Status | Contains |
|------|-----------|----------|
| `.env.prod.example` | ✅ Committed | Template (placeholders) |
| `.env.test.example` | ✅ Committed | Template (placeholders) |
| `.env.prod.example` | ✅ Committed | Template (placeholders) |
| `.env.example` | ✅ Committed | Generic template |
| `.env.prod` | ❌ Ignored | **Real dev credentials** |
| `.env.test` | ❌ Ignored | **Real test credentials** |
| `.env.prod` | ❌ Ignored | **Real prod credentials** |
| `.env` | ❌ Ignored | Auto-generated (current) |
## ✅ Quick Check
```bash
# See what git will commit
git status backend/.env*
# Should show ONLY .env*.example files
# If .env.prod or .env.test appear, they're NOT properly ignored!
```
## 📖 More Info
See `backend/ENV-SETUP.md` for complete documentation.

0
backend/__init__.py Normal file
View File

173
backend/config.py Normal file
View File

@@ -0,0 +1,173 @@
"""
Unified Configuration for ROA2WEB Backend
Consolidates settings from Reports, Data Entry, and Telegram modules
"""
import os
from pathlib import Path
from typing import List
from pydantic_settings import BaseSettings
from functools import lru_cache
class UnifiedSettings(BaseSettings):
"""Unified application settings for all modules."""
# ============================================================================
# GENERAL APPLICATION SETTINGS
# ============================================================================
app_name: str = "ROA2WEB Unified Backend"
app_version: str = "1.0.0"
debug: bool = False
api_host: str = "0.0.0.0"
api_port: int = 8000
# ============================================================================
# ORACLE DATABASE (Shared by all modules)
# ============================================================================
oracle_user: str = ""
oracle_password: str = ""
oracle_host: str = "localhost"
oracle_port: int = 1526
oracle_sid: str = "ROA"
# ============================================================================
# JWT AUTHENTICATION (Shared by all modules)
# ============================================================================
jwt_secret_key: str = "change-me-in-production"
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 30
refresh_token_expire_days: int = 7
# ============================================================================
# SESSION SECURITY - EMAIL 2FA (Telegram module)
# ============================================================================
auth_session_secret: str = "change-me-in-production"
# ============================================================================
# CORS
# ============================================================================
cors_origins: str = "http://localhost:3000,http://localhost:5173"
# ============================================================================
# REPORTS MODULE - CACHE CONFIGURATION
# ============================================================================
reports_cache_enabled: bool = True
reports_cache_type: str = "hybrid"
reports_cache_sqlite_path: str = "./data/cache/roa2web_cache.db"
reports_cache_memory_max_size: int = 1000
reports_cache_default_ttl: int = 900
# Cache TTL per type (seconds)
reports_cache_ttl_schema: int = 86400
reports_cache_ttl_companies: int = 1800
reports_cache_ttl_dashboard_summary: int = 1800
reports_cache_ttl_dashboard_trends: int = 1800
reports_cache_ttl_invoices: int = 600
reports_cache_ttl_invoices_summary: int = 900
reports_cache_ttl_treasury: int = 600
# Cache maintenance
reports_cache_cleanup_interval: int = 3600
reports_cache_auto_invalidate: bool = False
reports_cache_check_interval: int = 300
reports_cache_track_performance: bool = True
reports_cache_benchmark_on_startup: bool = False
# ============================================================================
# DATA ENTRY MODULE - CONFIGURATION
# ============================================================================
data_entry_sqlite_database_path: str = "data/receipts/receipts.db"
data_entry_upload_path: str = "data/receipts/uploads"
data_entry_max_upload_size_mb: int = 10
data_entry_allowed_mime_types: List[str] = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
]
# ============================================================================
# TELEGRAM MODULE - BOT CONFIGURATION
# ============================================================================
telegram_bot_token: str = ""
telegram_smtp_host: str = ""
telegram_smtp_port: int = 587
telegram_smtp_user: str = ""
telegram_smtp_password: str = ""
telegram_smtp_from_email: str = ""
telegram_smtp_from_name: str = "ROA2WEB"
telegram_smtp_use_tls: bool = True
telegram_email_max_retries: int = 3
telegram_email_retry_delay: float = 2.0
telegram_sqlite_database_path: str = "data/telegram/telegram.db"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
case_sensitive = False
# ============================================================================
# COMPUTED PROPERTIES
# ============================================================================
@property
def oracle_dsn(self) -> str:
"""Get Oracle DSN string."""
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
@property
def cors_origins_list(self) -> List[str]:
"""Get CORS origins as list."""
return [origin.strip() for origin in self.cors_origins.split(",")]
# Data Entry properties
@property
def data_entry_database_url(self) -> str:
"""Get SQLite database URL for async (Data Entry)."""
return f"sqlite+aiosqlite:///{self.data_entry_sqlite_database_path}"
@property
def data_entry_sync_database_url(self) -> str:
"""Get SQLite database URL for sync operations (Alembic)."""
return f"sqlite:///{self.data_entry_sqlite_database_path}"
@property
def data_entry_upload_path_resolved(self) -> Path:
"""Get resolved upload path."""
path = Path(self.data_entry_upload_path)
path.mkdir(parents=True, exist_ok=True)
return path
@property
def data_entry_max_upload_size_bytes(self) -> int:
"""Get max upload size in bytes."""
return self.data_entry_max_upload_size_mb * 1024 * 1024
# Reports cache properties
@property
def reports_cache_sqlite_path_resolved(self) -> Path:
"""Get resolved cache SQLite path."""
path = Path(self.reports_cache_sqlite_path)
path.parent.mkdir(parents=True, exist_ok=True)
return path
# Telegram properties
@property
def telegram_sqlite_path_resolved(self) -> Path:
"""Get resolved Telegram SQLite path."""
path = Path(self.telegram_sqlite_database_path)
path.parent.mkdir(parents=True, exist_ok=True)
return path
@lru_cache()
def get_settings() -> UnifiedSettings:
"""Get cached settings instance."""
return UnifiedSettings()
# Convenience instance
settings = get_settings()

45
backend/data/README.md Normal file
View File

@@ -0,0 +1,45 @@
# Backend Runtime Data
This directory contains runtime data generated by the unified backend.
## Directory Structure
```
data/
├── cache/ # Reports module cache (hybrid L1+L2)
│ └── *.db # SQLite L2 cache database
├── receipts/ # Data Entry module data
│ ├── *.db # SQLite receipts database
│ └── uploads/ # User-uploaded files (receipts, attachments)
└── telegram/ # Telegram bot data
└── *.db # SQLite bot auth/session database
```
## Git Behavior
- **Ignored**: All `*.db` files and `uploads/` contents
- **Committed**: Only `.gitkeep` files to preserve directory structure
## Environment-Specific Databases
Different environments use separate databases:
- **Development** (`.env.prod`):
- Cache: `roa2web_cache.db`
- Receipts: `receipts_dev.db`
- Telegram: `telegram.db`
- **Test** (`.env.test`):
- Cache: `roa2web_cache_test.db`
- Receipts: `receipts_test.db`
- Telegram: `telegram_test.db`
- **Production** (`.env.prod`):
- Cache: `roa2web_cache_prod.db`
- Receipts: `receipts_prod.db`
- Telegram: `telegram_prod.db`
## Auto-Created
All databases and directories are created automatically on first run.
No manual setup required.

0
backend/data/cache/.gitkeep vendored Normal file
View File

View File

View File

428
backend/main.py Normal file
View File

@@ -0,0 +1,428 @@
"""
ROA2WEB Unified Backend - Single FastAPI Application
Consolidates Reports, Data Entry, and Telegram modules into one process
"""
import asyncio
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Load environment variables
load_dotenv()
# Add project root and shared modules to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root)) # Enable 'from backend.xxx import yyy'
sys.path.insert(0, str(project_root / "shared")) # Enable 'from shared.xxx import yyy'
# Import configuration
from backend.config import settings
# Import shared infrastructure
from shared.database.oracle_pool import oracle_pool
from shared.auth.middleware import AuthenticationMiddleware
from shared.auth.routes import create_auth_router
from shared.routes.companies import create_companies_router
from shared.routes.calendar import create_calendar_router
# Import module router factories
from backend.modules.reports.routers import create_reports_router
from backend.modules.data_entry.routers import create_data_entry_router
from backend.modules.telegram.routers import create_telegram_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)
# Global variables for background tasks
telegram_bot_task = None
# ============================================================================
# INITIALIZATION FUNCTIONS
# ============================================================================
async def init_oracle_pool():
"""Initialize Oracle connection pool (shared by all modules)."""
logger.info("[ORACLE] Initializing connection pool...")
await oracle_pool.initialize()
logger.info("[ORACLE] ✅ Pool initialized successfully")
async def init_reports_cache():
"""Initialize Reports cache system."""
logger.info("[REPORTS] Initializing cache system...")
try:
from backend.modules.reports.cache import init_cache, init_event_monitor, get_cache
from backend.modules.reports.cache.config import CacheConfig
cache_config = CacheConfig.from_env()
await init_cache(cache_config)
logger.info(f"[REPORTS] ✅ Cache initialized: type={cache_config.cache_type}, enabled={cache_config.enabled}")
# Initialize event monitor
cache = get_cache()
await init_event_monitor(cache, cache_config)
if cache_config.auto_invalidate_enabled:
logger.info("[REPORTS] Event-based auto-invalidation ENABLED")
else:
logger.info("[REPORTS] Event-based auto-invalidation DISABLED")
except Exception as e:
logger.error(f"[REPORTS] ⚠️ Cache initialization error: {e}", exc_info=True)
logger.warning("[REPORTS] Continuing without cache")
async def init_data_entry_db():
"""Initialize Data Entry SQLite database."""
logger.info("[DATA-ENTRY] Initializing SQLite database...")
try:
from backend.modules.data_entry.db.database import init_db
await init_db()
logger.info(f"[DATA-ENTRY] ✅ Database initialized: {settings.data_entry_sqlite_database_path}")
# Ensure upload directory exists
settings.data_entry_upload_path_resolved
logger.info(f"[DATA-ENTRY] Upload path: {settings.data_entry_upload_path_resolved}")
except Exception as e:
logger.error(f"[DATA-ENTRY] ❌ Database initialization error: {e}", exc_info=True)
raise
async def init_telegram_db():
"""Initialize Telegram SQLite database."""
logger.info("[TELEGRAM] Initializing SQLite database...")
try:
from backend.modules.telegram.db import init_database, cleanup_expired_codes, cleanup_expired_sessions, cleanup_expired_email_codes
await init_database()
logger.info(f"[TELEGRAM] ✅ Database initialized: {settings.telegram_sqlite_database_path}")
# Cleanup expired data
expired_codes = await cleanup_expired_codes()
expired_sessions = await cleanup_expired_sessions()
expired_email_codes = await cleanup_expired_email_codes()
logger.info(f"[TELEGRAM] Cleanup: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
except Exception as e:
logger.error(f"[TELEGRAM] ❌ Database initialization error: {e}", exc_info=True)
raise
def init_paddle_ocr_background():
"""Initialize PaddleOCR in background thread (takes 15-20s)."""
try:
logger.info("[DATA-ENTRY] Pre-loading OCR engine (background)...")
from backend.modules.data_entry.services.ocr_service import ocr_service
ocr_service.ocr_engine._init_paddle_lazy()
logger.info("[DATA-ENTRY] ✅ OCR engine ready")
except Exception as e:
logger.warning(f"[DATA-ENTRY] ⚠️ OCR engine pre-load failed: {e}")
async def run_telegram_bot():
"""Run Telegram bot as background task."""
logger.info("[TELEGRAM] Starting bot...")
try:
from telegram.ext import Application, CommandHandler, CallbackQueryHandler, MessageHandler, filters
from backend.modules.telegram.bot.handlers import (
start_command, help_command, clear_command, companies_command,
unlink_command, selectcompany_command, dashboard_command, sold_command,
facturi_command, trezorerie_command, menu_command, trezorerie_casa_command,
trezorerie_banca_command, clienti_command, furnizori_command, evolutie_command,
clearcache_command, togglecache_command, handle_text_message, button_callback,
error_handler
)
from backend.modules.telegram.bot.email_handlers import email_login_handler
# Create Telegram application
application = Application.builder().token(settings.telegram_bot_token).build()
# Register handlers
application.add_handler(email_login_handler)
application.add_handler(CommandHandler("start", start_command))
application.add_handler(CommandHandler("menu", menu_command))
application.add_handler(CommandHandler("help", help_command))
application.add_handler(CommandHandler("unlink", unlink_command))
application.add_handler(CommandHandler("clear", clear_command))
application.add_handler(CommandHandler("companies", companies_command))
application.add_handler(CommandHandler("selectcompany", selectcompany_command))
application.add_handler(CommandHandler("dashboard", dashboard_command))
application.add_handler(CommandHandler("sold", sold_command))
application.add_handler(CommandHandler("facturi", facturi_command))
application.add_handler(CommandHandler("trezorerie", trezorerie_command))
application.add_handler(CommandHandler("trezorerie_casa", trezorerie_casa_command))
application.add_handler(CommandHandler("trezorerie_banca", trezorerie_banca_command))
application.add_handler(CommandHandler("clienti", clienti_command))
application.add_handler(CommandHandler("furnizori", furnizori_command))
application.add_handler(CommandHandler("evolutie", evolutie_command))
application.add_handler(CommandHandler("clearcache", clearcache_command))
application.add_handler(CommandHandler("togglecache", togglecache_command))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_text_message))
application.add_handler(CallbackQueryHandler(button_callback))
application.add_error_handler(error_handler)
# Initialize and start
await application.initialize()
await application.start()
await application.updater.start_polling(drop_pending_updates=True)
bot_info = await application.bot.get_me()
logger.info(f"[TELEGRAM] ✅ Bot running: @{bot_info.username}")
# Keep bot running
while True:
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("[TELEGRAM] Bot task cancelled, stopping...")
if 'application' in locals():
await application.updater.stop()
await application.stop()
await application.shutdown()
logger.info("[TELEGRAM] ✅ Bot stopped")
raise
except Exception as e:
logger.error(f"[TELEGRAM] ❌ Bot error: {e}", exc_info=True)
raise
# ============================================================================
# FASTAPI APPLICATION
# ============================================================================
app = FastAPI(
title="ROA2WEB Unified Backend",
description="Unified FastAPI backend for Reports, Data Entry, and Telegram modules",
version="1.0.0"
)
# ============================================================================
# STARTUP/SHUTDOWN EVENT HANDLERS
# ============================================================================
@app.on_event("startup")
async def startup_event():
"""Application startup - Initialize all resources."""
global telegram_bot_task
logger.info("=" * 80)
logger.info("[STARTUP] ROA2WEB Unified Backend")
logger.info("=" * 80)
try:
# Step 1: Initialize Oracle pool (shared by all modules)
await init_oracle_pool()
# Step 2: Parallel initialization of module-specific resources
logger.info("[STARTUP] Initializing module resources in parallel...")
await asyncio.gather(
init_reports_cache(),
init_data_entry_db(),
init_telegram_db(),
)
# Step 3: Start PaddleOCR initialization in background thread
import threading
threading.Thread(target=init_paddle_ocr_background, daemon=True).start()
# Step 4: Start Telegram bot as background task
if settings.telegram_bot_token:
telegram_bot_task = asyncio.create_task(run_telegram_bot())
logger.info("[STARTUP] ✅ Telegram bot task created")
else:
logger.warning("[STARTUP] ⚠️ TELEGRAM_BOT_TOKEN not set, bot disabled")
logger.info("=" * 80)
logger.info("[STARTUP] ✅ All modules initialized successfully")
logger.info(f"[STARTUP] ✅ Server running on http://{settings.api_host}:{settings.api_port}")
logger.info("=" * 80)
except Exception as e:
logger.error(f"[STARTUP] ❌ Initialization failed: {e}", exc_info=True)
raise
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown - Cleanup resources."""
global telegram_bot_task
logger.info("=" * 80)
logger.info("[SHUTDOWN] Stopping ROA2WEB Unified Backend...")
logger.info("=" * 80)
try:
# Stop Telegram bot
if telegram_bot_task and not telegram_bot_task.done():
logger.info("[SHUTDOWN] Stopping Telegram bot...")
telegram_bot_task.cancel()
try:
await telegram_bot_task
except asyncio.CancelledError:
pass
# Stop Reports cache event monitor
try:
from backend.modules.reports.cache import close_cache, get_event_monitor
monitor = get_event_monitor()
if monitor:
await monitor.stop()
logger.info("[SHUTDOWN] Reports cache monitor stopped")
await close_cache()
logger.info("[SHUTDOWN] Reports cache closed")
except Exception as e:
logger.error(f"[SHUTDOWN] Cache error: {e}")
# Close Oracle pool
await oracle_pool.close_pool()
logger.info("[SHUTDOWN] Oracle pool closed")
logger.info("=" * 80)
logger.info("[SHUTDOWN] ✅ Shutdown complete")
logger.info("=" * 80)
except Exception as e:
logger.error(f"[SHUTDOWN] Error during shutdown: {e}", exc_info=True)
# ============================================================================
# MIDDLEWARE
# ============================================================================
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for production deployment
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Authentication middleware
app.add_middleware(
AuthenticationMiddleware,
excluded_paths=[
"/", "/docs", "/health", "/redoc", "/openapi.json",
"/api/auth/login", "/api/auth/refresh",
"/api/telegram/auth/verify-user",
"/api/telegram/auth/verify-email",
"/api/telegram/auth/login-with-email",
"/api/telegram/auth/refresh-token",
"/api/telegram/health"
]
)
# ============================================================================
# ROUTER REGISTRATION
# ============================================================================
# Module routers with prefixes
app.include_router(create_reports_router(), prefix="/api/reports", tags=["reports"])
app.include_router(create_data_entry_router(), prefix="/api/data-entry", tags=["data-entry"])
app.include_router(create_telegram_router(), prefix="/api/telegram", tags=["telegram"])
# Shared routers
auth_router = create_auth_router(prefix="", tags=["authentication"])
app.include_router(auth_router, prefix="/api/auth")
companies_router = create_companies_router(oracle_pool, tags=["companies"])
app.include_router(companies_router, prefix="/api/companies")
calendar_router = create_calendar_router(oracle_pool, tags=["calendar"])
app.include_router(calendar_router, prefix="/api/calendar")
# ============================================================================
# ROOT & HEALTH ENDPOINTS
# ============================================================================
@app.get("/")
async def root():
"""Root endpoint - API information."""
return {
"name": settings.app_name,
"version": settings.app_version,
"status": "running",
"modules": ["reports", "data-entry", "telegram"],
"docs": "/docs",
"health": "/health"
}
@app.get("/health")
async def health_check():
"""Health check endpoint with module status."""
health_status = {
"api": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"modules": {}
}
# Check Oracle connection
try:
async with oracle_pool.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1 FROM DUAL")
health_status["modules"]["oracle"] = "connected"
except Exception as e:
health_status["modules"]["oracle"] = f"error: {str(e)}"
# Check Reports cache
try:
from backend.modules.reports.cache import get_cache
cache = get_cache()
health_status["modules"]["reports_cache"] = "initialized" if cache else "disabled"
except Exception as e:
health_status["modules"]["reports_cache"] = f"error: {str(e)}"
# Check Data Entry DB
try:
db_path = Path(settings.data_entry_sqlite_database_path)
health_status["modules"]["data_entry_db"] = "exists" if db_path.exists() else "missing"
except Exception as e:
health_status["modules"]["data_entry_db"] = f"error: {str(e)}"
# Check Telegram bot
global telegram_bot_task
if telegram_bot_task:
if telegram_bot_task.done():
health_status["modules"]["telegram_bot"] = "stopped"
else:
health_status["modules"]["telegram_bot"] = "running"
else:
health_status["modules"]["telegram_bot"] = "disabled"
return health_status
# ============================================================================
# MAIN ENTRY POINT
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"backend.main:app",
host=settings.api_host,
port=settings.api_port,
reload=False,
log_level="info"
)

View File

View File

View File

@@ -0,0 +1,96 @@
"""Application configuration using pydantic-settings."""
import os
from pathlib import Path
from typing import List
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# App info
app_name: str = "Data Entry API"
app_version: str = "1.0.0"
debug: bool = False
# API
api_host: str = "0.0.0.0"
api_port: int = 8003
# SQLite Database
sqlite_database_path: str = "data/receipts/receipts.db"
# File uploads
upload_path: str = "data/uploads"
max_upload_size_mb: int = 10
allowed_mime_types: List[str] = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
]
# Oracle Database (for nomenclatures)
oracle_user: str = ""
oracle_password: str = ""
oracle_host: str = "localhost"
oracle_port: int = 1526
oracle_sid: str = "ROA"
# JWT Authentication
jwt_secret_key: str = "change-me-in-production"
jwt_algorithm: str = "HS256"
jwt_expire_minutes: int = 480
# CORS
cors_origins: str = "http://localhost:3010,http://localhost:3000"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
@property
def database_url(self) -> str:
"""Get SQLite database URL for async."""
return f"sqlite+aiosqlite:///{self.sqlite_database_path}"
@property
def sync_database_url(self) -> str:
"""Get SQLite database URL for sync operations (Alembic)."""
return f"sqlite:///{self.sqlite_database_path}"
@property
def upload_path_resolved(self) -> Path:
"""Get resolved upload path."""
path = Path(self.upload_path)
path.mkdir(parents=True, exist_ok=True)
return path
@property
def max_upload_size_bytes(self) -> int:
"""Get max upload size in bytes."""
return self.max_upload_size_mb * 1024 * 1024
@property
def cors_origins_list(self) -> List[str]:
"""Get CORS origins as list."""
return [origin.strip() for origin in self.cors_origins.split(",")]
@property
def oracle_dsn(self) -> str:
"""Get Oracle DSN string."""
return f"{self.oracle_host}:{self.oracle_port}/{self.oracle_sid}"
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
# Convenience instance
settings = get_settings()

View File

@@ -0,0 +1,4 @@
# Database module
from .database import get_session, init_db, engine
__all__ = ["get_session", "init_db", "engine"]

View File

@@ -0,0 +1,10 @@
# CRUD operations
from .receipt import ReceiptCRUD
from .attachment import AttachmentCRUD
from .accounting_entry import AccountingEntryCRUD
__all__ = [
"ReceiptCRUD",
"AttachmentCRUD",
"AccountingEntryCRUD",
]

View File

@@ -0,0 +1,197 @@
"""CRUD operations for accounting entries."""
from datetime import datetime
from typing import Optional, List
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry, EntryType
from backend.modules.data_entry.schemas.receipt import AccountingEntryCreate, AccountingEntryUpdate
class AccountingEntryCRUD:
"""CRUD operations for AccountingEntry model."""
@staticmethod
async def create(
session: AsyncSession,
receipt_id: int,
data: AccountingEntryCreate,
sort_order: int = 0,
is_auto_generated: bool = True,
) -> AccountingEntry:
"""Create a new accounting entry."""
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=data.entry_type,
account_code=data.account_code,
account_name=data.account_name,
amount=data.amount,
partner_id=data.partner_id,
cost_center_id=data.cost_center_id,
is_auto_generated=is_auto_generated,
sort_order=sort_order,
)
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
@staticmethod
async def create_bulk(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
is_auto_generated: bool = True,
) -> List[AccountingEntry]:
"""Create multiple accounting entries at once."""
created_entries = []
for idx, entry_data in enumerate(entries):
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=entry_data.entry_type,
account_code=entry_data.account_code,
account_name=entry_data.account_name,
amount=entry_data.amount,
partner_id=entry_data.partner_id,
cost_center_id=entry_data.cost_center_id,
is_auto_generated=is_auto_generated,
sort_order=idx,
)
session.add(entry)
created_entries.append(entry)
await session.commit()
for entry in created_entries:
await session.refresh(entry)
return created_entries
@staticmethod
async def get_by_id(
session: AsyncSession,
entry_id: int,
) -> Optional[AccountingEntry]:
"""Get accounting entry by ID."""
query = select(AccountingEntry).where(AccountingEntry.id == entry_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_by_receipt_id(
session: AsyncSession,
receipt_id: int,
) -> List[AccountingEntry]:
"""Get all accounting entries for a receipt."""
query = select(AccountingEntry).where(
AccountingEntry.receipt_id == receipt_id
).order_by(AccountingEntry.sort_order.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
async def update(
session: AsyncSession,
entry: AccountingEntry,
data: AccountingEntryUpdate,
modified_by: str,
) -> AccountingEntry:
"""Update an accounting entry."""
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(entry, field, value)
entry.is_auto_generated = False
entry.modified_by = modified_by
entry.modified_at = datetime.utcnow()
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
@staticmethod
async def delete(session: AsyncSession, entry: AccountingEntry) -> bool:
"""Delete an accounting entry."""
await session.delete(entry)
await session.commit()
return True
@staticmethod
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
"""Delete all accounting entries for a receipt."""
query = delete(AccountingEntry).where(AccountingEntry.receipt_id == receipt_id)
result = await session.execute(query)
await session.commit()
return result.rowcount
@staticmethod
async def replace_all_for_receipt(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
modified_by: str,
) -> List[AccountingEntry]:
"""Replace all entries for a receipt with new ones."""
# Delete existing entries
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
# Create new entries (marked as manually modified)
created_entries = []
for idx, entry_data in enumerate(entries):
entry = AccountingEntry(
receipt_id=receipt_id,
entry_type=entry_data.entry_type,
account_code=entry_data.account_code,
account_name=entry_data.account_name,
amount=entry_data.amount,
partner_id=entry_data.partner_id,
cost_center_id=entry_data.cost_center_id,
is_auto_generated=False,
modified_by=modified_by,
modified_at=datetime.utcnow(),
sort_order=idx,
)
session.add(entry)
created_entries.append(entry)
await session.commit()
for entry in created_entries:
await session.refresh(entry)
return created_entries
@staticmethod
async def validate_entries(entries: List[AccountingEntryCreate]) -> tuple[bool, str]:
"""
Validate accounting entries.
Returns (is_valid, error_message).
"""
if not entries:
return False, "At least one entry is required"
total_debit = sum(
e.amount for e in entries if e.entry_type == EntryType.DEBIT
)
total_credit = sum(
e.amount for e in entries if e.entry_type == EntryType.CREDIT
)
# Check balance (debit should equal credit)
if abs(total_debit - total_credit) > 0.01:
return False, f"Entries not balanced: Debit={total_debit}, Credit={total_credit}"
# Check for valid account codes
for entry in entries:
if not entry.account_code or len(entry.account_code) < 3:
return False, f"Invalid account code: {entry.account_code}"
return True, ""

View File

@@ -0,0 +1,140 @@
"""CRUD operations for receipt attachments."""
import os
import uuid
import aiofiles
from datetime import datetime
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import UploadFile
from backend.modules.data_entry.db.models.receipt import ReceiptAttachment
from backend.modules.data_entry.config import settings
class AttachmentCRUD:
"""CRUD operations for ReceiptAttachment model."""
@staticmethod
def _generate_stored_filename(original_filename: str) -> str:
"""Generate unique filename for storage."""
ext = Path(original_filename).suffix.lower()
return f"{uuid.uuid4()}{ext}"
@staticmethod
def _get_upload_path(stored_filename: str) -> Path:
"""Get full path for storing file, organized by year/month."""
now = datetime.utcnow()
relative_path = Path(str(now.year)) / f"{now.month:02d}"
full_path = settings.upload_path_resolved / relative_path
# Ensure directory exists
full_path.mkdir(parents=True, exist_ok=True)
return relative_path / stored_filename
@staticmethod
async def create(
session: AsyncSession,
receipt_id: int,
file: UploadFile,
) -> ReceiptAttachment:
"""Create attachment by saving file and creating DB record."""
# Generate stored filename
stored_filename = AttachmentCRUD._generate_stored_filename(file.filename or "upload")
# Get relative path
relative_path = AttachmentCRUD._get_upload_path(stored_filename)
# Full path for saving
full_path = settings.upload_path_resolved / relative_path
# Read file content
content = await file.read()
file_size = len(content)
# Validate file size
if file_size > settings.max_upload_size_bytes:
raise ValueError(f"File too large. Maximum size is {settings.max_upload_size_mb}MB")
# Validate MIME type
mime_type = file.content_type or "application/octet-stream"
if mime_type not in settings.allowed_mime_types:
raise ValueError(f"File type not allowed: {mime_type}")
# Save file
async with aiofiles.open(full_path, "wb") as f:
await f.write(content)
# Create DB record
attachment = ReceiptAttachment(
receipt_id=receipt_id,
filename=file.filename or "upload",
stored_filename=stored_filename,
file_path=str(relative_path),
file_size=file_size,
mime_type=mime_type,
)
session.add(attachment)
await session.commit()
await session.refresh(attachment)
return attachment
@staticmethod
async def get_by_id(
session: AsyncSession,
attachment_id: int,
) -> Optional[ReceiptAttachment]:
"""Get attachment by ID."""
query = select(ReceiptAttachment).where(ReceiptAttachment.id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_by_receipt_id(
session: AsyncSession,
receipt_id: int,
) -> List[ReceiptAttachment]:
"""Get all attachments for a receipt."""
query = select(ReceiptAttachment).where(
ReceiptAttachment.receipt_id == receipt_id
).order_by(ReceiptAttachment.uploaded_at.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
def get_file_path(attachment: ReceiptAttachment) -> Path:
"""Get full file path for an attachment."""
return settings.upload_path_resolved / attachment.file_path
@staticmethod
async def delete(session: AsyncSession, attachment: ReceiptAttachment) -> bool:
"""Delete attachment (file and DB record)."""
# Delete file
file_path = AttachmentCRUD.get_file_path(attachment)
if file_path.exists():
os.remove(file_path)
# Delete DB record
await session.delete(attachment)
await session.commit()
return True
@staticmethod
async def delete_all_for_receipt(session: AsyncSession, receipt_id: int) -> int:
"""Delete all attachments for a receipt."""
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
count = 0
for attachment in attachments:
await AttachmentCRUD.delete(session, attachment)
count += 1
return count

View File

@@ -0,0 +1,324 @@
"""CRUD operations for receipts."""
import json
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, List, Tuple, Any
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus
from backend.modules.data_entry.schemas.receipt import ReceiptCreate, ReceiptUpdate, ReceiptFilter
def _serialize_tva_breakdown(tva_breakdown: Optional[List[Any]]) -> Optional[str]:
"""Serialize TVA breakdown list to JSON string for SQLite storage."""
if tva_breakdown is None:
return None
# Convert Decimal to float for JSON serialization
serializable = []
for entry in tva_breakdown:
if hasattr(entry, 'model_dump'):
# Pydantic model
item = entry.model_dump()
elif isinstance(entry, dict):
item = entry.copy()
else:
item = dict(entry)
# Convert Decimal to float
if 'amount' in item and isinstance(item['amount'], Decimal):
item['amount'] = float(item['amount'])
serializable.append(item)
return json.dumps(serializable)
def _serialize_payment_methods(payment_methods: Optional[List[Any]]) -> Optional[str]:
"""Serialize payment methods list to JSON string for SQLite storage."""
if payment_methods is None:
return None
serializable = []
for pm in payment_methods:
if hasattr(pm, 'model_dump'):
item = pm.model_dump()
elif isinstance(pm, dict):
item = pm.copy()
else:
item = dict(pm)
# Convert Decimal to float for JSON
if 'amount' in item:
if hasattr(item['amount'], '__float__'):
item['amount'] = float(item['amount'])
serializable.append(item)
return json.dumps(serializable)
class ReceiptCRUD:
"""CRUD operations for Receipt model."""
@staticmethod
async def create(
session: AsyncSession,
data: ReceiptCreate,
created_by: str,
) -> Receipt:
"""Create a new receipt."""
# Get data as dict and serialize tva_breakdown and payment_methods to JSON string
receipt_data = data.model_dump()
receipt_data['tva_breakdown'] = _serialize_tva_breakdown(receipt_data.get('tva_breakdown'))
receipt_data['payment_methods'] = _serialize_payment_methods(receipt_data.get('payment_methods'))
receipt = Receipt(
**receipt_data,
created_by=created_by,
status=ReceiptStatus.DRAFT,
)
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def get_by_id(
session: AsyncSession,
receipt_id: int,
include_relations: bool = True,
) -> Optional[Receipt]:
"""Get receipt by ID, optionally with relationships."""
query = select(Receipt).where(Receipt.id == receipt_id)
if include_relations:
query = query.options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
result = await session.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def get_list(
session: AsyncSession,
filters: ReceiptFilter,
) -> Tuple[List[Receipt], int]:
"""Get paginated list of receipts with filters."""
# Base query
query = select(Receipt).options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
# Apply filters
if filters.status:
query = query.where(Receipt.status == filters.status)
if filters.direction:
query = query.where(Receipt.direction == filters.direction)
if filters.company_id:
query = query.where(Receipt.company_id == filters.company_id)
if filters.created_by:
query = query.where(Receipt.created_by == filters.created_by)
if filters.date_from:
query = query.where(Receipt.receipt_date >= filters.date_from)
if filters.date_to:
query = query.where(Receipt.receipt_date <= filters.date_to)
if filters.search:
search_term = f"%{filters.search}%"
query = query.where(
or_(
Receipt.description.ilike(search_term),
Receipt.partner_name.ilike(search_term),
Receipt.receipt_number.ilike(search_term),
)
)
# Count total
count_query = select(func.count()).select_from(query.subquery())
total_result = await session.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination and ordering
query = query.order_by(Receipt.created_at.desc())
offset = (filters.page - 1) * filters.page_size
query = query.offset(offset).limit(filters.page_size)
# Execute
result = await session.execute(query)
receipts = result.scalars().all()
return list(receipts), total
@staticmethod
async def get_pending_review(
session: AsyncSession,
company_id: Optional[int] = None,
) -> List[Receipt]:
"""Get all receipts pending review."""
query = select(Receipt).where(
Receipt.status == ReceiptStatus.PENDING_REVIEW
).options(
selectinload(Receipt.attachments),
selectinload(Receipt.entries),
)
if company_id:
query = query.where(Receipt.company_id == company_id)
query = query.order_by(Receipt.submitted_at.asc())
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
async def update(
session: AsyncSession,
receipt: Receipt,
data: ReceiptUpdate,
) -> Receipt:
"""Update receipt fields."""
update_data = data.model_dump(exclude_unset=True)
# Serialize tva_breakdown and payment_methods to JSON string if present
if 'tva_breakdown' in update_data:
update_data['tva_breakdown'] = _serialize_tva_breakdown(update_data['tva_breakdown'])
if 'payment_methods' in update_data:
update_data['payment_methods'] = _serialize_payment_methods(update_data['payment_methods'])
for field, value in update_data.items():
setattr(receipt, field, value)
receipt.updated_at = datetime.utcnow()
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def update_status(
session: AsyncSession,
receipt: Receipt,
new_status: ReceiptStatus,
reviewed_by: Optional[str] = None,
rejection_reason: Optional[str] = None,
) -> Receipt:
"""Update receipt workflow status."""
receipt.status = new_status
receipt.updated_at = datetime.utcnow()
if new_status == ReceiptStatus.PENDING_REVIEW:
receipt.submitted_at = datetime.utcnow()
if new_status in [ReceiptStatus.APPROVED, ReceiptStatus.REJECTED]:
receipt.reviewed_by = reviewed_by
receipt.reviewed_at = datetime.utcnow()
if new_status == ReceiptStatus.REJECTED:
receipt.rejection_reason = rejection_reason
if new_status == ReceiptStatus.DRAFT:
# Reset review fields when moving back to draft
receipt.rejection_reason = None
session.add(receipt)
await session.commit()
await session.refresh(receipt)
# Reload with relationships to avoid lazy loading issues with async
return await ReceiptCRUD.get_by_id(session, receipt.id, include_relations=True)
@staticmethod
async def delete(session: AsyncSession, receipt: Receipt) -> bool:
"""Delete a receipt (cascade deletes attachments and entries)."""
await session.delete(receipt)
await session.commit()
return True
@staticmethod
async def can_edit(receipt: Receipt, username: str) -> bool:
"""Check if user can edit receipt."""
# DRAFT and REJECTED receipts can be edited (to fix and resubmit)
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
return False
# Only creator can edit their own receipts
return receipt.created_by == username
@staticmethod
async def can_delete(receipt: Receipt, username: str) -> bool:
"""Check if user can delete receipt."""
# Only DRAFT receipts can be deleted
if receipt.status != ReceiptStatus.DRAFT:
return False
# Only creator can delete their own drafts
return receipt.created_by == username
@staticmethod
async def can_submit(receipt: Receipt, username: str) -> bool:
"""Check if user can submit receipt for review."""
# Only DRAFT or REJECTED receipts can be submitted
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
return False
# Only creator can submit their own receipts
return receipt.created_by == username
@staticmethod
async def get_stats(
session: AsyncSession,
company_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Get receipt statistics."""
base_query = select(
Receipt.status,
func.count(Receipt.id).label("count"),
func.sum(Receipt.amount).label("total_amount"),
).where(
Receipt.company_id == company_id
)
if created_by:
base_query = base_query.where(Receipt.created_by == created_by)
query = base_query.group_by(Receipt.status)
result = await session.execute(query)
rows = result.all()
stats = {
"draft": {"count": 0, "amount": 0},
"pending_review": {"count": 0, "amount": 0},
"approved": {"count": 0, "amount": 0},
"rejected": {"count": 0, "amount": 0},
"synced": {"count": 0, "amount": 0},
"total": {"count": 0, "amount": 0},
}
for row in rows:
status_key = row.status.value
stats[status_key] = {
"count": row.count,
"amount": float(row.total_amount or 0),
}
stats["total"]["count"] += row.count
stats["total"]["amount"] += float(row.total_amount or 0)
return stats

View File

@@ -0,0 +1,49 @@
"""Database configuration and session management using SQLModel."""
from pathlib import Path
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from backend.modules.data_entry.config import settings
# Create async engine
engine = create_async_engine(
settings.database_url,
echo=settings.debug,
future=True,
)
# Create async session factory
async_session_maker = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def init_db() -> None:
"""Initialize database - create tables if they don't exist."""
# Ensure data directory exists
db_path = Path(settings.sqlite_database_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Get async database session for dependency injection."""
async with async_session_maker() as session:
try:
yield session
finally:
await session.close()
# Convenience function for manual session usage
async def get_db_session() -> AsyncSession:
"""Get a new database session (manual management)."""
return async_session_maker()

View File

@@ -0,0 +1,17 @@
# Database models
from .receipt import Receipt, ReceiptAttachment, ReceiptStatus, ReceiptType, ReceiptDirection
from .accounting_entry import AccountingEntry, EntryType
from .nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
__all__ = [
"Receipt",
"ReceiptAttachment",
"ReceiptStatus",
"ReceiptType",
"ReceiptDirection",
"AccountingEntry",
"EntryType",
"SyncedSupplier",
"LocalSupplier",
"SyncedCashRegister",
]

View File

@@ -0,0 +1,49 @@
"""AccountingEntry SQLModel model for proposed accounting entries."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship
if TYPE_CHECKING:
from .receipt import Receipt
class EntryType(str, Enum):
"""Type of accounting entry."""
DEBIT = "debit"
CREDIT = "credit"
class AccountingEntry(SQLModel, table=True):
"""Proposed accounting entry for a receipt."""
__tablename__ = "accounting_entries"
id: Optional[int] = Field(default=None, primary_key=True)
receipt_id: int = Field(foreign_key="receipts.id", index=True)
# Account
entry_type: EntryType
account_code: str = Field(max_length=20) # e.g., 6022, 5311, 4426
account_name: Optional[str] = Field(default=None, max_length=200) # Cache: "Cheltuieli combustibil"
# Amount
amount: Decimal = Field(decimal_places=2, max_digits=15)
# Analytics (optional)
partner_id: Optional[int] = Field(default=None)
cost_center_id: Optional[int] = Field(default=None)
# Entry metadata
is_auto_generated: bool = Field(default=True) # True if system-generated
modified_by: Optional[str] = Field(default=None, max_length=100) # Username if modified
modified_at: Optional[datetime] = Field(default=None)
# Order for display
sort_order: int = Field(default=0)
# Relationship
receipt: Optional["Receipt"] = Relationship(back_populates="entries")

View File

@@ -0,0 +1,46 @@
"""Nomenclature models for synced and local data."""
from typing import Optional
from datetime import datetime
from sqlmodel import SQLModel, Field
class SyncedSupplier(SQLModel, table=True):
"""Suppliers synced from Oracle NOM_PARTENERI."""
__tablename__ = "synced_suppliers"
id: Optional[int] = Field(default=None, primary_key=True)
oracle_id: int = Field(index=True) # Original Oracle ID
company_id: int = Field(index=True) # Company this supplier belongs to
name: str = Field(max_length=200)
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True) # CUI/CIF
address: Optional[str] = Field(default=None, max_length=500)
synced_at: datetime = Field(default_factory=datetime.utcnow)
class LocalSupplier(SQLModel, table=True):
"""Suppliers created locally from OCR (not in Oracle)."""
__tablename__ = "local_suppliers"
id: Optional[int] = Field(default=None, primary_key=True)
company_id: int = Field(index=True)
name: str = Field(max_length=200)
fiscal_code: Optional[str] = Field(default=None, max_length=50, index=True)
address: Optional[str] = Field(default=None, max_length=500)
created_by: str = Field(max_length=100) # Username who created it
created_at: datetime = Field(default_factory=datetime.utcnow)
# Flag to indicate if it should be synced to Oracle later
pending_oracle_sync: bool = Field(default=True)
class SyncedCashRegister(SQLModel, table=True):
"""Cash registers and bank accounts synced from Oracle."""
__tablename__ = "synced_cash_registers"
id: Optional[int] = Field(default=None, primary_key=True)
oracle_id: int = Field(index=True)
company_id: int = Field(index=True)
name: str = Field(max_length=100)
account_code: str = Field(max_length=20) # 5311, 5121, etc.
register_type: str = Field(max_length=10) # 'cash' or 'bank'
synced_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -0,0 +1,127 @@
"""Receipt and ReceiptAttachment SQLModel models."""
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship
class ReceiptType(str, Enum):
"""Type of receipt document."""
BON_FISCAL = "bon_fiscal"
CHITANTA = "chitanta"
class ReceiptDirection(str, Enum):
"""Direction of receipt - expense or income."""
CHELTUIALA = "cheltuiala" # Expense (receipt from supplier)
INCASARE = "incasare" # Income (receipt issued to client)
class ReceiptStatus(str, Enum):
"""Workflow status of receipt."""
DRAFT = "draft" # User is filling in data
PENDING_REVIEW = "pending_review" # Awaiting accountant approval
APPROVED = "approved" # Approved by accountant
REJECTED = "rejected" # Rejected by accountant
SYNCED = "synced" # Synced to Oracle (Phase 2)
class PaymentMode(str, Enum):
"""Payment mode - how the expense was paid."""
CASA = "casa" # Numerar firma (5311)
BANCA = "banca" # Virament/POS (5121)
AVANS_DECONTARE = "avans_decontare" # Decont angajat (542)
if TYPE_CHECKING:
from .accounting_entry import AccountingEntry
class Receipt(SQLModel, table=True):
"""Receipt (Bon Fiscal / Chitanta) with approval workflow."""
__tablename__ = "receipts"
id: Optional[int] = Field(default=None, primary_key=True)
# Document identification
receipt_type: ReceiptType = Field(default=ReceiptType.BON_FISCAL)
direction: ReceiptDirection = Field(default=ReceiptDirection.CHELTUIALA)
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
# Main data
receipt_date: date
amount: Decimal = Field(decimal_places=2, max_digits=15)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (extracted from OCR) - stored as JSON for multiple entries
tva_breakdown: Optional[str] = Field(default=None, max_length=1000) # JSON: [{"code":"A","percent":19,"amount":"15.20"}]
tva_total: Optional[Decimal] = Field(default=None, decimal_places=2, max_digits=15)
items_count: Optional[int] = Field(default=None)
vendor_address: Optional[str] = Field(default=None, max_length=500)
# Expense type (for auto-generating accounting entries)
expense_type_code: Optional[str] = Field(default=None, max_length=20)
# Oracle references (nomenclatures)
company_id: int
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200) # Supplier name from OCR/selection
cui: Optional[str] = Field(default=None, max_length=20) # Fiscal code from OCR
ocr_raw_text: Optional[str] = Field(default=None) # Raw OCR text for debugging
payment_methods: Optional[str] = Field(default=None, max_length=500) # JSON: [{"method":"CARD","amount":"50.00"}]
cash_register_id: Optional[int] = Field(default=None) # Cash/Bank ID from Oracle
cash_register_name: Optional[str] = Field(default=None, max_length=100) # Cache for display
cash_register_account: Optional[str] = Field(default=None, max_length=20) # Account code (5311, 5121)
payment_mode: Optional[str] = Field(default=None, max_length=20) # PaymentMode value: casa/banca/avans_decontare
# Workflow
status: ReceiptStatus = Field(default=ReceiptStatus.DRAFT)
created_by: str = Field(max_length=100) # Username of creator
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
submitted_at: Optional[datetime] = Field(default=None) # When submitted for approval
# Approval
reviewed_by: Optional[str] = Field(default=None, max_length=100) # Accountant username
reviewed_at: Optional[datetime] = Field(default=None)
rejection_reason: Optional[str] = Field(default=None, max_length=500) # Reason for rejection
# Phase 2 - Oracle sync
oracle_synced_at: Optional[datetime] = Field(default=None)
oracle_act_id: Optional[int] = Field(default=None)
oracle_error: Optional[str] = Field(default=None, max_length=500)
# Relationships
attachments: List["ReceiptAttachment"] = Relationship(
back_populates="receipt",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
entries: List["AccountingEntry"] = Relationship(
back_populates="receipt",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
class ReceiptAttachment(SQLModel, table=True):
"""Attachment (photo or PDF) for a receipt."""
__tablename__ = "receipt_attachments"
id: Optional[int] = Field(default=None, primary_key=True)
receipt_id: int = Field(foreign_key="receipts.id", index=True)
# File info
filename: str = Field(max_length=255) # Original filename
stored_filename: str = Field(max_length=255) # Filename on disk (UUID)
file_path: str = Field(max_length=500) # Relative path
file_size: int # Size in bytes
mime_type: str = Field(max_length=100) # MIME type (image/jpeg, application/pdf)
uploaded_at: datetime = Field(default_factory=datetime.utcnow)
# Relationship
receipt: Optional[Receipt] = Relationship(back_populates="attachments")

View File

@@ -0,0 +1,89 @@
"""Alembic environment configuration."""
import os
from logging.config import fileConfig
from dotenv import load_dotenv
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlmodel import SQLModel
# Load environment variables from .env file
load_dotenv()
# Import all models to ensure they're registered with SQLModel
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptAttachment
from backend.modules.data_entry.db.models.accounting_entry import AccountingEntry
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Override sqlalchemy.url from environment variable if set
db_path = os.getenv("SQLITE_DATABASE_PATH", "data/receipts/receipts.db")
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,112 @@
"""Initial receipts schema
Revision ID: 001_initial
Revises:
Create Date: 2024-12-11
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '001_initial'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create receipts table
op.create_table(
'receipts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_type', sa.Enum('BON_FISCAL', 'CHITANTA', name='receipttype'), nullable=False),
sa.Column('direction', sa.Enum('CHELTUIALA', 'INCASARE', name='receiptdirection'), nullable=False),
sa.Column('receipt_number', sa.String(length=50), nullable=True),
sa.Column('receipt_series', sa.String(length=20), nullable=True),
sa.Column('receipt_date', sa.Date(), nullable=False),
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
sa.Column('description', sa.String(length=500), nullable=True),
sa.Column('expense_type_code', sa.String(length=20), nullable=True),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=True),
sa.Column('partner_name', sa.String(length=200), nullable=True),
sa.Column('cash_register_id', sa.Integer(), nullable=True),
sa.Column('cash_register_name', sa.String(length=100), nullable=True),
sa.Column('cash_register_account', sa.String(length=20), nullable=True),
sa.Column('status', sa.Enum('DRAFT', 'PENDING_REVIEW', 'APPROVED', 'REJECTED', 'SYNCED', name='receiptstatus'), nullable=False),
sa.Column('created_by', sa.String(length=100), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('submitted_at', sa.DateTime(), nullable=True),
sa.Column('reviewed_by', sa.String(length=100), nullable=True),
sa.Column('reviewed_at', sa.DateTime(), nullable=True),
sa.Column('rejection_reason', sa.String(length=500), nullable=True),
sa.Column('oracle_synced_at', sa.DateTime(), nullable=True),
sa.Column('oracle_act_id', sa.Integer(), nullable=True),
sa.Column('oracle_error', sa.String(length=500), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_receipts_company_id'), 'receipts', ['company_id'], unique=False)
op.create_index(op.f('ix_receipts_status'), 'receipts', ['status'], unique=False)
op.create_index(op.f('ix_receipts_created_by'), 'receipts', ['created_by'], unique=False)
op.create_index(op.f('ix_receipts_receipt_date'), 'receipts', ['receipt_date'], unique=False)
# Create receipt_attachments table
op.create_table(
'receipt_attachments',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_id', sa.Integer(), nullable=False),
sa.Column('filename', sa.String(length=255), nullable=False),
sa.Column('stored_filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('file_size', sa.Integer(), nullable=False),
sa.Column('mime_type', sa.String(length=100), nullable=False),
sa.Column('uploaded_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_receipt_attachments_receipt_id'), 'receipt_attachments', ['receipt_id'], unique=False)
# Create accounting_entries table
op.create_table(
'accounting_entries',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('receipt_id', sa.Integer(), nullable=False),
sa.Column('entry_type', sa.Enum('DEBIT', 'CREDIT', name='entrytype'), nullable=False),
sa.Column('account_code', sa.String(length=20), nullable=False),
sa.Column('account_name', sa.String(length=200), nullable=True),
sa.Column('amount', sa.Numeric(precision=15, scale=2), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=True),
sa.Column('cost_center_id', sa.Integer(), nullable=True),
sa.Column('is_auto_generated', sa.Boolean(), nullable=False),
sa.Column('modified_by', sa.String(length=100), nullable=True),
sa.Column('modified_at', sa.DateTime(), nullable=True),
sa.Column('sort_order', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['receipt_id'], ['receipts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_accounting_entries_receipt_id'), 'accounting_entries', ['receipt_id'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_accounting_entries_receipt_id'), table_name='accounting_entries')
op.drop_table('accounting_entries')
op.drop_index(op.f('ix_receipt_attachments_receipt_id'), table_name='receipt_attachments')
op.drop_table('receipt_attachments')
op.drop_index(op.f('ix_receipts_receipt_date'), table_name='receipts')
op.drop_index(op.f('ix_receipts_created_by'), table_name='receipts')
op.drop_index(op.f('ix_receipts_status'), table_name='receipts')
op.drop_index(op.f('ix_receipts_company_id'), table_name='receipts')
op.drop_table('receipts')
# Drop enums (SQLite doesn't actually use these, but for consistency)
op.execute("DROP TYPE IF EXISTS receipttype")
op.execute("DROP TYPE IF EXISTS receiptdirection")
op.execute("DROP TYPE IF EXISTS receiptstatus")
op.execute("DROP TYPE IF EXISTS entrytype")

View File

@@ -0,0 +1,37 @@
"""add_tva_breakdown_to_receipt
Revision ID: 1cfb423c6953
Revises: 001_initial
Create Date: 2025-12-12 14:04:22.464289+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '1cfb423c6953'
down_revision: Union[str, None] = '001_initial'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add TVA-related columns to receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('tva_breakdown', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True))
batch_op.add_column(sa.Column('tva_total', sa.Numeric(precision=15, scale=2), nullable=True))
batch_op.add_column(sa.Column('items_count', sa.Integer(), nullable=True))
batch_op.add_column(sa.Column('vendor_address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
def downgrade() -> None:
# Remove TVA-related columns from receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('vendor_address')
batch_op.drop_column('items_count')
batch_op.drop_column('tva_total')
batch_op.drop_column('tva_breakdown')

View File

@@ -0,0 +1,89 @@
"""add nomenclature tables
Revision ID: 3a653da79002
Revises: 1cfb423c6953
Create Date: 2025-12-13 00:28:05.719430+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '3a653da79002'
down_revision: Union[str, None] = '1cfb423c6953'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('local_suppliers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column('created_by', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('pending_oracle_sync', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_local_suppliers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_local_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
op.create_table('synced_cash_registers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('oracle_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
sa.Column('account_code', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False),
sa.Column('register_type', sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False),
sa.Column('synced_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_synced_cash_registers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_cash_registers_oracle_id'), ['oracle_id'], unique=False)
op.create_table('synced_suppliers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('oracle_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column('fiscal_code', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column('synced_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_synced_suppliers_company_id'), ['company_id'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_suppliers_fiscal_code'), ['fiscal_code'], unique=False)
batch_op.create_index(batch_op.f('ix_synced_suppliers_oracle_id'), ['oracle_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('synced_suppliers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_synced_suppliers_oracle_id'))
batch_op.drop_index(batch_op.f('ix_synced_suppliers_fiscal_code'))
batch_op.drop_index(batch_op.f('ix_synced_suppliers_company_id'))
op.drop_table('synced_suppliers')
with op.batch_alter_table('synced_cash_registers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_oracle_id'))
batch_op.drop_index(batch_op.f('ix_synced_cash_registers_company_id'))
op.drop_table('synced_cash_registers')
with op.batch_alter_table('local_suppliers', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_local_suppliers_fiscal_code'))
batch_op.drop_index(batch_op.f('ix_local_suppliers_company_id'))
op.drop_table('local_suppliers')
# ### end Alembic commands ###

View File

@@ -0,0 +1,35 @@
"""add_ocr_fields_to_receipt
Revision ID: 4b8e5f2a1d93
Revises: 3a653da79002
Create Date: 2025-12-15 10:00:00.000000+00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '4b8e5f2a1d93'
down_revision: Union[str, None] = '3a653da79002'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add OCR-related columns to receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('cui', sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
batch_op.add_column(sa.Column('ocr_raw_text', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('payment_methods', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True))
def downgrade() -> None:
# Remove OCR-related columns from receipts table
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('payment_methods')
batch_op.drop_column('ocr_raw_text')
batch_op.drop_column('cui')

View File

@@ -0,0 +1,29 @@
"""Remove partner_id from receipts - supplier data is text-only
Revision ID: 20251215_remove_partner_id
Revises: 20251216_payment_mode
Create Date: 2025-12-15
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '20251215_remove_partner_id'
down_revision: Union[str, None] = '20251216_payment_mode'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Remove partner_id column - supplier data is now text-only (partner_name, cui)."""
# Drop the partner_id column
op.drop_column('receipts', 'partner_id')
def downgrade() -> None:
"""Re-add partner_id column."""
op.add_column('receipts', sa.Column('partner_id', sa.Integer(), nullable=True))

View File

@@ -0,0 +1,44 @@
"""Add payment_mode field to receipts table.
Revision ID: 20251216_payment_mode
Revises: 4b8e5f2a1d93
Create Date: 2024-12-16
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '20251216_payment_mode'
down_revision = '4b8e5f2a1d93'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add payment_mode column and migrate existing data."""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.add_column(sa.Column('payment_mode', sa.String(length=20), nullable=True))
# Migrate existing data based on cash_register_account
op.execute("""
UPDATE receipts
SET payment_mode = 'casa'
WHERE cash_register_account LIKE '531%' AND payment_mode IS NULL
""")
op.execute("""
UPDATE receipts
SET payment_mode = 'banca'
WHERE cash_register_account LIKE '512%' AND payment_mode IS NULL
""")
op.execute("""
UPDATE receipts
SET payment_mode = 'avans_decontare'
WHERE cash_register_account LIKE '542%' AND payment_mode IS NULL
""")
def downgrade() -> None:
"""Remove payment_mode column."""
with op.batch_alter_table('receipts', schema=None) as batch_op:
batch_op.drop_column('payment_mode')

View File

@@ -0,0 +1,30 @@
"""Data Entry module router factory."""
from fastapi import APIRouter
def create_data_entry_router() -> APIRouter:
"""
Create and configure Data Entry module router.
Includes all data entry endpoints:
- /receipts - Receipt CRUD and workflow
- /ocr - OCR processing for receipts
- /nomenclature - Nomenclature syncing from Oracle
Returns:
APIRouter: Configured router for data entry module
"""
router = APIRouter()
# Import routers here to avoid circular imports
from .receipts import router as receipts_router
from .ocr import router as ocr_router
from .nomenclature import router as nomenclature_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/data-entry)
router.include_router(receipts_router, prefix="/receipts", tags=["data-entry-receipts"])
router.include_router(ocr_router, prefix="/ocr", tags=["data-entry-ocr"])
router.include_router(nomenclature_router, prefix="/nomenclature", tags=["data-entry-nomenclature"])
return router

View File

@@ -0,0 +1,254 @@
"""Nomenclature API endpoints."""
from typing import Optional, List, Annotated
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.services.sync_service import SyncService
# Import auth dependencies
import sys
from pathlib import Path
# Path setup handled by main.py - this is redundant
# project_root = Path(__file__).parent.parent.parent.parent.parent
# sys.path.insert(0, str(project_root / "shared"))
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============ Selected Company Dependency ============
async def get_selected_company(
current_user: CurrentUser = Depends(get_current_user),
x_selected_company: Annotated[Optional[str], Header()] = None
) -> int:
"""
Get selected company from X-Selected-Company header.
Validates user access. Falls back to first company if no header.
"""
if x_selected_company:
try:
company_id = int(x_selected_company)
except ValueError:
raise HTTPException(400, f"Invalid company ID: {x_selected_company}")
if str(company_id) in current_user.companies:
return company_id
raise HTTPException(403, f"Nu aveți acces la firma {company_id}")
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
pass
raise HTTPException(400, "Nu aveți nicio firmă asignată")
SelectedCompany = Annotated[int, Depends(get_selected_company)]
# Request/Response Models
class SupplierSearchResult(BaseModel):
found: bool
supplier: Optional[dict] = None
source: str # 'synced', 'local', 'not_found'
class LocalSupplierCreate(BaseModel):
name: str
fiscal_code: Optional[str] = None
address: Optional[str] = None
class LocalSupplierResponse(BaseModel):
id: int
name: str
fiscal_code: Optional[str]
address: Optional[str]
is_local: bool = True
class SyncResult(BaseModel):
synced: int
errors: int
message: str
class SupplierOption(BaseModel):
id: int
oracle_id: Optional[int] = None
name: str
fiscal_code: Optional[str]
source: str # 'synced' or 'local'
class CashRegisterOption(BaseModel):
id: int
oracle_id: int
name: str
account_code: str
register_type: str
# Endpoints
@router.get("/suppliers/search", response_model=SupplierSearchResult)
async def search_supplier(
fiscal_code: Optional[str] = None,
name: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Search for supplier by fiscal code or name."""
if not fiscal_code and not name:
raise HTTPException(status_code=400, detail="Provide fiscal_code or name")
cid = company_id or selected_company
found, supplier, source = await SyncService.search_supplier(
session, cid, fiscal_code, name
)
return SupplierSearchResult(found=found, supplier=supplier, source=source)
@router.get("/suppliers", response_model=List[SupplierOption])
async def get_suppliers(
search: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all suppliers (synced + local) for dropdown/autocomplete."""
cid = company_id or selected_company
suppliers = await SyncService.get_all_suppliers(session, cid, search)
return [
SupplierOption(
id=s["id"],
oracle_id=s.get("oracle_id"),
name=s["name"],
fiscal_code=s.get("fiscal_code"),
source=s["source"]
)
for s in suppliers
]
@router.post("/suppliers/local", response_model=LocalSupplierResponse)
async def create_local_supplier(
data: LocalSupplierCreate,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
current_user: CurrentUser = Depends(get_current_user),
):
"""Create a local supplier from OCR data."""
cid = company_id or selected_company
supplier = await SyncService.create_local_supplier(
session, cid, data.name, data.fiscal_code, data.address, current_user.username
)
return LocalSupplierResponse(
id=supplier.id,
name=supplier.name,
fiscal_code=supplier.fiscal_code,
address=supplier.address,
)
@router.get("/cash-registers", response_model=List[CashRegisterOption])
async def get_cash_registers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all cash registers for a company."""
cid = company_id or selected_company
registers = await SyncService.get_all_cash_registers(session, cid)
return [
CashRegisterOption(
id=r["id"],
oracle_id=r["oracle_id"],
name=r["name"],
account_code=r["account_code"],
register_type=r["register_type"]
)
for r in registers
]
@router.post("/sync/suppliers", response_model=SyncResult)
async def sync_suppliers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Manually trigger supplier sync from Oracle."""
cid = company_id or selected_company
synced, errors = await SyncService.sync_suppliers(session, cid)
return SyncResult(
synced=synced,
errors=errors,
message=f"Synced {synced} suppliers with {errors} errors"
)
@router.post("/sync/cash-registers", response_model=SyncResult)
async def sync_cash_registers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Manually trigger cash register sync from Oracle."""
cid = company_id or selected_company
synced, errors = await SyncService.sync_cash_registers(session, cid)
return SyncResult(
synced=synced,
errors=errors,
message=f"Synced {synced} cash registers with {errors} errors"
)
@router.post("/sync/all", response_model=dict)
async def sync_all_nomenclatures(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Sync all nomenclatures (suppliers + cash registers) from Oracle."""
cid = company_id or selected_company
# Sync suppliers
suppliers_synced, suppliers_errors = await SyncService.sync_suppliers(session, cid)
# Sync cash registers
registers_synced, registers_errors = await SyncService.sync_cash_registers(session, cid)
return {
"suppliers": {
"synced": suppliers_synced,
"errors": suppliers_errors
},
"cash_registers": {
"synced": registers_synced,
"errors": registers_errors
},
"total_synced": suppliers_synced + registers_synced,
"total_errors": suppliers_errors + registers_errors,
"message": f"Synced {suppliers_synced} suppliers and {registers_synced} cash registers"
}

View File

@@ -0,0 +1,218 @@
"""OCR API endpoints."""
import os
import tempfile
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
from backend.modules.data_entry.services.ocr_service import ocr_service
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.schemas.ocr import OCRResponse, OCRStatusResponse, ExtractionData, TvaEntry, PaymentMethod
# Auth integration (will be protected by middleware)
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
@router.get("/status", response_model=OCRStatusResponse)
async def get_ocr_status():
"""Check OCR service status and available engines."""
engines = OCREngine.get_available_engines()
available = len(engines) > 0
if available:
message = f"OCR service ready with engines: {', '.join(engines)}"
else:
message = "No OCR engines available. Install PaddleOCR or Tesseract."
return OCRStatusResponse(
available=available,
engines=engines,
message=message
)
@router.post("/extract", response_model=OCRResponse)
async def extract_from_image(file: UploadFile = File(...)):
"""
Extract receipt data from uploaded image.
Accepts JPG, PNG, or PDF files (max 10MB).
Returns extracted fields with confidence scores.
"""
allowed_types = ['image/jpeg', 'image/png', 'application/pdf']
if file.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"File type not supported: {file.content_type}. Allowed: JPG, PNG, PDF"
)
# Get file extension
suffix = Path(file.filename).suffix.lower() if file.filename else '.jpg'
if suffix not in ['.jpg', '.jpeg', '.png', '.pdf']:
suffix = '.jpg'
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = await file.read()
# Check file size (10MB limit)
if len(content) > 10 * 1024 * 1024:
raise HTTPException(
status_code=400,
detail="File too large. Maximum size is 10MB."
)
tmp.write(content)
tmp_path = Path(tmp.name)
try:
success, message, result = await ocr_service.process_image(
tmp_path, file.content_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
# Convert ExtractionResult to ExtractionData schema
# Convert tva_entries from dict to TvaEntry objects
tva_entries_schema = [
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
for e in result.tva_entries
] if result.tva_entries else []
# Convert payment_methods from dict to PaymentMethod objects
from decimal import Decimal
payment_methods_list = [
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
for pm in result.payment_methods
] if result.payment_methods else []
# Auto-suggest payment_mode based on detected methods
suggested_payment_mode = None
if payment_methods_list:
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
if has_card:
suggested_payment_mode = 'banca'
# NUMERAR -> no auto-suggestion, user chooses between casa/avans
data = ExtractionData(
receipt_type=result.receipt_type,
receipt_number=result.receipt_number,
receipt_series=result.receipt_series,
receipt_date=result.receipt_date,
amount=result.amount,
partner_name=result.partner_name,
cui=result.cui,
description=result.description,
tva_entries=tva_entries_schema,
tva_total=result.tva_total,
address=result.address,
items_count=result.items_count,
payment_methods=payment_methods_list,
suggested_payment_mode=suggested_payment_mode,
confidence_amount=result.confidence_amount,
confidence_date=result.confidence_date,
confidence_vendor=result.confidence_vendor,
overall_confidence=result.overall_confidence,
raw_text=result.raw_text,
ocr_engine=result.ocr_engine,
processing_time_ms=result.processing_time_ms,
)
return OCRResponse(success=True, message=message, data=data)
finally:
# Clean up temp file
if tmp_path.exists():
os.unlink(tmp_path)
@router.post("/extract-attachment/{attachment_id}", response_model=OCRResponse)
async def extract_from_attachment(
attachment_id: int,
session: AsyncSession = Depends(get_session),
):
"""
Extract receipt data from an existing attachment.
Re-processes an already uploaded file with OCR.
"""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
file_path = AttachmentCRUD.get_file_path(attachment)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
# Check if file type is supported
if attachment.mime_type not in ['image/jpeg', 'image/png', 'application/pdf']:
raise HTTPException(
status_code=400,
detail=f"File type not supported for OCR: {attachment.mime_type}"
)
success, message, result = await ocr_service.process_image(
file_path, attachment.mime_type
)
if not success:
raise HTTPException(status_code=422, detail=message)
# Convert ExtractionResult to ExtractionData schema
# Convert tva_entries from dict to TvaEntry objects
tva_entries_schema = [
TvaEntry(code=e.get('code'), percent=e['percent'], amount=e['amount'])
for e in result.tva_entries
] if result.tva_entries else []
# Convert payment_methods from dict to PaymentMethod objects
from decimal import Decimal
payment_methods_list = [
PaymentMethod(method=pm['method'], amount=Decimal(str(pm['amount'])))
for pm in result.payment_methods
] if result.payment_methods else []
# Auto-suggest payment_mode based on detected methods
suggested_payment_mode = None
if payment_methods_list:
has_card = any(pm.method == 'CARD' for pm in payment_methods_list)
if has_card:
suggested_payment_mode = 'banca'
# NUMERAR -> no auto-suggestion, user chooses between casa/avans
data = ExtractionData(
receipt_type=result.receipt_type,
receipt_number=result.receipt_number,
receipt_series=result.receipt_series,
receipt_date=result.receipt_date,
amount=result.amount,
partner_name=result.partner_name,
cui=result.cui,
description=result.description,
tva_entries=tva_entries_schema,
tva_total=result.tva_total,
address=result.address,
items_count=result.items_count,
payment_methods=payment_methods_list,
suggested_payment_mode=suggested_payment_mode,
confidence_amount=result.confidence_amount,
confidence_date=result.confidence_date,
confidence_vendor=result.confidence_vendor,
overall_confidence=result.overall_confidence,
raw_text=result.raw_text,
ocr_engine=result.ocr_engine,
processing_time_ms=result.processing_time_ms,
)
return OCRResponse(success=True, message=message, data=data)

View File

@@ -0,0 +1,517 @@
"""API endpoints for receipts."""
from typing import List, Optional, Annotated
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Header
from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.database import get_session
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.attachment import AttachmentCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.services.receipt_service import ReceiptService
from backend.modules.data_entry.services.nomenclature_service import NomenclatureService
from backend.modules.data_entry.schemas.receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptResponse,
ReceiptListResponse,
ReceiptFilter,
AttachmentResponse,
AccountingEntryResponse,
WorkflowAction,
RejectRequest,
EntriesUpdateRequest,
PartnerOption,
AccountOption,
CashRegisterOption,
ExpenseTypeOption,
)
from backend.modules.data_entry.db.models.receipt import ReceiptStatus, ReceiptDirection
# Auth integration
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
router = APIRouter()
# ============ Helper for selected company from header ============
async def get_selected_company(
current_user: CurrentUser = Depends(get_current_user),
x_selected_company: Annotated[Optional[str], Header()] = None
) -> int:
"""
Get selected company from X-Selected-Company header.
Validates that the user has access to the specified company.
Falls back to user's first company if no header is provided.
Raises:
HTTPException 403: If user doesn't have access to specified company
HTTPException 400: If user has no companies assigned
"""
if x_selected_company:
try:
company_id = int(x_selected_company)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid company ID format: {x_selected_company}"
)
# Validate user has access to this company
# Auth stores companies as strings
if str(company_id) in current_user.companies:
return company_id
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company_id}"
)
# No header - use first company from user's list
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
pass
raise HTTPException(
status_code=400,
detail="Nu aveți nicio firmă asignată"
)
# Dependency for injection
SelectedCompany = Annotated[int, Depends(get_selected_company)]
# Legacy function for backwards compatibility (deprecated)
def get_current_user_company(current_user: CurrentUser) -> int:
"""
DEPRECATED: Use get_selected_company() dependency instead.
This function returns the first company, ignoring X-Selected-Company header.
"""
if current_user.companies:
try:
return int(current_user.companies[0])
except (ValueError, IndexError):
return 1
return 1
# ============ Receipt CRUD Endpoints ============
@router.post("/", response_model=ReceiptResponse)
async def create_receipt(
data: ReceiptCreate,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Create a new receipt in DRAFT status."""
receipt = await ReceiptService.create_receipt(session, data, current_user.username)
return ReceiptResponse.model_validate(receipt)
@router.get("/", response_model=ReceiptListResponse)
async def list_receipts(
status: Optional[ReceiptStatus] = None,
direction: Optional[ReceiptDirection] = None,
company_id: Optional[int] = None,
created_by: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
search: Optional[str] = None,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get paginated list of receipts with filters."""
from datetime import date as date_type
filters = ReceiptFilter(
status=status,
direction=direction,
company_id=company_id or selected_company,
created_by=created_by,
date_from=date_type.fromisoformat(date_from) if date_from else None,
date_to=date_type.fromisoformat(date_to) if date_to else None,
search=search,
page=page,
page_size=page_size,
)
return await ReceiptService.get_receipts(session, filters)
@router.get("/pending", response_model=List[ReceiptResponse])
async def list_pending_receipts(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get all receipts pending review (for accountant view)."""
receipts = await ReceiptCRUD.get_pending_review(
session, company_id or selected_company
)
return [ReceiptResponse.model_validate(r) for r in receipts]
@router.get("/stats")
async def get_receipt_stats(
company_id: Optional[int] = None,
my_receipts: bool = False,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
current_user: CurrentUser = Depends(get_current_user),
):
"""Get receipt statistics."""
return await ReceiptCRUD.get_stats(
session,
company_id or selected_company,
created_by=current_user.username if my_receipts else None,
)
@router.get("/{receipt_id}", response_model=ReceiptResponse)
async def get_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
):
"""Get receipt details with attachments and accounting entries."""
receipt = await ReceiptService.get_receipt(session, receipt_id)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
return ReceiptResponse.model_validate(receipt)
@router.put("/{receipt_id}", response_model=ReceiptResponse)
async def update_receipt(
receipt_id: int,
data: ReceiptUpdate,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Update receipt (only DRAFT status, only by creator)."""
success, message, receipt = await ReceiptService.update_receipt(
session, receipt_id, data, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return ReceiptResponse.model_validate(receipt)
@router.delete("/{receipt_id}")
async def delete_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Delete receipt (only DRAFT status, only by creator)."""
success, message = await ReceiptService.delete_receipt(
session, receipt_id, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return {"success": True, "message": message}
# ============ Workflow Endpoints ============
@router.post("/{receipt_id}/submit", response_model=WorkflowAction)
async def submit_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Submit receipt for review (DRAFT → PENDING_REVIEW)."""
success, message, receipt = await ReceiptService.submit_for_review(
session, receipt_id, current_user.username
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/approve", response_model=WorkflowAction)
async def approve_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Approve receipt (PENDING_REVIEW → APPROVED). Accountant action."""
success, message, receipt = await ReceiptService.approve_receipt(
session, receipt_id, current_user.username
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/reject", response_model=WorkflowAction)
async def reject_receipt(
receipt_id: int,
data: RejectRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Reject receipt (PENDING_REVIEW → REJECTED). Accountant action."""
success, message, receipt = await ReceiptService.reject_receipt(
session, receipt_id, current_user.username, data.reason
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/resubmit", response_model=WorkflowAction)
async def resubmit_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW)."""
success, message, receipt = await ReceiptService.resubmit_receipt(
session, receipt_id, current_user.username
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
@router.post("/{receipt_id}/unapprove", response_model=WorkflowAction)
async def unapprove_receipt(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Unapprove receipt (APPROVED → PENDING_REVIEW). Returns to pending for corrections."""
success, message, receipt = await ReceiptService.unapprove_receipt(
session, receipt_id, current_user.username
)
return WorkflowAction(
success=success,
message=message,
receipt=ReceiptResponse.model_validate(receipt) if receipt else None,
)
# ============ Accounting Entries Endpoints ============
@router.get("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
async def get_receipt_entries(
receipt_id: int,
session: AsyncSession = Depends(get_session),
):
"""Get accounting entries for a receipt."""
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
return [AccountingEntryResponse.model_validate(e) for e in entries]
@router.put("/{receipt_id}/entries", response_model=List[AccountingEntryResponse])
async def update_receipt_entries(
receipt_id: int,
data: EntriesUpdateRequest,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Update accounting entries for a receipt (accountant action)."""
success, message, entries = await ReceiptService.update_entries(
session, receipt_id, data.entries, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
return [AccountingEntryResponse.model_validate(e) for e in entries]
@router.post("/{receipt_id}/entries/regenerate", response_model=List[AccountingEntryResponse])
async def regenerate_entries(
receipt_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Regenerate accounting entries based on receipt data."""
success, message, _ = await ReceiptService.regenerate_entries(
session, receipt_id, current_user.username
)
if not success:
raise HTTPException(status_code=400, detail=message)
entries = await AccountingEntryCRUD.get_by_receipt_id(session, receipt_id)
return [AccountingEntryResponse.model_validate(e) for e in entries]
# ============ Attachment Endpoints ============
@router.post("/{receipt_id}/attachments", response_model=AttachmentResponse)
async def upload_attachment(
receipt_id: int,
file: UploadFile = File(...),
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Upload attachment for a receipt."""
# Check receipt exists and user can modify it
receipt = await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=False)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
# Only allow uploads for DRAFT and REJECTED receipts
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.REJECTED]:
raise HTTPException(
status_code=400,
detail="Cannot upload attachments for this receipt status"
)
# Only creator can upload
if receipt.created_by != current_user.username:
raise HTTPException(
status_code=403,
detail="Only the creator can upload attachments"
)
try:
attachment = await AttachmentCRUD.create(session, receipt_id, file)
return AttachmentResponse.model_validate(attachment)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/{receipt_id}/attachments", response_model=List[AttachmentResponse])
async def list_attachments(
receipt_id: int,
session: AsyncSession = Depends(get_session),
):
"""Get all attachments for a receipt."""
attachments = await AttachmentCRUD.get_by_receipt_id(session, receipt_id)
return [AttachmentResponse.model_validate(a) for a in attachments]
@router.get("/attachments/{attachment_id}/download")
async def download_attachment(
attachment_id: int,
session: AsyncSession = Depends(get_session),
):
"""Download an attachment file."""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
file_path = AttachmentCRUD.get_file_path(attachment)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
return FileResponse(
path=str(file_path),
filename=attachment.filename,
media_type=attachment.mime_type,
)
@router.delete("/attachments/{attachment_id}")
async def delete_attachment(
attachment_id: int,
session: AsyncSession = Depends(get_session),
current_user: CurrentUser = Depends(get_current_user),
):
"""Delete an attachment."""
attachment = await AttachmentCRUD.get_by_id(session, attachment_id)
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
# Get receipt to check permissions
receipt = await ReceiptCRUD.get_by_id(session, attachment.receipt_id, include_relations=False)
if not receipt:
raise HTTPException(status_code=404, detail="Receipt not found")
# Only allow deletion for DRAFT receipts by creator
if receipt.status != ReceiptStatus.DRAFT:
raise HTTPException(
status_code=400,
detail="Cannot delete attachments for this receipt status"
)
if receipt.created_by != current_user.username:
raise HTTPException(
status_code=403,
detail="Only the creator can delete attachments"
)
await AttachmentCRUD.delete(session, attachment)
return {"success": True, "message": "Attachment deleted"}
# ============ Nomenclature Endpoints ============
@router.get("/nomenclature/partners", response_model=List[PartnerOption])
async def get_partners(
search: Optional[str] = None,
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get partners (suppliers/customers) for dropdown."""
return await NomenclatureService.get_partners(
company_id or selected_company, search, session
)
@router.get("/nomenclature/accounts", response_model=List[AccountOption])
async def get_accounts(
prefix: Optional[str] = None,
company_id: Optional[int] = None,
selected_company: SelectedCompany = None,
):
"""Get chart of accounts for dropdown."""
return await NomenclatureService.get_accounts(
company_id or selected_company, prefix
)
@router.get("/nomenclature/cash-registers", response_model=List[CashRegisterOption])
async def get_cash_registers(
company_id: Optional[int] = None,
session: AsyncSession = Depends(get_session),
selected_company: SelectedCompany = None,
):
"""Get cash registers and bank accounts for dropdown."""
return await NomenclatureService.get_cash_registers(company_id or selected_company, session)
@router.get("/nomenclature/expense-types", response_model=List[ExpenseTypeOption])
async def get_expense_types():
"""Get predefined expense types for dropdown."""
return await NomenclatureService.get_expense_types()

View File

@@ -0,0 +1,28 @@
# Pydantic schemas
from .receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptResponse,
ReceiptListResponse,
ReceiptFilter,
AttachmentResponse,
AccountingEntryCreate,
AccountingEntryUpdate,
AccountingEntryResponse,
WorkflowAction,
RejectRequest,
)
__all__ = [
"ReceiptCreate",
"ReceiptUpdate",
"ReceiptResponse",
"ReceiptListResponse",
"ReceiptFilter",
"AttachmentResponse",
"AccountingEntryCreate",
"AccountingEntryUpdate",
"AccountingEntryResponse",
"WorkflowAction",
"RejectRequest",
]

View File

@@ -0,0 +1,122 @@
"""Pydantic schemas for OCR API."""
from datetime import date
from decimal import Decimal
from typing import Optional, List
from pydantic import BaseModel, Field
class TvaEntry(BaseModel):
"""Single TVA entry with code, percentage and amount."""
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
amount: Decimal = Field(description="TVA amount for this rate")
class PaymentMethod(BaseModel):
"""Payment method entry from OCR."""
method: str = Field(description="CARD or NUMERAR")
amount: Decimal = Field(description="Amount paid")
class ExtractionData(BaseModel):
"""Extracted receipt data from OCR."""
receipt_type: str = Field(default='bon_fiscal', description="Receipt type: bon_fiscal or chitanta")
receipt_number: Optional[str] = Field(default=None, description="Receipt number")
receipt_series: Optional[str] = Field(default=None, description="Receipt series")
receipt_date: Optional[date] = Field(default=None, description="Receipt date")
amount: Optional[Decimal] = Field(default=None, description="Total amount")
partner_name: Optional[str] = Field(default=None, description="Vendor/partner name")
cui: Optional[str] = Field(default=None, description="CUI (fiscal identification code)")
description: Optional[str] = Field(default=None, description="Optional description")
# Additional extracted fields - Multiple TVA entries support
tva_entries: List[TvaEntry] = Field(default=[], description="List of TVA entries by rate (A, B, C, D)")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
address: Optional[str] = Field(default=None, description="Vendor address")
items_count: Optional[int] = Field(default=None, description="Number of items/articles")
# Payment methods extracted from receipt
payment_methods: List[PaymentMethod] = Field(default=[], description="Payment methods from receipt (CARD, NUMERAR)")
suggested_payment_mode: Optional[str] = Field(default=None, description="Auto-suggested payment mode based on OCR (casa/banca)")
# Client data (for B2B receipts - buyer information)
client_name: Optional[str] = Field(default=None, description="Client/customer company name")
client_cui: Optional[str] = Field(default=None, description="Client CUI/CIF fiscal code")
client_address: Optional[str] = Field(default=None, description="Client address")
confidence_amount: float = Field(default=0.0, ge=0, le=1, description="Amount extraction confidence")
confidence_date: float = Field(default=0.0, ge=0, le=1, description="Date extraction confidence")
confidence_vendor: float = Field(default=0.0, ge=0, le=1, description="Vendor extraction confidence")
confidence_client: float = Field(default=0.0, ge=0, le=1, description="Client extraction confidence")
overall_confidence: float = Field(default=0.0, ge=0, le=1, description="Overall confidence score")
raw_text: str = Field(default="", description="Raw OCR text")
ocr_engine: str = Field(default="", description="OCR engine used: paddleocr or tesseract")
processing_time_ms: int = Field(default=0, ge=0, description="Processing time in milliseconds")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"receipt_type": "bon_fiscal",
"receipt_number": "1360760",
"receipt_series": "0146",
"receipt_date": "2025-10-11",
"amount": 186.16,
"partner_name": "FIVE-HOLDING S.A.",
"cui": "10562600",
"description": None,
"tva_entries": [
{"code": "A", "percent": 19, "amount": 25.00},
{"code": "B", "percent": 9, "amount": 7.31}
],
"tva_total": 32.31,
"address": "JUD. CONSTANTA, MUN. CONSTANTA, STR. ION ROATA NR. 3",
"items_count": 17,
"confidence_amount": 0.98,
"confidence_date": 0.98,
"confidence_vendor": 0.95,
"overall_confidence": 0.97,
"raw_text": "FIVE-HOLDING S.A.\nCIF: RO10562600\n..."
}
}
class OCRResponse(BaseModel):
"""OCR API response."""
success: bool = Field(description="Whether OCR processing was successful")
message: str = Field(description="Status message")
data: Optional[ExtractionData] = Field(default=None, description="Extracted data")
class Config:
"""Pydantic config."""
json_schema_extra = {
"example": {
"success": True,
"message": "OCR processing successful. Found: amount, date, vendor",
"data": {
"receipt_type": "bon_fiscal",
"receipt_number": "12345",
"receipt_date": "2024-01-15",
"amount": 125.50,
"partner_name": "MEGA IMAGE SRL",
"cui": "12345678",
"confidence_amount": 0.95,
"confidence_date": 0.90,
"confidence_vendor": 0.75,
"overall_confidence": 0.87,
"raw_text": "BON FISCAL\nMEGA IMAGE SRL\n..."
}
}
}
class OCRStatusResponse(BaseModel):
"""OCR service status response."""
available: bool = Field(description="Whether OCR service is available")
engines: list[str] = Field(description="Available OCR engines")
message: str = Field(description="Status message")

View File

@@ -0,0 +1,269 @@
"""Pydantic schemas for receipts API."""
import json
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, List, Any, Union
from pydantic import BaseModel, Field, ConfigDict, field_validator
from backend.modules.data_entry.db.models.receipt import ReceiptType, ReceiptDirection, ReceiptStatus
from backend.modules.data_entry.db.models.accounting_entry import EntryType
# ============ Accounting Entry Schemas ============
class AccountingEntryBase(BaseModel):
"""Base schema for accounting entry."""
entry_type: EntryType
account_code: str = Field(max_length=20)
account_name: Optional[str] = Field(default=None, max_length=200)
amount: Decimal
partner_id: Optional[int] = None
cost_center_id: Optional[int] = None
class AccountingEntryCreate(AccountingEntryBase):
"""Schema for creating an accounting entry."""
pass
class AccountingEntryUpdate(BaseModel):
"""Schema for updating an accounting entry."""
entry_type: Optional[EntryType] = None
account_code: Optional[str] = Field(default=None, max_length=20)
account_name: Optional[str] = Field(default=None, max_length=200)
amount: Optional[Decimal] = None
partner_id: Optional[int] = None
cost_center_id: Optional[int] = None
class AccountingEntryResponse(AccountingEntryBase):
"""Schema for accounting entry response."""
model_config = ConfigDict(from_attributes=True)
id: int
receipt_id: int
is_auto_generated: bool
modified_by: Optional[str] = None
modified_at: Optional[datetime] = None
sort_order: int
# ============ Attachment Schemas ============
class AttachmentResponse(BaseModel):
"""Schema for attachment response."""
model_config = ConfigDict(from_attributes=True)
id: int
receipt_id: int
filename: str
stored_filename: str
file_path: str
file_size: int
mime_type: str
uploaded_at: datetime
# ============ TVA Schema ============
class TvaEntrySchema(BaseModel):
"""Single TVA entry with code, percentage and amount."""
code: Optional[str] = Field(default=None, description="TVA code: A, B, C, D")
percent: int = Field(description="TVA percentage: 0, 5, 9, 19, 21")
amount: Decimal = Field(description="TVA amount for this rate")
class PaymentMethodSchema(BaseModel):
"""Payment method entry (CARD/NUMERAR)."""
method: str = Field(description="Payment method: CARD or NUMERAR")
amount: Decimal = Field(description="Amount paid with this method")
# ============ Receipt Schemas ============
class ReceiptBase(BaseModel):
"""Base schema for receipt."""
receipt_type: ReceiptType = ReceiptType.BON_FISCAL
direction: ReceiptDirection = ReceiptDirection.CHELTUIALA
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
receipt_date: date
amount: Decimal = Field(gt=0)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (multiple entries support)
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
items_count: Optional[int] = Field(default=None, description="Number of items")
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
# Other fields
expense_type_code: Optional[str] = Field(default=None, max_length=20)
company_id: int
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200)
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
cash_register_id: Optional[int] = None
cash_register_name: Optional[str] = Field(default=None, max_length=100)
cash_register_account: Optional[str] = Field(default=None, max_length=20)
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
class ReceiptCreate(ReceiptBase):
"""Schema for creating a receipt."""
pass
class ReceiptUpdate(BaseModel):
"""Schema for updating a receipt (DRAFT only)."""
receipt_type: Optional[ReceiptType] = None
direction: Optional[ReceiptDirection] = None
receipt_number: Optional[str] = Field(default=None, max_length=50)
receipt_series: Optional[str] = Field(default=None, max_length=20)
receipt_date: Optional[date] = None
amount: Optional[Decimal] = Field(default=None, gt=0)
description: Optional[str] = Field(default=None, max_length=500)
# TVA info (multiple entries support)
tva_breakdown: Optional[List[TvaEntrySchema]] = Field(default=None, description="List of TVA entries")
tva_total: Optional[Decimal] = Field(default=None, description="Total TVA amount")
items_count: Optional[int] = Field(default=None, description="Number of items")
vendor_address: Optional[str] = Field(default=None, max_length=500, description="Vendor address")
# Other fields
expense_type_code: Optional[str] = Field(default=None, max_length=20)
# partner_id removed - supplier data is text-only (partner_name, cui)
partner_name: Optional[str] = Field(default=None, max_length=200)
cui: Optional[str] = Field(default=None, max_length=20, description="Fiscal code (CUI) from OCR")
ocr_raw_text: Optional[str] = Field(default=None, description="Raw OCR text for debugging")
payment_methods: Optional[List[PaymentMethodSchema]] = Field(default=None, description="Payment methods from OCR")
cash_register_id: Optional[int] = None
cash_register_name: Optional[str] = Field(default=None, max_length=100)
cash_register_account: Optional[str] = Field(default=None, max_length=20)
payment_mode: Optional[str] = Field(default=None, description="Payment mode: casa/banca/avans_decontare")
class ReceiptResponse(ReceiptBase):
"""Schema for receipt response with all fields."""
model_config = ConfigDict(from_attributes=True)
id: int
# Override amount to allow zero values in response (validation is on input, not output)
amount: Decimal
status: ReceiptStatus
created_by: str
created_at: datetime
updated_at: datetime
submitted_at: Optional[datetime] = None
reviewed_by: Optional[str] = None
reviewed_at: Optional[datetime] = None
rejection_reason: Optional[str] = None
oracle_synced_at: Optional[datetime] = None
oracle_act_id: Optional[int] = None
oracle_error: Optional[str] = None
# Relationships (optional, loaded when needed)
attachments: List[AttachmentResponse] = []
entries: List[AccountingEntryResponse] = []
@field_validator('tva_breakdown', mode='before')
@classmethod
def parse_tva_breakdown(cls, v: Any) -> Optional[List[dict]]:
"""Deserialize tva_breakdown from JSON string if needed."""
if v is None:
return None
if isinstance(v, str):
try:
return json.loads(v)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(v, list):
return v
return None
@field_validator('payment_methods', mode='before')
@classmethod
def parse_payment_methods(cls, v: Any) -> Optional[List[dict]]:
"""Deserialize payment_methods from JSON string if needed."""
if v is None:
return None
if isinstance(v, str):
try:
return json.loads(v)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(v, list):
return v
return None
class ReceiptListResponse(BaseModel):
"""Schema for paginated receipt list response."""
items: List[ReceiptResponse]
total: int
page: int
page_size: int
pages: int
class ReceiptFilter(BaseModel):
"""Schema for filtering receipts."""
status: Optional[ReceiptStatus] = None
direction: Optional[ReceiptDirection] = None
company_id: Optional[int] = None
created_by: Optional[str] = None
date_from: Optional[date] = None
date_to: Optional[date] = None
search: Optional[str] = None # Search in description, partner_name
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
# ============ Workflow Schemas ============
class WorkflowAction(BaseModel):
"""Schema for workflow action response."""
success: bool
message: str
receipt: Optional[ReceiptResponse] = None
class RejectRequest(BaseModel):
"""Schema for rejection request."""
reason: str = Field(min_length=5, max_length=500)
class EntriesUpdateRequest(BaseModel):
"""Schema for bulk updating accounting entries."""
entries: List[AccountingEntryCreate]
# ============ Nomenclature Schemas ============
class PartnerOption(BaseModel):
"""Schema for partner dropdown option (used for autocomplete assistance)."""
name: str
fiscal_code: Optional[str] = None
address: Optional[str] = None
source: str = "oracle" # 'oracle' (synced) or 'local'
class AccountOption(BaseModel):
"""Schema for account dropdown option."""
code: str
name: str
class CashRegisterOption(BaseModel):
"""Schema for cash register dropdown option."""
id: int
name: str
account_code: str # 5311, 5121, etc.
class ExpenseTypeOption(BaseModel):
"""Schema for expense type dropdown option."""
code: str
name: str
account_code: str
has_vat: bool
vat_percent: Decimal = Decimal("19")

View File

@@ -0,0 +1,11 @@
# Business logic services
from .receipt_service import ReceiptService
from .nomenclature_service import NomenclatureService
from .expense_types import EXPENSE_TYPES, ExpenseType
__all__ = [
"ReceiptService",
"NomenclatureService",
"EXPENSE_TYPES",
"ExpenseType",
]

View File

@@ -0,0 +1,101 @@
"""Predefined expense types for automatic accounting entry generation."""
from decimal import Decimal
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class ExpenseType:
"""Expense type definition with accounting configuration."""
code: str
name: str
account_code: str
account_name: str
has_vat: bool
vat_percent: Decimal = Decimal("19")
vat_account: str = "4426"
# Predefined expense types
EXPENSE_TYPES: Dict[str, ExpenseType] = {
"FUEL": ExpenseType(
code="FUEL",
name="Combustibil",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"MATERIALS": ExpenseType(
code="MATERIALS",
name="Materiale consumabile",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=True,
),
"OFFICE": ExpenseType(
code="OFFICE",
name="Rechizite birou",
account_code="6024",
account_name="Cheltuieli privind materialele pentru ambalat",
has_vat=True,
),
"PHONE": ExpenseType(
code="PHONE",
name="Telefonie / Internet",
account_code="626",
account_name="Cheltuieli postale si taxe de telecomunicatii",
has_vat=True,
),
"PARKING": ExpenseType(
code="PARKING",
name="Parcare",
account_code="6022",
account_name="Cheltuieli cu combustibilii",
has_vat=True,
),
"FOOD": ExpenseType(
code="FOOD",
name="Alimentatie",
account_code="6028",
account_name="Alte cheltuieli cu materiale consumabile",
has_vat=False, # No deductible VAT for food
),
"TRANSPORT": ExpenseType(
code="TRANSPORT",
name="Transport",
account_code="624",
account_name="Cheltuieli cu transportul de bunuri si personal",
has_vat=True,
),
"OTHER": ExpenseType(
code="OTHER",
name="Altele",
account_code="628",
account_name="Alte cheltuieli cu serviciile executate de terti",
has_vat=True,
),
}
def get_expense_type(code: str) -> Optional[ExpenseType]:
"""Get expense type by code."""
return EXPENSE_TYPES.get(code)
def get_all_expense_types() -> Dict[str, ExpenseType]:
"""Get all expense types."""
return EXPENSE_TYPES.copy()
# Default cash register accounts
CASH_REGISTER_ACCOUNTS = {
"CASA": {
"code": "5311",
"name": "Casa in lei",
},
"BANCA": {
"code": "5121",
"name": "Conturi la banci in lei",
},
}

View File

@@ -0,0 +1,270 @@
"""Image preprocessing for optimal OCR results."""
from pathlib import Path
from typing import List
import numpy as np
import cv2
try:
import pdf2image
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
class ImagePreprocessor:
"""Preprocess receipt images for OCR."""
def _add_safety_padding(self, image: np.ndarray, padding: int = 50) -> np.ndarray:
"""Add white padding around image to protect edge content during rotation.
This prevents left/right margin truncation in OCR by ensuring text near
edges isn't lost during deskew rotation.
"""
if len(image.shape) == 2:
# Grayscale
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=255
)
else:
# Color (BGR)
return cv2.copyMakeBorder(
image, padding, padding, padding, padding,
cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
def load_image(self, path: Path) -> np.ndarray:
"""Load image from file."""
image = cv2.imread(str(path))
if image is None:
raise ValueError(f"Could not load image: {path}")
return image
def pdf_to_images(self, path: Path, dpi: int = 300) -> List[np.ndarray]:
"""
Convert PDF to images.
Args:
path: Path to PDF file
dpi: Resolution (300 = fast & good quality, 400 = better but slower)
"""
if not PDF_AVAILABLE:
raise RuntimeError("pdf2image not available. Install with: pip install pdf2image")
images = pdf2image.convert_from_path(str(path), dpi=dpi)
return [np.array(img) for img in images]
def preprocess(self, image: np.ndarray, high_quality: bool = True) -> np.ndarray:
"""
Apply LIGHT preprocessing - better for clear PDFs.
Heavy binarization can destroy text on clear images.
"""
return self.preprocess_light(image)
def preprocess_light(self, image: np.ndarray) -> np.ndarray:
"""
Light preprocessing for CLEAR images (PDFs, good scans).
Preserves original quality, only enhances contrast.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Light contrast enhancement only
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# NO binarization, NO morphological ops - preserve original quality
return enhanced
def preprocess_heavy(self, image: np.ndarray) -> np.ndarray:
"""
Heavy preprocessing for FADED thermal receipts.
Aggressive binarization to recover faded text.
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2a. Scale DOWN if any side exceeds 4000px (PaddleOCR limit)
height, width = gray.shape
max_side = max(height, width)
if max_side > 4000:
scale = 4000 / max_side
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
height, width = gray.shape
# 2b. Scale UP if too small (larger = better OCR)
if width < 1500:
scale = 1500 / width
# Ensure we don't exceed 4000px after upscaling
new_width = int(width * scale)
new_height = int(height * scale)
if max(new_width, new_height) > 4000:
scale = 4000 / max(new_width, new_height)
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
# 3. Deskew
gray = self._deskew(gray)
# 4. Contrast enhancement with CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Denoise
denoised = cv2.fastNlMeansDenoising(enhanced, h=8, templateWindowSize=7, searchWindowSize=21)
# 6. Sharpening
gaussian = cv2.GaussianBlur(denoised, (0, 0), 2.0)
sharpened = cv2.addWeighted(denoised, 1.5, gaussian, -0.5, 0)
# 7. Adaptive thresholding (binarization)
binary = cv2.adaptiveThreshold(
sharpened, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
blockSize=11, C=5
)
# 8. Morphological operations
kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
result = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_close)
return result
def preprocess_for_tesseract(self, image: np.ndarray) -> np.ndarray:
"""
Tesseract-optimized preprocessing.
Tesseract works best with:
- Clean black text on white background (binarized)
- High DPI (scale up small images)
- Otsu thresholding (better than adaptive for clean documents)
"""
# 0. Add safety padding to protect edge content during deskew rotation
image = self._add_safety_padding(image)
# 1. Grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2. Scale for optimal Tesseract (target ~2000px width for receipts)
height, width = gray.shape
if width < 2000:
scale = 2000 / width
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
elif width > 3000:
scale = 3000 / width
gray = cv2.resize(gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
# 3. Deskew
gray = self._deskew(gray)
# 4. Strong contrast enhancement
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 5. Denoise before binarization
denoised = cv2.fastNlMeansDenoising(enhanced, h=10, templateWindowSize=7, searchWindowSize=21)
# 6. Otsu binarization (better than adaptive for clean PDFs)
_, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 7. Light morphological cleanup
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
cleaned = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
return cleaned
def get_all_variants(self, image: np.ndarray) -> List[np.ndarray]:
"""
Generate 2 preprocessing variants for OCR (fast mode).
Returns: [light_processed, heavy_processed]
"""
return [
self.preprocess_light(image),
self.preprocess_heavy(image),
]
def _deskew(self, image: np.ndarray) -> np.ndarray:
"""Correct image rotation/skew using Hough lines.
Uses expanded canvas to preserve all content during rotation,
preventing left/right margin truncation.
"""
edges = cv2.Canny(image, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(
edges, 1, np.pi / 180,
threshold=100, minLineLength=100, maxLineGap=10
)
if lines is None:
return image
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi
if abs(angle) < 45:
angles.append(angle)
if not angles:
return image
median_angle = np.median(angles)
if abs(median_angle) < 0.5:
return image
h, w = image.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
# Calculate new canvas size to fit entire rotated image (prevents edge truncation)
cos_angle = abs(np.cos(np.radians(median_angle)))
sin_angle = abs(np.sin(np.radians(median_angle)))
new_w = int(h * sin_angle + w * cos_angle)
new_h = int(h * cos_angle + w * sin_angle)
# Adjust rotation matrix for new canvas center
M[0, 2] += (new_w - w) / 2
M[1, 2] += (new_h - h) / 2
return cv2.warpAffine(
image, M, (new_w, new_h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_CONSTANT,
borderValue=255 # White background (grayscale)
)

View File

@@ -0,0 +1,234 @@
"""Service for fetching nomenclatures from Oracle (read-only)."""
from typing import List, Optional
from decimal import Decimal
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.schemas.receipt import (
PartnerOption,
AccountOption,
CashRegisterOption,
ExpenseTypeOption,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
class NomenclatureService:
"""
Service for fetching nomenclatures.
In Phase 1 (MVP), some nomenclatures are hardcoded.
In Phase 2, these will be fetched from Oracle.
"""
@staticmethod
async def get_partners(
company_id: int,
search: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[PartnerOption]:
"""
Get partners (suppliers/customers) for a company.
Phase 1: Returns mock data.
Phase 2: Returns synced data from SQLite (from Oracle sync).
Phase 3: Will fetch live from Oracle.
"""
# If session is provided, try to get from synced SQLite data
if session:
# Try to get from SQLite synced data
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.order_by(SyncedSupplier.name) # Order alphabetically, no limit for AutoComplete
result = await session.execute(stmt)
suppliers = result.scalars().all()
if suppliers:
# Also get local suppliers
local_stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
local_stmt = local_stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
local_stmt = local_stmt.order_by(LocalSupplier.name) # Order alphabetically
local_result = await session.execute(local_stmt)
local_suppliers = local_result.scalars().all()
# Combine both - no IDs needed, just text data for autocomplete
partners = []
for s in suppliers:
partners.append(PartnerOption(
name=s.name,
fiscal_code=s.fiscal_code,
address=s.address,
source="oracle"
))
for l in local_suppliers:
partners.append(PartnerOption(
name=l.name, # No suffix - must match search results
fiscal_code=l.fiscal_code,
address=l.address,
source="local"
))
return partners
# Fallback to mock data for Phase 1 (when no synced data)
mock_partners = [
PartnerOption(name="OMV Petrom", fiscal_code="RO123456", source="mock"),
PartnerOption(name="Dedeman", fiscal_code="RO789012", source="mock"),
PartnerOption(name="Kaufland", fiscal_code="RO345678", source="mock"),
PartnerOption(name="Emag", fiscal_code="RO901234", source="mock"),
PartnerOption(name="Altex", fiscal_code="RO567890", source="mock"),
]
if search:
search_lower = search.lower()
mock_partners = [
p for p in mock_partners
if search_lower in p.name.lower() or (p.fiscal_code and search_lower in p.fiscal_code.lower())
]
return mock_partners
@staticmethod
async def get_accounts(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Get chart of accounts for a company.
Phase 1: Returns common expense/income accounts.
Phase 2: Will fetch from Oracle PLAN_CONTURI.
"""
# Common accounts for expenses and receipts
accounts = [
# Expense accounts (Class 6)
AccountOption(code="6022", name="Cheltuieli cu combustibilii"),
AccountOption(code="6024", name="Cheltuieli materiale pentru ambalat"),
AccountOption(code="6028", name="Alte cheltuieli cu materiale consumabile"),
AccountOption(code="624", name="Cheltuieli cu transportul de bunuri si personal"),
AccountOption(code="626", name="Cheltuieli postale si taxe telecomunicatii"),
AccountOption(code="628", name="Alte cheltuieli cu serviciile executate de terti"),
# VAT
AccountOption(code="4426", name="TVA deductibila"),
AccountOption(code="4427", name="TVA colectata"),
# Cash and Bank (Class 5)
AccountOption(code="5311", name="Casa in lei"),
AccountOption(code="5121", name="Conturi la banci in lei"),
# Income accounts (Class 7)
AccountOption(code="7588", name="Alte venituri din exploatare"),
]
if prefix:
accounts = [a for a in accounts if a.code.startswith(prefix)]
return accounts
@staticmethod
async def get_cash_registers(
company_id: int,
session: Optional[AsyncSession] = None
) -> List[CashRegisterOption]:
"""
Get cash registers and bank accounts for a company.
Phase 1: Returns default options.
Phase 2: Returns synced data from SQLite (from Oracle sync).
Phase 3: Will fetch live from Oracle NOM_CASE / NOM_BANCI.
"""
# If session is provided, try to get from synced SQLite data
if session:
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
if registers:
return [
CashRegisterOption(id=r.id, name=r.name, account_code=r.account_code)
for r in registers
]
# Fallback to default cash registers for Phase 1
return [
CashRegisterOption(id=1, name="Casa principala", account_code="5311"),
CashRegisterOption(id=2, name="Cont BCR", account_code="5121"),
CashRegisterOption(id=3, name="Cont BRD", account_code="5121"),
]
@staticmethod
async def get_expense_types() -> List[ExpenseTypeOption]:
"""
Get predefined expense types with their accounting configuration.
"""
return [
ExpenseTypeOption(
code=et.code,
name=et.name,
account_code=et.account_code,
has_vat=et.has_vat,
vat_percent=et.vat_percent,
)
for et in EXPENSE_TYPES.values()
]
@staticmethod
async def get_companies(username: str) -> List[dict]:
"""
Get companies accessible by user.
Phase 1: Returns mock data.
Phase 2: Will fetch from shared auth based on user permissions.
"""
# TODO: Integrate with shared auth to get user's companies
return [
{"id": 1, "name": "SC Test SRL", "cui": "RO12345678"},
{"id": 2, "name": "SC Demo SA", "cui": "RO87654321"},
]
# ============ Phase 2 Oracle Integration Methods ============
@staticmethod
async def _fetch_partners_oracle(company_id: int, search: Optional[str] = None) -> List[PartnerOption]:
"""
Fetch partners from Oracle NOM_PARTENERI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
# Example query:
# SELECT ID_PART, DEN_PART, COD_FISCAL
# FROM {schema}.NOM_PARTENERI
# WHERE DEN_PART LIKE :search
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_accounts_oracle(company_id: int, prefix: Optional[str] = None) -> List[AccountOption]:
"""
Fetch chart of accounts from Oracle PLAN_CONTURI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")
@staticmethod
async def _fetch_cash_registers_oracle(company_id: int) -> List[CashRegisterOption]:
"""
Fetch cash registers from Oracle NOM_CASE / NOM_BANCI.
Will be implemented in Phase 2.
"""
# TODO: Implement using shared oracle_pool
raise NotImplementedError("Oracle integration pending - Phase 2")

View File

@@ -0,0 +1,295 @@
"""OCR engine wrapper for PaddleOCR and Tesseract."""
import os
import logging
import threading
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Ensure logs are visible
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x)
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
# Lazy imports - these will be imported on first use
PaddleOCR = None # Will be imported lazily
pytesseract = None # Will be imported lazily
# Check availability without importing heavy libraries
def _check_paddle_available() -> bool:
"""Check if paddleocr is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("paddleocr") is not None
except Exception:
return False
def _check_tesseract_available() -> bool:
"""Check if pytesseract is installed without importing it."""
try:
import importlib.util
return importlib.util.find_spec("pytesseract") is not None
except Exception:
return False
PADDLE_AVAILABLE = _check_paddle_available()
TESSERACT_AVAILABLE = _check_tesseract_available()
@dataclass
class OCRResult:
"""Raw OCR result."""
text: str
confidence: float
boxes: List[dict]
engine: str = "" # OCR engine used: paddleocr or tesseract
class OCREngine:
"""Unified OCR engine with fallback support."""
def __init__(self):
self._paddle = None
self._paddle_init_started = False
self._paddle_ready = threading.Event() # Signals when PaddleOCR is FULLY ready
self._paddle_init_lock = threading.Lock()
def _init_paddle_lazy(self):
"""Lazy initialize PaddleOCR on first use (avoids slow startup)."""
global PaddleOCR
with self._paddle_init_lock:
if self._paddle_init_started:
return # Already initializing or done
self._paddle_init_started = True
if PADDLE_AVAILABLE:
try:
print("Importing PaddleOCR (first use, may take ~15-20 seconds)...", flush=True)
from paddleocr import PaddleOCR as _PaddleOCR
PaddleOCR = _PaddleOCR
print("Initializing PaddleOCR engine...", flush=True)
# PaddleOCR 3.x API - optimized for Romanian receipts
# Note: 'latin' not available in PaddleOCR 3.x, 'en' works well for receipts
self._paddle = PaddleOCR(
lang='en', # 'en' handles Latin alphabet well for receipts
# High quality settings for better accuracy
det_db_thresh=0.3, # Lower threshold = detect more text (default 0.3)
det_db_box_thresh=0.5, # Box confidence threshold (default 0.5)
det_db_unclip_ratio=1.8, # Expand detected boxes slightly (default 1.5)
rec_batch_num=6, # Batch size for recognition
use_angle_cls=True, # Enable text angle classification
)
print("PaddleOCR initialized successfully with high-quality settings", flush=True)
except Exception as e:
print(f"Warning: Failed to initialize PaddleOCR: {e}", flush=True)
self._paddle = None
# Signal that initialization is complete (success or failure)
self._paddle_ready.set()
def wait_for_paddle(self, timeout: float = 30.0) -> bool:
"""
Wait for PaddleOCR to be fully initialized.
Args:
timeout: Max seconds to wait (default 30s)
Returns:
True if PaddleOCR is ready, False if timeout or unavailable
"""
if not PADDLE_AVAILABLE:
return False
if self._paddle is not None:
return True # Already ready
if not self._paddle_init_started:
# Start initialization if not already started
self._init_paddle_lazy()
# Wait for initialization to complete
print(f"[OCR] Waiting for PaddleOCR to be ready (max {timeout}s)...", flush=True)
start = time.time()
ready = self._paddle_ready.wait(timeout=timeout)
elapsed = time.time() - start
if ready and self._paddle is not None:
print(f"[OCR] PaddleOCR ready after {elapsed:.1f}s", flush=True)
return True
else:
print(f"[OCR] PaddleOCR not ready after {elapsed:.1f}s (timeout or failed)", flush=True)
return False
def is_paddle_ready(self) -> bool:
"""Check if PaddleOCR is ready without waiting."""
return self._paddle is not None
def recognize(self, image: np.ndarray) -> OCRResult:
"""Perform OCR on preprocessed image."""
logger.info(f"[OCR] Starting recognition, image shape: {image.shape}, dtype: {image.dtype}")
# Lazy init PaddleOCR on first call
self._init_paddle_lazy()
if PADDLE_AVAILABLE and self._paddle:
logger.info("[OCR] Using PaddleOCR engine")
return self._paddle_recognize(image)
elif TESSERACT_AVAILABLE:
logger.info("[OCR] Using Tesseract engine (PaddleOCR not available)")
return self._tesseract_recognize(image)
else:
logger.error("[OCR] No OCR engine available!")
raise RuntimeError(
"No OCR engine available. Install PaddleOCR or Tesseract."
)
def _paddle_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using PaddleOCR 3.x API."""
# Wait for PaddleOCR to be fully ready (handles background init)
if not self.wait_for_paddle(timeout=30.0):
logger.warning("[PaddleOCR] Not ready, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise RuntimeError("PaddleOCR not ready and Tesseract not available")
try:
logger.info(f"[PaddleOCR] Processing image, shape: {image.shape}")
# PaddleOCR 3.x requires 3-channel images
if len(image.shape) == 2:
# Convert grayscale to 3-channel BGR
import cv2
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
logger.info(f"[PaddleOCR] Converted to BGR, new shape: {image.shape}")
# PaddleOCR 3.x uses predict() with new parameter names
logger.info("[PaddleOCR] Calling predict()...")
result = self._paddle.predict(image, use_textline_orientation=True)
logger.info(f"[PaddleOCR] predict() returned, result type: {type(result)}")
if not result or len(result) == 0:
logger.warning("[PaddleOCR] No results returned")
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# PaddleOCR 3.x returns OCRResult objects with different structure
ocr_result = result[0]
# Extract texts and scores from the new format
rec_texts = ocr_result.get('rec_texts', [])
rec_scores = ocr_result.get('rec_scores', [])
dt_polys = ocr_result.get('dt_polys', [])
if not rec_texts:
return OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
boxes = []
for i, text in enumerate(rec_texts):
conf = rec_scores[i] if i < len(rec_scores) else 0.0
box = dt_polys[i].tolist() if i < len(dt_polys) else []
boxes.append({
'text': text,
'confidence': float(conf),
'box': box
})
avg_conf = sum(rec_scores) / len(rec_scores) if rec_scores else 0.0
text_result = '\n'.join(rec_texts)
logger.info(f"[PaddleOCR] SUCCESS - Found {len(rec_texts)} text lines, avg confidence: {avg_conf:.2%}")
logger.debug(f"[PaddleOCR] Raw text preview: {text_result[:200]}...")
return OCRResult(
text=text_result,
confidence=float(avg_conf),
boxes=boxes,
engine="paddleocr"
)
except Exception as e:
logger.error(f"[PaddleOCR] ERROR: {e}, falling back to Tesseract")
if TESSERACT_AVAILABLE:
return self._tesseract_recognize(image)
raise
def _tesseract_recognize(self, image: np.ndarray) -> OCRResult:
"""Recognize text using Tesseract."""
global pytesseract
logger.info(f"[Tesseract] Processing image, shape: {image.shape}")
# Lazy import pytesseract
if pytesseract is None:
logger.info("[Tesseract] Importing pytesseract...")
import pytesseract as _pytesseract
pytesseract = _pytesseract
# PSM 4: Single column (best for receipts)
config = '--psm 4 -l ron+eng'
text = pytesseract.image_to_string(image, config=config)
# Quick confidence estimate
data = pytesseract.image_to_data(image, config=config, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data['conf'] if int(c) > 0]
avg_conf = sum(confidences) / len(confidences) / 100 if confidences else 0.0
logger.info(f"[Tesseract] Done: {len(text)} chars, conf: {avg_conf:.2%}")
return OCRResult(text=text, confidence=avg_conf, boxes=[], engine="tesseract")
def recognize_dual(self, image: np.ndarray) -> Tuple[OCRResult, Optional[OCRResult]]:
"""
Run both OCR engines and return both results.
Returns:
Tuple of (paddle_result, tesseract_result)
tesseract_result may be None if Tesseract is not available
"""
logger.info(f"[OCR Dual] Starting dual recognition, image shape: {image.shape}")
# Lazy init PaddleOCR
self._init_paddle_lazy()
paddle_result = None
tesseract_result = None
# Run PaddleOCR
if PADDLE_AVAILABLE and self._paddle:
try:
logger.info("[OCR Dual] Running PaddleOCR...")
paddle_result = self._paddle_recognize(image)
logger.info(f"[OCR Dual] PaddleOCR: {len(paddle_result.text)} chars, conf: {paddle_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] PaddleOCR failed: {e}")
paddle_result = OCRResult(text="", confidence=0.0, boxes=[], engine="paddleocr")
# Run Tesseract
if TESSERACT_AVAILABLE:
try:
logger.info("[OCR Dual] Running Tesseract...")
tesseract_result = self._tesseract_recognize(image)
logger.info(f"[OCR Dual] Tesseract: {len(tesseract_result.text)} chars, conf: {tesseract_result.confidence:.2%}")
except Exception as e:
logger.error(f"[OCR Dual] Tesseract failed: {e}")
tesseract_result = OCRResult(text="", confidence=0.0, boxes=[], engine="tesseract")
# Fallback if PaddleOCR not available
if paddle_result is None:
if tesseract_result:
paddle_result = tesseract_result
else:
raise RuntimeError("No OCR engine available")
return paddle_result, tesseract_result
@staticmethod
def get_available_engines() -> List[str]:
"""Return list of available OCR engines."""
engines = []
if PADDLE_AVAILABLE:
engines.append('paddleocr')
if TESSERACT_AVAILABLE:
engines.append('tesseract')
return engines

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,569 @@
"""Main OCR service coordinating preprocessing, recognition, and extraction."""
import os
import re
import logging
# Disable PaddleOCR model source check for faster startup (PaddleX 3.x) - must be set before import
os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
import time
import asyncio
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
from pathlib import Path
from typing import Optional, Tuple
from backend.modules.data_entry.services.ocr_engine import OCREngine
from backend.modules.data_entry.services.ocr_extractor import ReceiptExtractor, ExtractionResult
from backend.modules.data_entry.services.image_preprocessor import ImagePreprocessor
# Setup logging
logger = logging.getLogger(__name__)
class OCRService:
"""Service for OCR processing of receipt images."""
_executor = ThreadPoolExecutor(max_workers=2)
def __init__(self):
self.preprocessor = ImagePreprocessor()
self.ocr_engine = OCREngine()
self.extractor = ReceiptExtractor()
async def process_image(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""
Process receipt image and extract structured data.
Args:
image_path: Path to the image file
mime_type: MIME type of the file
Returns:
Tuple of (success, message, extraction_result)
"""
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self._executor,
self._process_sync,
image_path,
mime_type
)
return result
except Exception as e:
return False, f"OCR processing failed: {str(e)}", None
def _process_sync(
self,
image_path: Path,
mime_type: str
) -> Tuple[bool, str, Optional[ExtractionResult]]:
"""Synchronous processing with ADAPTIVE OCR pipeline."""
start_time = time.time()
print(f"[OCR Service] Starting processing: {image_path}, mime: {mime_type}", flush=True)
# Load image
if mime_type == 'application/pdf':
try:
images = self.preprocessor.pdf_to_images(image_path)
if not images:
return False, "Failed to extract images from PDF", None
image = images[0]
except RuntimeError as e:
return False, str(e), None
else:
try:
image = self.preprocessor.load_image(image_path)
except ValueError as e:
return False, str(e), None
raw_texts = []
extraction = None
# ══════════════════════════════════════════════════════════════
# STEP 1: PaddleOCR + Light (fastest, best for clear PDFs)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 1: PaddleOCR + Light preprocessing", flush=True)
print("=" * 60, flush=True)
light_img = self.preprocessor.preprocess_light(image)
try:
paddle_light = self.ocr_engine._paddle_recognize(light_img)
if paddle_light and paddle_light.text:
extraction = self.extractor.extract(paddle_light.text)
extraction.ocr_engine = "paddle-light"
raw_texts.append(f"═══ PaddleOCR (light, conf: {paddle_light.confidence:.0%}) ═══\n{paddle_light.text}")
# Log extraction results
print(f"[OCR] Step 1 Results:", flush=True)
print(f" - OCR Confidence: {paddle_light.confidence:.0%}", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total} (entries: {len(extraction.tva_entries) if extraction.tva_entries else 0})", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
# Early exit if complete
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] ✓✓✓ EARLY EXIT at Step 1 - All fields found! ({elapsed_ms}ms) ✓✓✓", flush=True)
return True, "OCR complete (fast mode)", extraction
else:
print("[OCR] → Step 1 incomplete, continuing to Step 2...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR light failed: {e}", flush=True)
extraction = ExtractionResult()
# ══════════════════════════════════════════════════════════════
# STEP 2: PaddleOCR + Heavy (for faded thermal receipts)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 2: PaddleOCR + Heavy preprocessing", flush=True)
print("=" * 60, flush=True)
heavy_img = self.preprocessor.preprocess_heavy(image)
try:
paddle_heavy = self.ocr_engine._paddle_recognize(heavy_img)
if paddle_heavy and paddle_heavy.text:
extraction_heavy = self.extractor.extract(paddle_heavy.text)
extraction_heavy.ocr_engine = "paddle-heavy"
raw_texts.append(f"═══ PaddleOCR (heavy, conf: {paddle_heavy.confidence:.0%}) ═══\n{paddle_heavy.text}")
print(f"[OCR] Step 2 (Heavy) Results:", flush=True)
print(f" - OCR Confidence: {paddle_heavy.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_heavy.amount}", flush=True)
print(f" - Date: {extraction_heavy.receipt_date}", flush=True)
print(f" - CUI: {extraction_heavy.cui}", flush=True)
# Merge with previous
extraction = self._merge_extractions(extraction, extraction_heavy)
print(f"[OCR] After merge:", flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
if self._is_extraction_complete(extraction):
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "paddle-adaptive"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print(f"[OCR] ✓✓✓ EARLY EXIT at Step 2 - All fields found after merge! ({elapsed_ms}ms) ✓✓✓", flush=True)
return True, "OCR complete (paddle dual)", extraction
else:
print("[OCR] → Step 2 incomplete, continuing to Step 3 (Tesseract)...", flush=True)
except Exception as e:
print(f"[OCR] PaddleOCR heavy failed: {e}", flush=True)
# ══════════════════════════════════════════════════════════════
# STEP 3: Tesseract - ONLY to complete missing fields
# Uses Tesseract-optimized preprocessing (binarized, high contrast)
# ══════════════════════════════════════════════════════════════
print("=" * 60, flush=True)
print("[OCR] STEP 3: Tesseract (complement only, not override)", flush=True)
print("=" * 60, flush=True)
try:
# Use Tesseract-specific preprocessing (Otsu binarization)
tesseract_img = self.preprocessor.preprocess_for_tesseract(image)
tesseract_result = self.ocr_engine._tesseract_recognize(tesseract_img)
if tesseract_result and tesseract_result.text:
extraction_tess = self.extractor.extract(tesseract_result.text)
extraction_tess.ocr_engine = "tesseract"
raw_texts.append(f"═══ Tesseract (conf: {tesseract_result.confidence:.0%}) ═══\n{tesseract_result.text}")
print(f"[OCR] Step 3 (Tesseract) Results:", flush=True)
print(f" - OCR Confidence: {tesseract_result.confidence:.0%}", flush=True)
print(f" - Amount: {extraction_tess.amount}", flush=True)
print(f" - Date: {extraction_tess.receipt_date}", flush=True)
print(f" - CUI: {extraction_tess.cui}", flush=True)
# IMPORTANT: Tesseract only COMPLETES missing fields, never overrides!
extraction = self._complement_extraction(extraction, extraction_tess)
except Exception as e:
print(f"[OCR] Tesseract failed: {e}", flush=True)
# ══════════════════════════════════════════════════════════════
# FINAL VALIDATION: Fix impossible values
# ══════════════════════════════════════════════════════════════
if extraction:
extraction = self._final_validation(extraction)
# Final result
if extraction is None:
return False, "No text detected", None
extraction.raw_text = "\n\n".join(raw_texts)
extraction.ocr_engine = "adaptive-full"
# Build result message
fields_found = []
if extraction.amount: fields_found.append("amount")
if extraction.receipt_date: fields_found.append("date")
if extraction.receipt_number: fields_found.append("number")
if extraction.cui: fields_found.append("CUI")
if extraction.tva_total or extraction.tva_entries: fields_found.append("TVA")
message = f"OCR complete (full pipeline). Found: {', '.join(fields_found) or 'no fields'}"
elapsed_ms = int((time.time() - start_time) * 1000)
extraction.processing_time_ms = elapsed_ms
print("=" * 60, flush=True)
print(f"[OCR] FINAL RESULT (full pipeline) - {elapsed_ms}ms", flush=True)
print("=" * 60, flush=True)
print(f" - Amount: {extraction.amount}", flush=True)
print(f" - Date: {extraction.receipt_date}", flush=True)
print(f" - Number: {extraction.receipt_number}", flush=True)
print(f" - CUI: {extraction.cui}", flush=True)
print(f" - TVA: {extraction.tva_total}", flush=True)
print(f" - Overall Confidence: {extraction.overall_confidence:.0%}", flush=True)
print(f" - Processing Time: {elapsed_ms}ms", flush=True)
print(f" - Message: {message}", flush=True)
return True, message, extraction
def _merge_extractions(
self,
paddle: Optional[ExtractionResult],
tesseract: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Merge two extractions, picking best fields from each engine.
Strategy:
- For each field, prefer the one with higher confidence
- Use validation rules (CUI format, date validity, company indicators)
- Combine TVA entries if different
"""
result = ExtractionResult()
# Handle case where one is None
if paddle is None and tesseract is None:
return result
if paddle is None:
return tesseract
if tesseract is None:
return paddle
print("[Merge] Comparing PaddleOCR vs Tesseract extractions...", flush=True)
# === AMOUNT ===
# Pick higher confidence, both must be positive
if paddle.amount and tesseract.amount:
if paddle.confidence_amount >= tesseract.confidence_amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
print(f"[Merge] Amount: PaddleOCR {paddle.amount} (conf: {paddle.confidence_amount:.0%})", flush=True)
else:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
print(f"[Merge] Amount: Tesseract {tesseract.amount} (conf: {tesseract.confidence_amount:.0%})", flush=True)
elif paddle.amount:
result.amount = paddle.amount
result.confidence_amount = paddle.confidence_amount
elif tesseract.amount:
result.amount = tesseract.amount
result.confidence_amount = tesseract.confidence_amount
# === DATE ===
# Pick higher confidence, validate date reasonableness
if paddle.receipt_date and tesseract.receipt_date:
if paddle.confidence_date >= tesseract.confidence_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
print(f"[Merge] Date: PaddleOCR {paddle.receipt_date}", flush=True)
else:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
print(f"[Merge] Date: Tesseract {tesseract.receipt_date}", flush=True)
elif paddle.receipt_date:
result.receipt_date = paddle.receipt_date
result.confidence_date = paddle.confidence_date
elif tesseract.receipt_date:
result.receipt_date = tesseract.receipt_date
result.confidence_date = tesseract.confidence_date
# === VENDOR NAME ===
# Prefer one with company indicators (S.R.L., S.A., etc.)
paddle_has_indicator = self._has_company_indicator(paddle.partner_name)
tesseract_has_indicator = self._has_company_indicator(tesseract.partner_name)
if paddle.partner_name and tesseract.partner_name:
if paddle_has_indicator and not tesseract_has_indicator:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (has company indicator)", flush=True)
elif tesseract_has_indicator and not paddle_has_indicator:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (has company indicator)", flush=True)
elif paddle.confidence_vendor >= tesseract.confidence_vendor:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
print(f"[Merge] Vendor: PaddleOCR '{paddle.partner_name}' (higher conf)", flush=True)
else:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
print(f"[Merge] Vendor: Tesseract '{tesseract.partner_name}' (higher conf)", flush=True)
elif paddle.partner_name:
result.partner_name = paddle.partner_name
result.confidence_vendor = paddle.confidence_vendor
elif tesseract.partner_name:
result.partner_name = tesseract.partner_name
result.confidence_vendor = tesseract.confidence_vendor
# === CUI (Fiscal Code) ===
# Validate format: 6-10 digits, prefer valid one
paddle_cui_valid = self._is_valid_cui(paddle.cui)
tesseract_cui_valid = self._is_valid_cui(tesseract.cui)
if paddle.cui and tesseract.cui:
if paddle_cui_valid and not tesseract_cui_valid:
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui} (valid format)", flush=True)
elif tesseract_cui_valid and not paddle_cui_valid:
result.cui = tesseract.cui
print(f"[Merge] CUI: Tesseract {tesseract.cui} (valid format)", flush=True)
else:
# Both valid or both invalid - prefer PaddleOCR
result.cui = paddle.cui
print(f"[Merge] CUI: PaddleOCR {paddle.cui}", flush=True)
elif paddle.cui and paddle_cui_valid:
result.cui = paddle.cui
elif tesseract.cui and tesseract_cui_valid:
result.cui = tesseract.cui
elif paddle.cui:
result.cui = paddle.cui
elif tesseract.cui:
result.cui = tesseract.cui
# === TVA ENTRIES ===
# Prefer non-empty, use the one with more entries or higher amounts
if paddle.tva_entries and tesseract.tva_entries:
# Compare: prefer the one with actual amounts (not just 0)
paddle_total = sum(e.get('amount', Decimal('0')) for e in paddle.tva_entries)
tesseract_total = sum(e.get('amount', Decimal('0')) for e in tesseract.tva_entries)
if paddle_total >= tesseract_total:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
print(f"[Merge] TVA: PaddleOCR (total: {paddle_total})", flush=True)
else:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
print(f"[Merge] TVA: Tesseract (total: {tesseract_total})", flush=True)
elif paddle.tva_entries:
result.tva_entries = paddle.tva_entries
result.tva_total = paddle.tva_total
elif tesseract.tva_entries:
result.tva_entries = tesseract.tva_entries
result.tva_total = tesseract.tva_total
# === OTHER FIELDS ===
# Simple preference: paddle > tesseract
result.receipt_number = paddle.receipt_number or tesseract.receipt_number
result.receipt_series = paddle.receipt_series or tesseract.receipt_series
result.receipt_type = paddle.receipt_type or tesseract.receipt_type
result.items_count = paddle.items_count or tesseract.items_count
result.address = paddle.address or tesseract.address
result.description = paddle.description or tesseract.description
return result
def _has_company_indicator(self, name: Optional[str]) -> bool:
"""Check if vendor name has company type indicator (S.R.L., S.A., etc.)"""
if not name:
return False
name_upper = name.upper()
indicators = [
r'\bS\.?\s*R\.?\s*L\.?\b',
r'\bS\.?\s*A\.?\b',
r'\bS\.?\s*N\.?\s*C\.?\b',
r'\bP\.?\s*F\.?\s*A\.?\b',
r'\bI\.?\s*I\.?\b',
r'\bHOLDING\b',
r'\bGROUP\b',
r'\bCOMPANY\b',
]
for indicator in indicators:
if re.search(indicator, name_upper):
return True
return False
def _is_valid_cui(self, cui: Optional[str]) -> bool:
"""Validate CUI format: 6-10 digits."""
if not cui:
return False
# Remove any RO prefix
cui_clean = re.sub(r'^RO', '', cui.upper())
# Must be 6-10 digits
return bool(re.match(r'^\d{6,10}$', cui_clean))
def _is_extraction_complete(self, ext: ExtractionResult, min_confidence: float = 0.85) -> bool:
"""
Check if extraction has ALL required fields to skip further processing.
Required for early exit (ALL must be true):
- Overall confidence >= 85%
- ALL 5 critical fields present: number, date, amount, TVA, CUI
"""
# Must have high confidence
if ext.overall_confidence < min_confidence:
print(f"[OCR] Confidence {ext.overall_confidence:.0%} < {min_confidence:.0%} - continuing", flush=True)
return False
# Check all required fields
has_number = bool(ext.receipt_number)
has_date = bool(ext.receipt_date)
has_amount = bool(ext.amount)
has_tva = bool(ext.tva_total) or bool(ext.tva_entries)
has_cui = bool(ext.cui)
missing = []
if not has_number: missing.append("number")
if not has_date: missing.append("date")
if not has_amount: missing.append("amount")
if not has_tva: missing.append("TVA")
if not has_cui: missing.append("CUI")
if missing:
print(f"[OCR] Missing: {', '.join(missing)} - continuing", flush=True)
return False
print(f"[OCR] ✓ All 5 fields found with {ext.overall_confidence:.0%} confidence", flush=True)
return True
def _complement_extraction(
self,
primary: Optional[ExtractionResult],
secondary: Optional[ExtractionResult]
) -> ExtractionResult:
"""
Complement primary extraction with missing fields from secondary.
NEVER overrides existing values - only fills in gaps.
This is different from _merge_extractions which can override values.
"""
if primary is None and secondary is None:
return ExtractionResult()
if primary is None:
return secondary
if secondary is None:
return primary
print("[Complement] Adding missing fields from Tesseract...", flush=True)
# Only fill missing amount
if not primary.amount and secondary.amount:
primary.amount = secondary.amount
primary.confidence_amount = secondary.confidence_amount
print(f"[Complement] Added amount: {secondary.amount}", flush=True)
# Only fill missing date
if not primary.receipt_date and secondary.receipt_date:
primary.receipt_date = secondary.receipt_date
primary.confidence_date = secondary.confidence_date
print(f"[Complement] Added date: {secondary.receipt_date}", flush=True)
# Only fill missing vendor
if not primary.partner_name and secondary.partner_name:
primary.partner_name = secondary.partner_name
primary.confidence_vendor = secondary.confidence_vendor
print(f"[Complement] Added vendor: {secondary.partner_name}", flush=True)
# Only fill missing CUI
if not primary.cui and secondary.cui and self._is_valid_cui(secondary.cui):
primary.cui = secondary.cui
print(f"[Complement] Added CUI: {secondary.cui}", flush=True)
# Only fill missing TVA
if not primary.tva_entries and secondary.tva_entries:
primary.tva_entries = secondary.tva_entries
primary.tva_total = secondary.tva_total
print(f"[Complement] Added TVA: {secondary.tva_total}", flush=True)
# Only fill missing receipt number
if not primary.receipt_number and secondary.receipt_number:
primary.receipt_number = secondary.receipt_number
print(f"[Complement] Added number: {secondary.receipt_number}", flush=True)
# Only fill missing address
if not primary.address and secondary.address:
primary.address = secondary.address
print(f"[Complement] Added address: {secondary.address}", flush=True)
return primary
def _final_validation(self, extraction: ExtractionResult) -> ExtractionResult:
"""
Final validation and correction of impossible values.
Key rules:
1. TVA cannot be greater than TOTAL (it's always a fraction)
2. If TVA > TOTAL, recalculate TOTAL from TVA using known rates
3. Validate TVA entries sum equals TVA total
"""
print("[Final Validation] Checking extracted values...", flush=True)
# Rule 1: TVA cannot be greater than TOTAL
if extraction.tva_total and extraction.amount:
if extraction.tva_total > extraction.amount:
print(f"[Final Validation] TVA ({extraction.tva_total}) > TOTAL ({extraction.amount}) - IMPOSSIBLE!", flush=True)
# Calculate TOTAL from TVA using reverse formula:
# total = base + tva = tva * (100/rate + 1) = tva * (100 + rate) / rate
# For 9% TVA: total = tva * 109 / 9 = tva * 12.11
# For 19% TVA: total = tva * 119 / 19 = tva * 6.26
# For 21% TVA: total = tva * 121 / 21 = tva * 5.76
rate = 19 # Default rate assumption
if extraction.tva_entries:
# Use the rate from the first entry
rate = extraction.tva_entries[0].get('percent', 19)
if rate > 0:
# Formula: total = tva * (100 + rate) / rate
calculated_total = extraction.tva_total * (Decimal('100') + Decimal(str(rate))) / Decimal(str(rate))
calculated_total = calculated_total.quantize(Decimal('0.01'))
print(f"[Final Validation] Calculated TOTAL from TVA: {calculated_total} (using {rate}% rate)", flush=True)
extraction.amount = calculated_total
extraction.confidence_amount = 0.70 # Lower confidence for calculated value
# Rule 2: TVA cannot be more than ~25% of total (max Romanian rate is 21%)
if extraction.tva_total and extraction.amount:
tva_percent = extraction.tva_total / extraction.amount * Decimal('100')
if tva_percent > Decimal('25'):
print(f"[Final Validation] Warning: TVA is {tva_percent:.1f}% of total - suspicious", flush=True)
# Rule 3: Validate TVA entries sum
if extraction.tva_entries and extraction.tva_total:
entries_sum = sum(e.get('amount', Decimal('0')) for e in extraction.tva_entries)
tolerance = Decimal('0.05')
if abs(entries_sum - extraction.tva_total) > tolerance:
print(f"[Final Validation] TVA entries sum ({entries_sum}) != tva_total ({extraction.tva_total})", flush=True)
# Use the sum as it's more reliable
extraction.tva_total = entries_sum
print(f"[Final Validation] Done. Amount={extraction.amount}, TVA={extraction.tva_total}", flush=True)
return extraction
# Singleton instance
ocr_service = OCRService()

View File

@@ -0,0 +1,447 @@
"""Business logic service for receipts workflow."""
from decimal import Decimal, ROUND_HALF_UP
from typing import List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from backend.modules.data_entry.db.models.receipt import Receipt, ReceiptStatus, ReceiptDirection
from backend.modules.data_entry.db.models.accounting_entry import EntryType
from backend.modules.data_entry.db.crud.receipt import ReceiptCRUD
from backend.modules.data_entry.db.crud.accounting_entry import AccountingEntryCRUD
from backend.modules.data_entry.schemas.receipt import (
ReceiptCreate,
ReceiptUpdate,
ReceiptFilter,
ReceiptResponse,
ReceiptListResponse,
AccountingEntryCreate,
)
from backend.modules.data_entry.services.expense_types import EXPENSE_TYPES, get_expense_type
# Payment mode to accounting account mapping
PAYMENT_MODE_ACCOUNTS = {
'casa': ('5311', 'Casa in lei'),
'banca': ('5121', 'Conturi la banci in lei'),
'avans_decontare': ('542', 'Avansuri de trezorerie'),
}
class ReceiptService:
"""Service for receipt business logic and workflow."""
@staticmethod
async def create_receipt(
session: AsyncSession,
data: ReceiptCreate,
created_by: str,
) -> Receipt:
"""Create a new receipt in DRAFT status."""
return await ReceiptCRUD.create(session, data, created_by)
@staticmethod
async def get_receipt(
session: AsyncSession,
receipt_id: int,
) -> Optional[Receipt]:
"""Get receipt by ID with all relationships."""
return await ReceiptCRUD.get_by_id(session, receipt_id, include_relations=True)
@staticmethod
async def get_receipts(
session: AsyncSession,
filters: ReceiptFilter,
) -> ReceiptListResponse:
"""Get paginated list of receipts."""
receipts, total = await ReceiptCRUD.get_list(session, filters)
pages = (total + filters.page_size - 1) // filters.page_size if total > 0 else 1
return ReceiptListResponse(
items=[ReceiptResponse.model_validate(r) for r in receipts],
total=total,
page=filters.page,
page_size=filters.page_size,
pages=pages,
)
@staticmethod
async def update_receipt(
session: AsyncSession,
receipt_id: int,
data: ReceiptUpdate,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Update receipt (only DRAFT status).
Returns (success, message, receipt).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_edit(receipt, username):
return False, "Cannot edit this receipt", None
updated = await ReceiptCRUD.update(session, receipt, data)
return True, "Receipt updated", updated
@staticmethod
async def delete_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str]:
"""
Delete receipt (only DRAFT status).
Returns (success, message).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found"
if not await ReceiptCRUD.can_delete(receipt, username):
return False, "Cannot delete this receipt"
await ReceiptCRUD.delete(session, receipt)
return True, "Receipt deleted"
@staticmethod
def generate_accounting_entries(receipt: Receipt) -> List[AccountingEntryCreate]:
"""
Generate accounting entries based on receipt data and expense type.
"""
entries: List[AccountingEntryCreate] = []
# Get expense type configuration
expense_type = get_expense_type(receipt.expense_type_code or "OTHER")
if not expense_type:
expense_type = EXPENSE_TYPES["OTHER"]
amount = Decimal(str(receipt.amount))
if receipt.direction == ReceiptDirection.CHELTUIALA:
# Expense: Debit expense account, Credit cash/bank
if expense_type.has_vat:
# Calculate net and VAT
vat_rate = expense_type.vat_percent / Decimal("100")
net_amount = (amount / (1 + vat_rate)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
vat_amount = amount - net_amount
# Debit: Expense account (net)
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=net_amount,
))
# Debit: VAT deductible
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.vat_account,
account_name="TVA deductibila",
amount=vat_amount,
))
else:
# No VAT - full amount to expense
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=expense_type.account_code,
account_name=expense_type.account_name,
amount=amount,
))
# Credit entry - based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
credit_account, credit_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
# Backwards compatibility for existing receipts
credit_account = receipt.cash_register_account
credit_name = receipt.cash_register_name or "Casa/Banca"
else:
# Default fallback
credit_account = "5311"
credit_name = "Casa in lei"
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code=credit_account,
account_name=credit_name,
amount=amount,
))
else:
# Income: Debit cash/bank, Credit income account
# Based on payment_mode (new) or cash_register (legacy)
if receipt.payment_mode and receipt.payment_mode in PAYMENT_MODE_ACCOUNTS:
cash_account, cash_name = PAYMENT_MODE_ACCOUNTS[receipt.payment_mode]
elif receipt.cash_register_account:
cash_account = receipt.cash_register_account
cash_name = receipt.cash_register_name or "Casa/Banca"
else:
cash_account = "5311"
cash_name = "Casa in lei"
# Debit: Cash/Bank
entries.append(AccountingEntryCreate(
entry_type=EntryType.DEBIT,
account_code=cash_account,
account_name=cash_name,
amount=amount,
))
# Credit: Income account (7xx - to be configured)
entries.append(AccountingEntryCreate(
entry_type=EntryType.CREDIT,
account_code="7588",
account_name="Alte venituri din exploatare",
amount=amount,
))
return entries
@staticmethod
async def submit_for_review(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Submit receipt for review (DRAFT/REJECTED → PENDING_REVIEW).
Generates accounting entries automatically.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if not await ReceiptCRUD.can_submit(receipt, username):
return False, "Cannot submit this receipt", None
# Check if receipt has at least one attachment
if not receipt.attachments:
return False, "Receipt must have at least one attachment", None
# Check required fields
if not receipt.expense_type_code:
return False, "Expense type is required", None
# Validate payment_mode or cash_register (backwards compatibility)
if not receipt.payment_mode and not receipt.cash_register_account:
return False, "Modul de plata este obligatoriu", None
# Generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Delete existing entries and create new ones
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt submitted for review", updated
@staticmethod
async def approve_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Approve receipt (PENDING_REVIEW → APPROVED).
Requires valid CUI (fiscal code) for approval.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Validate CUI is present (required for Oracle import)
if not receipt.cui:
return False, "Trebuie completat codul fiscal (CUI) pentru aprobare", None
# Validate accounting entries
if not receipt.entries:
return False, "Receipt has no accounting entries", None
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.APPROVED, reviewed_by=username
)
return True, "Receipt approved", updated
@staticmethod
async def unapprove_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Unapprove receipt (APPROVED → PENDING_REVIEW).
Returns receipt to pending review for corrections.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.APPROVED:
return False, "Receipt is not approved", None
# Update status back to pending review
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
return True, "Receipt returned to pending review", updated
@staticmethod
async def reject_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
reason: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Reject receipt (PENDING_REVIEW → REJECTED).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Receipt is not pending review", None
# Update status
updated = await ReceiptCRUD.update_status(
session,
receipt,
ReceiptStatus.REJECTED,
reviewed_by=username,
rejection_reason=reason,
)
return True, "Receipt rejected", updated
@staticmethod
async def resubmit_receipt(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, Optional[Receipt]]:
"""
Resubmit rejected receipt after corrections (REJECTED → PENDING_REVIEW).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", None
if receipt.status != ReceiptStatus.REJECTED:
return False, "Receipt is not rejected", None
if receipt.created_by != username:
return False, "Only the creator can resubmit", None
# Re-generate accounting entries
entries = ReceiptService.generate_accounting_entries(receipt)
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
# Refresh receipt to clear stale relationship references after entry deletion
await session.refresh(receipt)
# Update status
updated = await ReceiptCRUD.update_status(
session, receipt, ReceiptStatus.PENDING_REVIEW
)
# Reload with entries
updated = await ReceiptCRUD.get_by_id(session, receipt_id)
return True, "Receipt resubmitted for review", updated
@staticmethod
async def regenerate_entries(
session: AsyncSession,
receipt_id: int,
username: str,
) -> Tuple[bool, str, List[AccountingEntryCreate]]:
"""
Regenerate accounting entries for a receipt.
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status not in [ReceiptStatus.DRAFT, ReceiptStatus.PENDING_REVIEW]:
return False, "Cannot regenerate entries for this receipt status", []
# Generate new entries
entries = ReceiptService.generate_accounting_entries(receipt)
# Replace existing entries
await AccountingEntryCRUD.delete_all_for_receipt(session, receipt_id)
await AccountingEntryCRUD.create_bulk(session, receipt_id, entries, is_auto_generated=True)
return True, "Entries regenerated", entries
@staticmethod
async def update_entries(
session: AsyncSession,
receipt_id: int,
entries: List[AccountingEntryCreate],
username: str,
) -> Tuple[bool, str, List]:
"""
Update accounting entries for a receipt (accountant action).
"""
receipt = await ReceiptCRUD.get_by_id(session, receipt_id)
if not receipt:
return False, "Receipt not found", []
if receipt.status != ReceiptStatus.PENDING_REVIEW:
return False, "Can only modify entries for receipts pending review", []
# Validate entries
is_valid, error = await AccountingEntryCRUD.validate_entries(entries)
if not is_valid:
return False, error, []
# Replace entries
updated_entries = await AccountingEntryCRUD.replace_all_for_receipt(
session, receipt_id, entries, username
)
return True, "Entries updated", updated_entries
@staticmethod
async def get_pending_count(
session: AsyncSession,
company_id: Optional[int] = None,
) -> int:
"""Get count of receipts pending review."""
receipts = await ReceiptCRUD.get_pending_review(session, company_id)
return len(receipts)

View File

@@ -0,0 +1,406 @@
"""Service for syncing nomenclatures from Oracle to SQLite."""
import sys
from pathlib import Path
from typing import Optional, List, Tuple
from datetime import datetime
import logging
from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession
# Path setup handled by main.py - this is redundant
# project_root = Path(__file__).parent.parent.parent.parent.parent
# sys.path.insert(0, str(project_root / "shared"))
from shared.database.oracle_pool import oracle_pool
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier, SyncedCashRegister
logger = logging.getLogger(__name__)
# Cache for schema lookups (populated dynamically from Oracle)
_schema_cache: dict[int, str] = {}
class SyncService:
"""Service for syncing nomenclatures from Oracle."""
@staticmethod
async def get_schema_for_company(company_id: int) -> Optional[str]:
"""
Get Oracle schema for company ID from V_NOM_FIRME view.
Results are cached in memory for performance.
"""
# Check cache first
if company_id in _schema_cache:
return _schema_cache[company_id]
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT SCHEMA
FROM CONTAFIN_ORACLE.V_NOM_FIRME
WHERE ID_FIRMA = :company_id
""", {'company_id': company_id})
result = cursor.fetchone()
if result:
schema = result[0]
_schema_cache[company_id] = schema
logger.info(f"Resolved schema for company {company_id}: {schema}")
return schema
else:
logger.warning(f"No schema found for company {company_id}")
return None
except Exception as e:
logger.error(f"Error fetching schema for company {company_id}: {e}")
return None
@staticmethod
async def sync_suppliers(session: AsyncSession, company_id: int) -> Tuple[int, int]:
"""
Sync suppliers (furnizori, id_tip_part=17) from Oracle to SQLite.
Uses CORESP_TIP_PART joined with VNOM_PARTENERI view.
Returns (synced_count, error_count).
"""
schema = await SyncService.get_schema_for_company(company_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id}")
return 0, 0
synced = 0
errors = 0
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Fetch active suppliers from Oracle
# id_tip_part = 17 means "furnizori" (suppliers)
# Using CORESP_TIP_PART to filter by partner type
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, B.COD_FISCAL, B.ADRESA
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART = 17
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY B.DENUMIRE
""")
rows = cursor.fetchall()
for row in rows:
try:
oracle_id, name, fiscal_code, address = row
# Check if already exists
stmt = select(SyncedSupplier).where(
SyncedSupplier.oracle_id == oracle_id,
SyncedSupplier.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.fiscal_code = fiscal_code
existing.address = address
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated supplier {oracle_id}: {name}")
else:
# Create new record
supplier = SyncedSupplier(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
fiscal_code=fiscal_code,
address=address,
)
session.add(supplier)
logger.debug(f"Created supplier {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing supplier row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} suppliers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing suppliers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
async def sync_cash_registers(session: AsyncSession, company_id: int) -> Tuple[int, int]:
"""
Sync cash registers and bank accounts from Oracle to SQLite.
Returns (synced_count, error_count).
Uses CORESP_TIP_PART with:
- id_tip_part = 22: CASA LEI
- id_tip_part = 23: CASA VALUTA
- id_tip_part = 24: BANCA LEI
- id_tip_part = 25: BANCA VALUTA
"""
schema = await SyncService.get_schema_for_company(company_id)
if not schema:
logger.warning(f"No schema mapping for company {company_id}")
return 0, 0
synced = 0
errors = 0
# Partner types mapping
# 22=CASA LEI, 23=CASA VALUTA -> cash
# 24=BANCA LEI, 25=BANCA VALUTA -> bank
partner_types = [22, 23, 24, 25]
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Fetch cash/bank partners from CORESP_TIP_PART
cursor.execute(f"""
SELECT B.ID_PART, B.DENUMIRE, A.ID_TIP_PART
FROM {schema}.CORESP_TIP_PART A
INNER JOIN {schema}.VNOM_PARTENERI B ON A.ID_PART = B.ID_PART
WHERE A.ID_TIP_PART IN (22, 23, 24, 25)
AND (B.INACTIV = 0 OR B.INACTIV IS NULL)
AND B.ID_PART IS NOT NULL
ORDER BY A.ID_TIP_PART, B.DENUMIRE
""")
rows = cursor.fetchall()
# Type mapping: 22=CASA LEI, 23=CASA VALUTA -> cash; 24=BANCA LEI, 25=BANCA VALUTA -> bank
type_mapping = {
22: ("cash", "CASA_LEI"),
23: ("cash", "CASA_VALUTA"),
24: ("bank", "BANCA_LEI"),
25: ("bank", "BANCA_VALUTA"),
}
for row in rows:
try:
oracle_id, name, tip_part_id = row
# Determine type based on partner type
register_type, account_code = type_mapping.get(tip_part_id, ("cash", "UNKNOWN"))
# Check if already exists
stmt = select(SyncedCashRegister).where(
SyncedCashRegister.oracle_id == oracle_id,
SyncedCashRegister.company_id == company_id
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# Update existing record
existing.name = name or ""
existing.account_code = account_code
existing.register_type = register_type
existing.synced_at = datetime.utcnow()
logger.debug(f"Updated cash register {oracle_id}: {name}")
else:
# Create new record
cash_register = SyncedCashRegister(
oracle_id=oracle_id,
company_id=company_id,
name=name or "",
account_code=account_code,
register_type=register_type,
)
session.add(cash_register)
logger.debug(f"Created cash register {oracle_id}: {name}")
synced += 1
except Exception as e:
logger.error(f"Error processing cash register row {row}: {e}")
errors += 1
# Commit all changes
await session.commit()
logger.info(f"Synced {synced} cash registers for company {company_id}, {errors} errors")
except Exception as e:
logger.error(f"Error syncing cash registers for company {company_id}: {e}")
errors += 1
await session.rollback()
return synced, errors
@staticmethod
async def search_supplier(
session: AsyncSession,
company_id: int,
fiscal_code: Optional[str] = None,
name: Optional[str] = None
) -> Tuple[bool, Optional[dict], str]:
"""
Search for supplier in SQLite first, then Oracle if not found.
Returns (found, supplier_data, source).
Source can be: 'synced', 'local', 'not_found'
"""
# 1. Search in synced suppliers
if fiscal_code:
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.fiscal_code == fiscal_code
)
elif name:
stmt = select(SyncedSupplier).where(
SyncedSupplier.company_id == company_id,
SyncedSupplier.name.ilike(f"%{name}%")
)
else:
return False, None, "no_query"
result = await session.execute(stmt)
supplier = result.scalar_one_or_none()
if supplier:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": supplier.name,
"fiscal_code": supplier.fiscal_code,
"address": supplier.address,
}, "synced"
# 2. Search in local suppliers
if fiscal_code:
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.fiscal_code == fiscal_code
)
elif name:
stmt = select(LocalSupplier).where(
LocalSupplier.company_id == company_id,
LocalSupplier.name.ilike(f"%{name}%")
)
result = await session.execute(stmt)
local = result.scalar_one_or_none()
if local:
# Return only text data - no IDs needed for autocomplete
return True, {
"name": local.name,
"fiscal_code": local.fiscal_code,
"address": local.address,
}, "local"
# 3. Try live Oracle search (optional fallback for unsynced data)
# This is a fallback - ideally sync should be up to date
# TODO: Implement live Oracle search if needed
return False, None, "not_found"
@staticmethod
async def create_local_supplier(
session: AsyncSession,
company_id: int,
name: str,
fiscal_code: Optional[str],
address: Optional[str],
created_by: str
) -> LocalSupplier:
"""Create a local supplier entry from OCR data."""
supplier = LocalSupplier(
company_id=company_id,
name=name,
fiscal_code=fiscal_code,
address=address,
created_by=created_by,
)
session.add(supplier)
await session.commit()
await session.refresh(supplier)
logger.info(f"Created local supplier: {name} (CUI: {fiscal_code})")
return supplier
@staticmethod
async def get_all_suppliers(
session: AsyncSession,
company_id: int,
search: Optional[str] = None
) -> List[dict]:
"""
Get all suppliers (synced + local) for a company.
Used for dropdown/autocomplete in UI.
"""
suppliers = []
# Get synced suppliers
stmt = select(SyncedSupplier).where(SyncedSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(SyncedSupplier.name.ilike(f"%{search}%")) |
(SyncedSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50) # Limit results for performance
result = await session.execute(stmt)
synced = result.scalars().all()
for s in synced:
suppliers.append({
"id": s.id,
"oracle_id": s.oracle_id,
"name": s.name,
"fiscal_code": s.fiscal_code,
"source": "synced"
})
# Get local suppliers
stmt = select(LocalSupplier).where(LocalSupplier.company_id == company_id)
if search:
stmt = stmt.where(
(LocalSupplier.name.ilike(f"%{search}%")) |
(LocalSupplier.fiscal_code.ilike(f"%{search}%"))
)
stmt = stmt.limit(50)
result = await session.execute(stmt)
local = result.scalars().all()
for l in local:
suppliers.append({
"id": l.id,
"name": l.name,
"fiscal_code": l.fiscal_code,
"source": "local"
})
return suppliers
@staticmethod
async def get_all_cash_registers(
session: AsyncSession,
company_id: int
) -> List[dict]:
"""
Get all cash registers for a company.
Used for dropdown in UI.
"""
stmt = select(SyncedCashRegister).where(SyncedCashRegister.company_id == company_id)
result = await session.execute(stmt)
registers = result.scalars().all()
return [
{
"id": r.id,
"oracle_id": r.oracle_id,
"name": r.name,
"account_code": r.account_code,
"register_type": r.register_type
}
for r in registers
]

View File

View File

@@ -0,0 +1,66 @@
"""
Cache module for ROA2WEB
Provides hybrid two-tier caching (Memory L1 + SQLite L2)
with performance tracking and event-based invalidation.
Usage:
# Initialize cache at app startup
from app.cache import init_cache
from app.cache.config import CacheConfig
config = CacheConfig.from_env()
await init_cache(config)
# Use @cached decorator in services
from app.cache.decorators import cached
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
async def get_complete_summary(company: str, username: str):
# ... Oracle query logic ...
# Get cache manager for manual operations
from app.cache import get_cache
cache = get_cache()
await cache.invalidate(company_id=123)
"""
from .config import CacheConfig
from .cache_manager import (
init_cache,
get_cache,
close_cache,
CacheManager
)
from .decorators import cached
from .event_monitor import (
init_event_monitor,
get_event_monitor,
toggle_event_monitor,
preload_all_schema_mappings
)
from .benchmarks import run_baseline_benchmarks
__all__ = [
# Configuration
'CacheConfig',
# Cache Manager
'init_cache',
'get_cache',
'close_cache',
'CacheManager',
# Decorators
'cached',
# Event Monitor
'init_event_monitor',
'get_event_monitor',
'toggle_event_monitor',
'preload_all_schema_mappings',
# Benchmarks
'run_baseline_benchmarks',
]

View File

@@ -0,0 +1,269 @@
"""
Baseline performance benchmarking
Runs at startup to establish baseline Oracle query times
Used for calculating "time saved" by cache
"""
import time
import logging
from typing import Dict
logger = logging.getLogger(__name__)
async def run_baseline_benchmarks() -> Dict[str, float]:
"""
Run baseline benchmarks for Oracle queries (without cache)
Measures typical query times to establish performance baselines
These are used to calculate time saved when cache hits occur
NOTE: This implementation provides a framework. Actual benchmark
implementations need access to Oracle services and sample data.
Returns:
Dictionary mapping cache_type to average query time (ms)
"""
from .cache_manager import get_cache
cache = get_cache()
if not cache:
logger.warning("Cache not initialized - skipping benchmarks")
return {}
logger.info("Starting baseline performance benchmarks...")
benchmarks = {}
try:
# Benchmark: Schema lookup
logger.info("Benchmarking: schema lookup")
schema_times = await _benchmark_schema_lookup()
if schema_times:
avg_schema = sum(schema_times) / len(schema_times)
benchmarks['schema'] = avg_schema
await cache.sqlite.set_benchmark('schema', avg_schema, len(schema_times))
logger.info(f" Schema lookup: {avg_schema:.2f}ms (avg of {len(schema_times)} samples)")
# Benchmark: Companies list
logger.info("Benchmarking: companies list")
companies_time = await _benchmark_companies_list()
if companies_time:
benchmarks['companies'] = companies_time
await cache.sqlite.set_benchmark('companies', companies_time, 1)
logger.info(f" Companies list: {companies_time:.2f}ms")
# Benchmark: Dashboard summary
logger.info("Benchmarking: dashboard summary")
dashboard_time = await _benchmark_dashboard_summary()
if dashboard_time:
benchmarks['dashboard_summary'] = dashboard_time
await cache.sqlite.set_benchmark('dashboard_summary', dashboard_time, 1)
logger.info(f" Dashboard summary: {dashboard_time:.2f}ms")
# Benchmark: Dashboard trends
logger.info("Benchmarking: dashboard trends")
trends_time = await _benchmark_dashboard_trends()
if trends_time:
benchmarks['dashboard_trends'] = trends_time
await cache.sqlite.set_benchmark('dashboard_trends', trends_time, 1)
logger.info(f" Dashboard trends: {trends_time:.2f}ms")
# Benchmark: Invoices
logger.info("Benchmarking: invoices")
invoices_time = await _benchmark_invoices()
if invoices_time:
benchmarks['invoices'] = invoices_time
await cache.sqlite.set_benchmark('invoices', invoices_time, 1)
logger.info(f" Invoices: {invoices_time:.2f}ms")
# Benchmark: Treasury
logger.info("Benchmarking: treasury")
treasury_time = await _benchmark_treasury()
if treasury_time:
benchmarks['treasury'] = treasury_time
await cache.sqlite.set_benchmark('treasury', treasury_time, 1)
logger.info(f" Treasury: {treasury_time:.2f}ms")
logger.info(f"Baseline benchmarks completed: {len(benchmarks)} types measured")
return benchmarks
except Exception as e:
logger.error(f"Benchmark error: {e}", exc_info=True)
return benchmarks
async def _benchmark_schema_lookup() -> list:
"""
Benchmark schema lookup queries
Returns:
List of query times (ms) for multiple samples
"""
try:
# Import here to avoid circular dependency
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
# Get sample company IDs to test
sample_companies = await _get_sample_company_ids(limit=10)
if not sample_companies:
logger.warning("No sample companies found for schema benchmark")
return []
times = []
for company_id in sample_companies:
start = time.time()
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :id
""", {'id': company_id})
cursor.fetchone()
elapsed_ms = (time.time() - start) * 1000
times.append(elapsed_ms)
return times
except Exception as e:
logger.error(f"Schema benchmark error: {e}")
return []
async def _benchmark_companies_list() -> float:
"""
Benchmark companies list query
Returns:
Query time (ms)
"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
# Get sample username
sample_user = await _get_sample_username()
if not sample_user:
return 0
start = time.time()
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT nf.id_firma, nf.denumire, nf.cui, nf.schema
FROM CONTAFIN_ORACLE.v_nom_firme nf
JOIN CONTAFIN_ORACLE.vdef_util_firme uf ON nf.id_firma = uf.id_firma
WHERE uf.nume_utilizator = :username
ORDER BY nf.denumire
""", {'username': sample_user})
cursor.fetchall()
elapsed_ms = (time.time() - start) * 1000
return elapsed_ms
except Exception as e:
logger.error(f"Companies benchmark error: {e}")
return 0
async def _benchmark_dashboard_summary() -> float:
"""
Benchmark dashboard summary query
Returns:
Query time (ms)
"""
try:
# This requires access to DashboardService
# For now, return estimated value
logger.warning("Dashboard summary benchmark not implemented - using estimate")
return 250.0 # Estimated 250ms based on plan
except Exception as e:
logger.error(f"Dashboard benchmark error: {e}")
return 0
async def _benchmark_dashboard_trends() -> float:
"""Benchmark dashboard trends query"""
try:
logger.warning("Dashboard trends benchmark not implemented - using estimate")
return 400.0 # Estimated 400ms
except Exception as e:
logger.error(f"Trends benchmark error: {e}")
return 0
async def _benchmark_invoices() -> float:
"""Benchmark invoices query"""
try:
logger.warning("Invoices benchmark not implemented - using estimate")
return 180.0 # Estimated 180ms
except Exception as e:
logger.error(f"Invoices benchmark error: {e}")
return 0
async def _benchmark_treasury() -> float:
"""Benchmark treasury query"""
try:
logger.warning("Treasury benchmark not implemented - using estimate")
return 250.0 # Estimated 250ms
except Exception as e:
logger.error(f"Treasury benchmark error: {e}")
return 0
# Helper functions
async def _get_sample_company_ids(limit: int = 10) -> list:
"""Get sample company IDs for testing"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(f"""
SELECT id_firma
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE ROWNUM <= {limit}
""")
results = cursor.fetchall()
return [row[0] for row in results]
except Exception as e:
logger.error(f"Get sample companies error: {e}")
return []
async def _get_sample_username() -> str:
"""Get sample username for testing"""
try:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT nume_utilizator
FROM CONTAFIN_ORACLE.vdef_util_firme
WHERE ROWNUM <= 1
""")
result = cursor.fetchone()
return result[0] if result else "admin"
except Exception as e:
logger.error(f"Get sample username error: {e}")
return "admin"

View File

@@ -0,0 +1,335 @@
"""
Cache Manager - Orchestrator for hybrid L1 + L2 cache
"""
import logging
import asyncio
from typing import Any, Optional
from .config import CacheConfig
from .memory_cache import MemoryCache
from .sqlite_cache import SQLiteCache
logger = logging.getLogger(__name__)
class CacheManager:
"""
Hybrid cache manager (Memory L1 + SQLite L2)
Features:
- Two-tier caching: fast memory + persistent SQLite
- Automatic TTL management per cache type
- Performance tracking and benchmarking
- Per-user cache enable/disable
- Global cache toggle
"""
def __init__(self, config: CacheConfig):
"""
Initialize cache manager
Args:
config: Cache configuration
"""
self.config = config
self.memory = MemoryCache(max_size=config.memory_max_size)
self.sqlite = SQLiteCache(db_path=config.sqlite_path)
self._cleanup_task: Optional[asyncio.Task] = None
self._initialized = False
self._last_cache_source: Optional[str] = None # Track last cache source (L1/L2)
async def init(self):
"""Initialize cache system"""
if self._initialized:
logger.warning("Cache already initialized")
return
# Initialize SQLite database schema
await self.sqlite.init_db()
# Start cleanup task
if self.config.enabled:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
self._initialized = True
logger.info(f"Cache initialized: type={self.config.cache_type}, enabled={self.config.enabled}")
async def close(self):
"""Close cache and cleanup"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("Cache closed")
async def get(self, key: str, cache_type: str) -> Optional[Any]:
"""
Get value from cache (L1 → L2)
Args:
key: Cache key
cache_type: Type of cache entry
Returns:
Cached value or None if not found
"""
if not self.config.enabled:
self._last_cache_source = None
return None
# Try L1 (Memory) first
value = await self.memory.get(key)
if value is not None:
self._last_cache_source = "L1"
logger.debug(f"Cache HIT L1 (memory): {key}")
return value
# Try L2 (SQLite)
value = await self.sqlite.get(key)
if value is not None:
self._last_cache_source = "L2"
logger.debug(f"Cache HIT L2 (sqlite): {key}")
# Populate L1 for next time
ttl = self.config.get_ttl_for_type(cache_type)
await self.memory.set(key, value, ttl)
return value
# Cache MISS
self._last_cache_source = None
logger.debug(f"Cache MISS: {key}")
return None
def get_last_cache_source(self) -> Optional[str]:
"""
Get source of last cache hit (L1/L2/None)
Returns:
"L1" if last hit was from memory cache
"L2" if last hit was from SQLite cache
None if last call was a cache miss or cache disabled
"""
return self._last_cache_source
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int] = None,
ttl: Optional[int] = None):
"""
Set value in cache (both L1 and L2)
Args:
key: Cache key
value: Value to cache
cache_type: Type of cache entry
company_id: Company ID (for company-specific caches)
ttl: Time to live (uses default for cache_type if not provided)
"""
if not self.config.enabled:
return
if ttl is None:
ttl = self.config.get_ttl_for_type(cache_type)
# Store in both L1 and L2
await self.memory.set(key, value, ttl)
await self.sqlite.set(key, value, cache_type, company_id, ttl)
logger.debug(f"Cache SET (L1 + L2): {key} (TTL: {ttl}s)")
async def delete(self, key: str):
"""Delete entry from both L1 and L2"""
await self.memory.delete(key)
await self.sqlite.delete(key)
logger.debug(f"Cache deleted: {key}")
async def invalidate(self, company_id: Optional[int] = None, cache_type: Optional[str] = None):
"""
Invalidate cache entries
Args:
company_id: If provided, clear only this company's cache
cache_type: If provided, clear only this cache type
"""
if company_id is not None and cache_type is not None:
# Clear specific company + type
from .keys import generate_key_pattern
pattern = generate_key_pattern(cache_type, company_id)
await self.memory.clear_by_pattern(pattern)
# SQLite: clear by company + type (needs query)
# For now, just clear by company
await self.sqlite.clear_by_company(company_id)
logger.info(f"Cache invalidated: company={company_id}, type={cache_type}")
elif company_id is not None:
# Clear all for company
from .keys import generate_key_pattern
# Clear all types for this company (pattern match all)
# Memory: need to iterate and match company_id in key
# For simplicity, clear by pattern prefix
await self.memory.clear() # TODO: improve pattern matching
await self.sqlite.clear_by_company(company_id)
logger.info(f"Cache invalidated: company={company_id}")
elif cache_type is not None:
# Clear all for type
from .keys import generate_key_pattern
pattern = generate_key_pattern(cache_type)
await self.memory.clear_by_pattern(pattern)
await self.sqlite.clear_by_type(cache_type)
logger.info(f"Cache invalidated: type={cache_type}")
else:
# Clear everything
await self.memory.clear()
await self.sqlite.clear()
logger.info("Cache invalidated: ALL")
async def is_enabled_for_user(self, username: Optional[str]) -> bool:
"""
Check if cache is enabled for specific user
Args:
username: Username to check
Returns:
True if cache enabled for user, False otherwise
"""
if not self.config.enabled:
return False
if username is None:
return True
# Check per-user setting
return await self.sqlite.get_user_cache_enabled(username)
async def set_user_cache_enabled(self, username: str, enabled: bool):
"""Set user cache enabled/disabled"""
await self.sqlite.set_user_cache_enabled(username, enabled)
logger.info(f"User cache setting: {username} -> {enabled}")
# Benchmarking
async def get_benchmark(self, cache_type: str) -> Optional[float]:
"""Get average benchmark time for cache type"""
return await self.sqlite.get_benchmark(cache_type)
async def update_benchmark(self, cache_type: str, new_time_ms: float):
"""
Update benchmark with new measurement (exponential moving average)
Args:
cache_type: Type of cache
new_time_ms: New measured time in milliseconds
"""
current_avg = await self.sqlite.get_benchmark(cache_type)
if current_avg is None:
# First measurement
new_avg = new_time_ms
sample_count = 1
else:
# Exponential moving average (alpha = 0.1)
new_avg = 0.9 * current_avg + 0.1 * new_time_ms
# Get current sample count (TODO: retrieve from DB)
sample_count = 1 # Simplified for now
await self.sqlite.set_benchmark(cache_type, new_avg, sample_count)
logger.debug(f"Benchmark updated: {cache_type} -> {new_avg:.2f}ms")
# Performance Tracking
async def track_performance(self, cache_type: str, is_hit: bool, actual_time_ms: float,
time_saved_ms: Optional[float] = None,
estimated_oracle_time_ms: Optional[float] = None,
company_id: Optional[int] = None,
username: Optional[str] = None):
"""
Track performance metric
Args:
cache_type: Type of cache
is_hit: True if cache hit, False if cache miss
actual_time_ms: Actual response time
time_saved_ms: Time saved by cache (for hits)
estimated_oracle_time_ms: Estimated Oracle time (for hits)
company_id: Company ID
username: Username
"""
if not self.config.track_performance:
return
await self.sqlite.log_performance(
cache_type=cache_type,
company_id=company_id,
cache_hit=is_hit,
response_time_ms=actual_time_ms,
estimated_oracle_time_ms=estimated_oracle_time_ms,
time_saved_ms=time_saved_ms,
username=username
)
# Statistics
async def get_stats(self) -> dict:
"""Get comprehensive cache statistics"""
memory_stats = self.memory.get_stats()
sqlite_stats = await self.sqlite.get_stats()
return {
'enabled': self.config.enabled,
'cache_type': self.config.cache_type,
'memory': memory_stats,
'sqlite': sqlite_stats,
}
# Cleanup
async def _cleanup_loop(self):
"""Background task to cleanup expired entries"""
while True:
try:
await asyncio.sleep(self.config.cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Cleanup error: {e}", exc_info=True)
async def _cleanup_expired(self):
"""Remove expired entries from both caches"""
logger.info("Running cache cleanup...")
await self.memory.cleanup_expired()
await self.sqlite.cleanup_expired()
logger.info("Cache cleanup completed")
# Global cache manager instance
_cache_manager: Optional[CacheManager] = None
async def init_cache(config: CacheConfig):
"""Initialize global cache manager"""
global _cache_manager
if _cache_manager is not None:
logger.warning("Cache already initialized")
return
_cache_manager = CacheManager(config)
await _cache_manager.init()
logger.info("Global cache manager initialized")
def get_cache() -> Optional[CacheManager]:
"""Get global cache manager instance"""
return _cache_manager
async def close_cache():
"""Close global cache manager"""
global _cache_manager
if _cache_manager is not None:
await _cache_manager.close()
_cache_manager = None

89
backend/modules/reports/cache/config.py vendored Normal file
View File

@@ -0,0 +1,89 @@
"""
Cache configuration from environment variables
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class CacheConfig:
"""Cache configuration loaded from environment variables"""
# Core Settings
enabled: bool
cache_type: str # 'hybrid', 'memory', 'sqlite', 'disabled'
sqlite_path: str
memory_max_size: int
default_ttl: int
# TTL per Cache Type (seconds)
ttl_schema: int
ttl_companies: int
ttl_dashboard_summary: int
ttl_dashboard_trends: int
ttl_invoices: int
ttl_invoices_summary: int
ttl_treasury: int
ttl_trial_balance: int
ttl_calendar_periods: int
# Maintenance
cleanup_interval: int
# Event-Based Invalidation
auto_invalidate_enabled: bool
check_interval: int
# Performance Tracking
track_performance: bool
benchmark_on_startup: bool
@classmethod
def from_env(cls) -> 'CacheConfig':
"""Load configuration from environment variables"""
return cls(
# Core Settings
enabled=os.getenv('CACHE_ENABLED', 'True').lower() == 'true',
cache_type=os.getenv('CACHE_TYPE', 'hybrid'),
sqlite_path=os.getenv('CACHE_SQLITE_PATH', './data/cache/roa2web_cache.db'),
memory_max_size=int(os.getenv('CACHE_MEMORY_MAX_SIZE', '1000')),
default_ttl=int(os.getenv('CACHE_DEFAULT_TTL', '900')),
# TTL per Cache Type
ttl_schema=int(os.getenv('CACHE_TTL_SCHEMA', '86400')),
ttl_companies=int(os.getenv('CACHE_TTL_COMPANIES', '1800')),
ttl_dashboard_summary=int(os.getenv('CACHE_TTL_DASHBOARD_SUMMARY', '1800')),
ttl_dashboard_trends=int(os.getenv('CACHE_TTL_DASHBOARD_TRENDS', '1800')),
ttl_invoices=int(os.getenv('CACHE_TTL_INVOICES', '600')),
ttl_invoices_summary=int(os.getenv('CACHE_TTL_INVOICES_SUMMARY', '900')),
ttl_treasury=int(os.getenv('CACHE_TTL_TREASURY', '600')),
ttl_trial_balance=int(os.getenv('CACHE_TTL_TRIAL_BALANCE', '600')),
ttl_calendar_periods=int(os.getenv('CACHE_TTL_CALENDAR_PERIODS', '3600')),
# Maintenance
cleanup_interval=int(os.getenv('CACHE_CLEANUP_INTERVAL', '3600')),
# Event-Based Invalidation
auto_invalidate_enabled=os.getenv('CACHE_AUTO_INVALIDATE', 'False').lower() == 'true',
check_interval=int(os.getenv('CACHE_CHECK_INTERVAL', '300')),
# Performance Tracking
track_performance=os.getenv('CACHE_TRACK_PERFORMANCE', 'True').lower() == 'true',
benchmark_on_startup=os.getenv('CACHE_BENCHMARK_ON_STARTUP', 'True').lower() == 'true',
)
def get_ttl_for_type(self, cache_type: str) -> int:
"""Get TTL for specific cache type"""
ttl_map = {
'schema': self.ttl_schema,
'companies': self.ttl_companies,
'dashboard_summary': self.ttl_dashboard_summary,
'dashboard_trends': self.ttl_dashboard_trends,
'invoices': self.ttl_invoices,
'invoices_summary': self.ttl_invoices_summary,
'treasury': self.ttl_treasury,
'trial_balance': self.ttl_trial_balance,
'calendar_periods': self.ttl_calendar_periods,
}
return ttl_map.get(cache_type, self.default_ttl)

View File

@@ -0,0 +1,254 @@
"""
Cache decorators for service methods
"""
import time
import logging
from functools import wraps
from typing import Callable, Optional, List
from .cache_manager import get_cache
from .keys import generate_cache_key
logger = logging.getLogger(__name__)
def cached(cache_type: str, ttl: Optional[int] = None, key_params: Optional[List[str]] = None):
"""
Decorator for caching service method results with performance tracking
Usage:
@cached(cache_type='dashboard_summary', key_params=['company', 'username'])
async def get_complete_summary(company: str, username: str):
# ... Oracle query logic ...
Features:
- Automatic cache key generation from function parameters
- Performance timing (cache hit vs miss)
- Benchmark tracking for time saved calculation
- Per-user cache enable/disable
- Global cache toggle
- Transparent - zero changes to function logic
Args:
cache_type: Type of cache (used for TTL lookup and stats)
ttl: Optional custom TTL (overrides config default)
key_params: List of parameter names to include in cache key
Returns:
Decorated async function
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
cache = get_cache()
# Extract username for per-user settings
username = _extract_username(args, kwargs, key_params)
# Check if cache is enabled (global + per-user)
cache_enabled = await cache.is_enabled_for_user(username) if cache else False
if not cache or not cache_enabled:
# Cache disabled - execute directly
result = await func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
kwargs['request'].state.cache_hit = False
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = None
if cache and cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=False,
actual_time_ms=elapsed_ms,
username=username
)
return result
# Generate cache key from function parameters
cache_key = generate_cache_key(cache_type, key_params, args, kwargs)
# Try to get from cache
cached_value = await cache.get(cache_key, cache_type)
if cached_value is not None:
# ✅ CACHE HIT
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
cache_source_value = cache.get_last_cache_source() # L1 or L2
kwargs['request'].state.cache_hit = True
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = cache_source_value
# Get benchmark for calculating time saved
benchmark = await cache.get_benchmark(cache_type)
time_saved_ms = (benchmark - elapsed_ms) if benchmark else None
# Track performance
if cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=True,
actual_time_ms=elapsed_ms,
time_saved_ms=time_saved_ms,
estimated_oracle_time_ms=benchmark,
company_id=_extract_company_id(args, kwargs, key_params),
username=username
)
return cached_value
# ❌ CACHE MISS - execute function (query Oracle)
result = await func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
# Set metadata in request.state if available (for API responses)
if 'request' in kwargs and hasattr(kwargs['request'], 'state'):
kwargs['request'].state.cache_hit = False
kwargs['request'].state.response_time_ms = elapsed_ms
kwargs['request'].state.cache_source = None
# Update benchmark with real Oracle time
await cache.update_benchmark(cache_type, elapsed_ms)
# Track performance
if cache.config.track_performance:
await cache.track_performance(
cache_type=cache_type,
is_hit=False,
actual_time_ms=elapsed_ms,
company_id=_extract_company_id(args, kwargs, key_params),
username=username
)
# Store in cache for next time
company_id = _extract_company_id(args, kwargs, key_params)
await cache.set(cache_key, result, cache_type, company_id, ttl)
return result
return wrapper
return decorator
def _extract_username(args, kwargs, key_params: Optional[List[str]]) -> Optional[str]:
"""
Extract username from function parameters (args or kwargs)
Checks:
1. key_params position in args (if username is in key_params)
2. Direct username in kwargs
3. current_user object in kwargs
4. user object in kwargs
5. request.state.user (from AuthenticationMiddleware)
Args:
args: Positional arguments
kwargs: Keyword arguments
key_params: List of parameter names (for finding position in args)
Returns:
Username string or None
"""
# Try to find username in args based on key_params position
if key_params and 'username' in key_params:
try:
username_idx = key_params.index('username')
if username_idx < len(args):
username = args[username_idx]
if username:
return str(username)
except (ValueError, IndexError):
pass
# Direct username parameter in kwargs
if 'username' in kwargs:
return kwargs['username']
# Current user object (from FastAPI Depends)
if 'current_user' in kwargs:
user = kwargs['current_user']
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return str(user)
# User object
if 'user' in kwargs:
user = kwargs['user']
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return str(user)
# Extract from request.state.user (set by AuthenticationMiddleware)
if 'request' in kwargs:
request = kwargs['request']
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user = request.state.user
if hasattr(user, 'username'):
return user.username
elif isinstance(user, dict) and 'username' in user:
return user['username']
return None
def _extract_company_id(args, kwargs, key_params: Optional[List[str]]) -> Optional[int]:
"""
Extract company_id from function parameters for cache indexing
Tries multiple approaches:
1. Direct company_id in kwargs
2. company parameter (converted to int)
3. Positional args based on key_params position
Args:
args: Positional arguments
kwargs: Keyword arguments
key_params: List of parameter names
Returns:
Company ID as integer or None
"""
# Try kwargs first
if 'company_id' in kwargs:
try:
return int(kwargs['company_id'])
except (ValueError, TypeError):
pass
if 'company' in kwargs:
try:
return int(kwargs['company'])
except (ValueError, TypeError):
pass
# Try positional args based on key_params
if key_params:
if 'company_id' in key_params:
idx = key_params.index('company_id')
if idx < len(args):
try:
return int(args[idx])
except (ValueError, TypeError):
pass
elif 'company' in key_params:
idx = key_params.index('company')
if idx < len(args):
try:
return int(args[idx])
except (ValueError, TypeError):
pass
return None

View File

@@ -0,0 +1,333 @@
"""
Event-based cache invalidation monitor
Monitors {schema}.act tables for changes and invalidates cache automatically
"""
import asyncio
import logging
import sys
import os
from typing import Optional
# Path setup handled by main.py - this is redundant but kept for module isolation
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')))
logger = logging.getLogger(__name__)
class EventMonitor:
"""
Monitors schema.act tables for changes to trigger cache invalidation
Runs as background task, checking max(id_act) at configured intervals
Uses permanent schema_mappings cache to avoid repeated schema lookups
"""
def __init__(self, cache_manager, config):
"""
Initialize event monitor
Args:
cache_manager: CacheManager instance
config: CacheConfig instance
"""
self.cache = cache_manager
self.config = config
self.running = False
self.task: Optional[asyncio.Task] = None
async def start(self):
"""Start monitoring task"""
if self.running:
logger.warning("Event monitor already running")
return
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.info(
f"Event monitor started (interval: {self.config.check_interval}s)"
)
async def stop(self):
"""Stop monitoring task"""
if not self.running:
return
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info("Event monitor stopped")
async def _monitor_loop(self):
"""Main monitoring loop"""
while self.running:
try:
await self._check_all_companies()
await asyncio.sleep(self.config.check_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Event monitor error: {e}", exc_info=True)
# Wait 1 minute on error before retrying
await asyncio.sleep(60)
async def _check_all_companies(self):
"""
Check all companies with active cache for changes
Queries max(id_act) from {schema}.act for each cached company
and invalidates cache if changes detected
"""
try:
# Get list of companies with active cache entries
cached_companies = await self.cache.sqlite.get_cached_company_ids()
if not cached_companies:
logger.debug("No cached companies to monitor")
return
logger.info(f"Checking {len(cached_companies)} companies for changes...")
invalidated_count = 0
for company_id in cached_companies:
try:
# Check if company data changed
changed = await self._check_company_changes(company_id)
if changed:
# Invalidate cache for this company
await self.cache.invalidate(company_id=company_id)
invalidated_count += 1
logger.info(
f"Cache invalidated for company {company_id} due to act changes"
)
except Exception as e:
# Error for one company shouldn't stop checking others
logger.error(f"Error checking company {company_id}: {e}")
continue
if invalidated_count > 0:
logger.info(
f"Auto-invalidation complete: {invalidated_count} companies affected"
)
except Exception as e:
logger.error(f"Check all companies error: {e}", exc_info=True)
async def _check_company_changes(self, company_id: int) -> bool:
"""
Check if company data changed (monitor max(id_act) in schema.act)
Args:
company_id: Company ID to check
Returns:
True if cache should be invalidated, False otherwise
"""
try:
# 1. Get schema (from permanent cache)
schema = await self._get_schema_for_company(company_id)
if not schema:
logger.warning(f"Schema not found for company {company_id}")
return False
# 2. Get current max(id_act) from Oracle
current_max = await self._get_max_id_act(schema)
# 3. Get cached watermark
cached_watermark = await self.cache.sqlite.get_watermark(company_id)
# 4. Compare
if cached_watermark is None:
# First time checking - store watermark, no invalidation
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
logger.debug(
f"Watermark initialized for company {company_id}: {current_max}"
)
return False
if current_max > cached_watermark:
# Changes detected!
logger.info(
f"Schema {schema} (company {company_id}): "
f"id_act changed {cached_watermark}{current_max}"
)
# Update watermark
await self.cache.sqlite.set_watermark(company_id, schema, current_max)
return True # Invalidate cache
# No changes
return False
except Exception as e:
logger.error(f"Check company {company_id} changes error: {e}")
return False # Don't invalidate on error
async def _get_schema_for_company(self, company_id: int) -> Optional[str]:
"""
Get schema for company (with permanent caching)
First checks permanent schema_mappings cache,
falls back to Oracle query if not cached
Args:
company_id: Company ID
Returns:
Schema name or None
"""
# Check permanent cache first
cached_schema = await self.cache.sqlite.get_schema_mapping(company_id)
if cached_schema:
logger.debug(f"Schema mapping HIT for company {company_id}: {cached_schema}")
return cached_schema
# Cache MISS - query Oracle
logger.info(f"Schema mapping MISS for company {company_id}, querying Oracle...")
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :id
""", {'id': company_id})
result = cursor.fetchone()
if not result:
logger.warning(f"Company {company_id} not found in v_nom_firme")
return None
schema = result[0]
# Store PERMANENT in schema_mappings (never expires)
await self.cache.sqlite.set_schema_mapping(company_id, schema)
logger.info(f"Schema mapping stored for company {company_id}: {schema}")
return schema
except Exception as e:
logger.error(f"Get schema for company {company_id} error: {e}")
return None
async def _get_max_id_act(self, schema: str) -> int:
"""
Query max(id_act) from {schema}.act
Args:
schema: Schema name
Returns:
Max id_act value (0 if table empty)
"""
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# IMPORTANT: Schema comes from v_nom_firme (trusted source)
# so it's safe from SQL injection
query = f"SELECT MAX(id_act) FROM {schema}.act"
cursor.execute(query)
result = cursor.fetchone()
max_id_act = result[0] if result and result[0] is not None else 0
return max_id_act
except Exception as e:
logger.error(f"Get max_id_act for schema {schema} error: {e}")
return 0
# Optional: Preload all schema mappings at startup
async def preload_all_schema_mappings():
"""
Preload all schema mappings at startup (optional)
Prevents cache misses on first requests by populating
schema_mappings table with all companies
"""
from .cache_manager import get_cache
cache = get_cache()
if not cache:
logger.warning("Cache not initialized - skipping schema preload")
return
logger.info("Preloading all schema mappings...")
try:
from shared.database.oracle_pool import oracle_pool
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT id_firma, schema
FROM CONTAFIN_ORACLE.v_nom_firme
""")
results = cursor.fetchall()
for id_firma, schema in results:
await cache.sqlite.set_schema_mapping(id_firma, schema)
logger.info(f"Preloaded {len(results)} schema mappings")
except Exception as e:
logger.error(f"Schema preload error: {e}")
# Global event monitor instance
_event_monitor: Optional[EventMonitor] = None
async def init_event_monitor(cache_manager, config):
"""
Initialize global event monitor
Args:
cache_manager: CacheManager instance
config: CacheConfig instance
"""
global _event_monitor
_event_monitor = EventMonitor(cache_manager, config)
# Start if auto-invalidate enabled
if config.auto_invalidate_enabled:
await _event_monitor.start()
def get_event_monitor() -> Optional[EventMonitor]:
"""Get global event monitor instance"""
return _event_monitor
async def toggle_event_monitor(enabled: bool):
"""
Toggle event monitor on/off
Args:
enabled: True to start monitoring, False to stop
"""
monitor = get_event_monitor()
if not monitor:
logger.warning("Event monitor not initialized")
return
if enabled and not monitor.running:
await monitor.start()
elif not enabled and monitor.running:
await monitor.stop()

150
backend/modules/reports/cache/keys.py vendored Normal file
View File

@@ -0,0 +1,150 @@
"""
Cache key generation utilities
"""
import hashlib
import json
from typing import Any, List, Optional
def generate_cache_key(cache_type: str, key_params: Optional[List[str]], args: tuple, kwargs: dict) -> str:
"""
Generate cache key from function parameters
Format: "{cache_type}:{param1_value}:{param2_value}:..."
Args:
cache_type: Type of cache (e.g., 'dashboard_summary', 'invoices')
key_params: List of parameter names to include in key
args: Positional arguments from function call
kwargs: Keyword arguments from function call
Returns:
Cache key string
Examples:
generate_cache_key('schema', ['company_id'], (123,), {})
-> "schema:123"
generate_cache_key('dashboard_summary', ['company', 'username'], (), {'company': '123', 'username': 'john'})
-> "dashboard_summary:123:john"
generate_cache_key('invoices', ['company', 'invoice_type', 'status'], (123, 'CLIENTI', 'neplatite'), {})
-> "invoices:123:CLIENTI:neplatite"
"""
key_parts = [cache_type]
if not key_params:
# No specific params - use all args/kwargs (fallback)
if args:
key_parts.extend([str(arg) for arg in args])
if kwargs:
# Sort kwargs for consistent key generation
sorted_kwargs = sorted(kwargs.items())
key_parts.extend([f"{k}={v}" for k, v in sorted_kwargs])
else:
# Extract specific params
for i, param_name in enumerate(key_params):
# Try to get from kwargs first
if param_name in kwargs:
value = kwargs[param_name]
# Then try positional args
elif i < len(args):
value = args[i]
else:
# Parameter not found - use placeholder
value = "none"
key_parts.append(str(value))
return ":".join(key_parts)
def generate_key_pattern(cache_type: str, company_id: Optional[int] = None) -> str:
"""
Generate cache key pattern for matching multiple keys
Used for invalidation by type or company
Args:
cache_type: Type of cache
company_id: Optional company ID to filter by
Returns:
Pattern string (prefix)
Examples:
generate_key_pattern('dashboard_summary')
-> "dashboard_summary:"
generate_key_pattern('dashboard_summary', 123)
-> "dashboard_summary:123"
"""
if company_id is not None:
return f"{cache_type}:{company_id}"
return f"{cache_type}:"
def hash_complex_params(params: dict) -> str:
"""
Generate hash for complex parameters (e.g., filters, queries)
Used when cache key would be too long with full param values
Args:
params: Dictionary of parameters to hash
Returns:
8-character hash string
Example:
filters = {'status': 'neplatite', 'date_from': '2024-01-01', 'date_to': '2024-12-31'}
hash_complex_params(filters)
-> "a3f8b2c1"
"""
# Sort keys for consistent hashing
sorted_params = json.dumps(params, sort_keys=True)
hash_obj = hashlib.sha256(sorted_params.encode())
# Return first 8 characters of hex digest
return hash_obj.hexdigest()[:8]
def extract_company_id_from_key(cache_key: str) -> Optional[int]:
"""
Extract company_id from cache key
Assumes format: "cache_type:company_id:..."
Args:
cache_key: Cache key string
Returns:
Company ID or None if not found
Example:
extract_company_id_from_key("dashboard_summary:123:john")
-> 123
"""
parts = cache_key.split(":")
if len(parts) >= 2:
try:
return int(parts[1])
except (ValueError, TypeError):
pass
return None
def extract_cache_type_from_key(cache_key: str) -> str:
"""
Extract cache_type from cache key
Args:
cache_key: Cache key string
Returns:
Cache type (first part before colon)
Example:
extract_cache_type_from_key("dashboard_summary:123:john")
-> "dashboard_summary"
"""
return cache_key.split(":")[0]

View File

@@ -0,0 +1,180 @@
"""
In-memory cache with TTL (L1 cache)
Fast, limited size, lost on restart
"""
import time
import logging
from typing import Any, Optional, Dict
from collections import OrderedDict
logger = logging.getLogger(__name__)
class MemoryCache:
"""
In-memory LRU cache with TTL support
Features:
- LRU eviction when max_size reached
- Per-entry TTL expiration
- Thread-safe operations
- Fast O(1) get/set operations
"""
def __init__(self, max_size: int = 1000):
"""
Initialize memory cache
Args:
max_size: Maximum number of entries to store
"""
self.max_size = max_size
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'evictions': 0
}
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
if key not in self._cache:
self._stats['misses'] += 1
return None
entry = self._cache[key]
# Check TTL expiration
if entry['expires_at'] < time.time():
# Expired - remove and return None
del self._cache[key]
self._stats['misses'] += 1
logger.debug(f"Memory cache expired: {key}")
return None
# Move to end (LRU - most recently used)
self._cache.move_to_end(key)
self._stats['hits'] += 1
logger.debug(f"Memory cache HIT: {key}")
return entry['value']
async def set(self, key: str, value: Any, ttl: int):
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
"""
expires_at = time.time() + ttl
# Check if we need to evict (LRU)
if key not in self._cache and len(self._cache) >= self.max_size:
# Evict oldest entry (first item in OrderedDict)
evicted_key = next(iter(self._cache))
del self._cache[evicted_key]
self._stats['evictions'] += 1
logger.debug(f"Memory cache evicted (LRU): {evicted_key}")
# Store entry
self._cache[key] = {
'value': value,
'expires_at': expires_at,
'created_at': time.time()
}
# Move to end (most recently used)
self._cache.move_to_end(key)
self._stats['sets'] += 1
logger.debug(f"Memory cache SET: {key} (TTL: {ttl}s)")
async def delete(self, key: str) -> bool:
"""
Delete entry from cache
Args:
key: Cache key
Returns:
True if deleted, False if not found
"""
if key in self._cache:
del self._cache[key]
logger.debug(f"Memory cache deleted: {key}")
return True
return False
async def clear(self):
"""Clear all entries from cache"""
count = len(self._cache)
self._cache.clear()
logger.info(f"Memory cache cleared: {count} entries removed")
async def clear_by_pattern(self, pattern: str):
"""
Clear entries matching pattern (simple prefix match)
Args:
pattern: Key prefix to match (e.g., "dashboard_summary:123")
"""
keys_to_delete = [key for key in self._cache.keys() if key.startswith(pattern)]
for key in keys_to_delete:
del self._cache[key]
logger.info(f"Memory cache cleared by pattern '{pattern}': {len(keys_to_delete)} entries")
async def cleanup_expired(self):
"""Remove all expired entries"""
now = time.time()
expired_keys = [
key for key, entry in self._cache.items()
if entry['expires_at'] < now
]
for key in expired_keys:
del self._cache[key]
if expired_keys:
logger.info(f"Memory cache cleanup: {len(expired_keys)} expired entries removed")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics
Returns:
Dictionary with stats (hits, misses, size, etc.)
"""
total_requests = self._stats['hits'] + self._stats['misses']
hit_rate = (self._stats['hits'] / total_requests * 100) if total_requests > 0 else 0
return {
'size': len(self._cache),
'max_size': self.max_size,
'hits': self._stats['hits'],
'misses': self._stats['misses'],
'sets': self._stats['sets'],
'evictions': self._stats['evictions'],
'hit_rate': hit_rate,
'total_requests': total_requests
}
def reset_stats(self):
"""Reset statistics counters"""
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'evictions': 0
}

View File

@@ -0,0 +1,404 @@
"""
SQLite persistent cache (L2 cache)
Persistent, survives restarts, unlimited size
"""
import time
import json
import logging
import aiosqlite
from typing import Any, Optional, List, Dict
from pathlib import Path
from decimal import Decimal
from datetime import datetime, date
logger = logging.getLogger(__name__)
class CustomJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles Pydantic models, Decimal, datetime, etc."""
def default(self, obj):
# Handle Pydantic models
if hasattr(obj, 'dict'):
return obj.dict()
if hasattr(obj, 'model_dump'): # Pydantic v2
return obj.model_dump()
# Handle Decimal
if isinstance(obj, Decimal):
return float(obj)
# Handle datetime/date
if isinstance(obj, (datetime, date)):
return obj.isoformat()
return super().default(obj)
class SQLiteCache:
"""
SQLite-based persistent cache
Features:
- Persistent storage (survives restarts)
- JSON serialization for complex objects
- Schema mappings (permanent cache for company->schema)
- Watermarks for event-based invalidation
- Performance tracking and benchmarks
"""
def __init__(self, db_path: str):
"""
Initialize SQLite cache
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._ensure_db_dir()
def _ensure_db_dir(self):
"""Ensure database directory exists"""
db_dir = Path(self.db_path).parent
db_dir.mkdir(parents=True, exist_ok=True)
async def init_db(self):
"""Initialize database schema with WAL mode enabled"""
async with aiosqlite.connect(self.db_path) as db:
# Enable Write-Ahead Logging (WAL) mode for better concurrency
await db.execute("PRAGMA journal_mode=WAL")
await db.commit()
# Table: cache_entries
await db.execute("""
CREATE TABLE IF NOT EXISTS cache_entries (
cache_key TEXT PRIMARY KEY,
cache_type TEXT NOT NULL,
company_id INTEGER,
data_json TEXT NOT NULL,
created_at REAL NOT NULL,
expires_at REAL NOT NULL,
hit_count INTEGER DEFAULT 0,
last_accessed REAL
)
""")
await db.execute("CREATE INDEX IF NOT EXISTS idx_cache_type ON cache_entries(cache_type)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_company_id ON cache_entries(company_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_expires_at ON cache_entries(expires_at)")
# Table: schema_mappings (PERMANENT)
await db.execute("""
CREATE TABLE IF NOT EXISTS schema_mappings (
id_firma INTEGER PRIMARY KEY,
schema TEXT NOT NULL,
created_at REAL NOT NULL,
last_verified REAL
)
""")
# Table: query_benchmarks
await db.execute("""
CREATE TABLE IF NOT EXISTS query_benchmarks (
cache_type TEXT PRIMARY KEY,
avg_time_ms REAL NOT NULL,
sample_count INTEGER DEFAULT 0,
last_updated REAL
)
""")
# Table: performance_log
await db.execute("""
CREATE TABLE IF NOT EXISTS performance_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cache_type TEXT NOT NULL,
company_id INTEGER,
cache_hit BOOLEAN NOT NULL,
response_time_ms REAL NOT NULL,
estimated_oracle_time_ms REAL,
time_saved_ms REAL,
username TEXT,
timestamp REAL NOT NULL
)
""")
await db.execute("CREATE INDEX IF NOT EXISTS idx_perf_timestamp ON performance_log(timestamp)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_perf_cache_type ON performance_log(cache_type)")
# Table: user_cache_settings
await db.execute("""
CREATE TABLE IF NOT EXISTS user_cache_settings (
username TEXT PRIMARY KEY,
cache_enabled BOOLEAN DEFAULT TRUE,
created_at REAL,
updated_at REAL
)
""")
# Table: cache_config
await db.execute("""
CREATE TABLE IF NOT EXISTS cache_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at REAL
)
""")
# Table: cache_watermarks
await db.execute("""
CREATE TABLE IF NOT EXISTS cache_watermarks (
company_id INTEGER PRIMARY KEY,
schema TEXT NOT NULL,
max_id_act INTEGER NOT NULL,
checked_at REAL NOT NULL
)
""")
await db.commit()
logger.info("SQLite cache database initialized")
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT data_json, expires_at
FROM cache_entries
WHERE cache_key = ?
""", (key,)) as cursor:
result = await cursor.fetchone()
if not result:
return None
data_json, expires_at = result
# Check TTL expiration
if expires_at < time.time():
# Expired - delete and return None
await db.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
await db.commit()
logger.debug(f"SQLite cache expired: {key}")
return None
# Update hit_count and last_accessed
await db.execute("""
UPDATE cache_entries
SET hit_count = hit_count + 1, last_accessed = ?
WHERE cache_key = ?
""", (time.time(), key))
await db.commit()
logger.debug(f"SQLite cache HIT: {key}")
return json.loads(data_json)
async def set(self, key: str, value: Any, cache_type: str, company_id: Optional[int], ttl: int):
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
cache_type: Type of cache entry
company_id: Company ID (None for global caches)
ttl: Time to live in seconds
"""
# Use custom encoder to handle Pydantic models, Decimal, datetime, etc.
data_json = json.dumps(value, cls=CustomJSONEncoder)
now = time.time()
expires_at = now + ttl
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT OR REPLACE INTO cache_entries
(cache_key, cache_type, company_id, data_json, created_at, expires_at, hit_count, last_accessed)
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
""", (key, cache_type, company_id, data_json, now, expires_at, now))
await db.commit()
logger.debug(f"SQLite cache SET: {key} (TTL: {ttl}s)")
async def delete(self, key: str) -> bool:
"""Delete entry from cache"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("DELETE FROM cache_entries WHERE cache_key = ?", (key,))
await db.commit()
deleted = cursor.rowcount > 0
if deleted:
logger.debug(f"SQLite cache deleted: {key}")
return deleted
async def clear(self):
"""Clear all cache entries"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("DELETE FROM cache_entries")
await db.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared: {count} entries removed")
async def clear_by_company(self, company_id: int):
"""Clear all entries for specific company"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("DELETE FROM cache_entries WHERE company_id = ?", (company_id,))
await db.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared for company {company_id}: {count} entries")
async def clear_by_type(self, cache_type: str):
"""Clear all entries of specific type"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("DELETE FROM cache_entries WHERE cache_type = ?", (cache_type,))
await db.commit()
count = cursor.rowcount
logger.info(f"SQLite cache cleared for type '{cache_type}': {count} entries")
async def cleanup_expired(self):
"""Remove all expired entries"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("DELETE FROM cache_entries WHERE expires_at < ?", (time.time(),))
await db.commit()
count = cursor.rowcount
if count > 0:
logger.info(f"SQLite cache cleanup: {count} expired entries removed")
# Schema Mappings (PERMANENT)
async def get_schema_mapping(self, company_id: int) -> Optional[str]:
"""Get permanent cached schema for company"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT schema
FROM schema_mappings
WHERE id_firma = ?
""", (company_id,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_schema_mapping(self, company_id: int, schema: str):
"""Set permanent schema mapping (never expires)"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT OR REPLACE INTO schema_mappings
(id_firma, schema, created_at, last_verified)
VALUES (?, ?, ?, ?)
""", (company_id, schema, time.time(), time.time()))
await db.commit()
# Benchmarks
async def get_benchmark(self, cache_type: str) -> Optional[float]:
"""Get average benchmark time for cache type"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT avg_time_ms
FROM query_benchmarks
WHERE cache_type = ?
""", (cache_type,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_benchmark(self, cache_type: str, avg_time_ms: float, sample_count: int):
"""Set/update benchmark"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT OR REPLACE INTO query_benchmarks
(cache_type, avg_time_ms, sample_count, last_updated)
VALUES (?, ?, ?, ?)
""", (cache_type, avg_time_ms, sample_count, time.time()))
await db.commit()
# Performance Tracking
async def log_performance(self, cache_type: str, company_id: Optional[int], cache_hit: bool,
response_time_ms: float, estimated_oracle_time_ms: Optional[float],
time_saved_ms: Optional[float], username: Optional[str]):
"""Log performance metric"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT INTO performance_log
(cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
time_saved_ms, username, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (cache_type, company_id, cache_hit, response_time_ms, estimated_oracle_time_ms,
time_saved_ms, username, time.time()))
await db.commit()
# User Settings
async def get_user_cache_enabled(self, username: str) -> bool:
"""Get user cache setting (default True)"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT cache_enabled
FROM user_cache_settings
WHERE username = ?
""", (username,)) as cursor:
result = await cursor.fetchone()
return bool(result[0]) if result else True # Default enabled, explicit bool conversion
async def set_user_cache_enabled(self, username: str, enabled: bool):
"""Set user cache setting"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT OR REPLACE INTO user_cache_settings
(username, cache_enabled, created_at, updated_at)
VALUES (?, ?, ?, ?)
""", (username, enabled, time.time(), time.time()))
await db.commit()
# Watermarks
async def get_watermark(self, company_id: int) -> Optional[int]:
"""Get cached watermark (max_id_act) for company"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT max_id_act
FROM cache_watermarks
WHERE company_id = ?
""", (company_id,)) as cursor:
result = await cursor.fetchone()
return result[0] if result else None
async def set_watermark(self, company_id: int, schema: str, max_id_act: int):
"""Set/update watermark for company"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT OR REPLACE INTO cache_watermarks
(company_id, schema, max_id_act, checked_at)
VALUES (?, ?, ?, ?)
""", (company_id, schema, max_id_act, time.time()))
await db.commit()
async def get_cached_company_ids(self) -> List[int]:
"""Get list of company_ids with active cache entries"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("""
SELECT DISTINCT company_id
FROM cache_entries
WHERE company_id IS NOT NULL
AND expires_at > ?
""", (time.time(),)) as cursor:
results = await cursor.fetchall()
return [row[0] for row in results]
# Statistics
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
async with aiosqlite.connect(self.db_path) as db:
# Total entries
async with db.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
total_entries = (await cursor.fetchone())[0]
# Active entries (not expired)
async with db.execute("""
SELECT COUNT(*) FROM cache_entries WHERE expires_at > ?
""", (time.time(),)) as cursor:
active_entries = (await cursor.fetchone())[0]
return {
'total_entries': total_entries,
'active_entries': active_entries,
'expired_entries': total_entries - active_entries
}

View File

@@ -0,0 +1,19 @@
"""
Calendar period models for accounting period selector
"""
from pydantic import BaseModel
from typing import List, Optional
class CalendarPeriod(BaseModel):
"""Model for an accounting period"""
an: int # Year
luna: int # Month (1-12)
display_name: str # Format: "Decembrie 2025"
class CalendarPeriodsResponse(BaseModel):
"""Response model for calendar periods list"""
periods: List[CalendarPeriod]
current_period: Optional[CalendarPeriod] = None # Most recent period
total_count: int

View File

@@ -0,0 +1,129 @@
from pydantic import BaseModel
from decimal import Decimal
from typing import List, Dict, Optional, Any
class TreasuryAccount(BaseModel):
"""Cont de trezorerie (bancă/casă)"""
cont: str # 5121, 5124, 5311, 5314
nume_cont: str # "Bancă LEI", "Casă VALUTA" etc
nume_banca: str # Numele băncii din vbalanta_parteneri.nume
sold: Decimal
valuta: str
class TrendData(BaseModel):
"""Model pentru datele de trend - MODEL VECHI"""
labels: List[str]
incasari: List[Decimal]
plati: List[Decimal]
trezorerie: List[Decimal]
incasari_total: Decimal
plati_total: Decimal
trezorerie_total: Decimal
incasari_change: Optional[float] = None
plati_change: Optional[float] = None
trezorerie_change: Optional[float] = None
class TrendsResponse(BaseModel):
"""Model pentru răspunsul endpoint-ului de trenduri - MODEL NOU"""
# Current period data
periods: List[str]
clienti_facturat: List[float]
clienti_incasat: List[float]
furnizori_facturat: List[float]
furnizori_achitat: List[float]
clienti_sold: List[float]
furnizori_sold: List[float]
trezorerie_sold: Optional[List[float]] = None
rata_incasare_clienti: List[float]
rata_achitare_furnizori: List[float]
# Previous period data (for year-over-year comparison in sparklines)
previous_periods: Optional[List[str]] = None
clienti_facturat_prev: Optional[List[float]] = None
clienti_incasat_prev: Optional[List[float]] = None
furnizori_facturat_prev: Optional[List[float]] = None
furnizori_achitat_prev: Optional[List[float]] = None
clienti_sold_prev: Optional[List[float]] = None
furnizori_sold_prev: Optional[List[float]] = None
trezorerie_sold_prev: Optional[List[float]] = None
# Metadata and analytics
metadata: Dict[str, Any]
growth_rates: Optional[Dict[str, float]] = None
# Cache metadata (optional, for Telegram Bot)
cache_hit: Optional[bool] = None
response_time_ms: Optional[float] = None
cache_source: Optional[str] = None
class DashboardSummary(BaseModel):
"""Model pentru toate datele dashboard-ului"""
# CLIENȚI - statistici existente
clienti_total_facturat: Decimal # precdeb + debit (conturi 4111, 461)
clienti_total_incasat: Decimal # preccred + credit (conturi 4111, 461)
clienti_avansuri: Decimal # sold 419 (pasiv): credit - debit
clienti_sold_total: Decimal # (facturat - incasat) - avansuri
clienti_sold_restant: Decimal # sold cu datascad < azi
# CLIENȚI - NOI câmpuri pentru sold în termen
clienti_sold_in_termen: Decimal # sold cu datascad >= azi
# CLIENȚI - NOI detalieri restanțe (sold cu datascad < azi)
clienti_restant_7: Decimal # restant 1-7 zile
clienti_restant_14: Decimal # restant 8-14 zile
clienti_restant_30: Decimal # restant 15-30 zile
clienti_restant_60: Decimal # restant 31-60 zile
clienti_restant_90: Decimal # restant 61-90 zile
clienti_restant_90plus: Decimal # restant 90+ zile
# CLIENȚI - NOI detalieri scadențe (sold cu datascad >= azi)
clienti_scadent_7: Decimal # scadent în 1-7 zile
clienti_scadent_14: Decimal # scadent în 8-14 zile
clienti_scadent_30: Decimal # scadent în 15-30 zile
clienti_scadent_60: Decimal # scadent în 31-60 zile
clienti_scadent_90: Decimal # scadent în 61-90 zile
clienti_scadent_90plus: Decimal # scadent în 90+ zile
# FURNIZORI - statistici existente
furnizori_total_facturat: Decimal # preccred + credit (conturi 401, 404, 462)
furnizori_total_achitat: Decimal # precdeb + debit (conturi 401, 404, 462)
furnizori_avansuri: Decimal # sold 409x (activ): debit - credit
furnizori_sold_total: Decimal # (facturat - achitat) - avansuri
furnizori_sold_restant: Decimal # sold cu datascad < azi
# FURNIZORI - NOI câmpuri pentru sold în termen
furnizori_sold_in_termen: Decimal # sold cu datascad >= azi
# FURNIZORI - NOI detalieri restanțe (sold cu datascad < azi)
furnizori_restant_7: Decimal # restant 1-7 zile
furnizori_restant_14: Decimal # restant 8-14 zile
furnizori_restant_30: Decimal # restant 15-30 zile
furnizori_restant_60: Decimal # restant 31-60 zile
furnizori_restant_90: Decimal # restant 61-90 zile
furnizori_restant_90plus: Decimal # restant 90+ zile
# FURNIZORI - NOI detalieri scadențe (sold cu datascad >= azi)
furnizori_scadent_7: Decimal # scadent în 1-7 zile
furnizori_scadent_14: Decimal # scadent în 8-14 zile
furnizori_scadent_30: Decimal # scadent în 15-30 zile
furnizori_scadent_60: Decimal # scadent în 31-60 zile
furnizori_scadent_90: Decimal # scadent în 61-90 zile
furnizori_scadent_90plus: Decimal # scadent în 90+ zile
# TREZORERIE - existente
treasury_accounts: List[TreasuryAccount]
treasury_totals_by_currency: Dict[str, Decimal]
# DATE SUPLIMENTARE pentru trend analysis
clienti_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
furnizori_facturat_luna_anterioara: Optional[Decimal] = Decimal('0')
clienti_facturat_an_curent: Optional[Decimal] = Decimal('0')
clienti_facturat_an_anterior: Optional[Decimal] = Decimal('0')
furnizori_facturat_an_curent: Optional[Decimal] = Decimal('0')
furnizori_facturat_an_anterior: Optional[Decimal] = Decimal('0')
# SOLDURI TVA
tva_plata_precedent: Decimal = Decimal('0')
tva_recuperat_precedent: Decimal = Decimal('0')
tva_plata_curent: Decimal = Decimal('0')
tva_recuperat_curent: Decimal = Decimal('0')

View File

@@ -0,0 +1,79 @@
"""
Modele Pydantic pentru facturi - Compatibile cu aplicația Flask existentă
"""
from pydantic import BaseModel, Field, validator
from datetime import date
from typing import Optional, List, Literal
from decimal import Decimal
class InvoiceBase(BaseModel):
"""Model de bază pentru factură - mapează exact pe rezultatul query-ului Flask"""
nume: str = Field(description="Numele partenerului")
nract: int = Field(description="Numărul actului")
dataact: Optional[date] = Field(description="Data actului")
datascad: Optional[date] = Field(description="Data scadentă")
contract: Optional[str] = Field(description="Numărul contractului")
cod_fiscal: Optional[str] = Field(description="Codul fiscal")
reg_comert: Optional[str] = Field(description="Registrul comerțului")
cont: Optional[str] = Field(description="Contul contabil")
valuta: str = Field(default="RON", description="Valuta (RON, EUR, USD, etc.)")
class Invoice(InvoiceBase):
"""Model complet pentru factură cu calcule financiare"""
totctva: Decimal = Field(description="Total cu TVA", decimal_places=2)
achitat: Decimal = Field(description="Suma achitată", decimal_places=2)
soldfinal: Decimal = Field(description="Soldul final", decimal_places=2)
css_class: Literal["", "invoice-paid", "invoice-overdue"] = Field(
default="", description="Clasa CSS pentru stilizare"
)
@validator('css_class', always=True)
def determine_css_class(cls, v, values):
"""Determină automat clasa CSS bazată pe status factură"""
if 'soldfinal' in values and 'datascad' in values:
sold = values['soldfinal']
data_scad = values['datascad']
if sold < 1:
return 'invoice-paid'
elif data_scad and data_scad < date.today() and sold != 0:
return 'invoice-overdue'
return ''
class InvoiceFilter(BaseModel):
"""Filtru pentru căutarea facturilor"""
company: str = Field(description="Codul firmei (schema Oracle)")
partner_type: Literal["CLIENTI", "FURNIZORI"] = Field(description="Tipul partenerului")
luna: Optional[int] = Field(default=None, ge=1, le=12, description="Luna contabilă (1-12)")
an: Optional[int] = Field(default=None, ge=2000, le=2100, description="Anul contabil")
partner_name: Optional[str] = Field(description="Filtru după nume")
cont: Optional[str] = Field(description="Filtru după cont contabil")
only_unpaid: bool = Field(default=True, description="Doar neachitate")
min_amount: Optional[Decimal] = Field(description="Suma minimă")
max_amount: Optional[Decimal] = Field(description="Suma maximă")
page: int = Field(default=1, ge=1, description="Pagina")
page_size: int = Field(default=50, ge=1, le=10000000, description="Mărimea paginii")
class InvoiceListResponse(BaseModel):
"""Răspuns pentru lista de facturi"""
invoices: List[Invoice]
total_count: int
filtered_count: int
total_amount: Decimal
page: int
page_size: int
has_more: bool
accounting_period: Optional[dict] = Field(default=None, description="Perioada contabilă (an, luna)")
# Total sold din TOATE facturile filtrate (nu doar pagina curentă)
total_sold_all: Decimal = Field(default=Decimal('0.00'), description="Total sold din toate facturile filtrate")
class InvoiceSummary(BaseModel):
"""Rezumat pentru facturi - pentru dashboard"""
company: str
partner_type: str
total_invoices: int
total_amount: Decimal
paid_amount: Decimal
outstanding_amount: Decimal
overdue_amount: Decimal
overdue_count: int

View File

@@ -0,0 +1,52 @@
from pydantic import BaseModel
from decimal import Decimal
from datetime import datetime
from typing import Optional, List
class AccountingPeriod(BaseModel):
"""Model pentru perioada contabilă"""
an: Optional[int] = None
luna: Optional[int] = None
class BankCashRegister(BaseModel):
"""Model pentru Registrul de Casă și Bancă"""
nume: str
nract: Optional[int] = None
dataact: Optional[datetime] = None
nume_cont_bancar: str # din vbalanta_parteneri.nume
incasari: Decimal
plati: Decimal
sold: Decimal
valuta: Optional[str] = None
tip_registru: str # "BANCA LEI", "CASA VALUTA" etc
explicatia: str
class RegisterFilter(BaseModel):
"""Filtre pentru registrul de casă și bancă"""
company: str
register_type: Optional[str] = None # BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate
luna: Optional[int] = None # Luna contabilă (1-12) pentru PACK_SESIUNE
an: Optional[int] = None # Anul contabil pentru PACK_SESIUNE
date_from: Optional[datetime] = None
date_to: Optional[datetime] = None
partner_name: Optional[str] = None
bank_account: Optional[str] = None # Filter for specific bank/cash account (bancasa)
page: int = 1
page_size: int = 50
class RegisterListResponse(BaseModel):
"""Răspuns pentru lista din registru"""
registers: List[BankCashRegister]
total_count: int
filtered_count: int
total_incasari: Decimal
total_plati: Decimal
page: int
page_size: int
has_more: bool
accounting_period: Optional[AccountingPeriod] = None
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
sold_precedent_all: Decimal = Decimal('0.00')
total_incasari_all: Decimal = Decimal('0.00')
total_plati_all: Decimal = Decimal('0.00')
sold_final_all: Decimal = Decimal('0.00')

View File

@@ -0,0 +1,102 @@
"""
Pydantic models for Trial Balance (Balanță de Verificare)
Maps to Oracle VBAL VIEW (exists in each company schema)
"""
from pydantic import BaseModel, Field
from typing import Optional, List
from decimal import Decimal
class TrialBalanceItem(BaseModel):
"""
Individual trial balance record from VBAL VIEW
Real structure from Oracle:
- CONT: account number
- DENUMIRE: account description
- PRECDEB/PRECCRED: previous balance debit/credit
- RULDEB/RULCRED: monthly movement debit/credit
- SOLDDEB/SOLDCRED: final balance debit/credit
"""
cont: str = Field(description="Număr cont contabil (CONT)")
denumire: Optional[str] = Field(default="", description="Denumire cont (DENUMIRE)")
sold_precedent_debit: Decimal = Field(description="Sold precedent debit (PRECDEB)", decimal_places=2)
sold_precedent_credit: Decimal = Field(description="Sold precedent credit (PRECCRED)", decimal_places=2)
rulaj_lunar_debit: Decimal = Field(description="Rulaj lunar debit (RULDEB)", decimal_places=2)
rulaj_lunar_credit: Decimal = Field(description="Rulaj lunar credit (RULCRED)", decimal_places=2)
sold_final_debit: Decimal = Field(description="Sold final debit (SOLDDEB)", decimal_places=2)
sold_final_credit: Decimal = Field(description="Sold final credit (SOLDCRED)", decimal_places=2)
class Config:
from_attributes = True
class TrialBalanceFilters(BaseModel):
"""
Filters applied to trial balance data
"""
luna: int = Field(description="Luna (1-12)")
an: int = Field(description="An")
cont_filter: Optional[str] = Field(default=None, description="Filtru număr cont (partial match)")
denumire_filter: Optional[str] = Field(default=None, description="Filtru denumire cont (partial match, case-insensitive)")
class TrialBalancePagination(BaseModel):
"""
Pagination metadata
"""
total_items: int = Field(description="Total number of items")
total_pages: int = Field(description="Total number of pages")
current_page: int = Field(description="Current page number")
page_size: int = Field(description="Items per page")
class TrialBalanceTotals(BaseModel):
"""
Totals for all 6 columns from all filtered records (not just current page)
"""
total_sold_precedent_debit: Decimal = Decimal('0.00')
total_sold_precedent_credit: Decimal = Decimal('0.00')
total_rulaj_lunar_debit: Decimal = Decimal('0.00')
total_rulaj_lunar_credit: Decimal = Decimal('0.00')
total_sold_final_debit: Decimal = Decimal('0.00')
total_sold_final_credit: Decimal = Decimal('0.00')
class TrialBalanceResponse(BaseModel):
"""
Complete response for trial balance endpoint
"""
success: bool = Field(default=True, description="Request success status")
data: dict = Field(description="Trial balance data with items, pagination, and filters")
class Config:
json_schema_extra = {
"example": {
"success": True,
"data": {
"items": [
{
"cont": "4111",
"dcont": "Furnizori interni",
"sold_precedent_debit": 0.00,
"sold_precedent_credit": 15000.00,
"rulaj_lunar_debit": 5000.00,
"rulaj_lunar_credit": 8000.00,
"sold_final_debit": 0.00,
"sold_final_credit": 18000.00
}
],
"pagination": {
"total_items": 150,
"total_pages": 3,
"current_page": 1,
"page_size": 50
},
"filters_applied": {
"luna": 11,
"an": 2025,
"cont_filter": None,
"denumire_filter": "furnizori"
}
}
}
}

View File

@@ -0,0 +1,36 @@
"""Reports module router factory."""
from fastapi import APIRouter
def create_reports_router() -> APIRouter:
"""
Create and configure Reports module router.
Includes all report-related endpoints:
- /invoices - Invoice management
- /dashboard - Dashboard and metrics
- /treasury - Treasury operations
- /trial-balance - Trial balance reports
- /cache - Cache management
Returns:
APIRouter: Configured router for reports module
"""
router = APIRouter()
# Import routers here to avoid circular imports
from .invoices import router as invoices_router
from .dashboard import router as dashboard_router
from .treasury import router as treasury_router
from .trial_balance import router as trial_balance_router
from .cache import router as cache_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/reports)
router.include_router(invoices_router, prefix="/invoices", tags=["reports-invoices"])
router.include_router(dashboard_router, prefix="/dashboard", tags=["reports-dashboard"])
router.include_router(treasury_router, prefix="/treasury", tags=["reports-treasury"])
router.include_router(trial_balance_router, prefix="/trial-balance", tags=["reports-trial-balance"])
router.include_router(cache_router, prefix="/cache", tags=["reports-cache"])
return router

View File

@@ -0,0 +1,398 @@
"""
API Router pentru managementul cache-ului
"""
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from typing import Optional, Dict, Any
# import sys # Removed - no longer needed
import os
import time
from datetime import datetime, timedelta
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..cache import get_cache, get_event_monitor, toggle_event_monitor
router = APIRouter(prefix="/cache", tags=["cache"])
# Pydantic Models
class CacheStatsResponse(BaseModel):
"""Răspuns statistici cache"""
enabled: bool
global_enabled: bool
user_enabled: bool
cache_type: str
hit_rate: float
total_hits: int
total_misses: int
queries_saved: Dict[str, int]
response_times: Dict[str, Dict[str, Any]]
cache_size: Dict[str, int]
auto_invalidate: bool
last_cleanup: Optional[str] = None
class InvalidateCacheRequest(BaseModel):
"""Request pentru invalidare cache"""
company_id: Optional[int] = None
cache_type: Optional[str] = None
class ToggleUserCacheRequest(BaseModel):
"""Request pentru toggle cache per-user"""
enabled: bool
class ToggleGlobalCacheRequest(BaseModel):
"""Request pentru toggle cache global"""
enabled: bool
class ToggleAutoInvalidateRequest(BaseModel):
"""Request pentru toggle auto-invalidation"""
enabled: bool
# Helper Functions
async def _calculate_cache_stats() -> Dict[str, Any]:
"""Calculate comprehensive cache statistics"""
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Get basic cache stats
stats = await cache.get_stats()
# Calculate hit rate
memory_stats = stats.get('memory', {})
total_hits = memory_stats.get('hits', 0)
total_misses = memory_stats.get('misses', 0)
total_requests = total_hits + total_misses
hit_rate = (total_hits / total_requests * 100) if total_requests > 0 else 0
# Calculate queries saved (from performance_log)
queries_saved = await _calculate_queries_saved(cache)
# Calculate response times per cache type
response_times = await _calculate_response_times(cache)
# Get cache sizes
cache_size = {
'memory': memory_stats.get('size', 0),
'sqlite': stats.get('sqlite', {}).get('active_entries', 0)
}
# Get event monitor status
monitor = get_event_monitor()
auto_invalidate = monitor.running if monitor else False
return {
'enabled': cache.config.enabled,
'global_enabled': cache.config.enabled,
'cache_type': cache.config.cache_type,
'hit_rate': round(hit_rate, 1),
'total_hits': total_hits,
'total_misses': total_misses,
'queries_saved': queries_saved,
'response_times': response_times,
'cache_size': cache_size,
'auto_invalidate': auto_invalidate,
'last_cleanup': None # TODO: track last cleanup time
}
async def _calculate_queries_saved(cache) -> Dict[str, int]:
"""Calculate queries saved by time period"""
import aiosqlite
try:
async with aiosqlite.connect(cache.sqlite.db_path) as db:
now = time.time()
today_start = now - 86400 # 24 hours
week_start = now - 604800 # 7 days
# Today
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1 AND timestamp >= ?
""", (today_start,)) as cursor:
today = (await cursor.fetchone())[0]
# This week
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1 AND timestamp >= ?
""", (week_start,)) as cursor:
week = (await cursor.fetchone())[0]
# All time
async with db.execute("""
SELECT COUNT(*) FROM performance_log
WHERE cache_hit = 1
""") as cursor:
total = (await cursor.fetchone())[0]
return {
'today': today,
'week': week,
'total': total
}
except Exception as e:
return {'today': 0, 'week': 0, 'total': 0}
async def _calculate_response_times(cache) -> Dict[str, Dict[str, Any]]:
"""Calculate average response times per cache type"""
import aiosqlite
try:
async with aiosqlite.connect(cache.sqlite.db_path) as db:
# Get average times per cache type
async with db.execute("""
SELECT
cache_type,
AVG(CASE WHEN cache_hit = 1 THEN response_time_ms ELSE NULL END) as avg_cached,
AVG(CASE WHEN cache_hit = 0 THEN response_time_ms ELSE NULL END) as avg_oracle
FROM performance_log
WHERE timestamp >= ?
GROUP BY cache_type
""", (time.time() - 86400,)) as cursor: # Last 24 hours
results = await cursor.fetchall()
response_times = {}
for row in results:
cache_type, avg_cached, avg_oracle = row
if avg_cached and avg_oracle:
improvement = int((avg_oracle - avg_cached) / avg_oracle * 100)
response_times[cache_type] = {
'cached': int(avg_cached),
'oracle': int(avg_oracle),
'improvement': improvement
}
return response_times
except Exception as e:
return {}
# API Endpoints
@router.get("/stats", response_model=CacheStatsResponse)
async def get_cache_stats(
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține statistici complete cache
Returns:
- Hit rate, queries saved, response times
- Cache sizes (memory + SQLite)
- Auto-invalidation status
- Per-user cache setting
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Get base stats
stats = await _calculate_cache_stats()
# Add user-specific setting
user_enabled = await cache.is_enabled_for_user(current_user.username)
stats['user_enabled'] = user_enabled
return CacheStatsResponse(**stats)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving cache stats: {str(e)}")
@router.post("/invalidate")
async def invalidate_cache(
request: InvalidateCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Invalidează cache
Args:
company_id: Opțional - invalidează doar pentru această companie
cache_type: Opțional - invalidează doar acest tip de cache
Returns:
Message de confirmare
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
await cache.invalidate(
company_id=request.company_id,
cache_type=request.cache_type
)
if request.company_id and request.cache_type:
message = f"Cache invalidated for company {request.company_id}, type {request.cache_type}"
elif request.company_id:
message = f"Cache invalidated for company {request.company_id}"
elif request.cache_type:
message = f"Cache invalidated for type {request.cache_type}"
else:
message = "All cache invalidated"
return {
"success": True,
"message": message,
"invalidated_at": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error invalidating cache: {str(e)}")
@router.post("/toggle-user")
async def toggle_user_cache(
request: ToggleUserCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle cache per-user
Permite utilizatorului să activeze/dezactiveze cache-ul pentru el
Folosit pentru A/B testing și comparații de performanță
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status
"""
try:
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
await cache.set_user_cache_enabled(current_user.username, request.enabled)
return {
"success": True,
"username": current_user.username,
"cache_enabled": request.enabled,
"message": f"Cache {'enabled' if request.enabled else 'disabled'} for user {current_user.username}"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling user cache: {str(e)}")
@router.post("/toggle-global")
async def toggle_global_cache(
request: ToggleGlobalCacheRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle cache global (ADMIN only)
Activează/dezactivează cache-ul la nivel global pentru toți utilizatorii
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status global
"""
try:
# TODO: Add admin permission check
# For now, allow any authenticated user
cache = get_cache()
if not cache:
raise HTTPException(status_code=503, detail="Cache not initialized")
# Update config (NOTE: This is runtime only, .env needs manual update)
cache.config.enabled = request.enabled
return {
"success": True,
"global_enabled": request.enabled,
"message": f"Cache {'enabled' if request.enabled else 'disabled'} globally",
"note": "This change is runtime only. Update .env CACHE_ENABLED for persistence."
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling global cache: {str(e)}")
@router.post("/toggle-auto-invalidate")
async def toggle_auto_invalidation(
request: ToggleAutoInvalidateRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Toggle auto-invalidation monitoring
Activează/dezactivează monitorizarea automată a {schema}.act
pentru invalidarea cache-ului când se detectează modificări
Args:
enabled: True pentru activare, False pentru dezactivare
Returns:
Noul status auto-invalidation
"""
try:
# TODO: Add admin permission check
# For now, allow any authenticated user
await toggle_event_monitor(request.enabled)
return {
"success": True,
"auto_invalidate_enabled": request.enabled,
"message": f"Auto-invalidation {'enabled' if request.enabled else 'disabled'}",
"note": "Monitors max(id_act) in {schema}.act tables for changes"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling auto-invalidation: {str(e)}")
@router.get("/health")
async def cache_health():
"""
Health check pentru sistemul de cache
Returns:
Status cache, mărime, și uptime
"""
try:
cache = get_cache()
if not cache:
return {
"status": "not_initialized",
"enabled": False
}
stats = await cache.get_stats()
monitor = get_event_monitor()
return {
"status": "healthy",
"enabled": cache.config.enabled,
"cache_type": cache.config.cache_type,
"memory_size": stats.get('memory', {}).get('size', 0),
"sqlite_size": stats.get('sqlite', {}).get('active_entries', 0),
"auto_invalidate_running": monitor.running if monitor else False
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}

View File

@@ -0,0 +1,430 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from typing import Optional
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
import logging
logger = logging.getLogger(__name__)
from ..models.dashboard import DashboardSummary, TrendsResponse, TrendData
from ..services.dashboard_service import DashboardService
router = APIRouter()
@router.get("/summary")
async def get_dashboard_summary(
request: Request,
company: str = Query(description="Codul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține toate datele pentru dashboard într-un singur apel
- Necesită autentificare JWT
- Returnează statistici clienți/furnizori și trezorerie
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
- Suportă filtrare pe luna/an contabil (dacă nu sunt specificate, folosește ultima perioadă)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_complete_summary(company, current_user.username, luna=luna, an=an, request=request)
# Convert Pydantic model to dict for JSON serialization
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor dashboard: {str(e)}")
@router.get("/trends", response_model=TrendsResponse)
async def get_dashboard_trends(
request: Request,
company: str = Query(description="Codul firmei"),
period: str = Query(default="30d", description="Perioada pentru trends: 7d, 30d, ytd, 12m"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
compare_previous: bool = Query(default=True, description="Compară cu perioada anterioară"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține trenduri pentru indicatorii principali (clienți/furnizori)
- period: "7d" (7 zile), "30d" (30 zile), "ytd" (year to date), "12m" (12 luni)
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- compare_previous: dacă să compare cu perioada anterioară
- Necesită autentificare JWT
- Returnează date pentru grafice de trenduri
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
# Validează perioada
valid_periods = ["7d", "30d", "ytd", "12m"]
if period not in valid_periods:
raise HTTPException(
status_code=400,
detail=f"Perioadă nevalidă: {period}. Valori permise: {', '.join(valid_periods)}"
)
# Obține datele de trenduri
result = await DashboardService.get_trends(int(company), period, luna=luna, an=an, request=request)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
# Return as TrendsResponse
return TrendsResponse(**result_dict)
except ValueError as e:
logger.error(f"Value error in trends endpoint: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea trendurilor: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea trendurilor: {str(e)}")
@router.get("/detailed-data")
async def get_detailed_data(
company: str = Query(description="Codul firmei"),
data_type: str = Query(description="Tipul de date: clients, suppliers, treasury"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
page: int = Query(default=1, ge=1),
page_size: int = Query(default=25, ge=1, le=100),
search: str = Query(default="", description="Termen de căutare"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține date detaliate pentru tabelele din dashboard
"""
logger.info(f"[ROUTER] detailed-data called: company={company}, data_type={data_type}")
try:
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
logger.info(f"[ROUTER] Calling DashboardService.get_detailed_data")
result = await DashboardService.get_detailed_data(
company=company,
data_type=data_type,
luna=luna,
an=an,
page=page,
page_size=page_size,
search=search
)
logger.info(f"[ROUTER] Service returned: {len(result.get('data', []))} rows")
return result
except Exception as e:
logger.error(f"Eroare la obținerea datelor detaliate: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/performance")
async def get_performance(
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m|ytd|12m)$", description="Perioada pentru analiză"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează date performanță pentru perioada selectată
- Necesită autentificare JWT
- Returnează grafice încasări vs plăți pentru perioada selectată
- Calculează indicatori: rata încasării, cash conversion, working capital
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_performance_data(company, period)
# Convert to Chart.js compatible format
return {
"labels": result.get("labels", []),
"datasets": [{
"data": result.get("data", []),
"label": result.get("label", "Performance"),
"borderColor": result.get("borderColor", "#3B82F6"),
"backgroundColor": result.get("backgroundColor", "rgba(59, 130, 246, 0.1)"),
"tension": 0.4
}]
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea datelor de performanță: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea datelor de performanță: {str(e)}")
@router.get("/cashflow")
async def get_cashflow(
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m)$", description="Perioada pentru previziune"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează previziune cash flow pentru perioada selectată
- Necesită autentificare JWT
- Analizează scadențele viitoare pentru calculul cash flow-ului
- Identifică zilele critice cu deficit de cash
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_cashflow_forecast(company, period)
# Convert to Chart.js compatible format
return {
"labels": result.get("labels", []),
"datasets": [{
"data": result.get("data", []),
"label": result.get("label", "Cash Flow"),
"borderColor": result.get("borderColor", "#10B981"),
"backgroundColor": result.get("backgroundColor", "rgba(16, 185, 129, 0.1)"),
"tension": 0.4
}]
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea previziunii cash flow: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea previziunii cash flow: {str(e)}")
@router.get("/maturity")
async def get_maturity_analysis(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
period: str = Query("7d", regex="^(7d|1m|3m|6m|12m|all)$", description="Orizont de planificare pentru analiza scadențelor"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează analiza scadențelor pentru orizontul de planificare selectat
- Necesită autentificare JWT
- Logică: Include TOATE restanțele + scadențele viitoare din perioada selectată
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Perioade disponibile:
* 7d: Toate restanțele + scadențe următoarelor 7 zile
* 1m: Toate restanțele + scadențe următoarelor 30 zile
* 3m: Toate restanțele + scadențe următoarelor 90 zile
* 6m: Toate restanțele + scadențe următoarelor 180 zile
* 12m: Toate restanțele + scadențe următoarelor 365 zile
* all: Toate soldurile (fără filtru)
- Compară scadențele clienți vs furnizori
- Calculează balanța și oferă recomandări
- Returnează metadate cu statistici complete
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_maturity_analysis(company, period, luna=luna, an=an, request=request)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea analizei scadențelor: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea analizei scadențelor: {str(e)}")
@router.get("/monthly-flows")
async def get_monthly_flows(
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează fluxurile lunare pentru firma selectată
- Necesită autentificare JWT
- Returnează date pentru analiza fluxurilor lunare
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_monthly_flows(company, luna=luna, an=an)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea fluxurilor lunare: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea fluxurilor lunare: {str(e)}")
@router.get("/treasury-breakdown")
async def get_treasury_breakdown(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează defalcarea trezoreriei pentru firma selectată
- Necesită autentificare JWT
- Returnează distribuția soldurilor pe conturi și tipuri
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_treasury_breakdown(company, luna=luna, an=an, request=request)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării trezoreriei: {str(e)}")
@router.get("/net-balance-breakdown")
async def get_net_balance_breakdown(
request: Request,
company: int = Query(..., description="ID-ul firmei"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează defalcarea balanței nete pentru firma selectată
- Necesită autentificare JWT
- Returnează analiza detaliată a balanței nete
- luna/an: perioada contabilă de referință (dacă nu sunt specificate, folosește ultima perioadă)
- Include metadata cache pentru Telegram Bot (X-Include-Cache-Metadata header)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_net_balance_breakdown(company, luna=luna, an=an, request=request)
# Convert to dict if needed
result_dict = result.dict() if hasattr(result, 'dict') else result
# Add cache metadata if requested (for Telegram Bot)
include_metadata = request.headers.get('X-Include-Cache-Metadata', '').lower() == 'true'
if include_metadata:
cache_hit = getattr(request.state, 'cache_hit', False)
response_time = getattr(request.state, 'response_time_ms', 0)
cache_source = getattr(request.state, 'cache_source', None)
result_dict['cache_hit'] = cache_hit
result_dict['response_time_ms'] = response_time
# Always include cache_source, even if None
result_dict['cache_source'] = cache_source
return result_dict
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea defalcării balanței nete: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea defalcării balanței nete: {str(e)}")
@router.get("/current-period")
async def get_current_period(
company: int = Query(..., description="ID-ul firmei"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Returnează perioada curentă (an și lună) din calendarul Oracle
- Necesită autentificare JWT
- Returnează anul, luna și perioada curentă în format YYYY-MM
- Folosit pentru afișarea lunii curente în dashboard
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if str(company) not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await DashboardService.get_current_period(company)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Eroare la obținerea perioadei curente: {str(e)}")
raise HTTPException(status_code=500, detail=f"Eroare la obținerea perioadei curente: {str(e)}")

View File

@@ -0,0 +1,128 @@
"""
API Router pentru facturi
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user, require_company_access
from shared.auth.models import CurrentUser
from ..models.invoice import InvoiceFilter, InvoiceListResponse, InvoiceSummary
from ..services.invoice_service import InvoiceService
router = APIRouter()
@router.get("/", response_model=InvoiceListResponse)
async def get_invoices(
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
cont: Optional[str] = Query(None, description="Filtru după cont contabil"),
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
min_amount: Optional[float] = Query(None, description="Suma minimă"),
max_amount: Optional[float] = Query(None, description="Suma maximă"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține lista de facturi pentru o firmă
- Necesită autentificare JWT
- Utilizatorul trebuie să aibă acces la firma specificată
- Suportă filtrare după luna/an contabil și paginare
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
filter_params = InvoiceFilter(
company=company,
partner_type=partner_type,
luna=luna,
an=an,
partner_name=partner_name,
cont=cont,
only_unpaid=only_unpaid,
min_amount=min_amount,
max_amount=max_amount,
page=page,
page_size=page_size
)
result = await InvoiceService.get_invoices(filter_params, current_user.username)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea facturilor: {str(e)}")
@router.get("/summary", response_model=InvoiceSummary)
async def get_invoices_summary(
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
current_user: CurrentUser = Depends(get_current_user)
):
"""Obține rezumatul facturilor pentru dashboard"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await InvoiceService.get_invoice_summary(company, partner_type, current_user.username)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea rezumatului facturilor: {str(e)}")
@router.get("/{invoice_number}")
async def get_invoice_details(
invoice_number: str,
company: str = Query(description="Codul firmei"),
current_user: CurrentUser = Depends(get_current_user)
):
"""Obține detaliile unei facturi specifice"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
result = await InvoiceService.get_invoice_details(company, invoice_number, current_user.username)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea detaliilor facturii: {str(e)}")
@router.get("/export/{format}")
async def export_invoices(
format: str,
company: str = Query(description="Codul firmei"),
partner_type: str = Query("CLIENTI", description="CLIENTI sau FURNIZORI"),
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
only_unpaid: bool = Query(True, description="Doar facturile neachitate"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Export facturi în format specificat (excel, pdf, csv)
Această funcție va fi implementată în viitor
"""
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
# Verifică formatul
if format not in ["excel", "pdf", "csv"]:
raise HTTPException(status_code=400, detail="Format invalid. Formatele suportate sunt: excel, pdf, csv")
# Pentru moment, returnează o eroare că funcția nu este implementată
raise HTTPException(status_code=501, detail=f"Export în format {format} nu este încă implementat")

View File

@@ -0,0 +1,117 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional, List
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..models.treasury import RegisterFilter, RegisterListResponse
from ..services.treasury_service import TreasuryService
router = APIRouter()
@router.get("/bank-cash-register", response_model=RegisterListResponse)
async def get_bank_cash_register(
company: str = Query(description="Codul firmei"),
register_type: Optional[str] = Query(None, description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA sau None pentru toate"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna contabilă (1-12)"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="Anul contabil"),
date_from: Optional[str] = Query(None, description="Data început (YYYY-MM-DD)"),
date_to: Optional[str] = Query(None, description="Data sfârșit (YYYY-MM-DD)"),
partner_name: Optional[str] = Query(None, description="Filtru nume partener"),
bank_account: Optional[str] = Query(None, description="Filtru cont bancă/casă (bancasa)"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=10000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține registrul de casă și bancă
- Necesită autentificare JWT
- Suportă filtrare pe tip registru: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA
- Suportă filtrare și paginare
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
# Validează register_type dacă e specificat
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
if register_type and register_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
)
# Convertește datele
date_from_obj = None
date_to_obj = None
if date_from:
try:
date_from_obj = date.fromisoformat(date_from)
except ValueError:
raise HTTPException(status_code=400, detail="Format dată început invalid")
if date_to:
try:
date_to_obj = date.fromisoformat(date_to)
except ValueError:
raise HTTPException(status_code=400, detail="Format dată sfârșit invalid")
filter_params = RegisterFilter(
company=company,
register_type=register_type,
luna=luna,
an=an,
date_from=date_from_obj,
date_to=date_to_obj,
partner_name=partner_name,
bank_account=bank_account,
page=page,
page_size=page_size
)
result = await TreasuryService.get_bank_cash_register(filter_params, current_user.username)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea registrului: {str(e)}")
@router.get("/bank-cash-accounts", response_model=List[str])
async def get_bank_cash_accounts(
company: str = Query(description="Codul firmei"),
register_type: str = Query(description="Tipul registrului: BANCA_LEI, BANCA_VALUTA, CASA_LEI, CASA_VALUTA"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține lista distinctă de conturi bancă/casă pentru dropdown
- Necesită autentificare JWT
- Returnează lista de valori bancasa pentru tipul de registru selectat
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(status_code=403, detail=f"Nu aveți acces la firma {company}")
# Validează register_type
valid_types = ['BANCA_LEI', 'BANCA_VALUTA', 'CASA_LEI', 'CASA_VALUTA']
if register_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"Tip registru invalid. Valori acceptate: {', '.join(valid_types)}"
)
result = await TreasuryService.get_bank_cash_accounts(int(company), register_type)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Eroare la obținerea conturilor: {str(e)}")

View File

@@ -0,0 +1,90 @@
"""
API Router for Trial Balance (Balanță de Verificare)
Refactored to use service layer with caching
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from datetime import date
# import sys # Removed - no longer needed
import os
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from ..models.trial_balance import TrialBalanceResponse
from ..services.trial_balance_service import TrialBalanceService
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/", response_model=TrialBalanceResponse)
async def get_trial_balance(
company: str = Query(description="Codul firmei (ID)"),
luna: Optional[int] = Query(None, ge=1, le=12, description="Luna (1-12), default: luna curentă"),
an: Optional[int] = Query(None, ge=2000, le=2100, description="An, default: anul curent"),
cont_filter: Optional[str] = Query(None, description="Filtru număr cont (ex: '512', '4111')"),
denumire_filter: Optional[str] = Query(None, description="Filtru denumire cont (partial match, case-insensitive)"),
sort_by: str = Query("CONT", description="Coloană pentru sortare"),
sort_order: str = Query("asc", description="Ordinea sortării (asc | desc)"),
page: int = Query(1, ge=1, description="Pagina"),
page_size: int = Query(50, ge=1, le=1000000, description="Mărimea paginii"),
current_user: CurrentUser = Depends(get_current_user)
):
"""
Obține balanța de verificare sintetică pentru o firmă
- Necesită autentificare JWT
- Utilizatorul trebuie să aibă acces la firma specificată
- Suportă filtrare după cont și denumire
- Suportă paginare și sortare
- **CACHED 10 min** - folosește sistem cache two-tier (L1 Memory + L2 SQLite)
"""
try:
# Verifică dacă utilizatorul are acces la firma specificată
if company not in current_user.companies:
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {company}"
)
# Setează valorile implicite pentru lună și an (luna și anul curent)
current_date = date.today()
if luna is None:
luna = current_date.month
if an is None:
an = current_date.year
# Convert company to int
company_id = int(company)
# Call service (with caching) - all business logic moved to service
data = await TrialBalanceService.get_trial_balance(
company_id=company_id,
luna=luna,
an=an,
cont_filter=cont_filter,
denumire_filter=denumire_filter,
sort_by=sort_by,
sort_order=sort_order,
page=page,
page_size=page_size,
username=current_user.username
)
return TrialBalanceResponse(
success=True,
data=data
)
except ValueError as e:
# Schema not found or validation error
logger.error(f"Validation error in trial balance: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# Log unexpected errors
logger.error(f"Error fetching trial balance: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Eroare la obținerea balanței de verificare: {str(e)}"
)

View File

@@ -0,0 +1,77 @@
"""
Calendar service for fetching available accounting periods
"""
# import sys # Removed - no longer needed
import os
from shared.database.oracle_pool import oracle_pool
from ..models.calendar import CalendarPeriod, CalendarPeriodsResponse
from ..cache.decorators import cached
import logging
logger = logging.getLogger(__name__)
class CalendarService:
"""Service for calendar/accounting period operations"""
# Romanian month names for display
MONTH_NAMES_RO = [
"Ianuarie", "Februarie", "Martie", "Aprilie", "Mai", "Iunie",
"Iulie", "August", "Septembrie", "Octombrie", "Noiembrie", "Decembrie"
]
@staticmethod
@cached(cache_type='schema', key_params=['company_id'])
async def _get_schema(company_id: int) -> str:
"""Get schema for company (CACHED 24h)"""
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
""", {'company_id': company_id})
result = cursor.fetchone()
return result[0] if result else None
@staticmethod
@cached(cache_type='calendar_periods', key_params=['company_id'])
async def get_available_periods(company_id: int) -> CalendarPeriodsResponse:
"""
Get all available accounting periods for a company (CACHED 1h)
Returns periods ordered by year DESC, month DESC with Romanian month names.
"""
schema = await CalendarService._get_schema(company_id)
if not schema:
logger.warning(f"Schema not found for company {company_id}")
return CalendarPeriodsResponse(periods=[], current_period=None, total_count=0)
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(f"""
SELECT anul, luna
FROM {schema}.calendar
ORDER BY anul DESC, luna DESC
""")
rows = cursor.fetchall()
periods = []
for row in rows:
an, luna = row[0], row[1]
month_name = CalendarService.MONTH_NAMES_RO[luna - 1]
periods.append(CalendarPeriod(
an=an,
luna=luna,
display_name=f"{month_name} {an}"
))
current_period = periods[0] if periods else None
logger.info(f"Loaded {len(periods)} accounting periods for company {company_id}")
return CalendarPeriodsResponse(
periods=periods,
current_period=current_period,
total_count=len(periods)
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,324 @@
"""
Service pentru logica facturi - Portează query-urile din aplicația Flask
"""
# import sys # Removed - no longer needed
import os
from shared.database.oracle_pool import oracle_pool
from typing import List, Tuple
from ..models.invoice import Invoice, InvoiceFilter, InvoiceListResponse, InvoiceSummary
from ..cache.decorators import cached
from decimal import Decimal
import logging
logger = logging.getLogger(__name__)
class InvoiceService:
"""Service pentru gestionarea facturilor"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id'])
async def _get_schema(company_id: int) -> str:
"""Obține schema pentru company_id (CACHED PERMANENT)"""
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
@cached(cache_type='invoices', key_params=['filter_params', 'username'])
async def get_invoices(filter_params: InvoiceFilter, username: str) -> InvoiceListResponse:
"""
Obține lista de facturi - Query simplu pentru afișare în tabel (CACHED 10 min)
"""
company_id = int(filter_params.company)
schema = await InvoiceService._get_schema(company_id)
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Determină conturile în funcție de partner_type
if filter_params.partner_type == "CLIENTI":
conturi = "'4111', '461'"
elif filter_params.partner_type == "FURNIZORI":
conturi = "'401', '404', '462'"
else:
conturi = "'4111'" # default
# Determine period to use: from params or MAX from calendar
if filter_params.luna and filter_params.an:
period_condition = "vp.an = :an AND vp.luna = :luna"
use_param_period = True
else:
period_condition = f"""vp.an = (SELECT anul FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))
AND vp.luna = (SELECT luna FROM {schema}.calendar WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar))"""
use_param_period = False
# Query cu calculele corecte pentru solduri
base_query = f"""
SELECT
vp.NUME,
vp.NRACT,
vp.DATAACT,
vp.DATASCAD,
vp.CONTRACT,
vp.COD_FISCAL,
vp.REG_COMERT,
CASE
WHEN vp.CONT IN ('4111','461') THEN vp.PRECDEB + vp.DEBIT -- Total facturat clienți
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECCRED + vp.CREDIT -- Total facturat furnizori
END as total_facturat,
CASE
WHEN vp.CONT IN ('4111','461') THEN vp.PRECCRED + vp.CREDIT -- Încasat clienți
WHEN vp.CONT IN ('401','404','462') THEN vp.PRECDEB + vp.DEBIT -- Achitat furnizori
END as achitat,
CASE
WHEN vp.CONT IN ('4111','461') THEN
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT) -- Sold clienți
WHEN vp.CONT IN ('401','404','462') THEN
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT) -- Sold furnizori
END as sold,
vp.CONT,
NVL(vp.NUME_VAL, 'RON') as valuta,
CASE
WHEN vp.DATASCAD < SYSDATE THEN 'restant'
ELSE 'in_termen'
END as status
FROM {schema}.vireg_parteneri vp
WHERE {period_condition}
AND (
(:partner_type = 'CLIENTI' AND vp.cont IN ('4111', '461'))
OR
(:partner_type = 'FURNIZORI' AND vp.cont IN ('401', '404', '462'))
)
"""
params = {'partner_type': filter_params.partner_type}
# Add period params if using explicit period
if use_param_period:
params['an'] = filter_params.an
params['luna'] = filter_params.luna
if filter_params.partner_name:
base_query += " AND UPPER(vp.nume) LIKE UPPER(:partner_name)"
params['partner_name'] = f"%{filter_params.partner_name}%"
if filter_params.cont:
base_query += " AND vp.cont = :cont"
params['cont'] = filter_params.cont
if filter_params.min_amount:
base_query += " AND total_facturat >= :min_amount"
params['min_amount'] = filter_params.min_amount
if filter_params.max_amount:
base_query += " AND total_facturat <= :max_amount"
params['max_amount'] = filter_params.max_amount
if filter_params.only_unpaid:
# Nu putem folosi aliasul "sold" în WHERE în Oracle, trebuie să repetăm calculul
base_query += """ AND (
CASE
WHEN vp.CONT IN ('4111','461') THEN
(vp.PRECDEB + vp.DEBIT) - (vp.PRECCRED + vp.CREDIT)
WHEN vp.CONT IN ('401','404','462') THEN
(vp.PRECCRED + vp.CREDIT) - (vp.PRECDEB + vp.DEBIT)
END
) > 0"""
# Count total pentru paginare
count_query = f"SELECT COUNT(*) FROM ({base_query})"
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# Query pentru TOTAL SOLD din TOATE facturile filtrate (nu doar pagina curentă)
total_sold_query = f"""
SELECT NVL(SUM(sold), 0) as total_sold
FROM ({base_query})
"""
cursor.execute(total_sold_query, params)
total_sold_result = cursor.fetchone()
total_sold_all = Decimal(str(total_sold_result[0])) if total_sold_result else Decimal('0.00')
# Get accounting period - use params if provided, else from calendar
if use_param_period:
accounting_period = {
'an': filter_params.an,
'luna': filter_params.luna
}
else:
period_query = f"""
SELECT anul, luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar)
"""
cursor.execute(period_query)
period_result = cursor.fetchone()
accounting_period = {
'an': period_result[0] if period_result else None,
'luna': period_result[1] if period_result else None
}
# Adaugă ORDER BY și paginare - Ordonare cronologică (DATAACT, NRACT, NUME)
base_query += " ORDER BY vp.DATAACT ASC, vp.NRACT ASC, vp.NUME"
# Paginare Oracle
offset = (filter_params.page - 1) * filter_params.page_size
limit = offset + filter_params.page_size
paginated_query = f"""
SELECT * FROM (
SELECT ROWNUM as rn, t.* FROM ({base_query}) t WHERE ROWNUM <= :limit
) WHERE rn > :offset
"""
params['offset'] = offset
params['limit'] = limit
cursor.execute(paginated_query, params)
rows = cursor.fetchall()
# Procesează rezultatele cu structura nouă
invoices = []
total_amount = Decimal('0.00')
for row in rows:
# Skip ROWNUM, extrage valorile din query-ul nou
nume = row[1]
nract = row[2]
dataact = row[3]
datascad = row[4]
contract = row[5]
cod_fiscal = row[6]
reg_comert = row[7]
total_facturat = Decimal(str(row[8] or 0))
achitat = Decimal(str(row[9] or 0))
sold = Decimal(str(row[10] or 0))
cont = row[11]
valuta = row[12] or 'RON'
status = row[13]
invoice_data = {
'nume': nume or '',
'nract': nract or 0,
'dataact': dataact,
'datascad': datascad,
'contract': contract,
'cod_fiscal': cod_fiscal,
'reg_comert': reg_comert,
'cont': cont,
'totctva': total_facturat,
'achitat': achitat,
'soldfinal': sold,
'valuta': valuta
}
invoice = Invoice(**invoice_data)
invoices.append(invoice)
total_amount += total_facturat
return InvoiceListResponse(
invoices=invoices,
total_count=total_count,
filtered_count=len(invoices),
total_amount=total_amount,
page=filter_params.page,
page_size=filter_params.page_size,
has_more=len(invoices) == filter_params.page_size,
accounting_period=accounting_period,
# Total sold din TOATE facturile filtrate
total_sold_all=total_sold_all
)
@staticmethod
async def get_invoice_details(company: str, invoice_number: str, username: str) -> Invoice:
"""
Obține detaliile unei facturi specifice
"""
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Obține schema din v_nom_firme bazat pe id_firma
company_id = int(company)
schema_query = "SELECT schema FROM CONTAFIN_ORACLE.v_nom_firme WHERE id_firma = :company_id"
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema nu a fost găsită pentru id_firma {company_id}")
schema = schema_result[0]
# Query simplu pentru detalii factură
detail_query = f"""
SELECT
NUME,
NRACT,
DATAACT,
DATASCAD,
CONTRACT,
COD_FISCAL,
REG_COMERT,
PRECDEB,
PRECCRED,
DEBIT,
CREDIT,
CONT
FROM {schema}.vireg_parteneri
WHERE nract = :invoice_number
AND an = (select anul from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
AND luna = (select luna from {schema}.calendar where anul*12+luna = (select max(anul*12+luna) as anmax from {schema}.calendar))
"""
cursor.execute(detail_query, {'invoice_number': invoice_number})
row = cursor.fetchone()
if not row:
raise ValueError(f"Factura {invoice_number} nu a fost găsită")
# Extrage valorile
nume = row[0]
nract = row[1]
dataact = row[2]
datascad = row[3]
contract = row[4]
cod_fiscal = row[5]
reg_comert = row[6]
precdeb = Decimal(str(row[7] or 0))
preccred = Decimal(str(row[8] or 0))
debit = Decimal(str(row[9] or 0))
credit = Decimal(str(row[10] or 0))
cont = row[11]
# Calculează valorile în funcție de tipul contului
if cont in ('4111', '461'): # CLIENTI
totctva = precdeb + debit
achitat = preccred + credit
soldfinal = precdeb - preccred + debit - credit
else: # FURNIZORI
totctva = preccred + credit
achitat = precdeb + debit
soldfinal = preccred - precdeb + credit - debit
invoice_data = {
'nume': nume or '',
'nract': nract or 0,
'dataact': dataact,
'datascad': datascad,
'contract': contract,
'cod_fiscal': cod_fiscal,
'reg_comert': reg_comert,
'totctva': totctva,
'achitat': achitat,
'soldfinal': soldfinal
}
return Invoice(**invoice_data)

View File

@@ -0,0 +1,410 @@
# import sys # Removed - no longer needed
import os
import oracledb
from shared.database.oracle_pool import oracle_pool
from ..models.treasury import BankCashRegister, RegisterFilter, RegisterListResponse, AccountingPeriod
from ..cache.decorators import cached
from decimal import Decimal
from typing import Optional, List, Tuple, Any
import logging
logger = logging.getLogger(__name__)
class TreasuryService:
"""Service pentru trezorerie - registru casă și bancă"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id'])
async def _get_schema(company_id: int) -> str:
"""Obține schema pentru company_id (CACHED PERMANENT)"""
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
def _get_view_query(schema: str, register_type: Optional[str] = None) -> str:
"""
Construiește query-ul pentru view-ul vbancasa corespunzător.
Dacă register_type este None, returnează UNION ALL pentru toate tipurile.
NU se filtrează pe incasari/plati > 0 - se afișează TOATE înregistrările!
"""
view_configs = {
'BANCA_LEI': {
'view': f'{schema}.vbancasa_5121_cum',
'incasari_col': 'incasari',
'plati_col': 'plati',
'valuta': "'RON'",
'tip': "'BANCA LEI'"
},
'BANCA_VALUTA': {
'view': f'{schema}.vbancasa_5124_cum',
'incasari_col': 'incasval',
'plati_col': 'platival',
'valuta': "COALESCE(numeval, 'EUR')",
'tip': "'BANCA VALUTA'"
},
'CASA_LEI': {
'view': f'{schema}.vbancasa_5311_cum',
'incasari_col': 'incasari',
'plati_col': 'plati',
'valuta': "'RON'",
'tip': "'CASA LEI'"
},
'CASA_VALUTA': {
'view': f'{schema}.vbancasa_5314_cum',
'incasari_col': 'incasval',
'plati_col': 'platival',
'valuta': "COALESCE(numeval, 'EUR')",
'tip': "'CASA VALUTA'"
}
}
def build_select(config):
# NU se filtrează - se afișează TOATE înregistrările
# SOLD CUMULAT: Running balance per bancasa using window function
# NULL-date rows (opening balance) come first due to NULLS FIRST
return f"""
SELECT
nume, nract, dataact, bancasa,
{config['incasari_col']} as incasari,
{config['plati_col']} as plati,
SUM({config['incasari_col']} - {config['plati_col']}) OVER (
PARTITION BY bancasa
ORDER BY dataact ASC NULLS FIRST, nract ASC NULLS FIRST
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) as sold,
{config['valuta']} as valuta,
{config['tip']} as tip_registru,
explicatia
FROM {config['view']}
"""
if register_type and register_type in view_configs:
return build_select(view_configs[register_type])
else:
# UNION ALL pentru toate tipurile
queries = [build_select(cfg) for cfg in view_configs.values()]
return " UNION ALL ".join(queries)
@staticmethod
@cached(cache_type='treasury', key_params=['filter_params', 'username'])
async def get_bank_cash_register(filter_params: RegisterFilter, username: str) -> RegisterListResponse:
"""
Obține registrul de casă și bancă din vbancasa views (CACHED 10 min)
IMPORTANT: PACK_SESIUNE.SETAN și SETLUNA trebuie executate în ACEEAȘI
tranzacție cu SELECT-ul din vbancasa* views!
Folosim un bloc PL/SQL anonim care:
1. Obține anul și luna curentă din calendar
2. Apelează PACK_SESIUNE.SETAN și SETLUNA
3. Execută SELECT-ul din vbancasa*
Toate în aceeași tranzacție!
"""
company_id = int(filter_params.company)
schema = await TreasuryService._get_schema(company_id)
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Construiește query-ul pentru tipul de registru selectat
base_select = TreasuryService._get_view_query(schema, filter_params.register_type)
# Construiește WHERE conditions
where_conditions = []
# Date filter preserves NULL-date rows (opening balance)
# for correct cumulative sum calculation
if filter_params.date_from and filter_params.date_to:
where_conditions.append(f"(dataact IS NULL OR (dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD') AND dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD')))")
elif filter_params.date_from:
where_conditions.append(f"(dataact IS NULL OR dataact >= TO_DATE('{filter_params.date_from.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
elif filter_params.date_to:
where_conditions.append(f"(dataact IS NULL OR dataact <= TO_DATE('{filter_params.date_to.strftime('%Y-%m-%d')}', 'YYYY-MM-DD'))")
if filter_params.partner_name:
# Escape single quotes pentru SQL
partner_escaped = filter_params.partner_name.replace("'", "''")
where_conditions.append(f"UPPER(nume) LIKE UPPER('%{partner_escaped}%')")
if filter_params.bank_account:
# Escape single quotes pentru SQL
bank_escaped = filter_params.bank_account.replace("'", "''")
where_conditions.append(f"bancasa = '{bank_escaped}'")
where_clause = ""
if where_conditions:
where_clause = " WHERE " + " AND ".join(where_conditions)
# Paginare Oracle
offset = (filter_params.page - 1) * filter_params.page_size
limit_val = filter_params.page_size
# Determine period to use: from params or MAX from calendar
if filter_params.luna and filter_params.an:
use_param_period = True
period_select = f"""
v_an := :param_an;
v_luna := :param_luna;
"""
else:
use_param_period = False
period_select = f"""
SELECT anul, luna INTO v_an, v_luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
"""
# Bloc PL/SQL anonim care face totul într-o singură tranzacție:
# 1. Obține anul și luna din params sau calendar
# 2. Setează PACK_SESIUNE.SETAN și SETLUNA
# 3. Returnează datele prin REF CURSOR
# IMPORTANT: Folosim ROW_NUMBER() pentru paginare corectă cu ORDER BY NULLS FIRST
plsql_block = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
v_cursor SYS_REFCURSOR;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
-- Setează contextul de sesiune (OBLIGATORIU înainte de SELECT din vbancasa*)
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Return accounting period
:out_an := v_an;
:out_luna := v_luna;
-- Returnează datele prin cursor cu ROW_NUMBER pentru paginare corectă
-- Pentru rânduri cu dataact=NULL (solduri precedente), sortare după bancasa
-- Pentru rânduri cu date, sortare după data, număr, bancasa
OPEN :result_cursor FOR
SELECT * FROM (
SELECT t.*, ROW_NUMBER() OVER (
ORDER BY dataact ASC NULLS FIRST,
CASE WHEN dataact IS NULL THEN bancasa END ASC,
nract ASC NULLS FIRST,
bancasa ASC
) as rn
FROM ({base_select}) t{where_clause}
) WHERE rn > {offset} AND rn <= {offset + limit_val};
END;
"""
# Creează cursor pentru rezultate (oracledb.CURSOR pentru REF CURSOR)
result_cursor = cursor.var(oracledb.CURSOR)
out_an = cursor.var(int)
out_luna = cursor.var(int)
# Build params dict
exec_params = {'result_cursor': result_cursor, 'out_an': out_an, 'out_luna': out_luna}
if use_param_period:
exec_params['param_an'] = filter_params.an
exec_params['param_luna'] = filter_params.luna
# Execută blocul PL/SQL cu REF CURSOR
cursor.execute(plsql_block, exec_params)
# Get accounting period values
accounting_year = out_an.getvalue()
accounting_month = out_luna.getvalue()
# Obține rezultatele din cursor
ref_cursor = result_cursor.getvalue()
rows = ref_cursor.fetchall()
ref_cursor.close()
# Pentru count total, executăm alt bloc PL/SQL
count_plsql = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
SELECT COUNT(*) INTO :total_count FROM ({base_select}) sub{where_clause};
END;
"""
total_count_var = cursor.var(int)
count_params = {'total_count': total_count_var}
if use_param_period:
count_params['param_an'] = filter_params.an
count_params['param_luna'] = filter_params.luna
cursor.execute(count_plsql, count_params)
total_count = total_count_var.getvalue()
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
# sold_precedent = suma sold pentru rânduri cu dataact IS NULL
# total_incasari = suma incasari pentru rânduri cu dataact IS NOT NULL
# total_plati = suma plati pentru rânduri cu dataact IS NOT NULL
# Notă: where_clause poate fi gol sau poate conține "WHERE ..."
# Dacă e gol, adăugăm WHERE; dacă nu, adăugăm AND
dataact_null_cond = " AND dataact IS NULL" if where_clause else " WHERE dataact IS NULL"
dataact_not_null_cond = " AND dataact IS NOT NULL" if where_clause else " WHERE dataact IS NOT NULL"
totals_plsql = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Obține anul și luna din parametri sau calendar
{period_select}
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Sold precedent: suma sold pentru rânduri fără dată (opening balance)
SELECT NVL(SUM(sold), 0) INTO :sold_precedent_all
FROM ({base_select}) sub{where_clause}{dataact_null_cond};
-- Total încasări: suma incasari pentru rânduri cu dată (transactions)
SELECT NVL(SUM(incasari), 0) INTO :total_incasari_all
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
-- Total plăți: suma plati pentru rânduri cu dată (transactions)
SELECT NVL(SUM(plati), 0) INTO :total_plati_all
FROM ({base_select}) sub{where_clause}{dataact_not_null_cond};
END;
"""
sold_precedent_all_var = cursor.var(oracledb.NUMBER)
total_incasari_all_var = cursor.var(oracledb.NUMBER)
total_plati_all_var = cursor.var(oracledb.NUMBER)
totals_params = {
'sold_precedent_all': sold_precedent_all_var,
'total_incasari_all': total_incasari_all_var,
'total_plati_all': total_plati_all_var
}
if use_param_period:
totals_params['param_an'] = filter_params.an
totals_params['param_luna'] = filter_params.luna
cursor.execute(totals_plsql, totals_params)
sold_precedent_all = Decimal(str(sold_precedent_all_var.getvalue() or 0))
total_incasari_all = Decimal(str(total_incasari_all_var.getvalue() or 0))
total_plati_all = Decimal(str(total_plati_all_var.getvalue() or 0))
sold_final_all = sold_precedent_all + total_incasari_all - total_plati_all
# Procesare rezultate
registers = []
total_incasari = Decimal('0.00')
total_plati = Decimal('0.00')
for row in rows:
# Coloane: nume, nract, dataact, bancasa, incasari, plati, sold, valuta, tip_registru, explicatia, rn
# row[0-9] = date, row[10] = rn (ROW_NUMBER la final)
register_data = BankCashRegister(
nume=row[0] or '',
nract=row[1],
dataact=row[2],
nume_cont_bancar=row[3] or '',
incasari=Decimal(str(row[4] or 0)),
plati=Decimal(str(row[5] or 0)),
sold=Decimal(str(row[6] or 0)),
valuta=row[7],
tip_registru=row[8],
explicatia=row[9] or ''
)
registers.append(register_data)
total_incasari += register_data.incasari
total_plati += register_data.plati
logger.info(f"Treasury query for company {company_id}, type={filter_params.register_type}: {len(registers)} records, total={total_count}")
return RegisterListResponse(
registers=registers,
total_count=total_count,
filtered_count=len(registers),
total_incasari=total_incasari,
total_plati=total_plati,
page=filter_params.page,
page_size=filter_params.page_size,
has_more=len(registers) == filter_params.page_size,
accounting_period=AccountingPeriod(an=accounting_year, luna=accounting_month),
# Totaluri din TOATE înregistrările filtrate
sold_precedent_all=sold_precedent_all,
total_incasari_all=total_incasari_all,
total_plati_all=total_plati_all,
sold_final_all=sold_final_all
)
@staticmethod
@cached(cache_type='treasury', key_params=['company_id', 'register_type'])
async def get_bank_cash_accounts(company_id: int, register_type: str) -> List[str]:
"""
Obține lista distinctă de conturi bancă/casă (bancasa) pentru dropdown.
Cached pentru performanță.
IMPORTANT: Trebuie să setăm contextul PACK_SESIUNE înainte de a accesa vbancasa views!
"""
schema = await TreasuryService._get_schema(company_id)
# Map register_type to view
view_map = {
'BANCA_LEI': f'{schema}.vbancasa_5121_cum',
'BANCA_VALUTA': f'{schema}.vbancasa_5124_cum',
'CASA_LEI': f'{schema}.vbancasa_5311_cum',
'CASA_VALUTA': f'{schema}.vbancasa_5314_cum'
}
if register_type not in view_map:
return []
view_name = view_map[register_type]
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# PL/SQL block to set session context and get accounts
plsql_block = f"""
DECLARE
v_an NUMBER;
v_luna NUMBER;
BEGIN
-- Get current year and month from calendar
SELECT anul, luna INTO v_an, v_luna
FROM {schema}.calendar
WHERE anul*12+luna = (SELECT MAX(anul*12+luna) FROM {schema}.calendar);
-- Set session context (REQUIRED before accessing vbancasa* views)
{schema}.PACK_SESIUNE.SETAN(v_an);
{schema}.PACK_SESIUNE.SETLUNA(v_luna);
-- Return accounts via cursor
OPEN :result_cursor FOR
SELECT DISTINCT bancasa
FROM {view_name}
WHERE bancasa IS NOT NULL
ORDER BY bancasa;
END;
"""
result_cursor = cursor.var(oracledb.CURSOR)
cursor.execute(plsql_block, {'result_cursor': result_cursor})
ref_cursor = result_cursor.getvalue()
rows = ref_cursor.fetchall()
ref_cursor.close()
accounts = [row[0] for row in rows if row[0]]
logger.info(f"Found {len(accounts)} bank/cash accounts for company {company_id}, type={register_type}")
return accounts

View File

@@ -0,0 +1,217 @@
"""
Service pentru Trial Balance (Balanță de Verificare) - Query VBAL VIEW
Refactored to use caching system for optimal performance
"""
# import sys # Removed - no longer needed
import os
from shared.database.oracle_pool import oracle_pool
from typing import Dict, Any
from ..models.trial_balance import (
TrialBalanceItem,
TrialBalanceFilters,
TrialBalancePagination,
TrialBalanceResponse
)
from ..cache.decorators import cached
from decimal import Decimal
import math
import logging
logger = logging.getLogger(__name__)
class TrialBalanceService:
"""Service pentru gestionarea balanței de verificare cu cache"""
@staticmethod
@cached(cache_type='schema', key_params=['company_id'])
async def _get_schema(company_id: int) -> str:
"""
Obține schema pentru company_id (CACHED 24h)
This is cached permanently because company schemas rarely change.
"""
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
schema_query = """
SELECT schema
FROM CONTAFIN_ORACLE.v_nom_firme
WHERE id_firma = :company_id
"""
cursor.execute(schema_query, {'company_id': company_id})
schema_result = cursor.fetchone()
if not schema_result:
raise ValueError(f"Schema not found for company {company_id}")
return schema_result[0]
@staticmethod
@cached(cache_type='trial_balance', key_params=['company_id', 'luna', 'an', 'cont_filter',
'denumire_filter', 'sort_by', 'sort_order',
'page', 'page_size', 'username'])
async def get_trial_balance(
company_id: int,
luna: int,
an: int,
cont_filter: str | None,
denumire_filter: str | None,
sort_by: str,
sort_order: str,
page: int,
page_size: int,
username: str
) -> Dict[str, Any]:
"""
Obține balanța de verificare sintetică (CACHED 10 min)
Cache key includes all filter parameters to ensure unique cache entries
for different query variations.
Args:
company_id: ID firmei
luna: Luna (1-12)
an: Anul
cont_filter: Filtru număr cont (optional)
denumire_filter: Filtru denumire cont (optional)
sort_by: Coloană pentru sortare
sort_order: Ordinea sortării (asc/desc)
page: Pagina
page_size: Mărimea paginii
username: Username pentru cache tracking
Returns:
Dictionary cu items, pagination, filters_applied
"""
# Get schema (cached separately)
schema = await TrialBalanceService._get_schema(company_id)
# Validate sort_order
if sort_order.lower() not in ['asc', 'desc']:
sort_order = 'asc'
# Validate sort_by (prevent SQL injection)
valid_sort_columns = ['CONT', 'DENUMIRE', 'PRECDEB', 'PRECCRED',
'RULDEB', 'RULCRED', 'SOLDDEB', 'SOLDCRED']
if sort_by.upper() not in valid_sort_columns:
sort_by = 'CONT'
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Build base query for VBAL VIEW
base_query = f"""
SELECT
CONT,
NVL(DENUMIRE, '') as DENUMIRE,
NVL(PRECDEB, 0) as PRECDEB,
NVL(PRECCRED, 0) as PRECCRED,
NVL(RULDEB, 0) as RULDEB,
NVL(RULCRED, 0) as RULCRED,
NVL(SOLDDEB, 0) as SOLDDEB,
NVL(SOLDCRED, 0) as SOLDCRED
FROM {schema}.VBAL
WHERE AN = :an
AND LUNA = :luna
"""
params = {
'an': an,
'luna': luna
}
# Add dynamic filters
if cont_filter:
base_query += " AND CONT LIKE :cont_filter"
params['cont_filter'] = f"{cont_filter}%"
if denumire_filter:
base_query += " AND UPPER(DENUMIRE) LIKE UPPER(:denumire_filter)"
params['denumire_filter'] = f"%{denumire_filter}%"
# Count total for pagination
count_query = f"SELECT COUNT(*) FROM ({base_query})"
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# Query pentru TOTALURI din TOATE înregistrările filtrate (nu doar pagina curentă)
totals_query = f"""
SELECT
NVL(SUM(PRECDEB), 0) as total_prec_deb,
NVL(SUM(PRECCRED), 0) as total_prec_cred,
NVL(SUM(RULDEB), 0) as total_rul_deb,
NVL(SUM(RULCRED), 0) as total_rul_cred,
NVL(SUM(SOLDDEB), 0) as total_sold_deb,
NVL(SUM(SOLDCRED), 0) as total_sold_cred
FROM ({base_query})
"""
cursor.execute(totals_query, params)
totals_row = cursor.fetchone()
totals = {
"total_sold_precedent_debit": Decimal(str(totals_row[0])) if totals_row else Decimal('0.00'),
"total_sold_precedent_credit": Decimal(str(totals_row[1])) if totals_row else Decimal('0.00'),
"total_rulaj_lunar_debit": Decimal(str(totals_row[2])) if totals_row else Decimal('0.00'),
"total_rulaj_lunar_credit": Decimal(str(totals_row[3])) if totals_row else Decimal('0.00'),
"total_sold_final_debit": Decimal(str(totals_row[4])) if totals_row else Decimal('0.00'),
"total_sold_final_credit": Decimal(str(totals_row[5])) if totals_row else Decimal('0.00')
}
# Add sorting
base_query += f" ORDER BY {sort_by.upper()} {sort_order.upper()}"
# Pagination (Oracle ROWNUM with ORDER BY)
offset = (page - 1) * page_size
limit = offset + page_size
paginated_query = f"""
SELECT * FROM (
SELECT a.*, ROWNUM rnum FROM (
{base_query}
) a WHERE ROWNUM <= :limit
) WHERE rnum > :offset
"""
params['offset'] = offset
params['limit'] = limit
cursor.execute(paginated_query, params)
rows = cursor.fetchall()
# Process results
# Index: CONT(0), DENUMIRE(1), PRECDEB(2), PRECCRED(3),
# RULDEB(4), RULCRED(5), SOLDDEB(6), SOLDCRED(7), rnum(8)
items = []
for row in rows:
item = TrialBalanceItem(
cont=row[0] or '',
denumire=row[1] or '',
sold_precedent_debit=Decimal(str(row[2] or 0)),
sold_precedent_credit=Decimal(str(row[3] or 0)),
rulaj_lunar_debit=Decimal(str(row[4] or 0)),
rulaj_lunar_credit=Decimal(str(row[5] or 0)),
sold_final_debit=Decimal(str(row[6] or 0)),
sold_final_credit=Decimal(str(row[7] or 0))
)
items.append(item.dict())
# Calculate pagination
total_pages = math.ceil(total_count / page_size) if page_size > 0 else 0
# Build response
return {
"items": items,
"pagination": {
"total_items": total_count,
"total_pages": total_pages,
"current_page": page,
"page_size": page_size
},
"filters_applied": {
"luna": luna,
"an": an,
"cont_filter": cont_filter,
"denumire_filter": denumire_filter
},
# Totaluri din TOATE înregistrările filtrate (nu doar pagina curentă)
"totals": totals
}

View File

View File

@@ -0,0 +1,313 @@
"""
Session Management for Telegram Bot
This module handles session state for Telegram users, specifically managing
the active company selection for command handlers.
"""
import logging
import json
from typing import Dict, Any, Optional
from datetime import datetime
from backend.modules.telegram.db.operations import (
create_session,
get_user_active_session,
update_session_state,
delete_user_sessions
)
logger = logging.getLogger(__name__)
class ConversationSession:
"""
Manages session state for a single user.
Attributes:
telegram_user_id: Telegram user ID
session_id: UUID of the session
active_company_id: Selected company ID
active_company_name: Selected company name
active_company_cui: Selected company CUI
"""
def __init__(
self,
telegram_user_id: int,
session_id: Optional[str] = None
):
"""
Initialize a session.
Args:
telegram_user_id: Telegram user ID
session_id: Existing session ID (if resuming), or None for new session
"""
self.telegram_user_id = telegram_user_id
self.session_id = session_id
self.created_at = datetime.now()
self.updated_at = datetime.now()
# Active company for this session
self.active_company_id: Optional[int] = None
self.active_company_name: Optional[str] = None
self.active_company_cui: Optional[str] = None
def set_active_company(
self,
company_id: int,
company_name: str,
company_cui: Optional[str] = None
):
"""
Set the active company for this session.
Args:
company_id: Company ID
company_name: Company name
company_cui: Company CUI (optional)
"""
self.active_company_id = company_id
self.active_company_name = company_name
self.active_company_cui = company_cui
self.updated_at = datetime.now()
logger.info(
f"Active company set for user {self.telegram_user_id}: "
f"{company_name} (ID: {company_id})"
)
def get_active_company(self) -> Optional[Dict[str, Any]]:
"""
Get the active company information.
Returns:
Dict with company info (id, name, cui) or None if no company selected
"""
if self.active_company_id is not None:
return {
"id": self.active_company_id,
"name": self.active_company_name,
"cui": self.active_company_cui
}
return None
def clear_active_company(self):
"""
Clear the active company selection.
"""
logger.info(
f"Clearing active company for user {self.telegram_user_id} "
f"(was: {self.active_company_name})"
)
self.active_company_id = None
self.active_company_name = None
self.active_company_cui = None
self.updated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""
Serialize session to dictionary (for database storage).
Returns:
Dict representation of session
"""
return {
"telegram_user_id": self.telegram_user_id,
"session_id": self.session_id,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"active_company_id": self.active_company_id,
"active_company_name": self.active_company_name,
"active_company_cui": self.active_company_cui
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ConversationSession':
"""
Deserialize session from dictionary.
Args:
data: Dict representation of session
Returns:
ConversationSession instance
"""
session = cls(
telegram_user_id=data["telegram_user_id"],
session_id=data.get("session_id")
)
# Restore active company
session.active_company_id = data.get("active_company_id")
session.active_company_name = data.get("active_company_name")
session.active_company_cui = data.get("active_company_cui")
if "created_at" in data:
session.created_at = datetime.fromisoformat(data["created_at"])
if "updated_at" in data:
session.updated_at = datetime.fromisoformat(data["updated_at"])
return session
class SessionManager:
"""
Manages sessions for all users.
Provides methods to create, retrieve, update, and delete sessions.
Sessions are stored both in memory (for quick access) and in database
(for persistence).
"""
def __init__(self):
"""
Initialize the session manager.
"""
self._sessions: Dict[int, ConversationSession] = {}
logger.info("SessionManager initialized")
async def get_or_create_session(
self,
telegram_user_id: int
) -> ConversationSession:
"""
Get existing session for a user or create a new one.
Args:
telegram_user_id: Telegram user ID
Returns:
ConversationSession for the user
"""
# Check in-memory cache first
if telegram_user_id in self._sessions:
logger.debug(f"Found session in cache for user {telegram_user_id}")
return self._sessions[telegram_user_id]
# Check database for existing session
session_data = await get_user_active_session(telegram_user_id)
if session_data:
# Restore session from database
conversation_state_json = session_data.get('conversation_state')
if conversation_state_json:
try:
session_dict = json.loads(conversation_state_json)
session = ConversationSession.from_dict(session_dict)
session.session_id = session_data['session_id']
self._sessions[telegram_user_id] = session
logger.info(f"Restored session from database for user {telegram_user_id}")
return session
except json.JSONDecodeError as e:
logger.error(f"Failed to parse session state: {e}")
# Create new session
session = ConversationSession(telegram_user_id)
# Save to database
session_id = await create_session(
telegram_user_id=telegram_user_id,
conversation_state=json.dumps(session.to_dict()),
expires_in_hours=24
)
session.session_id = session_id
self._sessions[telegram_user_id] = session
logger.info(f"Created new session for user {telegram_user_id} (ID: {session_id})")
return session
async def save_session(self, telegram_user_id: int) -> bool:
"""
Save session to database.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if saved successfully
"""
session = self._sessions.get(telegram_user_id)
if not session or not session.session_id:
logger.warning(f"No session to save for user {telegram_user_id}")
return False
try:
conversation_state = json.dumps(session.to_dict())
success = await update_session_state(
session_id=session.session_id,
conversation_state=conversation_state
)
if success:
logger.debug(f"Saved session for user {telegram_user_id}")
else:
logger.warning(f"Failed to save session for user {telegram_user_id}")
return success
except Exception as e:
logger.error(f"Error saving session for user {telegram_user_id}: {e}")
return False
async def delete_session(self, telegram_user_id: int) -> bool:
"""
Delete session completely (from memory and database).
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if deleted successfully
"""
# Remove from memory
if telegram_user_id in self._sessions:
del self._sessions[telegram_user_id]
# Delete from database
success = await delete_user_sessions(telegram_user_id)
if success:
logger.info(f"Deleted session for user {telegram_user_id}")
else:
logger.warning(f"Failed to delete session for user {telegram_user_id}")
return success
def get_active_sessions_count(self) -> int:
"""
Get count of active sessions in memory.
Returns:
int: Number of active sessions
"""
return len(self._sessions)
# Singleton instance
_session_manager_instance: Optional[SessionManager] = None
def get_session_manager() -> SessionManager:
"""
Get or create the singleton SessionManager instance.
Returns:
SessionManager: Singleton instance
"""
global _session_manager_instance
if _session_manager_instance is None:
_session_manager_instance = SessionManager()
return _session_manager_instance
# Export main classes and functions
__all__ = [
'ConversationSession',
'SessionManager',
'get_session_manager'
]

View File

View File

@@ -0,0 +1,917 @@
"""
API Client for ROA2WEB Backend Communication
This module provides an async HTTP client for communicating with the FastAPI backend.
Handles authentication, requests, error handling, and response parsing.
"""
import logging
import os
from typing import Optional, Dict, Any, List
from datetime import datetime
import httpx
from httpx import AsyncClient, Response, HTTPError, ConnectError
logger = logging.getLogger(__name__)
# Backend configuration from environment
# Default to port 8000 (production) instead of 8001 (development)
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
REQUEST_TIMEOUT = float(os.getenv("API_TIMEOUT", "30.0")) # 30 seconds default
class BackendAPIClient:
"""
Async HTTP client for ROA2WEB FastAPI backend.
Provides methods for all API endpoints used by the Telegram bot:
- Dashboard data
- Invoices search and retrieval
- Treasury/payment data
- Report exports
- Company listings
- User authentication and token management
"""
def __init__(self, base_url: str = BACKEND_URL):
"""
Initialize the API client.
Args:
base_url: Base URL of the FastAPI backend
"""
self.base_url = base_url.rstrip('/')
self.client: Optional[AsyncClient] = None
logger.info(f"Backend API client initialized with base URL: {self.base_url}")
async def __aenter__(self):
"""Async context manager entry."""
self.client = AsyncClient(
base_url=self.base_url,
timeout=REQUEST_TIMEOUT,
follow_redirects=True
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self.client:
await self.client.aclose()
def _get_auth_headers(self, jwt_token: str) -> Dict[str, str]:
"""
Generate authentication headers with JWT token.
Args:
jwt_token: JWT access token
Returns:
Dict with Authorization header
"""
return {
"Authorization": f"Bearer {jwt_token}",
"Content-Type": "application/json"
}
async def _handle_response(self, response: Response) -> Dict[str, Any]:
"""
Handle API response and extract data.
Args:
response: HTTP response object
Returns:
Dict: Response JSON data
Raises:
HTTPError: If response status is not successful
"""
try:
response.raise_for_status()
return response.json()
except HTTPError as e:
logger.error(f"API request failed: {e}")
logger.error(f"Response body: {response.text}")
raise
except Exception as e:
logger.error(f"Failed to parse response: {e}")
raise
# =========================================================================
# AUTHENTICATION & USER ENDPOINTS
# =========================================================================
async def verify_user(
self,
oracle_username: str,
linking_code: str
) -> Optional[Dict[str, Any]]:
"""
Verify user exists in Oracle and get JWT token.
Called during Telegram linking process (auto-linking flow).
Args:
oracle_username: Oracle username extracted from linking code
linking_code: The 8-character linking code for validation
Returns:
Dict with:
- success: True if verification succeeded
- access_token: JWT access token
- refresh_token: JWT refresh token
- user: Dict with user_id, username, companies, permissions
- message: Status message
None if user not found or error
Example:
result = await client.verify_user("JOHN.DOE", "ABC12345")
if result and result['success']:
jwt_token = result['access_token']
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Flow A: Auto-linking (no password required)
response = await self.client.post(
"/api/telegram/auth/verify-user",
json={
"linking_code": linking_code,
"oracle_username": oracle_username
}
)
return await self._handle_response(response)
except ConnectError as e:
logger.error(f"Cannot connect to backend at {self.base_url}: {e}")
logger.error("Verify that backend service is running and BACKEND_URL is correct")
return None
except HTTPError as e:
if e.response.status_code == 404:
logger.warning(f"User {oracle_username} not found in Oracle")
return None
logger.error(f"Failed to verify user {oracle_username}: {e}")
return None
except Exception as e:
logger.error(f"Error verifying user: {e}")
return None
async def refresh_token(self, refresh_token: str) -> Optional[str]:
"""
Refresh JWT token for a user.
Args:
refresh_token: JWT refresh token
Returns:
str: New JWT access token, None if failed
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/refresh-token",
json={"refresh_token": refresh_token}
)
data = await self._handle_response(response)
return data.get('access_token')
except Exception as e:
logger.error(f"Failed to refresh token: {e}")
return None
async def verify_email(self, email: str) -> dict:
"""
Verify if email exists in Oracle database
Args:
email: Email address to verify
Returns:
dict with 'success' (bool), 'username' (str or None), and 'message' (str)
Raises:
httpx.HTTPError: On network or HTTP errors
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/verify-email",
json={"email": email}
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error verifying email {email}: {e.response.status_code}")
return {
"success": False,
"username": None,
"message": "Eroare la verificarea email-ului"
}
except Exception as e:
logger.error(f"Failed to verify email {email}: {e}")
return {
"success": False,
"username": None,
"message": "Eroare la verificarea email-ului"
}
async def login_with_email(
self,
email: str,
password: str,
telegram_user_id: int,
session_token: str
) -> dict:
"""
Login via email + password with session token
Args:
email: User email address
password: Oracle password
telegram_user_id: Telegram user ID
session_token: Signed token from code validation
Returns:
Login response with JWT tokens and user data
Raises:
httpx.HTTPError: On network or HTTP errors
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/telegram/auth/login-with-email",
json={
"email": email,
"password": password,
"telegram_user_id": telegram_user_id,
"session_token": session_token
},
timeout=30.0 # 30 seconds timeout
)
response.raise_for_status()
data = response.json()
logger.info(f"Email login successful for user {telegram_user_id}")
return data
except httpx.HTTPStatusError as e:
logger.error(f"Email login HTTP error: {e.response.status_code} - {e.response.text}")
# Parse error detail if available
try:
error_data = e.response.json()
return {
"success": False,
"message": error_data.get("detail", "Autentificare eșuată")
}
except:
return {
"success": False,
"message": "Autentificare eșuată"
}
except httpx.TimeoutException:
logger.error("Email login timeout")
return {
"success": False,
"message": "Timeout. Te rugăm să încerci din nou."
}
except Exception as e:
logger.error(f"Email login error: {e}", exc_info=True)
return {
"success": False,
"message": "Eroare de conexiune"
}
async def get_user_companies(self, jwt_token: str) -> List[Dict[str, Any]]:
"""
Get list of companies the user has access to.
Args:
jwt_token: JWT access token
Returns:
List of company dicts with id, nume_firma, cui, etc.
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/companies",
headers=self._get_auth_headers(jwt_token)
)
data = await self._handle_response(response)
# Backend returns {"companies": [...], "total_count": N}
if isinstance(data, dict) and "companies" in data:
return data["companies"]
return data if isinstance(data, list) else []
except Exception as e:
logger.error(f"Failed to get companies: {e}")
return []
# =========================================================================
# DASHBOARD ENDPOINTS
# =========================================================================
async def get_dashboard_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get dashboard statistics for a company.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with dashboard data (sold_total, facturi, plati, etc.)
Includes _cache_hit and _response_time_ms metadata
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
"/api/reports/dashboard/summary",
params={"company": str(company_id)},
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get dashboard data for company {company_id}: {e}")
return None
async def get_treasury_breakdown(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get detailed treasury breakdown (casa + banca accounts).
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with treasury breakdown data (accounts by type)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/treasury-breakdown?company={company_id}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get treasury breakdown for company {company_id}: {e}")
return None
async def get_detailed_data(
self,
company_id: int,
jwt_token: str,
data_type: str
) -> Optional[Dict[str, Any]]:
"""
Get detailed data for clients or suppliers.
Args:
company_id: Company ID
jwt_token: JWT access token
data_type: Type of data ('clients' or 'suppliers')
Returns:
Dict with detailed data (list of clients/suppliers with balances)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/detailed-data?company={company_id}&data_type={data_type}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get detailed data ({data_type}) for company {company_id}: {e}")
return None
async def get_maturity_data(
self,
company_id: int,
jwt_token: str,
period: str = "all"
) -> Optional[Dict[str, Any]]:
"""
Get maturity data (in term/overdue breakdown).
Args:
company_id: Company ID
jwt_token: JWT access token
period: Period filter ('all', '30', '60', '90')
Returns:
Dict with maturity data (in_term, overdue, total)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/maturity?company={company_id}&period={period}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get maturity data for company {company_id}: {e}")
return None
async def get_performance_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get performance data (incasari/plati totals).
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with performance data (incasari_total, plati_total, net)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/performance?company={company_id}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get performance data for company {company_id}: {e}")
return None
async def get_monthly_flows(
self,
company_id: int,
jwt_token: str,
months: int = 12
) -> Optional[Dict[str, Any]]:
"""
Get monthly cash flows data.
Args:
company_id: Company ID
jwt_token: JWT access token
months: Number of months to retrieve
Returns:
Dict with monthly flows (months, incasari, plati arrays)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/monthly-flows?company={company_id}&months={months}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get monthly flows for company {company_id}: {e}")
return None
async def get_trends(
self,
company_id: int,
jwt_token: str,
period: str = "12m"
) -> Optional[Dict[str, Any]]:
"""
Get trends data (12-month historical data for collections/payments).
Args:
company_id: Company ID
jwt_token: JWT access token
period: Period for trends (e.g., "12m", "6m", "ytd")
Returns:
Dict with trends data including periods, clienti_incasat, furnizori_achitat arrays
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
# Add cache metadata header for Telegram Bot
headers = self._get_auth_headers(jwt_token)
headers['X-Include-Cache-Metadata'] = 'true'
response = await self.client.get(
f"/api/reports/dashboard/trends?company={company_id}&period={period}",
headers=headers
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get trends for company {company_id}: {e}")
return None
# =========================================================================
# INVOICES ENDPOINTS
# =========================================================================
async def search_invoices(
self,
company_id: int,
jwt_token: str,
filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search invoices with optional filters.
Args:
company_id: Company ID
jwt_token: JWT access token
filters: Optional filters dict:
- date_from: str (YYYY-MM-DD)
- date_to: str (YYYY-MM-DD)
- status: str (paid, unpaid, overdue)
- client_name: str
- partner_type: str (CLIENTI, FURNIZORI)
- partner_name: str
- series: str
- number: str
Returns:
List of invoice dicts
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
params = {"company": company_id}
if filters:
params.update(filters)
response = await self.client.get(
"/api/reports/invoices/",
params=params,
headers=self._get_auth_headers(jwt_token)
)
data = await self._handle_response(response)
if isinstance(data, dict) and 'invoices' in data:
invoice_list = data['invoices']
return invoice_list
elif isinstance(data, list):
return data
else:
logger.warning(f"📥 Unexpected response format: {type(data)}")
return []
except Exception as e:
logger.error(f"Failed to search invoices for company {company_id}: {e}")
return []
async def get_invoice_summary(
self,
company_id: int,
jwt_token: str,
partner_type: str = "CLIENTI"
) -> Optional[Dict[str, Any]]:
"""
Get invoice summary statistics.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with summary (total_count, total_amount, paid, unpaid, etc.)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/invoices/summary",
params={
"company": str(company_id),
"partner_type": partner_type
},
headers=self._get_auth_headers(jwt_token)
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get invoice summary for company {company_id}: {e}")
return None
# =========================================================================
# TREASURY ENDPOINTS
# =========================================================================
async def get_treasury_data(
self,
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get treasury/cash flow data for a company.
Args:
company_id: Company ID
jwt_token: JWT access token
Returns:
Dict with treasury data (cash_balance, incoming, outgoing, etc.)
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/treasury/bank-cash-register",
params={
"company": str(company_id),
"page": 1,
"page_size": 1000
},
headers=self._get_auth_headers(jwt_token)
)
return await self._handle_response(response)
except Exception as e:
logger.error(f"Failed to get treasury data for company {company_id}: {e}")
return None
# =========================================================================
# EXPORT ENDPOINTS
# =========================================================================
async def export_report(
self,
jwt_token: str,
report_type: str,
company_id: int,
format: str = "xlsx",
filters: Optional[Dict[str, Any]] = None
) -> Optional[bytes]:
"""
Generate and export a report.
Args:
jwt_token: JWT access token
report_type: Type of report ('dashboard', 'invoices', 'treasury')
company_id: Company ID
format: Export format ('xlsx', 'csv', 'pdf')
filters: Optional filters for data
Returns:
bytes: File content, None if failed
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
request_data = {
"type": report_type,
"company_id": company_id,
"format": format,
"filters": filters or {}
}
response = await self.client.post(
"/api/telegram/export",
json=request_data,
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to export report: {e}")
return None
# =========================================================================
# CACHE MANAGEMENT
# =========================================================================
async def invalidate_cache(
self,
jwt_token: str,
company_id: Optional[int] = None,
cache_type: Optional[str] = None
) -> bool:
"""
Invalidate cache entries.
Args:
jwt_token: JWT access token
company_id: Optional company ID (None = all companies)
cache_type: Optional cache type (None = all types)
Returns:
bool: True if successful
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
request_data = {}
if company_id is not None:
request_data['company_id'] = company_id
if cache_type is not None:
request_data['cache_type'] = cache_type
response = await self.client.post(
"/api/reports/cache/invalidate",
json=request_data,
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
logger.info(f"Cache invalidated: company_id={company_id}, cache_type={cache_type}")
return True
except Exception as e:
logger.error(f"Failed to invalidate cache: {e}")
return False
async def toggle_user_cache(
self,
jwt_token: str,
enabled: bool
) -> bool:
"""
Toggle cache for current user.
Args:
jwt_token: JWT access token
enabled: True to enable cache, False to disable
Returns:
bool: True if successful
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.post(
"/api/reports/cache/toggle-user",
json={"enabled": enabled},
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
logger.info(f"User cache toggled: enabled={enabled}")
return True
except Exception as e:
logger.error(f"Failed to toggle user cache: {e}")
return False
async def get_cache_stats(
self,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get cache statistics including user-specific settings.
Args:
jwt_token: JWT access token
Returns:
Dict with cache stats including 'user_enabled' field
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get(
"/api/reports/cache/stats",
headers=self._get_auth_headers(jwt_token)
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to get cache stats: {e}")
return None
# =========================================================================
# HEALTH CHECK
# =========================================================================
async def health_check(self) -> bool:
"""
Check if backend is healthy and reachable.
Returns:
bool: True if backend is healthy
"""
try:
if not self.client or self.client.is_closed:
self.client = AsyncClient(base_url=self.base_url, timeout=REQUEST_TIMEOUT)
response = await self.client.get("/api/telegram/health")
return response.status_code == 200
except Exception as e:
logger.error(f"Backend health check failed: {e}")
return False
# Singleton instance for global use
_backend_client_instance: Optional[BackendAPIClient] = None
def get_backend_client() -> BackendAPIClient:
"""
Get or create the singleton BackendAPIClient instance.
Returns:
BackendAPIClient: Singleton instance
"""
global _backend_client_instance
if _backend_client_instance is None:
_backend_client_instance = BackendAPIClient()
return _backend_client_instance
# Export main classes and functions
__all__ = [
'BackendAPIClient',
'get_backend_client',
'BACKEND_URL'
]

View File

@@ -0,0 +1,171 @@
"""
Email authentication logic with crypto-secure code generation
"""
import secrets
import re
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict
from collections import defaultdict
logger = logging.getLogger(__name__)
# ============================================================================
# RATE LIMITING (In-Memory)
# ============================================================================
# NOTE: For production with multiple bot instances, migrate to Redis
# See "Optional Future Enhancements" section in plan
_rate_limit_store: Dict[str, list] = defaultdict(list)
async def check_rate_limit(
identifier: str,
max_attempts: int = 3,
window_minutes: int = 60
) -> bool:
"""
Check if identifier is within rate limit
Args:
identifier: Email or telegram_user_id (as string)
max_attempts: Maximum attempts allowed
window_minutes: Time window in minutes
Returns:
True if within limit (can proceed), False if exceeded
NOTE: In-memory implementation - resets on bot restart
"""
now = datetime.now()
cutoff = now - timedelta(minutes=window_minutes)
# Clean old attempts
_rate_limit_store[identifier] = [
attempt for attempt in _rate_limit_store[identifier]
if attempt > cutoff
]
# Check limit
if len(_rate_limit_store[identifier]) >= max_attempts:
logger.warning(f"Rate limit exceeded for {identifier}")
return False
# Add new attempt
_rate_limit_store[identifier].append(now)
return True
def clear_rate_limit(identifier: str) -> None:
"""Clear rate limit for identifier (e.g., after successful auth)"""
if identifier in _rate_limit_store:
del _rate_limit_store[identifier]
logger.debug(f"Rate limit cleared for {identifier}")
# ============================================================================
# CODE GENERATION (Crypto-Secure)
# ============================================================================
def generate_email_code() -> str:
"""
Generate crypto-secure 6-digit code
Uses secrets module (not random) for cryptographic security
Returns:
6-digit string (000000 - 999999)
"""
# Generate 6-digit code using secrets (crypto-secure)
code = ''.join(secrets.choice('0123456789') for _ in range(6))
logger.debug(f"Generated email auth code (length: {len(code)})")
return code
# ============================================================================
# EMAIL VALIDATION
# ============================================================================
def is_valid_email_format(email: str) -> bool:
"""
Validate email format (basic regex)
Args:
email: Email address to validate
Returns:
True if format is valid
"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
async def verify_email_in_oracle(email: str) -> Optional[str]:
"""
Verify email exists in Oracle UTILIZATORI table via backend API
Args:
email: Email address to check
Returns:
Oracle username if found and active, None otherwise
NOTE: Uses backend API endpoint /api/telegram/auth/verify-email
"""
try:
from backend.modules.telegram.api.client import get_backend_client
backend_client = get_backend_client()
# Call backend API to verify email
response = await backend_client.verify_email(email)
if response.get('success'):
username = response.get('username')
logger.info(f"Email verified via backend: {email} -> {username}")
return username
else:
logger.warning(f"Email not found or inactive: {email}")
return None
except Exception as e:
logger.error(f"Error verifying email via backend: {e}", exc_info=True)
return None
# ============================================================================
# SESSION TOKEN GENERATION (Prevent User ID Spoofing)
# ============================================================================
def generate_session_token(telegram_user_id: int, email: str) -> str:
"""
Generate signed session token for backend verification
This prevents user ID spoofing attacks where malicious clients
could impersonate Telegram users by sending arbitrary user IDs
Args:
telegram_user_id: Telegram user ID
email: Verified email address
Returns:
Signed token (simple implementation - upgrade to JWT in future)
NOTE: For production, use proper JWT signing with shared secret
"""
import hashlib
import os
# Get secret from env (should match backend)
secret = os.getenv("AUTH_SESSION_SECRET", "change-me-in-production")
# Create signature: HMAC-like hash
payload = f"{telegram_user_id}:{email}:{secret}"
signature = hashlib.sha256(payload.encode()).hexdigest()[:16]
# Token format: user_id:email:signature
token = f"{telegram_user_id}:{email}:{signature}"
logger.debug(f"Generated session token for user {telegram_user_id}")
return token

View File

@@ -0,0 +1,350 @@
"""
Authentication and User Linking Logic
This module handles the linking process between Telegram users and Oracle ERP accounts.
It manages authentication codes, verifies users through the backend API, and maintains
user sessions with JWT tokens.
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from telegram import User as TelegramUser
from backend.modules.telegram.db.operations import (
get_user,
create_or_update_user,
link_user_to_oracle,
update_user_tokens,
verify_and_use_auth_code,
is_user_linked
)
from backend.modules.telegram.api.client import get_backend_client
logger = logging.getLogger(__name__)
async def link_telegram_account(
telegram_user: TelegramUser,
auth_code: str
) -> Optional[Dict[str, Any]]:
"""
Link a Telegram account to an Oracle ERP account using an authentication code.
Flow:
1. Verify auth code in database (check exists, not used, not expired)
2. Extract oracle_username from code
3. Call backend API to verify user in Oracle and get JWT token
4. Create/update Telegram user record
5. Link user to Oracle account with JWT tokens
6. Return success with user data
Args:
telegram_user: Telegram User object from python-telegram-bot
auth_code: 8-character authentication code from web frontend
Returns:
Dict with:
- success: True if linking succeeded
- username: Oracle username
- jwt_token: JWT access token
- companies: List of companies user has access to
OR None if linking failed
Example:
result = await link_telegram_account(telegram_user, "ABC12345")
if result:
print(f"Linked to {result['username']}")
else:
print("Linking failed")
"""
try:
telegram_user_id = telegram_user.id
telegram_username = telegram_user.username
first_name = telegram_user.first_name
last_name = telegram_user.last_name
logger.info(
f"Attempting to link Telegram user {telegram_user_id} "
f"(@{telegram_username}) with code {auth_code}"
)
# Step 1: Verify auth code
code_data = await verify_and_use_auth_code(auth_code)
if not code_data:
logger.warning(f"Invalid, expired, or already used auth code: {auth_code}")
return None
oracle_username = code_data.get('oracle_username')
logger.info(f"Auth code valid for Oracle user: {oracle_username}")
# Step 2: Create/update Telegram user record (basic info)
user_created = await create_or_update_user(
telegram_user_id=telegram_user_id,
username=telegram_username,
first_name=first_name,
last_name=last_name
)
if not user_created:
logger.error(f"Failed to create/update Telegram user {telegram_user_id}")
return None
# Step 3: Verify user in Oracle and get JWT token via backend API (auto-linking flow)
backend_client = get_backend_client()
async with backend_client:
user_data = await backend_client.verify_user(
oracle_username=oracle_username,
linking_code=auth_code
)
if not user_data or not user_data.get('success'):
logger.error(f"Failed to verify Oracle user {oracle_username} via backend")
return None
# Extract tokens and user info from response
jwt_token = user_data.get('access_token')
jwt_refresh_token = user_data.get('refresh_token', jwt_token)
user_info = user_data.get('user', {})
companies = user_info.get('companies', [])
permissions = user_info.get('permissions', [])
# Token expiration (typically 30 minutes for access token)
token_expires_at = datetime.now() + timedelta(minutes=30)
# Step 4: Link Telegram user to Oracle account
linked = await link_user_to_oracle(
telegram_user_id=telegram_user_id,
oracle_username=oracle_username,
jwt_token=jwt_token,
jwt_refresh_token=jwt_refresh_token,
token_expires_at=token_expires_at
)
if not linked:
logger.error(f"Failed to link user {telegram_user_id} to Oracle account")
return None
logger.info(
f"Successfully linked Telegram user {telegram_user_id} "
f"to Oracle user {oracle_username}"
)
return {
"success": True,
"telegram_user_id": telegram_user_id,
"username": oracle_username,
"jwt_token": jwt_token,
"jwt_refresh_token": jwt_refresh_token,
"companies": companies,
"permissions": permissions,
"linked_at": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error linking Telegram account: {e}", exc_info=True)
return None
async def get_user_auth_data(telegram_user_id: int) -> Optional[Dict[str, Any]]:
"""
Get authentication data for a linked Telegram user.
This function retrieves the user's Oracle account information and JWT tokens.
If the token is expired, it automatically refreshes it.
Args:
telegram_user_id: Telegram user ID
Returns:
Dict with:
- telegram_user_id: Telegram user ID
- username: Oracle username
- jwt_token: Valid JWT access token (refreshed if needed)
- jwt_refresh_token: JWT refresh token
- companies: List of companies (fetched if not cached)
OR None if user is not linked or error occurred
Example:
auth_data = await get_user_auth_data(12345)
if auth_data:
jwt = auth_data['jwt_token']
# Use JWT for API calls
"""
try:
# Get user from database
user_data = await get_user(telegram_user_id)
if not user_data:
logger.warning(f"User {telegram_user_id} not found in database")
return None
if not user_data.get('oracle_username'):
logger.warning(f"User {telegram_user_id} is not linked to Oracle account")
return None
oracle_username = user_data['oracle_username']
jwt_token = user_data['jwt_token']
jwt_refresh_token = user_data['jwt_refresh_token']
token_expires_at_str = user_data['token_expires_at']
# Parse token expiration
token_expires_at = datetime.fromisoformat(token_expires_at_str) if token_expires_at_str else None
# Check if token is expired or about to expire (< 5 minutes remaining)
token_expired = (
token_expires_at is None or
datetime.now() >= token_expires_at - timedelta(minutes=5)
)
if token_expired:
logger.info(f"Token expired for user {telegram_user_id}, refreshing...")
# Refresh token via backend API
backend_client = get_backend_client()
async with backend_client:
new_token = await backend_client.refresh_token(jwt_refresh_token)
if new_token:
# Update token in database
new_expires_at = datetime.now() + timedelta(minutes=30)
await update_user_tokens(
telegram_user_id=telegram_user_id,
jwt_token=new_token,
jwt_refresh_token=jwt_refresh_token, # Keep same refresh token
token_expires_at=new_expires_at
)
jwt_token = new_token
logger.info(f"Token refreshed for user {telegram_user_id}")
else:
logger.error(f"Failed to refresh token for user {telegram_user_id}")
return None
# Fetch user companies (fresh from backend)
backend_client = get_backend_client()
async with backend_client:
companies = await backend_client.get_user_companies(jwt_token)
return {
"telegram_user_id": telegram_user_id,
"username": oracle_username,
"jwt_token": jwt_token,
"jwt_refresh_token": jwt_refresh_token,
"companies": companies
}
except Exception as e:
logger.error(f"Error getting user auth data: {e}", exc_info=True)
return None
async def check_user_linked(telegram_user_id: int) -> bool:
"""
Check if a Telegram user is linked to an Oracle account.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if user is linked, False otherwise
Example:
if await check_user_linked(12345):
print("User is linked")
else:
print("User needs to link account")
"""
try:
return await is_user_linked(telegram_user_id)
except Exception as e:
logger.error(f"Error checking if user is linked: {e}")
return False
async def get_user_companies(telegram_user_id: int) -> Optional[list]:
"""
Get list of companies a user has access to.
This is a convenience function that fetches user auth data and returns
just the companies list.
Args:
telegram_user_id: Telegram user ID
Returns:
List of company dicts, or None if user not linked
Example:
companies = await get_user_companies(12345)
if companies:
for company in companies:
print(f"{company['id']}: {company['nume_firma']}")
"""
try:
auth_data = await get_user_auth_data(telegram_user_id)
if auth_data:
return auth_data.get('companies', [])
return None
except Exception as e:
logger.error(f"Error getting user companies: {e}")
return None
async def unlink_user(telegram_user_id: int) -> bool:
"""
Unlink a Telegram user from their Oracle account.
This removes the linking but keeps the Telegram user record.
Used for account disconnection or security purposes.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if successfully unlinked
Example:
if await unlink_user(12345):
print("Account unlinked")
"""
try:
# Set Oracle username and tokens to NULL
from backend.modules.telegram.db.database import DB_PATH
import aiosqlite
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE telegram_users
SET oracle_username = NULL,
jwt_token = NULL,
jwt_refresh_token = NULL,
token_expires_at = NULL,
linked_at = NULL
WHERE telegram_user_id = ?
""", (telegram_user_id,))
await db.commit()
logger.info(f"User {telegram_user_id} unlinked from Oracle account")
return True
except Exception as e:
logger.error(f"Error unlinking user {telegram_user_id}: {e}")
return False
# Export main functions
__all__ = [
'link_telegram_account',
'get_user_auth_data',
'check_user_linked',
'get_user_companies',
'unlink_user'
]

View File

View File

@@ -0,0 +1,768 @@
"""
Telegram bot handlers for email-based authentication flow
"""
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import (
ContextTypes,
ConversationHandler,
CommandHandler,
MessageHandler,
CallbackQueryHandler,
filters
)
import logging
from datetime import datetime
import asyncio
from backend.modules.telegram.auth.email_auth import (
is_valid_email_format,
verify_email_in_oracle,
generate_email_code,
generate_session_token,
check_rate_limit,
clear_rate_limit
)
from backend.modules.telegram.utils.email_service import get_email_service
from backend.modules.telegram.db.operations import (
create_email_auth_code,
get_email_auth_code,
get_pending_email_code,
mark_email_code_used,
increment_failed_attempts,
delete_user_email_codes,
is_user_authenticated,
link_user_to_oracle,
create_or_update_user
)
from backend.modules.telegram.api.client import get_backend_client
logger = logging.getLogger(__name__)
# Conversation states
AWAITING_EMAIL, AWAITING_CODE, AWAITING_PASSWORD = range(3)
# Constants
MAX_CODE_ATTEMPTS = 3
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
async def edit_login_message(
context: ContextTypes.DEFAULT_TYPE,
chat_id: int,
text: str,
reply_markup=None,
parse_mode="Markdown"
):
"""
Helper function to edit the login message stored in context.
If message_id is not stored, creates a new message instead.
"""
message_id = context.user_data.get('login_message_id')
if message_id:
try:
await context.bot.edit_message_text(
chat_id=chat_id,
message_id=message_id,
text=text,
reply_markup=reply_markup,
parse_mode=parse_mode
)
except Exception as e:
logger.warning(f"Could not edit message {message_id}: {e}")
# Fallback: send new message and update ID
msg = await context.bot.send_message(
chat_id=chat_id,
text=text,
reply_markup=reply_markup,
parse_mode=parse_mode
)
context.user_data['login_message_id'] = msg.message_id
else:
# No message ID stored - create new message
msg = await context.bot.send_message(
chat_id=chat_id,
text=text,
reply_markup=reply_markup,
parse_mode=parse_mode
)
context.user_data['login_message_id'] = msg.message_id
async def delete_login_message(context: ContextTypes.DEFAULT_TYPE, chat_id: int):
"""Delete the login message and clear the message_id from context"""
message_id = context.user_data.get('login_message_id')
if message_id:
try:
await context.bot.delete_message(chat_id=chat_id, message_id=message_id)
except Exception as e:
logger.warning(f"Could not delete message {message_id}: {e}")
# Clear from context
context.user_data.pop('login_message_id', None)
# ============================================================================
# ENTRY POINTS: /login command and action:login button
# ============================================================================
async def login_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Handler pentru /login command
Oferă opțiuni de autentificare: Email sau Web App
"""
user = update.effective_user
# Check dacă e deja autentificat
if await is_user_authenticated(user.id):
await update.message.reply_text(
"Ești deja autentificat.\n\n"
"Folosește:\n"
"• /companies - Vezi companiile tale\n"
"• /help - Comenzi disponibile\n"
"• /unlink - Deautentifică-te"
)
return ConversationHandler.END
# Check rate limiting (3 requests per hour)
if not await check_rate_limit(f"login_{user.id}", max_attempts=3, window_minutes=60):
await update.message.reply_text(
"Prea multe încercări de autentificare.\n\n"
"Te rugăm să aștepți 60 de minute înainte de a încerca din nou."
)
return ConversationHandler.END
# Afișează opțiuni de autentificare
keyboard = [
[InlineKeyboardButton("Login cu Email + Parolă", callback_data="email_login")],
[InlineKeyboardButton("Login din Web App", callback_data="web_login_info")],
[InlineKeyboardButton("Anulează", callback_data="cancel")]
]
reply_markup = InlineKeyboardMarkup(keyboard)
# CREATE message and SAVE message_id
msg = await update.message.reply_text(
"Alege metoda de autentificare:",
reply_markup=reply_markup,
parse_mode="Markdown"
)
# Save message ID for future edits
context.user_data['login_message_id'] = msg.message_id
return AWAITING_EMAIL
async def action_login_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Handler pentru butonul Login din meniu (action:login)
Oferă opțiuni de autentificare: Email sau Web App
"""
query = update.callback_query
user = update.effective_user
logger.info(f"[EMAIL_AUTH] action_login_callback triggered for user {user.id}")
await query.answer()
# Check dacă e deja autentificat
if await is_user_authenticated(user.id):
await query.edit_message_text(
"Ești deja autentificat.\n\n"
"Folosește:\n"
"• /companies - Vezi companiile tale\n"
"• /help - Comenzi disponibile\n"
"• /unlink - Deautentifică-te"
)
return ConversationHandler.END
# Check rate limiting (3 requests per hour)
if not await check_rate_limit(f"login_{user.id}", max_attempts=3, window_minutes=60):
await query.edit_message_text(
"Prea multe încercări de autentificare.\n\n"
"Te rugăm să aștepți 60 de minute înainte de a încerca din nou."
)
return ConversationHandler.END
# Afișează opțiuni de autentificare
keyboard = [
[InlineKeyboardButton("Login cu Email + Parolă", callback_data="email_login")],
[InlineKeyboardButton("Login din Web App", callback_data="web_login_info")],
[InlineKeyboardButton("Anulează", callback_data="cancel")]
]
reply_markup = InlineKeyboardMarkup(keyboard)
# EDIT existing menu message and SAVE message_id
await query.edit_message_text(
"Alege metoda de autentificare:",
reply_markup=reply_markup,
parse_mode="Markdown"
)
# Save message ID for future edits
context.user_data['login_message_id'] = query.message.message_id
return AWAITING_EMAIL
# ============================================================================
# CALLBACK: Email Login
# ============================================================================
async def email_login_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Callback pentru butonul 'Login cu Email'"""
query = update.callback_query
user = update.effective_user
logger.info(f"[EMAIL_AUTH] email_login_callback triggered for user {user.id}")
await query.answer()
# IMPORTANT: Salvează message_id înainte de a edita
context.user_data['login_message_id'] = query.message.message_id
# EDIT same message - remove buttons, ask for email
await query.edit_message_text(
text="Introdu adresa de email ROA:",
parse_mode="Markdown"
)
return AWAITING_EMAIL
async def web_login_info_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Info despre web app login"""
query = update.callback_query
await query.answer()
await query.edit_message_text(
"**Login din Web App**\n\n"
"Pentru această metodă:\n\n"
"1. Accesează aplicația web ROA2WEB\n"
"2. Autentifică-te cu username + parolă\n"
"3. Apasă butonul \"Link Telegram\"\n"
"4. Copiază codul generat (8 caractere)\n"
"5. Trimite-mi codul: /start ABC123XY\n\n"
"Vei fi autentificat automat.",
parse_mode="Markdown"
)
# IMPORTANT: Salvează message_id pentru ca /start să poată edita același mesaj
context.user_data['web_login_message_id'] = query.message.message_id
return ConversationHandler.END
# ============================================================================
# STATE: AWAITING_EMAIL
# ============================================================================
async def receive_email(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Handler pentru primirea email-ului"""
email = update.message.text.strip().lower()
user_id = update.effective_user.id
# ȘTERG mesajul utilizatorului imediat (chat curat)
try:
await update.message.delete()
except Exception as e:
logger.warning(f"Could not delete email message: {e}")
# Validare format email
if not is_valid_email_format(email):
# Show error in main message
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Email invalid\n\nIntrodu o adresă validă (nume@domeniu.ro)",
reply_markup=InlineKeyboardMarkup([
[InlineKeyboardButton("Anulează", callback_data="cancel")]
])
)
return AWAITING_EMAIL
# Check for existing pending code
existing_code = await get_pending_email_code(user_id)
if existing_code:
# Delete old pending code
await delete_user_email_codes(user_id)
logger.info(f"Deleted existing pending code for user {user_id}")
# EDIT login message to show loading
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Verificare email...",
reply_markup=None
)
try:
# Verifică email în Oracle
username = await verify_email_in_oracle(email)
# IMPORTANT: Generic response to prevent email enumeration
# We always say "code sent" even if email doesn't exist
if username:
# Email exists - generate and send code
code = generate_email_code()
# Save code in database
code_saved = await create_email_auth_code(
code=code,
email=email,
username=username,
telegram_user_id=user_id,
expiry_minutes=5
)
if not code_saved:
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare la salvarea codului.\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# Send email (async with retry)
email_service = get_email_service()
email_sent = await email_service.send_auth_code(email, code, username)
if not email_sent:
logger.error(f"Failed to send email to {email}")
# Don't reveal this to user - they'll timeout naturally
# Wait 1 second for better UX (looks like verification happened)
await asyncio.sleep(1)
# ALWAYS show this message (prevent enumeration)
# EDIT same message with success + buttons
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text=f"Cod trimis pe {email}\n\nIntrodu codul primit pe email:",
reply_markup=InlineKeyboardMarkup([
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{email}")],
[InlineKeyboardButton("Anulează", callback_data="cancel")]
])
)
# Save email in context for resend functionality
context.user_data['pending_email'] = email
context.user_data['pending_username'] = username
return AWAITING_CODE
except Exception as e:
logger.error(f"Error in receive_email: {e}", exc_info=True)
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare internă.\n\nIncearcă din nou mai târziu."
)
return ConversationHandler.END
# ============================================================================
# STATE: AWAITING_CODE
# ============================================================================
async def receive_code(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Handler pentru primirea codului din email"""
code = update.message.text.strip()
user_id = update.effective_user.id
# ȘTERG mesajul utilizatorului imediat (chat curat)
try:
await update.message.delete()
except Exception as e:
logger.warning(f"Could not delete code message: {e}")
# Validare format cod (6 digits)
if not (code.isdigit() and len(code) == 6):
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod invalid\n\nIntrodu cele 6 cifre din email.",
reply_markup=InlineKeyboardMarkup([
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{context.user_data.get('pending_email', '')}")],
[InlineKeyboardButton("Anulează", callback_data="cancel")]
])
)
return AWAITING_CODE
# Verifică cod în DB
try:
code_data = await get_email_auth_code(code)
if not code_data:
# EDIT login message to show error
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod invalid sau expirat\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# Verificări de securitate
# 1. Check if already used
if code_data['used']:
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod deja folosit\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# 2. Check if expired
if datetime.now() > code_data['expires_at']:
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod expirat\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# 3. Check if belongs to this user
if code_data['telegram_user_id'] != user_id:
logger.warning(
f"User {user_id} tried to use code belonging to "
f"user {code_data['telegram_user_id']}"
)
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod invalid"
)
return ConversationHandler.END
# 4. Check failed attempts (max 3)
if code_data['failed_attempts'] >= MAX_CODE_ATTEMPTS:
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Prea multe încercări greșite\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# Cod valid - Marchează ca folosit
await mark_email_code_used(code)
# Salvează date verificate în context
context.user_data['verified_username'] = code_data['oracle_username']
context.user_data['verified_email'] = code_data['email']
context.user_data['session_token'] = generate_session_token(
user_id,
code_data['email']
)
# EDIT same message - ask for password
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Cod validat!\n\nIntroduci parola ROA:",
reply_markup=None # No buttons for security
)
return AWAITING_PASSWORD
except Exception as e:
logger.error(f"Error validating code: {e}", exc_info=True)
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare la validarea codului.\n\nIncearcă din nou."
)
return ConversationHandler.END
async def resend_code_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Retrimite codul pe email"""
query = update.callback_query
await query.answer("Retrimitem codul...")
# Extract email from callback data
callback_data = query.data # Format: "resend:email@example.com"
if not callback_data.startswith("resend:"):
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
email = callback_data.split(":", 1)[1]
user_id = update.effective_user.id
# Check rate limiting for resend (max 2 per 10 minutes)
if not await check_rate_limit(f"resend_{user_id}", max_attempts=2, window_minutes=10):
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Prea multe solicitări\n\nAșteaptă 10 minute."
)
return ConversationHandler.END
# Get username from context or re-verify
username = context.user_data.get('pending_username')
if not username:
username = await verify_email_in_oracle(email)
if not username:
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# Delete old code and generate new one
await delete_user_email_codes(user_id)
code = generate_email_code()
# Save new code
await create_email_auth_code(
code=code,
email=email,
username=username,
telegram_user_id=user_id,
expiry_minutes=5
)
# Send email
email_service = get_email_service()
await email_service.send_auth_code(email, code, username)
# FIX BUG: EDIT message and KEEP buttons!
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text=f"Cod retrimis pe {email}\n\nIntrodu codul primit pe email:",
reply_markup=InlineKeyboardMarkup([
[InlineKeyboardButton("Retrimite Cod", callback_data=f"resend:{email}")],
[InlineKeyboardButton("Anulează", callback_data="cancel")]
])
)
return AWAITING_CODE
# ============================================================================
# STATE: AWAITING_PASSWORD
# ============================================================================
async def receive_password(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Handler pentru primirea parolei Oracle"""
password = update.message.text.strip()
user_id = update.effective_user.id
# Șterge IMEDIAT mesajul cu parola (securitate)
try:
await update.message.delete()
logger.info(f"Password message deleted for user {user_id}")
except Exception as e:
logger.warning(f"Could not delete password message: {e}")
# Get verified data from context
username = context.user_data.get('verified_username')
email = context.user_data.get('verified_email')
session_token = context.user_data.get('session_token')
if not all([username, email, session_token]):
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Sesiune expirată\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# EDIT login message to show loading
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Verificare...",
reply_markup=None
)
try:
# Call backend endpoint pentru verificare parolă + JWT
backend_client = get_backend_client()
response = await backend_client.login_with_email(
email=email,
password=password,
telegram_user_id=user_id,
session_token=session_token
)
if not response.get('success'):
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Credențiale invalide\n\nParolă incorectă sau cont inactiv.\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# Success - Salvează user în telegram_users
# First create or update user record
await create_or_update_user(
telegram_user_id=user_id,
username=update.effective_user.username,
first_name=update.effective_user.first_name,
last_name=update.effective_user.last_name
)
# Then link to Oracle
from datetime import datetime, timedelta
token_expires_at = datetime.now() + timedelta(minutes=30) # Default expiry
await link_user_to_oracle(
telegram_user_id=user_id,
oracle_username=response['username'],
jwt_token=response['access_token'],
jwt_refresh_token=response['refresh_token'],
token_expires_at=token_expires_at
)
# Clear rate limits on successful auth
clear_rate_limit(f"login_{user_id}")
clear_rate_limit(f"resend_{user_id}")
# Get session and active company BEFORE editing message
from backend.modules.telegram.agent.session import get_session_manager
from backend.modules.telegram.bot.menus import create_main_menu, pad_message_for_wide_buttons
session_manager = get_session_manager()
session = await session_manager.get_or_create_session(user_id)
company = session.get_active_company()
company_name = company['name'] if company else None
company_cui = company.get('cui') if company else None
# Create main menu keyboard
keyboard = create_main_menu(
company_name=company_name,
company_cui=company_cui,
is_authenticated=True, # Now authenticated
cache_enabled=True # Default enabled
)
# Menu message with company info
companies_count = len(response.get('companies', []))
if company_name:
menu_text = f"{company_name}"
else:
menu_text = f"Companii disponibile: {companies_count}\n\nSelectează o companie pentru a continua"
menu_message = pad_message_for_wide_buttons(menu_text)
# EDIT login message to show menu (no deletion, direct edit)
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text=menu_message,
reply_markup=keyboard
)
# Clear sensitive data from context
context.user_data.clear()
logger.info(f"User {user_id} authenticated successfully via email")
return ConversationHandler.END
except Exception as e:
logger.error(f"Error during password verification: {e}", exc_info=True)
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Eroare la autentificare.\n\nIncearcă din nou cu /login"
)
return ConversationHandler.END
# ============================================================================
# CANCEL HANDLER
# ============================================================================
async def cancel_login(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Cancel conversation"""
# EDIT login message to show cancellation (don't delete)
if update.callback_query:
# Called from button
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Login anulat",
reply_markup=None
)
elif update.message:
# Called from /cancel command
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Login anulat",
reply_markup=None
)
# Clear context
context.user_data.clear()
return ConversationHandler.END
async def conversation_timeout(update: Update, context: ContextTypes.DEFAULT_TYPE):
"""Handler for conversation timeout"""
# EDIT login message to show timeout
await edit_login_message(
context=context,
chat_id=update.effective_chat.id,
text="Sesiune expirată\n\nConversația a expirat după 5 minute.\n\nIncearcă din nou cu /login"
)
# Clear context
context.user_data.clear()
return ConversationHandler.END
# ============================================================================
# CONVERSATION HANDLER SETUP
# ============================================================================
email_login_handler = ConversationHandler(
entry_points=[
CommandHandler('login', login_command),
CallbackQueryHandler(action_login_callback, pattern='^action:login$'),
CallbackQueryHandler(email_login_callback, pattern='^email_login$')
],
states={
AWAITING_EMAIL: [
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_email)
],
AWAITING_CODE: [
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_code),
CallbackQueryHandler(resend_code_callback, pattern='^resend:')
],
AWAITING_PASSWORD: [
MessageHandler(filters.TEXT & ~filters.COMMAND, receive_password)
],
},
fallbacks=[
CommandHandler('cancel', cancel_login),
CallbackQueryHandler(cancel_login, pattern='^cancel$'),
CallbackQueryHandler(web_login_info_callback, pattern='^web_login_info$')
],
per_message=False, # Track conversation per user, not per message
allow_reentry=True, # Allow starting new conversation even if previous one is active
name="email_login_conversation"
)

View File

@@ -0,0 +1,629 @@
"""
Response formatters for bot commands.
Formats API responses into user-friendly Telegram messages.
"""
from typing import Dict, List, Any
def format_dashboard_response(data: Dict[str, Any], company_name: str = None) -> str:
"""
Format dashboard data for Telegram (content only, no header).
Note: company_name parameter kept for backwards compatibility but not used.
Use format_response_with_company() in handlers to add company header.
"""
text = ""
# Sold total trezorerie (casa + banca) - rotunjit la leu
treasury_totals = data.get('treasury_totals_by_currency', {})
sold_trezorerie = round(float(treasury_totals.get('RON', 0)))
text += f"**Sold Trezorerie:** {sold_trezorerie:,} RON\n\n"
# Sold Clienți - rotunjit la leu
clienti_sold = round(float(data.get('clienti_sold_total', 0)))
clienti_in_termen = round(float(data.get('clienti_sold_in_termen', 0)))
clienti_restant = round(float(data.get('clienti_sold_restant', 0)))
text += f"**Sold Clienți:** {clienti_sold:,} RON\n"
text += f" - În termen: {clienti_in_termen:,} RON\n"
text += f" - Restanță: {clienti_restant:,} RON\n\n"
# Sold Furnizori BRUT (pentru consistență cu detaliile) - rotunjit la leu
furnizori_in_termen = round(float(data.get('furnizori_sold_in_termen', 0)))
furnizori_restant = round(float(data.get('furnizori_sold_restant', 0)))
furnizori_sold_brut = furnizori_in_termen + furnizori_restant
furnizori_avansuri = round(float(data.get('furnizori_avansuri', 0)))
furnizori_sold_net = round(float(data.get('furnizori_sold_total', 0)))
text += f"**Sold Furnizori:** {furnizori_sold_brut:,} RON\n"
text += f" - În termen: {furnizori_in_termen:,} RON\n"
text += f" - Restanță: {furnizori_restant:,} RON\n"
if furnizori_avansuri != 0:
text += f" - Avansuri: {furnizori_avansuri:,} RON\n"
text += f" - Net (după avansuri): {furnizori_sold_net:,} RON"
else:
text += f" - Net: {furnizori_sold_net:,} RON"
# Solduri TVA - rotunjit la leu
tva_plata_prec = round(float(data.get('tva_plata_precedent', 0)))
tva_recup_prec = round(float(data.get('tva_recuperat_precedent', 0)))
tva_plata_cur = round(float(data.get('tva_plata_curent', 0)))
tva_recup_cur = round(float(data.get('tva_recuperat_curent', 0)))
# Afișează secțiunea doar dacă există cel puțin o valoare > 0
if tva_plata_prec > 0 or tva_recup_prec > 0 or tva_plata_cur > 0 or tva_recup_cur > 0:
text += "\n\n**Solduri TVA:**\n"
if tva_plata_prec > 0:
text += f" - TVA de plată precedent: {tva_plata_prec:,} RON\n"
if tva_recup_prec > 0:
text += f" - TVA de recuperat precedent: {tva_recup_prec:,} RON\n"
if tva_plata_cur > 0:
text += f" - TVA de plată curent: {tva_plata_cur:,} RON\n"
if tva_recup_cur > 0:
text += f" - TVA de recuperat curent: {tva_recup_cur:,} RON\n"
return text
def format_invoices_response(
invoices: List[Dict[str, Any]],
company_name: str = None,
limit: int = 10
) -> str:
"""
Format invoices list for Telegram - COMPACT TABLE FORMAT.
Args:
invoices: List of invoice dicts
company_name: Company name (kept for compatibility, not used)
limit: Maximum number of invoices to display
Returns:
Formatted Markdown string for Telegram (compact, no emojis)
"""
if not invoices:
return "Nu s-au gasit facturi cu aceste criterii."
# Header (o singură dată)
text = f"**Facturi** ({len(invoices)} total)\n\n"
text += "Nr | Client | Suma | Status\n"
text += "---|--------|------|-------\n"
# Lista facturi - compact, o linie per factură
for idx, inv in enumerate(invoices[:limit], 1):
seria = inv.get('seria', '')
numar = inv.get('numar', '')
client = inv.get('client', 'N/A')
suma = inv.get('suma_totala', 0)
status = inv.get('status', 'N/A')
# Truncate long client names for compact display
client_short = client[:20] + "..." if len(client) > 20 else client
# Status marker (no emoji)
status_marker = "PLATIT" if status == "platit" else "NEPLATIT"
text += f"{seria}{numar} | {client_short} | {suma:,.0f} | {status_marker}\n"
if len(invoices) > limit:
text += f"\n+{len(invoices) - limit} facturi"
return text
# =========================================================================
# FAZA 2: New Formatter Functions for Button Interface
# =========================================================================
def format_treasury_casa_response(data: Dict[str, Any], company_name: str = None) -> str:
"""
Format treasury CASH data for Telegram (content only, no header).
Args:
data: Dict with casa accounts and total from treasury breakdown
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram
Example:
data = {'accounts': [...], 'total': 5000}
text = format_treasury_casa_response(data)
"""
text = ""
# Total cash balance - rotunjit la leu (0 zecimale)
total_cash = round(data.get('total', 0))
text += f"**Sold Total Cash:** {total_cash:,} RON\n\n"
# Cash accounts
casa_accounts = data.get('accounts', [])
if casa_accounts:
text += "**Conturi de Casa:**\n"
for acc in casa_accounts: # Show all accounts
name = acc.get('name', 'N/A')
balance = round(acc.get('balance', 0))
text += f" - {name}: {balance:,} RON\n"
else:
text += "Nu exista conturi de casa configurate."
return text
def format_treasury_banca_response(data: Dict[str, Any], company_name: str = None) -> str:
"""
Format treasury BANK data for Telegram (content only, no header).
Args:
data: Dict with banca accounts and total from treasury breakdown
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram
Example:
data = {'accounts': [...], 'total': 15000}
text = format_treasury_banca_response(data)
"""
text = ""
# Total bank balance - rotunjit la leu (0 zecimale)
total_bank = round(data.get('total', 0))
text += f"**Sold Total Banca:** {total_bank:,} RON\n\n"
# Bank accounts
bank_accounts = data.get('accounts', [])
if bank_accounts:
text += "**Conturi Bancare:**\n"
for acc in bank_accounts: # Show all accounts
name = acc.get('name', 'N/A')
balance = round(acc.get('balance', 0))
text += f" - {name}: {balance:,} RON\n"
else:
text += "Nu exista conturi bancare configurate."
return text
def format_clients_balance_response(
clients: List[Dict[str, Any]],
maturity_data: Dict[str, Any],
company_name: str = None
) -> str:
"""
Format clients balance with maturity breakdown (content only, no header).
Args:
clients: List of client dicts with id, name, balance
maturity_data: Dict with in_term, overdue, total
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram
Example:
clients = [{'id': 1, 'name': 'Client A', 'balance': 15000}]
maturity = {'in_term': 10000, 'overdue': 5000, 'total': 15000}
text = format_clients_balance_response(clients, maturity)
"""
text = ""
# Maturity breakdown - rotunjit la leu (0 zecimale)
total = round(maturity_data.get('total', 0))
in_term = round(maturity_data.get('in_term', 0))
overdue = round(maturity_data.get('overdue', 0))
text += f"**Sold Total:** {total:,} RON\n\n"
text += "**Defalcare:**\n"
text += f" - In termen: {in_term:,} RON\n"
text += f" - Restanta: {overdue:,} RON\n\n"
# Top 10 clients
if clients:
text += f"**Top 10 Clienti** ({len(clients)} total):\n"
# Sort by balance descending
sorted_clients = sorted(
clients,
key=lambda x: x.get('balance', 0),
reverse=True
)
for idx, client in enumerate(sorted_clients[:10], 1):
name = client.get('name', 'N/A')
balance = round(client.get('balance', 0))
text += f"{idx}. {name}: {balance:,} RON\n"
if len(clients) > 10:
text += f"\nApasa butonul pentru lista completa"
else:
text += "Nu exista clienti cu solduri."
return text
def format_suppliers_balance_response(
suppliers: List[Dict[str, Any]],
maturity_data: Dict[str, Any],
company_name: str = None
) -> str:
"""
Format suppliers balance with maturity breakdown (content only, no header).
Args:
suppliers: List of supplier dicts with id, name, balance
maturity_data: Dict with in_term, overdue, total
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram
Example:
suppliers = [{'id': 1, 'name': 'Supplier A', 'balance': 5000}]
maturity = {'in_term': 4000, 'overdue': 1000, 'total': 5000}
text = format_suppliers_balance_response(suppliers, maturity)
"""
text = ""
# Maturity breakdown - rotunjit la leu (0 zecimale)
total = round(maturity_data.get('total', 0))
in_term = round(maturity_data.get('in_term', 0))
overdue = round(maturity_data.get('overdue', 0))
text += f"**Sold Total:** {total:,} RON\n\n"
text += "**Defalcare:**\n"
text += f" - In termen: {in_term:,} RON\n"
text += f" - Restanta: {overdue:,} RON\n\n"
# Top 10 suppliers
if suppliers:
text += f"**Top 10 Furnizori** ({len(suppliers)} total):\n"
# Sort by balance descending
sorted_suppliers = sorted(
suppliers,
key=lambda x: x.get('balance', 0),
reverse=True
)
for idx, supplier in enumerate(sorted_suppliers[:10], 1):
name = supplier.get('name', 'N/A')
balance = round(supplier.get('balance', 0))
text += f"{idx}. {name}: {balance:,} RON\n"
if len(suppliers) > 10:
text += f"\nApasa butonul pentru lista completa"
else:
text += "Nu exista furnizori cu solduri."
return text
def format_cashflow_evolution_response(
performance_data: Dict[str, Any],
monthly_data: Dict[str, Any],
company_name: str = None
) -> str:
"""
Format cash flow evolution data - Table format with mini-charts.
Args:
performance_data: Dict with current_year and previous_year YTD data
monthly_data: Dict with months, incasari, plati arrays + prev year data
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram (monospace table)
Example:
YTD 2024 vs 2023:
2024 2023 Δ Trend
Inc: 500,000 480,000 +4.2% ████░
Plt: 450,000 440,000 +2.3% ███░
Net: 50,000 40,000 +25.0% █████
"""
text = ""
# Helper functions
def calc_percent_change(current: float, previous: float) -> str:
"""Calculate percentage change: +4.2% or -3.5%"""
if previous == 0:
return "+100%" if current > 0 else "0.0%"
change = ((current - previous) / previous) * 100
sign = "+" if change >= 0 else ""
return f"{sign}{change:.1f}%"
def create_mini_chart(current: float, previous: float, width: int = 5) -> str:
"""Create mini bar chart: ████░ (proportional bars)"""
if current == 0 and previous == 0:
return "" * width
max_val = max(current, previous)
if max_val == 0:
return "" * width
curr_bars = int((current / max_val) * width)
prev_bars = int((previous / max_val) * width)
# Use filled and light blocks
filled = "" * curr_bars
light = "" * (width - curr_bars)
return filled + light
def get_trend_arrow(current: float, previous: float) -> str:
"""Get trend arrow: ↑ or ↓ or →"""
if current > previous * 1.02: # More than 2% increase
return ""
elif current < previous * 0.98: # More than 2% decrease
return ""
else:
return ""
# Extract YTD data
current = performance_data.get('current_year', {})
previous = performance_data.get('previous_year', {})
current_year = current.get('year', '2024')
previous_year = previous.get('year', '2023')
inc_cur = round(current.get('incasari', 0))
plt_cur = round(current.get('plati', 0))
net_cur = round(current.get('net', 0))
inc_prev = round(previous.get('incasari', 0))
plt_prev = round(previous.get('plati', 0))
net_prev = round(previous.get('net', 0))
# YTD Table Header
text += f"**YTD {current_year} vs {previous_year}:**\n"
text += f"` {current_year:>10} {previous_year:>10} Δ `\n"
# YTD Rows
inc_pct = calc_percent_change(inc_cur, inc_prev)
text += f"`Inc: {inc_cur:>10,} {inc_prev:>10,} {inc_pct:>6}`\n"
plt_pct = calc_percent_change(plt_cur, plt_prev)
text += f"`Plt: {plt_cur:>10,} {plt_prev:>10,} {plt_pct:>6}`\n"
net_pct = calc_percent_change(net_cur, net_prev)
text += f"`Net: {net_cur:>10,} {net_prev:>10,} {net_pct:>6}`\n\n"
# Monthly Evolution Table - Simplified (Net only)
months = monthly_data.get('months', [])
incasari = monthly_data.get('incasari', [])
plati = monthly_data.get('plati', [])
incasari_prev = monthly_data.get('incasari_prev', [])
plati_prev = monthly_data.get('plati_prev', [])
if months and len(months) > 0:
text += "**Evolutie Net (12 luni):**\n"
text += f"` {current_year:>10} {previous_year:>10} Δ `\n"
for i, month in enumerate(months):
inc = incasari[i] if i < len(incasari) else 0
plt = plati[i] if i < len(plati) else 0
inc_p = incasari_prev[i] if i < len(incasari_prev) else 0
plt_p = plati_prev[i] if i < len(plati_prev) else 0
net = inc - plt
net_p = inc_p - plt_p
# Extract short month name (first 3 chars before apostrophe)
month_short = month.split("'")[0][:3] if "'" in month else month[:3]
# Calculate percentage change
net_pct = calc_percent_change(net, net_p)
# Format row: Luna Net'current Net'prev Δ (aligned with YTD)
text += f"`{month_short:<4} {int(net):>10,} {int(net_p):>10,} {net_pct:>6}`\n"
else:
text += "Nu exista date lunare disponibile."
return text
def format_client_detail_response(
client: Dict[str, Any],
invoices: List[Dict[str, Any]],
company_name: str = None
) -> str:
"""
Format client details with invoices - COMPACT TABLE FORMAT.
Args:
client: Dict with client info (id, name, balance)
invoices: List of invoice dicts for this client
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram (compact, no emojis)
Example:
client = {'id': 1, 'name': 'Client A', 'balance': 15000}
invoices = [{'id': 1, 'number': 'FV001', 'amount': 5000, 'status': 'unpaid'}]
text = format_client_detail_response(client, invoices)
"""
client_name = client.get('name', 'N/A')
balance = client.get('balance', 0)
# Header with client info
text = f"**{client_name}**\n"
text += f"**Sold total: {balance:,.2f} RON**"
if invoices and len(invoices) > 1:
text += f"{len(invoices)} facturi"
text += "\n\n"
# Invoices - compact table format (no emojis)
if invoices:
from datetime import datetime
# Sort invoices by date (most recent first)
sorted_invoices = sorted(invoices, key=lambda x: x.get('dataact') or datetime.min, reverse=True)
# Invoice list - simple format without table
text += "Facturi cu sold:\n"
text += "━━━━━━━━━━━━━━━━━━━━\n"
# Invoice rows - one line each, simple format
for inv in sorted_invoices[:10]:
# Backend returns: nract, totctva, soldfinal, datascad, dataact, achitat
number = str(inv.get('nract', 'N/A'))
dataact = inv.get('dataact')
# Parse date - handle various formats to ensure dd.mm.yyyy
if dataact:
if isinstance(dataact, str):
try:
# Try ISO format first: "2024-10-25" or "2024-10-25 00:00:00"
if '-' in dataact and len(dataact) >= 10:
parsed_date = datetime.strptime(dataact[:10], '%Y-%m-%d')
date_str = parsed_date.strftime('%d.%m.%Y')
# Already in dd.mm.yyyy format
elif '.' in dataact:
date_str = dataact.split()[0][:10] # Take just date part
else:
date_str = dataact[:10] if len(dataact) >= 10 else dataact
except:
date_str = dataact[:10] if len(dataact) >= 10 else dataact
else:
# Datetime object - format as dd.mm.yyyy
date_str = dataact.strftime('%d.%m.%Y')
else:
date_str = 'N/A'
sold = float(inv.get('soldfinal', 0) or 0)
# Simple format: Nr • Data • Sold
text += f"Nr {number}{date_str}{sold:,.2f} RON\n"
if len(invoices) > 10:
text += f"\n\n+{len(invoices) - 10} facturi"
else:
text += "Nu exista facturi neachitate"
return text
def format_supplier_detail_response(
supplier: Dict[str, Any],
invoices: List[Dict[str, Any]],
company_name: str = None
) -> str:
"""
Format supplier details with invoices - COMPACT TABLE FORMAT.
Args:
supplier: Dict with supplier info (id, name, balance)
invoices: List of invoice dicts for this supplier
company_name: Company name (kept for compatibility, not used)
Returns:
Formatted Markdown string for Telegram (compact, no emojis)
Example:
supplier = {'id': 1, 'name': 'Supplier A', 'balance': 5000}
invoices = [{'id': 1, 'number': 'FC001', 'amount': 2000, 'status': 'unpaid'}]
text = format_supplier_detail_response(supplier, invoices)
"""
supplier_name = supplier.get('name', 'N/A')
balance = supplier.get('balance', 0)
# Header with supplier info
text = f"**{supplier_name}**\n"
text += f"**Sold total: {balance:,.2f} RON**"
if invoices and len(invoices) > 1:
text += f"{len(invoices)} facturi"
text += "\n\n"
# Invoices - compact table format (no emojis)
if invoices:
from datetime import datetime
# Sort invoices by date (most recent first)
sorted_invoices = sorted(invoices, key=lambda x: x.get('dataact') or datetime.min, reverse=True)
# Invoice list - simple format without table
text += "Facturi cu sold:\n"
text += "━━━━━━━━━━━━━━━━━━━━\n"
# Invoice rows - one line each, simple format
for inv in sorted_invoices[:10]:
# Backend returns: nract, totctva, soldfinal, datascad, dataact, achitat
number = str(inv.get('nract', 'N/A'))
dataact = inv.get('dataact')
# Parse date - handle various formats to ensure dd.mm.yyyy
if dataact:
if isinstance(dataact, str):
try:
# Try ISO format first: "2024-10-25" or "2024-10-25 00:00:00"
if '-' in dataact and len(dataact) >= 10:
parsed_date = datetime.strptime(dataact[:10], '%Y-%m-%d')
date_str = parsed_date.strftime('%d.%m.%Y')
# Already in dd.mm.yyyy format
elif '.' in dataact:
date_str = dataact.split()[0][:10] # Take just date part
else:
date_str = dataact[:10] if len(dataact) >= 10 else dataact
except:
date_str = dataact[:10] if len(dataact) >= 10 else dataact
else:
# Datetime object - format as dd.mm.yyyy
date_str = dataact.strftime('%d.%m.%Y')
else:
date_str = 'N/A'
sold = float(inv.get('soldfinal', 0) or 0)
# Simple format: Nr • Data • Sold
text += f"Nr {number}{date_str}{sold:,.2f} RON\n"
if len(invoices) > 10:
text += f"\n\n+{len(invoices) - 10} facturi"
else:
text += "Nu exista facturi neachitate"
return text
# =========================================================================
# FAZA 6: Performance Footer for Cache Monitoring
# =========================================================================
def add_performance_footer(message: str, cache_hit: bool, time_ms: float, cache_source: str = None) -> str:
"""
Add compact performance footer to bot responses.
Shows data source (cached L1/L2 or database) and response time.
Format: "cached L1 | 15ms", "cached L2 | 25ms" or "db | 285ms"
Args:
message: Existing message text
cache_hit: True if data came from cache
time_ms: Response time in milliseconds
cache_source: Cache source ("L1" for memory, "L2" for SQLite) if cache_hit is True
Returns:
Message with performance footer appended
Example:
>>> add_performance_footer("Dashboard data...", True, 52.3, "L1")
"Dashboard data...\n\ncached L1 | 52ms"
>>> add_performance_footer("Dashboard data...", True, 25.8, "L2")
"Dashboard data...\n\ncached L2 | 26ms"
>>> add_performance_footer("Dashboard data...", False, 285.7)
"Dashboard data...\n\ndb | 286ms"
"""
if cache_hit and cache_source:
source = f"cached {cache_source}"
elif cache_hit:
source = "cached" # Fallback if source not provided
else:
source = "db"
footer = f"\n\n`{source} | {time_ms:.0f}ms`"
return message + footer

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,814 @@
"""
Helper functions for Telegram bot command handlers.
Provides utilities for company selection, API calls, and response formatting.
"""
import logging
from typing import Optional, Dict, List, Any
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from backend.modules.telegram.api.client import get_backend_client
from backend.modules.telegram.agent.session import SessionManager
from backend.modules.telegram.bot.menus import pad_message_for_wide_buttons
logger = logging.getLogger(__name__)
async def get_active_company_or_prompt(
update: Update,
session_manager: SessionManager,
telegram_user_id: int
) -> Optional[Dict[str, Any]]:
"""
Get active company from session or prompt user to select one with buttons.
This function checks if the user has an active company set in their session.
If not, it fetches companies and displays selection buttons directly.
Args:
update: Telegram Update object (for sending messages)
session_manager: SessionManager instance
telegram_user_id: Telegram user ID
Returns:
Dict with company info (id, name, cui) if set, None if user needs to select
Example:
company = await get_active_company_or_prompt(update, session_manager, user_id)
if not company:
return # User was shown company selection buttons
# Continue with company operations...
"""
session = await session_manager.get_or_create_session(telegram_user_id)
company = session.get_active_company()
if not company:
# Get auth data and companies
from backend.modules.telegram.auth.linking import get_user_auth_data
auth_data = await get_user_auth_data(telegram_user_id)
jwt_token = auth_data['jwt_token']
client = get_backend_client()
async with client:
companies = await client.get_user_companies(jwt_token=jwt_token)
if companies:
keyboard = create_company_selection_keyboard_paginated(companies, page=0)
message = (
f"**Selecteaza mai intai o companie**\n\n"
f"Companiile tale ({len(companies)}):"
)
# Apply padding to make inline keyboard buttons wider
message = pad_message_for_wide_buttons(message)
await update.message.reply_text(
message,
reply_markup=keyboard,
parse_mode="Markdown"
)
else:
await update.message.reply_text(
"Nu ai acces la nicio companie.\n"
"Contacteaza administratorul.",
parse_mode="Markdown"
)
return None
return company
async def search_companies_by_name(
name_query: str,
jwt_token: str
) -> List[Dict[str, Any]]:
"""
Search companies by partial name match (case-insensitive).
Fetches all companies from backend and filters them by name.
Uses case-insensitive partial matching for flexible search.
Args:
name_query: Search term (partial match, e.g., "ACME")
jwt_token: JWT authentication token
Returns:
List of matching company dicts (each with id, nume_firma, cui, etc.)
Example:
companies = await search_companies_by_name("acme", token)
# Returns all companies with "acme" in their name (case-insensitive)
"""
client = get_backend_client()
async with client:
all_companies = await client.get_user_companies(jwt_token=jwt_token)
# Filter by name (case-insensitive partial match)
query_lower = name_query.lower()
matches = [
comp for comp in all_companies
if query_lower in comp.get('name', comp.get('nume_firma', '')).lower()
]
logger.info(
f"Search '{name_query}': {len(matches)} matches out of {len(all_companies)} total"
)
return matches
def create_company_selection_keyboard(
companies: List[Dict[str, Any]],
max_buttons: int = 10
) -> InlineKeyboardMarkup:
"""
Create inline keyboard for company selection (legacy - without pagination).
Generates a vertical list of buttons, one per company.
Each button shows company name and CUI, and triggers a callback.
NOTE: This function is deprecated in favor of create_company_selection_keyboard_paginated.
It's kept for backwards compatibility only.
Args:
companies: List of company dicts (with id, nume_firma, cui)
max_buttons: Maximum number of buttons to show (default: 10)
Returns:
InlineKeyboardMarkup with company selection buttons
Example:
keyboard = create_company_selection_keyboard(companies)
await update.message.reply_text("Select company:", reply_markup=keyboard)
"""
keyboard = []
for company in companies[:max_buttons]:
company_id = company.get('id_firma', company.get('id'))
company_name = company.get('name', company.get('nume_firma', 'N/A'))
company_cui = company.get('fiscal_code', company.get('cui', ''))
# Button text: "ACME SRL (CUI: 12345)"
button_text = f"{company_name}"
if company_cui:
button_text += f" ({company_cui})"
# Callback data: "select_company:123"
callback_data = f"select_company:{company_id}"
keyboard.append([InlineKeyboardButton(button_text, callback_data=callback_data)])
# Add overflow indicator if there are more companies
if len(companies) > max_buttons:
keyboard.append([InlineKeyboardButton(
f"... și încă {len(companies) - max_buttons} companii",
callback_data="noop"
)])
return InlineKeyboardMarkup(keyboard)
def create_company_selection_keyboard_paginated(
companies: List[Dict[str, Any]],
page: int = 0,
per_page: int = 10
) -> InlineKeyboardMarkup:
"""
Create paginated inline keyboard for company selection.
Generates a vertical list of buttons for one page of companies,
with navigation buttons for previous/next pages.
Args:
companies: Full list of company dicts (with id, nume_firma, cui)
page: Current page number (0-indexed)
per_page: Number of companies per page (default: 10)
Returns:
InlineKeyboardMarkup with company buttons and pagination controls
Example:
keyboard = create_company_selection_keyboard_paginated(companies, page=0)
await update.message.reply_text("Select company:", reply_markup=keyboard)
"""
keyboard = []
# Calculate pagination
total_companies = len(companies)
total_pages = (total_companies + per_page - 1) // per_page # Ceiling division
start_idx = page * per_page
end_idx = min(start_idx + per_page, total_companies)
# Display companies for current page
page_companies = companies[start_idx:end_idx]
for company in page_companies:
company_id = company.get('id_firma', company.get('id'))
company_name = company.get('name', company.get('nume_firma', 'N/A'))
company_cui = company.get('fiscal_code', company.get('cui', ''))
# Button text: "ACME SRL (CUI: 12345)"
button_text = f"{company_name}"
if company_cui:
button_text += f" ({company_cui})"
# Callback data: "select_company:123"
callback_data = f"select_company:{company_id}"
keyboard.append([InlineKeyboardButton(button_text, callback_data=callback_data)])
# Pagination controls (only if more than one page)
if total_pages > 1:
nav_buttons = []
# Previous button
if page > 0:
nav_buttons.append(
InlineKeyboardButton("< Anterior", callback_data=f"select_company_page:{page-1}")
)
# Page indicator (non-clickable)
nav_buttons.append(
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
)
# Next button
if page < total_pages - 1:
nav_buttons.append(
InlineKeyboardButton("Urmator >", callback_data=f"select_company_page:{page+1}")
)
keyboard.append(nav_buttons)
# Back to menu button
keyboard.append([
InlineKeyboardButton("< Inapoi la Meniu", callback_data="action:menu")
])
return InlineKeyboardMarkup(keyboard)
def format_company_context_footer(company_name: str) -> str:
"""
Format discrete footer with company context.
Adds a subtle footer to command responses showing the active company
and a quick link to change it.
Args:
company_name: Active company name
Returns:
Formatted footer string with separator and company name
Example:
footer = format_company_context_footer("ACME SRL")
message = f"Dashboard data...\n{footer}"
# Output: "Dashboard data...\n\n━━━━━━━━━━━━━━\nCompanie: ACME SRL"
"""
return f"\n\n━━━━━━━━━━━━━━\nCompanie: {company_name}"
# =========================================================================
# FAZA 2: New Helper Functions for Button Interface
# =========================================================================
async def get_treasury_breakdown_split(
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get treasury breakdown split into casa and banca.
Fetches treasury breakdown from backend and transforms it
to the format expected by formatters.
Backend returns:
{
"total": float,
"breakdown": {
"casa": {"total": float, "items": [{"nume": str, "cont": str, "sold": float}]},
"banca": {"total": float, "items": [{"nume": str, "cont": str, "sold": float}]}
},
"currency": "RON"
}
Args:
company_id: Company ID
jwt_token: JWT authentication token
Returns:
Dict with two keys:
- 'casa': Dict with 'accounts' (list) and 'total' (float)
- 'banca': Dict with 'accounts' (list) and 'total' (float)
None if request fails
Example:
data = await get_treasury_breakdown_split(1, token)
casa_total = data['casa']['total'] # Total cash balance
bank_accounts = data['banca']['accounts'] # List of bank accounts
"""
try:
client = get_backend_client()
async with client:
breakdown = await client.get_treasury_breakdown(
company_id=company_id,
jwt_token=jwt_token
)
if not breakdown:
return None
# Backend already splits data into casa and banca
# Transform backend structure to match formatter expectations
breakdown_data = breakdown.get('breakdown', {})
casa_data = breakdown_data.get('casa', {})
banca_data = breakdown_data.get('banca', {})
# Transform items to accounts format (nume->name, sold->balance)
casa_accounts = [
{
'name': item.get('nume', f"Cont {item.get('cont', 'N/A')}"),
'balance': float(item.get('sold', 0)),
'cont': item.get('cont', '')
}
for item in casa_data.get('items', [])
]
banca_accounts = [
{
'name': item.get('nume', f"Cont {item.get('cont', 'N/A')}"),
'balance': float(item.get('sold', 0)),
'cont': item.get('cont', '')
}
for item in banca_data.get('items', [])
]
result = {
'casa': {
'accounts': casa_accounts,
'total': float(casa_data.get('total', 0))
},
'banca': {
'accounts': banca_accounts,
'total': float(banca_data.get('total', 0))
}
}
# Pass through cache metadata if present
if 'cache_hit' in breakdown:
result['cache_hit'] = breakdown['cache_hit']
if 'response_time_ms' in breakdown:
result['response_time_ms'] = breakdown['response_time_ms']
if 'cache_source' in breakdown:
result['cache_source'] = breakdown['cache_source']
return result
except Exception as e:
logger.error(f"Error getting treasury breakdown split: {e}", exc_info=True)
return None
async def get_clients_with_maturity(
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get clients list with maturity breakdown.
Uses maturity analysis endpoint which returns client summaries
with amounts and overdue status.
Backend returns:
{
"clients": [{"name": str, "amount": float, "dueDate": str, "daysOverdue": int}],
"suppliers": [...],
"balance": float,
"metadata": {...}
}
Args:
company_id: Company ID
jwt_token: JWT authentication token
Returns:
Dict with:
- 'clients': List of client dicts (id, name, balance)
- 'maturity': Dict with 'in_term', 'overdue', 'total' amounts
None if request fails
Example:
data = await get_clients_with_maturity(1, token)
clients = data['clients'] # List of all clients
overdue = data['maturity']['overdue'] # Overdue amount
"""
try:
client = get_backend_client()
async with client:
# Get maturity analysis (contains client summaries)
maturity_response = await client.get_maturity_data(
company_id=company_id,
jwt_token=jwt_token,
period='all'
)
if not maturity_response:
return None
# Extract clients from maturity response
clients_raw = maturity_response.get('clients', [])
# Transform to expected format: amount → balance
clients = [
{
'name': c.get('name', 'N/A'),
'balance': float(c.get('amount', 0)),
'daysOverdue': c.get('daysOverdue', 0)
}
for c in clients_raw
]
# Calculate maturity breakdown from clients data
total = sum(c['balance'] for c in clients)
overdue = sum(c['balance'] for c in clients if c.get('daysOverdue', 0) > 0)
in_term = total - overdue
result = {
'clients': clients,
'maturity': {
'in_term': in_term,
'overdue': overdue,
'total': total
}
}
# Pass through cache metadata if present
if 'cache_hit' in maturity_response:
result['cache_hit'] = maturity_response['cache_hit']
if 'response_time_ms' in maturity_response:
result['response_time_ms'] = maturity_response['response_time_ms']
if 'cache_source' in maturity_response:
result['cache_source'] = maturity_response['cache_source']
return result
except Exception as e:
logger.error(f"Error getting clients with maturity: {e}", exc_info=True)
return None
async def get_suppliers_with_maturity(
company_id: int,
jwt_token: str
) -> Optional[Dict[str, Any]]:
"""
Get suppliers list with maturity breakdown.
Uses maturity analysis endpoint which returns supplier summaries
with amounts and overdue status.
Backend returns:
{
"clients": [...],
"suppliers": [{"name": str, "amount": float, "dueDate": str, "daysOverdue": int}],
"balance": float,
"metadata": {...}
}
Args:
company_id: Company ID
jwt_token: JWT authentication token
Returns:
Dict with:
- 'suppliers': List of supplier dicts (id, name, balance)
- 'maturity': Dict with 'in_term', 'overdue', 'total' amounts
None if request fails
Example:
data = await get_suppliers_with_maturity(1, token)
suppliers = data['suppliers'] # List of all suppliers
in_term = data['maturity']['in_term'] # In-term amount
"""
try:
client = get_backend_client()
async with client:
# Get maturity analysis (contains supplier summaries)
maturity_response = await client.get_maturity_data(
company_id=company_id,
jwt_token=jwt_token,
period='all'
)
if not maturity_response:
return None
# Extract suppliers from maturity response
suppliers_raw = maturity_response.get('suppliers', [])
# Transform to expected format: amount → balance
suppliers = [
{
'name': s.get('name', 'N/A'),
'balance': float(s.get('amount', 0)),
'daysOverdue': s.get('daysOverdue', 0)
}
for s in suppliers_raw
]
# Calculate maturity breakdown from suppliers data
total = sum(s['balance'] for s in suppliers)
overdue = sum(s['balance'] for s in suppliers if s.get('daysOverdue', 0) > 0)
in_term = total - overdue
result = {
'suppliers': suppliers,
'maturity': {
'in_term': in_term,
'overdue': overdue,
'total': total
}
}
# Pass through cache metadata if present
if 'cache_hit' in maturity_response:
result['cache_hit'] = maturity_response['cache_hit']
if 'response_time_ms' in maturity_response:
result['response_time_ms'] = maturity_response['response_time_ms']
if 'cache_source' in maturity_response:
result['cache_source'] = maturity_response['cache_source']
return result
except Exception as e:
logger.error(f"Error getting suppliers with maturity: {e}", exc_info=True)
return None
async def get_cashflow_evolution_data(
company_id: int,
jwt_token: str,
period: str = "12m"
) -> Optional[Dict[str, Any]]:
"""
Get cash flow evolution data with YTD comparison.
Uses trends endpoint which returns 12-month historical data for current and previous year.
Calculates YTD for comparison and extracts last 12 months in reverse chronological order.
Args:
company_id: Company ID
jwt_token: JWT authentication token
period: Period for trends data (default: "12m")
Returns:
Dict with:
- 'performance': Dict with YTD data for current and previous year
- 'monthly': Dict with last 12 months data (reverse chronological) + prev year comparison
None if request fails
Example:
data = await get_cashflow_evolution_data(1, token)
ytd_2025 = data['performance']['current_year']
ytd_2024 = data['performance']['previous_year']
"""
try:
client = get_backend_client()
async with client:
# Get trends data (12 months of historical data)
trends_data = await client.get_trends(
company_id=company_id,
jwt_token=jwt_token,
period="12m"
)
if not trends_data:
return None
# Extract current year data
periods = trends_data.get('periods', []) # ["2024-01", "2024-02", ...]
clienti_incasat = trends_data.get('clienti_incasat', [])
furnizori_achitat = trends_data.get('furnizori_achitat', [])
# Extract previous year data
previous_periods = trends_data.get('previous_periods', [])
clienti_incasat_prev = trends_data.get('clienti_incasat_prev', [])
furnizori_achitat_prev = trends_data.get('furnizori_achitat_prev', [])
if not periods or not clienti_incasat or not furnizori_achitat:
logger.warning("Trends data missing required fields")
return None
# Calculate YTD (Year-To-Date) = sum of all available months
incasari_ytd = sum(clienti_incasat)
plati_ytd = sum(furnizori_achitat)
net_ytd = incasari_ytd - plati_ytd
incasari_ytd_prev = sum(clienti_incasat_prev) if clienti_incasat_prev else 0
plati_ytd_prev = sum(furnizori_achitat_prev) if furnizori_achitat_prev else 0
net_ytd_prev = incasari_ytd_prev - plati_ytd_prev
# Extract years from periods
current_year = periods[-1].split('-')[0] if periods else "2025"
previous_year = previous_periods[-1].split('-')[0] if previous_periods else "2024"
# Take last 12 months (current year)
last_12_periods = periods[-12:]
last_12_incasari = clienti_incasat[-12:]
last_12_plati = furnizori_achitat[-12:]
# Take corresponding previous year months
last_12_periods_prev = previous_periods[-12:] if previous_periods else []
last_12_incasari_prev = clienti_incasat_prev[-12:] if clienti_incasat_prev else [0] * 12
last_12_plati_prev = furnizori_achitat_prev[-12:] if furnizori_achitat_prev else [0] * 12
# Month abbreviations (Romanian)
month_abbr = {
'01': 'Ian', '02': 'Feb', '03': 'Mar', '04': 'Apr',
'05': 'Mai', '06': 'Iun', '07': 'Iul', '08': 'Aug',
'09': 'Sep', '10': 'Oct', '11': 'Noi', '12': 'Dec'
}
# Format months as "Noi'25/'24"
formatted_months = []
for i, period_str in enumerate(last_12_periods):
if '-' in period_str:
year = period_str.split('-')[0][-2:] # Last 2 digits: "25"
month_num = period_str.split('-')[1]
month_name = month_abbr.get(month_num, month_num)
# Get previous year month
prev_year = previous_year[-2:] if previous_year else "24"
formatted_months.append(f"{month_name}'{year}/'{prev_year}")
else:
formatted_months.append(period_str)
# Reverse chronological order (newest first)
formatted_months.reverse()
last_12_incasari.reverse()
last_12_plati.reverse()
last_12_incasari_prev.reverse()
last_12_plati_prev.reverse()
# Build performance summary (YTD)
performance = {
'current_year': {
'year': current_year,
'incasari': incasari_ytd,
'plati': plati_ytd,
'net': net_ytd
},
'previous_year': {
'year': previous_year,
'incasari': incasari_ytd_prev,
'plati': plati_ytd_prev,
'net': net_ytd_prev
}
}
# Build monthly breakdown (reverse chronological with prev year comparison)
monthly = {
'months': formatted_months,
'incasari': last_12_incasari,
'plati': last_12_plati,
'incasari_prev': last_12_incasari_prev,
'plati_prev': last_12_plati_prev
}
result = {
'performance': performance,
'monthly': monthly
}
# Pass through cache metadata if present
if 'cache_hit' in trends_data:
result['cache_hit'] = trends_data['cache_hit']
if 'response_time_ms' in trends_data:
result['response_time_ms'] = trends_data['response_time_ms']
if 'cache_source' in trends_data:
result['cache_source'] = trends_data['cache_source']
return result
except Exception as e:
logger.error(f"Error getting cashflow evolution data: {e}", exc_info=True)
return None
async def get_client_invoices(
company_id: int,
client_name: str,
jwt_token: str
) -> List[Dict[str, Any]]:
"""
Get invoices for a specific client.
Args:
company_id: Company ID
client_name: Client name to filter by
jwt_token: JWT authentication token
Returns:
List of invoice dicts for the specified client
Example:
invoices = await get_client_invoices(1, "ACME Corp", token)
for inv in invoices:
print(inv['number'], inv['amount'])
"""
try:
logger.info(f"Fetching invoices for client '{client_name}' (company_id={company_id})")
client = get_backend_client()
async with client:
# Filter only by unpaid invoices (with balance > 0)
invoices = await client.search_invoices(
company_id=company_id,
jwt_token=jwt_token,
filters={
'partner_type': 'CLIENTI',
'partner_name': client_name,
'only_unpaid': True # Only show unpaid invoices (matching balance > 0)
}
)
logger.info(f"Found {len(invoices) if invoices else 0} invoices for client '{client_name}'")
if invoices:
logger.debug(f"First invoice sample: {invoices[0]}")
return invoices or []
except Exception as e:
logger.error(f"Error getting client invoices for '{client_name}': {e}", exc_info=True)
return []
async def get_supplier_invoices(
company_id: int,
supplier_name: str,
jwt_token: str
) -> List[Dict[str, Any]]:
"""
Get invoices for a specific supplier.
Args:
company_id: Company ID
supplier_name: Supplier name to filter by
jwt_token: JWT authentication token
Returns:
List of invoice dicts for the specified supplier
Example:
invoices = await get_supplier_invoices(1, "Supplier Inc", token)
for inv in invoices:
print(inv['number'], inv['amount'])
"""
try:
logger.info(f"Fetching invoices for supplier '{supplier_name}' (company_id={company_id})")
client = get_backend_client()
async with client:
# Filter only by unpaid invoices (with balance > 0)
invoices = await client.search_invoices(
company_id=company_id,
jwt_token=jwt_token,
filters={
'partner_type': 'FURNIZORI',
'partner_name': supplier_name,
'only_unpaid': True # Only show unpaid invoices (matching balance > 0)
}
)
logger.info(f"Found {len(invoices) if invoices else 0} invoices for supplier '{supplier_name}'")
if invoices:
logger.debug(f"First invoice sample: {invoices[0]}")
return invoices or []
except Exception as e:
logger.error(f"Error getting supplier invoices for '{supplier_name}': {e}", exc_info=True)
return []
# Export all helper functions
__all__ = [
'get_active_company_or_prompt',
'search_companies_by_name',
'create_company_selection_keyboard',
'create_company_selection_keyboard_paginated',
'format_company_context_footer',
'get_treasury_breakdown_split',
'get_clients_with_maturity',
'get_suppliers_with_maturity',
'get_cashflow_evolution_data',
'get_client_invoices',
'get_supplier_invoices'
]

View File

@@ -0,0 +1,607 @@
"""
Menu builders for Telegram bot inline keyboards.
This module provides functions to create InlineKeyboardMarkup objects
for different menu levels and navigation patterns in the bot.
NOTE: All button texts are plain text WITHOUT emojis/icons as per requirements.
BUTTON WIDTH: Inline keyboard width is determined by the message text width.
To make buttons wider, we pad message text with invisible characters.
"""
from telegram import InlineKeyboardButton, InlineKeyboardMarkup
from typing import List, Dict, Optional
from datetime import datetime
# ============================================================================
# IMPORTANT: BUTTON WIDTH CONFIGURATION
# ============================================================================
# Inline keyboard button width is determined by MESSAGE TEXT WIDTH!
# DO NOT REMOVE PADDING - it makes buttons wide like BotFather!
# ============================================================================
# Zero-Width Joiner character - invisible but prevents Telegram from trimming spaces
# This character has ZERO width (invisible) but prevents space trimming
ZERO_WIDTH_JOINER = '\u200D'
# Target character count per line to make buttons VERY WIDE
# Higher value = wider buttons (BotFather uses ~45-50 chars)
# DO NOT DECREASE THIS VALUE - buttons will become narrow!
TARGET_WIDTH = 50 # Increased from 40 to make buttons WIDER
# Enable/disable padding globally (useful for testing)
# KEEP THIS TRUE - disabling makes buttons narrow!
ENABLE_BUTTON_PADDING = True
def _get_current_month_ro() -> str:
"""Get current month name in Romanian."""
months_ro = {
1: "Ianuarie", 2: "Februarie", 3: "Martie", 4: "Aprilie",
5: "Mai", 6: "Iunie", 7: "Iulie", 8: "August",
9: "Septembrie", 10: "Octombrie", 11: "Noiembrie", 12: "Decembrie"
}
now = datetime.now()
return f"{months_ro[now.month]} {now.year}"
def _pad_line_for_wide_buttons(text: str, target_width: int = TARGET_WIDTH) -> str:
"""
Pad a single line of text with invisible characters to make inline buttons wider.
⚠️ CRITICAL: DO NOT REMOVE THIS FUNCTION - it makes buttons wide!
The width of InlineKeyboardMarkup buttons is determined by the message text width.
By padding text with spaces + zero-width joiner, we force wider buttons.
How it works:
1. Calculate how many characters needed to reach target_width
2. Add spaces + Zero-Width Joiner (invisible character)
3. Result: wider message = wider buttons (like BotFather)
Args:
text: The text line to pad
target_width: Target character count (default 50 for VERY WIDE buttons)
Returns:
Padded text with invisible characters (user sees normal text, Telegram sees wider text)
"""
current_length = len(text)
if current_length >= target_width:
return text
# ⚠️ DO NOT REMOVE: Add spaces + zero-width joiner at the end
# This makes buttons WIDE without changing visible text!
padding_needed = target_width - current_length
padding = ' ' * padding_needed + ZERO_WIDTH_JOINER
return text + padding
def pad_message_for_wide_buttons(message: str, target_width: int = TARGET_WIDTH, force: bool = False) -> str:
"""
Pad all lines in a message to make inline keyboard buttons wider.
⚠️ CRITICAL: DO NOT REMOVE THIS FUNCTION - it makes buttons wide!
This is the MAIN function that applies padding to ALL messages with keyboards.
Why we need this:
- Telegram determines button width based on MESSAGE TEXT width
- Short messages = narrow buttons
- Wide messages (with invisible padding) = WIDE buttons like BotFather
Args:
message: Multi-line message text
target_width: Target character count per line (default 50)
force: Force padding even if ENABLE_BUTTON_PADDING is False
Returns:
Message with all lines padded (if enabled or forced)
"""
# ⚠️ DO NOT REMOVE: Check if padding is enabled
if not ENABLE_BUTTON_PADDING and not force:
return message
# ⚠️ DO NOT REMOVE: Apply padding to each line
lines = message.split('\n')
padded_lines = [_pad_line_for_wide_buttons(line, target_width) for line in lines]
return '\n'.join(padded_lines)
def format_response_with_company(
content: str,
company_name: Optional[str] = None,
apply_padding: bool = True
) -> str:
"""
Format a response with company name at the top (simplified format).
⚠️ IMPORTANT: Applies padding by default to make buttons WIDE!
Format:
Company Name
[Content]
Args:
content: The main content text
company_name: Company name to show at top (if None, just returns content)
apply_padding: Whether to apply invisible padding for wider buttons (default TRUE)
Returns:
Formatted response with company name header AND padding for wide buttons
"""
if company_name:
message = f"{company_name}\n\n{content}"
else:
message = content
# ⚠️ DO NOT REMOVE: Apply padding to make inline keyboard buttons WIDE!
# Without this, buttons become narrow like before
if apply_padding:
message = pad_message_for_wide_buttons(message)
return message
def get_menu_message(
company_name: Optional[str] = None,
company_cui: Optional[str] = None,
apply_padding: bool = True
) -> str:
"""
Get the menu message text with company details (simplified format).
⚠️ IMPORTANT: Applies padding by default to make menu buttons WIDE!
Format without labels - just values:
- Line 1: Company name
- Line 2: CUI
- Line 3: Accounting month
Args:
company_name: Active company name
company_cui: Company fiscal code (CUI)
apply_padding: Whether to apply invisible padding for wider buttons (default TRUE)
Returns:
Formatted message text for menu WITH padding for wide buttons
"""
if company_name:
# Simplified format: just values, no labels
message = f"{company_name}\n"
if company_cui:
message += f"{company_cui}\n"
message += f"{_get_current_month_ro()}"
else:
# No company selected - just prompt
message = "Selectează o companie pentru a continua"
# ⚠️ DO NOT REMOVE: Apply padding to make inline keyboard buttons WIDE!
# This makes buttons look like BotFather (wide, not narrow)
if apply_padding:
message = pad_message_for_wide_buttons(message)
return message
def create_main_menu(
company_name: Optional[str] = None,
company_cui: Optional[str] = None,
is_authenticated: bool = True,
cache_enabled: Optional[bool] = None
) -> InlineKeyboardMarkup:
"""
Create main menu keyboard (Level 1) with financial options.
Layout: Full-width buttons with company selection at top
Args:
company_name: Active company name, or None if no company selected
company_cui: Company fiscal code (CUI), or None
is_authenticated: Whether user is authenticated (affects Login/Logout button)
cache_enabled: Cache state for user (True=ON, False=OFF, None=unknown)
Returns:
InlineKeyboardMarkup with main menu buttons
"""
keyboard = []
# Only show financial menu if authenticated
if is_authenticated:
# Row 1: Company selection (full width, single line - InlineKeyboardButton doesn't support multiline)
if company_name:
# Short company name for button (CUI and month will be shown in message text)
# Truncate long names to fit in button
max_length = 35
display_name = company_name if len(company_name) <= max_length else company_name[:max_length-3] + "..."
keyboard.append([
InlineKeyboardButton(
f"{display_name}",
callback_data="menu:select_company"
)
])
else:
keyboard.append([
InlineKeyboardButton(
"Selectare Companie",
callback_data="menu:select_company"
)
])
# Rows 2-4: Financial options (2 buttons per row, made wide by message text padding)
keyboard.extend([
[
InlineKeyboardButton("Sold Companie", callback_data="menu:sold"),
InlineKeyboardButton("Trezorerie Casa", callback_data="menu:casa")
],
[
InlineKeyboardButton("Trezorerie Banca", callback_data="menu:banca"),
InlineKeyboardButton("Sold Clienti", callback_data="menu:clienti")
],
[
InlineKeyboardButton("Sold Furnizori", callback_data="menu:furnizori"),
InlineKeyboardButton("Evolutie Incasari", callback_data="menu:evolutie")
]
])
# Row 5: Cache options (2 buttons per row, only if authenticated)
if is_authenticated:
# Dynamic cache toggle button showing current state
if cache_enabled is None:
cache_button_text = "Toggle Cache"
elif cache_enabled:
cache_button_text = "Cache: ON"
else:
cache_button_text = "Cache: OFF"
keyboard.append([
InlineKeyboardButton(cache_button_text, callback_data="menu:togglecache"),
InlineKeyboardButton("Clear Cache", callback_data="menu:clearcache")
])
# Row 6: Help/Logout buttons (authenticated) or Login button (non-authenticated)
if is_authenticated:
keyboard.append([
InlineKeyboardButton("Help", callback_data="action:help"),
InlineKeyboardButton("Logout", callback_data="action:logout")
])
else:
keyboard.append([
InlineKeyboardButton("Login", callback_data="action:login")
])
return InlineKeyboardMarkup(keyboard)
def create_action_buttons(current_view: str, show_export: bool = True, show_back: bool = False, show_refresh: bool = True) -> InlineKeyboardMarkup:
"""
Create action buttons for responses (Refresh, Export, Back, Menu).
Layout (buttons made wide by message text padding):
[Refresh] [Export] (if show_refresh=True and show_export=True)
[Refresh] (if show_refresh=True and show_export=False)
[Înapoi] (if show_back=True, full width)
[Menu] (full width, always shown)
Args:
current_view: View identifier for refresh callback (e.g., "sold", "clienti")
show_export: Whether to show Export button
show_back: Whether to show Back button to list
show_refresh: Whether to show Refresh button
Returns:
InlineKeyboardMarkup with action buttons
"""
keyboard = []
# Row 1: Refresh and optionally Export (only if show_refresh is True)
if show_refresh:
if show_export:
keyboard.append([
InlineKeyboardButton("Refresh", callback_data=f"action:refresh:{current_view}"),
InlineKeyboardButton("Export", callback_data=f"action:export:{current_view}")
])
else:
keyboard.append([
InlineKeyboardButton("Refresh", callback_data=f"action:refresh:{current_view}")
])
# Row 2: Back to List (if show_back is True)
if show_back:
# Determine back callback based on current view
# ✅ FIX: Handle detail views (client_detail:name, supplier_detail:name)
if current_view.startswith("client_detail:"):
back_callback = "menu:clienti" # Back to client list
elif current_view.startswith("supplier_detail:"):
back_callback = "menu:furnizori" # Back to supplier list
elif current_view == "clienti":
back_callback = "clients_page:0" # Match handlers.py:1689
elif current_view == "furnizori":
back_callback = "suppliers_page:0" # Match handlers.py:1721
else:
back_callback = "action:menu" # Fallback to menu
keyboard.append([
InlineKeyboardButton("« Înapoi", callback_data=back_callback)
])
# Row 3: Back to Menu (full width)
keyboard.append([
InlineKeyboardButton("Meniu Principal", callback_data="action:menu")
])
return InlineKeyboardMarkup(keyboard)
def create_client_list_keyboard(clients: List[Dict], max_items: int = 10, page: int = 0) -> InlineKeyboardMarkup:
"""
Create client list keyboard (Level 2) with client buttons and pagination.
Layout: 1 column for clients, pagination controls, 2 columns for navigation
Args:
clients: List of client dicts with keys: id, name, balance
max_items: Maximum number of clients per page (default: 10)
page: Current page number (0-indexed)
Returns:
InlineKeyboardMarkup with client list buttons and pagination
"""
keyboard = []
# Sort clients alphabetically by name
sorted_clients = sorted(clients, key=lambda x: x.get('name', '').lower())
# Calculate pagination
total_clients = len(sorted_clients)
total_pages = (total_clients + max_items - 1) // max_items # Ceiling division
start_idx = page * max_items
end_idx = min(start_idx + max_items, total_clients)
# Display clients for current page
display_clients = sorted_clients[start_idx:end_idx]
# Add client buttons (1 per row)
for client in display_clients:
client_name = client.get('name', 'N/A')
balance = client.get('balance', 0)
# Format balance with thousands separator
balance_str = f"{balance:,.0f}" if balance else "0"
button_text = f"{client_name} - {balance_str} RON"
# Limit callback_data to 64 bytes (Telegram limit)
# Use only first 40 chars of name to stay within limit
safe_name = client_name[:40] if len(client_name) > 40 else client_name
keyboard.append([
InlineKeyboardButton(
button_text,
callback_data=f"details:client:{safe_name}:0" # name:page
)
])
# Pagination controls (only if more than one page)
if total_pages > 1:
nav_buttons = []
# Previous button
if page > 0:
nav_buttons.append(
InlineKeyboardButton("< Anterior", callback_data=f"clients_page:{page-1}")
)
# Page indicator (non-clickable)
nav_buttons.append(
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
)
# Next button
if page < total_pages - 1:
nav_buttons.append(
InlineKeyboardButton("Următor >", callback_data=f"clients_page:{page+1}")
)
keyboard.append(nav_buttons)
# Navigation row: Back button only
keyboard.append([
InlineKeyboardButton("< Înapoi", callback_data="action:menu")
])
return InlineKeyboardMarkup(keyboard)
def create_supplier_list_keyboard(suppliers: List[Dict], max_items: int = 10, page: int = 0) -> InlineKeyboardMarkup:
"""
Create supplier list keyboard (Level 2) with supplier buttons and pagination.
Layout: 1 column for suppliers, pagination controls, 2 columns for navigation
Args:
suppliers: List of supplier dicts with keys: id, name, balance
max_items: Maximum number of suppliers per page (default: 10)
page: Current page number (0-indexed)
Returns:
InlineKeyboardMarkup with supplier list buttons and pagination
"""
keyboard = []
# Sort suppliers alphabetically by name
sorted_suppliers = sorted(suppliers, key=lambda x: x.get('name', '').lower())
# Calculate pagination
total_suppliers = len(sorted_suppliers)
total_pages = (total_suppliers + max_items - 1) // max_items # Ceiling division
start_idx = page * max_items
end_idx = min(start_idx + max_items, total_suppliers)
# Display suppliers for current page
display_suppliers = sorted_suppliers[start_idx:end_idx]
# Add supplier buttons (1 per row)
for supplier in display_suppliers:
supplier_name = supplier.get('name', 'N/A')
balance = supplier.get('balance', 0)
# Format balance with thousands separator
balance_str = f"{balance:,.0f}" if balance else "0"
button_text = f"{supplier_name} - {balance_str} RON"
# Limit callback_data to 64 bytes (Telegram limit)
# Use only first 40 chars of name to stay within limit
safe_name = supplier_name[:40] if len(supplier_name) > 40 else supplier_name
keyboard.append([
InlineKeyboardButton(
button_text,
callback_data=f"details:supplier:{safe_name}:0" # name:page
)
])
# Pagination controls (only if more than one page)
if total_pages > 1:
nav_buttons = []
# Previous button
if page > 0:
nav_buttons.append(
InlineKeyboardButton("< Anterior", callback_data=f"suppliers_page:{page-1}")
)
# Page indicator (non-clickable)
nav_buttons.append(
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
)
# Next button
if page < total_pages - 1:
nav_buttons.append(
InlineKeyboardButton("Următor >", callback_data=f"suppliers_page:{page+1}")
)
keyboard.append(nav_buttons)
# Navigation row: Back button only
keyboard.append([
InlineKeyboardButton("< Înapoi", callback_data="action:menu")
])
return InlineKeyboardMarkup(keyboard)
def create_invoice_list_keyboard(
invoices: List[Dict],
partner_type: str,
partner_name: str,
max_items: int = 10,
page: int = 0
) -> InlineKeyboardMarkup:
"""
Create invoice list keyboard (Level 3) with invoice buttons and pagination.
Layout: 1 column for invoices, pagination controls, 2 columns for navigation
Args:
invoices: List of invoice dicts with keys: id, number, amount, status
partner_type: "CLIENTI" or "FURNIZORI"
partner_name: Client/supplier name (for back navigation)
max_items: Maximum number of invoices per page (default: 10)
page: Current page number (0-indexed)
Returns:
InlineKeyboardMarkup with invoice list buttons and pagination
"""
keyboard = []
# Limit partner_name to 30 chars for Telegram callback_data limit (64 bytes)
safe_partner_name = partner_name[:30] if len(partner_name) > 30 else partner_name
# Calculate pagination
total_invoices = len(invoices)
total_pages = (total_invoices + max_items - 1) // max_items # Ceiling division
start_idx = page * max_items
end_idx = min(start_idx + max_items, total_invoices)
# Display invoices for current page
display_invoices = invoices[start_idx:end_idx]
# Add invoice buttons (1 per row)
for invoice in display_invoices:
invoice_id = invoice.get('id', 0)
invoice_number = invoice.get('number', 'N/A')
amount = invoice.get('amount', 0)
status = invoice.get('status', 'unknown')
# Format amount with thousands separator
amount_str = f"{amount:,.0f}" if amount else "0"
# Status text indicator (no emojis)
status_text = "[NEPLATIT]" if status in ['unpaid', 'overdue'] else "[PLATIT]"
button_text = f"{status_text} {invoice_number} - {amount_str} RON"
keyboard.append([
InlineKeyboardButton(
button_text,
callback_data=f"invoice:{partner_type}:{invoice_id}"
)
])
# Pagination controls (only if more than one page)
if total_pages > 1:
nav_buttons = []
# Previous button
if page > 0:
nav_buttons.append(
InlineKeyboardButton("< Anterior", callback_data=f"invoices_page:{partner_type}:{safe_partner_name}:{page-1}")
)
# Page indicator (non-clickable)
nav_buttons.append(
InlineKeyboardButton(f"Pagina {page+1}/{total_pages}", callback_data="noop")
)
# Next button
if page < total_pages - 1:
nav_buttons.append(
InlineKeyboardButton("Următor >", callback_data=f"invoices_page:{partner_type}:{safe_partner_name}:{page+1}")
)
keyboard.append(nav_buttons)
# Navigation row: Back and Export (2 buttons per row)
back_target = "clienti" if partner_type == "CLIENTI" else "furnizori"
keyboard.append([
InlineKeyboardButton("< Înapoi", callback_data=f"nav:back:{back_target}"),
InlineKeyboardButton("Export", callback_data=f"action:export:{partner_type.lower()}")
])
return InlineKeyboardMarkup(keyboard)
def create_navigation_buttons(back_to: str) -> InlineKeyboardMarkup:
"""
Create simple navigation buttons (just Back button).
Args:
back_to: Target location identifier (e.g., "menu", "clienti", "furnizori")
Returns:
InlineKeyboardMarkup with navigation button
"""
keyboard = [
[
InlineKeyboardButton(
f"< Înapoi la {back_to}",
callback_data=f"nav:back:{back_to}"
)
]
]
return InlineKeyboardMarkup(keyboard)

View File

@@ -0,0 +1,316 @@
"""
Main entry point for ROA2WEB Telegram Bot
This bot provides access to the ROA2WEB ERP system through Telegram
using direct command handlers for financial data queries.
"""
import asyncio
import logging
import os
from pathlib import Path
from dotenv import load_dotenv
import uvicorn
from threading import Thread
# ============================================================================
# LOAD ENVIRONMENT VARIABLES FIRST - BEFORE ANY APP IMPORTS
# ============================================================================
# This ensures all modules can access environment variables at import time
env_path = Path(__file__).parent.parent / '.env'
load_dotenv(env_path)
# Telegram imports
from telegram.ext import (
Application,
CommandHandler,
CallbackQueryHandler,
MessageHandler,
filters
)
# Import database initialization
from backend.modules.telegram.db import (
init_database,
cleanup_expired_codes,
cleanup_expired_sessions,
cleanup_expired_email_codes
)
# Import bot handlers
from backend.modules.telegram.bot.handlers import (
start_command,
help_command,
clear_command,
companies_command,
unlink_command,
selectcompany_command,
dashboard_command,
sold_command,
facturi_command,
trezorerie_command,
# FAZA 3: New command handlers with button interface
menu_command,
trezorerie_casa_command,
trezorerie_banca_command,
clienti_command,
furnizori_command,
evolutie_command,
# FAZA 6: Cache management commands
clearcache_command,
togglecache_command,
# Text message handlers
handle_text_message,
# FAZA 4: Callback and error handlers
button_callback,
error_handler
)
# Import email authentication handler
from backend.modules.telegram.bot.email_handlers import email_login_handler
# Import internal API
from backend.modules.telegram.internal_api import internal_api
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Environment variables (already loaded above)
TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
BACKEND_URL = os.getenv('BACKEND_URL', 'http://localhost:8000')
INTERNAL_API_PORT = int(os.getenv('INTERNAL_API_PORT', '8002'))
# ============================================================================
# TELEGRAM BOT SETUP
# ============================================================================
def create_telegram_application() -> Application:
"""
Create and configure the Telegram bot application.
Returns:
Application: Configured Telegram application
"""
logger.info("Creating Telegram application...")
# Create application
application = Application.builder().token(TELEGRAM_BOT_TOKEN).build()
# Register email authentication conversation handler (must be before other handlers)
application.add_handler(email_login_handler)
# Register essential command handlers
application.add_handler(CommandHandler("start", start_command))
application.add_handler(CommandHandler("menu", menu_command))
application.add_handler(CommandHandler("help", help_command))
application.add_handler(CommandHandler("unlink", unlink_command))
# =========================================================================
# LEGACY COMMAND HANDLERS (kept for backwards compatibility, hidden from help)
# =========================================================================
# NOTE: These commands are redundant with the button interface.
# They're kept for users who already know them, but we push buttons in help.
# Consider removing completely if migration is successful.
application.add_handler(CommandHandler("clear", clear_command))
application.add_handler(CommandHandler("companies", companies_command))
application.add_handler(CommandHandler("selectcompany", selectcompany_command))
application.add_handler(CommandHandler("dashboard", dashboard_command))
application.add_handler(CommandHandler("sold", sold_command))
application.add_handler(CommandHandler("facturi", facturi_command))
application.add_handler(CommandHandler("trezorerie", trezorerie_command))
application.add_handler(CommandHandler("trezorerie_casa", trezorerie_casa_command))
application.add_handler(CommandHandler("trezorerie_banca", trezorerie_banca_command))
application.add_handler(CommandHandler("clienti", clienti_command))
application.add_handler(CommandHandler("furnizori", furnizori_command))
application.add_handler(CommandHandler("evolutie", evolutie_command))
# FAZA 6: Cache management commands
application.add_handler(CommandHandler("clearcache", clearcache_command))
application.add_handler(CommandHandler("togglecache", togglecache_command))
# Text message handler (for direct code input and future NLP)
# IMPORTANT: This must be registered BEFORE CallbackQueryHandler
# filters.TEXT & ~filters.COMMAND ensures we only process non-command text messages
application.add_handler(MessageHandler(
filters.TEXT & ~filters.COMMAND,
handle_text_message
))
# FAZA 4: Register callback query handler (for inline buttons)
application.add_handler(CallbackQueryHandler(button_callback))
# Register error handler
application.add_error_handler(error_handler)
logger.info("Telegram application configured with all handlers")
return application
# ============================================================================
# INTERNAL API SERVER
# ============================================================================
def run_internal_api():
"""
Run the internal FastAPI server in a separate thread.
This API handles communication from the backend (saving auth codes).
"""
logger.info(f"Starting internal API on port {INTERNAL_API_PORT}...")
uvicorn.run(
internal_api,
host="0.0.0.0",
port=INTERNAL_API_PORT,
log_level="info"
)
# ============================================================================
# STARTUP/SHUTDOWN
# ============================================================================
async def startup():
"""
Initialize the bot application on startup.
"""
logger.info("🚀 ROA2WEB Telegram Bot - Starting up...")
# Initialize database
try:
logger.info("Initializing SQLite database...")
await init_database()
logger.info("✅ Database initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize database: {e}")
raise
# Cleanup expired data
try:
logger.info("Cleaning up expired data...")
expired_codes = await cleanup_expired_codes()
expired_sessions = await cleanup_expired_sessions()
expired_email_codes = await cleanup_expired_email_codes()
logger.info(f"✅ Cleanup complete: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
except Exception as e:
logger.warning(f"⚠️ Cleanup failed (non-critical): {e}")
logger.info("✅ Startup complete")
async def shutdown():
"""
Clean up resources on shutdown.
"""
logger.info("👋 ROA2WEB Telegram Bot - Shutting down...")
logger.info("✅ Shutdown complete")
async def scheduled_cleanup():
"""
Background task to periodically clean up expired data.
Runs every hour to remove expired auth codes, sessions, and email codes.
"""
while True:
try:
await asyncio.sleep(3600) # Sleep for 1 hour
logger.info("🧹 Running scheduled cleanup...")
expired_codes = await cleanup_expired_codes()
expired_sessions = await cleanup_expired_sessions()
expired_email_codes = await cleanup_expired_email_codes()
logger.info(f"✅ Scheduled cleanup: {expired_codes} codes, {expired_sessions} sessions, {expired_email_codes} email codes removed")
except Exception as e:
logger.error(f"❌ Error in scheduled cleanup: {e}")
# ============================================================================
# MAIN APPLICATION
# ============================================================================
async def main():
"""
Main application entry point.
Runs both the Telegram bot and internal API server concurrently.
"""
try:
# Run startup
await startup()
# Create Telegram application
telegram_app = create_telegram_application()
# Start internal API in a separate thread
api_thread = Thread(target=run_internal_api, daemon=True)
api_thread.start()
logger.info(f"✅ Internal API started on port {INTERNAL_API_PORT}")
# Start scheduled cleanup task in background
cleanup_task = asyncio.create_task(scheduled_cleanup())
logger.info("✅ Scheduled cleanup task started")
# Initialize and start Telegram bot
logger.info("🤖 Starting Telegram bot polling...")
await telegram_app.initialize()
await telegram_app.start()
await telegram_app.updater.start_polling(drop_pending_updates=True)
logger.info("✅ Telegram bot is now running and polling for updates")
logger.info(f"📱 Bot ready to receive messages at @{(await telegram_app.bot.get_me()).username}")
logger.info("🎯 Bot is operational with direct command handlers!")
# Keep running until interrupted
await asyncio.Event().wait()
except KeyboardInterrupt:
logger.info("⚠️ Received interrupt signal")
except Exception as e:
logger.error(f"❌ Fatal error: {e}", exc_info=True)
raise
finally:
# Stop Telegram bot gracefully
try:
if 'telegram_app' in locals():
logger.info("Stopping Telegram bot...")
await telegram_app.updater.stop()
await telegram_app.stop()
await telegram_app.shutdown()
logger.info("✅ Telegram bot stopped")
except Exception as e:
logger.error(f"Error stopping Telegram bot: {e}")
await shutdown()
# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
# Check required environment variables
if not os.getenv('TELEGRAM_BOT_TOKEN'):
logger.error("❌ TELEGRAM_BOT_TOKEN is required")
logger.error("Please set it in .env file")
exit(1)
# Display startup banner
logger.info("=" * 60)
logger.info(" ROA2WEB TELEGRAM BOT")
logger.info(" Financial ERP Assistant with Direct Commands")
logger.info("=" * 60)
# Run the main application
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("👋 Application stopped by user")
except Exception as e:
logger.error(f"❌ Application failed: {e}", exc_info=True)
exit(1)

View File

@@ -0,0 +1,86 @@
"""
Database module for Telegram Bot
Provides SQLite database operations for:
- User management and Oracle account linking
- Authentication code management
- Conversation session management
"""
from .database import (
init_database,
get_db_connection,
cleanup_expired_codes,
cleanup_expired_sessions,
cleanup_expired_email_codes,
get_database_stats,
DB_PATH,
)
from .operations import (
# User operations
create_or_update_user,
get_user,
link_user_to_oracle,
update_user_tokens,
update_user_last_active,
is_user_linked,
is_user_authenticated,
# Auth code operations
create_auth_code,
get_auth_code,
verify_and_use_auth_code,
get_pending_codes_for_user,
# Email auth code operations
get_pending_email_code,
create_email_auth_code,
get_email_auth_code,
increment_failed_attempts,
mark_email_code_used,
delete_user_email_codes,
# Session operations
create_session,
get_session,
get_user_active_session,
update_session_state,
delete_session,
delete_user_sessions,
)
__all__ = [
# Database setup
'init_database',
'get_db_connection',
'cleanup_expired_codes',
'cleanup_expired_sessions',
'cleanup_expired_email_codes',
'get_database_stats',
'DB_PATH',
# User operations
'create_or_update_user',
'get_user',
'link_user_to_oracle',
'update_user_tokens',
'update_user_last_active',
'is_user_linked',
'is_user_authenticated',
# Auth code operations
'create_auth_code',
'get_auth_code',
'verify_and_use_auth_code',
'get_pending_codes_for_user',
# Email auth code operations
'get_pending_email_code',
'create_email_auth_code',
'get_email_auth_code',
'increment_failed_attempts',
'mark_email_code_used',
'delete_user_email_codes',
# Session operations
'create_session',
'get_session',
'get_user_active_session',
'update_session_state',
'delete_session',
'delete_user_sessions',
]

View File

@@ -0,0 +1,310 @@
"""
SQLite Database Setup for Telegram Bot
This module handles database connection, initialization, and schema creation.
Uses aiosqlite for async SQLite operations.
"""
import aiosqlite
import logging
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional
logger = logging.getLogger(__name__)
# Database file location
DB_DIR = Path(__file__).parent.parent.parent / "data"
DB_PATH = DB_DIR / "telegram_bot.db"
async def get_db_connection() -> aiosqlite.Connection:
"""
Get a database connection.
Returns:
aiosqlite.Connection: Database connection
"""
conn = await aiosqlite.connect(DB_PATH)
conn.row_factory = aiosqlite.Row # Enable column access by name
return conn
async def init_database() -> None:
"""
Initialize the database and create all tables.
Safe to call multiple times - only creates tables if they don't exist.
"""
try:
# Ensure data directory exists
DB_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Database directory: {DB_DIR}")
async with aiosqlite.connect(DB_PATH) as db:
# Enable foreign keys
await db.execute("PRAGMA foreign_keys = ON")
# Create telegram_users table
await db.execute("""
CREATE TABLE IF NOT EXISTS telegram_users (
telegram_user_id INTEGER PRIMARY KEY,
username TEXT,
first_name TEXT NOT NULL,
last_name TEXT,
oracle_username TEXT,
jwt_token TEXT,
jwt_refresh_token TEXT,
token_expires_at TIMESTAMP,
linked_at TIMESTAMP,
last_active_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT 1
)
""")
# Create telegram_auth_codes table
await db.execute("""
CREATE TABLE IF NOT EXISTS telegram_auth_codes (
code TEXT PRIMARY KEY,
telegram_user_id INTEGER,
oracle_username TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL,
used BOOLEAN DEFAULT 0,
used_at TIMESTAMP,
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
)
""")
# Create telegram_sessions table
await db.execute("""
CREATE TABLE IF NOT EXISTS telegram_sessions (
session_id TEXT PRIMARY KEY,
telegram_user_id INTEGER NOT NULL,
conversation_state TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL,
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
)
""")
# Create email_auth_codes table (email-based authentication)
await db.execute("""
CREATE TABLE IF NOT EXISTS email_auth_codes (
code TEXT PRIMARY KEY,
email TEXT NOT NULL,
oracle_username TEXT NOT NULL,
telegram_user_id INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL,
used INTEGER DEFAULT 0,
used_at TIMESTAMP,
failed_attempts INTEGER DEFAULT 0,
FOREIGN KEY (telegram_user_id) REFERENCES telegram_users(telegram_user_id)
)
""")
# Create indexes for better query performance
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_auth_codes_telegram_user
ON telegram_auth_codes(telegram_user_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_auth_codes_expires
ON telegram_auth_codes(expires_at)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_telegram_user
ON telegram_sessions(telegram_user_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_expires
ON telegram_sessions(expires_at)
""")
# Create indexes for email_auth_codes table
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_email_auth_email
ON email_auth_codes(email)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_email_auth_telegram_user
ON email_auth_codes(telegram_user_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_email_auth_expires
ON email_auth_codes(expires_at)
""")
await db.commit()
logger.info("Database initialized successfully")
# Log table info
cursor = await db.execute("""
SELECT name FROM sqlite_master
WHERE type='table'
ORDER BY name
""")
tables = await cursor.fetchall()
logger.info(f"Existing tables: {[t[0] for t in tables]}")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def cleanup_expired_codes() -> int:
"""
Delete expired authentication codes from the database.
This should be called periodically (e.g., every hour).
Returns:
int: Number of expired codes deleted
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
DELETE FROM telegram_auth_codes
WHERE expires_at < ?
""", (datetime.now(),))
await db.commit()
deleted = cursor.rowcount
if deleted > 0:
logger.info(f"Cleaned up {deleted} expired auth codes")
return deleted
except Exception as e:
logger.error(f"Failed to cleanup expired codes: {e}")
return 0
async def cleanup_expired_sessions() -> int:
"""
Delete expired sessions from the database.
This should be called periodically (e.g., daily).
Returns:
int: Number of expired sessions deleted
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
DELETE FROM telegram_sessions
WHERE expires_at < ?
""", (datetime.now(),))
await db.commit()
deleted = cursor.rowcount
if deleted > 0:
logger.info(f"Cleaned up {deleted} expired sessions")
return deleted
except Exception as e:
logger.error(f"Failed to cleanup expired sessions: {e}")
return 0
async def cleanup_expired_email_codes() -> int:
"""
Delete expired and old used email codes from the database.
This should be called periodically (e.g., hourly).
Returns:
int: Number of email codes deleted
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
# Delete expired codes or used codes older than 1 day
cursor = await db.execute("""
DELETE FROM email_auth_codes
WHERE expires_at < ?
OR (used = 1 AND used_at < ?)
""", (
datetime.now(),
datetime.now() - timedelta(days=1)
))
await db.commit()
deleted = cursor.rowcount
if deleted > 0:
logger.info(f"Cleaned up {deleted} expired/old email auth codes")
return deleted
except Exception as e:
logger.error(f"Failed to cleanup email codes: {e}")
return 0
async def get_database_stats() -> dict:
"""
Get database statistics for monitoring.
Returns:
dict: Database statistics
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
stats = {}
# Count users
cursor = await db.execute("SELECT COUNT(*) FROM telegram_users")
stats['total_users'] = (await cursor.fetchone())[0]
cursor = await db.execute(
"SELECT COUNT(*) FROM telegram_users WHERE is_active = 1"
)
stats['active_users'] = (await cursor.fetchone())[0]
# Count pending codes
cursor = await db.execute("""
SELECT COUNT(*) FROM telegram_auth_codes
WHERE used = 0 AND expires_at > ?
""", (datetime.now(),))
stats['pending_codes'] = (await cursor.fetchone())[0]
# Count active sessions
cursor = await db.execute("""
SELECT COUNT(*) FROM telegram_sessions
WHERE expires_at > ?
""", (datetime.now(),))
stats['active_sessions'] = (await cursor.fetchone())[0]
# Database file size
if DB_PATH.exists():
stats['db_size_mb'] = DB_PATH.stat().st_size / (1024 * 1024)
else:
stats['db_size_mb'] = 0
return stats
except Exception as e:
logger.error(f"Failed to get database stats: {e}")
return {}
# Export main functions
__all__ = [
'get_db_connection',
'init_database',
'cleanup_expired_codes',
'cleanup_expired_sessions',
'cleanup_expired_email_codes',
'get_database_stats',
'DB_PATH',
]

View File

@@ -0,0 +1,813 @@
"""
Database Operations for Telegram Bot
This module provides CRUD operations for:
- telegram_users: Telegram user management and Oracle account linking
- telegram_auth_codes: Authentication code management
- telegram_sessions: Conversation session management
"""
import logging
import uuid
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
import aiosqlite
from .database import DB_PATH
logger = logging.getLogger(__name__)
# ============================================================================
# TELEGRAM USERS OPERATIONS
# ============================================================================
async def create_or_update_user(
telegram_user_id: int,
username: Optional[str],
first_name: str,
last_name: Optional[str]
) -> bool:
"""
Create or update a Telegram user record.
Args:
telegram_user_id: Telegram user ID
username: Telegram username (without @)
first_name: User's first name
last_name: User's last name
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
INSERT INTO telegram_users (
telegram_user_id, username, first_name, last_name, last_active_at
)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(telegram_user_id) DO UPDATE SET
username = excluded.username,
first_name = excluded.first_name,
last_name = excluded.last_name,
last_active_at = excluded.last_active_at
""", (telegram_user_id, username, first_name, last_name, datetime.now()))
await db.commit()
logger.info(f"User {telegram_user_id} created/updated")
return True
except Exception as e:
logger.error(f"Failed to create/update user {telegram_user_id}: {e}")
return False
async def get_user(telegram_user_id: int) -> Optional[Dict[str, Any]]:
"""
Get user information by Telegram user ID.
Args:
telegram_user_id: Telegram user ID
Returns:
Optional[Dict]: User data or None if not found
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM telegram_users
WHERE telegram_user_id = ?
""", (telegram_user_id,))
row = await cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get user {telegram_user_id}: {e}")
return None
async def link_user_to_oracle(
telegram_user_id: int,
oracle_username: str,
jwt_token: str,
jwt_refresh_token: str,
token_expires_at: datetime
) -> bool:
"""
Link a Telegram user to an Oracle account and save JWT tokens.
Args:
telegram_user_id: Telegram user ID
oracle_username: Oracle username
jwt_token: JWT access token
jwt_refresh_token: JWT refresh token
token_expires_at: Token expiration timestamp
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE telegram_users
SET oracle_username = ?,
jwt_token = ?,
jwt_refresh_token = ?,
token_expires_at = ?,
linked_at = ?,
is_active = 1
WHERE telegram_user_id = ?
""", (
oracle_username,
jwt_token,
jwt_refresh_token,
token_expires_at,
datetime.now(),
telegram_user_id
))
await db.commit()
logger.info(f"User {telegram_user_id} linked to Oracle user {oracle_username}")
return True
except Exception as e:
logger.error(f"Failed to link user {telegram_user_id}: {e}")
return False
async def update_user_tokens(
telegram_user_id: int,
jwt_token: str,
jwt_refresh_token: str,
token_expires_at: datetime
) -> bool:
"""
Update JWT tokens for a user.
Args:
telegram_user_id: Telegram user ID
jwt_token: New JWT access token
jwt_refresh_token: New JWT refresh token
token_expires_at: New token expiration timestamp
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE telegram_users
SET jwt_token = ?,
jwt_refresh_token = ?,
token_expires_at = ?
WHERE telegram_user_id = ?
""", (jwt_token, jwt_refresh_token, token_expires_at, telegram_user_id))
await db.commit()
logger.info(f"Tokens updated for user {telegram_user_id}")
return True
except Exception as e:
logger.error(f"Failed to update tokens for user {telegram_user_id}: {e}")
return False
async def update_user_last_active(telegram_user_id: int) -> bool:
"""
Update the last active timestamp for a user.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE telegram_users
SET last_active_at = ?
WHERE telegram_user_id = ?
""", (datetime.now(), telegram_user_id))
await db.commit()
return True
except Exception as e:
logger.error(f"Failed to update last active for user {telegram_user_id}: {e}")
return False
async def is_user_linked(telegram_user_id: int) -> bool:
"""
Check if a user is linked to an Oracle account.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if user is linked
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT oracle_username FROM telegram_users
WHERE telegram_user_id = ? AND oracle_username IS NOT NULL
""", (telegram_user_id,))
row = await cursor.fetchone()
return row is not None
except Exception as e:
logger.error(f"Failed to check if user {telegram_user_id} is linked: {e}")
return False
async def is_user_authenticated(telegram_user_id: int) -> bool:
"""
Check if a user is authenticated (linked and has valid token).
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if user is authenticated
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT oracle_username, jwt_token, token_expires_at
FROM telegram_users
WHERE telegram_user_id = ?
AND oracle_username IS NOT NULL
AND jwt_token IS NOT NULL
""", (telegram_user_id,))
row = await cursor.fetchone()
if not row:
return False
# Check if token is expired (with some buffer)
if row[2]: # token_expires_at
expires_at = datetime.fromisoformat(row[2])
# Token should have at least 5 minutes remaining
if expires_at < datetime.now() + timedelta(minutes=5):
return False
return True
except Exception as e:
logger.error(f"Failed to check if user {telegram_user_id} is authenticated: {e}")
return False
# ============================================================================
# AUTHENTICATION CODES OPERATIONS
# ============================================================================
async def create_auth_code(
code: str,
telegram_user_id: int,
oracle_username: str,
expires_in_minutes: int = 5
) -> bool:
"""
Create a new authentication code for linking.
Args:
code: 8-character authentication code
telegram_user_id: Telegram user ID
oracle_username: Oracle username to link
expires_in_minutes: Code expiration time in minutes (default: 5)
Returns:
bool: True if successful
"""
try:
expires_at = datetime.now() + timedelta(minutes=expires_in_minutes)
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
INSERT INTO telegram_auth_codes (
code, telegram_user_id, oracle_username, expires_at
)
VALUES (?, ?, ?, ?)
""", (code, telegram_user_id, oracle_username, expires_at))
await db.commit()
logger.info(f"Auth code created for user {telegram_user_id}")
return True
except Exception as e:
logger.error(f"Failed to create auth code: {e}")
return False
async def get_auth_code(code: str) -> Optional[Dict[str, Any]]:
"""
Get authentication code information.
Args:
code: 8-character authentication code
Returns:
Optional[Dict]: Code data or None if not found
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM telegram_auth_codes
WHERE code = ?
""", (code,))
row = await cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get auth code: {e}")
return None
async def verify_and_use_auth_code(code: str) -> Optional[Dict[str, Any]]:
"""
Verify an authentication code and mark it as used.
Args:
code: 8-character authentication code
Returns:
Optional[Dict]: Code data if valid, None if invalid/expired
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
# Check if code exists, is not used, and not expired
cursor = await db.execute("""
SELECT * FROM telegram_auth_codes
WHERE code = ?
AND used = 0
AND expires_at > ?
""", (code, datetime.now()))
row = await cursor.fetchone()
if not row:
logger.warning(f"Invalid or expired code: {code}")
return None
# Mark code as used
await db.execute("""
UPDATE telegram_auth_codes
SET used = 1, used_at = ?
WHERE code = ?
""", (datetime.now(), code))
await db.commit()
logger.info(f"Auth code {code} verified and used")
return dict(row)
except Exception as e:
logger.error(f"Failed to verify auth code: {e}")
return None
async def get_pending_codes_for_user(telegram_user_id: int) -> List[Dict[str, Any]]:
"""
Get all pending (unused, non-expired) codes for a user.
Args:
telegram_user_id: Telegram user ID
Returns:
List[Dict]: List of pending codes
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM telegram_auth_codes
WHERE telegram_user_id = ?
AND used = 0
AND expires_at > ?
ORDER BY created_at DESC
""", (telegram_user_id, datetime.now()))
rows = await cursor.fetchall()
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to get pending codes for user {telegram_user_id}: {e}")
return []
# ============================================================================
# EMAIL AUTHENTICATION CODES OPERATIONS
# ============================================================================
async def get_pending_email_code(
telegram_user_id: int
) -> Optional[Dict]:
"""
Get pending (non-expired, non-used) email code for user
Returns:
Code data dict or None if no pending code
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT code, email, oracle_username, expires_at, failed_attempts
FROM email_auth_codes
WHERE telegram_user_id = ?
AND used = 0
AND expires_at > ?
ORDER BY created_at DESC
LIMIT 1
""", (telegram_user_id, datetime.now()))
row = await cursor.fetchone()
if row:
return {
'code': row[0],
'email': row[1],
'oracle_username': row[2],
'expires_at': datetime.fromisoformat(row[3]),
'failed_attempts': row[4]
}
return None
except Exception as e:
logger.error(f"Failed to get pending email code: {e}")
return None
async def create_email_auth_code(
code: str,
email: str,
username: str,
telegram_user_id: int,
expiry_minutes: int = 5
) -> bool:
"""
Create new email authentication code
NOTE: Caller should check for existing pending codes first
"""
expires_at = datetime.now() + timedelta(minutes=expiry_minutes)
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
INSERT INTO email_auth_codes
(code, email, oracle_username, telegram_user_id, expires_at)
VALUES (?, ?, ?, ?, ?)
""", (code, email, username, telegram_user_id, expires_at))
await db.commit()
logger.info(
f"Email auth code created for user {telegram_user_id}, "
f"expires at {expires_at.isoformat()}"
)
return True
except Exception as e:
logger.error(f"Error creating email auth code: {e}", exc_info=True)
return False
async def get_email_auth_code(code: str) -> Optional[Dict]:
"""Get email auth code details"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT code, email, oracle_username, telegram_user_id,
created_at, expires_at, used, used_at, failed_attempts
FROM email_auth_codes
WHERE code = ?
""", (code,))
row = await cursor.fetchone()
if not row:
return None
return {
'code': row[0],
'email': row[1],
'oracle_username': row[2],
'telegram_user_id': row[3],
'created_at': datetime.fromisoformat(row[4]),
'expires_at': datetime.fromisoformat(row[5]),
'used': bool(row[6]),
'used_at': datetime.fromisoformat(row[7]) if row[7] else None,
'failed_attempts': row[8]
}
except Exception as e:
logger.error(f"Failed to get email auth code: {e}")
return None
async def increment_failed_attempts(code: str) -> bool:
"""Increment failed validation attempts for code"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE email_auth_codes
SET failed_attempts = failed_attempts + 1
WHERE code = ?
""", (code,))
await db.commit()
return True
except Exception as e:
logger.error(f"Error incrementing failed attempts: {e}")
return False
async def mark_email_code_used(code: str) -> bool:
"""Mark email code as used"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE email_auth_codes
SET used = 1, used_at = ?
WHERE code = ?
""", (datetime.now(), code))
await db.commit()
logger.info(f"Email auth code marked as used: {code}")
return True
except Exception as e:
logger.error(f"Error marking email code as used: {e}")
return False
async def delete_user_email_codes(telegram_user_id: int) -> int:
"""Delete all email codes for user (cleanup)"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
DELETE FROM email_auth_codes
WHERE telegram_user_id = ?
""", (telegram_user_id,))
await db.commit()
deleted = cursor.rowcount
logger.info(f"Deleted {deleted} email codes for user {telegram_user_id}")
return deleted
except Exception as e:
logger.error(f"Error deleting user email codes: {e}")
return 0
# ============================================================================
# SESSION OPERATIONS
# ============================================================================
async def create_session(
telegram_user_id: int,
conversation_state: Optional[str] = None,
expires_in_hours: int = 24
) -> Optional[str]:
"""
Create a new conversation session.
Args:
telegram_user_id: Telegram user ID
conversation_state: JSON string of conversation state
expires_in_hours: Session expiration time in hours (default: 24)
Returns:
Optional[str]: Session ID if successful, None otherwise
"""
try:
session_id = str(uuid.uuid4())
expires_at = datetime.now() + timedelta(hours=expires_in_hours)
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
INSERT INTO telegram_sessions (
session_id, telegram_user_id, conversation_state, expires_at
)
VALUES (?, ?, ?, ?)
""", (session_id, telegram_user_id, conversation_state, expires_at))
await db.commit()
logger.info(f"Session {session_id} created for user {telegram_user_id}")
return session_id
except Exception as e:
logger.error(f"Failed to create session: {e}")
return None
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session information.
Args:
session_id: Session UUID
Returns:
Optional[Dict]: Session data or None if not found/expired
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM telegram_sessions
WHERE session_id = ?
AND expires_at > ?
""", (session_id, datetime.now()))
row = await cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get session {session_id}: {e}")
return None
async def get_user_active_session(telegram_user_id: int) -> Optional[Dict[str, Any]]:
"""
Get the most recent active session for a user.
Args:
telegram_user_id: Telegram user ID
Returns:
Optional[Dict]: Session data or None if no active session
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM telegram_sessions
WHERE telegram_user_id = ?
AND expires_at > ?
ORDER BY updated_at DESC
LIMIT 1
""", (telegram_user_id, datetime.now()))
row = await cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get active session for user {telegram_user_id}: {e}")
return None
async def update_session_state(
session_id: str,
conversation_state: str
) -> bool:
"""
Update the conversation state for a session.
Args:
session_id: Session UUID
conversation_state: JSON string of conversation state
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
UPDATE telegram_sessions
SET conversation_state = ?,
updated_at = ?
WHERE session_id = ?
""", (conversation_state, datetime.now(), session_id))
await db.commit()
logger.info(f"Session {session_id} state updated")
return True
except Exception as e:
logger.error(f"Failed to update session {session_id}: {e}")
return False
async def delete_session(session_id: str) -> bool:
"""
Delete a session.
Args:
session_id: Session UUID
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
await db.execute("""
DELETE FROM telegram_sessions
WHERE session_id = ?
""", (session_id,))
await db.commit()
logger.info(f"Session {session_id} deleted")
return True
except Exception as e:
logger.error(f"Failed to delete session {session_id}: {e}")
return False
async def delete_user_sessions(telegram_user_id: int) -> bool:
"""
Delete all sessions for a user.
Args:
telegram_user_id: Telegram user ID
Returns:
bool: True if successful
"""
try:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
DELETE FROM telegram_sessions
WHERE telegram_user_id = ?
""", (telegram_user_id,))
await db.commit()
deleted = cursor.rowcount
logger.info(f"Deleted {deleted} sessions for user {telegram_user_id}")
return True
except Exception as e:
logger.error(f"Failed to delete sessions for user {telegram_user_id}: {e}")
return False
# Export all functions
__all__ = [
# User operations
'create_or_update_user',
'get_user',
'link_user_to_oracle',
'update_user_tokens',
'update_user_last_active',
'is_user_linked',
'is_user_authenticated',
# Auth code operations
'create_auth_code',
'get_auth_code',
'verify_and_use_auth_code',
'get_pending_codes_for_user',
# Email auth code operations
'get_pending_email_code',
'create_email_auth_code',
'get_email_auth_code',
'increment_failed_attempts',
'mark_email_code_used',
'delete_user_email_codes',
# Session operations
'create_session',
'get_session',
'get_user_active_session',
'update_session_state',
'delete_session',
'delete_user_sessions',
]

View File

@@ -0,0 +1,32 @@
"""Telegram module router factory."""
from fastapi import APIRouter
def create_telegram_router() -> APIRouter:
"""
Create and configure Telegram module router.
Includes all Telegram bot internal API endpoints:
- /auth/verify-user - Verify Telegram user authentication
- /auth/generate-code - Generate auth code for linking
- /auth/verify-code - Verify auth code
- /stats - Bot database statistics
Returns:
APIRouter: Configured router for Telegram module
"""
router = APIRouter()
# Import routers here to avoid circular imports
from .auth_codes import router as auth_codes_router
from .internal_api import internal_api as internal_api_router
# Include all sub-routers (no prefix - already prefixed in main.py with /api/telegram)
# Auth codes router provides /auth/* endpoints
router.include_router(auth_codes_router, tags=["telegram-auth"])
# Internal API router provides additional endpoints like /stats
router.include_router(internal_api_router, tags=["telegram-internal"])
return router

View File

@@ -0,0 +1,840 @@
"""
API Router pentru Telegram Bot Integration
Furnizează endpoint-uri pentru autentificare, linking și export rapoarte pentru Telegram bot
"""
from fastapi import APIRouter, Depends, HTTPException, Request
from typing import List, Optional, Dict, Any
# import sys # Removed - no longer needed
import os
import secrets
import string
import httpx
from datetime import datetime, timedelta
from pydantic import BaseModel, Field
from shared.auth.dependencies import get_current_user
from shared.auth.models import CurrentUser
from shared.auth.jwt_handler import jwt_handler
from shared.database.oracle_pool import oracle_pool
# Telegram bot internal API URL (running on same server)
TELEGRAM_BOT_INTERNAL_API = os.getenv("TELEGRAM_BOT_INTERNAL_API", "http://localhost:8002")
router = APIRouter(redirect_slashes=False)
# ==================== Schemas ====================
class GenerateCodeRequest(BaseModel):
"""Request pentru generarea unui cod de linking"""
telegram_user_id: int = Field(description="ID-ul utilizatorului Telegram")
telegram_username: Optional[str] = Field(default=None, description="Username-ul Telegram")
telegram_first_name: Optional[str] = Field(default=None, description="Prenumele utilizatorului")
telegram_last_name: Optional[str] = Field(default=None, description="Numele utilizatorului")
class GenerateCodeResponse(BaseModel):
"""Response pentru generarea unui cod de linking"""
linking_code: str = Field(description="Codul de linking generat (8 caractere)")
expires_at: datetime = Field(description="Data și ora expirării codului")
expires_in_minutes: int = Field(description="Minutele până la expirare")
class VerifyUserRequest(BaseModel):
"""
Request pentru verificarea utilizatorului în Oracle
Suportă 2 flow-uri:
1. Auto-linking (recomandat): doar linking_code și oracle_username
- Bot-ul verifică codul în SQLite, extrage oracle_username
- Backend face lookup în Oracle fără verificare parolă
- Codul valid este proof-of-authorization
2. Full verification (opțional): username, password, linking_code
- Verificare completă cu parolă în Oracle
"""
linking_code: str = Field(description="Codul de linking de la /generate-code")
oracle_username: Optional[str] = Field(default=None, description="Username Oracle (pentru auto-linking)")
username: Optional[str] = Field(default=None, description="Username pentru verificare completă")
password: Optional[str] = Field(default=None, description="Parolă pentru verificare completă")
class VerifyUserResponse(BaseModel):
"""Response pentru verificarea utilizatorului"""
success: bool = Field(description="True dacă verificarea a avut succes")
access_token: Optional[str] = Field(default=None, description="JWT access token")
refresh_token: Optional[str] = Field(default=None, description="JWT refresh token")
user: Optional[Dict[str, Any]] = Field(default=None, description="Detalii utilizator")
message: str = Field(description="Mesaj de status")
class RefreshTokenRequest(BaseModel):
"""Request pentru refresh JWT token"""
refresh_token: str = Field(description="Refresh token-ul obținut la autentificare")
class RefreshTokenResponse(BaseModel):
"""Response pentru refresh token"""
access_token: str = Field(description="Noul JWT access token")
expires_in: int = Field(description="Timpul de expirare în secunde")
token_type: str = Field(default="bearer", description="Tipul token-ului")
class ExportReportRequest(BaseModel):
"""Request pentru exportul unui raport"""
company_id: int = Field(description="ID-ul firmei")
report_type: str = Field(description="Tipul raportului (invoices, payments, dashboard)")
format: str = Field(default="excel", description="Formatul exportului (excel, pdf, csv)")
filters: Optional[Dict[str, Any]] = Field(default=None, description="Filtre pentru raport")
class ExportReportResponse(BaseModel):
"""Response pentru exportul raportului"""
success: bool = Field(description="True dacă exportul a avut succes")
file_url: Optional[str] = Field(default=None, description="URL-ul fișierului generat")
file_name: Optional[str] = Field(default=None, description="Numele fișierului generat")
file_size_bytes: Optional[int] = Field(default=None, description="Mărimea fișierului în bytes")
message: str = Field(description="Mesaj de status")
class VerifyEmailRequest(BaseModel):
"""Request pentru verificarea email-ului în Oracle"""
email: str = Field(description="Adresa de email Oracle")
class VerifyEmailResponse(BaseModel):
"""Response pentru verificarea email-ului"""
success: bool = Field(description="True dacă email-ul există și este activ")
username: Optional[str] = Field(default=None, description="Username-ul Oracle asociat")
message: str = Field(description="Mesaj de status")
class TelegramEmailLoginRequest(BaseModel):
"""Request pentru autentificare prin email + parolă"""
email: str = Field(description="Adresa de email Oracle")
password: str = Field(description="Parola Oracle")
telegram_user_id: int = Field(description="ID-ul utilizatorului Telegram")
session_token: str = Field(description="Token de sesiune pentru preveni spoofing")
class TelegramEmailLoginResponse(BaseModel):
"""Response pentru autentificare prin email + parolă"""
success: bool = Field(description="True dacă autentificarea a avut succes")
access_token: Optional[str] = Field(default=None, description="JWT access token")
refresh_token: Optional[str] = Field(default=None, description="JWT refresh token")
token_type: str = Field(default="bearer", description="Tipul token-ului")
user_id: Optional[int] = Field(default=None, description="ID-ul utilizatorului Oracle")
username: Optional[str] = Field(default=None, description="Username-ul Oracle")
companies: List[Dict[str, Any]] = Field(default_factory=list, description="Lista companiilor")
message: str = Field(description="Mesaj de status")
# ==================== Helper Functions ====================
# Rate limiting storage (in-memory)
from collections import defaultdict
_endpoint_rate_limits = defaultdict(list)
def check_endpoint_rate_limit(
identifier: str,
max_attempts: int = 5,
window_minutes: int = 5
) -> bool:
"""Backend rate limiting for sensitive endpoints"""
now = datetime.now()
cutoff = now - timedelta(minutes=window_minutes)
# Clean old attempts
_endpoint_rate_limits[identifier] = [
attempt for attempt in _endpoint_rate_limits[identifier]
if attempt > cutoff
]
# Check limit
if len(_endpoint_rate_limits[identifier]) >= max_attempts:
return False
# Add attempt
_endpoint_rate_limits[identifier].append(now)
return True
def verify_session_token(
telegram_user_id: int,
email: str,
token: str
) -> bool:
"""
Verify session token from bot to prevent user ID spoofing
Token format: user_id:email:signature
"""
import hashlib
try:
parts = token.split(":")
if len(parts) != 3:
return False
token_user_id, token_email, signature = parts
# Verify user ID and email match
if int(token_user_id) != telegram_user_id or token_email != email:
return False
# Verify signature
secret = os.getenv("AUTH_SESSION_SECRET", "change-me-in-production")
payload = f"{telegram_user_id}:{email}:{secret}"
expected_signature = hashlib.sha256(payload.encode()).hexdigest()[:16]
if signature != expected_signature:
return False
return True
except Exception:
return False
def generate_linking_code(length: int = 8) -> str:
"""
Generează un cod alfanumeric aleatoriu pentru linking
Args:
length: Lungimea codului (default: 8)
Returns:
Codul generat (uppercase alphanumeric)
"""
alphabet = string.ascii_uppercase + string.digits
# Exclude caractere care pot fi confundate: 0, O, I, 1
alphabet = alphabet.replace('0', '').replace('O', '').replace('I', '').replace('1', '')
return ''.join(secrets.choice(alphabet) for _ in range(length))
async def get_oracle_user_by_username(username: str) -> Optional[Dict[str, Any]]:
"""
Obține informații despre utilizator din Oracle FĂRĂ verificare parolă.
Folosit pentru auto-linking când utilizatorul a fost deja autentificat
prin generarea unui linking code valid în aplicația web.
Args:
username: Username-ul utilizatorului Oracle
Returns:
Dict cu informații despre utilizator sau None dacă nu există
"""
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Obține detalii utilizator
cursor.execute("""
SELECT ID_UTIL, UTILIZATOR
FROM UTILIZATORI
WHERE UPPER(UTILIZATOR) = :username
""", {'username': username.upper()})
user_row = cursor.fetchone()
if not user_row:
return None
user_id = user_row[0]
actual_username = user_row[1]
# Obține companiile utilizatorului
cursor.execute("""
SELECT A.ID_FIRMA, A.FIRMA
FROM V_NOM_FIRME A
WHERE A.ID_FIRMA IN (
SELECT ID_FIRMA
FROM VDEF_UTIL_FIRME
WHERE ID_PROGRAM = 2
AND ID_UTIL = :user_id
)
ORDER BY A.FIRMA
""", {'user_id': user_id})
companies_result = cursor.fetchall()
companies = [str(row[0]) for row in companies_result]
return {
'user_id': user_id,
'username': actual_username,
'companies': companies,
'permissions': ['read', 'reports']
}
except Exception as e:
print(f"Error getting Oracle user by username: {e}")
return None
async def verify_oracle_user(username: str, password: str) -> Optional[Dict[str, Any]]:
"""
Verifică utilizatorul în Oracle folosind pack_drepturi.verificautilizator
Args:
username: Username-ul utilizatorului
password: Parola utilizatorului
Returns:
Dict cu informații despre utilizator sau None dacă verificarea eșuează
"""
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Verifică autentificarea
cursor.execute("""
SELECT pack_drepturi.verificautilizator(:username, :password)
FROM DUAL
""", {
'username': username.upper(),
'password': password
})
result = cursor.fetchone()
verification_result = result[0] if result else -1
if verification_result == -1:
return None
# Obține detalii utilizator
cursor.execute("""
SELECT ID_UTIL, UTILIZATOR
FROM UTILIZATORI
WHERE UPPER(UTILIZATOR) = :username
""", {'username': username.upper()})
user_row = cursor.fetchone()
if not user_row:
return None
user_id = user_row[0]
# Obține companiile utilizatorului
cursor.execute("""
SELECT A.ID_FIRMA, A.FIRMA
FROM V_NOM_FIRME A
WHERE A.ID_FIRMA IN (
SELECT ID_FIRMA
FROM VDEF_UTIL_FIRME
WHERE ID_PROGRAM = 2
AND ID_UTIL = :user_id
)
ORDER BY A.FIRMA
""", {'user_id': user_id})
companies_result = cursor.fetchall()
companies = [str(row[0]) for row in companies_result]
return {
'user_id': user_id,
'username': username,
'companies': companies,
'permissions': ['read', 'reports']
}
except Exception as e:
print(f"Error verifying Oracle user: {e}")
return None
# ==================== Endpoints ====================
@router.post("/auth/generate-code", response_model=GenerateCodeResponse)
async def generate_linking_code_endpoint(
current_user: CurrentUser = Depends(get_current_user)
):
"""
Generează un cod de linking pentru conectarea unui utilizator Telegram
Flow:
1. Utilizatorul autentificat în aplicație solicită un cod
2. Se generează un cod unic de 8 caractere
3. Codul este trimis la Telegram bot pentru salvare în SQLite cu TTL de 15 minute
4. Utilizatorul introduce codul în Telegram bot pentru linking
Note:
- Acest endpoint necesită autentificare JWT (utilizatorul trebuie să fie logat în aplicație)
- Codul expiră după 15 minute
- Fiecare request generează un cod nou (codurile vechi devin invalide)
- Nu este nevoie de telegram_user_id în acest moment (utilizatorul nu e încă conectat la Telegram)
"""
try:
# Generează cod unic
linking_code = generate_linking_code()
# Setează expirarea la 15 minute
expires_at = datetime.utcnow() + timedelta(minutes=15)
expires_in_minutes = 15
# Salvează codul în database-ul Telegram bot (SQLite) via internal API
try:
async with httpx.AsyncClient(timeout=5.0) as client:
save_code_response = await client.post(
f"{TELEGRAM_BOT_INTERNAL_API}/internal/save-code",
json={
"code": linking_code,
"telegram_user_id": 0, # Not known yet (user hasn't linked)
"oracle_username": current_user.username,
"expires_in_minutes": expires_in_minutes
}
)
# Accept both 200 (OK) and 201 (Created) as success
if save_code_response.status_code not in [200, 201]:
raise HTTPException(
status_code=500,
detail=f"Failed to save code to Telegram bot: {save_code_response.text}"
)
except httpx.TimeoutException:
raise HTTPException(
status_code=503,
detail="Telegram bot service is not responding. Please try again later."
)
except httpx.ConnectError:
raise HTTPException(
status_code=503,
detail="Cannot connect to Telegram bot service. Please contact administrator."
)
return GenerateCodeResponse(
linking_code=linking_code,
expires_at=expires_at,
expires_in_minutes=expires_in_minutes
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Eroare la generarea codului de linking: {str(e)}"
)
@router.post("/auth/verify-user", response_model=VerifyUserResponse)
async def verify_user_endpoint(request: VerifyUserRequest):
"""
Verifică utilizatorul în Oracle și returnează JWT tokens
Suportă 2 flow-uri de autentificare:
Flow A - Auto-linking (RECOMANDAT):
1. Bot verifică linking_code în SQLite (code valid = user s-a autentificat în web app)
2. Bot extrage oracle_username din cod
3. Bot trimite: {linking_code, oracle_username}
4. Backend face lookup în Oracle (FĂRĂ verificare parolă)
5. Backend generează și returnează JWT tokens
Flow B - Full verification (OPȚIONAL):
1. Bot cere username și parolă de la user în Telegram
2. Bot trimite: {linking_code, username, password}
3. Backend verifică credențialele în Oracle
4. Backend generează și returnează JWT tokens
Note:
- Acest endpoint NU necesită autentificare JWT (este public pentru bot)
- Flow A oferă UX superior (fără re-introducere parolă)
- Linking code-ul valid este proof-of-authorization
"""
try:
# Flow A: Auto-linking (oracle_username provided, no password)
if request.oracle_username and not request.password:
user_data = await get_oracle_user_by_username(request.oracle_username)
if not user_data:
return VerifyUserResponse(
success=False,
message=f"Utilizatorul {request.oracle_username} nu există în Oracle"
)
# Flow B: Full verification (username + password provided)
elif request.username and request.password:
user_data = await verify_oracle_user(request.username, request.password)
if not user_data:
return VerifyUserResponse(
success=False,
message="Username sau parolă incorectă"
)
# Invalid request (missing required fields)
else:
return VerifyUserResponse(
success=False,
message="Trebuie furnizat fie oracle_username (auto-linking) fie username+password (verificare completă)"
)
# Generează JWT tokens
access_token = jwt_handler.create_access_token(
username=user_data['username'],
companies=user_data['companies'],
user_id=user_data['user_id'],
permissions=user_data['permissions']
)
refresh_token = jwt_handler.create_refresh_token(
username=user_data['username'],
user_id=user_data['user_id']
)
return VerifyUserResponse(
success=True,
access_token=access_token,
refresh_token=refresh_token,
user={
'user_id': user_data['user_id'],
'username': user_data['username'],
'companies': user_data['companies'],
'permissions': user_data['permissions']
},
message="Autentificare reușită"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Eroare la verificarea utilizatorului: {str(e)}"
)
@router.post("/auth/refresh-token", response_model=RefreshTokenResponse)
async def refresh_token_endpoint(request: RefreshTokenRequest):
"""
Refresh-uiește un JWT access token folosind refresh token-ul
Acest endpoint este folosit de Telegram bot pentru a obține un nou access token
când cel curent expiră, fără a solicita din nou username/password.
Flow:
1. Botul Telegram detectează că access token-ul a expirat
2. Trimite refresh token-ul la acest endpoint
3. Se validează refresh token-ul și se generează un nou access token
4. Botul stochează noul access token în SQLite
Note:
- Refresh token-ul este valid 7 zile (vs 30 minute pentru access token)
- Dacă refresh token-ul expiră, utilizatorul trebuie să se re-autentifice
"""
try:
# Verifică refresh token-ul
token_data = jwt_handler.verify_token(request.refresh_token)
if not token_data or token_data.token_type != "refresh":
raise HTTPException(
status_code=401,
detail="Refresh token invalid sau expirat"
)
# Obține companiile actualizate din Oracle
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("""
SELECT A.ID_FIRMA
FROM V_NOM_FIRME A
WHERE A.ID_FIRMA IN (
SELECT ID_FIRMA
FROM VDEF_UTIL_FIRME
WHERE ID_PROGRAM = 2
AND ID_UTIL = :user_id
)
ORDER BY A.FIRMA
""", {'user_id': token_data.user_id})
companies_result = cursor.fetchall()
companies = [str(row[0]) for row in companies_result]
# Generează nou access token
new_access_token = jwt_handler.create_access_token(
username=token_data.username,
companies=companies,
user_id=token_data.user_id,
permissions=token_data.permissions
)
return RefreshTokenResponse(
access_token=new_access_token,
expires_in=jwt_handler.access_token_expire_minutes * 60,
token_type="bearer"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Eroare la refresh token: {str(e)}"
)
@router.post("/auth/verify-email", response_model=VerifyEmailResponse)
async def verify_email_endpoint(request: VerifyEmailRequest):
"""
Verify if email exists in Oracle UTILIZATORI table (PUBLIC endpoint)
This is a PUBLIC endpoint used by the telegram bot during email authentication.
Returns username if email exists and user is active.
Security: Generic error messages to prevent email enumeration.
"""
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# Query to find username by email
cursor.execute("""
SELECT UTILIZATOR
FROM CONTAFIN_ORACLE.UTILIZATORI
WHERE UPPER(EMAIL) = UPPER(:email)
AND INACTIV = 0
AND STERS = 0
""", {"email": request.email})
row = cursor.fetchone()
if row:
username = row[0]
return VerifyEmailResponse(
success=True,
username=username,
message="Email verificat cu succes"
)
else:
# Generic message (no enumeration)
return VerifyEmailResponse(
success=False,
username=None,
message="Email invalid sau inactiv"
)
except Exception as e:
# Generic error message (no details exposed)
return VerifyEmailResponse(
success=False,
username=None,
message="Eroare la verificarea email-ului"
)
@router.post("/auth/login-with-email", response_model=TelegramEmailLoginResponse)
async def login_with_email_endpoint(request: TelegramEmailLoginRequest):
"""
Telegram email + password authentication endpoint
Security features:
- Rate limiting: 5 attempts per 5 minutes
- Session token verification (prevent user ID spoofing)
- Generic error messages (no username/email enumeration)
- Password verification in Oracle (not stored)
"""
# 1. Rate limiting
rate_limit_key = f"email_login_{request.telegram_user_id}"
if not check_endpoint_rate_limit(rate_limit_key, max_attempts=5, window_minutes=5):
raise HTTPException(
status_code=429,
detail="Prea multe încercări. Te rugăm să aștepți 5 minute."
)
# 2. Verify session token (prevent user ID spoofing)
if not verify_session_token(
request.telegram_user_id,
request.email,
request.session_token
):
raise HTTPException(
status_code=401,
detail="Sesiune invalidă. Te rugăm să reîncepi autentificarea."
)
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
# 3. Find username by email
cursor.execute("""
SELECT ID_UTIL, UTILIZATOR, INACTIV, STERS
FROM CONTAFIN_ORACLE.UTILIZATORI
WHERE UPPER(EMAIL) = UPPER(:email)
""", {"email": request.email})
user_row = cursor.fetchone()
# SECURITY: Generic error message (no email enumeration)
if not user_row:
raise HTTPException(
status_code=401,
detail="Credențiale invalide" # Generic message
)
user_id, username, inactiv, sters = user_row
# Check if user is active (INACTIV=0 means active, STERS=0 means not deleted)
if inactiv != 0 or sters != 0:
raise HTTPException(
status_code=401,
detail="Credențiale invalide" # Generic message
)
# 4. Verify password via Oracle stored procedure
# NOTE: This procedure returns a verification code, NOT the user_id!
# Returns -1 if authentication fails, any other value means success
cursor.execute("""
SELECT pack_drepturi.verificautilizator(:username, :password)
FROM DUAL
""", {
"username": username.upper(), # IMPORTANT: Oracle usernames are uppercase
"password": request.password
})
verification_result = cursor.fetchone()[0]
# SECURITY: Generic error message (no username leak)
if verification_result == -1:
raise HTTPException(
status_code=401,
detail="Credențiale invalide" # Generic message
)
# 5. Get user companies
cursor.execute("""
SELECT A.ID_FIRMA, A.FIRMA
FROM V_NOM_FIRME A
WHERE A.ID_FIRMA IN (
SELECT ID_FIRMA
FROM VDEF_UTIL_FIRME
WHERE ID_PROGRAM = 2
AND ID_UTIL = :user_id
)
ORDER BY A.FIRMA
""", {'user_id': user_id})
companies_result = cursor.fetchall()
companies = [
{"id": str(row[0]), "name": row[1]}
for row in companies_result
]
company_ids = [str(row[0]) for row in companies_result]
# 6. Get user permissions (default for Telegram)
permissions = ['read', 'reports']
# 7. Generate JWT tokens
token_data = {
"username": username,
"user_id": user_id,
"companies": company_ids,
"permissions": permissions
}
access_token = jwt_handler.create_access_token(**token_data)
refresh_token = jwt_handler.create_refresh_token(
username=username,
user_id=user_id
)
return TelegramEmailLoginResponse(
success=True,
access_token=access_token,
refresh_token=refresh_token,
user_id=user_id,
username=username,
companies=companies,
message="Autentificare reușită"
)
except HTTPException:
raise
except Exception as e:
print(f"Error in login_with_email: {e}")
raise HTTPException(
status_code=500,
detail="Eroare internă. Te rugăm să încerci din nou mai târziu."
)
@router.post("/export", response_model=ExportReportResponse)
async def export_report_endpoint(
request: ExportReportRequest,
current_user: CurrentUser = Depends(get_current_user)
):
"""
Exportă un raport în format Excel, PDF sau CSV
Acest endpoint este folosit de Telegram bot pentru a genera rapoarte
și a le trimite utilizatorului.
Flow:
1. Botul trimite cerere de export cu parametrii raportului
2. Se validează că utilizatorul are acces la firma specificată
3. Se generează raportul în formatul solicitat
4. Se returnează URL-ul sau conținutul fișierului
Tipuri de rapoarte suportate:
- invoices: Facturi (cu filtre: dată, status, client)
- payments: Încasări (cu filtre: dată, metodă plată)
- dashboard: Statistici dashboard (rezumat)
Formate suportate:
- excel: XLSX (cel mai complet)
- pdf: PDF (pentru printing)
- csv: CSV (pentru import în alte sisteme)
Note:
- Utilizatorul trebuie să aibă acces la firma specificată
- Fișierele generate sunt temporare (șterse după 1 oră)
"""
try:
# Verifică accesul la firmă
company_id_str = str(request.company_id)
if company_id_str not in current_user.companies:
raise HTTPException(
status_code=403,
detail=f"Nu aveți acces la firma {request.company_id}"
)
# TODO: Implementare export în funcție de report_type și format
# Deocamdată returnăm un placeholder
return ExportReportResponse(
success=True,
file_url=f"/api/telegram/downloads/report_{request.report_type}_{request.company_id}.{request.format}",
file_name=f"raport_{request.report_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{request.format}",
file_size_bytes=0,
message=f"Raport {request.report_type} generat cu succes în format {request.format}"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Eroare la generarea raportului: {str(e)}"
)
@router.get("/health")
async def telegram_health_check():
"""
Health check pentru routerul Telegram
Verifică conectivitatea la Oracle și disponibilitatea serviciilor
"""
try:
async with oracle_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM DUAL")
return {
"status": "healthy",
"service": "telegram-router",
"database": "connected",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
return {
"status": "degraded",
"service": "telegram-router",
"database": f"error: {str(e)}",
"timestamp": datetime.utcnow().isoformat()
}

Some files were not shown because too many files have changed in this diff Show More