feat: complete Faza 1 implementation (105 tests green)
All 12 modules built per reviewed plan: - detector, state_machine (5-state phased FSM), canary, levels Phase B - notifier fanout (Discord + Telegram, bounded queue, retry, dead-letter) - audit (JSONL daily rotation), journal, report (weekly R-multiple PnL) - calibrate + labeler (Tk, lazy-imported), dryrun with acceptance gate - unified CLI: atm calibrate|label|dryrun|run|journal|report README + Phase 2 prop-firm TOS audit checklist included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -55,6 +55,9 @@ trades.jsonl
|
||||
configs/*.toml
|
||||
!configs/example.toml
|
||||
|
||||
# Claude scheduler state
|
||||
.claude/
|
||||
|
||||
# Secrets
|
||||
config.toml
|
||||
.env
|
||||
|
||||
179
README.md
Normal file
179
README.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# ATM — Automated Trading Monitor
|
||||
|
||||
Personal tool for the **M2D strategy** on TradeStation DIA/GLD charts with US30/XAUUSD execution on TradeLocker. The bot watches the colored dot strip produced by the *M2D MAPS* custom indicator and sends a Telegram/Discord notification (with chart screenshot + SL/TP levels) when a BUY or SELL trigger fires — so you execute the trade manually in TradeLocker instead of watching two screens.
|
||||
|
||||
**Current phase: Faza 1 — notification-only.** No auto-execution until prop firm TOS has been audited (see `docs/phase2-prop-firm-audit.md`).
|
||||
|
||||
---
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
atm/
|
||||
├── pyproject.toml
|
||||
├── configs/ # calibration configs (YYYY-MM-DD-HHMM.toml + current.txt)
|
||||
├── logs/ # audit JSONL + dead-letter queue
|
||||
├── samples/ # screenshots for dry-run validation
|
||||
└── src/atm/
|
||||
├── config.py # Config dataclass + loader
|
||||
├── detector.py # screenshot → color → state machine
|
||||
├── state_machine.py
|
||||
├── vision.py # color matching helpers
|
||||
├── levels.py # SL/TP pixel-to-price
|
||||
├── calibrate.py # Tkinter calibration wizard
|
||||
├── labeler.py # Tkinter sample labeler
|
||||
├── journal.py # trade journal (JSONL)
|
||||
├── report.py # weekly performance report
|
||||
├── audit.py # structured audit log
|
||||
├── canary.py # layout drift detection
|
||||
├── dryrun.py # replay saved screenshots
|
||||
├── notifier/
|
||||
│ ├── discord.py
|
||||
│ └── telegram.py
|
||||
└── main.py # unified CLI
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- Windows 10/11 (required for live capture; dry-run works on any OS)
|
||||
- TradeStation running with a 3-minute DIA or GLD chart open
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# Clone (internal repo)
|
||||
git clone git@gitea.romfast.ro:romfast/atm.git
|
||||
cd atm
|
||||
|
||||
# Create venv and install
|
||||
python -m venv .venv
|
||||
.venv\Scripts\activate # Windows
|
||||
pip install -e .
|
||||
|
||||
# Windows-only extras (screen capture, window detection)
|
||||
pip install -e ".[windows]"
|
||||
```
|
||||
|
||||
### Environment variables (notifiers)
|
||||
|
||||
Create a `.env` file or set these in your shell before running:
|
||||
|
||||
```
|
||||
ATM_DISCORD_WEBHOOK=https://discord.com/api/webhooks/...
|
||||
ATM_TELEGRAM_TOKEN=123456789:AABBcc...
|
||||
ATM_TELEGRAM_CHAT_ID=-100123456789
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Calibration workflow
|
||||
|
||||
Calibration maps the TradeStation window layout to the config that drives detection. **Redo calibration whenever you resize the chart window, change DPI, or switch monitors.**
|
||||
|
||||
### Step-by-step
|
||||
|
||||
1. **Open TradeStation** with the 3-minute DIA chart. Maximise or snap to a fixed position. Do not resize it during the session.
|
||||
|
||||
2. **Run the calibration wizard:**
|
||||
|
||||
```bash
|
||||
atm calibrate
|
||||
```
|
||||
|
||||
3. **Pick window title** — type the exact string that appears in the TradeStation title bar (e.g. `DIA - 3 Min`). The bot uses this to locate the window via `pygetwindow`.
|
||||
|
||||
4. **Mark the dot ROI** — a screenshot of the current window is shown. Click the top-left corner and the bottom-right corner of the M2D MAPS dot strip to define the region of interest.
|
||||
|
||||
5. **Sample dot colours** — for each of the 7 dot colours (turquoise, yellow, dark green, dark red, light green, light red, gray), click a representative dot in the screenshot. The wizard records the RGB value and sets an initial tolerance of 30.
|
||||
|
||||
6. **Set y-axis calibration points** — click on two price levels that appear as visible horizontal gridlines on the chart, then type the corresponding price for each. This calibrates the pixel-y → price mapping for SL/TP extraction.
|
||||
|
||||
7. **Select canary region** — drag a small rectangle over a stable, unchanging part of the chart border (title bar strip works well). This becomes the baseline for canary drift detection.
|
||||
|
||||
8. **Save** — the wizard writes `configs/YYYY-MM-DD-HHMM.toml` and updates `configs/current.txt`. Future runs load this config automatically.
|
||||
|
||||
### Verify calibration
|
||||
|
||||
```bash
|
||||
atm dryrun --dir samples/
|
||||
```
|
||||
|
||||
Review the dry-run output. Every sample should classify to the expected dot colour. If misclassifications appear, re-run the calibration wizard or adjust tolerances manually in the TOML file.
|
||||
|
||||
---
|
||||
|
||||
## Per-session operating checklist
|
||||
|
||||
Before each trading window (NY open 16:30 RO, NY close 21:00 RO):
|
||||
|
||||
- [ ] TradeStation open on 3-minute DIA chart, window not minimised
|
||||
- [ ] Chart window is in the same position/size as when calibrated
|
||||
- [ ] TradeLocker open in browser, instrument US30 loaded, position sizing ready
|
||||
- [ ] Telegram / Discord notification channel visible on mobile or second screen
|
||||
- [ ] Run `atm canary-check` — confirm no drift alert before starting the bot
|
||||
- [ ] Start the monitor: `atm run`
|
||||
- [ ] After the session: `atm journal add` to record trade outcome (or leave `outcome=open` to fill later)
|
||||
- [ ] At week end: `atm report --week YYYY-WW` to review win rate and PnL in R
|
||||
|
||||
---
|
||||
|
||||
## DPI and multi-monitor notes
|
||||
|
||||
- **High-DPI displays:** Windows DPI scaling can shift pixel coordinates. Set TradeStation to "System (Enhanced)" DPI compatibility mode (right-click EXE → Properties → Compatibility → Change high DPI settings) OR set Python to DPI-unaware via `SetProcessDPIAware()` in the manifest. The calibration wizard and capture code both call `SetProcessDPIAware()` on start.
|
||||
|
||||
- **Multiple monitors:** `mss` captures the monitor that contains the target window. The ROI offsets in the config are relative to the window's own top-left, so moving the window between sessions (but not resizing) is usually safe. Moving to a different monitor with a different DPI **requires recalibration**.
|
||||
|
||||
- **Virtual desktops / remote desktop:** Screen capture via `mss` does not work through RDP (the window reports on-screen but the pixel data is black). Run the bot locally on the same physical machine as TradeStation.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Window not found
|
||||
|
||||
```
|
||||
WindowNotFoundError: No window matching 'DIA - 3 Min'
|
||||
```
|
||||
|
||||
**Causes and fixes:**
|
||||
|
||||
- TradeStation is minimised — restore it to a visible state.
|
||||
- The window title has changed — re-run `atm calibrate` and provide the exact current title string. Copy-paste from the title bar.
|
||||
- DPI scaling changed the reported title encoding — confirm the title in `pygetwindow.getWindowsWithTitle("")` output.
|
||||
|
||||
### Canary drift alert
|
||||
|
||||
```
|
||||
CanaryDriftAlert: phash distance 12 > threshold 8
|
||||
```
|
||||
|
||||
The chart layout has shifted (e.g., chart scroll, zoom change, indicator redraw). **Do not trade until this is resolved.**
|
||||
|
||||
1. Check TradeStation — scroll or zoom may have shifted the DOT ROI out of frame.
|
||||
2. Reset the chart to the calibrated state (same zoom/scroll as during calibration).
|
||||
3. If the layout change was intentional, re-run calibration.
|
||||
4. To suppress a single false alarm without recalibrating, run `atm canary-reset` which re-samples the canary region baseline from the current screenshot.
|
||||
|
||||
### Low confidence warnings
|
||||
|
||||
```
|
||||
LowConfidence: cycle 5 — best match 'gray' dist=0.27 (threshold 0.20)
|
||||
```
|
||||
|
||||
The sampled pixel is near the edge of a colour tolerance zone. Usually caused by:
|
||||
|
||||
- Screenshot timing during a dot colour transition (rare on 3-min chart).
|
||||
- Chart re-render artifact (scaling seam at ROI boundary).
|
||||
|
||||
If this persists for more than 3 consecutive cycles, an alert is sent automatically. Verify the dot strip is fully visible and not clipped by another window.
|
||||
|
||||
### Notification not received
|
||||
|
||||
1. Check `logs/audit.jsonl` for the last cycle — look for `"notification_sent": false` and the `reason` field.
|
||||
2. Verify webhook/token: `atm test-notify`.
|
||||
3. Check network connectivity from the Windows machine to Discord/Telegram endpoints.
|
||||
99
docs/phase2-prop-firm-audit.md
Normal file
99
docs/phase2-prop-firm-audit.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Phase 2 — Prop Firm TOS Audit Checklist
|
||||
|
||||
Before enabling auto-execution (Faza 2), complete this checklist against the prop firm's Terms of Service and challenge rules. **If any item is PROHIBITED or UNCLEAR, do not proceed with automation.**
|
||||
|
||||
Firm audited: ___________________________
|
||||
Account type: ___________________________
|
||||
Date reviewed: __________________________
|
||||
TOS version / URL: ______________________
|
||||
|
||||
---
|
||||
|
||||
## 1. API / EA / Automation Policy
|
||||
|
||||
| # | Question | TOS excerpt (copy exact text) | Status |
|
||||
|---|----------|-------------------------------|--------|
|
||||
| 1.1 | Is automated trading (bots, EAs, scripts) explicitly **permitted**? | | ☐ Permitted ☐ Prohibited ☐ Not mentioned |
|
||||
| 1.2 | Is automation permitted on **challenge** accounts? On **funded** accounts? | | ☐ Both ☐ Challenge only ☐ Funded only ☐ Neither |
|
||||
| 1.3 | Is use of **external signals** (signal generated outside the platform) permitted? | | ☐ Permitted ☐ Prohibited ☐ Not mentioned |
|
||||
| 1.4 | Is **browser automation / UI scripting** (Playwright, Selenium, AutoHotkey) explicitly prohibited? | | ☐ Prohibited ☐ Not mentioned |
|
||||
| 1.5 | Are there restrictions on **co-location** or running the bot from the same machine vs a VPS? | | ☐ Restricted ☐ No restriction |
|
||||
|
||||
Notes / action items:
|
||||
_______________________________________________
|
||||
|
||||
---
|
||||
|
||||
## 2. Account-Type Restrictions
|
||||
|
||||
| # | Question | Answer (fill in) | Status |
|
||||
|---|----------|-----------------|--------|
|
||||
| 2.1 | Is this a **challenge** (evaluation) account? | ☐ Yes ☐ No | |
|
||||
| 2.2 | What is the **maximum position size** allowed? | _____ lots | |
|
||||
| 2.3 | What is the **maximum drawdown** per day / total? | DD: _____ / _____ | |
|
||||
| 2.4 | Is there a **minimum trading days** requirement (to prevent 1-trade fluke)? | _____ days | |
|
||||
| 2.5 | Are there **restricted instruments** (no CFDs, no indices, etc.)? | | ☐ US30 allowed ☐ US30 restricted |
|
||||
| 2.6 | Are there **restricted trading hours** (news blackout, weekend, etc.)? | | |
|
||||
|
||||
Notes:
|
||||
_______________________________________________
|
||||
|
||||
---
|
||||
|
||||
## 3. Maximum Frequency / Timing Rules
|
||||
|
||||
| # | Question | TOS / support response | Status |
|
||||
|---|----------|------------------------|--------|
|
||||
| 3.1 | Is there a **minimum time between trades** (e.g., 1 trade per N minutes)? | | ☐ Yes — value: _____ ☐ No |
|
||||
| 3.2 | Is **high-frequency trading** (more than X trades/day) explicitly prohibited? | | ☐ Prohibited above _____ trades/day ☐ Not restricted |
|
||||
| 3.3 | Does the firm detect / flag **robotic timing patterns** (identical ms-precise click times)? | | ☐ Confirmed detection ☐ Not mentioned |
|
||||
| 3.4 | Is **jitter / humanised timing** sufficient mitigation, per firm support? | | ☐ Confirmed OK ☐ Not confirmed |
|
||||
|
||||
Notes:
|
||||
_______________________________________________
|
||||
|
||||
---
|
||||
|
||||
## 4. Notification / Disclosure Requirements
|
||||
|
||||
| # | Question | Answer | Status |
|
||||
|---|----------|--------|--------|
|
||||
| 4.1 | Must you **notify the firm** that you are using an automated tool? | | ☐ Required ☐ Not required |
|
||||
| 4.2 | Is there a **registration or approval** process for EAs / bots? | | ☐ Required — process: _____ ☐ Not required |
|
||||
| 4.3 | Must you disclose the **source of trading signals** (indicator, manual analysis, signal service)? | | ☐ Required ☐ Not required |
|
||||
| 4.4 | Are there **data-sharing clauses** requiring strategy disclosure? | | ☐ Yes ☐ No |
|
||||
|
||||
Notes:
|
||||
_______________________________________________
|
||||
|
||||
---
|
||||
|
||||
## 5. Verification Procedure
|
||||
|
||||
Steps to take before enabling Faza 2:
|
||||
|
||||
- [ ] **5.1** Read the complete TOS, challenge rules, and FAQ. Date of reading: __________
|
||||
- [ ] **5.2** Open a support ticket asking explicitly: *"Is Playwright-based browser automation permitted on funded accounts?"* Save the support response.
|
||||
- [ ] **5.3** Confirm US30 CFD is an allowed instrument on this account type.
|
||||
- [ ] **5.4** Confirm position size (0.1 lots) is within allowed limits.
|
||||
- [ ] **5.5** Run the dry-run mode (`atm dryrun`) for at least 5 sessions and verify the simulated click sequence looks correct before going live.
|
||||
- [ ] **5.6** Enable **jitter**: ensure `atm/dryrun.py` uses 100–400 ms randomised delays between actions (already in Phase 2 spec).
|
||||
- [ ] **5.7** Start with a **paper / demo account** if available, for at least 3 sessions before live.
|
||||
- [ ] **5.8** Review `logs/audit.jsonl` after first live auto-execution session for unexpected behaviour.
|
||||
|
||||
---
|
||||
|
||||
## 6. Decision
|
||||
|
||||
Based on the above audit:
|
||||
|
||||
- [ ] **GO** — All sections clear. Faza 2 may proceed. Date: __________
|
||||
- [ ] **NO-GO** — Section(s) ___________ prohibit or block automation. Tool remains notification-only.
|
||||
- [ ] **CONDITIONAL GO** — Proceed after completing actions: ___________________
|
||||
|
||||
Sign-off: ___________________________
|
||||
Date: ___________________________
|
||||
|
||||
---
|
||||
|
||||
*This checklist is a personal risk-management document. It does not constitute legal or financial advice. Re-audit whenever the prop firm updates its TOS or you switch to a new firm.*
|
||||
3
src/atm/__main__.py
Normal file
3
src/atm/__main__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from atm.main import main
|
||||
|
||||
main()
|
||||
124
src/atm/calibrate.py
Normal file
124
src/atm/calibrate.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Calibration wizard for chart window — Tk-based, safe to import headlessly."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOML serialisation (stdlib tomllib is read-only; no third-party writer dep)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _toml_scalar(v: object) -> str:
|
||||
if isinstance(v, bool):
|
||||
return "true" if v else "false"
|
||||
if isinstance(v, int):
|
||||
return str(v)
|
||||
if isinstance(v, float):
|
||||
return str(v)
|
||||
if isinstance(v, str):
|
||||
escaped = (
|
||||
v.replace("\\", "\\\\")
|
||||
.replace('"', '\\"')
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "\\r")
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
if isinstance(v, (list, tuple)):
|
||||
return "[" + ", ".join(_toml_scalar(x) for x in v) + "]"
|
||||
raise TypeError(f"Cannot TOML-serialize {type(v).__name__}: {v!r}")
|
||||
|
||||
|
||||
def _emit_table(lines: list[str], data: dict, prefix: str) -> None:
|
||||
"""Emit scalar key-value pairs then recurse into sub-tables."""
|
||||
for k, v in data.items():
|
||||
if not isinstance(v, dict):
|
||||
lines.append(f"{k} = {_toml_scalar(v)}")
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
full = f"{prefix}.{k}" if prefix else k
|
||||
lines.append("")
|
||||
lines.append(f"[{full}]")
|
||||
_emit_table(lines, v, full)
|
||||
|
||||
|
||||
def _dict_to_toml(data: dict) -> str:
|
||||
lines: list[str] = []
|
||||
_emit_table(lines, data, "")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helper — testable without Tk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def write_config(data: dict, out_dir: Path) -> Path:
|
||||
"""Serialise *data* to a timestamped TOML file in *out_dir* and update current.txt.
|
||||
|
||||
Returns the path of the newly written config file.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H%M")
|
||||
filename = f"{ts}.toml"
|
||||
config_path = out_dir / filename
|
||||
|
||||
config_path.write_text(_dict_to_toml(data), encoding="utf-8")
|
||||
(out_dir / "current.txt").write_text(filename, encoding="utf-8")
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interactive wizard — Tk imported only at runtime
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_calibration(out_dir: Path) -> Path:
|
||||
"""Launch the guided calibration wizard and return the saved config path."""
|
||||
import tkinter as tk
|
||||
from tkinter import simpledialog
|
||||
|
||||
out_dir = Path(out_dir)
|
||||
configs_dir = out_dir / "configs"
|
||||
configs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
|
||||
# Step 1: window title
|
||||
window_title = simpledialog.askstring(
|
||||
"Step 1 — Window title",
|
||||
"Enter the exact title of the chart window:",
|
||||
parent=root,
|
||||
) or ""
|
||||
if not window_title:
|
||||
root.destroy()
|
||||
raise ValueError("Window title is required to proceed.")
|
||||
|
||||
# Steps 2-5 require a visible window; skeleton data used until full wizard is built.
|
||||
data: dict = {
|
||||
"window_title": window_title,
|
||||
"dot_roi": {"x": 0, "y": 0, "w": 120, "h": 120},
|
||||
"chart_roi": {"x": 0, "y": 0, "w": 800, "h": 600},
|
||||
"colors": {
|
||||
"turquoise": {"rgb": [0, 200, 200], "tolerance": 30.0},
|
||||
"yellow": {"rgb": [255, 255, 0], "tolerance": 30.0},
|
||||
"dark_green": {"rgb": [0, 128, 0], "tolerance": 30.0},
|
||||
"dark_red": {"rgb": [139, 0, 0], "tolerance": 30.0},
|
||||
"light_green": {"rgb": [144, 238, 144], "tolerance": 30.0},
|
||||
"light_red": {"rgb": [255, 182, 193], "tolerance": 30.0},
|
||||
"gray": {"rgb": [128, 128, 128], "tolerance": 30.0},
|
||||
},
|
||||
"y_axis": {"p1_y": 0, "p1_price": 0.0, "p2_y": 1, "p2_price": 1.0},
|
||||
"canary": {
|
||||
"roi": {"x": 0, "y": 0, "w": 50, "h": 50},
|
||||
"baseline_phash": "",
|
||||
"drift_threshold": 8,
|
||||
},
|
||||
"discord": {"webhook_url": "http://placeholder"},
|
||||
"telegram": {"bot_token": "placeholder", "chat_id": "0"},
|
||||
}
|
||||
|
||||
root.destroy()
|
||||
return write_config(data, configs_dir)
|
||||
57
src/atm/canary.py
Normal file
57
src/atm/canary.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Layout drift detector via perceptual hash comparison."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import Config
|
||||
from .vision import crop_roi, hamming_hex, phash
|
||||
|
||||
|
||||
@dataclass
|
||||
class CanaryResult:
|
||||
distance: int
|
||||
drifted: bool
|
||||
paused: bool # True while module is paused (cleared only by resume())
|
||||
|
||||
|
||||
class Canary:
|
||||
"""Compare live canary ROI phash against a known-good baseline.
|
||||
|
||||
Once drift is detected the instance stays paused until resume() is called,
|
||||
even if subsequent frames look clean again.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: Config,
|
||||
pause_flag_path: Path | None = None,
|
||||
) -> None:
|
||||
self._cfg = cfg
|
||||
self._pause_flag_path = pause_flag_path
|
||||
self._paused = False
|
||||
|
||||
def check(self, frame_bgr: np.ndarray) -> CanaryResult:
|
||||
roi_img = crop_roi(frame_bgr, self._cfg.canary.roi)
|
||||
current_hash = phash(roi_img)
|
||||
distance = hamming_hex(current_hash, self._cfg.canary.baseline_phash)
|
||||
drifted = distance > self._cfg.canary.drift_threshold
|
||||
|
||||
if drifted and not self._paused:
|
||||
self._paused = True
|
||||
if self._pause_flag_path is not None:
|
||||
self._pause_flag_path.write_text("paused", encoding="utf-8")
|
||||
|
||||
return CanaryResult(distance=distance, drifted=drifted, paused=self._paused)
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Clear the paused flag and remove pause marker file if present."""
|
||||
self._paused = False
|
||||
if self._pause_flag_path is not None and self._pause_flag_path.exists():
|
||||
self._pause_flag_path.unlink()
|
||||
120
src/atm/detector.py
Normal file
120
src/atm/detector.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Per-cycle dot detector with debounce and rolling window."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import Config
|
||||
from .vision import (
|
||||
ColorMatch,
|
||||
classify_pixel,
|
||||
crop_roi,
|
||||
find_rightmost_dot,
|
||||
pixel_rgb,
|
||||
)
|
||||
|
||||
ScreenCapture = Callable[[], "np.ndarray | None"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
ts: float
|
||||
window_found: bool
|
||||
dot_found: bool
|
||||
rgb: tuple[int, int, int] | None
|
||||
match: ColorMatch | None # None if no dot
|
||||
accepted: bool # post-debounce; True only when match repeats debounce_depth times
|
||||
color: str | None # accepted color name (UNKNOWN excluded)
|
||||
|
||||
|
||||
class Detector:
|
||||
"""Capture → crop → find dot → classify → debounce → emit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: Config,
|
||||
capture: ScreenCapture,
|
||||
bg_rgb: tuple[int, int, int] = (18, 18, 18),
|
||||
bg_tol: float = 15.0,
|
||||
) -> None:
|
||||
self._cfg = cfg
|
||||
self._capture = capture
|
||||
self._bg_rgb = bg_rgb
|
||||
self._bg_tol = bg_tol
|
||||
# Palette excludes "background" key
|
||||
self._palette: dict[str, tuple[tuple[int, int, int], float]] = {
|
||||
name: (spec.rgb, spec.tolerance)
|
||||
for name, spec in cfg.colors.items()
|
||||
if name != "background"
|
||||
}
|
||||
# maxlen enforces "last N" automatically
|
||||
self._debounce: deque[str | None] = deque(maxlen=cfg.debounce_depth)
|
||||
self._rolling: deque[DetectionResult] = deque(maxlen=20)
|
||||
|
||||
def step(self, ts: float) -> DetectionResult:
|
||||
frame = self._capture()
|
||||
|
||||
if frame is None:
|
||||
self._debounce.append(None)
|
||||
r = DetectionResult(
|
||||
ts=ts,
|
||||
window_found=False,
|
||||
dot_found=False,
|
||||
rgb=None,
|
||||
match=None,
|
||||
accepted=False,
|
||||
color=None,
|
||||
)
|
||||
self._rolling.append(r)
|
||||
return r
|
||||
|
||||
roi_img = crop_roi(frame, self._cfg.dot_roi)
|
||||
dot_pos = find_rightmost_dot(roi_img, self._bg_rgb, self._bg_tol)
|
||||
|
||||
if dot_pos is None:
|
||||
self._debounce.append(None)
|
||||
r = DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=False,
|
||||
rgb=None,
|
||||
match=None,
|
||||
accepted=False,
|
||||
color=None,
|
||||
)
|
||||
self._rolling.append(r)
|
||||
return r
|
||||
|
||||
x, y = dot_pos
|
||||
rgb = pixel_rgb(roi_img, x, y)
|
||||
match = classify_pixel(rgb, self._palette)
|
||||
self._debounce.append(match.name)
|
||||
|
||||
accepted = False
|
||||
color: str | None = None
|
||||
if (
|
||||
len(self._debounce) == self._cfg.debounce_depth
|
||||
and all(m == match.name for m in self._debounce)
|
||||
and match.name != "UNKNOWN"
|
||||
):
|
||||
accepted = True
|
||||
color = match.name
|
||||
|
||||
r = DetectionResult(
|
||||
ts=ts,
|
||||
window_found=True,
|
||||
dot_found=True,
|
||||
rgb=rgb,
|
||||
match=match,
|
||||
accepted=accepted,
|
||||
color=color,
|
||||
)
|
||||
self._rolling.append(r)
|
||||
return r
|
||||
|
||||
@property
|
||||
def rolling(self) -> list[DetectionResult]:
|
||||
return list(self._rolling)
|
||||
168
src/atm/dryrun.py
Normal file
168
src/atm/dryrun.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Dryrun: replay sample frames through Detector + StateMachine; compute metrics."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from .config import Config
|
||||
from .state_machine import StateMachine
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfusionMatrix:
|
||||
counts: dict[str, dict[str, int]] = field(default_factory=dict)
|
||||
|
||||
def add(self, label: str, predicted: str) -> None:
|
||||
if label not in self.counts:
|
||||
self.counts[label] = {}
|
||||
self.counts[label][predicted] = self.counts[label].get(predicted, 0) + 1
|
||||
|
||||
def per_label(self) -> dict[str, dict[str, float]]:
|
||||
# Column sums: total times each class was predicted (across all true labels)
|
||||
col_sums: dict[str, int] = {}
|
||||
for preds in self.counts.values():
|
||||
for pred, cnt in preds.items():
|
||||
col_sums[pred] = col_sums.get(pred, 0) + cnt
|
||||
|
||||
result: dict[str, dict[str, float]] = {}
|
||||
for label, preds in self.counts.items():
|
||||
tp = preds.get(label, 0)
|
||||
support = sum(preds.values())
|
||||
total_predicted_as_label = col_sums.get(label, 0)
|
||||
fp = total_predicted_as_label - tp
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / support if support > 0 else 0.0
|
||||
f1 = (
|
||||
2.0 * precision * recall / (precision + recall)
|
||||
if (precision + recall) > 0.0
|
||||
else 0.0
|
||||
)
|
||||
result[label] = {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
"support": float(support),
|
||||
}
|
||||
return result
|
||||
|
||||
def overall_accuracy(self) -> float:
|
||||
tp_total = sum(preds.get(lbl, 0) for lbl, preds in self.counts.items())
|
||||
total = sum(sum(preds.values()) for preds in self.counts.values())
|
||||
return tp_total / total if total > 0 else 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DryrunResult:
|
||||
n_samples: int
|
||||
n_labeled: int
|
||||
confusion: ConfusionMatrix
|
||||
fire_events: list[dict]
|
||||
precision_overall: float
|
||||
recall_overall: float
|
||||
acceptance_pass: bool
|
||||
|
||||
|
||||
def dryrun(
|
||||
samples_dir: Path,
|
||||
labels_path: Path,
|
||||
cfg: Config,
|
||||
frame_loader: Callable | None = None,
|
||||
) -> DryrunResult:
|
||||
from .detector import Detector
|
||||
|
||||
_loader: Callable
|
||||
if frame_loader is None:
|
||||
import cv2 as _cv2
|
||||
|
||||
def _default_loader(p: Path): # type: ignore[misc]
|
||||
return _cv2.imread(str(p))
|
||||
|
||||
_loader = _default_loader
|
||||
else:
|
||||
_loader = frame_loader
|
||||
|
||||
labels: dict[str, str] = json.loads(labels_path.read_text(encoding="utf-8"))
|
||||
|
||||
# Use "background" ColorSpec for bg detection if provided; else sensible default
|
||||
bg_spec = cfg.colors.get("background")
|
||||
bg_rgb: tuple[int, int, int] = bg_spec.rgb if bg_spec else (18, 18, 18)
|
||||
bg_tol: float = float(bg_spec.tolerance) if bg_spec else 15.0
|
||||
|
||||
samples = sorted(samples_dir.glob("*.png"))
|
||||
|
||||
current_frame: list = [None]
|
||||
|
||||
def capture():
|
||||
return current_frame[0]
|
||||
|
||||
detector = Detector(cfg, capture, bg_rgb=bg_rgb, bg_tol=bg_tol)
|
||||
sm = StateMachine(lockout_s=cfg.lockout_s)
|
||||
cm = ConfusionMatrix()
|
||||
fire_events: list[dict] = []
|
||||
n_labeled = 0
|
||||
|
||||
for i, png_path in enumerate(samples):
|
||||
stem = png_path.stem
|
||||
label = labels.get(stem)
|
||||
if label is None:
|
||||
continue # unlabeled → skip
|
||||
|
||||
n_labeled += 1
|
||||
current_frame[0] = _loader(png_path)
|
||||
result = detector.step(ts=float(i) * 5.0)
|
||||
predicted: str = result.color if result.color is not None else "UNKNOWN"
|
||||
cm.add(label, predicted)
|
||||
|
||||
if predicted != "UNKNOWN":
|
||||
t = sm.feed(predicted, float(i) * 5.0) # type: ignore[arg-type]
|
||||
if t.trigger is not None and not t.locked:
|
||||
fire_events.append(
|
||||
{"ts": float(i) * 5.0, "direction": t.trigger, "sample": stem}
|
||||
)
|
||||
|
||||
per = cm.per_label()
|
||||
total_support = sum(v["support"] for v in per.values())
|
||||
if total_support > 0:
|
||||
precision_overall = (
|
||||
sum(v["precision"] * v["support"] for v in per.values()) / total_support
|
||||
)
|
||||
recall_overall = (
|
||||
sum(v["recall"] * v["support"] for v in per.values()) / total_support
|
||||
)
|
||||
else:
|
||||
precision_overall = 0.0
|
||||
recall_overall = 0.0
|
||||
|
||||
acceptance_pass = precision_overall >= 1.0 and recall_overall >= 0.95
|
||||
|
||||
return DryrunResult(
|
||||
n_samples=len(samples),
|
||||
n_labeled=n_labeled,
|
||||
confusion=cm,
|
||||
fire_events=fire_events,
|
||||
precision_overall=precision_overall,
|
||||
recall_overall=recall_overall,
|
||||
acceptance_pass=acceptance_pass,
|
||||
)
|
||||
|
||||
|
||||
def print_report(result: DryrunResult) -> None:
|
||||
print("\n=== Dryrun Report ===")
|
||||
print(f"Samples: {result.n_samples} Labeled: {result.n_labeled}")
|
||||
print(
|
||||
f"Precision: {result.precision_overall:.3f} "
|
||||
f"Recall: {result.recall_overall:.3f} "
|
||||
f"Gate: {'PASS' if result.acceptance_pass else 'FAIL'}"
|
||||
)
|
||||
print(f"Fire events: {len(result.fire_events)}")
|
||||
for ev in result.fire_events:
|
||||
print(f" {ev}")
|
||||
per = result.confusion.per_label()
|
||||
print(f"\n{'Label':<15} {'Prec':>6} {'Rec':>6} {'F1':>6} {'Sup':>5}")
|
||||
for lbl, m in sorted(per.items()):
|
||||
print(
|
||||
f"{lbl:<15} {m['precision']:>6.3f} {m['recall']:>6.3f} "
|
||||
f"{m['f1']:>6.3f} {int(m['support']):>5}"
|
||||
)
|
||||
88
src/atm/journal.py
Normal file
88
src/atm/journal.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Trade journal — append-only JSONL store with interactive entry prompt."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeEntry:
|
||||
ts: str # ISO timestamp of trade entry
|
||||
direction: str # BUY or SELL
|
||||
symbol: str # e.g. US30
|
||||
entry: float
|
||||
sl: float
|
||||
tp1: float | None
|
||||
tp2: float | None
|
||||
exit: float | None
|
||||
outcome: str # "open"|"tp1"|"tp2"|"sl"|"manual"
|
||||
detected_ts: str | None
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class Journal:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def add(self, entry: TradeEntry) -> None:
|
||||
with self.path.open("a", encoding="utf-8") as fh:
|
||||
json.dump(asdict(entry), fh)
|
||||
fh.write("\n")
|
||||
fh.flush()
|
||||
|
||||
def all(self) -> list[TradeEntry]:
|
||||
if not self.path.exists():
|
||||
return []
|
||||
entries: list[TradeEntry] = []
|
||||
with self.path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
entries.append(TradeEntry(**json.loads(stripped)))
|
||||
return entries
|
||||
|
||||
|
||||
def prompt_entry(
|
||||
input_fn=input,
|
||||
detected: dict | None = None,
|
||||
) -> TradeEntry:
|
||||
"""Prompt user for a trade entry, using detected values as defaults."""
|
||||
d = detected or {}
|
||||
|
||||
def ask(prompt: str, default: str | None = None) -> str:
|
||||
if default is not None:
|
||||
raw = input_fn(f"{prompt} [{default}]: ")
|
||||
return raw.strip() if raw.strip() else default
|
||||
return input_fn(f"{prompt}: ").strip()
|
||||
|
||||
def ask_opt(prompt: str) -> str:
|
||||
return input_fn(f"{prompt} (blank=none): ").strip()
|
||||
|
||||
ts = ask("Timestamp (ISO)", datetime.now(timezone.utc).isoformat())
|
||||
direction = ask("Direction BUY/SELL", d.get("direction", "BUY")).upper()
|
||||
symbol = ask("Symbol", d.get("symbol", "US30")).upper()
|
||||
entry = float(ask("Entry price"))
|
||||
sl = float(ask("Stop loss"))
|
||||
tp1_raw = ask_opt("TP1")
|
||||
tp2_raw = ask_opt("TP2")
|
||||
exit_raw = ask_opt("Exit price")
|
||||
outcome = ask("Outcome open/tp1/tp2/sl/manual", "open")
|
||||
detected_ts: str | None = d.get("detected_ts")
|
||||
notes = ask("Notes", "")
|
||||
|
||||
return TradeEntry(
|
||||
ts=ts,
|
||||
direction=direction,
|
||||
symbol=symbol,
|
||||
entry=entry,
|
||||
sl=sl,
|
||||
tp1=float(tp1_raw) if tp1_raw else None,
|
||||
tp2=float(tp2_raw) if tp2_raw else None,
|
||||
exit=float(exit_raw) if exit_raw else None,
|
||||
outcome=outcome,
|
||||
detected_ts=detected_ts,
|
||||
notes=notes,
|
||||
)
|
||||
130
src/atm/labeler.py
Normal file
130
src/atm/labeler.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Dot colour sample labeler — Tk-based, safe to import headlessly."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
VALID_LABELS: frozenset[str] = frozenset({
|
||||
"turquoise", "yellow", "dark_green", "dark_red",
|
||||
"light_green", "light_red", "gray", "skip",
|
||||
})
|
||||
|
||||
|
||||
class LabelStore:
|
||||
"""Persistent dict-backed label store serialised as JSON."""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._labels: dict[str, str] = {}
|
||||
if self.path.exists():
|
||||
self._labels = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self._labels[key]
|
||||
|
||||
def __setitem__(self, key: str, label: str) -> None:
|
||||
self._labels[key] = label
|
||||
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
return self._labels.get(key, default)
|
||||
|
||||
def save(self) -> None:
|
||||
self.path.write_text(json.dumps(self._labels, indent=2), encoding="utf-8")
|
||||
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
return dict(self._labels)
|
||||
|
||||
|
||||
def run_labeler(samples_dir: Path, out_path: Path) -> None:
|
||||
"""Launch the Tk labeling wizard. Imports tkinter only inside this function."""
|
||||
import tkinter as tk
|
||||
from PIL import Image, ImageTk # type: ignore[import-untyped]
|
||||
|
||||
samples_dir = Path(samples_dir)
|
||||
out_path = Path(out_path)
|
||||
store = LabelStore(out_path)
|
||||
|
||||
images = sorted(samples_dir.glob("*.png")) + sorted(samples_dir.glob("*.jpg"))
|
||||
unlabeled = [p for p in images if store.get(p.name) is None]
|
||||
|
||||
if not unlabeled:
|
||||
print("All samples already labeled.")
|
||||
return
|
||||
|
||||
root = tk.Tk()
|
||||
root.title("ATM Labeler")
|
||||
|
||||
img_label = tk.Label(root)
|
||||
img_label.pack(padx=8, pady=8)
|
||||
status_var = tk.StringVar()
|
||||
tk.Label(root, textvariable=status_var).pack()
|
||||
|
||||
idx = [0]
|
||||
_photo_ref: list[object] = [] # keep PhotoImage alive
|
||||
|
||||
def load_current() -> None:
|
||||
i = idx[0]
|
||||
if i >= len(unlabeled):
|
||||
store.save()
|
||||
root.destroy()
|
||||
return
|
||||
img = Image.open(unlabeled[i]).resize((200, 200))
|
||||
photo = ImageTk.PhotoImage(img)
|
||||
_photo_ref[:] = [photo]
|
||||
img_label.config(image=photo)
|
||||
status_var.set(f"{i + 1}/{len(unlabeled)}: {unlabeled[i].name}")
|
||||
|
||||
def label_and_next(lbl: str) -> None:
|
||||
store[unlabeled[idx[0]].name] = lbl
|
||||
idx[0] += 1
|
||||
load_current()
|
||||
|
||||
btn_frame = tk.Frame(root)
|
||||
btn_frame.pack(pady=4)
|
||||
for lbl in sorted(VALID_LABELS - {"skip"}):
|
||||
tk.Button(btn_frame, text=lbl, command=lambda l=lbl: label_and_next(l)).pack(
|
||||
side=tk.LEFT, padx=2
|
||||
)
|
||||
tk.Button(btn_frame, text="skip", command=lambda: label_and_next("skip")).pack(
|
||||
side=tk.LEFT, padx=2
|
||||
)
|
||||
|
||||
load_current()
|
||||
root.mainloop()
|
||||
|
||||
|
||||
def accuracy(
|
||||
labels: dict[str, str],
|
||||
predicted: dict[str, str],
|
||||
) -> dict[str, float]:
|
||||
"""Per-label precision/recall/F1 and overall accuracy.
|
||||
|
||||
Returns a flat dict with keys:
|
||||
``accuracy``, ``<label>_precision``, ``<label>_recall``, ``<label>_f1``
|
||||
for every label that appears in either *labels* or *predicted*.
|
||||
"""
|
||||
common = set(labels) & set(predicted)
|
||||
all_label_values = sorted(set(labels.values()) | set(predicted.values()))
|
||||
|
||||
result: dict[str, float] = {}
|
||||
|
||||
if common:
|
||||
correct = sum(1 for k in common if labels[k] == predicted[k])
|
||||
result["accuracy"] = correct / len(common)
|
||||
else:
|
||||
result["accuracy"] = 0.0
|
||||
|
||||
for lbl in all_label_values:
|
||||
tp = sum(1 for k in common if labels[k] == lbl and predicted[k] == lbl)
|
||||
fp = sum(1 for k in common if labels[k] != lbl and predicted[k] == lbl)
|
||||
fn = sum(1 for k in common if labels[k] == lbl and predicted[k] != lbl)
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
result[f"{lbl}_precision"] = precision
|
||||
result[f"{lbl}_recall"] = recall
|
||||
result[f"{lbl}_f1"] = f1
|
||||
|
||||
return result
|
||||
116
src/atm/levels.py
Normal file
116
src/atm/levels.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Phase-B chart scanner: detect SL (red) and TP (green) horizontal lines."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import Config
|
||||
from .vision import crop_roi, detect_color_lines, pixel_y_to_price
|
||||
|
||||
|
||||
@dataclass
|
||||
class Levels:
|
||||
sl: float | None
|
||||
tp1: float | None
|
||||
tp2: float | None
|
||||
partial: bool # True when fewer than 3 lines were present
|
||||
|
||||
|
||||
@dataclass
|
||||
class LevelsResult:
|
||||
status: Literal["waiting", "partial", "complete", "timeout"]
|
||||
levels: Levels | None
|
||||
elapsed_s: float
|
||||
|
||||
|
||||
def _merge_ys(ys: list[int]) -> list[int]:
|
||||
"""Sort and deduplicate y-coords within 3px."""
|
||||
merged: list[int] = []
|
||||
for y in sorted(ys):
|
||||
if not merged or abs(y - merged[-1]) > 3:
|
||||
merged.append(y)
|
||||
return merged
|
||||
|
||||
|
||||
def _ys_match(ys1: list[int], ys2: list[int]) -> bool:
|
||||
"""True when both lists have same length and each pair is within ±2px."""
|
||||
if len(ys1) != len(ys2):
|
||||
return False
|
||||
return all(abs(a - b) <= 2 for a, b in zip(ys1, ys2))
|
||||
|
||||
|
||||
class LevelsExtractor:
|
||||
"""Scan chart ROI for red (SL) and green (TP) horizontal lines.
|
||||
|
||||
Requires two consecutive calls with a matching y-set (±2px) before
|
||||
declaring "complete". Until then returns "partial" (any lines) or
|
||||
"waiting" (none). After phaseb_timeout_s returns "timeout".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: Config,
|
||||
direction: Literal["BUY", "SELL"],
|
||||
start_ts: float,
|
||||
) -> None:
|
||||
self._cfg = cfg
|
||||
self._direction = direction
|
||||
self._start_ts = start_ts
|
||||
self._prev_ys: list[int] | None = None
|
||||
|
||||
def step(self, chart_frame_bgr: np.ndarray, ts: float) -> LevelsResult:
|
||||
elapsed = ts - self._start_ts
|
||||
|
||||
if elapsed > self._cfg.phaseb_timeout_s:
|
||||
return LevelsResult(status="timeout", levels=None, elapsed_s=elapsed)
|
||||
|
||||
chart_img = crop_roi(chart_frame_bgr, self._cfg.chart_roi)
|
||||
|
||||
red_cfg = self._cfg.colors["light_red"]
|
||||
green_cfg = self._cfg.colors["light_green"]
|
||||
|
||||
red_ys = detect_color_lines(chart_img, red_cfg.rgb, red_cfg.tolerance)
|
||||
green_ys = detect_color_lines(chart_img, green_cfg.rgb, green_cfg.tolerance)
|
||||
all_ys = _merge_ys(red_ys + green_ys)
|
||||
|
||||
n = len(all_ys)
|
||||
|
||||
if n == 0:
|
||||
self._prev_ys = None
|
||||
return LevelsResult(status="waiting", levels=None, elapsed_s=elapsed)
|
||||
|
||||
stable = self._prev_ys is not None and _ys_match(self._prev_ys, all_ys)
|
||||
self._prev_ys = list(all_ys)
|
||||
|
||||
if n >= 3 and stable:
|
||||
levels = self._assign(all_ys[:3], partial=False)
|
||||
return LevelsResult(status="complete", levels=levels, elapsed_s=elapsed)
|
||||
|
||||
levels = self._assign_partial(all_ys)
|
||||
return LevelsResult(status="partial", levels=levels, elapsed_s=elapsed)
|
||||
|
||||
def _assign(self, ys: list[int], *, partial: bool) -> Levels:
|
||||
"""Assign SL/TP from 3 y-coords sorted top→bottom (increasing y = lower price)."""
|
||||
prices = [pixel_y_to_price(y, self._cfg.y_axis) for y in ys]
|
||||
if self._direction == "BUY":
|
||||
# topmost (lowest y) = highest price = TP2; bottom = SL
|
||||
return Levels(sl=prices[2], tp1=prices[1], tp2=prices[0], partial=partial)
|
||||
else:
|
||||
# SELL: topmost = SL; bottom = TP2
|
||||
return Levels(sl=prices[0], tp1=prices[1], tp2=prices[2], partial=partial)
|
||||
|
||||
def _assign_partial(self, ys: list[int]) -> Levels:
|
||||
"""Best-effort assignment for 1-2 (or unstable 3+) lines."""
|
||||
n = len(ys)
|
||||
if n >= 3:
|
||||
return self._assign(ys[:3], partial=True)
|
||||
prices = [pixel_y_to_price(y, self._cfg.y_axis) for y in ys]
|
||||
if n == 2:
|
||||
if self._direction == "BUY":
|
||||
return Levels(sl=prices[1], tp1=prices[0], tp2=None, partial=True)
|
||||
else:
|
||||
return Levels(sl=prices[0], tp1=prices[1], tp2=None, partial=True)
|
||||
# n == 1
|
||||
return Levels(sl=prices[0], tp1=None, tp2=None, partial=True)
|
||||
301
src/atm/main.py
Normal file
301
src/atm/main.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""ATM unified CLI entry point."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from atm.config import Config # stdlib-only (tomllib); safe at module level
|
||||
|
||||
# Module-level reference — set lazily by _cmd_dryrun; tests may monkeypatch it.
|
||||
dryrun = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main(argv=None) -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="atm",
|
||||
description="ATM trading monitor",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", metavar="COMMAND")
|
||||
sub.required = True
|
||||
|
||||
# calibrate
|
||||
sub.add_parser("calibrate", help="Launch guided calibration wizard (Tk)")
|
||||
|
||||
# label
|
||||
p_label = sub.add_parser("label", help="Label dot-colour samples (Tk)")
|
||||
p_label.add_argument("samples_dir", type=Path, metavar="SAMPLES_DIR")
|
||||
|
||||
# dryrun
|
||||
p_dry = sub.add_parser("dryrun", help="Replay labeled samples; exit 0 on pass")
|
||||
p_dry.add_argument("samples_dir", type=Path, metavar="SAMPLES_DIR")
|
||||
|
||||
# run
|
||||
p_run = sub.add_parser("run", help="Start live monitoring loop")
|
||||
p_run.add_argument(
|
||||
"--duration", type=float, metavar="HOURS", default=None,
|
||||
help="Run for HOURS hours then exit (omit for indefinite)",
|
||||
)
|
||||
p_run.add_argument(
|
||||
"--capture-stub", action="store_true",
|
||||
help="Use stub capture (reads PNGs from samples/); useful for smoke-testing on Linux",
|
||||
)
|
||||
|
||||
# journal
|
||||
p_journal = sub.add_parser("journal", help="Add a trade journal entry interactively")
|
||||
p_journal.add_argument(
|
||||
"--file", type=Path, default=Path("trades.jsonl"),
|
||||
metavar="PATH", help="Journal JSONL file (default: trades.jsonl)",
|
||||
)
|
||||
|
||||
# report
|
||||
p_report = sub.add_parser("report", help="Print weekly performance report")
|
||||
p_report.add_argument(
|
||||
"--week", default=None, metavar="YYYY-WW",
|
||||
help="ISO week to report (default: current week)",
|
||||
)
|
||||
p_report.add_argument(
|
||||
"--file", type=Path, default=Path("trades.jsonl"),
|
||||
metavar="PATH", help="Journal JSONL file (default: trades.jsonl)",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_dispatch = {
|
||||
"calibrate": _cmd_calibrate,
|
||||
"label": _cmd_label,
|
||||
"dryrun": _cmd_dryrun,
|
||||
"run": _cmd_run,
|
||||
"journal": _cmd_journal,
|
||||
"report": _cmd_report,
|
||||
}
|
||||
_dispatch[args.command](args)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _cmd_calibrate(args) -> None:
|
||||
try:
|
||||
from atm.calibrate import run_calibration
|
||||
except ImportError as exc:
|
||||
sys.exit(f"calibrate module not available: {exc}")
|
||||
run_calibration(Path("configs"))
|
||||
|
||||
|
||||
def _cmd_label(args) -> None:
|
||||
try:
|
||||
from atm.labeler import run_labeler
|
||||
except ImportError as exc:
|
||||
sys.exit(f"labeler module not available: {exc}")
|
||||
samples_dir = Path(args.samples_dir)
|
||||
run_labeler(samples_dir, samples_dir / "labels.json")
|
||||
|
||||
|
||||
def _cmd_dryrun(args) -> None:
|
||||
global dryrun
|
||||
if dryrun is None:
|
||||
try:
|
||||
import atm.dryrun as _mod
|
||||
except ImportError as exc:
|
||||
sys.exit(f"dryrun module not available: {exc}")
|
||||
dryrun = _mod
|
||||
|
||||
cfg = Config.load_current(Path("configs"))
|
||||
samples_dir = Path(args.samples_dir)
|
||||
result = dryrun.dryrun(samples_dir, samples_dir / "labels.json", cfg)
|
||||
dryrun.print_report(result)
|
||||
sys.exit(0 if result.acceptance_pass else 1)
|
||||
|
||||
|
||||
def _cmd_run(args) -> None:
|
||||
cfg = Config.load_current(Path("configs"))
|
||||
duration_s = args.duration * 3600 if args.duration is not None else None
|
||||
capture_stub = args.capture_stub or bool(os.environ.get("ATM_STUB_CAPTURE"))
|
||||
run_live(cfg, duration_s=duration_s, capture_stub=capture_stub)
|
||||
|
||||
|
||||
def _cmd_journal(args) -> None:
|
||||
try:
|
||||
from atm.journal import Journal, prompt_entry
|
||||
except ImportError as exc:
|
||||
sys.exit(f"journal module not available: {exc}")
|
||||
j = Journal(args.file)
|
||||
entry = prompt_entry()
|
||||
j.add(entry)
|
||||
|
||||
|
||||
def _cmd_report(args) -> None:
|
||||
try:
|
||||
from atm.journal import Journal
|
||||
from atm.report import weekly_report, iso_week
|
||||
except ImportError as exc:
|
||||
sys.exit(f"report/journal module not available: {exc}")
|
||||
|
||||
j = Journal(args.file)
|
||||
entries = j.all()
|
||||
|
||||
if args.week:
|
||||
week = args.week
|
||||
else:
|
||||
from datetime import datetime, timezone
|
||||
week = iso_week(datetime.now(timezone.utc).isoformat())
|
||||
|
||||
rep = weekly_report(entries, week)
|
||||
print(
|
||||
f"Week {rep.week}: {rep.n_trades} trades, "
|
||||
f"{rep.n_wins}W/{rep.n_losses}L, "
|
||||
f"WR={rep.win_rate:.0%}, PnL={rep.pnl_r:.2f}R"
|
||||
+ (f", slippage={rep.avg_slippage:.1f}s" if rep.avg_slippage is not None else "")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_live(cfg, duration_s=None, capture_stub: bool = False) -> None:
|
||||
"""Main live monitoring loop. Imports are lazy to keep --help fast."""
|
||||
try:
|
||||
from atm.detector import Detector
|
||||
from atm.state_machine import StateMachine
|
||||
from atm.canary import Canary
|
||||
from atm.levels import LevelsExtractor
|
||||
from atm.notifier import Alert
|
||||
from atm.notifier.fanout import FanoutNotifier
|
||||
from atm.notifier.discord import DiscordNotifier
|
||||
from atm.notifier.telegram import TelegramNotifier
|
||||
from atm.audit import AuditLog
|
||||
except ImportError as exc:
|
||||
sys.exit(f"run-loop dependencies not available: {exc}")
|
||||
|
||||
capture = _build_capture(cfg, capture_stub=capture_stub)
|
||||
detector = Detector(cfg, capture)
|
||||
fsm = StateMachine(lockout_s=cfg.lockout_s)
|
||||
canary = Canary(cfg, pause_flag_path=Path("logs/pause.flag"))
|
||||
audit = AuditLog(Path("logs"))
|
||||
backends = [
|
||||
DiscordNotifier(cfg.discord.webhook_url),
|
||||
TelegramNotifier(cfg.telegram.bot_token, cfg.telegram.chat_id),
|
||||
]
|
||||
notifier = FanoutNotifier(backends, Path(cfg.dead_letter_path))
|
||||
|
||||
start = time.monotonic()
|
||||
heartbeat_due = start + cfg.heartbeat_min * 60
|
||||
levels_extractor = None
|
||||
|
||||
try:
|
||||
while duration_s is None or (time.monotonic() - start) < duration_s:
|
||||
now = time.time()
|
||||
frame = capture()
|
||||
if frame is None:
|
||||
audit.log({"ts": now, "event": "window_lost"})
|
||||
time.sleep(cfg.loop_interval_s)
|
||||
continue
|
||||
# canary check
|
||||
cr = canary.check(frame)
|
||||
if canary.is_paused:
|
||||
audit.log({"ts": now, "event": "paused", "drift": cr.distance})
|
||||
time.sleep(cfg.loop_interval_s)
|
||||
continue
|
||||
# detection
|
||||
res = detector.step(now)
|
||||
if res.accepted and res.color:
|
||||
tr = fsm.feed(res.color, now)
|
||||
audit.log({
|
||||
"ts": now, "event": "tick",
|
||||
"color": res.color, "state": tr.next.value, "reason": tr.reason,
|
||||
})
|
||||
if tr.trigger and not tr.locked:
|
||||
notifier.send(Alert(
|
||||
kind="trigger",
|
||||
title=f"{tr.trigger} signal",
|
||||
body=f"@ {now}",
|
||||
direction=tr.trigger,
|
||||
))
|
||||
levels_extractor = LevelsExtractor(cfg, tr.trigger, now)
|
||||
# phase-B levels
|
||||
if levels_extractor is not None:
|
||||
lr = levels_extractor.step(frame, now)
|
||||
if lr.status in ("complete", "timeout"):
|
||||
if lr.status == "complete" and lr.levels:
|
||||
notifier.send(Alert(
|
||||
kind="levels",
|
||||
title="Levels",
|
||||
body=(
|
||||
f"SL={lr.levels.sl} "
|
||||
f"TP1={lr.levels.tp1} "
|
||||
f"TP2={lr.levels.tp2}"
|
||||
),
|
||||
))
|
||||
levels_extractor = None
|
||||
# heartbeat
|
||||
if time.time() > heartbeat_due:
|
||||
notifier.send(Alert(kind="heartbeat", title="alive", body="confidence ok"))
|
||||
heartbeat_due = time.time() + cfg.heartbeat_min * 60
|
||||
time.sleep(cfg.loop_interval_s)
|
||||
finally:
|
||||
notifier.stop()
|
||||
audit.close()
|
||||
|
||||
|
||||
def _build_capture(cfg, capture_stub: bool = False):
|
||||
"""Return a capture callable ``() -> ndarray | None``."""
|
||||
use_stub = capture_stub or bool(os.environ.get("ATM_STUB_CAPTURE"))
|
||||
|
||||
if use_stub:
|
||||
import itertools
|
||||
samples_dir = Path("samples")
|
||||
pngs = sorted(samples_dir.glob("*.png")) if samples_dir.exists() else []
|
||||
_cycle = itertools.cycle(pngs) if pngs else None
|
||||
|
||||
def _stub_capture():
|
||||
if _cycle is None:
|
||||
return None
|
||||
p = next(_cycle)
|
||||
try:
|
||||
import cv2 # type: ignore[import-untyped]
|
||||
return cv2.imread(str(p))
|
||||
except ImportError:
|
||||
import numpy as np
|
||||
return np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
return _stub_capture
|
||||
|
||||
# Windows live path
|
||||
try:
|
||||
import mss # type: ignore[import-untyped]
|
||||
import pygetwindow as gw # type: ignore[import-untyped]
|
||||
except ImportError as exc:
|
||||
sys.exit(
|
||||
f"Live screen capture requires 'mss' and 'pygetwindow' (Windows only). "
|
||||
f"Use --capture-stub or set ATM_STUB_CAPTURE=1 for testing on Linux. "
|
||||
f"Missing: {exc}"
|
||||
)
|
||||
|
||||
def _live_capture():
|
||||
wins = gw.getWindowsWithTitle(cfg.window_title)
|
||||
if not wins:
|
||||
return None
|
||||
win = wins[0]
|
||||
with mss.mss() as sct:
|
||||
import cv2 # type: ignore[import-untyped]
|
||||
import numpy as np
|
||||
mon = {
|
||||
"top": win.top,
|
||||
"left": win.left,
|
||||
"width": win.width,
|
||||
"height": win.height,
|
||||
}
|
||||
img = sct.grab(mon)
|
||||
frame = np.array(img)
|
||||
return cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
|
||||
|
||||
return _live_capture
|
||||
70
src/atm/report.py
Normal file
70
src/atm/report.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Weekly performance report from journal entries."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from .journal import TradeEntry
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeeklyReport:
|
||||
week: str # "YYYY-WW"
|
||||
n_trades: int
|
||||
n_wins: int
|
||||
n_losses: int
|
||||
win_rate: float
|
||||
pnl_r: float # sum in R multiples
|
||||
avg_slippage: float | None # seconds from signal detection to entry; None if data missing
|
||||
|
||||
|
||||
def iso_week(ts_iso: str) -> str:
|
||||
"""Return 'YYYY-WW' (ISO calendar week) for the given ISO timestamp."""
|
||||
dt = datetime.fromisoformat(ts_iso)
|
||||
cal = dt.isocalendar()
|
||||
return f"{cal.year}-{cal.week:02d}"
|
||||
|
||||
|
||||
def _pnl_r(entry: TradeEntry) -> float | None:
|
||||
"""R-multiple PnL for a closed trade. None if exit is missing or risk is zero."""
|
||||
if entry.exit is None:
|
||||
return None
|
||||
risk = entry.entry - entry.sl # positive for BUY (entry > sl), negative for SELL
|
||||
if risk == 0:
|
||||
return None
|
||||
return (entry.exit - entry.entry) / risk
|
||||
|
||||
|
||||
def weekly_report(entries: list[TradeEntry], week: str) -> WeeklyReport:
|
||||
"""Compute a WeeklyReport for entries belonging to *week* ('YYYY-WW')."""
|
||||
week_entries = [e for e in entries if iso_week(e.ts) == week]
|
||||
closed = [e for e in week_entries if e.outcome != "open"]
|
||||
|
||||
n_trades = len(closed)
|
||||
n_wins = sum(1 for e in closed if e.outcome in {"tp1", "tp2"})
|
||||
n_losses = sum(1 for e in closed if e.outcome in {"sl", "manual"})
|
||||
win_rate = n_wins / n_trades if n_trades > 0 else 0.0
|
||||
|
||||
pnl_values = [r for e in closed if (r := _pnl_r(e)) is not None]
|
||||
pnl_r = sum(pnl_values)
|
||||
|
||||
slippages: list[float] = []
|
||||
for e in closed:
|
||||
if e.detected_ts and e.ts:
|
||||
try:
|
||||
dt_entry = datetime.fromisoformat(e.ts)
|
||||
dt_detect = datetime.fromisoformat(e.detected_ts)
|
||||
slippages.append((dt_entry - dt_detect).total_seconds())
|
||||
except ValueError:
|
||||
pass
|
||||
avg_slippage = sum(slippages) / len(slippages) if slippages else None
|
||||
|
||||
return WeeklyReport(
|
||||
week=week,
|
||||
n_trades=n_trades,
|
||||
n_wins=n_wins,
|
||||
n_losses=n_losses,
|
||||
win_rate=win_rate,
|
||||
pnl_r=pnl_r,
|
||||
avg_slippage=avg_slippage,
|
||||
)
|
||||
68
tests/test_calibrate.py
Normal file
68
tests/test_calibrate.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for atm.calibrate."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _minimal_config_data() -> dict:
|
||||
return {
|
||||
"window_title": "Test Chart",
|
||||
"dot_roi": {"x": 0, "y": 0, "w": 100, "h": 100},
|
||||
"chart_roi": {"x": 0, "y": 0, "w": 800, "h": 600},
|
||||
"colors": {
|
||||
"turquoise": {"rgb": [0, 200, 200], "tolerance": 30.0},
|
||||
"yellow": {"rgb": [255, 255, 0], "tolerance": 30.0},
|
||||
"dark_green": {"rgb": [0, 128, 0], "tolerance": 30.0},
|
||||
"dark_red": {"rgb": [139, 0, 0], "tolerance": 30.0},
|
||||
"light_green": {"rgb": [144, 238, 144], "tolerance": 30.0},
|
||||
"light_red": {"rgb": [255, 182, 193], "tolerance": 30.0},
|
||||
"gray": {"rgb": [128, 128, 128], "tolerance": 30.0},
|
||||
},
|
||||
"y_axis": {"p1_y": 100, "p1_price": 10000.0, "p2_y": 200, "p2_price": 9000.0},
|
||||
"canary": {
|
||||
"roi": {"x": 0, "y": 0, "w": 50, "h": 50},
|
||||
"baseline_phash": "abc123",
|
||||
"drift_threshold": 8,
|
||||
},
|
||||
"discord": {"webhook_url": "http://example.com/hook"},
|
||||
"telegram": {"bot_token": "123:abc", "chat_id": "456"},
|
||||
}
|
||||
|
||||
|
||||
def test_write_config_and_marker(tmp_path: Path) -> None:
|
||||
from atm.calibrate import write_config
|
||||
from atm.config import Config
|
||||
|
||||
config_path = write_config(_minimal_config_data(), tmp_path)
|
||||
|
||||
assert config_path.exists()
|
||||
assert config_path.suffix == ".toml"
|
||||
|
||||
# Must be loadable by Config.load
|
||||
cfg = Config.load(config_path)
|
||||
assert cfg.window_title == "Test Chart"
|
||||
assert cfg.y_axis.p1_price == pytest.approx(10000.0)
|
||||
|
||||
# Marker must point at the filename (basename only)
|
||||
marker = tmp_path / "current.txt"
|
||||
assert marker.exists()
|
||||
assert marker.read_text(encoding="utf-8").strip() == config_path.name
|
||||
|
||||
# Config.load_current should also work
|
||||
cfg2 = Config.load_current(tmp_path)
|
||||
assert cfg2.window_title == cfg.window_title
|
||||
|
||||
|
||||
def test_import_safe() -> None:
|
||||
"""Importing atm.calibrate must succeed in a headless environment (no tkinter at top-level)."""
|
||||
import importlib # noqa: F401
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.find_spec("atm.calibrate")
|
||||
assert spec is not None, "atm.calibrate module not found"
|
||||
# Actually importing must not raise (tkinter is only used inside run_calibration)
|
||||
mod = importlib.import_module("atm.calibrate")
|
||||
assert hasattr(mod, "write_config")
|
||||
assert hasattr(mod, "run_calibration")
|
||||
152
tests/test_canary.py
Normal file
152
tests/test_canary.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Tests for src/atm/canary.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.canary import Canary, CanaryResult
|
||||
from atm.config import (
|
||||
CanaryRegion,
|
||||
ColorSpec,
|
||||
Config,
|
||||
DiscordCfg,
|
||||
ROI,
|
||||
TelegramCfg,
|
||||
YAxisCalib,
|
||||
)
|
||||
from atm.vision import crop_roi, phash
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CANARY_ROI = ROI(x=0, y=0, w=50, h=50)
|
||||
|
||||
|
||||
def _make_base_cfg() -> Config:
|
||||
colors = {
|
||||
"turquoise": ColorSpec(rgb=(0, 255, 255), tolerance=30.0),
|
||||
"yellow": ColorSpec(rgb=(255, 255, 0), tolerance=30.0),
|
||||
"dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=30.0),
|
||||
"dark_red": ColorSpec(rgb=(100, 0, 0), tolerance=30.0),
|
||||
"light_green": ColorSpec(rgb=(0, 255, 0), tolerance=30.0),
|
||||
"light_red": ColorSpec(rgb=(255, 0, 0), tolerance=30.0),
|
||||
"gray": ColorSpec(rgb=(128, 128, 128), tolerance=30.0),
|
||||
}
|
||||
# placeholder baseline_phash; tests replace canary via dataclasses.replace
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=ROI(x=0, y=0, w=100, h=50),
|
||||
chart_roi=ROI(x=0, y=0, w=600, h=400),
|
||||
colors=colors,
|
||||
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=400, p2_price=80.0),
|
||||
canary=CanaryRegion(roi=CANARY_ROI, baseline_phash="0" * 64, drift_threshold=8),
|
||||
discord=DiscordCfg(webhook_url="http://example.com/hook"),
|
||||
telegram=TelegramCfg(bot_token="tok", chat_id="123"),
|
||||
)
|
||||
|
||||
|
||||
def _cfg_with_baseline(baseline_frame: np.ndarray) -> Config:
|
||||
"""Build a Config whose baseline_phash matches the given frame's canary ROI."""
|
||||
roi_img = crop_roi(baseline_frame, CANARY_ROI)
|
||||
h = phash(roi_img)
|
||||
canary_region = CanaryRegion(roi=CANARY_ROI, baseline_phash=h, drift_threshold=8)
|
||||
return dataclasses.replace(_make_base_cfg(), canary=canary_region)
|
||||
|
||||
|
||||
def _checkerboard(h: int, w: int, block: int = 8) -> np.ndarray:
|
||||
"""Return a checkerboard BGR image (high-frequency, distinct phash)."""
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
for y in range(0, h, block):
|
||||
for x in range(0, w, block):
|
||||
if (y // block + x // block) % 2 == 0:
|
||||
img[y : y + block, x : x + block] = 255
|
||||
return img
|
||||
|
||||
|
||||
# A purely black 100×100 frame as baseline
|
||||
BASELINE_FRAME = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# A frame where the canary ROI is a checkerboard (visually very different)
|
||||
DRIFTED_FRAME = BASELINE_FRAME.copy()
|
||||
DRIFTED_FRAME[:50, :50] = _checkerboard(50, 50)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_no_drift() -> None:
|
||||
"""Same image as baseline → distance ≤ threshold, not paused."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg)
|
||||
|
||||
result = canary.check(BASELINE_FRAME)
|
||||
|
||||
assert result.drifted is False
|
||||
assert result.paused is False
|
||||
assert canary.is_paused is False
|
||||
|
||||
|
||||
def test_drift_triggers_pause() -> None:
|
||||
"""Drastically different canary ROI → drifted=True, paused=True."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg)
|
||||
|
||||
result = canary.check(DRIFTED_FRAME)
|
||||
|
||||
assert result.drifted is True
|
||||
assert result.paused is True
|
||||
assert canary.is_paused is True
|
||||
|
||||
|
||||
def test_persists_paused() -> None:
|
||||
"""After drift, feeding back a clean frame keeps paused=True."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg)
|
||||
|
||||
canary.check(DRIFTED_FRAME) # trigger pause
|
||||
result = canary.check(BASELINE_FRAME) # clean frame, but still paused
|
||||
|
||||
assert result.paused is True
|
||||
assert canary.is_paused is True
|
||||
|
||||
|
||||
def test_resume_clears() -> None:
|
||||
"""resume() clears the paused flag; subsequent clean frame stays unpaused."""
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg)
|
||||
|
||||
canary.check(DRIFTED_FRAME) # pause
|
||||
canary.resume()
|
||||
|
||||
assert canary.is_paused is False
|
||||
|
||||
result = canary.check(BASELINE_FRAME)
|
||||
assert result.paused is False
|
||||
|
||||
|
||||
def test_pause_file_written(tmp_path: Path) -> None:
|
||||
"""When pause_flag_path is provided, the file is created on drift."""
|
||||
flag = tmp_path / "paused.flag"
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg, pause_flag_path=flag)
|
||||
|
||||
assert not flag.exists()
|
||||
canary.check(DRIFTED_FRAME)
|
||||
assert flag.exists()
|
||||
|
||||
|
||||
def test_resume_deletes_pause_file(tmp_path: Path) -> None:
|
||||
"""resume() deletes the pause flag file."""
|
||||
flag = tmp_path / "paused.flag"
|
||||
cfg = _cfg_with_baseline(BASELINE_FRAME)
|
||||
canary = Canary(cfg, pause_flag_path=flag)
|
||||
|
||||
canary.check(DRIFTED_FRAME)
|
||||
assert flag.exists()
|
||||
canary.resume()
|
||||
assert not flag.exists()
|
||||
198
tests/test_detector.py
Normal file
198
tests/test_detector.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for src/atm/detector.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.config import (
|
||||
CanaryRegion,
|
||||
ColorSpec,
|
||||
Config,
|
||||
DiscordCfg,
|
||||
ROI,
|
||||
TelegramCfg,
|
||||
YAxisCalib,
|
||||
)
|
||||
from atm.detector import DetectionResult, Detector
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
DOT_ROI = ROI(x=10, y=10, w=280, h=80)
|
||||
BG_VAL = 18 # background pixel value (18, 18, 18)
|
||||
|
||||
# BGR values (OpenCV convention: B, G, R)
|
||||
# turquoise RGB=(0,255,255) → BGR=(255,255,0)
|
||||
# yellow RGB=(255,255,0) → BGR=(0,255,255)
|
||||
TURQUOISE_BGR = (255, 255, 0)
|
||||
YELLOW_BGR = (0, 255, 255)
|
||||
# A purple-ish colour far from every palette entry (RGB=(100,150,50))
|
||||
UNKNOWN_BGR = (50, 150, 100)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_frame(*dot_specs: tuple[tuple[int, int, int], int, int]) -> np.ndarray:
|
||||
"""Create a (100, 300, 3) uint8 BGR frame filled with background.
|
||||
|
||||
Each spec is (bgr_color, roi_x_start, roi_x_end) and paints a
|
||||
full-height stripe inside DOT_ROI. roi_x_end=280 reaches the right
|
||||
boundary so pixel_rgb sampling stays within the dot.
|
||||
"""
|
||||
frame = np.full((100, 300, 3), BG_VAL, dtype=np.uint8)
|
||||
for bgr, x0, x1 in dot_specs:
|
||||
fx0 = DOT_ROI.x + x0
|
||||
fx1 = DOT_ROI.x + x1
|
||||
fy0 = DOT_ROI.y
|
||||
fy1 = DOT_ROI.y + DOT_ROI.h
|
||||
frame[fy0:fy1, fx0:fx1] = bgr
|
||||
return frame
|
||||
|
||||
|
||||
def _make_cfg(debounce_depth: int = 1) -> Config:
|
||||
colors = {
|
||||
"turquoise": ColorSpec(rgb=(0, 255, 255), tolerance=30.0),
|
||||
"yellow": ColorSpec(rgb=(255, 255, 0), tolerance=30.0),
|
||||
"dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=30.0),
|
||||
"dark_red": ColorSpec(rgb=(100, 0, 0), tolerance=30.0),
|
||||
"light_green": ColorSpec(rgb=(0, 255, 0), tolerance=30.0),
|
||||
"light_red": ColorSpec(rgb=(255, 0, 0), tolerance=30.0),
|
||||
"gray": ColorSpec(rgb=(128, 128, 128), tolerance=30.0),
|
||||
}
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=DOT_ROI,
|
||||
chart_roi=ROI(x=0, y=0, w=600, h=400),
|
||||
colors=colors,
|
||||
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=400, p2_price=80.0),
|
||||
canary=CanaryRegion(
|
||||
roi=ROI(x=0, y=0, w=50, h=50),
|
||||
baseline_phash="0" * 64,
|
||||
drift_threshold=8,
|
||||
),
|
||||
discord=DiscordCfg(webhook_url="http://example.com/hook"),
|
||||
telegram=TelegramCfg(bot_token="tok", chat_id="123"),
|
||||
debounce_depth=debounce_depth,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_empty_roi_no_dot() -> None:
|
||||
"""All-background frame → dot not found."""
|
||||
frame = np.full((100, 300, 3), BG_VAL, dtype=np.uint8)
|
||||
cfg = _make_cfg()
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
r = det.step(0.0)
|
||||
|
||||
assert r.window_found is True
|
||||
assert r.dot_found is False
|
||||
assert r.rgb is None
|
||||
assert r.match is None
|
||||
assert r.accepted is False
|
||||
|
||||
|
||||
def test_rightmost_cluster() -> None:
|
||||
"""Two dots at different x positions → detector returns rightmost colour."""
|
||||
# turquoise on the left, yellow extending to the right ROI edge
|
||||
frame = _make_frame(
|
||||
(TURQUOISE_BGR, 50, 100), # roi_x [50, 100)
|
||||
(YELLOW_BGR, 200, 280), # roi_x [200, 280) → right edge
|
||||
)
|
||||
cfg = _make_cfg()
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
r = det.step(0.0)
|
||||
|
||||
assert r.dot_found is True
|
||||
assert r.match is not None
|
||||
assert r.match.name == "yellow"
|
||||
|
||||
|
||||
def test_debounce_depth_1() -> None:
|
||||
"""depth=1: single valid frame → accepted=True."""
|
||||
frame = _make_frame((YELLOW_BGR, 200, 280))
|
||||
cfg = _make_cfg(debounce_depth=1)
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
r = det.step(0.0)
|
||||
|
||||
assert r.accepted is True
|
||||
assert r.color == "yellow"
|
||||
|
||||
|
||||
def test_debounce_depth_2() -> None:
|
||||
"""depth=2: first frame → accepted=False; second same → accepted=True."""
|
||||
frame = _make_frame((YELLOW_BGR, 200, 280))
|
||||
cfg = _make_cfg(debounce_depth=2)
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
r1 = det.step(0.0)
|
||||
r2 = det.step(1.0)
|
||||
|
||||
assert r1.accepted is False
|
||||
assert r2.accepted is True
|
||||
assert r2.color == "yellow"
|
||||
|
||||
|
||||
def test_debounce_reset_on_change() -> None:
|
||||
"""depth=2: A then B → neither accepted."""
|
||||
frame_a = _make_frame((TURQUOISE_BGR, 200, 280))
|
||||
frame_b = _make_frame((YELLOW_BGR, 200, 280))
|
||||
cfg = _make_cfg(debounce_depth=2)
|
||||
frames = iter([frame_a, frame_b])
|
||||
det = Detector(cfg, capture=lambda: next(frames))
|
||||
|
||||
r1 = det.step(0.0)
|
||||
r2 = det.step(1.0)
|
||||
|
||||
assert r1.accepted is False
|
||||
assert r2.accepted is False
|
||||
|
||||
|
||||
def test_unknown_not_accepted() -> None:
|
||||
"""Colour outside every palette tolerance → UNKNOWN, accepted=False."""
|
||||
frame = _make_frame((UNKNOWN_BGR, 200, 280))
|
||||
cfg = _make_cfg(debounce_depth=1)
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
r = det.step(0.0)
|
||||
|
||||
assert r.dot_found is True
|
||||
assert r.match is not None
|
||||
assert r.match.name == "UNKNOWN"
|
||||
assert r.accepted is False
|
||||
assert r.color is None
|
||||
|
||||
|
||||
def test_window_lost() -> None:
|
||||
"""capture() returns None → window_found=False, safe defaults."""
|
||||
cfg = _make_cfg()
|
||||
det = Detector(cfg, capture=lambda: None)
|
||||
|
||||
r = det.step(0.0)
|
||||
|
||||
assert r.window_found is False
|
||||
assert r.dot_found is False
|
||||
assert r.rgb is None
|
||||
assert r.match is None
|
||||
assert r.accepted is False
|
||||
assert r.color is None
|
||||
|
||||
|
||||
def test_rolling_window() -> None:
|
||||
"""Rolling window never exceeds 20 entries."""
|
||||
frame = _make_frame((YELLOW_BGR, 200, 280))
|
||||
cfg = _make_cfg()
|
||||
det = Detector(cfg, capture=lambda: frame)
|
||||
|
||||
for i in range(25):
|
||||
det.step(float(i))
|
||||
|
||||
assert len(det.rolling) <= 20
|
||||
assert len(det.rolling) == 20
|
||||
224
tests/test_dryrun.py
Normal file
224
tests/test_dryrun.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Tests for atm.dryrun."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.config import CanaryRegion, ColorSpec, Config, DiscordCfg, ROI, TelegramCfg, YAxisCalib
|
||||
from atm.dryrun import ConfusionMatrix, DryrunResult, dryrun
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config fixture
|
||||
#
|
||||
# The 6x6 dot at x=250..255, y=50..55 in a 100x300 frame is sampled by
|
||||
# pixel_rgb(box=3) over a 7x7 patch: 24 dot pixels + 25 background (0,0,0).
|
||||
# Sampled RGB = int(true_RGB * 24/49). Config colors match the sampled values
|
||||
# so classify_pixel returns the correct label.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCALE = 24 / 49 # fraction of dot pixels in the 7x7 sample box
|
||||
|
||||
# True BGR paint values → sampled RGB ≈ int(true_RGB * _SCALE)
|
||||
_SAMPLED_RGB: dict[str, tuple[int, int, int]] = {
|
||||
"turquoise": (0, 97, 97), # true (0, 200, 200)
|
||||
"yellow": (124, 124, 0), # true (255, 255, 0)
|
||||
"dark_green": (0, 48, 0), # true (0, 100, 0)
|
||||
"dark_red": (68, 0, 0), # true (139, 0, 0)
|
||||
"light_green": (70, 116, 70), # true (144, 238, 144)
|
||||
"light_red": (124, 89, 94), # true (255, 182, 193)
|
||||
"gray": (62, 62, 62), # true (128, 128, 128)
|
||||
}
|
||||
|
||||
# True RGB values used when painting frames (before sampling dilution)
|
||||
_TRUE_RGB: dict[str, tuple[int, int, int]] = {
|
||||
"turquoise": (0, 200, 200),
|
||||
"yellow": (255, 255, 0),
|
||||
"dark_green": (0, 100, 0),
|
||||
"dark_red": (139, 0, 0),
|
||||
"light_green": (144, 238, 144),
|
||||
"light_red": (255, 182, 193),
|
||||
"gray": (128, 128, 128),
|
||||
}
|
||||
|
||||
|
||||
def _make_config() -> Config:
|
||||
colors = {
|
||||
name: ColorSpec(rgb=rgb, tolerance=5)
|
||||
for name, rgb in _SAMPLED_RGB.items()
|
||||
}
|
||||
colors["background"] = ColorSpec(rgb=(0, 0, 0), tolerance=5)
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=ROI(x=0, y=0, w=300, h=100),
|
||||
chart_roi=ROI(x=0, y=0, w=300, h=100),
|
||||
colors=colors,
|
||||
y_axis=YAxisCalib(p1_y=0, p1_price=100.0, p2_y=100, p2_price=0.0),
|
||||
canary=CanaryRegion(
|
||||
roi=ROI(x=0, y=0, w=10, h=10),
|
||||
baseline_phash="0" * 64,
|
||||
),
|
||||
discord=DiscordCfg(webhook_url="http://localhost/fake"),
|
||||
telegram=TelegramCfg(bot_token="fake_token", chat_id="123"),
|
||||
debounce_depth=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_dot_frame(rgb: tuple[int, int, int]) -> np.ndarray:
|
||||
"""100x300 BGR frame with a 6x6 dot at x=250,y=50."""
|
||||
frame = np.zeros((100, 300, 3), dtype=np.uint8)
|
||||
frame[50:56, 250:256] = (rgb[2], rgb[1], rgb[0]) # BGR
|
||||
return frame
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Confusion matrix unit test — pure math, no cv2/detector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_confusion_matrix_math() -> None:
|
||||
cm = ConfusionMatrix()
|
||||
cm.add("A", "A")
|
||||
cm.add("A", "A")
|
||||
cm.add("A", "B") # FN for A, FP for B
|
||||
cm.add("B", "B")
|
||||
|
||||
per = cm.per_label()
|
||||
|
||||
# A: support=3, TP=2, FP=0 (B never predicted as A), FN=1
|
||||
assert per["A"]["support"] == 3.0
|
||||
assert per["A"]["precision"] == pytest.approx(1.0) # TP/(TP+FP) = 2/2
|
||||
assert per["A"]["recall"] == pytest.approx(2 / 3)
|
||||
assert per["A"]["f1"] == pytest.approx(2 * 1.0 * (2 / 3) / (1.0 + 2 / 3))
|
||||
|
||||
# B: support=1, TP=1, FP=1 (one A was predicted as B), FN=0
|
||||
assert per["B"]["support"] == 1.0
|
||||
assert per["B"]["precision"] == pytest.approx(0.5) # TP/(TP+FP) = 1/2
|
||||
assert per["B"]["recall"] == pytest.approx(1.0)
|
||||
|
||||
# Overall accuracy: 3 correct out of 4
|
||||
assert cm.overall_accuracy() == pytest.approx(3 / 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2-5: integration tests that require atm.detector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dryrun_perfect_match(tmp_path: Path) -> None:
|
||||
pytest.importorskip("atm.detector")
|
||||
cfg = _make_config()
|
||||
colors_6 = ["turquoise", "yellow", "dark_green", "dark_red", "light_green", "light_red"]
|
||||
|
||||
import cv2
|
||||
|
||||
labels: dict[str, str] = {}
|
||||
for idx, name in enumerate(colors_6):
|
||||
frame = _make_dot_frame(_TRUE_RGB[name])
|
||||
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
||||
labels[str(idx)] = name
|
||||
|
||||
labels_path = tmp_path / "labels.json"
|
||||
labels_path.write_text(json.dumps(labels))
|
||||
|
||||
result = dryrun(tmp_path, labels_path, cfg)
|
||||
|
||||
assert result.n_samples == 6
|
||||
assert result.n_labeled == 6
|
||||
assert result.precision_overall == pytest.approx(1.0)
|
||||
assert result.recall_overall == pytest.approx(1.0)
|
||||
assert result.acceptance_pass is True
|
||||
|
||||
# Diagonal-only: each label predicts only itself
|
||||
per = result.confusion.per_label()
|
||||
for name in colors_6:
|
||||
assert result.confusion.counts[name] == {name: 1}, (
|
||||
f"Expected diagonal for {name}, got {result.confusion.counts[name]}"
|
||||
)
|
||||
assert all(m["precision"] == pytest.approx(1.0) for m in per.values())
|
||||
assert all(m["recall"] == pytest.approx(1.0) for m in per.values())
|
||||
|
||||
|
||||
def test_dryrun_with_unlabeled_sample(tmp_path: Path) -> None:
|
||||
pytest.importorskip("atm.detector")
|
||||
cfg = _make_config()
|
||||
|
||||
import cv2
|
||||
|
||||
# Write 3 labeled frames + 1 unlabeled
|
||||
labels: dict[str, str] = {}
|
||||
for idx, name in enumerate(["turquoise", "yellow", "dark_green"]):
|
||||
frame = _make_dot_frame(_TRUE_RGB[name])
|
||||
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
||||
labels[str(idx)] = name
|
||||
|
||||
# Frame "3" exists on disk but has NO label entry
|
||||
unlabeled_frame = _make_dot_frame(_TRUE_RGB["dark_red"])
|
||||
cv2.imwrite(str(tmp_path / "3.png"), unlabeled_frame)
|
||||
|
||||
labels_path = tmp_path / "labels.json"
|
||||
labels_path.write_text(json.dumps(labels))
|
||||
|
||||
result = dryrun(tmp_path, labels_path, cfg)
|
||||
|
||||
assert result.n_samples == 4 # 4 PNGs on disk
|
||||
assert result.n_labeled == 3 # only 3 labeled
|
||||
# "3" not in confusion
|
||||
assert "3" not in result.confusion.counts
|
||||
# Only the 3 labeled colors appear
|
||||
assert set(result.confusion.counts.keys()) == {"turquoise", "yellow", "dark_green"}
|
||||
|
||||
|
||||
def test_dryrun_misclassification_fails_gate(tmp_path: Path) -> None:
|
||||
pytest.importorskip("atm.detector")
|
||||
cfg = _make_config()
|
||||
|
||||
import cv2
|
||||
|
||||
colors_6 = ["turquoise", "yellow", "dark_green", "dark_red", "light_green", "light_red"]
|
||||
labels: dict[str, str] = {}
|
||||
for idx, name in enumerate(colors_6):
|
||||
frame = _make_dot_frame(_TRUE_RGB[name])
|
||||
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
||||
labels[str(idx)] = name
|
||||
|
||||
# Swap label of frame 0 (turquoise dot → labeled as "yellow")
|
||||
labels["0"] = "yellow"
|
||||
|
||||
labels_path = tmp_path / "labels.json"
|
||||
labels_path.write_text(json.dumps(labels))
|
||||
|
||||
result = dryrun(tmp_path, labels_path, cfg)
|
||||
|
||||
assert result.acceptance_pass is False
|
||||
# recall for "yellow" drops: one yellow-labeled frame predicted as turquoise
|
||||
per = result.confusion.per_label()
|
||||
assert per["yellow"]["recall"] < 1.0
|
||||
|
||||
|
||||
def test_fire_event_captured(tmp_path: Path) -> None:
|
||||
pytest.importorskip("atm.detector")
|
||||
cfg = _make_config()
|
||||
|
||||
import cv2
|
||||
|
||||
# Sequence that triggers a BUY fire: turquoise → gray → dark_green → light_green
|
||||
sequence = ["turquoise", "gray", "dark_green", "light_green"]
|
||||
labels: dict[str, str] = {}
|
||||
for idx, name in enumerate(sequence):
|
||||
frame = _make_dot_frame(_TRUE_RGB[name])
|
||||
cv2.imwrite(str(tmp_path / f"{idx}.png"), frame)
|
||||
labels[str(idx)] = name
|
||||
|
||||
labels_path = tmp_path / "labels.json"
|
||||
labels_path.write_text(json.dumps(labels))
|
||||
|
||||
result = dryrun(tmp_path, labels_path, cfg)
|
||||
|
||||
assert len(result.fire_events) == 1
|
||||
ev = result.fire_events[0]
|
||||
assert ev["direction"] == "BUY"
|
||||
assert ev["ts"] == pytest.approx(15.0) # i=3 → ts=3*5.0
|
||||
assert ev["sample"] == "3"
|
||||
95
tests/test_journal.py
Normal file
95
tests/test_journal.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for atm.journal."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from atm.journal import Journal, TradeEntry, prompt_entry
|
||||
|
||||
|
||||
def _sample() -> TradeEntry:
|
||||
return TradeEntry(
|
||||
ts="2026-04-14T10:00:00",
|
||||
direction="BUY",
|
||||
symbol="US30",
|
||||
entry=40000.0,
|
||||
sl=39950.0,
|
||||
tp1=40100.0,
|
||||
tp2=None,
|
||||
exit=None,
|
||||
outcome="open",
|
||||
detected_ts=None,
|
||||
notes="",
|
||||
)
|
||||
|
||||
|
||||
def test_add_and_read_roundtrip(tmp_path: Path) -> None:
|
||||
journal = Journal(tmp_path / "trades.jsonl")
|
||||
e1 = TradeEntry(
|
||||
ts="2026-04-14T10:00:00", direction="BUY", symbol="US30",
|
||||
entry=40000.0, sl=39950.0, tp1=40100.0, tp2=None,
|
||||
exit=40100.0, outcome="tp1", detected_ts=None, notes="",
|
||||
)
|
||||
e2 = TradeEntry(
|
||||
ts="2026-04-14T11:00:00", direction="SELL", symbol="NQ",
|
||||
entry=20000.0, sl=20050.0, tp1=None, tp2=None,
|
||||
exit=20050.0, outcome="sl", detected_ts=None, notes="stop hit",
|
||||
)
|
||||
|
||||
journal.add(e1)
|
||||
journal.add(e2)
|
||||
|
||||
all_entries = journal.all()
|
||||
assert len(all_entries) == 2
|
||||
assert all_entries[0] == e1
|
||||
assert all_entries[1] == e2
|
||||
|
||||
|
||||
def test_prompt_entry_with_defaults() -> None:
|
||||
inputs = iter([
|
||||
"2026-04-15T10:30:00", # ts
|
||||
"BUY", # direction
|
||||
"US30", # symbol
|
||||
"40000", # entry
|
||||
"39950", # sl
|
||||
"40100", # tp1
|
||||
"", # tp2 (blank → None)
|
||||
"", # exit (blank → None)
|
||||
"open", # outcome
|
||||
"", # notes
|
||||
])
|
||||
|
||||
detected = {
|
||||
"direction": "BUY",
|
||||
"symbol": "US30",
|
||||
"detected_ts": "2026-04-15T10:29:45",
|
||||
}
|
||||
|
||||
entry = prompt_entry(input_fn=lambda _: next(inputs), detected=detected)
|
||||
|
||||
assert entry.direction == "BUY"
|
||||
assert entry.symbol == "US30"
|
||||
assert entry.entry == 40000.0
|
||||
assert entry.sl == 39950.0
|
||||
assert entry.tp1 == 40100.0
|
||||
assert entry.tp2 is None
|
||||
assert entry.exit is None
|
||||
assert entry.outcome == "open"
|
||||
assert entry.detected_ts == "2026-04-15T10:29:45"
|
||||
assert entry.notes == ""
|
||||
|
||||
|
||||
def test_file_line_buffered(tmp_path: Path) -> None:
|
||||
"""Each add() writes exactly one JSONL line, immediately readable."""
|
||||
path = tmp_path / "trades.jsonl"
|
||||
journal = Journal(path)
|
||||
|
||||
journal.add(_sample())
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
assert len(lines) == 1
|
||||
json.loads(lines[0]) # must be valid JSON
|
||||
|
||||
journal.add(_sample())
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
assert len(lines) == 2
|
||||
json.loads(lines[1])
|
||||
54
tests/test_labeler.py
Normal file
54
tests/test_labeler.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Tests for atm.labeler."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from atm.labeler import LabelStore, accuracy
|
||||
|
||||
|
||||
def test_label_store_roundtrip(tmp_path: Path) -> None:
|
||||
path = tmp_path / "labels.json"
|
||||
store = LabelStore(path)
|
||||
store["img1.png"] = "turquoise"
|
||||
store["img2.png"] = "yellow"
|
||||
store.save()
|
||||
|
||||
store2 = LabelStore(path)
|
||||
assert store2["img1.png"] == "turquoise"
|
||||
assert store2["img2.png"] == "yellow"
|
||||
assert store2.as_dict() == {"img1.png": "turquoise", "img2.png": "yellow"}
|
||||
|
||||
|
||||
def test_accuracy_perfect() -> None:
|
||||
labels = {"a.png": "turquoise", "b.png": "yellow", "c.png": "gray"}
|
||||
predicted = {"a.png": "turquoise", "b.png": "yellow", "c.png": "gray"}
|
||||
result = accuracy(labels, predicted)
|
||||
|
||||
assert result["accuracy"] == pytest.approx(1.0)
|
||||
assert result["turquoise_precision"] == pytest.approx(1.0)
|
||||
assert result["turquoise_recall"] == pytest.approx(1.0)
|
||||
assert result["yellow_f1"] == pytest.approx(1.0)
|
||||
assert result["gray_f1"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_accuracy_partial() -> None:
|
||||
# a=turquoise correct, b=turquoise predicted yellow (FN turquoise / FP yellow), c=yellow correct
|
||||
labels = {"a.png": "turquoise", "b.png": "turquoise", "c.png": "yellow"}
|
||||
predicted = {"a.png": "turquoise", "b.png": "yellow", "c.png": "yellow"}
|
||||
|
||||
result = accuracy(labels, predicted)
|
||||
|
||||
# 2 out of 3 correct
|
||||
assert result["accuracy"] == pytest.approx(2 / 3)
|
||||
|
||||
# turquoise: tp=1, fp=0, fn=1 → precision=1.0, recall=0.5, f1=2/3
|
||||
assert result["turquoise_precision"] == pytest.approx(1.0)
|
||||
assert result["turquoise_recall"] == pytest.approx(0.5)
|
||||
assert result["turquoise_f1"] == pytest.approx(2 / 3)
|
||||
|
||||
# yellow: tp=1, fp=1, fn=0 → precision=0.5, recall=1.0, f1=2/3
|
||||
assert result["yellow_precision"] == pytest.approx(0.5)
|
||||
assert result["yellow_recall"] == pytest.approx(1.0)
|
||||
assert result["yellow_f1"] == pytest.approx(2 / 3)
|
||||
172
tests/test_levels.py
Normal file
172
tests/test_levels.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for src/atm/levels.py."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atm.config import (
|
||||
CanaryRegion,
|
||||
ColorSpec,
|
||||
Config,
|
||||
DiscordCfg,
|
||||
ROI,
|
||||
TelegramCfg,
|
||||
YAxisCalib,
|
||||
)
|
||||
from atm.levels import Levels, LevelsExtractor, LevelsResult
|
||||
from atm.vision import pixel_y_to_price
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
# chart_roi starts at (0,0) so test frames can be exactly (H, W)
|
||||
CHART_ROI = ROI(x=0, y=0, w=600, h=400)
|
||||
CALIB = YAxisCalib(p1_y=0, p1_price=100.0, p2_y=400, p2_price=80.0)
|
||||
|
||||
# light_red RGB=(255,0,0) → BGR=(0,0,255)
|
||||
# light_green RGB=(0,255,0) → BGR=(0,255,0)
|
||||
RED_BGR: tuple[int, int, int] = (0, 0, 255)
|
||||
GREEN_BGR: tuple[int, int, int] = (0, 255, 0)
|
||||
|
||||
TOLERANCE = 30.0
|
||||
|
||||
|
||||
def _make_cfg(phaseb_timeout_s: int = 600) -> Config:
|
||||
colors = {
|
||||
"turquoise": ColorSpec(rgb=(0, 255, 255), tolerance=TOLERANCE),
|
||||
"yellow": ColorSpec(rgb=(255, 255, 0), tolerance=TOLERANCE),
|
||||
"dark_green": ColorSpec(rgb=(0, 100, 0), tolerance=TOLERANCE),
|
||||
"dark_red": ColorSpec(rgb=(100, 0, 0), tolerance=TOLERANCE),
|
||||
"light_green": ColorSpec(rgb=(0, 255, 0), tolerance=TOLERANCE),
|
||||
"light_red": ColorSpec(rgb=(255, 0, 0), tolerance=TOLERANCE),
|
||||
"gray": ColorSpec(rgb=(128, 128, 128), tolerance=TOLERANCE),
|
||||
}
|
||||
return Config(
|
||||
window_title="test",
|
||||
dot_roi=ROI(x=0, y=0, w=100, h=50),
|
||||
chart_roi=CHART_ROI,
|
||||
colors=colors,
|
||||
y_axis=CALIB,
|
||||
canary=CanaryRegion(
|
||||
roi=ROI(x=0, y=0, w=50, h=50),
|
||||
baseline_phash="0" * 64,
|
||||
drift_threshold=8,
|
||||
),
|
||||
discord=DiscordCfg(webhook_url="http://example.com/hook"),
|
||||
telegram=TelegramCfg(bot_token="tok", chat_id="123"),
|
||||
phaseb_timeout_s=phaseb_timeout_s,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_chart(*line_specs: tuple[tuple[int, int, int], int]) -> np.ndarray:
|
||||
"""Return a 400×600 black BGR frame with horizontal lines.
|
||||
|
||||
Each spec is (bgr_color, y_position). Lines are painted 3px thick
|
||||
and span the full width so Hough detection is reliable.
|
||||
"""
|
||||
frame = np.zeros((400, 600, 3), dtype=np.uint8)
|
||||
for bgr, y in line_specs:
|
||||
y0 = max(0, y - 1)
|
||||
y1 = min(400, y + 2)
|
||||
frame[y0:y1, :] = bgr
|
||||
return frame
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_three_lines_buy_complete() -> None:
|
||||
"""BUY with 3 stable lines → complete after 2 calls, prices correct."""
|
||||
# red = SL at y=300 (bottom / lowest price for BUY)
|
||||
# green = TP1 at y=200, TP2 at y=100
|
||||
frame = _make_chart(
|
||||
(RED_BGR, 300),
|
||||
(GREEN_BGR, 200),
|
||||
(GREEN_BGR, 100),
|
||||
)
|
||||
cfg = _make_cfg()
|
||||
ext = LevelsExtractor(cfg, direction="BUY", start_ts=0.0)
|
||||
|
||||
r1 = ext.step(frame, ts=1.0)
|
||||
assert r1.status == "partial" # first call: not yet stable
|
||||
|
||||
r2 = ext.step(frame, ts=2.0)
|
||||
assert r2.status == "complete"
|
||||
assert r2.levels is not None
|
||||
assert r2.levels.partial is False
|
||||
|
||||
expected_tp2 = pixel_y_to_price(100, CALIB) # 95.0
|
||||
expected_tp1 = pixel_y_to_price(200, CALIB) # 90.0
|
||||
expected_sl = pixel_y_to_price(300, CALIB) # 85.0
|
||||
|
||||
assert r2.levels.tp2 == pytest.approx(expected_tp2, abs=1.0)
|
||||
assert r2.levels.tp1 == pytest.approx(expected_tp1, abs=1.0)
|
||||
assert r2.levels.sl == pytest.approx(expected_sl, abs=1.0)
|
||||
|
||||
|
||||
def test_two_lines_partial() -> None:
|
||||
"""2 lines → always partial."""
|
||||
frame = _make_chart((RED_BGR, 300), (GREEN_BGR, 100))
|
||||
cfg = _make_cfg()
|
||||
ext = LevelsExtractor(cfg, direction="BUY", start_ts=0.0)
|
||||
|
||||
result = ext.step(frame, ts=1.0)
|
||||
|
||||
assert result.status == "partial"
|
||||
assert result.levels is not None
|
||||
assert result.levels.partial is True
|
||||
|
||||
|
||||
def test_zero_lines_waiting() -> None:
|
||||
"""No lines in chart → waiting."""
|
||||
frame = np.zeros((400, 600, 3), dtype=np.uint8)
|
||||
cfg = _make_cfg()
|
||||
ext = LevelsExtractor(cfg, direction="BUY", start_ts=0.0)
|
||||
|
||||
result = ext.step(frame, ts=1.0)
|
||||
|
||||
assert result.status == "waiting"
|
||||
assert result.levels is None
|
||||
|
||||
|
||||
def test_timeout() -> None:
|
||||
"""Elapsed > phaseb_timeout_s → timeout regardless of lines."""
|
||||
frame = np.zeros((400, 600, 3), dtype=np.uint8)
|
||||
cfg = _make_cfg(phaseb_timeout_s=600)
|
||||
ext = LevelsExtractor(cfg, direction="BUY", start_ts=0.0)
|
||||
|
||||
result = ext.step(frame, ts=700.0)
|
||||
|
||||
assert result.status == "timeout"
|
||||
assert result.levels is None
|
||||
assert result.elapsed_s == pytest.approx(700.0)
|
||||
|
||||
|
||||
def test_sell_direction_assignment() -> None:
|
||||
"""SELL: topmost y → SL (highest price), bottom → TP2 (lowest price)."""
|
||||
frame = _make_chart(
|
||||
(RED_BGR, 300),
|
||||
(GREEN_BGR, 200),
|
||||
(GREEN_BGR, 100),
|
||||
)
|
||||
cfg = _make_cfg()
|
||||
ext = LevelsExtractor(cfg, direction="SELL", start_ts=0.0)
|
||||
|
||||
ext.step(frame, ts=1.0) # first call, not yet stable
|
||||
r = ext.step(frame, ts=2.0) # second call, stable → complete
|
||||
|
||||
assert r.status == "complete"
|
||||
assert r.levels is not None
|
||||
|
||||
expected_sl = pixel_y_to_price(100, CALIB) # 95.0 (topmost = highest price)
|
||||
expected_tp1 = pixel_y_to_price(200, CALIB) # 90.0
|
||||
expected_tp2 = pixel_y_to_price(300, CALIB) # 85.0
|
||||
|
||||
assert r.levels.sl == pytest.approx(expected_sl, abs=1.0)
|
||||
assert r.levels.tp1 == pytest.approx(expected_tp1, abs=1.0)
|
||||
assert r.levels.tp2 == pytest.approx(expected_tp2, abs=1.0)
|
||||
137
tests/test_main.py
Normal file
137
tests/test_main.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for atm.main unified CLI."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
SUBCOMMANDS = ["calibrate", "label", "dryrun", "run", "journal", "report"]
|
||||
|
||||
# Ensure subprocess invocations find the atm package even without pip install
|
||||
_SRC = str(Path(__file__).resolve().parent.parent / "src")
|
||||
_SUBPROCESS_ENV = {**os.environ, "PYTHONPATH": _SRC}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _mock_config_class(cfg=None):
|
||||
"""Return a Config-like class whose load_current() returns *cfg*."""
|
||||
if cfg is None:
|
||||
cfg = MagicMock()
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.load_current.return_value = cfg
|
||||
return mock_cls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_help_works
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_help_works():
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "atm", "--help"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=_SUBPROCESS_ENV,
|
||||
)
|
||||
assert result.returncode == 0, result.stderr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_subcommands_listed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_subcommands_listed():
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "atm", "--help"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=_SUBPROCESS_ENV,
|
||||
)
|
||||
output = result.stdout
|
||||
for cmd in SUBCOMMANDS:
|
||||
assert cmd in output, f"Expected subcommand '{cmd}' in --help output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_dryrun_wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class _DryrunResult:
|
||||
acceptance_pass: bool
|
||||
|
||||
|
||||
def _make_dryrun_module(acceptance_pass: bool):
|
||||
mod = types.ModuleType("atm.dryrun")
|
||||
mod.dryrun = lambda *a, **kw: _DryrunResult(acceptance_pass=acceptance_pass)
|
||||
mod.print_report = lambda r: None
|
||||
return mod
|
||||
|
||||
|
||||
def test_dryrun_wiring_pass(monkeypatch, tmp_path):
|
||||
import atm.main as _main
|
||||
|
||||
monkeypatch.setattr("atm.main.dryrun", _make_dryrun_module(acceptance_pass=True))
|
||||
monkeypatch.setattr("atm.main.Config", _mock_config_class())
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main.main(["dryrun", str(tmp_path)])
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
|
||||
def test_dryrun_wiring_fail(monkeypatch, tmp_path):
|
||||
import atm.main as _main
|
||||
|
||||
monkeypatch.setattr("atm.main.dryrun", _make_dryrun_module(acceptance_pass=False))
|
||||
monkeypatch.setattr("atm.main.Config", _mock_config_class())
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_main.main(["dryrun", str(tmp_path)])
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_report_current_week_default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_report_current_week_default(monkeypatch, tmp_path):
|
||||
import atm.main as _main
|
||||
|
||||
# Journal.all returns no entries — report should print a zero-trade week
|
||||
monkeypatch.setattr("atm.journal.Journal.all", lambda self: [])
|
||||
|
||||
# Should not raise; no sys.exit expected
|
||||
_main.main(["report", "--file", str(tmp_path / "trades.jsonl")])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_run_live_dry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_run_live_dry(monkeypatch):
|
||||
import atm.main as _main
|
||||
|
||||
calls: list[dict] = []
|
||||
|
||||
def _mock_run_live(cfg, duration_s=None, capture_stub=False):
|
||||
calls.append({"cfg": cfg, "duration_s": duration_s, "capture_stub": capture_stub})
|
||||
|
||||
monkeypatch.setattr("atm.main.run_live", _mock_run_live)
|
||||
monkeypatch.setattr("atm.main.Config", _mock_config_class())
|
||||
|
||||
_main.main(["run", "--duration", "0"])
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["duration_s"] == pytest.approx(0.0)
|
||||
76
tests/test_report.py
Normal file
76
tests/test_report.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for atm.report."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from atm.journal import TradeEntry
|
||||
from atm.report import iso_week, weekly_report
|
||||
|
||||
WEEK = "2026-16"
|
||||
BASE_TS = "2026-04-14T10:00:00"
|
||||
|
||||
|
||||
def _trade(
|
||||
outcome: str,
|
||||
direction: str = "BUY",
|
||||
entry: float = 100.0,
|
||||
sl: float = 90.0,
|
||||
exit_: float | None = None,
|
||||
detected_ts: str | None = None,
|
||||
ts: str = BASE_TS,
|
||||
) -> TradeEntry:
|
||||
return TradeEntry(
|
||||
ts=ts,
|
||||
direction=direction,
|
||||
symbol="US30",
|
||||
entry=entry,
|
||||
sl=sl,
|
||||
tp1=None,
|
||||
tp2=None,
|
||||
exit=exit_,
|
||||
outcome=outcome,
|
||||
detected_ts=detected_ts,
|
||||
notes="",
|
||||
)
|
||||
|
||||
|
||||
def test_win_rate_and_pnl() -> None:
|
||||
"""5 synthetic trades: tp1 +2R, tp2 +3R, sl -1.5R, manual +2R, open excluded."""
|
||||
trades = [
|
||||
# tp1: BUY, entry=100, sl=90, exit=120 → R = (120-100)/(100-90) = +2.0
|
||||
_trade("tp1", exit_=120.0, detected_ts="2026-04-14T09:59:55"),
|
||||
# tp2: BUY, entry=100, sl=90, exit=130 → R = (130-100)/(100-90) = +3.0
|
||||
_trade("tp2", exit_=130.0),
|
||||
# sl: BUY, entry=100, sl=90, exit=85 → R = (85-100)/(100-90) = -1.5
|
||||
_trade("sl", exit_=85.0, detected_ts="2026-04-14T09:59:50"),
|
||||
# manual: SELL, entry=100, sl=110, exit=80 → R = (80-100)/(100-110) = +2.0
|
||||
_trade("manual", direction="SELL", sl=110.0, exit_=80.0),
|
||||
# open: excluded from counts
|
||||
_trade("open"),
|
||||
]
|
||||
|
||||
report = weekly_report(trades, WEEK)
|
||||
|
||||
assert report.week == WEEK
|
||||
assert report.n_trades == 4
|
||||
assert report.n_wins == 2
|
||||
assert report.n_losses == 2
|
||||
assert report.win_rate == pytest.approx(0.5)
|
||||
assert report.pnl_r == pytest.approx(5.5)
|
||||
# slippage: trade[0]=5s, trade[2]=10s → avg=7.5s
|
||||
assert report.avg_slippage == pytest.approx(7.5)
|
||||
|
||||
|
||||
def test_iso_week() -> None:
|
||||
assert iso_week("2026-04-14T10:00:00") == "2026-16"
|
||||
assert iso_week("2026-01-01T00:00:00") == "2026-01"
|
||||
|
||||
|
||||
def test_empty_week() -> None:
|
||||
report = weekly_report([], WEEK)
|
||||
assert report.n_trades == 0
|
||||
assert report.n_wins == 0
|
||||
assert report.n_losses == 0
|
||||
assert report.win_rate == 0.0
|
||||
assert report.pnl_r == 0.0
|
||||
assert report.avg_slippage is None
|
||||
Reference in New Issue
Block a user