Files
roa2web-service-auto/backend/modules/data_entry/services/ocr/validation.py
Marius Mutu ab160b628d feat(ocr): Add validation system and CLIENT CUI extraction
OCR Data Extraction Validation System:
- Add 7 validation rules (amount range, TVA ratio, payment sum, etc.)
- Add Medium preprocessing to replace Heavy (fixes digit concatenation)
- Add validation warnings to API responses
- Flag receipts needing manual review (needs_manual_review field)
- Add database migration for needs_manual_review column

CLIENT CUI Extraction Improvements:
- Support all format variations: CIF CLIENT:, CLIENT C.U.I/C.I.F., etc.
- Handle OCR errors (R0 vs RO, C1F vs CIF)
- Add client_name, client_cui, client_address to API response
- Add validation fields to API response (was missing)

QA Review: 12 issues found, 9 fixed (5 errors + 4 warnings)
- Fixed type safety in validation rules
- Fixed ZeroDivisionError risk
- Fixed schema mismatch (Optional[bool] for needs_manual_review)
- All 37 unit tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-30 19:12:52 +02:00

738 lines
24 KiB
Python

"""
OCR Data Validation Module
Provides multi-layer validation for OCR extraction results to prevent
incorrect data from entering the system.
Validation Layers:
1. Absolute sanity checks (value ranges)
2. Cross-field validation (correlation between fields)
3. Inter-OCR consistency (compare multiple OCR results)
4. Auto-correction (fix obvious errors)
Usage:
engine = OCRValidationEngine()
validated_result = engine.validate_extraction(
merged_result,
light_ocr_result,
medium_ocr_result
)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class ValidationResult:
"""Result of a single validation rule execution.
Attributes:
is_valid: Whether the validation passed
confidence_penalty: Penalty to apply to confidence score (0.0-1.0)
0.0 = no penalty, 1.0 = complete rejection
message: Human-readable description of validation result
severity: "info" | "warning" | "error"
"""
is_valid: bool
confidence_penalty: float = 0.0
message: str = ""
severity: str = "info" # "info" | "warning" | "error"
def __post_init__(self):
"""Validate penalty is in valid range."""
if not 0.0 <= self.confidence_penalty <= 1.0:
raise ValueError(f"Confidence penalty must be 0.0-1.0, got {self.confidence_penalty}")
class ValidationRule(ABC):
"""Abstract base class for OCR validation rules.
Each rule implements a specific validation check and returns
a ValidationResult indicating success/failure with optional
confidence penalty.
"""
@abstractmethod
def validate(self, data: dict[str, Any]) -> ValidationResult:
"""Execute validation rule on extraction data.
Args:
data: Dictionary containing extraction fields to validate
Example: {"amount": 85.99, "tva": 14.92, ...}
Returns:
ValidationResult with is_valid flag and optional penalty
"""
pass
@property
@abstractmethod
def rule_name(self) -> str:
"""Human-readable name of this validation rule."""
pass
# ============================================================================
# VALIDATION RULES
# ============================================================================
class AmountRangeRule(ValidationRule):
"""Validate amount is within reasonable bounds for Romanian receipts.
Romanian receipts rarely exceed 100,000 RON. This catches obvious
OCR errors like digit concatenation (85.99 → 859,762.16).
Example:
rule = AmountRangeRule(min_amount=0.01, max_amount=100_000.0)
result = rule.validate({"amount": 859762.16})
# result.is_valid = False, penalty = 0.5
"""
def __init__(self, min_amount: float = 0.01, max_amount: float = 100_000.0):
self.min_amount = min_amount
self.max_amount = max_amount
@property
def rule_name(self) -> str:
return "Amount Range Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
amount = data.get("amount")
if amount is None:
return ValidationResult(
is_valid=True,
message="No amount to validate"
)
if amount < self.min_amount:
return ValidationResult(
is_valid=False,
confidence_penalty=0.5,
message=f"Amount {amount:.2f} RON below minimum {self.min_amount:.2f} RON",
severity="error"
)
if amount > self.max_amount:
return ValidationResult(
is_valid=False,
confidence_penalty=0.5,
message=f"Amount {amount:.2f} RON exceeds maximum {self.max_amount:.2f} RON (likely OCR error)",
severity="error"
)
return ValidationResult(
is_valid=True,
message=f"Amount {amount:.2f} RON within valid range"
)
class TVARatioRule(ValidationRule):
"""Validate TVA is reasonable percentage of TOTAL amount.
Romanian TVA rates: 5%, 9%, 19%, 21% (most common: 19-21%)
This catches errors where TVA > TOTAL (impossible).
Example:
rule = TVARatioRule(min_ratio=0.05, max_ratio=0.24)
result = rule.validate({"amount": 85.99, "tva": 149.21})
# result.is_valid = False (149.21 > 85.99!)
"""
def __init__(self, min_ratio: float = 0.05, max_ratio: float = 0.24):
self.min_ratio = min_ratio
self.max_ratio = max_ratio
@property
def rule_name(self) -> str:
return "TVA Ratio Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
amount = data.get("amount")
tva = data.get("tva")
if not amount or not tva:
return ValidationResult(
is_valid=True,
message="Insufficient data for TVA correlation"
)
# Type safety: ensure numeric types before division
if not isinstance(amount, (int, float)) or not isinstance(tva, (int, float)):
return ValidationResult(
is_valid=True,
message="Non-numeric values, skipping TVA correlation"
)
# Avoid division by zero
if amount <= 0:
return ValidationResult(
is_valid=True,
message="Amount is zero or negative, skipping TVA ratio"
)
tva_ratio = tva / amount
if tva_ratio < self.min_ratio or tva_ratio > self.max_ratio:
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"TVA ratio {tva_ratio:.1%} outside valid range ({self.min_ratio:.0%}-{self.max_ratio:.0%})",
severity="warning"
)
return ValidationResult(
is_valid=True,
message=f"TVA ratio {tva_ratio:.1%} valid"
)
class PaymentSumRule(ValidationRule):
"""Validate CARD + NUMERAR = TOTAL BON (within tolerance).
This is a CRITICAL validation that catches cases where OCR extracts
wrong TOTAL but correct payment methods.
Example:
rule = PaymentSumRule(tolerance=0.02)
result = rule.validate({
"amount": 859762.16, # Wrong from OCR
"card_amount": 85.99, # Correct
"cash_amount": 0.0
})
# result.is_valid = False, suggests auto-correction
"""
def __init__(self, tolerance: float = 0.02):
self.tolerance = tolerance
@property
def rule_name(self) -> str:
return "Payment Sum Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
total = data.get("amount")
card = data.get("card_amount", 0.0) or 0.0
cash = data.get("cash_amount", 0.0) or 0.0
if not total:
return ValidationResult(
is_valid=True,
message="No total amount to validate"
)
payment_sum = card + cash
if payment_sum == 0:
return ValidationResult(
is_valid=True,
message="No payment methods extracted"
)
diff = abs(total - payment_sum)
if diff > self.tolerance:
return ValidationResult(
is_valid=False,
confidence_penalty=0.4,
message=f"Payment sum {payment_sum:.2f} RON ≠ Total {total:.2f} RON (diff: {diff:.2f} RON). Consider auto-correction.",
severity="error"
)
return ValidationResult(
is_valid=True,
message=f"Payment sum matches total (diff: {diff:.2f} RON)"
)
class TVAEntriesSumRule(ValidationRule):
"""Validate Σ(TVA entries) = TVA TOTAL (within tolerance).
TVA breakdown (A, B, C, D rates) should sum to total TVA.
Example:
rule = TVAEntriesSumRule(tolerance=0.02)
result = rule.validate({
"tva": 14.92,
"tva_entries": {"A": 14.92, "B": 0.0}
})
# result.is_valid = True
"""
def __init__(self, tolerance: float = 0.02):
self.tolerance = tolerance
@property
def rule_name(self) -> str:
return "TVA Entries Sum Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
tva_total = data.get("tva")
tva_entries = data.get("tva_entries", {})
if not tva_total:
return ValidationResult(
is_valid=True,
message="No TVA total to validate"
)
if not tva_entries:
return ValidationResult(
is_valid=True,
message="No TVA entries extracted"
)
entries_sum = sum(tva_entries.values())
if entries_sum == 0:
return ValidationResult(
is_valid=True,
message="TVA entries sum is zero"
)
diff = abs(tva_total - entries_sum)
if diff > self.tolerance:
return ValidationResult(
is_valid=False,
confidence_penalty=0.2,
message=f"TVA entries sum {entries_sum:.2f} RON ≠ TVA total {tva_total:.2f} RON (diff: {diff:.2f} RON)",
severity="warning"
)
return ValidationResult(
is_valid=True,
message=f"TVA entries sum matches total (diff: {diff:.2f} RON)"
)
class CUIFormatRule(ValidationRule):
"""Validate CUI format: RO + 6-10 digits.
Romanian CUI (Cod Unic de Identificare) format:
- Optional "RO" prefix (or "R0" from OCR errors)
- 6-10 numeric digits
Example:
rule = CUIFormatRule()
result = rule.validate({"cui": "RO10562600"})
# result.is_valid = True
"""
@property
def rule_name(self) -> str:
return "CUI Format Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
cui = data.get("cui")
if not cui:
return ValidationResult(
is_valid=True,
message="No CUI to validate"
)
# Normalize: remove RO/R0 prefix
cui_clean = cui.strip().upper()
if cui_clean.startswith("RO"):
cui_clean = cui_clean[2:]
elif cui_clean.startswith("R0"):
cui_clean = cui_clean[2:]
# Check if numeric
if not cui_clean.isdigit():
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' contains non-numeric characters",
severity="warning"
)
# Check length
if len(cui_clean) < 6 or len(cui_clean) > 10:
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' length {len(cui_clean)} outside valid range (6-10 digits)",
severity="warning"
)
return ValidationResult(
is_valid=True,
message=f"CUI '{cui}' format valid"
)
class CUIChecksumRule(ValidationRule):
"""Validate Romanian CIF/CUI using Mod 11 checksum algorithm.
Algorithm:
1. Remove RO prefix if present
2. Extract last digit as declared checksum
3. Apply multipliers [7,5,3,2,1,7,5,3,2] to first N-1 digits
4. Calculate: (sum * 10) mod 11
5. If result = 10, expected checksum = 0
6. Else, expected checksum = result
7. Compare with declared checksum
Example:
rule = CUIChecksumRule()
result = rule.validate({"cui": "RO10562600"})
# result.is_valid = True (checksum correct)
result = rule.validate({"cui": "R01879855"})
# result.is_valid = False (checksum mismatch)
"""
@property
def rule_name(self) -> str:
return "CUI Checksum Check (Mod 11)"
def validate(self, data: dict[str, Any]) -> ValidationResult:
cui = data.get("cui")
if not cui:
return ValidationResult(
is_valid=True,
message="No CUI to validate"
)
# Normalize: remove RO/R0 prefix
cui_clean = cui.strip().upper()
if cui_clean.startswith("RO"):
cui_clean = cui_clean[2:]
elif cui_clean.startswith("R0"):
cui_clean = cui_clean[2:]
# Check format first
if not cui_clean.isdigit():
return ValidationResult(
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
message="CUI format invalid, skipping checksum"
)
if len(cui_clean) < 6 or len(cui_clean) > 10:
return ValidationResult(
is_valid=True,
message="CUI length invalid, skipping checksum"
)
# Extract digits
digits = [int(d) for d in cui_clean]
checksum_declared = digits[-1]
base_digits = digits[:-1]
# Multipliers (trim to match base_digits length)
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
multipliers = multipliers[:len(base_digits)]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
# Calculate expected checksum
checksum_calculated = (weighted_sum * 10) % 11
if checksum_calculated == 10:
checksum_calculated = 0
if checksum_calculated != checksum_declared:
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
severity="warning"
)
return ValidationResult(
is_valid=True,
message=f"CUI '{cui}' checksum valid"
)
class InterOCRConsistencyRule(ValidationRule):
"""Validate consistency between multiple OCR results.
If Light OCR and Medium OCR produce values that differ by >10x,
one is clearly wrong (likely digit concatenation error).
Example:
rule = InterOCRConsistencyRule(max_ratio=10.0)
result = rule.validate({
"light_amount": 85.99,
"medium_amount": 859762.16
})
# result.is_valid = False (ratio = 10,000x!)
"""
def __init__(self, max_ratio: float = 10.0):
self.max_ratio = max_ratio
@property
def rule_name(self) -> str:
return "Inter-OCR Consistency Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
light_value = data.get("light_value")
medium_value = data.get("medium_value")
field_name = data.get("field_name", "value")
if not light_value or not medium_value:
return ValidationResult(
is_valid=True,
message="Insufficient OCR results for consistency check"
)
# Avoid division by zero
if light_value == 0 or medium_value == 0:
return ValidationResult(
is_valid=True,
message="One value is zero, skipping consistency check"
)
ratio = max(light_value, medium_value) / min(light_value, medium_value)
if ratio > self.max_ratio:
return ValidationResult(
is_valid=False,
confidence_penalty=0.2,
message=f"{field_name}: OCR results differ by {ratio:.1f}x (Light: {light_value}, Medium: {medium_value})",
severity="warning"
)
return ValidationResult(
is_valid=True,
message=f"{field_name}: OCR results consistent (ratio: {ratio:.2f}x)"
)
# ============================================================================
# VALIDATION ENGINE
# ============================================================================
@dataclass
class EnhancedExtractionResult:
"""Enhanced extraction result with validation metadata.
This wraps the original extraction data and adds validation results.
"""
# Original data
data: dict[str, Any]
# Validation results
needs_manual_review: bool = False
validation_warnings: list[str] = field(default_factory=list)
validation_errors: list[str] = field(default_factory=list)
confidence_adjustments: dict[str, float] = field(default_factory=dict)
# Inter-OCR metadata
inter_ocr_ratios: dict[str, float] = field(default_factory=dict)
class OCRValidationEngine:
"""Orchestrate all validation rules for OCR extraction results.
This engine applies validation rules in order:
1. Sanity checks (amount range, format checks)
2. Cross-field correlation (TVA ratio, payment sum)
3. Inter-OCR consistency checks
Example:
engine = OCRValidationEngine()
result = engine.validate_extraction(
extraction_result=merged_data,
light_result=light_ocr_data,
medium_result=medium_ocr_data
)
"""
def __init__(self):
"""Initialize validation engine with default rules."""
# Sanity check rules (absolute value validation)
self.sanity_rules = [
AmountRangeRule(min_amount=0.01, max_amount=100_000.0),
CUIFormatRule(),
CUIChecksumRule(),
]
# Cross-field validation rules (correlation between fields)
self.cross_field_rules = [
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
PaymentSumRule(tolerance=0.02),
TVAEntriesSumRule(tolerance=0.02),
]
# Inter-OCR consistency rules
self.inter_ocr_rules = [
InterOCRConsistencyRule(max_ratio=10.0),
]
def validate_extraction(
self,
extraction_result: dict[str, Any],
light_result: Optional[dict[str, Any]] = None,
medium_result: Optional[dict[str, Any]] = None
) -> EnhancedExtractionResult:
"""Run all validation rules and return enhanced result.
Args:
extraction_result: Merged OCR extraction data (required)
light_result: Light OCR preprocessing results (optional)
medium_result: Medium OCR preprocessing results (optional)
Returns:
EnhancedExtractionResult with validation warnings and metadata
"""
warnings = []
errors = []
confidence_adjustments = {}
inter_ocr_ratios = {}
# Step 1: Sanity checks
print("\n[Validation] Step 1: Sanity checks...", flush=True)
for rule in self.sanity_rules:
result = rule.validate(extraction_result)
if not result.is_valid:
msg = f"[{rule.rule_name}] {result.message}"
if result.severity == "error":
errors.append(msg)
else:
warnings.append(msg)
print(f"{msg}", flush=True)
# Track confidence penalty for the relevant field based on rule
if result.confidence_penalty > 0:
rule_field_map = {
"Amount Range Check": ["amount"],
"CUI Format Check": ["cui"],
"CUI Checksum Check (Mod 11)": ["cui"],
}
fields = rule_field_map.get(rule.rule_name, ["amount", "tva", "cui"])
for f in fields:
if f in extraction_result:
confidence_adjustments[f] = result.confidence_penalty
else:
print(f"{rule.rule_name}: {result.message}", flush=True)
# Step 2: Cross-field validation
print("\n[Validation] Step 2: Cross-field validation...", flush=True)
for rule in self.cross_field_rules:
result = rule.validate(extraction_result)
if not result.is_valid:
msg = f"[{rule.rule_name}] {result.message}"
if result.severity == "error":
errors.append(msg)
else:
warnings.append(msg)
print(f"{msg}", flush=True)
# Track confidence penalty for the relevant field based on rule
if result.confidence_penalty > 0:
rule_field_map = {
"TVA Ratio Check": ["tva"],
"Payment Sum Check": ["amount"],
"TVA Entries Sum Check": ["tva"],
}
fields = rule_field_map.get(rule.rule_name, ["amount", "tva"])
for f in fields:
if f in extraction_result:
confidence_adjustments[f] = result.confidence_penalty
else:
print(f"{rule.rule_name}: {result.message}", flush=True)
# Step 3: Inter-OCR consistency checks
if light_result and medium_result:
print("\n[Validation] Step 3: Inter-OCR consistency...", flush=True)
# Check amount consistency
if "amount" in light_result and "amount" in medium_result:
consistency_data = {
"light_value": light_result["amount"],
"medium_value": medium_result["amount"],
"field_name": "amount"
}
result = self.inter_ocr_rules[0].validate(consistency_data)
if not result.is_valid:
msg = f"[Inter-OCR] {result.message}"
warnings.append(msg)
print(f"{msg}", flush=True)
# Store ratio for metadata
ratio = max(
light_result["amount"],
medium_result["amount"]
) / min(light_result["amount"], medium_result["amount"])
inter_ocr_ratios["amount"] = ratio
else:
print(f"{result.message}", flush=True)
# Determine if manual review is needed
# Only flag for review if there are errors OR high-severity warnings
high_severity_warnings = [w for w in warnings if "[Amount Range" in w or "[Payment Sum" in w or "[Inter-OCR]" in w]
needs_manual_review = (
len(errors) > 0 or
len(high_severity_warnings) > 0 or
any(ratio > 10.0 for ratio in inter_ocr_ratios.values())
)
print(f"\n[Validation] Summary:", flush=True)
print(f" Errors: {len(errors)}", flush=True)
print(f" Warnings: {len(warnings)}", flush=True)
print(f" Manual review needed: {needs_manual_review}", flush=True)
return EnhancedExtractionResult(
data=extraction_result,
needs_manual_review=needs_manual_review,
validation_warnings=warnings,
validation_errors=errors,
confidence_adjustments=confidence_adjustments,
inter_ocr_ratios=inter_ocr_ratios
)
@staticmethod
def normalize_cui(cui: Optional[str]) -> Optional[str]:
"""Normalize CUI to RO prefix + digits format.
Examples:
10562600 → RO10562600
R010562600 → RO10562600 (fix R0 OCR error)
RO10562600 → RO10562600 (unchanged)
Args:
cui: Raw CUI string from OCR
Returns:
Normalized CUI with RO prefix, or None if invalid
"""
if not cui:
return None
cui = cui.strip().upper()
# Remove existing prefix if present
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"):
cui = cui[2:]
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
return None
# Add RO prefix
return f"RO{cui_digits}"