diff --git a/samples/calibration_labels.README.md b/samples/calibration_labels.README.md new file mode 100644 index 0000000..216b9d0 --- /dev/null +++ b/samples/calibration_labels.README.md @@ -0,0 +1,33 @@ +# calibration_labels.json — schema + +Used by `atm validate-calibration` to check that the current color calibration +classifies known-good screenshots correctly before a live session. + +## Schema + +A JSON array of entries. Each entry: + +| Field | Type | Required | Description | +|------------|---------|----------|----------------------------------------------------------------| +| `path` | string | yes | Path to a PNG frame (relative to CWD or absolute). | +| `expected` | string | yes | Expected color name: one of `turquoise`, `yellow`, `dark_green`, `dark_red`, `light_green`, `light_red`, `gray`. | +| `note` | string | no | Freeform annotation; shown in SUGGESTIONS output. | + +## Usage + +```bash +atm validate-calibration samples/calibration_labels.json +``` + +Exit codes: +- `0` — every sample PASS +- `1` — one or more FAIL +- `2` — label file missing or malformed JSON + +## Adding new samples + +1. Find a screenshot in `logs/fires/` whose dot color you can verify by eye. +2. Append an entry with `path`, `expected`, and an optional `note`. +3. Re-run validation. If it FAILs, the SUGGESTIONS section will tell you the + RGB distance between the observed pixel and the expected color's center — + use that as input for `atm calibrate`. diff --git a/samples/calibration_labels.json b/samples/calibration_labels.json new file mode 100644 index 0000000..031a0df --- /dev/null +++ b/samples/calibration_labels.json @@ -0,0 +1,17 @@ +[ + { + "path": "logs/fires/20260417_201500_arm_sell.png", + "expected": "yellow", + "note": "first arm of SELL cycle 2026-04-17" + }, + { + "path": "logs/fires/20260417_205302_ss.png", + "expected": "dark_red", + "note": "user confirmed via screenshot (missed live alert)" + }, + { + "path": "logs/fires/20260417_210441_ss.png", + "expected": "light_red", + "note": "fire phase (missed live alert)" + } +] diff --git a/src/atm/main.py b/src/atm/main.py index c82815f..c129b8b 100644 --- a/src/atm/main.py +++ b/src/atm/main.py @@ -135,6 +135,16 @@ def main(argv=None) -> None: metavar="PATH", help="Journal JSONL file (default: trades.jsonl)", ) + # validate-calibration + p_valid = sub.add_parser( + "validate-calibration", + help="Offline: run Detector on labeled frames and report PASS/FAIL", + ) + p_valid.add_argument( + "label_file", type=Path, metavar="LABEL_FILE", + help="JSON array with [{path, expected, note?}, ...] entries", + ) + args = parser.parse_args(argv) _dispatch = { @@ -145,6 +155,7 @@ def main(argv=None) -> None: "debug": _cmd_debug, "journal": _cmd_journal, "report": _cmd_report, + "validate-calibration": _cmd_validate_calibration, } _dispatch[args.command](args) @@ -418,6 +429,37 @@ def _cmd_report(args) -> None: ) +def _cmd_validate_calibration(args) -> None: + """Run offline calibration validation; exit 0 on 100% PASS, 1 otherwise.""" + try: + from atm.validate import validate_calibration, ValidationError + except ImportError as exc: + sys.exit(f"validate module not available: {exc}") + + label_file = Path(args.label_file) + try: + cfg = Config.load_current(Path("configs")) + except FileNotFoundError as exc: + sys.exit(f"config not found: {exc}") + + try: + config_name = "" + cur_ptr = Path("configs") / "current.txt" + if cur_ptr.exists(): + config_name = cur_ptr.read_text(encoding="utf-8").strip() + except Exception: + config_name = "" + + try: + report = validate_calibration(label_file, cfg, config_name=config_name) + except ValidationError as exc: + print(f"error: {exc}", file=sys.stderr) + sys.exit(2) + + print(report.render()) + sys.exit(0 if report.all_pass else 1) + + # --------------------------------------------------------------------------- # Live loop # --------------------------------------------------------------------------- diff --git a/src/atm/validate.py b/src/atm/validate.py new file mode 100644 index 0000000..2b89ff8 --- /dev/null +++ b/src/atm/validate.py @@ -0,0 +1,229 @@ +"""Offline calibration validation: run Detector on labeled frames, report PASS/FAIL. + +Used by the `atm validate-calibration` subcommand. Reports per-sample detection +results against expected labels, and for failures, computes RGB distance to +each color threshold and emits tuning suggestions. + +Reuses `Detector.step(frame)` - does NOT reimplement color classification. +""" +from __future__ import annotations + +import json +import math +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .config import Config + + +@dataclass +class SampleRecord: + path: str + expected: str + detected: str | None + confidence: float + rgb: tuple[int, int, int] | None + top3: list[tuple[str, float]] # [(name, score), ...] ranked by RGB distance + passed: bool + note: str = "" + error: str | None = None # non-None if frame load failed / schema bad + + +@dataclass +class ValidationReport: + records: list[SampleRecord] = field(default_factory=list) + config_name: str = "" + + @property + def total(self) -> int: + return len(self.records) + + @property + def passed(self) -> int: + return sum(1 for r in self.records if r.passed) + + @property + def failed(self) -> int: + return self.total - self.passed + + @property + def all_pass(self) -> bool: + return self.total > 0 and self.failed == 0 + + def render(self) -> str: + lines: list[str] = [] + hdr = f"Testing {self.total} frames" + if self.config_name: + hdr += f" against config {self.config_name}" + hdr += "..." + lines.append(hdr) + lines.append("") + + for r in self.records: + name = Path(r.path).name or r.path + if r.error: + lines.append(f" [FAIL] {name}") + lines.append(f" error: {r.error}") + continue + tag = "PASS" if r.passed else "FAIL" + rgb_str = f"RGB {r.rgb}" if r.rgb is not None else "RGB n/a" + detected = r.detected if r.detected is not None else "none" + lines.append(f" [{tag}] {name}") + lines.append( + f" expected={r.expected} detected={detected} " + f"(conf {r.confidence:.2f}, {rgb_str})" + ) + if not r.passed and r.top3: + top3_str = " ".join(f"{n}({c:.2f})" for n, c in r.top3) + lines.append(f" Top 3 candidates: {top3_str}") + + lines.append("") + pct = (self.passed / self.total * 100.0) if self.total else 0.0 + lines.append(f"SUMMARY: {self.passed}/{self.total} PASS ({pct:.0f}%)") + + fails = [r for r in self.records if not r.passed] + if fails: + lines.append("FAILED:") + for r in fails: + name = Path(r.path).name or r.path + if r.error: + lines.append(f" - {name}: {r.error}") + continue + detected = r.detected if r.detected is not None else "none" + lines.append( + f" - {name}: expected {r.expected}, got {detected}" + ) + + sug_lines = [ + r._suggestion # type: ignore[attr-defined] + for r in fails + if getattr(r, "_suggestion", "") + ] + if sug_lines: + lines.append("") + lines.append("SUGGESTIONS:") + for s in sug_lines: + lines.append(f" - {s}") + + return "\n".join(lines) + + def __str__(self) -> str: + return self.render() + + +class ValidationError(Exception): + """Raised for missing label files or invalid schema.""" + + +def _rgb_distance(a: tuple[int, int, int], b: tuple[int, int, int]) -> float: + return math.sqrt(sum((a[i] - b[i]) ** 2 for i in range(3))) + + +def _load_labels(label_file: Path) -> list[dict[str, Any]]: + if not label_file.exists(): + raise ValidationError(f"label file not found: {label_file}") + try: + data = json.loads(label_file.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValidationError(f"invalid JSON in {label_file}: {exc}") from exc + if not isinstance(data, list): + raise ValidationError( + f"label file must be a JSON array; got {type(data).__name__}" + ) + return data + + +def validate_calibration( + label_file: Path, + cfg: Config, + config_name: str = "", +) -> ValidationReport: + """Run Detector on each labeled frame; return a ValidationReport. + + Reuses `Detector.step(frame)`. Loads frames via cv2.imread. + Raises ValidationError if the label file is missing or malformed. + """ + import cv2 # local import keeps module import cheap + from .detector import Detector + + entries = _load_labels(label_file) + report = ValidationReport(config_name=config_name) + + palette = { + name: spec.rgb + for name, spec in cfg.colors.items() + if name != "background" + } + + detector = Detector(cfg=cfg, capture=lambda: None) + + for entry in entries: + path = str(entry.get("path", "")) + expected = str(entry.get("expected", "")) + note = str(entry.get("note", "")) + + if not path or not expected: + rec = SampleRecord( + path=path, expected=expected, detected=None, confidence=0.0, + rgb=None, top3=[], passed=False, note=note, + error="missing 'path' or 'expected' field", + ) + rec._suggestion = "" # type: ignore[attr-defined] + report.records.append(rec) + continue + + frame = cv2.imread(path) + if frame is None: + rec = SampleRecord( + path=path, expected=expected, detected=None, confidence=0.0, + rgb=None, top3=[], passed=False, note=note, + error=f"cv2.imread failed for {path}", + ) + rec._suggestion = "" # type: ignore[attr-defined] + report.records.append(rec) + continue + + result = detector.step(ts=0.0, frame=frame) + + match = result.match + if match is None: + detected: str | None = None + confidence = 0.0 + else: + detected = match.name if match.name != "UNKNOWN" else None + confidence = match.confidence + + rgb = result.rgb + + # Top 3 candidates: rank palette entries by RGB distance to observed. + top3: list[tuple[str, float]] = [] + if rgb is not None: + scored: list[tuple[str, float]] = [] + for name, ref in palette.items(): + scored.append((name, _rgb_distance(rgb, ref))) + scored.sort(key=lambda t: t[1]) + top3 = [(n, 1.0 / (1.0 + d / 20.0)) for n, d in scored[:3]] + + passed = detected == expected + + rec = SampleRecord( + path=path, expected=expected, detected=detected, + confidence=confidence, rgb=rgb, top3=top3, passed=passed, note=note, + ) + + if not passed and rgb is not None and expected in palette: + ref = palette[expected] + tol = cfg.colors[expected].tolerance + dist = _rgb_distance(rgb, ref) + rec._suggestion = ( # type: ignore[attr-defined] + f"{expected} praguri curente: RGB{ref} +/- {tol:.0f}. " + f"Pixelul observat {rgb} e la distanta {dist:.1f} " + f"-> recalibreaza cu acest sample." + ) + else: + rec._suggestion = "" # type: ignore[attr-defined] + + report.records.append(rec) + + return report diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..c8e8dfc --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,214 @@ +"""Tests for atm.validate — offline calibration validation. + +Covers the 3 tests from plan section D': + 17. test_validate_calibration_pass + 18. test_validate_calibration_fail_reports_top_candidates + 19. test_validate_calibration_file_not_found +""" +from __future__ import annotations + +import json +from pathlib import Path + +import numpy as np +import pytest + +from atm.config import ( + CanaryRegion, + ColorSpec, + Config, + DiscordCfg, + ROI, + TelegramCfg, + YAxisCalib, +) +from atm.detector import DetectionResult +from atm.vision import ColorMatch + + +def _make_config() -> Config: + """Minimal Config with a palette large enough to support top-3 candidates.""" + colors = { + "turquoise": ColorSpec(rgb=(0, 200, 200), tolerance=30), + "yellow": ColorSpec(rgb=(255, 255, 0), tolerance=30), + "dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=30), + "dark_red": ColorSpec(rgb=(165, 42, 42), tolerance=30), + "light_green": ColorSpec(rgb=(144, 238, 144), tolerance=30), + "light_red": ColorSpec(rgb=(255, 182, 193), tolerance=30), + "gray": ColorSpec(rgb=(128, 128, 128), tolerance=30), + "background": ColorSpec(rgb=(18, 18, 18), tolerance=15), + } + return Config( + window_title="test", + dot_roi=ROI(x=0, y=0, w=100, h=100), + chart_roi=ROI(x=0, y=0, w=100, h=100), + colors=colors, + y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=100, p2_price=0.0), + canary=CanaryRegion( + roi=ROI(x=0, y=0, w=10, h=10), + baseline_phash="0" * 64, + ), + discord=DiscordCfg(webhook_url="http://localhost/fake"), + telegram=TelegramCfg(bot_token="fake_token", chat_id="123"), + debounce_depth=1, + ) + + +def _write_labels(tmp_path: Path, entries: list[dict]) -> Path: + f = tmp_path / "labels.json" + f.write_text(json.dumps(entries), encoding="utf-8") + return f + + +def _write_blank_png(tmp_path: Path, name: str) -> Path: + """Write a trivially-valid 10x10 BGR image so cv2.imread returns non-None.""" + import cv2 + p = tmp_path / name + arr = np.zeros((10, 10, 3), dtype=np.uint8) + cv2.imwrite(str(p), arr) + return p + + +# --------------------------------------------------------------------------- +# Test 17: PASS path — mocked Detector.step returns expected color +# --------------------------------------------------------------------------- + +def test_validate_calibration_pass(monkeypatch, tmp_path): + from atm import validate as validate_mod + + img_path = _write_blank_png(tmp_path, "yellow_sample.png") + labels = _write_labels( + tmp_path, + [{"path": str(img_path), "expected": "yellow", "note": "test"}], + ) + + def fake_step(self, ts, frame=None): + return DetectionResult( + ts=ts, + window_found=True, + dot_found=True, + rgb=(250, 250, 5), + match=ColorMatch(name="yellow", distance=6.0, confidence=0.94), + accepted=True, + color="yellow", + ) + + monkeypatch.setattr("atm.detector.Detector.step", fake_step) + + report = validate_mod.validate_calibration(labels, _make_config()) + + assert report.total == 1 + assert report.passed == 1 + assert report.failed == 0 + assert report.all_pass is True + rec = report.records[0] + assert rec.passed is True + assert rec.detected == "yellow" + assert rec.expected == "yellow" + assert "[PASS]" in report.render() + + # CLI wiring: exit 0 + import atm.main as _main + + class _Args: + label_file = labels + + monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config())) + with pytest.raises(SystemExit) as exc_info: + _main._cmd_validate_calibration(_Args()) + assert exc_info.value.code == 0 + + +# --------------------------------------------------------------------------- +# Test 18: FAIL path — Detector returns wrong color; report lists top 3 +# candidates and a SUGGESTIONS line with RGB distance. +# --------------------------------------------------------------------------- + +def test_validate_calibration_fail_reports_top_candidates(monkeypatch, tmp_path): + from atm import validate as validate_mod + + img_path = _write_blank_png(tmp_path, "dark_red_sample.png") + labels = _write_labels( + tmp_path, + [{"path": str(img_path), "expected": "dark_red", "note": "missed dark_red"}], + ) + + # Observed RGB closer to gray than dark_red (like the real 2026-04-17 miss). + def fake_step(self, ts, frame=None): + return DetectionResult( + ts=ts, + window_found=True, + dot_found=True, + rgb=(135, 62, 67), + match=ColorMatch(name="gray", distance=45.0, confidence=0.12), + accepted=True, + color="gray", + ) + + monkeypatch.setattr("atm.detector.Detector.step", fake_step) + + report = validate_mod.validate_calibration(labels, _make_config()) + + assert report.total == 1 + assert report.failed == 1 + assert report.all_pass is False + + rec = report.records[0] + assert rec.passed is False + assert rec.detected == "gray" + assert rec.expected == "dark_red" + # Top 3 candidates populated (name, score) sorted by RGB distance. + assert len(rec.top3) == 3 + names = [n for n, _ in rec.top3] + # dark_red should appear in top candidates since observed RGB(135,62,67) + # is reasonably close to dark_red(165,42,42). + assert "dark_red" in names + + rendered = report.render() + assert "[FAIL]" in rendered + assert "Top 3 candidates:" in rendered + assert "SUGGESTIONS:" in rendered + # The suggestion must mention the expected color's RGB and the measured distance. + assert "dark_red" in rendered + assert "(165, 42, 42)" in rendered + + # CLI wiring: exit 1 + import atm.main as _main + + class _Args: + label_file = labels + + monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config())) + with pytest.raises(SystemExit) as exc_info: + _main._cmd_validate_calibration(_Args()) + assert exc_info.value.code == 1 + + +# --------------------------------------------------------------------------- +# Test 19: missing label file — clean error, non-zero exit, no stack trace +# --------------------------------------------------------------------------- + +def test_validate_calibration_file_not_found(monkeypatch, tmp_path, capsys): + from atm import validate as validate_mod + + missing = tmp_path / "nope.json" + + # Library-level: raises ValidationError (not bare FileNotFoundError). + with pytest.raises(validate_mod.ValidationError) as exc_info: + validate_mod.validate_calibration(missing, _make_config()) + assert "not found" in str(exc_info.value).lower() + + # CLI-level: graceful sys.exit with non-zero code, message on stderr. + import atm.main as _main + + class _Args: + label_file = missing + + monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config())) + with pytest.raises(SystemExit) as exc_info: + _main._cmd_validate_calibration(_Args()) + assert exc_info.value.code != 0 + err = capsys.readouterr().err + assert "not found" in err.lower() + # Ensure no python traceback leaked through. + assert "Traceback" not in err