126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
"""Pydantic schema for the M2D vision-extraction JSON returned by the vision subagent."""
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from datetime import date as date_type, datetime, timezone
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
|
|
|
|
_DATA_PATTERN = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
|
_ORA_PATTERN = re.compile(r"^\d{2}:\d{2}$")
|
|
|
|
|
|
class M2DExtraction(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
screenshot_file: str
|
|
data: str
|
|
ora_utc: str
|
|
instrument: Literal["DIA", "US30", "other"]
|
|
directie: Literal["Buy", "Sell"]
|
|
tf_mare: Literal["5min", "15min"]
|
|
tf_mic: Literal["1min", "3min"]
|
|
calitate: Literal["Clară", "Mai mare ca impuls", "Slabă", "n/a"]
|
|
entry: float
|
|
sl: float
|
|
tp0: float
|
|
tp1: float
|
|
tp2: float
|
|
risc_pct: float
|
|
outcome_path: Literal[
|
|
"SL", "TP0→SL", "TP0→TP1", "TP0→TP2", "TP0→pending", "pending"
|
|
]
|
|
max_reached: Literal["SL_first", "TP0", "TP1", "TP2"]
|
|
be_moved: bool
|
|
confidence: Literal["high", "medium", "low"]
|
|
ambiguities: list[str] = Field(default_factory=list)
|
|
note: str = ""
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_data_format(self) -> "M2DExtraction":
|
|
if not _DATA_PATTERN.match(self.data):
|
|
raise ValueError(
|
|
f"data must match YYYY-MM-DD, got {self.data!r}"
|
|
)
|
|
try:
|
|
parsed = date_type.fromisoformat(self.data)
|
|
except ValueError as exc:
|
|
raise ValueError(f"data is not a valid ISO date: {self.data!r}") from exc
|
|
today = datetime.now(timezone.utc).date()
|
|
if parsed > today:
|
|
raise ValueError(
|
|
f"data {self.data!r} is in the future (today UTC: {today.isoformat()})"
|
|
)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_ora_utc_format(self) -> "M2DExtraction":
|
|
if not _ORA_PATTERN.match(self.ora_utc):
|
|
raise ValueError(
|
|
f"ora_utc must match HH:MM, got {self.ora_utc!r}"
|
|
)
|
|
try:
|
|
datetime.strptime(self.ora_utc, "%H:%M")
|
|
except ValueError as exc:
|
|
raise ValueError(
|
|
f"ora_utc is not a valid HH:MM time: {self.ora_utc!r}"
|
|
) from exc
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_entry_ne_sl(self) -> "M2DExtraction":
|
|
if self.entry == self.sl:
|
|
raise ValueError("entry must not equal sl (zero risk distance)")
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_tp_ordering(self) -> "M2DExtraction":
|
|
if self.directie == "Buy":
|
|
if not (self.sl < self.entry < self.tp0 < self.tp1 < self.tp2):
|
|
raise ValueError(
|
|
"for Buy, required: sl < entry < tp0 < tp1 < tp2 "
|
|
f"(got sl={self.sl}, entry={self.entry}, tp0={self.tp0}, "
|
|
f"tp1={self.tp1}, tp2={self.tp2})"
|
|
)
|
|
else:
|
|
if not (self.sl > self.entry > self.tp0 > self.tp1 > self.tp2):
|
|
raise ValueError(
|
|
"for Sell, required: sl > entry > tp0 > tp1 > tp2 "
|
|
f"(got sl={self.sl}, entry={self.entry}, tp0={self.tp0}, "
|
|
f"tp1={self.tp1}, tp2={self.tp2})"
|
|
)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_outcome_max_consistency(self) -> "M2DExtraction":
|
|
op = self.outcome_path
|
|
mr = self.max_reached
|
|
if op == "SL":
|
|
if mr != "SL_first":
|
|
raise ValueError(
|
|
f"outcome_path='SL' requires max_reached='SL_first', got {mr!r}"
|
|
)
|
|
elif op.startswith("TP0"):
|
|
if mr not in {"TP0", "TP1", "TP2"}:
|
|
raise ValueError(
|
|
f"outcome_path={op!r} requires max_reached in "
|
|
f"{{TP0, TP1, TP2}}, got {mr!r}"
|
|
)
|
|
# op == "pending" → any max_reached accepted
|
|
return self
|
|
|
|
|
|
def parse_extraction(json_str: str) -> M2DExtraction:
|
|
"""Parse a JSON string into an M2DExtraction.
|
|
|
|
Raises pydantic.ValidationError on invalid input.
|
|
"""
|
|
return M2DExtraction.model_validate_json(json_str)
|
|
|
|
|
|
def parse_extraction_dict(d: dict) -> M2DExtraction:
|
|
"""Validate a dict against the M2DExtraction schema."""
|
|
return M2DExtraction.model_validate(d)
|