feat(cli): atm validate-calibration — offline color classification gate
Adds `atm validate-calibration LABEL_FILE` subcommand that runs the Detector on a set of labeled PNG frames and reports per-sample PASS/FAIL with top-3 candidate colors and RGB-distance suggestions for failures. Exits 0 on 100% PASS, 1 on any FAIL, 2 on missing/malformed label file. - New module src/atm/validate.py with ValidationReport + SampleRecord dataclasses; reuses Detector.step(frame), does not reimplement color classification. - main.py: new `validate-calibration` subparser and _cmd_validate_calibration handler wired into the dispatch map. - samples/calibration_labels.json seeded with 3 entries from the 2026-04-17 incident, plus a README describing the schema. - tests/test_validate.py covers the 3 planned cases: PASS, FAIL w/ top-3 + suggestion, missing file (graceful error, no traceback). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
33
samples/calibration_labels.README.md
Normal file
33
samples/calibration_labels.README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# calibration_labels.json — schema
|
||||
|
||||
Used by `atm validate-calibration` to check that the current color calibration
|
||||
classifies known-good screenshots correctly before a live session.
|
||||
|
||||
## Schema
|
||||
|
||||
A JSON array of entries. Each entry:
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|------------|---------|----------|----------------------------------------------------------------|
|
||||
| `path` | string | yes | Path to a PNG frame (relative to CWD or absolute). |
|
||||
| `expected` | string | yes | Expected color name: one of `turquoise`, `yellow`, `dark_green`, `dark_red`, `light_green`, `light_red`, `gray`. |
|
||||
| `note` | string | no | Freeform annotation; shown in SUGGESTIONS output. |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
atm validate-calibration samples/calibration_labels.json
|
||||
```
|
||||
|
||||
Exit codes:
|
||||
- `0` — every sample PASS
|
||||
- `1` — one or more FAIL
|
||||
- `2` — label file missing or malformed JSON
|
||||
|
||||
## Adding new samples
|
||||
|
||||
1. Find a screenshot in `logs/fires/` whose dot color you can verify by eye.
|
||||
2. Append an entry with `path`, `expected`, and an optional `note`.
|
||||
3. Re-run validation. If it FAILs, the SUGGESTIONS section will tell you the
|
||||
RGB distance between the observed pixel and the expected color's center —
|
||||
use that as input for `atm calibrate`.
|
||||
17
samples/calibration_labels.json
Normal file
17
samples/calibration_labels.json
Normal file
@@ -0,0 +1,17 @@
|
||||
[
|
||||
{
|
||||
"path": "logs/fires/20260417_201500_arm_sell.png",
|
||||
"expected": "yellow",
|
||||
"note": "first arm of SELL cycle 2026-04-17"
|
||||
},
|
||||
{
|
||||
"path": "logs/fires/20260417_205302_ss.png",
|
||||
"expected": "dark_red",
|
||||
"note": "user confirmed via screenshot (missed live alert)"
|
||||
},
|
||||
{
|
||||
"path": "logs/fires/20260417_210441_ss.png",
|
||||
"expected": "light_red",
|
||||
"note": "fire phase (missed live alert)"
|
||||
}
|
||||
]
|
||||
@@ -135,6 +135,16 @@ def main(argv=None) -> None:
|
||||
metavar="PATH", help="Journal JSONL file (default: trades.jsonl)",
|
||||
)
|
||||
|
||||
# validate-calibration
|
||||
p_valid = sub.add_parser(
|
||||
"validate-calibration",
|
||||
help="Offline: run Detector on labeled frames and report PASS/FAIL",
|
||||
)
|
||||
p_valid.add_argument(
|
||||
"label_file", type=Path, metavar="LABEL_FILE",
|
||||
help="JSON array with [{path, expected, note?}, ...] entries",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_dispatch = {
|
||||
@@ -145,6 +155,7 @@ def main(argv=None) -> None:
|
||||
"debug": _cmd_debug,
|
||||
"journal": _cmd_journal,
|
||||
"report": _cmd_report,
|
||||
"validate-calibration": _cmd_validate_calibration,
|
||||
}
|
||||
_dispatch[args.command](args)
|
||||
|
||||
@@ -418,6 +429,37 @@ def _cmd_report(args) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _cmd_validate_calibration(args) -> None:
|
||||
"""Run offline calibration validation; exit 0 on 100% PASS, 1 otherwise."""
|
||||
try:
|
||||
from atm.validate import validate_calibration, ValidationError
|
||||
except ImportError as exc:
|
||||
sys.exit(f"validate module not available: {exc}")
|
||||
|
||||
label_file = Path(args.label_file)
|
||||
try:
|
||||
cfg = Config.load_current(Path("configs"))
|
||||
except FileNotFoundError as exc:
|
||||
sys.exit(f"config not found: {exc}")
|
||||
|
||||
try:
|
||||
config_name = ""
|
||||
cur_ptr = Path("configs") / "current.txt"
|
||||
if cur_ptr.exists():
|
||||
config_name = cur_ptr.read_text(encoding="utf-8").strip()
|
||||
except Exception:
|
||||
config_name = ""
|
||||
|
||||
try:
|
||||
report = validate_calibration(label_file, cfg, config_name=config_name)
|
||||
except ValidationError as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
print(report.render())
|
||||
sys.exit(0 if report.all_pass else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
229
src/atm/validate.py
Normal file
229
src/atm/validate.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Offline calibration validation: run Detector on labeled frames, report PASS/FAIL.
|
||||
|
||||
Used by the `atm validate-calibration` subcommand. Reports per-sample detection
|
||||
results against expected labels, and for failures, computes RGB distance to
|
||||
each color threshold and emits tuning suggestions.
|
||||
|
||||
Reuses `Detector.step(frame)` - does NOT reimplement color classification.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleRecord:
|
||||
path: str
|
||||
expected: str
|
||||
detected: str | None
|
||||
confidence: float
|
||||
rgb: tuple[int, int, int] | None
|
||||
top3: list[tuple[str, float]] # [(name, score), ...] ranked by RGB distance
|
||||
passed: bool
|
||||
note: str = ""
|
||||
error: str | None = None # non-None if frame load failed / schema bad
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
records: list[SampleRecord] = field(default_factory=list)
|
||||
config_name: str = ""
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
return len(self.records)
|
||||
|
||||
@property
|
||||
def passed(self) -> int:
|
||||
return sum(1 for r in self.records if r.passed)
|
||||
|
||||
@property
|
||||
def failed(self) -> int:
|
||||
return self.total - self.passed
|
||||
|
||||
@property
|
||||
def all_pass(self) -> bool:
|
||||
return self.total > 0 and self.failed == 0
|
||||
|
||||
def render(self) -> str:
|
||||
lines: list[str] = []
|
||||
hdr = f"Testing {self.total} frames"
|
||||
if self.config_name:
|
||||
hdr += f" against config {self.config_name}"
|
||||
hdr += "..."
|
||||
lines.append(hdr)
|
||||
lines.append("")
|
||||
|
||||
for r in self.records:
|
||||
name = Path(r.path).name or r.path
|
||||
if r.error:
|
||||
lines.append(f" [FAIL] {name}")
|
||||
lines.append(f" error: {r.error}")
|
||||
continue
|
||||
tag = "PASS" if r.passed else "FAIL"
|
||||
rgb_str = f"RGB {r.rgb}" if r.rgb is not None else "RGB n/a"
|
||||
detected = r.detected if r.detected is not None else "none"
|
||||
lines.append(f" [{tag}] {name}")
|
||||
lines.append(
|
||||
f" expected={r.expected} detected={detected} "
|
||||
f"(conf {r.confidence:.2f}, {rgb_str})"
|
||||
)
|
||||
if not r.passed and r.top3:
|
||||
top3_str = " ".join(f"{n}({c:.2f})" for n, c in r.top3)
|
||||
lines.append(f" Top 3 candidates: {top3_str}")
|
||||
|
||||
lines.append("")
|
||||
pct = (self.passed / self.total * 100.0) if self.total else 0.0
|
||||
lines.append(f"SUMMARY: {self.passed}/{self.total} PASS ({pct:.0f}%)")
|
||||
|
||||
fails = [r for r in self.records if not r.passed]
|
||||
if fails:
|
||||
lines.append("FAILED:")
|
||||
for r in fails:
|
||||
name = Path(r.path).name or r.path
|
||||
if r.error:
|
||||
lines.append(f" - {name}: {r.error}")
|
||||
continue
|
||||
detected = r.detected if r.detected is not None else "none"
|
||||
lines.append(
|
||||
f" - {name}: expected {r.expected}, got {detected}"
|
||||
)
|
||||
|
||||
sug_lines = [
|
||||
r._suggestion # type: ignore[attr-defined]
|
||||
for r in fails
|
||||
if getattr(r, "_suggestion", "")
|
||||
]
|
||||
if sug_lines:
|
||||
lines.append("")
|
||||
lines.append("SUGGESTIONS:")
|
||||
for s in sug_lines:
|
||||
lines.append(f" - {s}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.render()
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Raised for missing label files or invalid schema."""
|
||||
|
||||
|
||||
def _rgb_distance(a: tuple[int, int, int], b: tuple[int, int, int]) -> float:
|
||||
return math.sqrt(sum((a[i] - b[i]) ** 2 for i in range(3)))
|
||||
|
||||
|
||||
def _load_labels(label_file: Path) -> list[dict[str, Any]]:
|
||||
if not label_file.exists():
|
||||
raise ValidationError(f"label file not found: {label_file}")
|
||||
try:
|
||||
data = json.loads(label_file.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValidationError(f"invalid JSON in {label_file}: {exc}") from exc
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError(
|
||||
f"label file must be a JSON array; got {type(data).__name__}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def validate_calibration(
|
||||
label_file: Path,
|
||||
cfg: Config,
|
||||
config_name: str = "",
|
||||
) -> ValidationReport:
|
||||
"""Run Detector on each labeled frame; return a ValidationReport.
|
||||
|
||||
Reuses `Detector.step(frame)`. Loads frames via cv2.imread.
|
||||
Raises ValidationError if the label file is missing or malformed.
|
||||
"""
|
||||
import cv2 # local import keeps module import cheap
|
||||
from .detector import Detector
|
||||
|
||||
entries = _load_labels(label_file)
|
||||
report = ValidationReport(config_name=config_name)
|
||||
|
||||
palette = {
|
||||
name: spec.rgb
|
||||
for name, spec in cfg.colors.items()
|
||||
if name != "background"
|
||||
}
|
||||
|
||||
detector = Detector(cfg=cfg, capture=lambda: None)
|
||||
|
||||
for entry in entries:
|
||||
path = str(entry.get("path", ""))
|
||||
expected = str(entry.get("expected", ""))
|
||||
note = str(entry.get("note", ""))
|
||||
|
||||
if not path or not expected:
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=None, confidence=0.0,
|
||||
rgb=None, top3=[], passed=False, note=note,
|
||||
error="missing 'path' or 'expected' field",
|
||||
)
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
report.records.append(rec)
|
||||
continue
|
||||
|
||||
frame = cv2.imread(path)
|
||||
if frame is None:
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=None, confidence=0.0,
|
||||
rgb=None, top3=[], passed=False, note=note,
|
||||
error=f"cv2.imread failed for {path}",
|
||||
)
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
report.records.append(rec)
|
||||
continue
|
||||
|
||||
result = detector.step(ts=0.0, frame=frame)
|
||||
|
||||
match = result.match
|
||||
if match is None:
|
||||
detected: str | None = None
|
||||
confidence = 0.0
|
||||
else:
|
||||
detected = match.name if match.name != "UNKNOWN" else None
|
||||
confidence = match.confidence
|
||||
|
||||
rgb = result.rgb
|
||||
|
||||
# Top 3 candidates: rank palette entries by RGB distance to observed.
|
||||
top3: list[tuple[str, float]] = []
|
||||
if rgb is not None:
|
||||
scored: list[tuple[str, float]] = []
|
||||
for name, ref in palette.items():
|
||||
scored.append((name, _rgb_distance(rgb, ref)))
|
||||
scored.sort(key=lambda t: t[1])
|
||||
top3 = [(n, 1.0 / (1.0 + d / 20.0)) for n, d in scored[:3]]
|
||||
|
||||
passed = detected == expected
|
||||
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=detected,
|
||||
confidence=confidence, rgb=rgb, top3=top3, passed=passed, note=note,
|
||||
)
|
||||
|
||||
if not passed and rgb is not None and expected in palette:
|
||||
ref = palette[expected]
|
||||
tol = cfg.colors[expected].tolerance
|
||||
dist = _rgb_distance(rgb, ref)
|
||||
rec._suggestion = ( # type: ignore[attr-defined]
|
||||
f"{expected} praguri curente: RGB{ref} +/- {tol:.0f}. "
|
||||
f"Pixelul observat {rgb} e la distanta {dist:.1f} "
|
||||
f"-> recalibreaza cu acest sample."
|
||||
)
|
||||
else:
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
|
||||
report.records.append(rec)
|
||||
|
||||
return report
|
||||
214
tests/test_validate.py
Normal file
214
tests/test_validate.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Tests for atm.validate — offline calibration validation.
|
||||
|
||||
Covers the 3 tests from plan section D':
|
||||
17. test_validate_calibration_pass
|
||||
18. test_validate_calibration_fail_reports_top_candidates
|
||||
19. test_validate_calibration_file_not_found
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.config import (
|
||||
CanaryRegion,
|
||||
ColorSpec,
|
||||
Config,
|
||||
DiscordCfg,
|
||||
ROI,
|
||||
TelegramCfg,
|
||||
YAxisCalib,
|
||||
)
|
||||
from atm.detector import DetectionResult
|
||||
from atm.vision import ColorMatch
|
||||
|
||||
|
||||
def _make_config() -> Config:
|
||||
"""Minimal Config with a palette large enough to support top-3 candidates."""
|
||||
colors = {
|
||||
"turquoise": ColorSpec(rgb=(0, 200, 200), tolerance=30),
|
||||
"yellow": ColorSpec(rgb=(255, 255, 0), tolerance=30),
|
||||
"dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=30),
|
||||
"dark_red": ColorSpec(rgb=(165, 42, 42), tolerance=30),
|
||||
"light_green": ColorSpec(rgb=(144, 238, 144), tolerance=30),
|
||||
"light_red": ColorSpec(rgb=(255, 182, 193), tolerance=30),
|
||||
"gray": ColorSpec(rgb=(128, 128, 128), tolerance=30),
|
||||
"background": ColorSpec(rgb=(18, 18, 18), tolerance=15),
|
||||
}
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=ROI(x=0, y=0, w=100, h=100),
|
||||
chart_roi=ROI(x=0, y=0, w=100, h=100),
|
||||
colors=colors,
|
||||
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=100, p2_price=0.0),
|
||||
canary=CanaryRegion(
|
||||
roi=ROI(x=0, y=0, w=10, h=10),
|
||||
baseline_phash="0" * 64,
|
||||
),
|
||||
discord=DiscordCfg(webhook_url="http://localhost/fake"),
|
||||
telegram=TelegramCfg(bot_token="fake_token", chat_id="123"),
|
||||
debounce_depth=1,
|
||||
)
|
||||
|
||||
|
||||
def _write_labels(tmp_path: Path, entries: list[dict]) -> Path:
|
||||
f = tmp_path / "labels.json"
|
||||
f.write_text(json.dumps(entries), encoding="utf-8")
|
||||
return f
|
||||
|
||||
|
||||
def _write_blank_png(tmp_path: Path, name: str) -> Path:
|
||||
"""Write a trivially-valid 10x10 BGR image so cv2.imread returns non-None."""
|
||||
import cv2
|
||||
p = tmp_path / name
|
||||
arr = np.zeros((10, 10, 3), dtype=np.uint8)
|
||||
cv2.imwrite(str(p), arr)
|
||||
return p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 17: PASS path — mocked Detector.step returns expected color
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_pass(monkeypatch, tmp_path):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
img_path = _write_blank_png(tmp_path, "yellow_sample.png")
|
||||
labels = _write_labels(
|
||||
tmp_path,
|
||||
[{"path": str(img_path), "expected": "yellow", "note": "test"}],
|
||||
)
|
||||
|
||||
def fake_step(self, ts, frame=None):
|
||||
return DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=True,
|
||||
rgb=(250, 250, 5),
|
||||
match=ColorMatch(name="yellow", distance=6.0, confidence=0.94),
|
||||
accepted=True,
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("atm.detector.Detector.step", fake_step)
|
||||
|
||||
report = validate_mod.validate_calibration(labels, _make_config())
|
||||
|
||||
assert report.total == 1
|
||||
assert report.passed == 1
|
||||
assert report.failed == 0
|
||||
assert report.all_pass is True
|
||||
rec = report.records[0]
|
||||
assert rec.passed is True
|
||||
assert rec.detected == "yellow"
|
||||
assert rec.expected == "yellow"
|
||||
assert "[PASS]" in report.render()
|
||||
|
||||
# CLI wiring: exit 0
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = labels
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 18: FAIL path — Detector returns wrong color; report lists top 3
|
||||
# candidates and a SUGGESTIONS line with RGB distance.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_fail_reports_top_candidates(monkeypatch, tmp_path):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
img_path = _write_blank_png(tmp_path, "dark_red_sample.png")
|
||||
labels = _write_labels(
|
||||
tmp_path,
|
||||
[{"path": str(img_path), "expected": "dark_red", "note": "missed dark_red"}],
|
||||
)
|
||||
|
||||
# Observed RGB closer to gray than dark_red (like the real 2026-04-17 miss).
|
||||
def fake_step(self, ts, frame=None):
|
||||
return DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=True,
|
||||
rgb=(135, 62, 67),
|
||||
match=ColorMatch(name="gray", distance=45.0, confidence=0.12),
|
||||
accepted=True,
|
||||
color="gray",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("atm.detector.Detector.step", fake_step)
|
||||
|
||||
report = validate_mod.validate_calibration(labels, _make_config())
|
||||
|
||||
assert report.total == 1
|
||||
assert report.failed == 1
|
||||
assert report.all_pass is False
|
||||
|
||||
rec = report.records[0]
|
||||
assert rec.passed is False
|
||||
assert rec.detected == "gray"
|
||||
assert rec.expected == "dark_red"
|
||||
# Top 3 candidates populated (name, score) sorted by RGB distance.
|
||||
assert len(rec.top3) == 3
|
||||
names = [n for n, _ in rec.top3]
|
||||
# dark_red should appear in top candidates since observed RGB(135,62,67)
|
||||
# is reasonably close to dark_red(165,42,42).
|
||||
assert "dark_red" in names
|
||||
|
||||
rendered = report.render()
|
||||
assert "[FAIL]" in rendered
|
||||
assert "Top 3 candidates:" in rendered
|
||||
assert "SUGGESTIONS:" in rendered
|
||||
# The suggestion must mention the expected color's RGB and the measured distance.
|
||||
assert "dark_red" in rendered
|
||||
assert "(165, 42, 42)" in rendered
|
||||
|
||||
# CLI wiring: exit 1
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = labels
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 19: missing label file — clean error, non-zero exit, no stack trace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_file_not_found(monkeypatch, tmp_path, capsys):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
missing = tmp_path / "nope.json"
|
||||
|
||||
# Library-level: raises ValidationError (not bare FileNotFoundError).
|
||||
with pytest.raises(validate_mod.ValidationError) as exc_info:
|
||||
validate_mod.validate_calibration(missing, _make_config())
|
||||
assert "not found" in str(exc_info.value).lower()
|
||||
|
||||
# CLI-level: graceful sys.exit with non-zero code, message on stderr.
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = missing
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code != 0
|
||||
err = capsys.readouterr().err
|
||||
assert "not found" in err.lower()
|
||||
# Ensure no python traceback leaked through.
|
||||
assert "Traceback" not in err
|
||||
Reference in New Issue
Block a user