Also: calibrate._sample_rgb now snaps to the most-saturated pixel within 15px of the click, so rough clicks still pick up the dot's pure colour. Default dot-colour tolerance bumped 30→60 to absorb anti-aliasing. Test fixture _SAMPLED_RGB recomputed for the new 36/49 dilution (was 24/49 when sampling at the trailing edge). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
225 lines
7.8 KiB
Python
225 lines
7.8 KiB
Python
"""Tests for atm.dryrun."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from atm.config import CanaryRegion, ColorSpec, Config, DiscordCfg, ROI, TelegramCfg, YAxisCalib
|
|
from atm.dryrun import ConfusionMatrix, DryrunResult, dryrun
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config fixture
|
|
#
|
|
# The 6x6 dot at x=250..255, y=50..55 in a 100x300 frame is sampled by
|
|
# pixel_rgb(box=3) over a 7x7 patch: 24 dot pixels + 25 background (0,0,0).
|
|
# Sampled RGB = int(true_RGB * 24/49). Config colors match the sampled values
|
|
# so classify_pixel returns the correct label.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SCALE = 36 / 49 # fraction of dot pixels in the 7x7 sample box (centre-based)
|
|
|
|
# True BGR paint values → sampled RGB ≈ int(true_RGB * _SCALE)
|
|
_SAMPLED_RGB: dict[str, tuple[int, int, int]] = {
|
|
"turquoise": (0, 146, 146), # true (0, 200, 200)
|
|
"yellow": (187, 187, 0), # true (255, 255, 0)
|
|
"dark_green": (0, 73, 0), # true (0, 100, 0)
|
|
"dark_red": (102, 0, 0), # true (139, 0, 0)
|
|
"light_green": (105, 174, 105), # true (144, 238, 144)
|
|
"light_red": (187, 133, 141), # true (255, 182, 193)
|
|
"gray": (94, 94, 94), # true (128, 128, 128)
|
|
}
|
|
|
|
# True RGB values used when painting frames (before sampling dilution)
|
|
_TRUE_RGB: dict[str, tuple[int, int, int]] = {
|
|
"turquoise": (0, 200, 200),
|
|
"yellow": (255, 255, 0),
|
|
"dark_green": (0, 100, 0),
|
|
"dark_red": (139, 0, 0),
|
|
"light_green": (144, 238, 144),
|
|
"light_red": (255, 182, 193),
|
|
"gray": (128, 128, 128),
|
|
}
|
|
|
|
|
|
def _make_config() -> Config:
|
|
colors = {
|
|
name: ColorSpec(rgb=rgb, tolerance=5)
|
|
for name, rgb in _SAMPLED_RGB.items()
|
|
}
|
|
colors["background"] = ColorSpec(rgb=(0, 0, 0), tolerance=5)
|
|
return Config(
|
|
window_title="test",
|
|
dot_roi=ROI(x=0, y=0, w=300, h=100),
|
|
chart_roi=ROI(x=0, y=0, w=300, h=100),
|
|
colors=colors,
|
|
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=100, p2_price=0.0),
|
|
canary=CanaryRegion(
|
|
roi=ROI(x=0, y=0, w=10, h=10),
|
|
baseline_phash="0" * 64,
|
|
),
|
|
discord=DiscordCfg(webhook_url="http://localhost/fake"),
|
|
telegram=TelegramCfg(bot_token="fake_token", chat_id="123"),
|
|
debounce_depth=1,
|
|
)
|
|
|
|
|
|
def _make_dot_frame(rgb: tuple[int, int, int]) -> np.ndarray:
|
|
"""100x300 BGR frame with a 6x6 dot at x=250,y=50."""
|
|
frame = np.zeros((100, 300, 3), dtype=np.uint8)
|
|
frame[50:56, 250:256] = (rgb[2], rgb[1], rgb[0]) # BGR
|
|
return frame
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 1. Confusion matrix unit test — pure math, no cv2/detector
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_confusion_matrix_math() -> None:
|
|
cm = ConfusionMatrix()
|
|
cm.add("A", "A")
|
|
cm.add("A", "A")
|
|
cm.add("A", "B") # FN for A, FP for B
|
|
cm.add("B", "B")
|
|
|
|
per = cm.per_label()
|
|
|
|
# A: support=3, TP=2, FP=0 (B never predicted as A), FN=1
|
|
assert per["A"]["support"] == 3.0
|
|
assert per["A"]["precision"] == pytest.approx(1.0) # TP/(TP+FP) = 2/2
|
|
assert per["A"]["recall"] == pytest.approx(2 / 3)
|
|
assert per["A"]["f1"] == pytest.approx(2 * 1.0 * (2 / 3) / (1.0 + 2 / 3))
|
|
|
|
# B: support=1, TP=1, FP=1 (one A was predicted as B), FN=0
|
|
assert per["B"]["support"] == 1.0
|
|
assert per["B"]["precision"] == pytest.approx(0.5) # TP/(TP+FP) = 1/2
|
|
assert per["B"]["recall"] == pytest.approx(1.0)
|
|
|
|
# Overall accuracy: 3 correct out of 4
|
|
assert cm.overall_accuracy() == pytest.approx(3 / 4)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 2-5: integration tests that require atm.detector
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_dryrun_perfect_match(tmp_path: Path) -> None:
|
|
pytest.importorskip("atm.detector")
|
|
cfg = _make_config()
|
|
colors_6 = ["turquoise", "yellow", "dark_green", "dark_red", "light_green", "light_red"]
|
|
|
|
import cv2
|
|
|
|
labels: dict[str, str] = {}
|
|
for idx, name in enumerate(colors_6):
|
|
frame = _make_dot_frame(_TRUE_RGB[name])
|
|
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
|
labels[str(idx)] = name
|
|
|
|
labels_path = tmp_path / "labels.json"
|
|
labels_path.write_text(json.dumps(labels))
|
|
|
|
result = dryrun(tmp_path, labels_path, cfg)
|
|
|
|
assert result.n_samples == 6
|
|
assert result.n_labeled == 6
|
|
assert result.precision_overall == pytest.approx(1.0)
|
|
assert result.recall_overall == pytest.approx(1.0)
|
|
assert result.acceptance_pass is True
|
|
|
|
# Diagonal-only: each label predicts only itself
|
|
per = result.confusion.per_label()
|
|
for name in colors_6:
|
|
assert result.confusion.counts[name] == {name: 1}, (
|
|
f"Expected diagonal for {name}, got {result.confusion.counts[name]}"
|
|
)
|
|
assert all(m["precision"] == pytest.approx(1.0) for m in per.values())
|
|
assert all(m["recall"] == pytest.approx(1.0) for m in per.values())
|
|
|
|
|
|
def test_dryrun_with_unlabeled_sample(tmp_path: Path) -> None:
|
|
pytest.importorskip("atm.detector")
|
|
cfg = _make_config()
|
|
|
|
import cv2
|
|
|
|
# Write 3 labeled frames + 1 unlabeled
|
|
labels: dict[str, str] = {}
|
|
for idx, name in enumerate(["turquoise", "yellow", "dark_green"]):
|
|
frame = _make_dot_frame(_TRUE_RGB[name])
|
|
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
|
labels[str(idx)] = name
|
|
|
|
# Frame "3" exists on disk but has NO label entry
|
|
unlabeled_frame = _make_dot_frame(_TRUE_RGB["dark_red"])
|
|
cv2.imwrite(str(tmp_path / "3.png"), unlabeled_frame)
|
|
|
|
labels_path = tmp_path / "labels.json"
|
|
labels_path.write_text(json.dumps(labels))
|
|
|
|
result = dryrun(tmp_path, labels_path, cfg)
|
|
|
|
assert result.n_samples == 4 # 4 PNGs on disk
|
|
assert result.n_labeled == 3 # only 3 labeled
|
|
# "3" not in confusion
|
|
assert "3" not in result.confusion.counts
|
|
# Only the 3 labeled colors appear
|
|
assert set(result.confusion.counts.keys()) == {"turquoise", "yellow", "dark_green"}
|
|
|
|
|
|
def test_dryrun_misclassification_fails_gate(tmp_path: Path) -> None:
|
|
pytest.importorskip("atm.detector")
|
|
cfg = _make_config()
|
|
|
|
import cv2
|
|
|
|
colors_6 = ["turquoise", "yellow", "dark_green", "dark_red", "light_green", "light_red"]
|
|
labels: dict[str, str] = {}
|
|
for idx, name in enumerate(colors_6):
|
|
frame = _make_dot_frame(_TRUE_RGB[name])
|
|
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
|
labels[str(idx)] = name
|
|
|
|
# Swap label of frame 0 (turquoise dot → labeled as "yellow")
|
|
labels["0"] = "yellow"
|
|
|
|
labels_path = tmp_path / "labels.json"
|
|
labels_path.write_text(json.dumps(labels))
|
|
|
|
result = dryrun(tmp_path, labels_path, cfg)
|
|
|
|
assert result.acceptance_pass is False
|
|
# recall for "yellow" drops: one yellow-labeled frame predicted as turquoise
|
|
per = result.confusion.per_label()
|
|
assert per["yellow"]["recall"] < 1.0
|
|
|
|
|
|
def test_fire_event_captured(tmp_path: Path) -> None:
|
|
pytest.importorskip("atm.detector")
|
|
cfg = _make_config()
|
|
|
|
import cv2
|
|
|
|
# Sequence that triggers a BUY fire: turquoise → gray → dark_green → light_green
|
|
sequence = ["turquoise", "gray", "dark_green", "light_green"]
|
|
labels: dict[str, str] = {}
|
|
for idx, name in enumerate(sequence):
|
|
frame = _make_dot_frame(_TRUE_RGB[name])
|
|
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
|
labels[str(idx)] = name
|
|
|
|
labels_path = tmp_path / "labels.json"
|
|
labels_path.write_text(json.dumps(labels))
|
|
|
|
result = dryrun(tmp_path, labels_path, cfg)
|
|
|
|
assert len(result.fire_events) == 1
|
|
ev = result.fire_events[0]
|
|
assert ev["direction"] == "BUY"
|
|
assert ev["ts"] == pytest.approx(15.0) # i=3 → ts=3*5.0
|
|
assert ev["sample"] == "3"
|