Compare commits
6 Commits
153196f762
...
8bae507bbd
| Author | SHA1 | Date | |
|---|---|---|---|
| 8bae507bbd | |||
| 23865776e3 | |||
| 54f55752c1 | |||
| 8b53b8d3c9 | |||
| 9cf49caf8a | |||
| c5024ce600 |
@@ -81,6 +81,24 @@ low_conf_run = 3
|
||||
phaseb_timeout_s = 600
|
||||
dead_letter_path = "logs/dead_letter.jsonl"
|
||||
|
||||
# Alert-behavior toggles (not screenshot-attachment; see attach_screenshots below).
|
||||
# fire_on_phase_skip: emit a backstop "PHASE SKIP" alert when the FSM observes
|
||||
# ARMED → light_green/light_red directly (skipping the dark prime). Default on
|
||||
# because missing a fire is worse than a false-positive phase-skip alert.
|
||||
[options.alerts]
|
||||
fire_on_phase_skip = true
|
||||
|
||||
# Operating hours — detection only runs on allowed weekdays + HH:MM window.
|
||||
# Timezone is the source of truth (NYSE local); the runtime converts tick
|
||||
# timestamps to this zone so DST rollovers stay aligned with the exchange.
|
||||
# Override from CLI with --tz / --weekdays / --oh-start / --oh-stop.
|
||||
[options.operating_hours]
|
||||
enabled = false
|
||||
timezone = "America/New_York"
|
||||
weekdays = ["MON", "TUE", "WED", "THU", "FRI"]
|
||||
start_hhmm = "09:30"
|
||||
stop_hhmm = "16:00"
|
||||
|
||||
# Per-kind screenshot-attach toggles. All default to true on upgrade.
|
||||
# Accepts either a bare bool (legacy: attach_screenshots = true) or this table.
|
||||
[options.attach_screenshots]
|
||||
|
||||
33
samples/calibration_labels.README.md
Normal file
33
samples/calibration_labels.README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# calibration_labels.json — schema
|
||||
|
||||
Used by `atm validate-calibration` to check that the current color calibration
|
||||
classifies known-good screenshots correctly before a live session.
|
||||
|
||||
## Schema
|
||||
|
||||
A JSON array of entries. Each entry:
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|------------|---------|----------|----------------------------------------------------------------|
|
||||
| `path` | string | yes | Path to a PNG frame (relative to CWD or absolute). |
|
||||
| `expected` | string | yes | Expected color name: one of `turquoise`, `yellow`, `dark_green`, `dark_red`, `light_green`, `light_red`, `gray`. |
|
||||
| `note` | string | no | Freeform annotation; shown in SUGGESTIONS output. |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
atm validate-calibration samples/calibration_labels.json
|
||||
```
|
||||
|
||||
Exit codes:
|
||||
- `0` — every sample PASS
|
||||
- `1` — one or more FAIL
|
||||
- `2` — label file missing or malformed JSON
|
||||
|
||||
## Adding new samples
|
||||
|
||||
1. Find a screenshot in `logs/fires/` whose dot color you can verify by eye.
|
||||
2. Append an entry with `path`, `expected`, and an optional `note`.
|
||||
3. Re-run validation. If it FAILs, the SUGGESTIONS section will tell you the
|
||||
RGB distance between the observed pixel and the expected color's center —
|
||||
use that as input for `atm calibrate`.
|
||||
17
samples/calibration_labels.json
Normal file
17
samples/calibration_labels.json
Normal file
@@ -0,0 +1,17 @@
|
||||
[
|
||||
{
|
||||
"path": "logs/fires/20260417_201500_arm_sell.png",
|
||||
"expected": "yellow",
|
||||
"note": "first arm of SELL cycle 2026-04-17"
|
||||
},
|
||||
{
|
||||
"path": "logs/fires/20260417_205302_ss.png",
|
||||
"expected": "dark_red",
|
||||
"note": "user confirmed via screenshot (missed live alert)"
|
||||
},
|
||||
{
|
||||
"path": "logs/fires/20260417_210441_ss.png",
|
||||
"expected": "light_red",
|
||||
"note": "fire phase (missed live alert)"
|
||||
}
|
||||
]
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Layout drift detector via perceptual hash comparison."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import Config
|
||||
from .vision import crop_roi, hamming_hex, phash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CanaryResult:
|
||||
@@ -28,10 +32,15 @@ class Canary:
|
||||
self,
|
||||
cfg: Config,
|
||||
pause_flag_path: Path | None = None,
|
||||
on_pause_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
self._cfg = cfg
|
||||
self._pause_flag_path = pause_flag_path
|
||||
self._paused = False
|
||||
# Single-shot callback invoked exactly once per not_paused→paused transition.
|
||||
# Wrapped in try/except at call site so a faulty notifier never breaks
|
||||
# the detection cycle.
|
||||
self._on_pause = on_pause_callback
|
||||
|
||||
def check(self, frame_bgr: np.ndarray) -> CanaryResult:
|
||||
roi_img = crop_roi(frame_bgr, self._cfg.canary.roi)
|
||||
@@ -43,6 +52,12 @@ class Canary:
|
||||
self._paused = True
|
||||
if self._pause_flag_path is not None:
|
||||
self._pause_flag_path.write_text("paused", encoding="utf-8")
|
||||
if self._on_pause is not None:
|
||||
try:
|
||||
self._on_pause(distance)
|
||||
except Exception as exc:
|
||||
# Never let a notifier hiccup abort the detection cycle.
|
||||
logger.warning("canary on_pause_callback raised: %s", exc)
|
||||
|
||||
return CanaryResult(distance=distance, drifted=drifted, paused=self._paused)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CommandAction = Literal["set_interval", "stop", "status", "ss"]
|
||||
CommandAction = Literal["set_interval", "stop", "status", "ss", "pause", "resume"]
|
||||
|
||||
_BASE = "https://api.telegram.org/bot{token}/{method}"
|
||||
|
||||
@@ -154,6 +154,13 @@ class TelegramPoller:
|
||||
return Command(action="status")
|
||||
if t in ("ss", "screenshot"):
|
||||
return Command(action="ss")
|
||||
if t == "pause":
|
||||
return Command(action="pause")
|
||||
if t == "resume":
|
||||
return Command(action="resume")
|
||||
if t == "resume force":
|
||||
# value=1 signals force: also lift canary drift-pause, not just user pause.
|
||||
return Command(action="resume", value=1)
|
||||
# "3" → set_interval 3 minutes → 180s; "interval 3" also accepted
|
||||
parts = t.split()
|
||||
if len(parts) == 1 and parts[0].isdigit():
|
||||
|
||||
@@ -5,6 +5,9 @@ import tomllib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
_VALID_WEEKDAYS: tuple[str, ...] = ("MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN")
|
||||
|
||||
DotColor = Literal[
|
||||
"turquoise", "yellow",
|
||||
@@ -97,6 +100,43 @@ class AlertsCfg:
|
||||
trigger: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperatingHoursCfg:
|
||||
"""Session window: only run detection on allowed weekdays within HH:MM range.
|
||||
|
||||
Timezone is the source of truth for the exchange (default America/New_York
|
||||
for NYSE). Start/stop are compared against the clock in that timezone.
|
||||
Weekday check uses datetime.weekday() + a fixed MON..SUN list to stay
|
||||
locale-independent (strftime('%a') returns localized names).
|
||||
|
||||
The ZoneInfo is cached at config load time so the detection loop doesn't
|
||||
pay per-tick lookup cost.
|
||||
|
||||
NOTE: this dataclass is mutable (non-frozen) so Config._from_dict can stash
|
||||
the resolved ZoneInfo onto `_tz_cache` after validation. Treat fields as
|
||||
read-only at runtime.
|
||||
"""
|
||||
enabled: bool = False
|
||||
timezone: str = "America/New_York"
|
||||
weekdays: tuple[str, ...] = ("MON", "TUE", "WED", "THU", "FRI")
|
||||
start_hhmm: str = "09:30"
|
||||
stop_hhmm: str = "16:00"
|
||||
# Populated by Config._from_dict; None for disabled or failed-load cases.
|
||||
_tz_cache: ZoneInfo | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlertBehaviorCfg:
|
||||
"""Alert behavior knobs (not screenshot toggles).
|
||||
|
||||
`fire_on_phase_skip`: backstop alert when FSM observes ARMED→light_{green,red}
|
||||
directly (skipping the dark prime phase — often means dark color was
|
||||
mis-classified as gray). Default True: missing a fire is worse than a noisy
|
||||
phase-skip alert. Disable via `[options.alerts] fire_on_phase_skip = false`.
|
||||
"""
|
||||
fire_on_phase_skip: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
window_title: str
|
||||
@@ -117,6 +157,8 @@ class Config:
|
||||
phaseb_timeout_s: int = 600
|
||||
dead_letter_path: str = "logs/dead_letter.jsonl"
|
||||
attach_screenshots: AlertsCfg = field(default_factory=AlertsCfg)
|
||||
alerts: AlertBehaviorCfg = field(default_factory=AlertBehaviorCfg)
|
||||
operating_hours: OperatingHoursCfg = field(default_factory=OperatingHoursCfg)
|
||||
config_version: str = "unknown"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -184,6 +226,36 @@ class Config:
|
||||
)
|
||||
else:
|
||||
attach = AlertsCfg()
|
||||
|
||||
alerts_dict = opts.get("alerts", {}) or {}
|
||||
alert_behavior = AlertBehaviorCfg(
|
||||
fire_on_phase_skip=bool(alerts_dict.get("fire_on_phase_skip", True)),
|
||||
)
|
||||
|
||||
oh_dict = opts.get("operating_hours", {}) or {}
|
||||
oh_weekdays = tuple(
|
||||
str(w).upper() for w in oh_dict.get("weekdays", ("MON", "TUE", "WED", "THU", "FRI"))
|
||||
)
|
||||
for wd in oh_weekdays:
|
||||
if wd not in _VALID_WEEKDAYS:
|
||||
raise ValueError(
|
||||
f"operating_hours.weekdays contains invalid day {wd!r}; "
|
||||
f"expected any of {_VALID_WEEKDAYS}"
|
||||
)
|
||||
oh = OperatingHoursCfg(
|
||||
enabled=bool(oh_dict.get("enabled", False)),
|
||||
timezone=str(oh_dict.get("timezone", "America/New_York")),
|
||||
weekdays=oh_weekdays,
|
||||
start_hhmm=str(oh_dict.get("start_hhmm", "09:30")),
|
||||
stop_hhmm=str(oh_dict.get("stop_hhmm", "16:00")),
|
||||
)
|
||||
if oh.enabled:
|
||||
try:
|
||||
oh._tz_cache = ZoneInfo(oh.timezone)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"operating_hours.timezone {oh.timezone!r} invalid: {exc}"
|
||||
) from exc
|
||||
return cls(
|
||||
window_title=data["window_title"],
|
||||
dot_roi=roi,
|
||||
@@ -203,5 +275,7 @@ class Config:
|
||||
phaseb_timeout_s=int(opts.get("phaseb_timeout_s", 600)),
|
||||
dead_letter_path=opts.get("dead_letter_path", "logs/dead_letter.jsonl"),
|
||||
attach_screenshots=attach,
|
||||
alerts=alert_behavior,
|
||||
operating_hours=oh,
|
||||
config_version=version,
|
||||
)
|
||||
|
||||
636
src/atm/main.py
636
src/atm/main.py
@@ -92,6 +92,23 @@ def main(argv=None) -> None:
|
||||
help="Stop at local HH:MM (overrides --duration). If the time is in "
|
||||
"the past when the loop starts, rolls over to tomorrow.",
|
||||
)
|
||||
p_run.add_argument(
|
||||
"--tz", metavar="ZONE", default=None,
|
||||
help="Override operating_hours.timezone (e.g. America/New_York).",
|
||||
)
|
||||
p_run.add_argument(
|
||||
"--weekdays", metavar="DAYS", default=None,
|
||||
help="Override operating_hours.weekdays. Accepts comma list "
|
||||
"(MON,TUE) or range (MON-FRI).",
|
||||
)
|
||||
p_run.add_argument(
|
||||
"--oh-start", metavar="HH:MM", default=None,
|
||||
help="Override operating_hours.start_hhmm (exchange-local).",
|
||||
)
|
||||
p_run.add_argument(
|
||||
"--oh-stop", metavar="HH:MM", default=None,
|
||||
help="Override operating_hours.stop_hhmm (exchange-local).",
|
||||
)
|
||||
|
||||
# journal
|
||||
p_journal = sub.add_parser("journal", help="Add a trade journal entry interactively")
|
||||
@@ -118,6 +135,16 @@ def main(argv=None) -> None:
|
||||
metavar="PATH", help="Journal JSONL file (default: trades.jsonl)",
|
||||
)
|
||||
|
||||
# validate-calibration
|
||||
p_valid = sub.add_parser(
|
||||
"validate-calibration",
|
||||
help="Offline: run Detector on labeled frames and report PASS/FAIL",
|
||||
)
|
||||
p_valid.add_argument(
|
||||
"label_file", type=Path, metavar="LABEL_FILE",
|
||||
help="JSON array with [{path, expected, note?}, ...] entries",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_dispatch = {
|
||||
@@ -128,6 +155,7 @@ def main(argv=None) -> None:
|
||||
"debug": _cmd_debug,
|
||||
"journal": _cmd_journal,
|
||||
"report": _cmd_report,
|
||||
"validate-calibration": _cmd_validate_calibration,
|
||||
}
|
||||
_dispatch[args.command](args)
|
||||
|
||||
@@ -171,6 +199,7 @@ def _cmd_dryrun(args) -> None:
|
||||
|
||||
def _cmd_run(args) -> None:
|
||||
cfg = Config.load_current(Path("configs"))
|
||||
cfg = _apply_operating_hours_cli_overrides(cfg, args)
|
||||
capture_stub = args.capture_stub or bool(os.environ.get("ATM_STUB_CAPTURE"))
|
||||
|
||||
# --start-at HH:MM: sleep until the next occurrence of that local wall-clock time
|
||||
@@ -230,6 +259,66 @@ def _cmd_run(args) -> None:
|
||||
run_live(cfg, duration_s=duration_s, capture_stub=capture_stub)
|
||||
|
||||
|
||||
_WEEKDAY_ORDER = ("MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN")
|
||||
|
||||
|
||||
def _parse_weekdays_arg(raw: str) -> tuple[str, ...]:
|
||||
"""Accept 'MON,TUE,WED' or 'MON-FRI'. Case-insensitive."""
|
||||
txt = raw.strip().upper()
|
||||
if "-" in txt and "," not in txt:
|
||||
a, b = (p.strip() for p in txt.split("-", 1))
|
||||
if a not in _WEEKDAY_ORDER or b not in _WEEKDAY_ORDER:
|
||||
raise ValueError(f"unknown weekday(s) in range {raw!r}")
|
||||
i, j = _WEEKDAY_ORDER.index(a), _WEEKDAY_ORDER.index(b)
|
||||
if i > j:
|
||||
raise ValueError(f"weekday range reversed: {raw!r}")
|
||||
return tuple(_WEEKDAY_ORDER[i : j + 1])
|
||||
days = tuple(d.strip() for d in txt.split(",") if d.strip())
|
||||
for d in days:
|
||||
if d not in _WEEKDAY_ORDER:
|
||||
raise ValueError(f"unknown weekday {d!r} (valid: {_WEEKDAY_ORDER})")
|
||||
return days
|
||||
|
||||
|
||||
def _apply_operating_hours_cli_overrides(cfg, args):
|
||||
"""Return cfg (possibly new) with operating_hours overridden by CLI flags.
|
||||
|
||||
Config is a frozen dataclass, but operating_hours is non-frozen by design
|
||||
so we can tweak it in-place and recompute the tz cache. CLI flags implicitly
|
||||
enable operating_hours even if the TOML had it disabled.
|
||||
"""
|
||||
import dataclasses as _dc
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
oh = cfg.operating_hours
|
||||
any_override = any(
|
||||
getattr(args, k, None)
|
||||
for k in ("tz", "weekdays", "oh_start", "oh_stop")
|
||||
)
|
||||
if not any_override:
|
||||
return cfg
|
||||
|
||||
new_tz = args.tz if args.tz else oh.timezone
|
||||
try:
|
||||
tz_cache = ZoneInfo(new_tz)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
sys.exit(f"--tz {new_tz!r} invalid: {exc}")
|
||||
|
||||
new_weekdays = _parse_weekdays_arg(args.weekdays) if args.weekdays else oh.weekdays
|
||||
new_start = args.oh_start if args.oh_start else oh.start_hhmm
|
||||
new_stop = args.oh_stop if args.oh_stop else oh.stop_hhmm
|
||||
oh.enabled = True
|
||||
oh.timezone = new_tz
|
||||
oh.weekdays = new_weekdays
|
||||
oh.start_hhmm = new_start
|
||||
oh.stop_hhmm = new_stop
|
||||
oh._tz_cache = tz_cache
|
||||
# Config is frozen but operating_hours is a mutable field object —
|
||||
# mutating it in place is sufficient; no dataclasses.replace needed.
|
||||
_ = _dc # keep import for future use
|
||||
return cfg
|
||||
|
||||
|
||||
def _cmd_journal(args) -> None:
|
||||
try:
|
||||
from atm.journal import Journal, prompt_entry
|
||||
@@ -340,6 +429,37 @@ def _cmd_report(args) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _cmd_validate_calibration(args) -> None:
|
||||
"""Run offline calibration validation; exit 0 on 100% PASS, 1 otherwise."""
|
||||
try:
|
||||
from atm.validate import validate_calibration, ValidationError
|
||||
except ImportError as exc:
|
||||
sys.exit(f"validate module not available: {exc}")
|
||||
|
||||
label_file = Path(args.label_file)
|
||||
try:
|
||||
cfg = Config.load_current(Path("configs"))
|
||||
except FileNotFoundError as exc:
|
||||
sys.exit(f"config not found: {exc}")
|
||||
|
||||
try:
|
||||
config_name = ""
|
||||
cur_ptr = Path("configs") / "current.txt"
|
||||
if cur_ptr.exists():
|
||||
config_name = cur_ptr.read_text(encoding="utf-8").strip()
|
||||
except Exception:
|
||||
config_name = ""
|
||||
|
||||
try:
|
||||
report = validate_calibration(label_file, cfg, config_name=config_name)
|
||||
except ValidationError as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
print(report.render())
|
||||
sys.exit(0 if report.all_pass else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live loop
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -415,6 +535,7 @@ def _handle_tick(
|
||||
audit: _AuditLike,
|
||||
first_accepted: bool,
|
||||
snapshot: Snapshot | None = None,
|
||||
cfg: Any = None,
|
||||
) -> Transition | None:
|
||||
"""Feed FSM for a single accepted color and dispatch arm/prime/late_start
|
||||
alerts. Returns the final Transition, or None when the color triggered a
|
||||
@@ -518,6 +639,31 @@ def _handle_tick(
|
||||
image_path=snap(prime_kind, prime_label),
|
||||
direction=direction,
|
||||
))
|
||||
# PHASE_SKIP fire backstop: ARMED→light_{green,red} directly (dark was missed).
|
||||
# Emits a fire-equivalent alert when cfg.alerts.fire_on_phase_skip (default True).
|
||||
# Uses public FSM lockout API (is_locked/record_fire) to reuse the standard
|
||||
# 240s dedupe window so bouncing detectors do not spam the user.
|
||||
elif tr.reason == "phase_skip" and color in ("light_green", "light_red"):
|
||||
flag_on = True
|
||||
if cfg is not None:
|
||||
alerts_cfg = getattr(cfg, "alerts", None)
|
||||
if alerts_cfg is not None:
|
||||
flag_on = bool(getattr(alerts_cfg, "fire_on_phase_skip", True))
|
||||
if flag_on:
|
||||
direction = "BUY" if color == "light_green" else "SELL"
|
||||
if not fsm.is_locked(direction, now):
|
||||
fsm.record_fire(direction, now)
|
||||
dark_name = "dark_green" if direction == "BUY" else "dark_red"
|
||||
notifier.send(Alert(
|
||||
kind="phase_skip_fire",
|
||||
title=f"PHASE SKIP {direction} — {dark_name} nu a fost detectat",
|
||||
body=(
|
||||
"Verifică chart-ul manual. Posibil necalibrare culoare "
|
||||
f"(observat {color} direct după armare)."
|
||||
),
|
||||
image_path=snap("phase_skip", f"phase_skip_{direction.lower()}"),
|
||||
direction=direction,
|
||||
))
|
||||
return tr
|
||||
|
||||
|
||||
@@ -531,6 +677,127 @@ class _TickSyncResult:
|
||||
new_color: str | None = None # corpus sample color when changed
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunContext:
|
||||
"""Dependencies passed to module-scope detection-loop helpers.
|
||||
|
||||
Keeps `_run_tick`, `_handle_fsm_result`, `_drain_cmd_queue`, and
|
||||
`_dispatch_command` at module scope so they are directly unit-testable
|
||||
without reconstructing `run_live_async`.
|
||||
"""
|
||||
cfg: Any
|
||||
capture: Callable
|
||||
canary: Any
|
||||
detector: Any
|
||||
fsm: Any
|
||||
notifier: _NotifierLike
|
||||
audit: _AuditLike
|
||||
detection_log: _AuditLike
|
||||
scheduler: Any
|
||||
samples_dir: Path
|
||||
fires_dir: Path
|
||||
cmd_queue: Any # asyncio.Queue[Command]
|
||||
state: Any # carries first_accepted, last_saved_color, levels_extractor, fire_count, start
|
||||
levels_extractor_factory: Callable # builds LevelsExtractor(cfg, trigger, now)
|
||||
lifecycle: Any = None # LifecycleState — window + user_paused tracking
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LoopState:
|
||||
"""Per-loop mutable state (previously closure nonlocals)."""
|
||||
first_accepted: bool = True
|
||||
last_saved_color: str | None = None
|
||||
levels_extractor: Any = None
|
||||
fire_count: int = 0
|
||||
start: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleState:
|
||||
"""Tracks user-pause / out-of-window state across detection ticks.
|
||||
|
||||
last_window_state: None at startup so _maybe_log_transition can seed it
|
||||
without emitting a spurious market_open alert on the first in-window tick.
|
||||
"""
|
||||
user_paused: bool = False
|
||||
last_window_state: str | None = None # "open" / "closed" / None (uninitialized)
|
||||
|
||||
|
||||
# Locale-independent weekday names; index matches datetime.weekday() (MON=0).
|
||||
_WEEKDAY_NAMES: tuple[str, ...] = ("MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN")
|
||||
|
||||
|
||||
def _should_skip(now_ts: float, state: LifecycleState, cfg, canary) -> str | None:
|
||||
"""Return a reason string if detection should be skipped, else None.
|
||||
|
||||
Order: user_paused > canary drift > operating-hours window. Uses the
|
||||
ZoneInfo cached on cfg.operating_hours._tz_cache (populated at config load)
|
||||
to avoid per-tick tz lookup cost.
|
||||
"""
|
||||
if state.user_paused:
|
||||
return "user_paused"
|
||||
if getattr(canary, "is_paused", False):
|
||||
return "drift_paused"
|
||||
oh = getattr(cfg, "operating_hours", None)
|
||||
if oh is None or not oh.enabled:
|
||||
return None
|
||||
tz = getattr(oh, "_tz_cache", None)
|
||||
if tz is None:
|
||||
# Enabled but no tz resolved — skip the check rather than crash mid-loop.
|
||||
return None
|
||||
now_exchange = datetime.fromtimestamp(now_ts, tz=tz)
|
||||
# weekday() = 0..6 (MON..SUN). Locale-free; strftime('%a') is not.
|
||||
if _WEEKDAY_NAMES[now_exchange.weekday()] not in oh.weekdays:
|
||||
return "out_of_window_weekend"
|
||||
hhmm = now_exchange.strftime("%H:%M")
|
||||
if hhmm < oh.start_hhmm or hhmm >= oh.stop_hhmm:
|
||||
return "out_of_window_hours"
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_log_transition(
|
||||
reason: str | None,
|
||||
state: LifecycleState,
|
||||
now: float,
|
||||
audit: _AuditLike,
|
||||
notifier: _NotifierLike,
|
||||
) -> None:
|
||||
"""Log market_open / market_closed exactly once per transition.
|
||||
|
||||
Startup guard (R2): when last_window_state is None we just seed it; no
|
||||
alert/audit event is emitted for the initial evaluation. This prevents a
|
||||
spurious market_open alert when run_live_async starts in-window.
|
||||
"""
|
||||
if reason is None:
|
||||
window_reason = "open"
|
||||
elif reason.startswith("out_of_window"):
|
||||
window_reason = "closed"
|
||||
else:
|
||||
# user_paused / drift_paused don't change market window state
|
||||
return
|
||||
|
||||
if window_reason == state.last_window_state:
|
||||
return
|
||||
|
||||
if state.last_window_state is None:
|
||||
state.last_window_state = window_reason
|
||||
return
|
||||
|
||||
event_name = "market_open" if window_reason == "open" else "market_closed"
|
||||
audit.log({"ts": now, "event": event_name, "reason": reason})
|
||||
body = (
|
||||
"Piața închisă — monitorizare pauzată până la următoarea deschidere"
|
||||
if event_name == "market_closed"
|
||||
else "Piața deschisă — monitorizare reluată"
|
||||
)
|
||||
notifier.send(Alert(
|
||||
kind="status",
|
||||
title=event_name.replace("_", " ").title(),
|
||||
body=body,
|
||||
))
|
||||
state.last_window_state = window_reason
|
||||
|
||||
|
||||
def _sync_detection_tick(
|
||||
capture: Callable,
|
||||
canary: Any,
|
||||
@@ -584,7 +851,7 @@ def _sync_detection_tick(
|
||||
canary_ok=True,
|
||||
)
|
||||
|
||||
tr = _handle_tick(fsm, res.color, now, notifier, audit, is_first, snapshot=_snapshot)
|
||||
tr = _handle_tick(fsm, res.color, now, notifier, audit, is_first, snapshot=_snapshot, cfg=cfg)
|
||||
|
||||
if tr is None:
|
||||
return _TickSyncResult(frame=frame, res=res, first_consumed=is_first, late_start=True)
|
||||
@@ -622,6 +889,208 @@ def _sync_detection_tick(
|
||||
)
|
||||
|
||||
|
||||
async def _run_tick(ctx: RunContext) -> _TickSyncResult:
|
||||
"""Execute one `_sync_detection_tick` in a thread; returns result or empty.
|
||||
|
||||
Lifecycle gating (user pause / operating hours / drift) happens here, not
|
||||
inside the sync tick, so the async loop can still drain commands and emit
|
||||
market_open / market_closed transitions even when the heavy detection
|
||||
work is skipped.
|
||||
"""
|
||||
now = time.time()
|
||||
if ctx.lifecycle is not None:
|
||||
skip = _should_skip(now, ctx.lifecycle, ctx.cfg, ctx.canary)
|
||||
_maybe_log_transition(skip, ctx.lifecycle, now, ctx.audit, ctx.notifier)
|
||||
if skip is not None:
|
||||
# No detection this tick. Empty result → _handle_fsm_result no-op.
|
||||
return _TickSyncResult()
|
||||
return await asyncio.to_thread(
|
||||
_sync_detection_tick,
|
||||
ctx.capture, ctx.canary, ctx.cfg, ctx.detector, ctx.fsm,
|
||||
ctx.notifier, ctx.audit, ctx.detection_log,
|
||||
ctx.fires_dir, ctx.state.first_accepted, ctx.state.last_saved_color,
|
||||
now, ctx.samples_dir,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_fsm_result(ctx: RunContext, result: _TickSyncResult) -> None:
|
||||
"""Scheduler start/stop + levels extraction. No-op if res is None/late_start."""
|
||||
if result.first_consumed:
|
||||
ctx.state.first_accepted = False
|
||||
if result.new_color is not None:
|
||||
ctx.state.last_saved_color = result.new_color
|
||||
|
||||
tr = result.tr
|
||||
res = result.res
|
||||
|
||||
if result.late_start or res is None:
|
||||
return
|
||||
|
||||
if tr is not None and getattr(res, "accepted", False) and getattr(res, "color", None):
|
||||
if tr.reason == "prime" and not ctx.scheduler.is_running:
|
||||
ctx.scheduler.start(ctx.cfg.telegram.auto_poll_interval_s)
|
||||
ctx.audit.log({"ts": time.time(), "event": "scheduler_started", "reason": "primed"})
|
||||
elif tr.reason in ("fire", "cooled", "phase_skip", "opposite_rearm") and ctx.scheduler.is_running:
|
||||
ctx.scheduler.stop()
|
||||
ctx.audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": tr.reason})
|
||||
|
||||
if tr is not None and tr.trigger and not tr.locked:
|
||||
ctx.state.fire_count += 1
|
||||
if ctx.scheduler.is_running:
|
||||
ctx.scheduler.stop()
|
||||
ctx.audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": "fire"})
|
||||
ctx.state.levels_extractor = ctx.levels_extractor_factory(ctx.cfg, tr.trigger, time.time())
|
||||
|
||||
if ctx.state.levels_extractor is not None and result.frame is not None:
|
||||
lr = ctx.state.levels_extractor.step(result.frame, time.time())
|
||||
if lr.status in ("complete", "timeout"):
|
||||
if lr.status == "complete" and lr.levels:
|
||||
ctx.notifier.send(Alert(
|
||||
kind="levels",
|
||||
title="Niveluri",
|
||||
body=(
|
||||
f"SL={lr.levels.sl} "
|
||||
f"TP1={lr.levels.tp1} "
|
||||
f"TP2={lr.levels.tp2}"
|
||||
),
|
||||
))
|
||||
ctx.state.levels_extractor = None
|
||||
|
||||
|
||||
async def _dispatch_command(ctx: RunContext, cmd) -> None:
|
||||
"""Process a single Command. Exceptions bubble — caller wraps in try/except."""
|
||||
cfg = ctx.cfg
|
||||
if cmd.action == "set_interval":
|
||||
secs = cmd.value or cfg.telegram.auto_poll_interval_s
|
||||
ctx.scheduler.start(secs)
|
||||
ctx.audit.log({"ts": time.time(), "event": "scheduler_started", "reason": "set_interval", "interval_s": secs})
|
||||
ctx.notifier.send(Alert(kind="status", title=f"Polling activ — interval {secs // 60} min", body=""))
|
||||
elif cmd.action == "stop":
|
||||
if ctx.scheduler.is_running:
|
||||
ctx.scheduler.stop()
|
||||
ctx.audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": "command_stop"})
|
||||
ctx.notifier.send(Alert(kind="status", title="Polling oprit", body=""))
|
||||
else:
|
||||
ctx.notifier.send(Alert(kind="status", title="Polling nu este activ", body=""))
|
||||
elif cmd.action == "status":
|
||||
uptime_s = time.monotonic() - ctx.state.start
|
||||
last_roll = ctx.detector.rolling[-1] if ctx.detector.rolling else None
|
||||
last_conf = f"{last_roll.match.confidence:.2f}" if last_roll and last_roll.match else "—"
|
||||
last_color = (
|
||||
(last_roll.color or last_roll.match.name) if last_roll and last_roll.match else "—"
|
||||
) if last_roll else "—"
|
||||
sched_info = (
|
||||
f"activ @{ctx.scheduler.interval_s // 60}min" if ctx.scheduler.interval_s else "activ"
|
||||
) if ctx.scheduler.is_running else "oprit"
|
||||
canary_info = "drift (pauze)" if ctx.canary.is_paused else "ok"
|
||||
|
||||
# Active / pause reason + window state
|
||||
active_info = "activ"
|
||||
window_info = "—"
|
||||
if ctx.lifecycle is not None:
|
||||
skip = _should_skip(time.time(), ctx.lifecycle, ctx.cfg, ctx.canary)
|
||||
if skip is not None:
|
||||
active_info = f"pauzat:{skip}"
|
||||
oh = getattr(ctx.cfg, "operating_hours", None)
|
||||
if oh is not None and oh.enabled:
|
||||
window_info = ctx.lifecycle.last_window_state or "—"
|
||||
else:
|
||||
window_info = "always_on"
|
||||
|
||||
body = (
|
||||
f"Stare: {ctx.fsm.state.value}\n"
|
||||
f"Activ: {active_info} | Fereastră: {window_info}\n"
|
||||
f"Ultima detecție: {last_color} (conf {last_conf})\n"
|
||||
f"Uptime: {uptime_s / 3600:.1f}h | Semnale: {ctx.state.fire_count}\n"
|
||||
f"Poller: {sched_info} | Canary: {canary_info}"
|
||||
)
|
||||
ctx.notifier.send(Alert(kind="status", title="ATM Status", body=body))
|
||||
elif cmd.action == "ss":
|
||||
now_ss = time.time()
|
||||
frame_ss = await asyncio.to_thread(ctx.capture)
|
||||
if frame_ss is None:
|
||||
ctx.notifier.send(Alert(
|
||||
kind="warn",
|
||||
title="Captură eșuată — verificați fereastra TradeStation",
|
||||
body="",
|
||||
))
|
||||
return
|
||||
path_ss = await asyncio.to_thread(
|
||||
_save_annotated_frame, frame_ss, ctx.cfg, ctx.fires_dir, "ss", now_ss, ctx.audit,
|
||||
)
|
||||
ctx.audit.log({"ts": now_ss, "event": "screenshot_sent", "path": str(path_ss) if path_ss else None})
|
||||
ctx.notifier.send(Alert(kind="screenshot", title="Screenshot manual", body="", image_path=path_ss))
|
||||
elif cmd.action == "pause":
|
||||
# User manually stops monitoring. Canary drift state is untouched.
|
||||
if ctx.lifecycle is not None:
|
||||
ctx.lifecycle.user_paused = True
|
||||
ctx.audit.log({"ts": time.time(), "event": "user_paused"})
|
||||
ctx.notifier.send(Alert(
|
||||
kind="status",
|
||||
title="Monitorizare oprită manual",
|
||||
body="Folosește /resume pentru a relua.",
|
||||
))
|
||||
elif cmd.action == "resume":
|
||||
# R2: /resume clears only user_paused. Canary drift requires
|
||||
# /resume force (value == 1) so the user acknowledges the risk.
|
||||
was_drift = bool(getattr(ctx.canary, "is_paused", False))
|
||||
was_user = bool(ctx.lifecycle.user_paused) if ctx.lifecycle is not None else False
|
||||
force = cmd.value == 1
|
||||
if ctx.lifecycle is not None:
|
||||
ctx.lifecycle.user_paused = False
|
||||
if force and was_drift:
|
||||
ctx.canary.resume()
|
||||
ctx.audit.log({
|
||||
"ts": time.time(), "event": "user_resumed",
|
||||
"was_drift": was_drift, "was_user": was_user, "force": force,
|
||||
})
|
||||
# Adaptive response
|
||||
if was_drift and not force:
|
||||
title = "Pauză user eliminată — dar Canary drift activ"
|
||||
body = (
|
||||
"Trimite /resume force pentru a anula drift-pause. "
|
||||
"Recalibrează dacă driftul persistă."
|
||||
)
|
||||
elif force and was_drift:
|
||||
title = "Drift-pause anulat manual (force)"
|
||||
body = "Dacă driftul persistă, Canary va repauza."
|
||||
else:
|
||||
skip_now = None
|
||||
if ctx.lifecycle is not None:
|
||||
skip_now = _should_skip(time.time(), ctx.lifecycle, ctx.cfg, ctx.canary)
|
||||
if skip_now and skip_now.startswith("out_of_window"):
|
||||
title = "Pauză eliminată — piața e închisă acum"
|
||||
body = "Monitorizarea va porni la următoarea fereastră."
|
||||
else:
|
||||
title = "Monitorizare reluată"
|
||||
body = ""
|
||||
ctx.notifier.send(Alert(kind="status", title=title, body=body))
|
||||
|
||||
|
||||
async def _drain_cmd_queue(ctx: RunContext) -> None:
|
||||
"""Drain all pending commands, isolating each dispatch in try/except.
|
||||
|
||||
CRITICAL: this MUST run every loop iteration, unconditionally, even when
|
||||
the detection tick returned nothing (canary paused, out-of-window, etc.).
|
||||
Prior bug: the main loop `continue`'d past this drain when res=None,
|
||||
causing commands to accumulate indefinitely while canary was drifted.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
cmd = ctx.cmd_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return
|
||||
try:
|
||||
await _dispatch_command(ctx, cmd)
|
||||
except Exception as exc:
|
||||
ctx.audit.log({
|
||||
"ts": time.time(), "event": "command_error",
|
||||
"action": cmd.action, "error": str(exc),
|
||||
})
|
||||
print(f"ERR command_dispatch /{cmd.action}: {exc}", flush=True)
|
||||
ctx.notifier.send(Alert(kind="warn", title=f"Eroare comandă /{cmd.action}", body=str(exc)))
|
||||
|
||||
|
||||
def run_live(cfg, duration_s=None, capture_stub: bool = False) -> None:
|
||||
"""Sync entry point — delegates to asyncio event loop."""
|
||||
asyncio.run(run_live_async(cfg, duration_s=duration_s, capture_stub=capture_stub))
|
||||
@@ -645,8 +1114,30 @@ async def run_live_async(cfg, duration_s=None, capture_stub: bool = False) -> No
|
||||
capture = _build_capture(cfg, capture_stub=capture_stub)
|
||||
detector = Detector(cfg, capture)
|
||||
fsm = StateMachine(lockout_s=cfg.lockout_s)
|
||||
canary = Canary(cfg, pause_flag_path=Path("logs/pause.flag"))
|
||||
audit = AuditLog(Path("logs"))
|
||||
|
||||
# Forward-declare notifier so the canary pause callback can close over it.
|
||||
# The notifier is constructed a few lines below once backends exist.
|
||||
_notifier_ref: dict = {}
|
||||
|
||||
def _on_canary_pause(distance: int) -> None:
|
||||
audit.log({"ts": time.time(), "event": "canary_drift_paused", "distance": distance})
|
||||
n = _notifier_ref.get("n")
|
||||
if n is not None:
|
||||
n.send(Alert(
|
||||
kind="warn",
|
||||
title=f"Canary drift={distance} — monitorizare pauzată",
|
||||
body=(
|
||||
"Fereastra/paleta s-a schimbat. Trimite /resume pentru a relua "
|
||||
"sau recalibrează."
|
||||
),
|
||||
))
|
||||
|
||||
canary = Canary(
|
||||
cfg,
|
||||
pause_flag_path=Path("logs/pause.flag"),
|
||||
on_pause_callback=_on_canary_pause,
|
||||
)
|
||||
detection_log = AuditLog(Path("logs/detections"))
|
||||
backends = [
|
||||
DiscordNotifier(cfg.discord.webhook_url),
|
||||
@@ -663,6 +1154,7 @@ async def run_live_async(cfg, duration_s=None, capture_stub: bool = False) -> No
|
||||
})
|
||||
|
||||
notifier = FanoutNotifier(backends, Path(cfg.dead_letter_path), on_drop=_on_drop)
|
||||
_notifier_ref["n"] = notifier
|
||||
|
||||
# Initial frame + canary check
|
||||
first_frame = capture()
|
||||
@@ -699,10 +1191,8 @@ async def run_live_async(cfg, duration_s=None, capture_stub: bool = False) -> No
|
||||
pass
|
||||
|
||||
cmd_queue: asyncio.Queue[Command] = asyncio.Queue()
|
||||
first_accepted = True
|
||||
last_saved_color: str | None = None
|
||||
levels_extractor = None
|
||||
fire_count = 0
|
||||
loop_state = _LoopState(first_accepted=True, last_saved_color=None,
|
||||
levels_extractor=None, fire_count=0, start=start)
|
||||
|
||||
def _bound_save(frame: Any, label: str, now: float) -> "Path | None":
|
||||
return _save_annotated_frame(frame, cfg, fires_dir, label, now, audit=audit)
|
||||
@@ -715,8 +1205,23 @@ async def run_live_async(cfg, duration_s=None, capture_stub: bool = False) -> No
|
||||
)
|
||||
poller = TelegramPoller(cfg.telegram, cmd_queue, audit)
|
||||
|
||||
lifecycle = LifecycleState()
|
||||
# Seed lifecycle.last_window_state with the current status so we don't emit
|
||||
# a spurious market_open alert on the very first tick (R2).
|
||||
_pre_skip = _should_skip(time.time(), lifecycle, cfg, canary)
|
||||
_maybe_log_transition(_pre_skip, lifecycle, time.time(), audit, notifier)
|
||||
|
||||
ctx = RunContext(
|
||||
cfg=cfg, capture=capture, canary=canary, detector=detector, fsm=fsm,
|
||||
notifier=notifier, audit=audit, detection_log=detection_log,
|
||||
scheduler=scheduler, samples_dir=samples_dir, fires_dir=fires_dir,
|
||||
cmd_queue=cmd_queue, state=loop_state,
|
||||
levels_extractor_factory=lambda _cfg, trigger, now_ts: LevelsExtractor(_cfg, trigger, now_ts),
|
||||
lifecycle=lifecycle,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Nested async coroutines — capture nonlocal state from run_live_async
|
||||
# Nested async coroutines — heartbeat captures notifier + heartbeat_due
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _heartbeat_loop() -> None:
|
||||
@@ -738,124 +1243,13 @@ async def run_live_async(cfg, duration_s=None, capture_stub: bool = False) -> No
|
||||
notifier.send(Alert(kind="heartbeat", title="activ", body="încredere ok"))
|
||||
heartbeat_due = time.monotonic() + cfg.heartbeat_min * 60
|
||||
|
||||
async def _dispatch_command(cmd: Command) -> None:
|
||||
nonlocal fire_count
|
||||
if cmd.action == "set_interval":
|
||||
secs = cmd.value or cfg.telegram.auto_poll_interval_s
|
||||
scheduler.start(secs)
|
||||
audit.log({"ts": time.time(), "event": "scheduler_started", "reason": "set_interval", "interval_s": secs})
|
||||
notifier.send(Alert(kind="status", title=f"Polling activ — interval {secs // 60} min", body=""))
|
||||
elif cmd.action == "stop":
|
||||
if scheduler.is_running:
|
||||
scheduler.stop()
|
||||
audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": "command_stop"})
|
||||
notifier.send(Alert(kind="status", title="Polling oprit", body=""))
|
||||
else:
|
||||
notifier.send(Alert(kind="status", title="Polling nu este activ", body=""))
|
||||
elif cmd.action == "status":
|
||||
uptime_s = time.monotonic() - start
|
||||
last_roll = detector.rolling[-1] if detector.rolling else None
|
||||
last_conf = f"{last_roll.match.confidence:.2f}" if last_roll and last_roll.match else "—"
|
||||
last_color = (
|
||||
(last_roll.color or last_roll.match.name) if last_roll and last_roll.match else "—"
|
||||
) if last_roll else "—"
|
||||
sched_info = (
|
||||
f"activ @{scheduler.interval_s // 60}min" if scheduler.interval_s else "activ"
|
||||
) if scheduler.is_running else "oprit"
|
||||
canary_info = "drift (pauze)" if canary.is_paused else "ok"
|
||||
body = (
|
||||
f"Stare: {fsm.state.value}\n"
|
||||
f"Ultima detecție: {last_color} (conf {last_conf})\n"
|
||||
f"Uptime: {uptime_s / 3600:.1f}h | Semnale: {fire_count}\n"
|
||||
f"Poller: {sched_info} | Canary: {canary_info}"
|
||||
)
|
||||
notifier.send(Alert(kind="status", title="ATM Status", body=body))
|
||||
elif cmd.action == "ss":
|
||||
now_ss = time.time()
|
||||
frame_ss = await asyncio.to_thread(capture)
|
||||
if frame_ss is None:
|
||||
notifier.send(Alert(
|
||||
kind="warn",
|
||||
title="Captură eșuată — verificați fereastra TradeStation",
|
||||
body="",
|
||||
))
|
||||
return
|
||||
path_ss = await asyncio.to_thread(
|
||||
_save_annotated_frame, frame_ss, cfg, fires_dir, "ss", now_ss, audit,
|
||||
)
|
||||
audit.log({"ts": now_ss, "event": "screenshot_sent", "path": str(path_ss) if path_ss else None})
|
||||
notifier.send(Alert(kind="screenshot", title="Screenshot manual", body="", image_path=path_ss))
|
||||
|
||||
async def _detection_loop() -> None:
|
||||
nonlocal first_accepted, last_saved_color, levels_extractor, fire_count
|
||||
|
||||
while True:
|
||||
if duration_s is not None and (time.monotonic() - start) >= duration_s:
|
||||
break
|
||||
|
||||
now = time.time()
|
||||
|
||||
result: _TickSyncResult = await asyncio.to_thread(
|
||||
_sync_detection_tick,
|
||||
capture, canary, cfg, detector, fsm, notifier, audit, detection_log,
|
||||
fires_dir, first_accepted, last_saved_color, now, samples_dir,
|
||||
)
|
||||
|
||||
if result.first_consumed:
|
||||
first_accepted = False
|
||||
if result.new_color is not None:
|
||||
last_saved_color = result.new_color
|
||||
|
||||
tr = result.tr
|
||||
res = result.res
|
||||
|
||||
if result.late_start or res is None:
|
||||
await asyncio.sleep(cfg.loop_interval_s)
|
||||
continue
|
||||
|
||||
if tr is not None and res.accepted and res.color:
|
||||
if tr.reason == "prime" and not scheduler.is_running:
|
||||
scheduler.start(cfg.telegram.auto_poll_interval_s)
|
||||
audit.log({"ts": time.time(), "event": "scheduler_started", "reason": "primed"})
|
||||
elif tr.reason in ("fire", "cooled", "phase_skip", "opposite_rearm") and scheduler.is_running:
|
||||
scheduler.stop()
|
||||
audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": tr.reason})
|
||||
|
||||
if tr is not None and tr.trigger and not tr.locked:
|
||||
fire_count += 1
|
||||
if scheduler.is_running:
|
||||
scheduler.stop()
|
||||
audit.log({"ts": time.time(), "event": "scheduler_stopped", "reason": "fire"})
|
||||
levels_extractor = LevelsExtractor(cfg, tr.trigger, now)
|
||||
|
||||
if levels_extractor is not None and result.frame is not None:
|
||||
lr = levels_extractor.step(result.frame, now)
|
||||
if lr.status in ("complete", "timeout"):
|
||||
if lr.status == "complete" and lr.levels:
|
||||
notifier.send(Alert(
|
||||
kind="levels",
|
||||
title="Niveluri",
|
||||
body=(
|
||||
f"SL={lr.levels.sl} "
|
||||
f"TP1={lr.levels.tp1} "
|
||||
f"TP2={lr.levels.tp2}"
|
||||
),
|
||||
))
|
||||
levels_extractor = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd = cmd_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
try:
|
||||
await _dispatch_command(cmd)
|
||||
except Exception as _cmd_exc:
|
||||
_msg = f"/{cmd.action}: {_cmd_exc}"
|
||||
audit.log({"ts": time.time(), "event": "command_error", "action": cmd.action, "error": str(_cmd_exc)})
|
||||
print(f"ERR command_dispatch {_msg}", flush=True)
|
||||
notifier.send(Alert(kind="warn", title=f"Eroare comandă /{cmd.action}", body=str(_cmd_exc)))
|
||||
|
||||
result = await _run_tick(ctx)
|
||||
await _handle_fsm_result(ctx, result)
|
||||
await _drain_cmd_queue(ctx) # UNCONDITIONAL — fix for command hang
|
||||
await asyncio.sleep(cfg.loop_interval_s)
|
||||
|
||||
# Launch background tasks
|
||||
|
||||
@@ -232,3 +232,20 @@ class StateMachine:
|
||||
if last is None:
|
||||
return False
|
||||
return (ts - last) < self._lockout_s
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lockout API — used by fire_on_phase_skip handler outside the
|
||||
# FSM. Mirrors _is_locked / _last_fire without leaking private attrs.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_locked(self, direction: str, ts: float) -> bool:
|
||||
"""True if a FIRE in `direction` at ts would be within the lockout window."""
|
||||
return self._is_locked(direction, ts)
|
||||
|
||||
def record_fire(self, direction: str, ts: float) -> None:
|
||||
"""Mark a FIRE for `direction` at ts, starting the lockout timer.
|
||||
|
||||
Used by backstop handlers (e.g. fire_on_phase_skip) that emit a
|
||||
fire-equivalent alert without going through the natural FSM path.
|
||||
"""
|
||||
self._last_fire[direction] = ts
|
||||
|
||||
229
src/atm/validate.py
Normal file
229
src/atm/validate.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Offline calibration validation: run Detector on labeled frames, report PASS/FAIL.
|
||||
|
||||
Used by the `atm validate-calibration` subcommand. Reports per-sample detection
|
||||
results against expected labels, and for failures, computes RGB distance to
|
||||
each color threshold and emits tuning suggestions.
|
||||
|
||||
Reuses `Detector.step(frame)` - does NOT reimplement color classification.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleRecord:
|
||||
path: str
|
||||
expected: str
|
||||
detected: str | None
|
||||
confidence: float
|
||||
rgb: tuple[int, int, int] | None
|
||||
top3: list[tuple[str, float]] # [(name, score), ...] ranked by RGB distance
|
||||
passed: bool
|
||||
note: str = ""
|
||||
error: str | None = None # non-None if frame load failed / schema bad
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
records: list[SampleRecord] = field(default_factory=list)
|
||||
config_name: str = ""
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
return len(self.records)
|
||||
|
||||
@property
|
||||
def passed(self) -> int:
|
||||
return sum(1 for r in self.records if r.passed)
|
||||
|
||||
@property
|
||||
def failed(self) -> int:
|
||||
return self.total - self.passed
|
||||
|
||||
@property
|
||||
def all_pass(self) -> bool:
|
||||
return self.total > 0 and self.failed == 0
|
||||
|
||||
def render(self) -> str:
|
||||
lines: list[str] = []
|
||||
hdr = f"Testing {self.total} frames"
|
||||
if self.config_name:
|
||||
hdr += f" against config {self.config_name}"
|
||||
hdr += "..."
|
||||
lines.append(hdr)
|
||||
lines.append("")
|
||||
|
||||
for r in self.records:
|
||||
name = Path(r.path).name or r.path
|
||||
if r.error:
|
||||
lines.append(f" [FAIL] {name}")
|
||||
lines.append(f" error: {r.error}")
|
||||
continue
|
||||
tag = "PASS" if r.passed else "FAIL"
|
||||
rgb_str = f"RGB {r.rgb}" if r.rgb is not None else "RGB n/a"
|
||||
detected = r.detected if r.detected is not None else "none"
|
||||
lines.append(f" [{tag}] {name}")
|
||||
lines.append(
|
||||
f" expected={r.expected} detected={detected} "
|
||||
f"(conf {r.confidence:.2f}, {rgb_str})"
|
||||
)
|
||||
if not r.passed and r.top3:
|
||||
top3_str = " ".join(f"{n}({c:.2f})" for n, c in r.top3)
|
||||
lines.append(f" Top 3 candidates: {top3_str}")
|
||||
|
||||
lines.append("")
|
||||
pct = (self.passed / self.total * 100.0) if self.total else 0.0
|
||||
lines.append(f"SUMMARY: {self.passed}/{self.total} PASS ({pct:.0f}%)")
|
||||
|
||||
fails = [r for r in self.records if not r.passed]
|
||||
if fails:
|
||||
lines.append("FAILED:")
|
||||
for r in fails:
|
||||
name = Path(r.path).name or r.path
|
||||
if r.error:
|
||||
lines.append(f" - {name}: {r.error}")
|
||||
continue
|
||||
detected = r.detected if r.detected is not None else "none"
|
||||
lines.append(
|
||||
f" - {name}: expected {r.expected}, got {detected}"
|
||||
)
|
||||
|
||||
sug_lines = [
|
||||
r._suggestion # type: ignore[attr-defined]
|
||||
for r in fails
|
||||
if getattr(r, "_suggestion", "")
|
||||
]
|
||||
if sug_lines:
|
||||
lines.append("")
|
||||
lines.append("SUGGESTIONS:")
|
||||
for s in sug_lines:
|
||||
lines.append(f" - {s}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.render()
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Raised for missing label files or invalid schema."""
|
||||
|
||||
|
||||
def _rgb_distance(a: tuple[int, int, int], b: tuple[int, int, int]) -> float:
|
||||
return math.sqrt(sum((a[i] - b[i]) ** 2 for i in range(3)))
|
||||
|
||||
|
||||
def _load_labels(label_file: Path) -> list[dict[str, Any]]:
|
||||
if not label_file.exists():
|
||||
raise ValidationError(f"label file not found: {label_file}")
|
||||
try:
|
||||
data = json.loads(label_file.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValidationError(f"invalid JSON in {label_file}: {exc}") from exc
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError(
|
||||
f"label file must be a JSON array; got {type(data).__name__}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def validate_calibration(
|
||||
label_file: Path,
|
||||
cfg: Config,
|
||||
config_name: str = "",
|
||||
) -> ValidationReport:
|
||||
"""Run Detector on each labeled frame; return a ValidationReport.
|
||||
|
||||
Reuses `Detector.step(frame)`. Loads frames via cv2.imread.
|
||||
Raises ValidationError if the label file is missing or malformed.
|
||||
"""
|
||||
import cv2 # local import keeps module import cheap
|
||||
from .detector import Detector
|
||||
|
||||
entries = _load_labels(label_file)
|
||||
report = ValidationReport(config_name=config_name)
|
||||
|
||||
palette = {
|
||||
name: spec.rgb
|
||||
for name, spec in cfg.colors.items()
|
||||
if name != "background"
|
||||
}
|
||||
|
||||
detector = Detector(cfg=cfg, capture=lambda: None)
|
||||
|
||||
for entry in entries:
|
||||
path = str(entry.get("path", ""))
|
||||
expected = str(entry.get("expected", ""))
|
||||
note = str(entry.get("note", ""))
|
||||
|
||||
if not path or not expected:
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=None, confidence=0.0,
|
||||
rgb=None, top3=[], passed=False, note=note,
|
||||
error="missing 'path' or 'expected' field",
|
||||
)
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
report.records.append(rec)
|
||||
continue
|
||||
|
||||
frame = cv2.imread(path)
|
||||
if frame is None:
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=None, confidence=0.0,
|
||||
rgb=None, top3=[], passed=False, note=note,
|
||||
error=f"cv2.imread failed for {path}",
|
||||
)
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
report.records.append(rec)
|
||||
continue
|
||||
|
||||
result = detector.step(ts=0.0, frame=frame)
|
||||
|
||||
match = result.match
|
||||
if match is None:
|
||||
detected: str | None = None
|
||||
confidence = 0.0
|
||||
else:
|
||||
detected = match.name if match.name != "UNKNOWN" else None
|
||||
confidence = match.confidence
|
||||
|
||||
rgb = result.rgb
|
||||
|
||||
# Top 3 candidates: rank palette entries by RGB distance to observed.
|
||||
top3: list[tuple[str, float]] = []
|
||||
if rgb is not None:
|
||||
scored: list[tuple[str, float]] = []
|
||||
for name, ref in palette.items():
|
||||
scored.append((name, _rgb_distance(rgb, ref)))
|
||||
scored.sort(key=lambda t: t[1])
|
||||
top3 = [(n, 1.0 / (1.0 + d / 20.0)) for n, d in scored[:3]]
|
||||
|
||||
passed = detected == expected
|
||||
|
||||
rec = SampleRecord(
|
||||
path=path, expected=expected, detected=detected,
|
||||
confidence=confidence, rgb=rgb, top3=top3, passed=passed, note=note,
|
||||
)
|
||||
|
||||
if not passed and rgb is not None and expected in palette:
|
||||
ref = palette[expected]
|
||||
tol = cfg.colors[expected].tolerance
|
||||
dist = _rgb_distance(rgb, ref)
|
||||
rec._suggestion = ( # type: ignore[attr-defined]
|
||||
f"{expected} praguri curente: RGB{ref} +/- {tol:.0f}. "
|
||||
f"Pixelul observat {rgb} e la distanta {dist:.1f} "
|
||||
f"-> recalibreaza cu acest sample."
|
||||
)
|
||||
else:
|
||||
rec._suggestion = "" # type: ignore[attr-defined]
|
||||
|
||||
report.records.append(rec)
|
||||
|
||||
return report
|
||||
@@ -140,6 +140,52 @@ def test_pause_file_written(tmp_path: Path) -> None:
|
||||
assert flag.exists()
|
||||
|
||||
|
||||
def test_canary_pause_callback_fires_once() -> None:
|
||||
"""Single-shot: callback invoked exactly once per not_paused→paused edge."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
calls: list[int] = []
|
||||
|
||||
canary = Canary(cfg, on_pause_callback=lambda d: calls.append(d))
|
||||
|
||||
canary.check(DRIFTED_FRAME) # transition → callback fires
|
||||
canary.check(DRIFTED_FRAME) # still paused → no new callback
|
||||
canary.check(BASELINE_FRAME) # clean but still paused → no new callback
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0] > 0 # distance should be positive
|
||||
|
||||
|
||||
def test_canary_resume_allows_new_pause_notification() -> None:
|
||||
"""After resume, a fresh drift must re-fire the callback."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
calls: list[int] = []
|
||||
|
||||
canary = Canary(cfg, on_pause_callback=lambda d: calls.append(d))
|
||||
|
||||
canary.check(DRIFTED_FRAME)
|
||||
assert len(calls) == 1
|
||||
|
||||
canary.resume()
|
||||
canary.check(DRIFTED_FRAME) # new pause transition
|
||||
|
||||
assert len(calls) == 2
|
||||
|
||||
|
||||
def test_canary_pause_callback_exception_does_not_crash_check() -> None:
|
||||
"""A failing callback must not break canary.check (detection cycle safety)."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
|
||||
def _boom(_d: int) -> None:
|
||||
raise RuntimeError("notifier down")
|
||||
|
||||
canary = Canary(cfg, on_pause_callback=_boom)
|
||||
|
||||
# Must not raise — exception is swallowed + logged.
|
||||
result = canary.check(DRIFTED_FRAME)
|
||||
assert result.paused is True
|
||||
assert canary.is_paused is True
|
||||
|
||||
|
||||
def test_resume_deletes_pause_file(tmp_path: Path) -> None:
|
||||
"""resume() deletes the pause flag file."""
|
||||
flag = tmp_path / "paused.flag"
|
||||
|
||||
45
tests/test_commands.py
Normal file
45
tests/test_commands.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Tests for atm.commands — /pause /resume parsing (Commit 5)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from atm.commands import Command, TelegramPoller
|
||||
|
||||
|
||||
def _make_poller() -> TelegramPoller:
|
||||
cfg = MagicMock()
|
||||
cfg.bot_token = "tok"
|
||||
cfg.chat_id = "123"
|
||||
cfg.allowed_chat_ids = ("123",)
|
||||
cfg.poll_timeout_s = 1
|
||||
return TelegramPoller(cfg, MagicMock(), MagicMock())
|
||||
|
||||
|
||||
def test_parse_pause():
|
||||
p = _make_poller()
|
||||
assert p._parse_command("pause") == Command(action="pause")
|
||||
assert p._parse_command("/pause") == Command(action="pause")
|
||||
|
||||
|
||||
def test_parse_resume_plain():
|
||||
p = _make_poller()
|
||||
assert p._parse_command("resume") == Command(action="resume")
|
||||
assert p._parse_command("/resume") == Command(action="resume")
|
||||
|
||||
|
||||
def test_parse_resume_force():
|
||||
p = _make_poller()
|
||||
# "resume force" → value=1 signals force-resume of canary drift
|
||||
cmd = p._parse_command("resume force")
|
||||
assert cmd is not None
|
||||
assert cmd.action == "resume"
|
||||
assert cmd.value == 1
|
||||
|
||||
|
||||
def test_parse_existing_commands_still_work():
|
||||
"""Regression: adding pause/resume must not break stop/status/ss/interval."""
|
||||
p = _make_poller()
|
||||
assert p._parse_command("stop") == Command(action="stop")
|
||||
assert p._parse_command("status") == Command(action="status")
|
||||
assert p._parse_command("ss") == Command(action="ss")
|
||||
assert p._parse_command("3") == Command(action="set_interval", value=180)
|
||||
@@ -97,3 +97,59 @@ def test_attach_screenshots_unknown_keys_ignored() -> None:
|
||||
}))
|
||||
assert cfg.attach_screenshots.arm is False
|
||||
# Should not raise even with unknown key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 3: AlertBehaviorCfg (fire_on_phase_skip)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_alerts_default_fire_on_phase_skip_true() -> None:
|
||||
cfg = Config._from_dict(_with_opts({}))
|
||||
assert cfg.alerts.fire_on_phase_skip is True
|
||||
|
||||
|
||||
def test_alerts_fire_on_phase_skip_can_be_disabled() -> None:
|
||||
cfg = Config._from_dict(_with_opts({"alerts": {"fire_on_phase_skip": False}}))
|
||||
assert cfg.alerts.fire_on_phase_skip is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 4: OperatingHoursCfg parsing + tz cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_operating_hours_default_disabled() -> None:
|
||||
cfg = Config._from_dict(_with_opts({}))
|
||||
assert cfg.operating_hours.enabled is False
|
||||
assert cfg.operating_hours.timezone == "America/New_York"
|
||||
assert cfg.operating_hours._tz_cache is None
|
||||
|
||||
|
||||
def test_operating_hours_enabled_caches_tz() -> None:
|
||||
cfg = Config._from_dict(_with_opts({
|
||||
"operating_hours": {
|
||||
"enabled": True,
|
||||
"timezone": "America/New_York",
|
||||
"weekdays": ["MON", "TUE", "WED", "THU", "FRI"],
|
||||
"start_hhmm": "09:30",
|
||||
"stop_hhmm": "16:00",
|
||||
}
|
||||
}))
|
||||
assert cfg.operating_hours.enabled is True
|
||||
assert cfg.operating_hours._tz_cache is not None
|
||||
assert str(cfg.operating_hours._tz_cache) == "America/New_York"
|
||||
|
||||
|
||||
def test_operating_hours_invalid_tz_raises_valueerror() -> None:
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="operating_hours.timezone"):
|
||||
Config._from_dict(_with_opts({
|
||||
"operating_hours": {"enabled": True, "timezone": "Not/A_Zone"},
|
||||
}))
|
||||
|
||||
|
||||
def test_operating_hours_invalid_weekday_raises_valueerror() -> None:
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="weekdays"):
|
||||
Config._from_dict(_with_opts({
|
||||
"operating_hours": {"enabled": True, "weekdays": ["XYZ"]},
|
||||
}))
|
||||
|
||||
@@ -10,6 +10,8 @@ Covers the six cases from the arm+prime notification plan:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from atm.main import _handle_tick
|
||||
from atm.notifier import Alert
|
||||
from atm.state_machine import State, StateMachine
|
||||
@@ -486,3 +488,82 @@ def test_save_annotated_frame_succeeds(tmp_path, monkeypatch):
|
||||
assert "BUY" in result.name
|
||||
assert len(written) == 1
|
||||
assert not any(e.get("event") == "snapshot_fail" for e in audit.events)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 3: fire_on_phase_skip backstop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _cfg_with_flag(enabled: bool):
|
||||
return SimpleNamespace(alerts=SimpleNamespace(fire_on_phase_skip=enabled))
|
||||
|
||||
|
||||
def test_phase_skip_fire_when_flag_on():
|
||||
"""ARMED_SELL → light_red directly with flag=True → phase_skip_fire alert."""
|
||||
fsm = StateMachine(lockout_s=240)
|
||||
notif = FakeNotifier()
|
||||
audit = FakeAudit()
|
||||
|
||||
# Arm SELL (yellow from IDLE)
|
||||
_handle_tick(fsm, "yellow", 1.0, notif, audit, first_accepted=False,
|
||||
cfg=_cfg_with_flag(True))
|
||||
assert fsm.state == State.ARMED_SELL
|
||||
notif.alerts.clear()
|
||||
|
||||
# ARMED_SELL → light_red (skips dark_red) → phase_skip_fire
|
||||
tr = _handle_tick(fsm, "light_red", 2.0, notif, audit, first_accepted=False,
|
||||
cfg=_cfg_with_flag(True))
|
||||
assert tr is not None and tr.reason == "phase_skip"
|
||||
|
||||
ps_alerts = [a for a in notif.alerts if a.kind == "phase_skip_fire"]
|
||||
assert len(ps_alerts) == 1
|
||||
assert ps_alerts[0].direction == "SELL"
|
||||
assert "SELL" in ps_alerts[0].title
|
||||
|
||||
|
||||
def test_phase_skip_no_fire_when_flag_off():
|
||||
"""Same scenario, flag=False → no phase_skip_fire emitted."""
|
||||
fsm = StateMachine(lockout_s=240)
|
||||
notif = FakeNotifier()
|
||||
audit = FakeAudit()
|
||||
|
||||
_handle_tick(fsm, "yellow", 1.0, notif, audit, first_accepted=False,
|
||||
cfg=_cfg_with_flag(False))
|
||||
notif.alerts.clear()
|
||||
|
||||
_handle_tick(fsm, "light_red", 2.0, notif, audit, first_accepted=False,
|
||||
cfg=_cfg_with_flag(False))
|
||||
|
||||
ps_alerts = [a for a in notif.alerts if a.kind == "phase_skip_fire"]
|
||||
assert ps_alerts == []
|
||||
|
||||
|
||||
def test_phase_skip_lockout_suppresses_spam():
|
||||
"""Two phase_skip events within lockout_s → only the first emits an alert."""
|
||||
fsm = StateMachine(lockout_s=240)
|
||||
notif = FakeNotifier()
|
||||
audit = FakeAudit()
|
||||
cfg = _cfg_with_flag(True)
|
||||
|
||||
# First cycle
|
||||
_handle_tick(fsm, "yellow", 1.0, notif, audit, first_accepted=False, cfg=cfg)
|
||||
_handle_tick(fsm, "light_red", 2.0, notif, audit, first_accepted=False, cfg=cfg)
|
||||
# Second arm + phase_skip well within 240s
|
||||
_handle_tick(fsm, "yellow", 60.0, notif, audit, first_accepted=False, cfg=cfg)
|
||||
_handle_tick(fsm, "light_red", 61.0, notif, audit, first_accepted=False, cfg=cfg)
|
||||
|
||||
ps_alerts = [a for a in notif.alerts if a.kind == "phase_skip_fire"]
|
||||
assert len(ps_alerts) == 1, (
|
||||
f"expected 1 phase_skip_fire (lockout), got {len(ps_alerts)}"
|
||||
)
|
||||
|
||||
|
||||
def test_state_machine_is_locked_and_record_fire_public_api():
|
||||
"""Public lockout helpers mirror the private _is_locked / _last_fire behavior."""
|
||||
fsm = StateMachine(lockout_s=100)
|
||||
assert fsm.is_locked("BUY", 0.0) is False
|
||||
|
||||
fsm.record_fire("BUY", 10.0)
|
||||
assert fsm.is_locked("BUY", 50.0) is True # within 100s
|
||||
assert fsm.is_locked("BUY", 150.0) is False # past lockout
|
||||
assert fsm.is_locked("SELL", 50.0) is False # other direction unaffected
|
||||
|
||||
@@ -401,3 +401,581 @@ async def test_lifecycle_idle_armed_primed_autopoll_fire_stop(monkeypatch, tmp_p
|
||||
start_idx = scheduler_events.index("start:180")
|
||||
stop_idx = scheduler_events.index("stop")
|
||||
assert start_idx < stop_idx, "scheduler started after it stopped"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 1 regression tests: _drain_cmd_queue MUST run unconditionally,
|
||||
# even when canary is paused or when detection is otherwise skipped.
|
||||
# Prior bug: `continue` past the drain loop caused commands to pile up.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ctx_for_drain(cmd_queue, dispatched: list):
|
||||
"""Build a minimal RunContext where _dispatch_command just records calls."""
|
||||
import atm.main as _main
|
||||
|
||||
class _FakeAudit:
|
||||
def __init__(self): self.events = []
|
||||
def log(self, e): self.events.append(e)
|
||||
|
||||
class _FakeNotifier:
|
||||
def __init__(self): self.alerts = []
|
||||
def send(self, a): self.alerts.append(a)
|
||||
|
||||
class _FakeCanary:
|
||||
def __init__(self, paused=True):
|
||||
self.is_paused = paused
|
||||
|
||||
class _FakeScheduler:
|
||||
is_running = False
|
||||
interval_s = None
|
||||
def start(self, s): pass
|
||||
def stop(self): pass
|
||||
|
||||
state = _main._LoopState(start=0.0)
|
||||
ctx = _main.RunContext(
|
||||
cfg=MagicMock(),
|
||||
capture=lambda: None,
|
||||
canary=_FakeCanary(paused=True),
|
||||
detector=MagicMock(),
|
||||
fsm=MagicMock(),
|
||||
notifier=_FakeNotifier(),
|
||||
audit=_FakeAudit(),
|
||||
detection_log=_FakeAudit(),
|
||||
scheduler=_FakeScheduler(),
|
||||
samples_dir=Path("."),
|
||||
fires_dir=Path("."),
|
||||
cmd_queue=cmd_queue,
|
||||
state=state,
|
||||
levels_extractor_factory=lambda *a, **kw: None,
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_works_when_canary_paused(monkeypatch):
|
||||
"""Regression: when canary.is_paused, _drain_cmd_queue still dispatches.
|
||||
|
||||
Prior bug: detection loop `continue`'d past the drain block whenever the
|
||||
tick returned res=None (canary paused). Commands accumulated forever.
|
||||
"""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
await q.put(Command(action="status"))
|
||||
await q.put(Command(action="ss"))
|
||||
|
||||
dispatched: list = []
|
||||
|
||||
async def _fake_dispatch(ctx, cmd):
|
||||
dispatched.append(cmd.action)
|
||||
|
||||
monkeypatch.setattr(_main, "_dispatch_command", _fake_dispatch)
|
||||
|
||||
ctx = _make_ctx_for_drain(q, dispatched)
|
||||
|
||||
await _main._drain_cmd_queue(ctx)
|
||||
|
||||
assert dispatched == ["status", "ss"]
|
||||
assert q.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_works_when_out_of_window(monkeypatch):
|
||||
"""Drain must still fire when the tick skipped (e.g. out of operating hours).
|
||||
|
||||
The refactored loop runs _drain_cmd_queue unconditionally after every tick,
|
||||
regardless of `_TickSyncResult` content.
|
||||
"""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
await q.put(Command(action="stop"))
|
||||
|
||||
dispatched: list = []
|
||||
|
||||
async def _fake_dispatch(ctx, cmd):
|
||||
dispatched.append(cmd.action)
|
||||
|
||||
monkeypatch.setattr(_main, "_dispatch_command", _fake_dispatch)
|
||||
|
||||
ctx = _make_ctx_for_drain(q, dispatched)
|
||||
# Simulate out-of-window tick (empty _TickSyncResult, no res)
|
||||
await _main._handle_fsm_result(ctx, _main._TickSyncResult())
|
||||
await _main._drain_cmd_queue(ctx)
|
||||
|
||||
assert dispatched == ["stop"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_isolates_dispatch_exceptions(monkeypatch):
|
||||
"""If one command raises, remaining commands still drain + warn alert sent."""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
await q.put(Command(action="status"))
|
||||
await q.put(Command(action="ss"))
|
||||
|
||||
attempts: list = []
|
||||
|
||||
async def _fake_dispatch(ctx, cmd):
|
||||
attempts.append(cmd.action)
|
||||
if cmd.action == "status":
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(_main, "_dispatch_command", _fake_dispatch)
|
||||
|
||||
ctx = _make_ctx_for_drain(q, attempts)
|
||||
await _main._drain_cmd_queue(ctx)
|
||||
|
||||
assert attempts == ["status", "ss"]
|
||||
# warn alert for the failed command
|
||||
warn_titles = [a.title for a in ctx.notifier.alerts if a.kind == "warn"]
|
||||
assert any("status" in t for t in warn_titles)
|
||||
# command_error audit event
|
||||
errs = [e for e in ctx.audit.events if e.get("event") == "command_error"]
|
||||
assert len(errs) == 1 and errs[0]["action"] == "status"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 4: operating hours + LifecycleState transitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from zoneinfo import ZoneInfo as _ZI # noqa: E402
|
||||
import datetime as _dt # noqa: E402
|
||||
|
||||
|
||||
def _oh_cfg(enabled=True, weekdays=("MON", "TUE", "WED", "THU", "FRI"),
|
||||
start="09:30", stop="16:00", tz="America/New_York"):
|
||||
"""Build a lightweight cfg-like object with operating_hours populated."""
|
||||
oh = types.SimpleNamespace(
|
||||
enabled=enabled,
|
||||
timezone=tz,
|
||||
weekdays=weekdays,
|
||||
start_hhmm=start,
|
||||
stop_hhmm=stop,
|
||||
_tz_cache=_ZI(tz) if enabled else None,
|
||||
)
|
||||
return types.SimpleNamespace(operating_hours=oh)
|
||||
|
||||
|
||||
def _fake_canary(paused=False):
|
||||
return types.SimpleNamespace(is_paused=paused)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"local_dt,expected",
|
||||
[
|
||||
# Monday 09:30 NY — exact open → active (None)
|
||||
(_dt.datetime(2026, 4, 20, 9, 30), None),
|
||||
# Monday 16:00 NY — exact close → inactive (>= stop)
|
||||
(_dt.datetime(2026, 4, 20, 16, 0), "out_of_window_hours"),
|
||||
# Monday 08:00 NY — before open
|
||||
(_dt.datetime(2026, 4, 20, 8, 0), "out_of_window_hours"),
|
||||
# Monday 12:00 NY — active
|
||||
(_dt.datetime(2026, 4, 20, 12, 0), None),
|
||||
# Saturday 12:00 NY — weekend
|
||||
(_dt.datetime(2026, 4, 18, 12, 0), "out_of_window_weekend"),
|
||||
# Sunday 23:00 NY — weekend
|
||||
(_dt.datetime(2026, 4, 19, 23, 0), "out_of_window_weekend"),
|
||||
],
|
||||
)
|
||||
def test_operating_hours_skip_matrix(local_dt, expected):
|
||||
"""Timezone-aware start/stop + weekday checks."""
|
||||
import atm.main as _main
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
now_ts = local_dt.replace(tzinfo=tz).timestamp()
|
||||
|
||||
lifecycle = _main.LifecycleState()
|
||||
result = _main._should_skip(now_ts, lifecycle, cfg, _fake_canary())
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_market_open_close_transitions_logged_once():
|
||||
"""Crossing a boundary emits exactly one market_open / market_closed event."""
|
||||
import atm.main as _main
|
||||
|
||||
audit_events = []
|
||||
alerts = []
|
||||
|
||||
class _A:
|
||||
def log(self, e): audit_events.append(e)
|
||||
|
||||
class _N:
|
||||
def send(self, a): alerts.append(a)
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
lifecycle = _main.LifecycleState()
|
||||
canary = _fake_canary()
|
||||
|
||||
# Prime as closed (before open, Monday 08:00)
|
||||
pre_open = _dt.datetime(2026, 4, 20, 8, 0, tzinfo=tz).timestamp()
|
||||
skip_pre = _main._should_skip(pre_open, lifecycle, cfg, canary)
|
||||
_main._maybe_log_transition(skip_pre, lifecycle, pre_open, _A(), _N())
|
||||
# First evaluation seeds state, no alert yet.
|
||||
assert lifecycle.last_window_state == "closed"
|
||||
assert alerts == []
|
||||
assert audit_events == []
|
||||
|
||||
# Transition to open
|
||||
mid = _dt.datetime(2026, 4, 20, 12, 0, tzinfo=tz).timestamp()
|
||||
skip_mid = _main._should_skip(mid, lifecycle, cfg, canary)
|
||||
_main._maybe_log_transition(skip_mid, lifecycle, mid, _A(), _N())
|
||||
assert lifecycle.last_window_state == "open"
|
||||
assert len(alerts) == 1
|
||||
assert any(e.get("event") == "market_open" for e in audit_events)
|
||||
|
||||
# Repeated open tick — no duplicate log
|
||||
alerts.clear()
|
||||
audit_events.clear()
|
||||
skip_mid2 = _main._should_skip(mid + 60, lifecycle, cfg, canary)
|
||||
_main._maybe_log_transition(skip_mid2, lifecycle, mid + 60, _A(), _N())
|
||||
assert alerts == []
|
||||
assert audit_events == []
|
||||
|
||||
# Transition to close
|
||||
close = _dt.datetime(2026, 4, 20, 17, 0, tzinfo=tz).timestamp()
|
||||
skip_close = _main._should_skip(close, lifecycle, cfg, canary)
|
||||
_main._maybe_log_transition(skip_close, lifecycle, close, _A(), _N())
|
||||
assert lifecycle.last_window_state == "closed"
|
||||
assert any(e.get("event") == "market_closed" for e in audit_events)
|
||||
|
||||
|
||||
def test_market_transition_sends_notification():
|
||||
"""market_open / market_closed transitions produce kind=status alerts."""
|
||||
import atm.main as _main
|
||||
|
||||
alerts = []
|
||||
|
||||
class _A:
|
||||
def log(self, e): pass
|
||||
|
||||
class _N:
|
||||
def send(self, a): alerts.append(a)
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
lifecycle = _main.LifecycleState(last_window_state="closed")
|
||||
|
||||
mid = _dt.datetime(2026, 4, 20, 12, 0, tzinfo=tz).timestamp()
|
||||
_main._maybe_log_transition(None, lifecycle, mid, _A(), _N())
|
||||
assert len(alerts) == 1
|
||||
assert alerts[0].kind == "status"
|
||||
assert "market" in alerts[0].title.lower() or "piața" in alerts[0].body.lower()
|
||||
|
||||
|
||||
def test_startup_in_window_suppresses_market_open():
|
||||
"""R2 #20: first evaluation in-window just seeds state; no alert fires."""
|
||||
import atm.main as _main
|
||||
|
||||
alerts = []
|
||||
events = []
|
||||
|
||||
class _A:
|
||||
def log(self, e): events.append(e)
|
||||
|
||||
class _N:
|
||||
def send(self, a): alerts.append(a)
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
lifecycle = _main.LifecycleState() # last_window_state is None
|
||||
|
||||
in_window = _dt.datetime(2026, 4, 20, 12, 0, tzinfo=tz).timestamp()
|
||||
skip = _main._should_skip(in_window, lifecycle, cfg, _fake_canary())
|
||||
assert skip is None
|
||||
_main._maybe_log_transition(skip, lifecycle, in_window, _A(), _N())
|
||||
|
||||
# Seeded silently
|
||||
assert lifecycle.last_window_state == "open"
|
||||
assert alerts == []
|
||||
assert not any(e.get("event") == "market_open" for e in events)
|
||||
|
||||
# Two more ticks, still in-window → no spurious alert
|
||||
for _ in range(2):
|
||||
skip = _main._should_skip(in_window + 60, lifecycle, cfg, _fake_canary())
|
||||
_main._maybe_log_transition(skip, lifecycle, in_window + 60, _A(), _N())
|
||||
assert alerts == []
|
||||
|
||||
|
||||
def test_operating_hours_weekday_locale_independent():
|
||||
"""R2 #22: weekday check must not depend on process locale (strftime('%a'))."""
|
||||
import locale as _locale
|
||||
import atm.main as _main
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
# Saturday 12:00 NY
|
||||
sat = _dt.datetime(2026, 4, 18, 12, 0, tzinfo=tz).timestamp()
|
||||
|
||||
original = _locale.setlocale(_locale.LC_TIME)
|
||||
try:
|
||||
for loc in ("C", "de_DE.UTF-8"):
|
||||
try:
|
||||
_locale.setlocale(_locale.LC_TIME, loc)
|
||||
except _locale.Error:
|
||||
continue # locale not installed → skip gracefully
|
||||
lifecycle = _main.LifecycleState()
|
||||
result = _main._should_skip(sat, lifecycle, cfg, _fake_canary())
|
||||
assert result == "out_of_window_weekend", (
|
||||
f"locale={loc} returned {result!r}"
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
_locale.setlocale(_locale.LC_TIME, original)
|
||||
except _locale.Error:
|
||||
_locale.setlocale(_locale.LC_TIME, "C")
|
||||
|
||||
|
||||
def test_should_skip_user_paused_wins():
|
||||
import atm.main as _main
|
||||
cfg = _oh_cfg()
|
||||
lifecycle = _main.LifecycleState(user_paused=True)
|
||||
# Mid-Monday (in-window) — should still skip because user_paused
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
mid = _dt.datetime(2026, 4, 20, 12, 0, tzinfo=tz).timestamp()
|
||||
assert _main._should_skip(mid, lifecycle, cfg, _fake_canary()) == "user_paused"
|
||||
|
||||
|
||||
def test_should_skip_canary_drift_wins_over_window():
|
||||
import atm.main as _main
|
||||
cfg = _oh_cfg()
|
||||
lifecycle = _main.LifecycleState()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
mid = _dt.datetime(2026, 4, 20, 12, 0, tzinfo=tz).timestamp()
|
||||
assert _main._should_skip(mid, lifecycle, cfg, _fake_canary(paused=True)) == "drift_paused"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commit 5: /pause /resume dispatch (plan tests #11-15, #16, R2 #21)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _dispatch_ctx(canary=None, lifecycle=None, cfg=None):
|
||||
"""Minimal RunContext for _dispatch_command unit tests."""
|
||||
import atm.main as _main
|
||||
|
||||
class _A:
|
||||
def __init__(self): self.events = []
|
||||
def log(self, e): self.events.append(e)
|
||||
|
||||
class _N:
|
||||
def __init__(self): self.alerts = []
|
||||
def send(self, a): self.alerts.append(a)
|
||||
|
||||
class _S:
|
||||
is_running = False
|
||||
interval_s = None
|
||||
def start(self, s): self.is_running = True
|
||||
def stop(self): self.is_running = False
|
||||
|
||||
if canary is None:
|
||||
canary = types.SimpleNamespace(is_paused=False, resume=lambda: None)
|
||||
if lifecycle is None:
|
||||
lifecycle = _main.LifecycleState()
|
||||
if cfg is None:
|
||||
cfg = MagicMock()
|
||||
cfg.telegram.auto_poll_interval_s = 180
|
||||
cfg.operating_hours = types.SimpleNamespace(enabled=False, _tz_cache=None)
|
||||
|
||||
state = _main._LoopState(start=0.0)
|
||||
ctx = _main.RunContext(
|
||||
cfg=cfg, capture=lambda: None, canary=canary,
|
||||
detector=MagicMock(), fsm=MagicMock(),
|
||||
notifier=_N(), audit=_A(), detection_log=_A(),
|
||||
scheduler=_S(), samples_dir=Path("."), fires_dir=Path("."),
|
||||
cmd_queue=MagicMock(), state=state,
|
||||
levels_extractor_factory=lambda *a, **kw: None,
|
||||
lifecycle=lifecycle,
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_command_sets_user_paused_and_skips_detection():
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
ctx = _dispatch_ctx()
|
||||
await _main._dispatch_command(ctx, Command(action="pause"))
|
||||
|
||||
assert ctx.lifecycle.user_paused is True
|
||||
# When combined with _should_skip, we get user_paused
|
||||
assert _main._should_skip(0.0, ctx.lifecycle, ctx.cfg, ctx.canary) == "user_paused"
|
||||
# Audit + notif
|
||||
assert any(e.get("event") == "user_paused" for e in ctx.audit.events)
|
||||
assert any(a.kind == "status" and "oprit" in a.title.lower() for a in ctx.notifier.alerts)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_clears_user_paused_and_canary_when_forced():
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
canary_state = {"paused": True}
|
||||
canary = types.SimpleNamespace(
|
||||
is_paused=True,
|
||||
resume=lambda: canary_state.__setitem__("paused", False),
|
||||
)
|
||||
# Re-bind is_paused via property so resume() effect is visible
|
||||
class _Canary:
|
||||
def __init__(self): self._p = True
|
||||
@property
|
||||
def is_paused(self): return self._p
|
||||
def resume(self): self._p = False
|
||||
canary = _Canary()
|
||||
|
||||
ctx = _dispatch_ctx(canary=canary)
|
||||
ctx.lifecycle.user_paused = True
|
||||
|
||||
await _main._dispatch_command(ctx, Command(action="resume", value=1))
|
||||
|
||||
assert ctx.lifecycle.user_paused is False
|
||||
assert canary.is_paused is False
|
||||
force_events = [e for e in ctx.audit.events if e.get("event") == "user_resumed"]
|
||||
assert force_events and force_events[0]["force"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_during_drift_keeps_canary_paused_without_force():
|
||||
"""R2 #21: plain /resume during drift clears user_paused but NOT canary."""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
class _Canary:
|
||||
def __init__(self): self._p = True
|
||||
@property
|
||||
def is_paused(self): return self._p
|
||||
def resume(self): self._p = False
|
||||
canary = _Canary()
|
||||
|
||||
ctx = _dispatch_ctx(canary=canary)
|
||||
ctx.lifecycle.user_paused = True
|
||||
|
||||
await _main._dispatch_command(ctx, Command(action="resume")) # no force
|
||||
|
||||
assert ctx.lifecycle.user_paused is False
|
||||
assert canary.is_paused is True # still drift-paused
|
||||
# Message must mention drift
|
||||
status = [a for a in ctx.notifier.alerts if a.kind == "status"]
|
||||
assert status and ("drift" in (status[0].title + status[0].body).lower())
|
||||
|
||||
# Now force
|
||||
ctx.notifier.alerts.clear()
|
||||
await _main._dispatch_command(ctx, Command(action="resume", value=1))
|
||||
assert canary.is_paused is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_out_of_window_responds_with_pending_message():
|
||||
"""/resume while operating-hours window is closed → special body."""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
cfg = _oh_cfg()
|
||||
tz = cfg.operating_hours._tz_cache
|
||||
lifecycle = _main.LifecycleState(user_paused=True, last_window_state="closed")
|
||||
canary = types.SimpleNamespace(is_paused=False, resume=lambda: None)
|
||||
|
||||
ctx = _dispatch_ctx(canary=canary, lifecycle=lifecycle, cfg=cfg)
|
||||
|
||||
# Pin time to Saturday
|
||||
import atm.main as _mm
|
||||
real_time = _mm.time
|
||||
fake_ts = _dt.datetime(2026, 4, 18, 12, 0, tzinfo=tz).timestamp()
|
||||
class _FakeTime:
|
||||
def time(self): return fake_ts
|
||||
def monotonic(self): return 0.0
|
||||
_mm.time = _FakeTime()
|
||||
try:
|
||||
await _main._dispatch_command(ctx, Command(action="resume"))
|
||||
finally:
|
||||
_mm.time = real_time
|
||||
|
||||
assert ctx.lifecycle.user_paused is False
|
||||
status = [a for a in ctx.notifier.alerts if a.kind == "status"]
|
||||
assert status
|
||||
combined = (status[0].title + status[0].body).lower()
|
||||
assert "închis" in combined or "piața" in combined or "ferestr" in combined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_reports_pause_reason():
|
||||
"""/status body must mention pause reason + window state."""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
ctx = _dispatch_ctx()
|
||||
ctx.lifecycle.user_paused = True
|
||||
# Stub detector.rolling for status
|
||||
ctx.detector.rolling = []
|
||||
ctx.fsm.state = types.SimpleNamespace(value="IDLE")
|
||||
|
||||
await _main._dispatch_command(ctx, Command(action="status"))
|
||||
|
||||
status = [a for a in ctx.notifier.alerts if a.kind == "status"]
|
||||
assert status
|
||||
body = status[0].body
|
||||
assert "user_paused" in body or "pauzat:user_paused" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_with_drift_then_resume_then_fire(monkeypatch, tmp_path):
|
||||
"""E2E #16: drift paused → /resume force → dark_red/light_red produce FIRE alert.
|
||||
|
||||
This test verifies the full command-driven lifecycle in isolation:
|
||||
- canary starts drift-paused, _should_skip returns drift_paused
|
||||
- /resume force clears canary + user_paused
|
||||
- subsequent detection produces SELL fire through normal FSM path
|
||||
"""
|
||||
import atm.main as _main
|
||||
from atm.commands import Command
|
||||
|
||||
# Canary with mutable pause state
|
||||
class _Canary:
|
||||
def __init__(self): self._p = True
|
||||
@property
|
||||
def is_paused(self): return self._p
|
||||
def resume(self): self._p = False
|
||||
|
||||
canary = _Canary()
|
||||
cfg = MagicMock()
|
||||
cfg.telegram.auto_poll_interval_s = 180
|
||||
cfg.operating_hours = types.SimpleNamespace(enabled=False, _tz_cache=None)
|
||||
|
||||
ctx = _dispatch_ctx(canary=canary, cfg=cfg)
|
||||
|
||||
# 1. While drift-paused, _should_skip returns drift_paused
|
||||
assert _main._should_skip(0.0, ctx.lifecycle, cfg, canary) == "drift_paused"
|
||||
|
||||
# 2. User issues /resume force
|
||||
await _main._dispatch_command(ctx, Command(action="resume", value=1))
|
||||
assert canary.is_paused is False
|
||||
assert _main._should_skip(0.0, ctx.lifecycle, cfg, canary) is None
|
||||
|
||||
# 3. Feed a yellow→light_red sequence through _handle_tick (FSM path)
|
||||
from atm.state_machine import StateMachine, State
|
||||
fsm = StateMachine(lockout_s=60)
|
||||
|
||||
class _N:
|
||||
def __init__(self): self.alerts = []
|
||||
def send(self, a): self.alerts.append(a)
|
||||
|
||||
class _A:
|
||||
def log(self, _e): pass
|
||||
|
||||
notif = _N()
|
||||
audit = _A()
|
||||
cfg_mock = types.SimpleNamespace(alerts=types.SimpleNamespace(fire_on_phase_skip=True))
|
||||
|
||||
_main._handle_tick(fsm, "yellow", 1.0, notif, audit, first_accepted=False, cfg=cfg_mock)
|
||||
_main._handle_tick(fsm, "dark_red", 2.0, notif, audit, first_accepted=False, cfg=cfg_mock)
|
||||
tr = _main._handle_tick(fsm, "light_red", 3.0, notif, audit, first_accepted=False, cfg=cfg_mock)
|
||||
|
||||
# FSM reached fire via normal path
|
||||
assert tr is not None and tr.trigger == "SELL"
|
||||
assert fsm.state == State.IDLE
|
||||
|
||||
214
tests/test_validate.py
Normal file
214
tests/test_validate.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Tests for atm.validate — offline calibration validation.
|
||||
|
||||
Covers the 3 tests from plan section D':
|
||||
17. test_validate_calibration_pass
|
||||
18. test_validate_calibration_fail_reports_top_candidates
|
||||
19. test_validate_calibration_file_not_found
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.config import (
|
||||
CanaryRegion,
|
||||
ColorSpec,
|
||||
Config,
|
||||
DiscordCfg,
|
||||
ROI,
|
||||
TelegramCfg,
|
||||
YAxisCalib,
|
||||
)
|
||||
from atm.detector import DetectionResult
|
||||
from atm.vision import ColorMatch
|
||||
|
||||
|
||||
def _make_config() -> Config:
|
||||
"""Minimal Config with a palette large enough to support top-3 candidates."""
|
||||
colors = {
|
||||
"turquoise": ColorSpec(rgb=(0, 200, 200), tolerance=30),
|
||||
"yellow": ColorSpec(rgb=(255, 255, 0), tolerance=30),
|
||||
"dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=30),
|
||||
"dark_red": ColorSpec(rgb=(165, 42, 42), tolerance=30),
|
||||
"light_green": ColorSpec(rgb=(144, 238, 144), tolerance=30),
|
||||
"light_red": ColorSpec(rgb=(255, 182, 193), tolerance=30),
|
||||
"gray": ColorSpec(rgb=(128, 128, 128), tolerance=30),
|
||||
"background": ColorSpec(rgb=(18, 18, 18), tolerance=15),
|
||||
}
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=ROI(x=0, y=0, w=100, h=100),
|
||||
chart_roi=ROI(x=0, y=0, w=100, h=100),
|
||||
colors=colors,
|
||||
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=100, p2_price=0.0),
|
||||
canary=CanaryRegion(
|
||||
roi=ROI(x=0, y=0, w=10, h=10),
|
||||
baseline_phash="0" * 64,
|
||||
),
|
||||
discord=DiscordCfg(webhook_url="http://localhost/fake"),
|
||||
telegram=TelegramCfg(bot_token="fake_token", chat_id="123"),
|
||||
debounce_depth=1,
|
||||
)
|
||||
|
||||
|
||||
def _write_labels(tmp_path: Path, entries: list[dict]) -> Path:
|
||||
f = tmp_path / "labels.json"
|
||||
f.write_text(json.dumps(entries), encoding="utf-8")
|
||||
return f
|
||||
|
||||
|
||||
def _write_blank_png(tmp_path: Path, name: str) -> Path:
|
||||
"""Write a trivially-valid 10x10 BGR image so cv2.imread returns non-None."""
|
||||
import cv2
|
||||
p = tmp_path / name
|
||||
arr = np.zeros((10, 10, 3), dtype=np.uint8)
|
||||
cv2.imwrite(str(p), arr)
|
||||
return p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 17: PASS path — mocked Detector.step returns expected color
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_pass(monkeypatch, tmp_path):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
img_path = _write_blank_png(tmp_path, "yellow_sample.png")
|
||||
labels = _write_labels(
|
||||
tmp_path,
|
||||
[{"path": str(img_path), "expected": "yellow", "note": "test"}],
|
||||
)
|
||||
|
||||
def fake_step(self, ts, frame=None):
|
||||
return DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=True,
|
||||
rgb=(250, 250, 5),
|
||||
match=ColorMatch(name="yellow", distance=6.0, confidence=0.94),
|
||||
accepted=True,
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("atm.detector.Detector.step", fake_step)
|
||||
|
||||
report = validate_mod.validate_calibration(labels, _make_config())
|
||||
|
||||
assert report.total == 1
|
||||
assert report.passed == 1
|
||||
assert report.failed == 0
|
||||
assert report.all_pass is True
|
||||
rec = report.records[0]
|
||||
assert rec.passed is True
|
||||
assert rec.detected == "yellow"
|
||||
assert rec.expected == "yellow"
|
||||
assert "[PASS]" in report.render()
|
||||
|
||||
# CLI wiring: exit 0
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = labels
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 18: FAIL path — Detector returns wrong color; report lists top 3
|
||||
# candidates and a SUGGESTIONS line with RGB distance.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_fail_reports_top_candidates(monkeypatch, tmp_path):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
img_path = _write_blank_png(tmp_path, "dark_red_sample.png")
|
||||
labels = _write_labels(
|
||||
tmp_path,
|
||||
[{"path": str(img_path), "expected": "dark_red", "note": "missed dark_red"}],
|
||||
)
|
||||
|
||||
# Observed RGB closer to gray than dark_red (like the real 2026-04-17 miss).
|
||||
def fake_step(self, ts, frame=None):
|
||||
return DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=True,
|
||||
rgb=(135, 62, 67),
|
||||
match=ColorMatch(name="gray", distance=45.0, confidence=0.12),
|
||||
accepted=True,
|
||||
color="gray",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("atm.detector.Detector.step", fake_step)
|
||||
|
||||
report = validate_mod.validate_calibration(labels, _make_config())
|
||||
|
||||
assert report.total == 1
|
||||
assert report.failed == 1
|
||||
assert report.all_pass is False
|
||||
|
||||
rec = report.records[0]
|
||||
assert rec.passed is False
|
||||
assert rec.detected == "gray"
|
||||
assert rec.expected == "dark_red"
|
||||
# Top 3 candidates populated (name, score) sorted by RGB distance.
|
||||
assert len(rec.top3) == 3
|
||||
names = [n for n, _ in rec.top3]
|
||||
# dark_red should appear in top candidates since observed RGB(135,62,67)
|
||||
# is reasonably close to dark_red(165,42,42).
|
||||
assert "dark_red" in names
|
||||
|
||||
rendered = report.render()
|
||||
assert "[FAIL]" in rendered
|
||||
assert "Top 3 candidates:" in rendered
|
||||
assert "SUGGESTIONS:" in rendered
|
||||
# The suggestion must mention the expected color's RGB and the measured distance.
|
||||
assert "dark_red" in rendered
|
||||
assert "(165, 42, 42)" in rendered
|
||||
|
||||
# CLI wiring: exit 1
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = labels
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 19: missing label file — clean error, non-zero exit, no stack trace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_validate_calibration_file_not_found(monkeypatch, tmp_path, capsys):
|
||||
from atm import validate as validate_mod
|
||||
|
||||
missing = tmp_path / "nope.json"
|
||||
|
||||
# Library-level: raises ValidationError (not bare FileNotFoundError).
|
||||
with pytest.raises(validate_mod.ValidationError) as exc_info:
|
||||
validate_mod.validate_calibration(missing, _make_config())
|
||||
assert "not found" in str(exc_info.value).lower()
|
||||
|
||||
# CLI-level: graceful sys.exit with non-zero code, message on stderr.
|
||||
import atm.main as _main
|
||||
|
||||
class _Args:
|
||||
label_file = missing
|
||||
|
||||
monkeypatch.setattr("atm.config.Config.load_current", classmethod(lambda cls, d: _make_config()))
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main._cmd_validate_calibration(_Args())
|
||||
assert exc_info.value.code != 0
|
||||
err = capsys.readouterr().err
|
||||
assert "not found" in err.lower()
|
||||
# Ensure no python traceback leaked through.
|
||||
assert "Traceback" not in err
|
||||
Reference in New Issue
Block a user