- Add PWA manifest, icons (192x192, 512x512), and service worker - Register service worker in index.html with Apple mobile web app support - Consolidate CSS variables and design tokens documentation - Update PrimeVue overrides for consistent theming - Refactor data-entry components to use shared CSS patterns - Add frontend-style-auditor agent for style consistency checks - Minor OCR validation and job worker improvements - Update start-prod.sh configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1395 lines
50 KiB
Python
1395 lines
50 KiB
Python
"""
|
|
OCR Data Validation Module
|
|
|
|
Provides multi-layer validation for OCR extraction results to prevent
|
|
incorrect data from entering the system.
|
|
|
|
Validation Layers:
|
|
1. Absolute sanity checks (value ranges)
|
|
2. Cross-field validation (correlation between fields)
|
|
3. Inter-OCR consistency (compare multiple OCR results)
|
|
4. Auto-correction (fix obvious errors)
|
|
|
|
Usage:
|
|
engine = OCRValidationEngine()
|
|
validated_result = engine.validate_extraction(
|
|
merged_result,
|
|
light_ocr_result,
|
|
medium_ocr_result
|
|
)
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional
|
|
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of a single validation rule execution.
|
|
|
|
Attributes:
|
|
is_valid: Whether the validation passed
|
|
confidence_penalty: Penalty to apply to confidence score (0.0-1.0)
|
|
0.0 = no penalty, 1.0 = complete rejection
|
|
message: Human-readable description of validation result
|
|
severity: "info" | "warning" | "error"
|
|
"""
|
|
is_valid: bool
|
|
confidence_penalty: float = 0.0
|
|
message: str = ""
|
|
severity: str = "info" # "info" | "warning" | "error"
|
|
|
|
def __post_init__(self):
|
|
"""Validate penalty is in valid range."""
|
|
if not 0.0 <= self.confidence_penalty <= 1.0:
|
|
raise ValueError(f"Confidence penalty must be 0.0-1.0, got {self.confidence_penalty}")
|
|
|
|
|
|
class ValidationRule(ABC):
|
|
"""Abstract base class for OCR validation rules.
|
|
|
|
Each rule implements a specific validation check and returns
|
|
a ValidationResult indicating success/failure with optional
|
|
confidence penalty.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
"""Execute validation rule on extraction data.
|
|
|
|
Args:
|
|
data: Dictionary containing extraction fields to validate
|
|
Example: {"amount": 85.99, "tva": 14.92, ...}
|
|
|
|
Returns:
|
|
ValidationResult with is_valid flag and optional penalty
|
|
"""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def rule_name(self) -> str:
|
|
"""Human-readable name of this validation rule."""
|
|
pass
|
|
|
|
|
|
# ============================================================================
|
|
# VALIDATION RULES
|
|
# ============================================================================
|
|
|
|
|
|
class AmountRangeRule(ValidationRule):
|
|
"""Validate amount is within reasonable bounds for Romanian receipts.
|
|
|
|
Romanian receipts rarely exceed 100,000 RON. This catches obvious
|
|
OCR errors like digit concatenation (85.99 → 859,762.16).
|
|
|
|
Example:
|
|
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
|
|
result = rule.validate({"amount": 859762.16})
|
|
# result.is_valid = False, penalty = 0.5
|
|
"""
|
|
|
|
def __init__(self, min_amount: float = 0.01, max_amount: float = 100_000.0):
|
|
self.min_amount = min_amount
|
|
self.max_amount = max_amount
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "Amount Range Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
amount = data.get("amount")
|
|
|
|
if amount is None:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No amount to validate"
|
|
)
|
|
|
|
if amount < self.min_amount:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.5,
|
|
message=f"Amount {amount:.2f} RON below minimum {self.min_amount:.2f} RON",
|
|
severity="error"
|
|
)
|
|
|
|
if amount > self.max_amount:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.5,
|
|
message=f"Amount {amount:.2f} RON exceeds maximum {self.max_amount:.2f} RON (likely OCR error)",
|
|
severity="error"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"Amount {amount:.2f} RON within valid range"
|
|
)
|
|
|
|
|
|
class TVARatioRule(ValidationRule):
|
|
"""Validate TVA is reasonable percentage of TOTAL amount.
|
|
|
|
Romanian TVA rates: 5%, 9%, 19%, 21% (most common: 19-21%)
|
|
This catches errors where TVA > TOTAL (impossible).
|
|
|
|
Example:
|
|
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
|
|
result = rule.validate({"amount": 85.99, "tva": 149.21})
|
|
# result.is_valid = False (149.21 > 85.99!)
|
|
"""
|
|
|
|
def __init__(self, min_ratio: float = 0.05, max_ratio: float = 0.24):
|
|
self.min_ratio = min_ratio
|
|
self.max_ratio = max_ratio
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "TVA Ratio Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
amount = data.get("amount")
|
|
tva = data.get("tva")
|
|
|
|
if not amount or not tva:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Insufficient data for TVA correlation"
|
|
)
|
|
|
|
# Type safety: ensure numeric types before division
|
|
if not isinstance(amount, (int, float)) or not isinstance(tva, (int, float)):
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Non-numeric values, skipping TVA correlation"
|
|
)
|
|
|
|
# Avoid division by zero
|
|
if amount <= 0:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Amount is zero or negative, skipping TVA ratio"
|
|
)
|
|
|
|
tva_ratio = tva / amount
|
|
|
|
if tva_ratio < self.min_ratio or tva_ratio > self.max_ratio:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.3,
|
|
message=f"TVA ratio {tva_ratio:.1%} outside valid range ({self.min_ratio:.0%}-{self.max_ratio:.0%})",
|
|
severity="warning"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"TVA ratio {tva_ratio:.1%} valid"
|
|
)
|
|
|
|
|
|
class PaymentSumRule(ValidationRule):
|
|
"""Validate CARD + NUMERAR = TOTAL BON (within tolerance).
|
|
|
|
This is a CRITICAL validation that catches cases where OCR extracts
|
|
wrong TOTAL but correct payment methods.
|
|
|
|
Example:
|
|
rule = PaymentSumRule(tolerance=0.02)
|
|
result = rule.validate({
|
|
"amount": 859762.16, # Wrong from OCR
|
|
"card_amount": 85.99, # Correct
|
|
"cash_amount": 0.0
|
|
})
|
|
# result.is_valid = False, suggests auto-correction
|
|
"""
|
|
|
|
def __init__(self, tolerance: float = 0.02):
|
|
self.tolerance = tolerance
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "Payment Sum Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
total = data.get("amount")
|
|
card = data.get("card_amount", 0.0) or 0.0
|
|
cash = data.get("cash_amount", 0.0) or 0.0
|
|
|
|
if not total:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No total amount to validate"
|
|
)
|
|
|
|
payment_sum = card + cash
|
|
|
|
if payment_sum == 0:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No payment methods extracted"
|
|
)
|
|
|
|
diff = abs(total - payment_sum)
|
|
|
|
if diff > self.tolerance:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.4,
|
|
message=f"Payment sum {payment_sum:.2f} RON != Total {total:.2f} RON (diff: {diff:.2f} RON). Consider auto-correction.",
|
|
severity="error"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"Payment sum matches total (diff: {diff:.2f} RON)"
|
|
)
|
|
|
|
|
|
class TVAEntriesSumRule(ValidationRule):
|
|
"""Validate Σ(TVA entries) = TVA TOTAL (within tolerance).
|
|
|
|
TVA breakdown (A, B, C, D rates) should sum to total TVA.
|
|
|
|
Example:
|
|
rule = TVAEntriesSumRule(tolerance=0.02)
|
|
result = rule.validate({
|
|
"tva": 14.92,
|
|
"tva_entries": {"A": 14.92, "B": 0.0}
|
|
})
|
|
# result.is_valid = True
|
|
"""
|
|
|
|
def __init__(self, tolerance: float = 0.02):
|
|
self.tolerance = tolerance
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "TVA Entries Sum Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
tva_total = data.get("tva")
|
|
tva_entries = data.get("tva_entries", {})
|
|
|
|
if not tva_total:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No TVA total to validate"
|
|
)
|
|
|
|
if not tva_entries:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No TVA entries extracted"
|
|
)
|
|
|
|
entries_sum = sum(tva_entries.values())
|
|
|
|
if entries_sum == 0:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="TVA entries sum is zero"
|
|
)
|
|
|
|
diff = abs(tva_total - entries_sum)
|
|
|
|
if diff > self.tolerance:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.2,
|
|
message=f"TVA entries sum {entries_sum:.2f} RON != TVA total {tva_total:.2f} RON (diff: {diff:.2f} RON)",
|
|
severity="warning"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"TVA entries sum matches total (diff: {diff:.2f} RON)"
|
|
)
|
|
|
|
|
|
class CUIFormatRule(ValidationRule):
|
|
"""Validate CUI format: RO + 6-10 digits.
|
|
|
|
Romanian CUI (Cod Unic de Identificare) format:
|
|
- Optional "RO" prefix (or "R0" from OCR errors)
|
|
- 6-10 numeric digits
|
|
|
|
Example:
|
|
rule = CUIFormatRule()
|
|
result = rule.validate({"cui": "RO10562600"})
|
|
# result.is_valid = True
|
|
"""
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "CUI Format Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
cui = data.get("cui")
|
|
|
|
if not cui:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No CUI to validate"
|
|
)
|
|
|
|
# Normalize: remove RO/R0 prefix
|
|
cui_clean = cui.strip().upper()
|
|
if cui_clean.startswith("RO"):
|
|
cui_clean = cui_clean[2:]
|
|
elif cui_clean.startswith("R0"):
|
|
cui_clean = cui_clean[2:]
|
|
|
|
# Check if numeric
|
|
if not cui_clean.isdigit():
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.3,
|
|
message=f"CUI '{cui}' contains non-numeric characters",
|
|
severity="warning"
|
|
)
|
|
|
|
# Check length
|
|
if len(cui_clean) < 6 or len(cui_clean) > 10:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.3,
|
|
message=f"CUI '{cui}' length {len(cui_clean)} outside valid range (6-10 digits)",
|
|
severity="warning"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"CUI '{cui}' format valid"
|
|
)
|
|
|
|
|
|
class CUIChecksumRule(ValidationRule):
|
|
"""Validate Romanian CIF/CUI using Mod 11 checksum algorithm.
|
|
|
|
Algorithm:
|
|
1. Remove RO prefix if present
|
|
2. Extract last digit as declared checksum
|
|
3. Apply multipliers [7,5,3,2,1,7,5,3,2] to first N-1 digits
|
|
4. Calculate: (sum * 10) mod 11
|
|
5. If result = 10, expected checksum = 0
|
|
6. Else, expected checksum = result
|
|
7. Compare with declared checksum
|
|
|
|
Example:
|
|
rule = CUIChecksumRule()
|
|
result = rule.validate({"cui": "RO10562600"})
|
|
# result.is_valid = True (checksum correct)
|
|
|
|
result = rule.validate({"cui": "R01879855"})
|
|
# result.is_valid = False (checksum mismatch)
|
|
|
|
Static methods available for direct use:
|
|
CUIChecksumRule.calculate_checksum("1056260") -> 0
|
|
CUIChecksumRule.validate_checksum("10562600") -> True
|
|
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
|
|
"""
|
|
|
|
# Fixed multipliers for 9 positions (Romanian Mod 11)
|
|
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
|
|
|
|
@staticmethod
|
|
def calculate_checksum(cui_base: str) -> int:
|
|
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
|
|
|
|
Args:
|
|
cui_base: CUI digits WITHOUT the checksum digit (last digit)
|
|
|
|
Returns:
|
|
Expected checksum digit (0-9), or -1 if invalid input
|
|
"""
|
|
if not cui_base or not cui_base.isdigit():
|
|
return -1
|
|
|
|
# Pad base to 9 digits from LEFT
|
|
base_padded = cui_base.zfill(9)
|
|
base_digits = [int(d) for d in base_padded]
|
|
|
|
# Calculate weighted sum
|
|
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
|
|
|
|
# Calculate checksum
|
|
checksum = (weighted_sum * 10) % 11
|
|
if checksum == 10:
|
|
checksum = 0
|
|
|
|
return checksum
|
|
|
|
@staticmethod
|
|
def validate_checksum(cui_digits: str) -> bool:
|
|
"""Check if CUI checksum is valid.
|
|
|
|
Args:
|
|
cui_digits: Full CUI digits (including checksum as last digit)
|
|
|
|
Returns:
|
|
True if checksum is valid, False otherwise
|
|
"""
|
|
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
|
|
return False
|
|
|
|
base = cui_digits[:-1]
|
|
declared = int(cui_digits[-1])
|
|
expected = CUIChecksumRule.calculate_checksum(base)
|
|
|
|
return expected == declared
|
|
|
|
@staticmethod
|
|
def has_ro_prefix(cui: str) -> bool:
|
|
"""Check if CUI has RO prefix (proper format for VAT payers)."""
|
|
if not cui:
|
|
return False
|
|
return cui.upper().strip().startswith('RO')
|
|
|
|
@staticmethod
|
|
def extract_digits(cui: str) -> str:
|
|
"""Extract digits from CUI, removing RO/R0 prefix."""
|
|
if not cui:
|
|
return ""
|
|
cui = cui.strip().upper()
|
|
if cui.startswith("RO"):
|
|
cui = cui[2:]
|
|
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
|
cui = cui[2:]
|
|
return ''.join(c for c in cui if c.isdigit())
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "CUI Checksum Check (Mod 11)"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
cui = data.get("cui")
|
|
|
|
if not cui:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="No CUI to validate"
|
|
)
|
|
|
|
# Use static method to extract digits
|
|
cui_clean = CUIChecksumRule.extract_digits(cui)
|
|
|
|
# Check format first
|
|
if not cui_clean:
|
|
return ValidationResult(
|
|
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
|
|
message="CUI format invalid, skipping checksum"
|
|
)
|
|
|
|
if len(cui_clean) < 6 or len(cui_clean) > 10:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="CUI length invalid, skipping checksum"
|
|
)
|
|
|
|
# Use static method to validate checksum
|
|
if not CUIChecksumRule.validate_checksum(cui_clean):
|
|
# Calculate expected for error message
|
|
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
|
|
declared = int(cui_clean[-1])
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.3,
|
|
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
|
|
severity="warning"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"CUI '{cui}' checksum valid"
|
|
)
|
|
|
|
|
|
class TVABasedTotalRule(ValidationRule):
|
|
"""Validate TOTAL using reverse calculation from TVA amount.
|
|
|
|
This is a CRITICAL validation that catches cases where OCR extracts
|
|
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
|
|
we can calculate expected TOTAL from TVA alone.
|
|
|
|
Formula:
|
|
Expected TOTAL = TVA / rate * (1 + rate)
|
|
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
|
|
|
|
For TVA rate 21%:
|
|
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
|
|
|
|
Example (benzina 27 oct):
|
|
TVA = 49.58, rate = 21%
|
|
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
|
|
Extracted TOTAL = 205.66 (WRONG!)
|
|
Rule detects mismatch and flags for escalation
|
|
|
|
Usage in multi-tier processing (e.g., doctr_plus):
|
|
If this rule fails, the engine should proceed to next tier
|
|
instead of returning early with potentially wrong data.
|
|
"""
|
|
|
|
def __init__(self, tolerance_percent: float = 0.02):
|
|
"""
|
|
Args:
|
|
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
|
|
"""
|
|
self.tolerance_percent = tolerance_percent
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "TVA-Based Total Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
total = data.get("amount")
|
|
tva = data.get("tva")
|
|
tva_entries = data.get("tva_entries", [])
|
|
|
|
if not total or not tva:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Insufficient data for TVA-based total validation"
|
|
)
|
|
|
|
# Type safety
|
|
try:
|
|
total = float(total)
|
|
tva = float(tva)
|
|
except (TypeError, ValueError):
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Non-numeric values, skipping TVA-based total validation"
|
|
)
|
|
|
|
if tva <= 0 or total <= 0:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Zero or negative values, skipping TVA-based total validation"
|
|
)
|
|
|
|
# Try to determine TVA rate from entries
|
|
tva_rate = None
|
|
|
|
# Check tva_entries for rate information
|
|
if tva_entries:
|
|
for entry in tva_entries:
|
|
if isinstance(entry, dict):
|
|
percent = entry.get('percent')
|
|
if percent:
|
|
try:
|
|
tva_rate = float(percent) / 100.0
|
|
break
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
# Fallback: try to calculate rate from TVA/TOTAL ratio
|
|
if not tva_rate:
|
|
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
|
|
# TVA/TOTAL = rate / (1 + rate)
|
|
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
|
|
base = total - tva
|
|
if base > 0:
|
|
calculated_rate = tva / base
|
|
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
|
|
if 0.04 <= calculated_rate <= 0.25:
|
|
tva_rate = calculated_rate
|
|
|
|
if not tva_rate:
|
|
# Assume most common rate: 21%
|
|
tva_rate = 0.21
|
|
|
|
# Calculate expected TOTAL from TVA
|
|
# TVA = BASE * rate → BASE = TVA / rate
|
|
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
|
|
expected_total = tva * (1 + tva_rate) / tva_rate
|
|
|
|
# Calculate difference
|
|
diff = abs(total - expected_total)
|
|
diff_percent = diff / expected_total if expected_total > 0 else 1.0
|
|
|
|
if diff_percent > self.tolerance_percent:
|
|
# Significant mismatch - OCR likely extracted TOTAL wrong
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.5, # High penalty - this is a critical error
|
|
message=(
|
|
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
|
|
f"TVA-calculated {expected_total:.2f} RON "
|
|
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
|
|
f"Likely OCR error on TOTAL."
|
|
),
|
|
severity="error"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
|
|
)
|
|
|
|
|
|
class InterOCRConsistencyRule(ValidationRule):
|
|
"""Validate consistency between multiple OCR results.
|
|
|
|
If Light OCR and Medium OCR produce values that differ by >10x,
|
|
one is clearly wrong (likely digit concatenation error).
|
|
|
|
Example:
|
|
rule = InterOCRConsistencyRule(max_ratio=10.0)
|
|
result = rule.validate({
|
|
"light_amount": 85.99,
|
|
"medium_amount": 859762.16
|
|
})
|
|
# result.is_valid = False (ratio = 10,000x!)
|
|
"""
|
|
|
|
def __init__(self, max_ratio: float = 10.0):
|
|
self.max_ratio = max_ratio
|
|
|
|
@property
|
|
def rule_name(self) -> str:
|
|
return "Inter-OCR Consistency Check"
|
|
|
|
def validate(self, data: dict[str, Any]) -> ValidationResult:
|
|
light_value = data.get("light_value")
|
|
medium_value = data.get("medium_value")
|
|
field_name = data.get("field_name", "value")
|
|
|
|
if not light_value or not medium_value:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="Insufficient OCR results for consistency check"
|
|
)
|
|
|
|
# Avoid division by zero
|
|
if light_value == 0 or medium_value == 0:
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message="One value is zero, skipping consistency check"
|
|
)
|
|
|
|
ratio = max(light_value, medium_value) / min(light_value, medium_value)
|
|
|
|
if ratio > self.max_ratio:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
confidence_penalty=0.2,
|
|
message=f"{field_name}: OCR results differ by {ratio:.1f}x (Light: {light_value}, Medium: {medium_value})",
|
|
severity="warning"
|
|
)
|
|
|
|
return ValidationResult(
|
|
is_valid=True,
|
|
message=f"{field_name}: OCR results consistent (ratio: {ratio:.2f}x)"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# VALIDATION ENGINE
|
|
# ============================================================================
|
|
|
|
|
|
@dataclass
|
|
class EnhancedExtractionResult:
|
|
"""Enhanced extraction result with validation metadata.
|
|
|
|
This wraps the original extraction data and adds validation results.
|
|
"""
|
|
# Original data
|
|
data: dict[str, Any]
|
|
|
|
# Validation results
|
|
needs_manual_review: bool = False
|
|
validation_warnings: list[str] = field(default_factory=list)
|
|
validation_errors: list[str] = field(default_factory=list)
|
|
confidence_adjustments: dict[str, float] = field(default_factory=dict)
|
|
|
|
# Inter-OCR metadata
|
|
inter_ocr_ratios: dict[str, float] = field(default_factory=dict)
|
|
|
|
|
|
class OCRValidationEngine:
|
|
"""Orchestrate all validation rules for OCR extraction results.
|
|
|
|
This engine applies validation rules in order:
|
|
1. Sanity checks (amount range, format checks)
|
|
2. Cross-field correlation (TVA ratio, payment sum)
|
|
3. Inter-OCR consistency checks
|
|
|
|
Example:
|
|
engine = OCRValidationEngine()
|
|
result = engine.validate_extraction(
|
|
extraction_result=merged_data,
|
|
light_result=light_ocr_data,
|
|
medium_result=medium_ocr_data
|
|
)
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize validation engine with default rules."""
|
|
# Sanity check rules (absolute value validation)
|
|
self.sanity_rules = [
|
|
AmountRangeRule(min_amount=0.01, max_amount=100_000.0),
|
|
CUIFormatRule(),
|
|
CUIChecksumRule(),
|
|
]
|
|
|
|
# Cross-field validation rules (correlation between fields)
|
|
self.cross_field_rules = [
|
|
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
|
|
PaymentSumRule(tolerance=0.02),
|
|
TVAEntriesSumRule(tolerance=0.02),
|
|
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
|
|
]
|
|
|
|
# Inter-OCR consistency rules
|
|
self.inter_ocr_rules = [
|
|
InterOCRConsistencyRule(max_ratio=10.0),
|
|
]
|
|
|
|
def validate_extraction(
|
|
self,
|
|
extraction_result: dict[str, Any],
|
|
light_result: Optional[dict[str, Any]] = None,
|
|
medium_result: Optional[dict[str, Any]] = None
|
|
) -> EnhancedExtractionResult:
|
|
"""Run all validation rules and return enhanced result.
|
|
|
|
Args:
|
|
extraction_result: Merged OCR extraction data (required)
|
|
light_result: Light OCR preprocessing results (optional)
|
|
medium_result: Medium OCR preprocessing results (optional)
|
|
|
|
Returns:
|
|
EnhancedExtractionResult with validation warnings and metadata
|
|
"""
|
|
warnings = []
|
|
errors = []
|
|
confidence_adjustments = {}
|
|
inter_ocr_ratios = {}
|
|
|
|
# Step 1: Sanity checks
|
|
print("\n[Validation] Step 1: Sanity checks...", flush=True)
|
|
for rule in self.sanity_rules:
|
|
result = rule.validate(extraction_result)
|
|
|
|
if not result.is_valid:
|
|
msg = f"[{rule.rule_name}] {result.message}"
|
|
|
|
if result.severity == "error":
|
|
errors.append(msg)
|
|
else:
|
|
warnings.append(msg)
|
|
|
|
print(f" [X] {msg}", flush=True)
|
|
|
|
# Track confidence penalty for the relevant field based on rule
|
|
if result.confidence_penalty > 0:
|
|
rule_field_map = {
|
|
"Amount Range Check": ["amount"],
|
|
"CUI Format Check": ["cui"],
|
|
"CUI Checksum Check (Mod 11)": ["cui"],
|
|
}
|
|
fields = rule_field_map.get(rule.rule_name, ["amount", "tva", "cui"])
|
|
for f in fields:
|
|
if f in extraction_result:
|
|
confidence_adjustments[f] = result.confidence_penalty
|
|
else:
|
|
print(f" [OK] {rule.rule_name}: {result.message}", flush=True)
|
|
|
|
# Step 2: Cross-field validation
|
|
print("\n[Validation] Step 2: Cross-field validation...", flush=True)
|
|
for rule in self.cross_field_rules:
|
|
result = rule.validate(extraction_result)
|
|
|
|
if not result.is_valid:
|
|
msg = f"[{rule.rule_name}] {result.message}"
|
|
|
|
if result.severity == "error":
|
|
errors.append(msg)
|
|
else:
|
|
warnings.append(msg)
|
|
|
|
print(f" [X] {msg}", flush=True)
|
|
|
|
# Track confidence penalty for the relevant field based on rule
|
|
if result.confidence_penalty > 0:
|
|
rule_field_map = {
|
|
"TVA Ratio Check": ["tva"],
|
|
"Payment Sum Check": ["amount"],
|
|
"TVA Entries Sum Check": ["tva"],
|
|
}
|
|
fields = rule_field_map.get(rule.rule_name, ["amount", "tva"])
|
|
for f in fields:
|
|
if f in extraction_result:
|
|
confidence_adjustments[f] = result.confidence_penalty
|
|
else:
|
|
print(f" [OK] {rule.rule_name}: {result.message}", flush=True)
|
|
|
|
# Step 3: Inter-OCR consistency checks
|
|
if light_result and medium_result:
|
|
print("\n[Validation] Step 3: Inter-OCR consistency...", flush=True)
|
|
|
|
# Check amount consistency
|
|
if "amount" in light_result and "amount" in medium_result:
|
|
consistency_data = {
|
|
"light_value": light_result["amount"],
|
|
"medium_value": medium_result["amount"],
|
|
"field_name": "amount"
|
|
}
|
|
|
|
result = self.inter_ocr_rules[0].validate(consistency_data)
|
|
|
|
if not result.is_valid:
|
|
msg = f"[Inter-OCR] {result.message}"
|
|
warnings.append(msg)
|
|
print(f" [X] {msg}", flush=True)
|
|
|
|
# Store ratio for metadata
|
|
ratio = max(
|
|
light_result["amount"],
|
|
medium_result["amount"]
|
|
) / min(light_result["amount"], medium_result["amount"])
|
|
inter_ocr_ratios["amount"] = ratio
|
|
else:
|
|
print(f" [OK] {result.message}", flush=True)
|
|
|
|
# Determine if manual review is needed
|
|
# Only flag for review if there are errors OR high-severity warnings
|
|
high_severity_warnings = [w for w in warnings if "[Amount Range" in w or "[Payment Sum" in w or "[Inter-OCR]" in w]
|
|
needs_manual_review = (
|
|
len(errors) > 0 or
|
|
len(high_severity_warnings) > 0 or
|
|
any(ratio > 10.0 for ratio in inter_ocr_ratios.values())
|
|
)
|
|
|
|
print(f"\n[Validation] Summary:", flush=True)
|
|
print(f" Errors: {len(errors)}", flush=True)
|
|
print(f" Warnings: {len(warnings)}", flush=True)
|
|
print(f" Manual review needed: {needs_manual_review}", flush=True)
|
|
|
|
return EnhancedExtractionResult(
|
|
data=extraction_result,
|
|
needs_manual_review=needs_manual_review,
|
|
validation_warnings=warnings,
|
|
validation_errors=errors,
|
|
confidence_adjustments=confidence_adjustments,
|
|
inter_ocr_ratios=inter_ocr_ratios
|
|
)
|
|
|
|
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
|
|
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
|
|
|
|
Runs critical cross-validation rules to detect obvious OCR errors.
|
|
Used to decide whether to proceed to next processing tier or exit early.
|
|
|
|
Args:
|
|
extraction_result: Extraction data dict with fields:
|
|
- amount: Extracted TOTAL
|
|
- tva: Extracted TVA total
|
|
- tva_entries: List of TVA entries with rates
|
|
|
|
Returns:
|
|
Tuple of (passes_validation, confidence_penalty, error_messages)
|
|
- passes_validation: True if no critical errors detected
|
|
- confidence_penalty: Cumulative penalty (0.0-1.0)
|
|
- error_messages: List of validation error messages
|
|
|
|
Example usage:
|
|
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
|
|
if not passes:
|
|
print(f"Validation failed: {errors}, proceeding to next tier")
|
|
# Continue to next processing tier instead of early exit
|
|
"""
|
|
errors = []
|
|
total_penalty = 0.0
|
|
|
|
# Critical rules for early-exit decision-making
|
|
# These determine if we can trust the extraction or need to proceed to next tier
|
|
critical_rules = [
|
|
# Cross-field validations (most important for detecting OCR errors)
|
|
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
|
|
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
|
|
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
|
|
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
|
|
|
|
# Format & checksum validations
|
|
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
|
|
CUIFormatRule(), # CUI should be 6-10 digits
|
|
|
|
# Sanity checks
|
|
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
|
|
]
|
|
|
|
for rule in critical_rules:
|
|
result = rule.validate(extraction_result)
|
|
if not result.is_valid:
|
|
errors.append(result.message)
|
|
total_penalty += result.confidence_penalty
|
|
|
|
# Cap penalty at 1.0
|
|
total_penalty = min(1.0, total_penalty)
|
|
|
|
passes = len(errors) == 0
|
|
return passes, total_penalty, errors
|
|
|
|
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
|
|
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
|
|
|
|
@staticmethod
|
|
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
|
|
"""Try to repair CUI by attempting 1-digit corrections.
|
|
|
|
OCR often misreads similar-looking digits:
|
|
- 5 ↔ 8 (most common in receipts)
|
|
- 6 ↔ 0
|
|
- 1 ↔ 7
|
|
- 3 ↔ 8
|
|
|
|
Algorithm:
|
|
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
|
|
2. Skip first digit (position 0) - usually reliable in CUI
|
|
3. Check checksum digit (last position) last
|
|
4. Prefer common OCR digit confusions (5↔8, 6↔0)
|
|
|
|
Args:
|
|
cui_digits: Original CUI digits (without RO prefix)
|
|
|
|
Returns:
|
|
Repaired CUI digits if 1-digit fix found, else None
|
|
"""
|
|
if len(cui_digits) < 6 or not cui_digits.isdigit():
|
|
return None
|
|
|
|
# If already valid, return as-is
|
|
if CUIChecksumRule.validate_checksum(cui_digits):
|
|
return cui_digits
|
|
|
|
# Common OCR digit confusions (try these first)
|
|
confusion_pairs = {
|
|
'5': ['8', '6'], # 5 often misread as 8 or 6
|
|
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
|
'6': ['0', '8'], # 6 often misread as 0 or 8
|
|
'0': ['6', '8'], # 0 often misread as 6 or 8
|
|
'1': ['7', '4'], # 1 often misread as 7 or 4
|
|
'7': ['1'], # 7 often misread as 1
|
|
'3': ['8'], # 3 often misread as 8
|
|
'4': ['1'], # 4 often misread as 1
|
|
'2': ['7'], # 2 sometimes misread as 7
|
|
'9': ['0'], # 9 sometimes misread as 0
|
|
}
|
|
|
|
n = len(cui_digits)
|
|
last_pos = n - 1 # checksum position
|
|
|
|
# Position check order: middle positions first, then position 1, then 0, then checksum
|
|
# Skip position 0 (first digit) - it's usually reliable
|
|
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
|
|
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
|
|
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
|
|
|
|
for pos in position_order:
|
|
if pos >= n:
|
|
continue
|
|
|
|
original_digit = cui_digits[pos]
|
|
|
|
# Try common confusions first for this digit
|
|
candidates = confusion_pairs.get(original_digit, [])
|
|
# Then try all other digits
|
|
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
|
|
|
for replacement in candidates + all_digits:
|
|
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
|
if CUIChecksumRule.validate_checksum(candidate):
|
|
print(f"[CUI Repair] Fixed {cui_digits} -> {candidate} (position {pos}: {original_digit}->{replacement})", flush=True)
|
|
return candidate
|
|
|
|
# No single-digit fix found
|
|
return None
|
|
|
|
@staticmethod
|
|
def normalize_cui(cui: Optional[str]) -> Optional[str]:
|
|
"""Normalize CUI - fix OCR errors but preserve original format.
|
|
|
|
Rules:
|
|
- R0 → RO (fix OCR error where O is read as 0)
|
|
- Keep RO prefix if original had it (platitor TVA)
|
|
- Do NOT add RO if original didn't have it (neplatitor TVA)
|
|
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
|
|
|
|
Examples:
|
|
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
|
|
R010562600 → RO10562600 (fix R0 OCR error)
|
|
RO10562600 → RO10562600 (unchanged)
|
|
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
|
|
|
|
Args:
|
|
cui: Raw CUI string from OCR
|
|
|
|
Returns:
|
|
Normalized CUI, or None if invalid
|
|
"""
|
|
if not cui:
|
|
return None
|
|
|
|
cui = cui.strip().upper()
|
|
|
|
# Check if original had RO/R0 prefix
|
|
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
|
|
|
# Extract digits
|
|
if cui.startswith("RO"):
|
|
cui_digits = cui[2:]
|
|
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
|
cui_digits = cui[2:]
|
|
else:
|
|
cui_digits = cui
|
|
|
|
# Remove any non-digit characters
|
|
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
|
|
|
# Validate length
|
|
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
|
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
|
|
return None
|
|
|
|
# Try to repair checksum if invalid
|
|
if not CUIChecksumRule.validate_checksum(cui_digits):
|
|
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
|
|
if repaired:
|
|
cui_digits = repaired
|
|
|
|
# Return with RO prefix only if original had it
|
|
if had_ro_prefix:
|
|
return f"RO{cui_digits}"
|
|
else:
|
|
return cui_digits
|
|
|
|
@staticmethod
|
|
async def fuzzy_match_cui_from_db(
|
|
cui: Optional[str],
|
|
db_session
|
|
) -> Optional[tuple[str, str]]:
|
|
"""Fuzzy match CUI against database of known suppliers.
|
|
|
|
This function:
|
|
1. Validates CUI checksum
|
|
2. If valid, looks up in database (exact match)
|
|
3. If invalid, tries 1-digit corrections and looks up each candidate
|
|
4. Returns the first match found in database
|
|
|
|
Args:
|
|
cui: Extracted CUI from OCR (may be invalid)
|
|
db_session: SQLAlchemy async session for database lookups
|
|
|
|
Returns:
|
|
Tuple of (corrected_cui, supplier_name) if found, else None
|
|
|
|
Usage in OCR extraction:
|
|
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
|
|
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
|
|
if match:
|
|
corrected_cui, supplier_name = match
|
|
"""
|
|
from sqlalchemy import select, or_
|
|
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
|
|
|
if not cui:
|
|
return None
|
|
|
|
cui = cui.strip().upper()
|
|
|
|
# Check if original had RO/R0 prefix
|
|
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
|
|
|
|
# Extract digits
|
|
if cui.startswith("RO"):
|
|
cui_digits = cui[2:]
|
|
elif cui.startswith("R0"): # Fix OCR error R0 → RO
|
|
cui_digits = cui[2:]
|
|
else:
|
|
cui_digits = cui
|
|
|
|
# Remove any non-digit characters
|
|
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
|
|
|
|
# Validate length
|
|
if len(cui_digits) < 6 or len(cui_digits) > 10:
|
|
return None
|
|
|
|
# Helper to format CUI with optional RO prefix
|
|
def format_cui(digits: str) -> str:
|
|
if had_ro_prefix:
|
|
return f"RO{digits}"
|
|
return digits
|
|
|
|
# Helper to search database for CUI
|
|
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
|
|
"""Search both synced and local suppliers for CUI."""
|
|
# Search patterns: with and without RO prefix, with and without space
|
|
# Database may have: "22891860", "RO22891860", "RO 22891860"
|
|
search_patterns = [
|
|
digits, # Just digits: 22891860
|
|
f"RO{digits}", # With RO prefix: RO22891860
|
|
f"RO {digits}", # With RO prefix and space: RO 22891860
|
|
digits.lstrip('0'), # Handle leading zeros
|
|
]
|
|
|
|
# Search synced_suppliers first (more data)
|
|
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
|
or_(
|
|
SyncedSupplier.fiscal_code == digits,
|
|
SyncedSupplier.fiscal_code == f"RO{digits}",
|
|
SyncedSupplier.fiscal_code == f"RO {digits}", # With space
|
|
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
|
|
)
|
|
).limit(1)
|
|
result = await db_session.execute(stmt)
|
|
row = result.first()
|
|
if row:
|
|
return (format_cui(digits), row.name)
|
|
|
|
# Search local_suppliers
|
|
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
|
or_(
|
|
LocalSupplier.fiscal_code == digits,
|
|
LocalSupplier.fiscal_code == f"RO{digits}",
|
|
LocalSupplier.fiscal_code == f"RO {digits}", # With space
|
|
LocalSupplier.fiscal_code == digits.lstrip('0'),
|
|
)
|
|
).limit(1)
|
|
result = await db_session.execute(stmt)
|
|
row = result.first()
|
|
if row:
|
|
return (format_cui(digits), row.name)
|
|
|
|
return None
|
|
|
|
# 1. If checksum is valid, check if it exists in database (exact match)
|
|
if CUIChecksumRule.validate_checksum(cui_digits):
|
|
match = await lookup_cui_in_db(cui_digits)
|
|
if match:
|
|
print(f"[Fuzzy CUI] Exact match found: {cui} -> {match[0]} ({match[1]})", flush=True)
|
|
return match
|
|
# Valid checksum but not in DB - return as-is (it might be a new supplier)
|
|
return None
|
|
|
|
# 2. Invalid checksum - try 1-digit corrections and verify against database
|
|
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
|
|
|
|
# Common OCR digit confusions (try these first)
|
|
confusion_pairs = {
|
|
'5': ['8', '6'], # 5 often misread as 8 or 6
|
|
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
|
|
'6': ['0', '8'], # 6 often misread as 0 or 8
|
|
'0': ['6', '8'], # 0 often misread as 6 or 8
|
|
'1': ['7', '4'], # 1 often misread as 7 or 4
|
|
'7': ['1'], # 7 often misread as 1
|
|
'3': ['8'], # 3 often misread as 8
|
|
'4': ['1'], # 4 often misread as 1
|
|
'2': ['7'], # 2 sometimes misread as 7
|
|
'9': ['0'], # 9 sometimes misread as 0
|
|
}
|
|
|
|
n = len(cui_digits)
|
|
last_pos = n - 1 # checksum position
|
|
|
|
# Position check order: middle positions first, then ends
|
|
middle_positions = list(range(2, last_pos))
|
|
position_order = middle_positions + [1, last_pos, 0]
|
|
|
|
for pos in position_order:
|
|
if pos >= n:
|
|
continue
|
|
|
|
original_digit = cui_digits[pos]
|
|
|
|
# Try common confusions first for this digit
|
|
candidates = confusion_pairs.get(original_digit, [])
|
|
# Then try all other digits
|
|
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
|
|
|
|
for replacement in candidates + all_digits:
|
|
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
|
|
|
|
# Only consider if checksum is valid
|
|
if not CUIChecksumRule.validate_checksum(candidate):
|
|
continue
|
|
|
|
# Check if this corrected CUI exists in database
|
|
match = await lookup_cui_in_db(candidate)
|
|
if match:
|
|
print(f"[Fuzzy CUI] DB match: {cui} -> {match[0]} ({match[1]}) [pos {pos}: {original_digit}->{replacement}]", flush=True)
|
|
return match
|
|
|
|
# No match found in database
|
|
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
|
|
return None
|
|
|
|
@staticmethod
|
|
async def fuzzy_match_by_name_and_cui(
|
|
vendor_name: Optional[str],
|
|
cui: Optional[str],
|
|
db_session
|
|
) -> Optional[tuple[str, str]]:
|
|
"""Fuzzy match supplier by NAME, then narrow down by CUI.
|
|
|
|
Algorithm:
|
|
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
|
|
2. Search suppliers by fuzzy name match (LIKE %name%)
|
|
3. If multiple results, use fuzzy CUI matching to pick best one
|
|
4. Return the best match
|
|
|
|
Args:
|
|
vendor_name: Extracted vendor name from OCR
|
|
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
|
db_session: SQLAlchemy async session
|
|
|
|
Returns:
|
|
Tuple of (matched_cui, supplier_name) if found, else None
|
|
"""
|
|
from sqlalchemy import select, or_, func
|
|
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
|
|
import re
|
|
|
|
if not vendor_name or len(vendor_name) < 3:
|
|
return None
|
|
|
|
# Normalize vendor name for search
|
|
def normalize_name(name: str) -> str:
|
|
"""Normalize name for fuzzy matching."""
|
|
name = name.upper()
|
|
# Remove company type suffixes
|
|
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
|
|
name = name.replace(suffix, '')
|
|
# Remove punctuation and extra spaces
|
|
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
|
|
name = ' '.join(name.split())
|
|
return name.strip()
|
|
|
|
# Extract key words from vendor name (for fuzzy search)
|
|
normalized_name = normalize_name(vendor_name)
|
|
name_words = [w for w in normalized_name.split() if len(w) >= 3]
|
|
|
|
if not name_words:
|
|
return None
|
|
|
|
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' -> keywords: {name_words}", flush=True)
|
|
|
|
# Build search pattern - use first significant word
|
|
primary_word = name_words[0]
|
|
search_pattern = f"%{primary_word}%"
|
|
|
|
candidates = []
|
|
|
|
# Search synced_suppliers
|
|
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
|
|
func.upper(SyncedSupplier.name).like(search_pattern)
|
|
).limit(20)
|
|
result = await db_session.execute(stmt)
|
|
for row in result:
|
|
if row.fiscal_code:
|
|
candidates.append((row.fiscal_code, row.name))
|
|
|
|
# Search local_suppliers
|
|
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
|
|
func.upper(LocalSupplier.name).like(search_pattern)
|
|
).limit(20)
|
|
result = await db_session.execute(stmt)
|
|
for row in result:
|
|
if row.fiscal_code:
|
|
candidates.append((row.fiscal_code, row.name))
|
|
|
|
if not candidates:
|
|
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
|
|
return None
|
|
|
|
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
|
|
|
|
# If only one candidate, return it
|
|
if len(candidates) == 1:
|
|
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
|
return candidates[0]
|
|
|
|
# Multiple candidates - try to narrow down by CUI
|
|
if cui:
|
|
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
|
|
|
|
if len(cui_digits) >= 6:
|
|
# Score each candidate by how similar their CUI is to the extracted one
|
|
def cui_similarity(candidate_cui: str) -> int:
|
|
"""Calculate how many digits match in the same position."""
|
|
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
|
|
if len(cand_digits) != len(cui_digits):
|
|
return 0
|
|
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
|
|
|
|
# Sort candidates by CUI similarity (descending)
|
|
scored = [(cui_similarity(c[0]), c) for c in candidates]
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
best_score, best_match = scored[0]
|
|
# Require at least 70% digit match for CUI similarity
|
|
min_matching = int(len(cui_digits) * 0.7)
|
|
|
|
if best_score >= min_matching:
|
|
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
|
|
return best_match
|
|
|
|
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
|
|
|
|
# If still multiple and no CUI match, try name similarity
|
|
def name_similarity(candidate_name: str) -> int:
|
|
"""Count how many keywords match."""
|
|
norm_cand = normalize_name(candidate_name)
|
|
return sum(1 for w in name_words if w in norm_cand)
|
|
|
|
scored = [(name_similarity(c[1]), c) for c in candidates]
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
if scored[0][0] >= 2: # At least 2 keywords match
|
|
best_match = scored[0][1]
|
|
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
|
|
return best_match
|
|
|
|
# Return first candidate if nothing else works
|
|
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
|
|
return candidates[0]
|
|
|
|
@staticmethod
|
|
async def fuzzy_match_supplier(
|
|
cui: Optional[str],
|
|
vendor_name: Optional[str],
|
|
db_session
|
|
) -> Optional[tuple[str, str]]:
|
|
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
|
|
|
|
Strategy:
|
|
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
|
|
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
|
|
|
|
Args:
|
|
cui: Extracted CUI from OCR (may be invalid/incomplete)
|
|
vendor_name: Extracted vendor name from OCR
|
|
db_session: SQLAlchemy async session
|
|
|
|
Returns:
|
|
Tuple of (matched_cui, supplier_name) if found, else None
|
|
"""
|
|
# Step 1: Try fuzzy CUI matching
|
|
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
|
|
if cui_match:
|
|
return cui_match
|
|
|
|
# Step 2: Fallback to fuzzy NAME + CUI matching
|
|
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
|
|
vendor_name, cui, db_session
|
|
)
|
|
if name_match:
|
|
return name_match
|
|
|
|
return None
|