""" OCR Data Validation Module Provides multi-layer validation for OCR extraction results to prevent incorrect data from entering the system. Validation Layers: 1. Absolute sanity checks (value ranges) 2. Cross-field validation (correlation between fields) 3. Inter-OCR consistency (compare multiple OCR results) 4. Auto-correction (fix obvious errors) Usage: engine = OCRValidationEngine() validated_result = engine.validate_extraction( merged_result, light_ocr_result, medium_ocr_result ) """ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Optional @dataclass class ValidationResult: """Result of a single validation rule execution. Attributes: is_valid: Whether the validation passed confidence_penalty: Penalty to apply to confidence score (0.0-1.0) 0.0 = no penalty, 1.0 = complete rejection message: Human-readable description of validation result severity: "info" | "warning" | "error" """ is_valid: bool confidence_penalty: float = 0.0 message: str = "" severity: str = "info" # "info" | "warning" | "error" def __post_init__(self): """Validate penalty is in valid range.""" if not 0.0 <= self.confidence_penalty <= 1.0: raise ValueError(f"Confidence penalty must be 0.0-1.0, got {self.confidence_penalty}") class ValidationRule(ABC): """Abstract base class for OCR validation rules. Each rule implements a specific validation check and returns a ValidationResult indicating success/failure with optional confidence penalty. """ @abstractmethod def validate(self, data: dict[str, Any]) -> ValidationResult: """Execute validation rule on extraction data. Args: data: Dictionary containing extraction fields to validate Example: {"amount": 85.99, "tva": 14.92, ...} Returns: ValidationResult with is_valid flag and optional penalty """ pass @property @abstractmethod def rule_name(self) -> str: """Human-readable name of this validation rule.""" pass # ============================================================================ # VALIDATION RULES # ============================================================================ class AmountRangeRule(ValidationRule): """Validate amount is within reasonable bounds for Romanian receipts. Romanian receipts rarely exceed 100,000 RON. This catches obvious OCR errors like digit concatenation (85.99 → 859,762.16). Example: rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0) result = rule.validate({"amount": 859762.16}) # result.is_valid = False, penalty = 0.5 """ def __init__(self, min_amount: float = 0.01, max_amount: float = 100_000.0): self.min_amount = min_amount self.max_amount = max_amount @property def rule_name(self) -> str: return "Amount Range Check" def validate(self, data: dict[str, Any]) -> ValidationResult: amount = data.get("amount") if amount is None: return ValidationResult( is_valid=True, message="No amount to validate" ) if amount < self.min_amount: return ValidationResult( is_valid=False, confidence_penalty=0.5, message=f"Amount {amount:.2f} RON below minimum {self.min_amount:.2f} RON", severity="error" ) if amount > self.max_amount: return ValidationResult( is_valid=False, confidence_penalty=0.5, message=f"Amount {amount:.2f} RON exceeds maximum {self.max_amount:.2f} RON (likely OCR error)", severity="error" ) return ValidationResult( is_valid=True, message=f"Amount {amount:.2f} RON within valid range" ) class TVARatioRule(ValidationRule): """Validate TVA is reasonable percentage of TOTAL amount. Romanian TVA rates: 5%, 9%, 19%, 21% (most common: 19-21%) This catches errors where TVA > TOTAL (impossible). Example: rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24) result = rule.validate({"amount": 85.99, "tva": 149.21}) # result.is_valid = False (149.21 > 85.99!) """ def __init__(self, min_ratio: float = 0.05, max_ratio: float = 0.24): self.min_ratio = min_ratio self.max_ratio = max_ratio @property def rule_name(self) -> str: return "TVA Ratio Check" def validate(self, data: dict[str, Any]) -> ValidationResult: amount = data.get("amount") tva = data.get("tva") if not amount or not tva: return ValidationResult( is_valid=True, message="Insufficient data for TVA correlation" ) # Type safety: ensure numeric types before division if not isinstance(amount, (int, float)) or not isinstance(tva, (int, float)): return ValidationResult( is_valid=True, message="Non-numeric values, skipping TVA correlation" ) # Avoid division by zero if amount <= 0: return ValidationResult( is_valid=True, message="Amount is zero or negative, skipping TVA ratio" ) tva_ratio = tva / amount if tva_ratio < self.min_ratio or tva_ratio > self.max_ratio: return ValidationResult( is_valid=False, confidence_penalty=0.3, message=f"TVA ratio {tva_ratio:.1%} outside valid range ({self.min_ratio:.0%}-{self.max_ratio:.0%})", severity="warning" ) return ValidationResult( is_valid=True, message=f"TVA ratio {tva_ratio:.1%} valid" ) class PaymentSumRule(ValidationRule): """Validate CARD + NUMERAR = TOTAL BON (within tolerance). This is a CRITICAL validation that catches cases where OCR extracts wrong TOTAL but correct payment methods. Example: rule = PaymentSumRule(tolerance=0.02) result = rule.validate({ "amount": 859762.16, # Wrong from OCR "card_amount": 85.99, # Correct "cash_amount": 0.0 }) # result.is_valid = False, suggests auto-correction """ def __init__(self, tolerance: float = 0.02): self.tolerance = tolerance @property def rule_name(self) -> str: return "Payment Sum Check" def validate(self, data: dict[str, Any]) -> ValidationResult: total = data.get("amount") card = data.get("card_amount", 0.0) or 0.0 cash = data.get("cash_amount", 0.0) or 0.0 if not total: return ValidationResult( is_valid=True, message="No total amount to validate" ) payment_sum = card + cash if payment_sum == 0: return ValidationResult( is_valid=True, message="No payment methods extracted" ) diff = abs(total - payment_sum) if diff > self.tolerance: return ValidationResult( is_valid=False, confidence_penalty=0.4, message=f"Payment sum {payment_sum:.2f} RON ≠ Total {total:.2f} RON (diff: {diff:.2f} RON). Consider auto-correction.", severity="error" ) return ValidationResult( is_valid=True, message=f"Payment sum matches total (diff: {diff:.2f} RON)" ) class TVAEntriesSumRule(ValidationRule): """Validate Σ(TVA entries) = TVA TOTAL (within tolerance). TVA breakdown (A, B, C, D rates) should sum to total TVA. Example: rule = TVAEntriesSumRule(tolerance=0.02) result = rule.validate({ "tva": 14.92, "tva_entries": {"A": 14.92, "B": 0.0} }) # result.is_valid = True """ def __init__(self, tolerance: float = 0.02): self.tolerance = tolerance @property def rule_name(self) -> str: return "TVA Entries Sum Check" def validate(self, data: dict[str, Any]) -> ValidationResult: tva_total = data.get("tva") tva_entries = data.get("tva_entries", {}) if not tva_total: return ValidationResult( is_valid=True, message="No TVA total to validate" ) if not tva_entries: return ValidationResult( is_valid=True, message="No TVA entries extracted" ) entries_sum = sum(tva_entries.values()) if entries_sum == 0: return ValidationResult( is_valid=True, message="TVA entries sum is zero" ) diff = abs(tva_total - entries_sum) if diff > self.tolerance: return ValidationResult( is_valid=False, confidence_penalty=0.2, message=f"TVA entries sum {entries_sum:.2f} RON ≠ TVA total {tva_total:.2f} RON (diff: {diff:.2f} RON)", severity="warning" ) return ValidationResult( is_valid=True, message=f"TVA entries sum matches total (diff: {diff:.2f} RON)" ) class CUIFormatRule(ValidationRule): """Validate CUI format: RO + 6-10 digits. Romanian CUI (Cod Unic de Identificare) format: - Optional "RO" prefix (or "R0" from OCR errors) - 6-10 numeric digits Example: rule = CUIFormatRule() result = rule.validate({"cui": "RO10562600"}) # result.is_valid = True """ @property def rule_name(self) -> str: return "CUI Format Check" def validate(self, data: dict[str, Any]) -> ValidationResult: cui = data.get("cui") if not cui: return ValidationResult( is_valid=True, message="No CUI to validate" ) # Normalize: remove RO/R0 prefix cui_clean = cui.strip().upper() if cui_clean.startswith("RO"): cui_clean = cui_clean[2:] elif cui_clean.startswith("R0"): cui_clean = cui_clean[2:] # Check if numeric if not cui_clean.isdigit(): return ValidationResult( is_valid=False, confidence_penalty=0.3, message=f"CUI '{cui}' contains non-numeric characters", severity="warning" ) # Check length if len(cui_clean) < 6 or len(cui_clean) > 10: return ValidationResult( is_valid=False, confidence_penalty=0.3, message=f"CUI '{cui}' length {len(cui_clean)} outside valid range (6-10 digits)", severity="warning" ) return ValidationResult( is_valid=True, message=f"CUI '{cui}' format valid" ) class CUIChecksumRule(ValidationRule): """Validate Romanian CIF/CUI using Mod 11 checksum algorithm. Algorithm: 1. Remove RO prefix if present 2. Extract last digit as declared checksum 3. Apply multipliers [7,5,3,2,1,7,5,3,2] to first N-1 digits 4. Calculate: (sum * 10) mod 11 5. If result = 10, expected checksum = 0 6. Else, expected checksum = result 7. Compare with declared checksum Example: rule = CUIChecksumRule() result = rule.validate({"cui": "RO10562600"}) # result.is_valid = True (checksum correct) result = rule.validate({"cui": "R01879855"}) # result.is_valid = False (checksum mismatch) Static methods available for direct use: CUIChecksumRule.calculate_checksum("1056260") -> 0 CUIChecksumRule.validate_checksum("10562600") -> True CUIChecksumRule.has_ro_prefix("RO10562600") -> True """ # Fixed multipliers for 9 positions (Romanian Mod 11) MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2] @staticmethod def calculate_checksum(cui_base: str) -> int: """Calculate expected CUI checksum using Romanian Mod 11 algorithm. Args: cui_base: CUI digits WITHOUT the checksum digit (last digit) Returns: Expected checksum digit (0-9), or -1 if invalid input """ if not cui_base or not cui_base.isdigit(): return -1 # Pad base to 9 digits from LEFT base_padded = cui_base.zfill(9) base_digits = [int(d) for d in base_padded] # Calculate weighted sum weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS)) # Calculate checksum checksum = (weighted_sum * 10) % 11 if checksum == 10: checksum = 0 return checksum @staticmethod def validate_checksum(cui_digits: str) -> bool: """Check if CUI checksum is valid. Args: cui_digits: Full CUI digits (including checksum as last digit) Returns: True if checksum is valid, False otherwise """ if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit(): return False base = cui_digits[:-1] declared = int(cui_digits[-1]) expected = CUIChecksumRule.calculate_checksum(base) return expected == declared @staticmethod def has_ro_prefix(cui: str) -> bool: """Check if CUI has RO prefix (proper format for VAT payers).""" if not cui: return False return cui.upper().strip().startswith('RO') @staticmethod def extract_digits(cui: str) -> str: """Extract digits from CUI, removing RO/R0 prefix.""" if not cui: return "" cui = cui.strip().upper() if cui.startswith("RO"): cui = cui[2:] elif cui.startswith("R0"): # Fix OCR error R0 → RO cui = cui[2:] return ''.join(c for c in cui if c.isdigit()) @property def rule_name(self) -> str: return "CUI Checksum Check (Mod 11)" def validate(self, data: dict[str, Any]) -> ValidationResult: cui = data.get("cui") if not cui: return ValidationResult( is_valid=True, message="No CUI to validate" ) # Use static method to extract digits cui_clean = CUIChecksumRule.extract_digits(cui) # Check format first if not cui_clean: return ValidationResult( is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule) message="CUI format invalid, skipping checksum" ) if len(cui_clean) < 6 or len(cui_clean) > 10: return ValidationResult( is_valid=True, message="CUI length invalid, skipping checksum" ) # Use static method to validate checksum if not CUIChecksumRule.validate_checksum(cui_clean): # Calculate expected for error message expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1]) declared = int(cui_clean[-1]) return ValidationResult( is_valid=False, confidence_penalty=0.3, message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}", severity="warning" ) return ValidationResult( is_valid=True, message=f"CUI '{cui}' checksum valid" ) class TVABasedTotalRule(ValidationRule): """Validate TOTAL using reverse calculation from TVA amount. This is a CRITICAL validation that catches cases where OCR extracts wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA, we can calculate expected TOTAL from TVA alone. Formula: Expected TOTAL = TVA / rate * (1 + rate) Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate For TVA rate 21%: Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619 Example (benzina 27 oct): TVA = 49.58, rate = 21% Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68 Extracted TOTAL = 205.66 (WRONG!) Rule detects mismatch and flags for escalation Usage in multi-tier processing (e.g., doctr_plus): If this rule fails, the engine should proceed to next tier instead of returning early with potentially wrong data. """ def __init__(self, tolerance_percent: float = 0.02): """ Args: tolerance_percent: Allowed difference as percentage (0.02 = 2%) """ self.tolerance_percent = tolerance_percent @property def rule_name(self) -> str: return "TVA-Based Total Check" def validate(self, data: dict[str, Any]) -> ValidationResult: total = data.get("amount") tva = data.get("tva") tva_entries = data.get("tva_entries", []) if not total or not tva: return ValidationResult( is_valid=True, message="Insufficient data for TVA-based total validation" ) # Type safety try: total = float(total) tva = float(tva) except (TypeError, ValueError): return ValidationResult( is_valid=True, message="Non-numeric values, skipping TVA-based total validation" ) if tva <= 0 or total <= 0: return ValidationResult( is_valid=True, message="Zero or negative values, skipping TVA-based total validation" ) # Try to determine TVA rate from entries tva_rate = None # Check tva_entries for rate information if tva_entries: for entry in tva_entries: if isinstance(entry, dict): percent = entry.get('percent') if percent: try: tva_rate = float(percent) / 100.0 break except (TypeError, ValueError): pass # Fallback: try to calculate rate from TVA/TOTAL ratio if not tva_rate: # TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate) # TVA/TOTAL = rate / (1 + rate) # So rate = TVA / (TOTAL - TVA) = TVA / BASE base = total - tva if base > 0: calculated_rate = tva / base # Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%) if 0.04 <= calculated_rate <= 0.25: tva_rate = calculated_rate if not tva_rate: # Assume most common rate: 21% tva_rate = 0.21 # Calculate expected TOTAL from TVA # TVA = BASE * rate → BASE = TVA / rate # TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate expected_total = tva * (1 + tva_rate) / tva_rate # Calculate difference diff = abs(total - expected_total) diff_percent = diff / expected_total if expected_total > 0 else 1.0 if diff_percent > self.tolerance_percent: # Significant mismatch - OCR likely extracted TOTAL wrong return ValidationResult( is_valid=False, confidence_penalty=0.5, # High penalty - this is a critical error message=( f"TOTAL mismatch: Extracted {total:.2f} RON vs " f"TVA-calculated {expected_total:.2f} RON " f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). " f"Likely OCR error on TOTAL." ), severity="error" ) return ValidationResult( is_valid=True, message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})" ) class InterOCRConsistencyRule(ValidationRule): """Validate consistency between multiple OCR results. If Light OCR and Medium OCR produce values that differ by >10x, one is clearly wrong (likely digit concatenation error). Example: rule = InterOCRConsistencyRule(max_ratio=10.0) result = rule.validate({ "light_amount": 85.99, "medium_amount": 859762.16 }) # result.is_valid = False (ratio = 10,000x!) """ def __init__(self, max_ratio: float = 10.0): self.max_ratio = max_ratio @property def rule_name(self) -> str: return "Inter-OCR Consistency Check" def validate(self, data: dict[str, Any]) -> ValidationResult: light_value = data.get("light_value") medium_value = data.get("medium_value") field_name = data.get("field_name", "value") if not light_value or not medium_value: return ValidationResult( is_valid=True, message="Insufficient OCR results for consistency check" ) # Avoid division by zero if light_value == 0 or medium_value == 0: return ValidationResult( is_valid=True, message="One value is zero, skipping consistency check" ) ratio = max(light_value, medium_value) / min(light_value, medium_value) if ratio > self.max_ratio: return ValidationResult( is_valid=False, confidence_penalty=0.2, message=f"{field_name}: OCR results differ by {ratio:.1f}x (Light: {light_value}, Medium: {medium_value})", severity="warning" ) return ValidationResult( is_valid=True, message=f"{field_name}: OCR results consistent (ratio: {ratio:.2f}x)" ) # ============================================================================ # VALIDATION ENGINE # ============================================================================ @dataclass class EnhancedExtractionResult: """Enhanced extraction result with validation metadata. This wraps the original extraction data and adds validation results. """ # Original data data: dict[str, Any] # Validation results needs_manual_review: bool = False validation_warnings: list[str] = field(default_factory=list) validation_errors: list[str] = field(default_factory=list) confidence_adjustments: dict[str, float] = field(default_factory=dict) # Inter-OCR metadata inter_ocr_ratios: dict[str, float] = field(default_factory=dict) class OCRValidationEngine: """Orchestrate all validation rules for OCR extraction results. This engine applies validation rules in order: 1. Sanity checks (amount range, format checks) 2. Cross-field correlation (TVA ratio, payment sum) 3. Inter-OCR consistency checks Example: engine = OCRValidationEngine() result = engine.validate_extraction( extraction_result=merged_data, light_result=light_ocr_data, medium_result=medium_ocr_data ) """ def __init__(self): """Initialize validation engine with default rules.""" # Sanity check rules (absolute value validation) self.sanity_rules = [ AmountRangeRule(min_amount=0.01, max_amount=100_000.0), CUIFormatRule(), CUIChecksumRule(), ] # Cross-field validation rules (correlation between fields) self.cross_field_rules = [ TVARatioRule(min_ratio=0.05, max_ratio=0.24), PaymentSumRule(tolerance=0.02), TVAEntriesSumRule(tolerance=0.02), TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA ] # Inter-OCR consistency rules self.inter_ocr_rules = [ InterOCRConsistencyRule(max_ratio=10.0), ] def validate_extraction( self, extraction_result: dict[str, Any], light_result: Optional[dict[str, Any]] = None, medium_result: Optional[dict[str, Any]] = None ) -> EnhancedExtractionResult: """Run all validation rules and return enhanced result. Args: extraction_result: Merged OCR extraction data (required) light_result: Light OCR preprocessing results (optional) medium_result: Medium OCR preprocessing results (optional) Returns: EnhancedExtractionResult with validation warnings and metadata """ warnings = [] errors = [] confidence_adjustments = {} inter_ocr_ratios = {} # Step 1: Sanity checks print("\n[Validation] Step 1: Sanity checks...", flush=True) for rule in self.sanity_rules: result = rule.validate(extraction_result) if not result.is_valid: msg = f"[{rule.rule_name}] {result.message}" if result.severity == "error": errors.append(msg) else: warnings.append(msg) print(f" ❌ {msg}", flush=True) # Track confidence penalty for the relevant field based on rule if result.confidence_penalty > 0: rule_field_map = { "Amount Range Check": ["amount"], "CUI Format Check": ["cui"], "CUI Checksum Check (Mod 11)": ["cui"], } fields = rule_field_map.get(rule.rule_name, ["amount", "tva", "cui"]) for f in fields: if f in extraction_result: confidence_adjustments[f] = result.confidence_penalty else: print(f" ✅ {rule.rule_name}: {result.message}", flush=True) # Step 2: Cross-field validation print("\n[Validation] Step 2: Cross-field validation...", flush=True) for rule in self.cross_field_rules: result = rule.validate(extraction_result) if not result.is_valid: msg = f"[{rule.rule_name}] {result.message}" if result.severity == "error": errors.append(msg) else: warnings.append(msg) print(f" ❌ {msg}", flush=True) # Track confidence penalty for the relevant field based on rule if result.confidence_penalty > 0: rule_field_map = { "TVA Ratio Check": ["tva"], "Payment Sum Check": ["amount"], "TVA Entries Sum Check": ["tva"], } fields = rule_field_map.get(rule.rule_name, ["amount", "tva"]) for f in fields: if f in extraction_result: confidence_adjustments[f] = result.confidence_penalty else: print(f" ✅ {rule.rule_name}: {result.message}", flush=True) # Step 3: Inter-OCR consistency checks if light_result and medium_result: print("\n[Validation] Step 3: Inter-OCR consistency...", flush=True) # Check amount consistency if "amount" in light_result and "amount" in medium_result: consistency_data = { "light_value": light_result["amount"], "medium_value": medium_result["amount"], "field_name": "amount" } result = self.inter_ocr_rules[0].validate(consistency_data) if not result.is_valid: msg = f"[Inter-OCR] {result.message}" warnings.append(msg) print(f" ❌ {msg}", flush=True) # Store ratio for metadata ratio = max( light_result["amount"], medium_result["amount"] ) / min(light_result["amount"], medium_result["amount"]) inter_ocr_ratios["amount"] = ratio else: print(f" ✅ {result.message}", flush=True) # Determine if manual review is needed # Only flag for review if there are errors OR high-severity warnings high_severity_warnings = [w for w in warnings if "[Amount Range" in w or "[Payment Sum" in w or "[Inter-OCR]" in w] needs_manual_review = ( len(errors) > 0 or len(high_severity_warnings) > 0 or any(ratio > 10.0 for ratio in inter_ocr_ratios.values()) ) print(f"\n[Validation] Summary:", flush=True) print(f" Errors: {len(errors)}", flush=True) print(f" Warnings: {len(warnings)}", flush=True) print(f" Manual review needed: {needs_manual_review}", flush=True) return EnhancedExtractionResult( data=extraction_result, needs_manual_review=needs_manual_review, validation_warnings=warnings, validation_errors=errors, confidence_adjustments=confidence_adjustments, inter_ocr_ratios=inter_ocr_ratios ) def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]: """Quick validation for early-exit decisions (e.g., doctr_plus Tier 1). Runs critical cross-validation rules to detect obvious OCR errors. Used to decide whether to proceed to next processing tier or exit early. Args: extraction_result: Extraction data dict with fields: - amount: Extracted TOTAL - tva: Extracted TVA total - tva_entries: List of TVA entries with rates Returns: Tuple of (passes_validation, confidence_penalty, error_messages) - passes_validation: True if no critical errors detected - confidence_penalty: Cumulative penalty (0.0-1.0) - error_messages: List of validation error messages Example usage: passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data) if not passes: print(f"Validation failed: {errors}, proceeding to next tier") # Continue to next processing tier instead of early exit """ errors = [] total_penalty = 0.0 # Critical rules for early-exit decision-making # These determine if we can trust the extraction or need to proceed to next tier critical_rules = [ # Cross-field validations (most important for detecting OCR errors) TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total # Format & checksum validations CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm CUIFormatRule(), # CUI should be 6-10 digits # Sanity checks AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range ] for rule in critical_rules: result = rule.validate(extraction_result) if not result.is_valid: errors.append(result.message) total_penalty += result.confidence_penalty # Cap penalty at 1.0 total_penalty = min(1.0, total_penalty) passes = len(errors) == 0 return passes, total_penalty, errors # NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed # Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead @staticmethod def _repair_cui_checksum(cui_digits: str) -> Optional[str]: """Try to repair CUI by attempting 1-digit corrections. OCR often misreads similar-looking digits: - 5 ↔ 8 (most common in receipts) - 6 ↔ 0 - 1 ↔ 7 - 3 ↔ 8 Algorithm: 1. Check middle positions first (2,3,4,5...) - OCR errors more common there 2. Skip first digit (position 0) - usually reliable in CUI 3. Check checksum digit (last position) last 4. Prefer common OCR digit confusions (5↔8, 6↔0) Args: cui_digits: Original CUI digits (without RO prefix) Returns: Repaired CUI digits if 1-digit fix found, else None """ if len(cui_digits) < 6 or not cui_digits.isdigit(): return None # If already valid, return as-is if CUIChecksumRule.validate_checksum(cui_digits): return cui_digits # Common OCR digit confusions (try these first) confusion_pairs = { '5': ['8', '6'], # 5 often misread as 8 or 6 '8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0 '6': ['0', '8'], # 6 often misread as 0 or 8 '0': ['6', '8'], # 0 often misread as 6 or 8 '1': ['7', '4'], # 1 often misread as 7 or 4 '7': ['1'], # 7 often misread as 1 '3': ['8'], # 3 often misread as 8 '4': ['1'], # 4 often misread as 1 '2': ['7'], # 2 sometimes misread as 7 '9': ['0'], # 9 sometimes misread as 0 } n = len(cui_digits) last_pos = n - 1 # checksum position # Position check order: middle positions first, then position 1, then 0, then checksum # Skip position 0 (first digit) - it's usually reliable # Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)] middle_positions = list(range(2, last_pos)) # positions 2 to n-2 position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong) for pos in position_order: if pos >= n: continue original_digit = cui_digits[pos] # Try common confusions first for this digit candidates = confusion_pairs.get(original_digit, []) # Then try all other digits all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates] for replacement in candidates + all_digits: candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:] if CUIChecksumRule.validate_checksum(candidate): print(f"[CUI Repair] Fixed {cui_digits} -> {candidate} (position {pos}: {original_digit}->{replacement})", flush=True) return candidate # No single-digit fix found return None @staticmethod def normalize_cui(cui: Optional[str]) -> Optional[str]: """Normalize CUI - fix OCR errors but preserve original format. Rules: - R0 → RO (fix OCR error where O is read as 0) - Keep RO prefix if original had it (platitor TVA) - Do NOT add RO if original didn't have it (neplatitor TVA) - Try to repair 1-digit checksum errors (OCR mistakes like 5↔8) Examples: 45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is) R010562600 → RO10562600 (fix R0 OCR error) RO10562600 → RO10562600 (unchanged) RO10862600 → RO10562600 (repaired: 8→5 at position 2) Args: cui: Raw CUI string from OCR Returns: Normalized CUI, or None if invalid """ if not cui: return None cui = cui.strip().upper() # Check if original had RO/R0 prefix had_ro_prefix = cui.startswith("RO") or cui.startswith("R0") # Extract digits if cui.startswith("RO"): cui_digits = cui[2:] elif cui.startswith("R0"): # Fix OCR error R0 → RO cui_digits = cui[2:] else: cui_digits = cui # Remove any non-digit characters cui_digits = ''.join(c for c in cui_digits if c.isdigit()) # Validate length if len(cui_digits) < 6 or len(cui_digits) > 10: print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True) return None # Try to repair checksum if invalid if not CUIChecksumRule.validate_checksum(cui_digits): repaired = OCRValidationEngine._repair_cui_checksum(cui_digits) if repaired: cui_digits = repaired # Return with RO prefix only if original had it if had_ro_prefix: return f"RO{cui_digits}" else: return cui_digits @staticmethod async def fuzzy_match_cui_from_db( cui: Optional[str], db_session ) -> Optional[tuple[str, str]]: """Fuzzy match CUI against database of known suppliers. This function: 1. Validates CUI checksum 2. If valid, looks up in database (exact match) 3. If invalid, tries 1-digit corrections and looks up each candidate 4. Returns the first match found in database Args: cui: Extracted CUI from OCR (may be invalid) db_session: SQLAlchemy async session for database lookups Returns: Tuple of (corrected_cui, supplier_name) if found, else None Usage in OCR extraction: from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session) if match: corrected_cui, supplier_name = match """ from sqlalchemy import select, or_ from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier if not cui: return None cui = cui.strip().upper() # Check if original had RO/R0 prefix had_ro_prefix = cui.startswith("RO") or cui.startswith("R0") # Extract digits if cui.startswith("RO"): cui_digits = cui[2:] elif cui.startswith("R0"): # Fix OCR error R0 → RO cui_digits = cui[2:] else: cui_digits = cui # Remove any non-digit characters cui_digits = ''.join(c for c in cui_digits if c.isdigit()) # Validate length if len(cui_digits) < 6 or len(cui_digits) > 10: return None # Helper to format CUI with optional RO prefix def format_cui(digits: str) -> str: if had_ro_prefix: return f"RO{digits}" return digits # Helper to search database for CUI async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]: """Search both synced and local suppliers for CUI.""" # Search patterns: with and without RO prefix search_patterns = [digits, f"RO{digits}"] # Search synced_suppliers first (more data) stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where( or_( SyncedSupplier.fiscal_code == digits, SyncedSupplier.fiscal_code == f"RO{digits}", SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros ) ).limit(1) result = await db_session.execute(stmt) row = result.first() if row: return (format_cui(digits), row.name) # Search local_suppliers stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where( or_( LocalSupplier.fiscal_code == digits, LocalSupplier.fiscal_code == f"RO{digits}", LocalSupplier.fiscal_code == digits.lstrip('0'), ) ).limit(1) result = await db_session.execute(stmt) row = result.first() if row: return (format_cui(digits), row.name) return None # 1. If checksum is valid, check if it exists in database (exact match) if CUIChecksumRule.validate_checksum(cui_digits): match = await lookup_cui_in_db(cui_digits) if match: print(f"[Fuzzy CUI] Exact match found: {cui} -> {match[0]} ({match[1]})", flush=True) return match # Valid checksum but not in DB - return as-is (it might be a new supplier) return None # 2. Invalid checksum - try 1-digit corrections and verify against database print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True) # Common OCR digit confusions (try these first) confusion_pairs = { '5': ['8', '6'], # 5 often misread as 8 or 6 '8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0 '6': ['0', '8'], # 6 often misread as 0 or 8 '0': ['6', '8'], # 0 often misread as 6 or 8 '1': ['7', '4'], # 1 often misread as 7 or 4 '7': ['1'], # 7 often misread as 1 '3': ['8'], # 3 often misread as 8 '4': ['1'], # 4 often misread as 1 '2': ['7'], # 2 sometimes misread as 7 '9': ['0'], # 9 sometimes misread as 0 } n = len(cui_digits) last_pos = n - 1 # checksum position # Position check order: middle positions first, then ends middle_positions = list(range(2, last_pos)) position_order = middle_positions + [1, last_pos, 0] for pos in position_order: if pos >= n: continue original_digit = cui_digits[pos] # Try common confusions first for this digit candidates = confusion_pairs.get(original_digit, []) # Then try all other digits all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates] for replacement in candidates + all_digits: candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:] # Only consider if checksum is valid if not CUIChecksumRule.validate_checksum(candidate): continue # Check if this corrected CUI exists in database match = await lookup_cui_in_db(candidate) if match: print(f"[Fuzzy CUI] DB match: {cui} -> {match[0]} ({match[1]}) [pos {pos}: {original_digit}->{replacement}]", flush=True) return match # No match found in database print(f"[Fuzzy CUI] No database match found for {cui}", flush=True) return None @staticmethod async def fuzzy_match_by_name_and_cui( vendor_name: Optional[str], cui: Optional[str], db_session ) -> Optional[tuple[str, str]]: """Fuzzy match supplier by NAME, then narrow down by CUI. Algorithm: 1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.) 2. Search suppliers by fuzzy name match (LIKE %name%) 3. If multiple results, use fuzzy CUI matching to pick best one 4. Return the best match Args: vendor_name: Extracted vendor name from OCR cui: Extracted CUI from OCR (may be invalid/incomplete) db_session: SQLAlchemy async session Returns: Tuple of (matched_cui, supplier_name) if found, else None """ from sqlalchemy import select, or_, func from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier import re if not vendor_name or len(vendor_name) < 3: return None # Normalize vendor name for search def normalize_name(name: str) -> str: """Normalize name for fuzzy matching.""" name = name.upper() # Remove company type suffixes for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']: name = name.replace(suffix, '') # Remove punctuation and extra spaces name = re.sub(r'[.,\-_/\\()"\']', ' ', name) name = ' '.join(name.split()) return name.strip() # Extract key words from vendor name (for fuzzy search) normalized_name = normalize_name(vendor_name) name_words = [w for w in normalized_name.split() if len(w) >= 3] if not name_words: return None print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' -> keywords: {name_words}", flush=True) # Build search pattern - use first significant word primary_word = name_words[0] search_pattern = f"%{primary_word}%" candidates = [] # Search synced_suppliers stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where( func.upper(SyncedSupplier.name).like(search_pattern) ).limit(20) result = await db_session.execute(stmt) for row in result: if row.fiscal_code: candidates.append((row.fiscal_code, row.name)) # Search local_suppliers stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where( func.upper(LocalSupplier.name).like(search_pattern) ).limit(20) result = await db_session.execute(stmt) for row in result: if row.fiscal_code: candidates.append((row.fiscal_code, row.name)) if not candidates: print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True) return None print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True) # If only one candidate, return it if len(candidates) == 1: print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True) return candidates[0] # Multiple candidates - try to narrow down by CUI if cui: cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit()) if len(cui_digits) >= 6: # Score each candidate by how similar their CUI is to the extracted one def cui_similarity(candidate_cui: str) -> int: """Calculate how many digits match in the same position.""" cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit()) if len(cand_digits) != len(cui_digits): return 0 return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b) # Sort candidates by CUI similarity (descending) scored = [(cui_similarity(c[0]), c) for c in candidates] scored.sort(key=lambda x: x[0], reverse=True) best_score, best_match = scored[0] # Require at least 70% digit match for CUI similarity min_matching = int(len(cui_digits) * 0.7) if best_score >= min_matching: print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True) return best_match print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True) # If still multiple and no CUI match, try name similarity def name_similarity(candidate_name: str) -> int: """Count how many keywords match.""" norm_cand = normalize_name(candidate_name) return sum(1 for w in name_words if w in norm_cand) scored = [(name_similarity(c[1]), c) for c in candidates] scored.sort(key=lambda x: x[0], reverse=True) if scored[0][0] >= 2: # At least 2 keywords match best_match = scored[0][1] print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True) return best_match # Return first candidate if nothing else works print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True) return candidates[0] @staticmethod async def fuzzy_match_supplier( cui: Optional[str], vendor_name: Optional[str], db_session ) -> Optional[tuple[str, str]]: """Combined fuzzy matching: try CUI first, then fallback to NAME+CUI. Strategy: 1. Try fuzzy CUI matching (1-digit corrections with checksum validation) 2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity Args: cui: Extracted CUI from OCR (may be invalid/incomplete) vendor_name: Extracted vendor name from OCR db_session: SQLAlchemy async session Returns: Tuple of (matched_cui, supplier_name) if found, else None """ # Step 1: Try fuzzy CUI matching cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session) if cui_match: return cui_match # Step 2: Fallback to fuzzy NAME + CUI matching name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui( vendor_name, cui, db_session ) if name_match: return name_match return None