238 lines
7.1 KiB
Python
238 lines
7.1 KiB
Python
"""Append a validated M2D extraction to ``data/trades.csv``.
|
|
|
|
Pipeline:
|
|
JSON file --> pydantic validate (M2DExtraction)
|
|
--> load data/_meta.yaml (versions + schema)
|
|
--> compute ora_ro, zi, set, pl_marius, pl_theoretical
|
|
--> dedup on (screenshot_file, source)
|
|
--> atomic CSV write (temp file + os.replace)
|
|
|
|
Source values
|
|
- ``manual`` : Marius logged by hand
|
|
- ``vision`` : produced by the vision subagent
|
|
- ``manual_calibration`` : calibration P4 — manual leg
|
|
- ``vision_calibration`` : calibration P4 — vision leg
|
|
|
|
A row with ``source=manual_calibration`` and a row with ``source=vision_calibration``
|
|
for the *same* screenshot are allowed to coexist (different dedup keys); a
|
|
duplicate ``(screenshot_file, source)`` pair is rejected (or skipped — see
|
|
``append_row`` ``on_duplicate`` argument).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import yaml
|
|
|
|
from scripts.calendar_parse import calc_set, load_calendar, utc_to_ro
|
|
from scripts.pl_calc import pl_marius, pl_theoretical
|
|
from scripts.vision_schema import M2DExtraction, parse_extraction_dict
|
|
|
|
__all__ = [
|
|
"CSV_COLUMNS",
|
|
"VALID_SOURCES",
|
|
"build_row",
|
|
"read_rows",
|
|
"append_row",
|
|
"append_row_from_json",
|
|
]
|
|
|
|
|
|
Source = Literal["manual", "vision", "manual_calibration", "vision_calibration"]
|
|
|
|
VALID_SOURCES: frozenset[str] = frozenset(
|
|
{"manual", "vision", "manual_calibration", "vision_calibration"}
|
|
)
|
|
|
|
|
|
CSV_COLUMNS: tuple[str, ...] = (
|
|
"screenshot_file",
|
|
"source",
|
|
"data",
|
|
"ora_utc",
|
|
"ora_ro",
|
|
"zi",
|
|
"set",
|
|
"instrument",
|
|
"directie",
|
|
"tf_mare",
|
|
"tf_mic",
|
|
"calitate",
|
|
"entry",
|
|
"sl",
|
|
"tp0",
|
|
"tp1",
|
|
"tp2",
|
|
"risc_pct",
|
|
"outcome_path",
|
|
"max_reached",
|
|
"be_moved",
|
|
"confidence",
|
|
"ambiguities",
|
|
"note",
|
|
"pl_marius",
|
|
"pl_theoretical",
|
|
"indicator_version",
|
|
"pl_overlay_version",
|
|
"csv_schema_version",
|
|
)
|
|
|
|
|
|
def _load_meta(meta_path: Path) -> dict[str, Any]:
|
|
with meta_path.open("r", encoding="utf-8") as fh:
|
|
meta = yaml.safe_load(fh) or {}
|
|
required = ("indicator_version", "pl_overlay_version", "csv_schema_version")
|
|
missing = [k for k in required if k not in meta]
|
|
if missing:
|
|
raise ValueError(f"_meta.yaml missing required keys: {missing}")
|
|
return meta
|
|
|
|
|
|
def _format_optional(value: float | None) -> str:
|
|
return "" if value is None else f"{value:.4f}"
|
|
|
|
|
|
def build_row(
|
|
extraction: M2DExtraction,
|
|
source: str,
|
|
meta: dict[str, Any],
|
|
calendar: list[dict[str, Any]],
|
|
) -> dict[str, str]:
|
|
"""Compute the full CSV row dict for one extraction."""
|
|
if source not in VALID_SOURCES:
|
|
raise ValueError(
|
|
f"invalid source {source!r}; must be one of {sorted(VALID_SOURCES)}"
|
|
)
|
|
|
|
d_ro, t_ro, zi = utc_to_ro(extraction.data, extraction.ora_utc)
|
|
set_label = calc_set(d_ro, t_ro, zi, calendar)
|
|
pl_m = pl_marius(extraction.outcome_path, extraction.be_moved)
|
|
pl_t = pl_theoretical(extraction.max_reached)
|
|
|
|
return {
|
|
"screenshot_file": extraction.screenshot_file,
|
|
"source": source,
|
|
"data": extraction.data,
|
|
"ora_utc": extraction.ora_utc,
|
|
"ora_ro": t_ro.strftime("%H:%M"),
|
|
"zi": zi,
|
|
"set": set_label,
|
|
"instrument": extraction.instrument,
|
|
"directie": extraction.directie,
|
|
"tf_mare": extraction.tf_mare,
|
|
"tf_mic": extraction.tf_mic,
|
|
"calitate": extraction.calitate,
|
|
"entry": f"{extraction.entry}",
|
|
"sl": f"{extraction.sl}",
|
|
"tp0": f"{extraction.tp0}",
|
|
"tp1": f"{extraction.tp1}",
|
|
"tp2": f"{extraction.tp2}",
|
|
"risc_pct": f"{extraction.risc_pct}",
|
|
"outcome_path": extraction.outcome_path,
|
|
"max_reached": extraction.max_reached,
|
|
"be_moved": "true" if extraction.be_moved else "false",
|
|
"confidence": extraction.confidence,
|
|
"ambiguities": json.dumps(extraction.ambiguities, ensure_ascii=False),
|
|
"note": extraction.note,
|
|
"pl_marius": _format_optional(pl_m),
|
|
"pl_theoretical": _format_optional(pl_t),
|
|
"indicator_version": str(meta["indicator_version"]),
|
|
"pl_overlay_version": str(meta["pl_overlay_version"]),
|
|
"csv_schema_version": str(meta["csv_schema_version"]),
|
|
}
|
|
|
|
|
|
def read_rows(csv_path: Path) -> list[dict[str, str]]:
|
|
"""Read existing rows; return [] if the file does not exist or is empty."""
|
|
if not csv_path.exists() or csv_path.stat().st_size == 0:
|
|
return []
|
|
with csv_path.open("r", encoding="utf-8", newline="") as fh:
|
|
reader = csv.DictReader(fh)
|
|
return list(reader)
|
|
|
|
|
|
def _atomic_write(csv_path: Path, rows: list[dict[str, str]]) -> None:
|
|
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fd, tmp_name = tempfile.mkstemp(
|
|
prefix=csv_path.name + ".",
|
|
suffix=".tmp",
|
|
dir=str(csv_path.parent),
|
|
)
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8", newline="") as fh:
|
|
writer = csv.DictWriter(fh, fieldnames=list(CSV_COLUMNS))
|
|
writer.writeheader()
|
|
for r in rows:
|
|
writer.writerow({k: r.get(k, "") for k in CSV_COLUMNS})
|
|
os.replace(tmp_name, csv_path)
|
|
except Exception:
|
|
try:
|
|
os.unlink(tmp_name)
|
|
except OSError:
|
|
pass
|
|
raise
|
|
|
|
|
|
def append_row(
|
|
extraction: M2DExtraction,
|
|
source: str,
|
|
csv_path: Path,
|
|
meta_path: Path,
|
|
calendar_path: Path,
|
|
on_duplicate: Literal["raise", "skip"] = "raise",
|
|
) -> dict[str, str]:
|
|
"""Append one extraction to the CSV.
|
|
|
|
Dedup key: ``(screenshot_file, source)``. If a row with the same key
|
|
already exists, behaviour is controlled by ``on_duplicate``:
|
|
|
|
- ``"raise"`` (default): raise ``ValueError``.
|
|
- ``"skip"``: leave the CSV untouched and return the *existing* row.
|
|
"""
|
|
meta = _load_meta(meta_path)
|
|
calendar = load_calendar(calendar_path)
|
|
row = build_row(extraction, source, meta, calendar)
|
|
|
|
existing = read_rows(csv_path)
|
|
key = (row["screenshot_file"], row["source"])
|
|
for r in existing:
|
|
if (r.get("screenshot_file"), r.get("source")) == key:
|
|
if on_duplicate == "skip":
|
|
return r
|
|
raise ValueError(
|
|
f"duplicate row: screenshot_file={key[0]!r} source={key[1]!r} "
|
|
f"already exists in {csv_path}"
|
|
)
|
|
|
|
existing.append(row)
|
|
_atomic_write(csv_path, existing)
|
|
return row
|
|
|
|
|
|
def append_row_from_json(
|
|
json_path: Path,
|
|
source: str,
|
|
csv_path: Path,
|
|
meta_path: Path,
|
|
calendar_path: Path,
|
|
on_duplicate: Literal["raise", "skip"] = "raise",
|
|
) -> dict[str, str]:
|
|
"""Convenience wrapper: load JSON, validate, append."""
|
|
with Path(json_path).open("r", encoding="utf-8") as fh:
|
|
payload = json.load(fh)
|
|
extraction = parse_extraction_dict(payload)
|
|
return append_row(
|
|
extraction=extraction,
|
|
source=source,
|
|
csv_path=csv_path,
|
|
meta_path=meta_path,
|
|
calendar_path=calendar_path,
|
|
on_duplicate=on_duplicate,
|
|
)
|