feat(ocr): Add docTR OCR engine with metrics infrastructure

Add docTR as primary OCR engine with 2-tier sequential processing,
OCR metrics tracking, and simplified engine selection.

Features:
- docTR OCR engine with light+medium preprocessing tiers
- doctr_plus mode with early exit optimization (~65% fast path)
- OCR metrics dashboard with per-engine statistics
- User OCR preference persistence
- Parallel worker pool for OCR processing
- Cross-validation for extraction quality

Engine options: tesseract, doctr, doctr_plus (recommended), paddleocr

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-02 05:37:16 +02:00
parent 74f7aefc26
commit 495790411f
75 changed files with 23349 additions and 1311 deletions

View File

@@ -385,8 +385,81 @@ class CUIChecksumRule(ValidationRule):
result = rule.validate({"cui": "R01879855"})
# result.is_valid = False (checksum mismatch)
Static methods available for direct use:
CUIChecksumRule.calculate_checksum("1056260") -> 0
CUIChecksumRule.validate_checksum("10562600") -> True
CUIChecksumRule.has_ro_prefix("RO10562600") -> True
"""
# Fixed multipliers for 9 positions (Romanian Mod 11)
MULTIPLIERS = [7, 5, 3, 2, 1, 7, 5, 3, 2]
@staticmethod
def calculate_checksum(cui_base: str) -> int:
"""Calculate expected CUI checksum using Romanian Mod 11 algorithm.
Args:
cui_base: CUI digits WITHOUT the checksum digit (last digit)
Returns:
Expected checksum digit (0-9), or -1 if invalid input
"""
if not cui_base or not cui_base.isdigit():
return -1
# Pad base to 9 digits from LEFT
base_padded = cui_base.zfill(9)
base_digits = [int(d) for d in base_padded]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, CUIChecksumRule.MULTIPLIERS))
# Calculate checksum
checksum = (weighted_sum * 10) % 11
if checksum == 10:
checksum = 0
return checksum
@staticmethod
def validate_checksum(cui_digits: str) -> bool:
"""Check if CUI checksum is valid.
Args:
cui_digits: Full CUI digits (including checksum as last digit)
Returns:
True if checksum is valid, False otherwise
"""
if not cui_digits or len(cui_digits) < 6 or not cui_digits.isdigit():
return False
base = cui_digits[:-1]
declared = int(cui_digits[-1])
expected = CUIChecksumRule.calculate_checksum(base)
return expected == declared
@staticmethod
def has_ro_prefix(cui: str) -> bool:
"""Check if CUI has RO prefix (proper format for VAT payers)."""
if not cui:
return False
return cui.upper().strip().startswith('RO')
@staticmethod
def extract_digits(cui: str) -> str:
"""Extract digits from CUI, removing RO/R0 prefix."""
if not cui:
return ""
cui = cui.strip().upper()
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui = cui[2:]
return ''.join(c for c in cui if c.isdigit())
@property
def rule_name(self) -> str:
return "CUI Checksum Check (Mod 11)"
@@ -400,15 +473,11 @@ class CUIChecksumRule(ValidationRule):
message="No CUI to validate"
)
# Normalize: remove RO/R0 prefix
cui_clean = cui.strip().upper()
if cui_clean.startswith("RO"):
cui_clean = cui_clean[2:]
elif cui_clean.startswith("R0"):
cui_clean = cui_clean[2:]
# Use static method to extract digits
cui_clean = CUIChecksumRule.extract_digits(cui)
# Check format first
if not cui_clean.isdigit():
if not cui_clean:
return ValidationResult(
is_valid=True, # Don't fail checksum if format invalid (handled by CUIFormatRule)
message="CUI format invalid, skipping checksum"
@@ -420,28 +489,15 @@ class CUIChecksumRule(ValidationRule):
message="CUI length invalid, skipping checksum"
)
# Extract digits
digits = [int(d) for d in cui_clean]
checksum_declared = digits[-1]
base_digits = digits[:-1]
# Multipliers (trim to match base_digits length)
multipliers = [7, 5, 3, 2, 1, 7, 5, 3, 2]
multipliers = multipliers[:len(base_digits)]
# Calculate weighted sum
weighted_sum = sum(d * m for d, m in zip(base_digits, multipliers))
# Calculate expected checksum
checksum_calculated = (weighted_sum * 10) % 11
if checksum_calculated == 10:
checksum_calculated = 0
if checksum_calculated != checksum_declared:
# Use static method to validate checksum
if not CUIChecksumRule.validate_checksum(cui_clean):
# Calculate expected for error message
expected = CUIChecksumRule.calculate_checksum(cui_clean[:-1])
declared = int(cui_clean[-1])
return ValidationResult(
is_valid=False,
confidence_penalty=0.3,
message=f"CUI '{cui}' checksum mismatch: expected {checksum_calculated}, got {checksum_declared}",
message=f"CUI '{cui}' checksum mismatch: expected {expected}, got {declared}",
severity="warning"
)
@@ -451,6 +507,129 @@ class CUIChecksumRule(ValidationRule):
)
class TVABasedTotalRule(ValidationRule):
"""Validate TOTAL using reverse calculation from TVA amount.
This is a CRITICAL validation that catches cases where OCR extracts
wrong TOTAL but correct TVA. Since TVA = BASE * rate and TOTAL = BASE + TVA,
we can calculate expected TOTAL from TVA alone.
Formula:
Expected TOTAL = TVA / rate * (1 + rate)
Or equivalently: Expected TOTAL = TVA * (1 + rate) / rate
For TVA rate 21%:
Expected TOTAL = TVA / 0.21 * 1.21 = TVA * 5.7619
Example (benzina 27 oct):
TVA = 49.58, rate = 21%
Expected TOTAL = 49.58 / 0.21 * 1.21 = 285.68
Extracted TOTAL = 205.66 (WRONG!)
Rule detects mismatch and flags for escalation
Usage in multi-tier processing (e.g., doctr_plus):
If this rule fails, the engine should proceed to next tier
instead of returning early with potentially wrong data.
"""
def __init__(self, tolerance_percent: float = 0.02):
"""
Args:
tolerance_percent: Allowed difference as percentage (0.02 = 2%)
"""
self.tolerance_percent = tolerance_percent
@property
def rule_name(self) -> str:
return "TVA-Based Total Check"
def validate(self, data: dict[str, Any]) -> ValidationResult:
total = data.get("amount")
tva = data.get("tva")
tva_entries = data.get("tva_entries", [])
if not total or not tva:
return ValidationResult(
is_valid=True,
message="Insufficient data for TVA-based total validation"
)
# Type safety
try:
total = float(total)
tva = float(tva)
except (TypeError, ValueError):
return ValidationResult(
is_valid=True,
message="Non-numeric values, skipping TVA-based total validation"
)
if tva <= 0 or total <= 0:
return ValidationResult(
is_valid=True,
message="Zero or negative values, skipping TVA-based total validation"
)
# Try to determine TVA rate from entries
tva_rate = None
# Check tva_entries for rate information
if tva_entries:
for entry in tva_entries:
if isinstance(entry, dict):
percent = entry.get('percent')
if percent:
try:
tva_rate = float(percent) / 100.0
break
except (TypeError, ValueError):
pass
# Fallback: try to calculate rate from TVA/TOTAL ratio
if not tva_rate:
# TVA = BASE * rate, TOTAL = BASE + TVA = BASE * (1 + rate)
# TVA/TOTAL = rate / (1 + rate)
# So rate = TVA / (TOTAL - TVA) = TVA / BASE
base = total - tva
if base > 0:
calculated_rate = tva / base
# Validate it's a reasonable Romanian TVA rate (5%, 9%, 19%, 21%)
if 0.04 <= calculated_rate <= 0.25:
tva_rate = calculated_rate
if not tva_rate:
# Assume most common rate: 21%
tva_rate = 0.21
# Calculate expected TOTAL from TVA
# TVA = BASE * rate → BASE = TVA / rate
# TOTAL = BASE + TVA = (TVA / rate) + TVA = TVA * (1 + 1/rate) = TVA * (1 + rate) / rate
expected_total = tva * (1 + tva_rate) / tva_rate
# Calculate difference
diff = abs(total - expected_total)
diff_percent = diff / expected_total if expected_total > 0 else 1.0
if diff_percent > self.tolerance_percent:
# Significant mismatch - OCR likely extracted TOTAL wrong
return ValidationResult(
is_valid=False,
confidence_penalty=0.5, # High penalty - this is a critical error
message=(
f"TOTAL mismatch: Extracted {total:.2f} RON vs "
f"TVA-calculated {expected_total:.2f} RON "
f"(TVA={tva:.2f}, rate={tva_rate:.0%}, diff={diff_percent:.1%}). "
f"Likely OCR error on TOTAL."
),
severity="error"
)
return ValidationResult(
is_valid=True,
message=f"TOTAL {total:.2f} matches TVA-calculated {expected_total:.2f} (diff: {diff_percent:.1%})"
)
class InterOCRConsistencyRule(ValidationRule):
"""Validate consistency between multiple OCR results.
@@ -562,6 +741,7 @@ class OCRValidationEngine:
TVARatioRule(min_ratio=0.05, max_ratio=0.24),
PaymentSumRule(tolerance=0.02),
TVAEntriesSumRule(tolerance=0.02),
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA
]
# Inter-OCR consistency rules
@@ -699,39 +879,508 @@ class OCRValidationEngine:
inter_ocr_ratios=inter_ocr_ratios
)
def quick_validate_for_hybrid(self, extraction_result: dict[str, Any]) -> tuple[bool, float, list[str]]:
"""Quick validation for early-exit decisions (e.g., doctr_plus Tier 1).
Runs critical cross-validation rules to detect obvious OCR errors.
Used to decide whether to proceed to next processing tier or exit early.
Args:
extraction_result: Extraction data dict with fields:
- amount: Extracted TOTAL
- tva: Extracted TVA total
- tva_entries: List of TVA entries with rates
Returns:
Tuple of (passes_validation, confidence_penalty, error_messages)
- passes_validation: True if no critical errors detected
- confidence_penalty: Cumulative penalty (0.0-1.0)
- error_messages: List of validation error messages
Example usage:
passes, penalty, errors = validation_engine.quick_validate_for_hybrid(extraction_data)
if not passes:
print(f"Validation failed: {errors}, proceeding to next tier")
# Continue to next processing tier instead of early exit
"""
errors = []
total_penalty = 0.0
# Critical rules for early-exit decision-making
# These determine if we can trust the extraction or need to proceed to next tier
critical_rules = [
# Cross-field validations (most important for detecting OCR errors)
TVABasedTotalRule(tolerance_percent=0.02), # Critical: detect TOTAL errors via TVA calculation
PaymentSumRule(tolerance=0.05), # Cross-validate TOTAL vs CARD+CASH payments
TVARatioRule(min_ratio=0.05, max_ratio=0.24), # TVA should be 5-24% of TOTAL
TVAEntriesSumRule(tolerance=0.05), # Sum of TVA entries should match TVA total
# Format & checksum validations
CUIChecksumRule(), # Validate CUI/CIF with Romanian Mod11 checksum algorithm
CUIFormatRule(), # CUI should be 6-10 digits
# Sanity checks
AmountRangeRule(min_amount=0.01, max_amount=100_000.0), # Reasonable amount range
]
for rule in critical_rules:
result = rule.validate(extraction_result)
if not result.is_valid:
errors.append(result.message)
total_penalty += result.confidence_penalty
# Cap penalty at 1.0
total_penalty = min(1.0, total_penalty)
passes = len(errors) == 0
return passes, total_penalty, errors
# NOTE: _calculate_cui_checksum and _is_cui_checksum_valid removed
# Use CUIChecksumRule.calculate_checksum() and CUIChecksumRule.validate_checksum() instead
@staticmethod
def _repair_cui_checksum(cui_digits: str) -> Optional[str]:
"""Try to repair CUI by attempting 1-digit corrections.
OCR often misreads similar-looking digits:
- 5 ↔ 8 (most common in receipts)
- 6 ↔ 0
- 1 ↔ 7
- 3 ↔ 8
Algorithm:
1. Check middle positions first (2,3,4,5...) - OCR errors more common there
2. Skip first digit (position 0) - usually reliable in CUI
3. Check checksum digit (last position) last
4. Prefer common OCR digit confusions (5↔8, 6↔0)
Args:
cui_digits: Original CUI digits (without RO prefix)
Returns:
Repaired CUI digits if 1-digit fix found, else None
"""
if len(cui_digits) < 6 or not cui_digits.isdigit():
return None
# If already valid, return as-is
if CUIChecksumRule.validate_checksum(cui_digits):
return cui_digits
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then position 1, then 0, then checksum
# Skip position 0 (first digit) - it's usually reliable
# Example for 8-digit CUI: [2,3,4,5,6, 1, 7(checksum)]
middle_positions = list(range(2, last_pos)) # positions 2 to n-2
position_order = middle_positions + [1, last_pos, 0] # check pos 0 last (rarely wrong)
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
if CUIChecksumRule.validate_checksum(candidate):
print(f"[CUI Repair] Fixed {cui_digits}{candidate} (position {pos}: {original_digit}{replacement})", flush=True)
return candidate
# No single-digit fix found
return None
@staticmethod
def normalize_cui(cui: Optional[str]) -> Optional[str]:
"""Normalize CUI to RO prefix + digits format.
"""Normalize CUI - fix OCR errors but preserve original format.
Rules:
- R0 → RO (fix OCR error where O is read as 0)
- Keep RO prefix if original had it (platitor TVA)
- Do NOT add RO if original didn't have it (neplatitor TVA)
- Try to repair 1-digit checksum errors (OCR mistakes like 5↔8)
Examples:
10562600 → RO10562600
45417955 → 45417955 (no prefix = neplatitor TVA, keep as-is)
R010562600 → RO10562600 (fix R0 OCR error)
RO10562600 → RO10562600 (unchanged)
RO10862600 → RO10562600 (repaired: 8→5 at position 2)
Args:
cui: Raw CUI string from OCR
Returns:
Normalized CUI with RO prefix, or None if invalid
Normalized CUI, or None if invalid
"""
if not cui:
return None
cui = cui.strip().upper()
# Remove existing prefix if present
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui = cui[2:]
elif cui.startswith("R0"):
cui = cui[2:]
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui if c.isdigit())
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
print(f"[CUI Normalize] Invalid length: {len(cui_digits)} digits (expected 6-10)", flush=True)
return None
# Add RO prefix
return f"RO{cui_digits}"
# Try to repair checksum if invalid
if not CUIChecksumRule.validate_checksum(cui_digits):
repaired = OCRValidationEngine._repair_cui_checksum(cui_digits)
if repaired:
cui_digits = repaired
# Return with RO prefix only if original had it
if had_ro_prefix:
return f"RO{cui_digits}"
else:
return cui_digits
@staticmethod
async def fuzzy_match_cui_from_db(
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match CUI against database of known suppliers.
This function:
1. Validates CUI checksum
2. If valid, looks up in database (exact match)
3. If invalid, tries 1-digit corrections and looks up each candidate
4. Returns the first match found in database
Args:
cui: Extracted CUI from OCR (may be invalid)
db_session: SQLAlchemy async session for database lookups
Returns:
Tuple of (corrected_cui, supplier_name) if found, else None
Usage in OCR extraction:
from backend.modules.data_entry.services.ocr.validation import OCRValidationEngine
match = await OCRValidationEngine.fuzzy_match_cui_from_db(extracted_cui, session)
if match:
corrected_cui, supplier_name = match
"""
from sqlalchemy import select, or_
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
if not cui:
return None
cui = cui.strip().upper()
# Check if original had RO/R0 prefix
had_ro_prefix = cui.startswith("RO") or cui.startswith("R0")
# Extract digits
if cui.startswith("RO"):
cui_digits = cui[2:]
elif cui.startswith("R0"): # Fix OCR error R0 → RO
cui_digits = cui[2:]
else:
cui_digits = cui
# Remove any non-digit characters
cui_digits = ''.join(c for c in cui_digits if c.isdigit())
# Validate length
if len(cui_digits) < 6 or len(cui_digits) > 10:
return None
# Helper to format CUI with optional RO prefix
def format_cui(digits: str) -> str:
if had_ro_prefix:
return f"RO{digits}"
return digits
# Helper to search database for CUI
async def lookup_cui_in_db(digits: str) -> Optional[tuple[str, str]]:
"""Search both synced and local suppliers for CUI."""
# Search patterns: with and without RO prefix
search_patterns = [digits, f"RO{digits}"]
# Search synced_suppliers first (more data)
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
or_(
SyncedSupplier.fiscal_code == digits,
SyncedSupplier.fiscal_code == f"RO{digits}",
SyncedSupplier.fiscal_code == digits.lstrip('0'), # Handle leading zeros
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
or_(
LocalSupplier.fiscal_code == digits,
LocalSupplier.fiscal_code == f"RO{digits}",
LocalSupplier.fiscal_code == digits.lstrip('0'),
)
).limit(1)
result = await db_session.execute(stmt)
row = result.first()
if row:
return (format_cui(digits), row.name)
return None
# 1. If checksum is valid, check if it exists in database (exact match)
if CUIChecksumRule.validate_checksum(cui_digits):
match = await lookup_cui_in_db(cui_digits)
if match:
print(f"[Fuzzy CUI] Exact match found: {cui}{match[0]} ({match[1]})", flush=True)
return match
# Valid checksum but not in DB - return as-is (it might be a new supplier)
return None
# 2. Invalid checksum - try 1-digit corrections and verify against database
print(f"[Fuzzy CUI] Invalid checksum for {cui}, trying corrections...", flush=True)
# Common OCR digit confusions (try these first)
confusion_pairs = {
'5': ['8', '6'], # 5 often misread as 8 or 6
'8': ['5', '3', '0'], # 8 often misread as 5, 3, or 0
'6': ['0', '8'], # 6 often misread as 0 or 8
'0': ['6', '8'], # 0 often misread as 6 or 8
'1': ['7', '4'], # 1 often misread as 7 or 4
'7': ['1'], # 7 often misread as 1
'3': ['8'], # 3 often misread as 8
'4': ['1'], # 4 often misread as 1
'2': ['7'], # 2 sometimes misread as 7
'9': ['0'], # 9 sometimes misread as 0
}
n = len(cui_digits)
last_pos = n - 1 # checksum position
# Position check order: middle positions first, then ends
middle_positions = list(range(2, last_pos))
position_order = middle_positions + [1, last_pos, 0]
for pos in position_order:
if pos >= n:
continue
original_digit = cui_digits[pos]
# Try common confusions first for this digit
candidates = confusion_pairs.get(original_digit, [])
# Then try all other digits
all_digits = [d for d in '0123456789' if d != original_digit and d not in candidates]
for replacement in candidates + all_digits:
candidate = cui_digits[:pos] + replacement + cui_digits[pos+1:]
# Only consider if checksum is valid
if not CUIChecksumRule.validate_checksum(candidate):
continue
# Check if this corrected CUI exists in database
match = await lookup_cui_in_db(candidate)
if match:
print(f"[Fuzzy CUI] DB match: {cui}{match[0]} ({match[1]}) [pos {pos}: {original_digit}{replacement}]", flush=True)
return match
# No match found in database
print(f"[Fuzzy CUI] No database match found for {cui}", flush=True)
return None
@staticmethod
async def fuzzy_match_by_name_and_cui(
vendor_name: Optional[str],
cui: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Fuzzy match supplier by NAME, then narrow down by CUI.
Algorithm:
1. Normalize vendor name (remove S.R.L., S.A., punctuation, etc.)
2. Search suppliers by fuzzy name match (LIKE %name%)
3. If multiple results, use fuzzy CUI matching to pick best one
4. Return the best match
Args:
vendor_name: Extracted vendor name from OCR
cui: Extracted CUI from OCR (may be invalid/incomplete)
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
from sqlalchemy import select, or_, func
from backend.modules.data_entry.db.models.nomenclature import SyncedSupplier, LocalSupplier
import re
if not vendor_name or len(vendor_name) < 3:
return None
# Normalize vendor name for search
def normalize_name(name: str) -> str:
"""Normalize name for fuzzy matching."""
name = name.upper()
# Remove company type suffixes
for suffix in ['S.R.L.', 'SRL', 'S.A.', 'SA', 'S.C.', 'SC', 'I.F.', 'IF', 'P.F.A.', 'PFA']:
name = name.replace(suffix, '')
# Remove punctuation and extra spaces
name = re.sub(r'[.,\-_/\\()"\']', ' ', name)
name = ' '.join(name.split())
return name.strip()
# Extract key words from vendor name (for fuzzy search)
normalized_name = normalize_name(vendor_name)
name_words = [w for w in normalized_name.split() if len(w) >= 3]
if not name_words:
return None
print(f"[Fuzzy Name] Searching for vendor: '{vendor_name}' → keywords: {name_words}", flush=True)
# Build search pattern - use first significant word
primary_word = name_words[0]
search_pattern = f"%{primary_word}%"
candidates = []
# Search synced_suppliers
stmt = select(SyncedSupplier.fiscal_code, SyncedSupplier.name).where(
func.upper(SyncedSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
# Search local_suppliers
stmt = select(LocalSupplier.fiscal_code, LocalSupplier.name).where(
func.upper(LocalSupplier.name).like(search_pattern)
).limit(20)
result = await db_session.execute(stmt)
for row in result:
if row.fiscal_code:
candidates.append((row.fiscal_code, row.name))
if not candidates:
print(f"[Fuzzy Name] No name matches found for '{primary_word}'", flush=True)
return None
print(f"[Fuzzy Name] Found {len(candidates)} name matches for '{primary_word}'", flush=True)
# If only one candidate, return it
if len(candidates) == 1:
print(f"[Fuzzy Name] Single match: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
# Multiple candidates - try to narrow down by CUI
if cui:
cui_digits = ''.join(c for c in cui.upper().replace('RO', '').replace('R0', '') if c.isdigit())
if len(cui_digits) >= 6:
# Score each candidate by how similar their CUI is to the extracted one
def cui_similarity(candidate_cui: str) -> int:
"""Calculate how many digits match in the same position."""
cand_digits = ''.join(c for c in candidate_cui.upper().replace('RO', '') if c.isdigit())
if len(cand_digits) != len(cui_digits):
return 0
return sum(1 for a, b in zip(cand_digits, cui_digits) if a == b)
# Sort candidates by CUI similarity (descending)
scored = [(cui_similarity(c[0]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
best_score, best_match = scored[0]
# Require at least 70% digit match for CUI similarity
min_matching = int(len(cui_digits) * 0.7)
if best_score >= min_matching:
print(f"[Fuzzy Name] Best CUI match: {best_match[0]} ({best_match[1]}) - score {best_score}/{len(cui_digits)}", flush=True)
return best_match
print(f"[Fuzzy Name] No strong CUI match (best score: {best_score}/{len(cui_digits)})", flush=True)
# If still multiple and no CUI match, try name similarity
def name_similarity(candidate_name: str) -> int:
"""Count how many keywords match."""
norm_cand = normalize_name(candidate_name)
return sum(1 for w in name_words if w in norm_cand)
scored = [(name_similarity(c[1]), c) for c in candidates]
scored.sort(key=lambda x: x[0], reverse=True)
if scored[0][0] >= 2: # At least 2 keywords match
best_match = scored[0][1]
print(f"[Fuzzy Name] Best name match: {best_match[0]} ({best_match[1]})", flush=True)
return best_match
# Return first candidate if nothing else works
print(f"[Fuzzy Name] Returning first candidate: {candidates[0][0]} ({candidates[0][1]})", flush=True)
return candidates[0]
@staticmethod
async def fuzzy_match_supplier(
cui: Optional[str],
vendor_name: Optional[str],
db_session
) -> Optional[tuple[str, str]]:
"""Combined fuzzy matching: try CUI first, then fallback to NAME+CUI.
Strategy:
1. Try fuzzy CUI matching (1-digit corrections with checksum validation)
2. If no CUI match, try fuzzy NAME matching, narrowed by CUI similarity
Args:
cui: Extracted CUI from OCR (may be invalid/incomplete)
vendor_name: Extracted vendor name from OCR
db_session: SQLAlchemy async session
Returns:
Tuple of (matched_cui, supplier_name) if found, else None
"""
# Step 1: Try fuzzy CUI matching
cui_match = await OCRValidationEngine.fuzzy_match_cui_from_db(cui, db_session)
if cui_match:
return cui_match
# Step 2: Fallback to fuzzy NAME + CUI matching
name_match = await OCRValidationEngine.fuzzy_match_by_name_and_cui(
vendor_name, cui, db_session
)
if name_match:
return name_match
return None