- Add ssh-tunnel.ps1: Windows SSH tunnel manager (equivalent to ssh-tunnel.sh) - Supports password auth via plink.exe (PuTTY) - Supports ssh_hostkey for non-interactive batch mode - Commands: start, stop, restart, status - Add start-backend-service.ps1: NSSM service wrapper - Starts SSH tunnels before uvicorn - Waits for tunnel ports to be accessible (30s timeout) - Configured by Install-ROA2WEB.ps1 - Add start.ps1: Windows equivalent of start.sh - Orchestrates SSH tunnel + backend + frontend startup - Add backend/shared/ssh_tunnel_manager.py: Python monitoring - Background asyncio task monitors tunnel health every 30s - Auto-restarts tunnels after 2 consecutive failures - Exposes status to /health endpoint - Update ROA2WEB-Console.ps1: - Add Deploy-Scripts function - Update Update-ServiceToUseVenv to use wrapper script - Fix PowerShell reserved variable ($PID -> $tunnelPid) - Fix script path detection (scripts/ vs deployment/windows/scripts/) - Update README.md with ssh_hostkey documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
351 lines
13 KiB
Python
351 lines
13 KiB
Python
"""
|
|
SSH Tunnel Manager - Cross-Platform Monitoring and Auto-Reconnect
|
|
|
|
This module provides MONITORING and AUTO-RECONNECT for SSH tunnels.
|
|
It does NOT start tunnels - that's the responsibility of:
|
|
- Linux: start.sh → ssh-tunnel.sh
|
|
- Windows: Start-ROA2WEB.ps1 → SSH-Tunnels.ps1
|
|
- Windows Service: Start-Backend-Service.ps1 → SSH-Tunnels.ps1
|
|
|
|
Responsibilities:
|
|
✅ Monitor tunnel health via port checks (background asyncio task)
|
|
✅ Auto-restart tunnels if they go down (calls platform-specific scripts)
|
|
✅ Expose status for /health endpoint
|
|
|
|
NOT responsible for:
|
|
❌ Initial tunnel startup (done by wrapper scripts before backend starts)
|
|
|
|
Usage in main.py:
|
|
from backend.shared.ssh_tunnel_manager import ssh_tunnel_manager
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
await ssh_tunnel_manager.start_monitoring()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
await ssh_tunnel_manager.stop_monitoring()
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"ssh_tunnels": ssh_tunnel_manager.get_status()
|
|
}
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import platform
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SSHTunnelManager:
|
|
"""
|
|
Cross-platform SSH tunnel MONITOR (not starter).
|
|
|
|
Timeline:
|
|
T=0 start.sh / Wrapper starts
|
|
T=1s ssh-tunnel.sh / SSH-Tunnels.ps1 START
|
|
T=3s Tunnels active ✅
|
|
T=5s uvicorn backend starts
|
|
T=7s Backend startup_event()
|
|
T=8s ssh_tunnel_manager.start_monitoring()
|
|
└─ Detects tunnels already active (just monitors, doesn't start)
|
|
T=38s Monitor check #1 - OK ✅
|
|
...
|
|
T=XXs [Tunnel drops]
|
|
T=XX+30 Monitor detects FAIL (1/2)
|
|
T=XX+60 Monitor detects FAIL (2/2) → RESTART via script
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Configuration
|
|
self.check_interval: int = 30 # seconds between health checks
|
|
self.max_failures_before_restart: int = 2 # restart after N consecutive failures
|
|
self.restart_cooldown: int = 60 # minimum seconds between restarts
|
|
|
|
# State
|
|
self.tunnel_configs: List[Dict] = []
|
|
self.tunnel_status: Dict[str, bool] = {}
|
|
self.consecutive_failures: Dict[str, int] = {}
|
|
self.last_restart_time: float = 0
|
|
self.monitor_task: Optional[asyncio.Task] = None
|
|
self._is_monitoring: bool = False
|
|
|
|
# Paths (detected at runtime)
|
|
self._project_root: Optional[Path] = None
|
|
self._config_file: Optional[Path] = None
|
|
|
|
def _detect_paths(self) -> bool:
|
|
"""Detect project paths based on current file location."""
|
|
# This file is at: backend/shared/ssh_tunnel_manager.py
|
|
# Project root is 2 levels up
|
|
current_file = Path(__file__)
|
|
self._project_root = current_file.parent.parent.parent
|
|
|
|
# Config file location
|
|
self._config_file = self._project_root / "backend" / "ssh-tunnels.json"
|
|
|
|
return self._config_file.exists()
|
|
|
|
def _load_config(self) -> List[Dict]:
|
|
"""Load tunnel configuration from ssh-tunnels.json."""
|
|
if not self._config_file or not self._config_file.exists():
|
|
return []
|
|
|
|
try:
|
|
with open(self._config_file, 'r') as f:
|
|
tunnels = json.load(f)
|
|
|
|
# Filter to only tunnels with ssh_host (excludes direct connections)
|
|
return [t for t in tunnels if t.get("ssh_host")]
|
|
except Exception as e:
|
|
logger.error(f"[SSH-MONITOR] Failed to load config: {e}")
|
|
return []
|
|
|
|
async def start_monitoring(self) -> bool:
|
|
"""
|
|
Start monitoring EXISTING tunnels.
|
|
|
|
Does NOT start tunnels - assumes they're already running
|
|
(started by start.sh / Start-ROA2WEB.ps1 / Start-Backend-Service.ps1).
|
|
"""
|
|
if self._is_monitoring:
|
|
logger.warning("[SSH-MONITOR] Already monitoring")
|
|
return True
|
|
|
|
# Detect paths and load config
|
|
if not self._detect_paths():
|
|
logger.info("[SSH-MONITOR] No ssh-tunnels.json found, skipping")
|
|
return True
|
|
|
|
self.tunnel_configs = self._load_config()
|
|
|
|
if not self.tunnel_configs:
|
|
logger.info("[SSH-MONITOR] No SSH tunnels configured (or all are direct connections)")
|
|
return True
|
|
|
|
# Check initial status (tunnels should already be running)
|
|
logger.info(f"[SSH-MONITOR] Checking {len(self.tunnel_configs)} tunnel(s)...")
|
|
|
|
for config in self.tunnel_configs:
|
|
tunnel_id = config.get("id", "default")
|
|
port = config.get("local_port", 1521)
|
|
name = config.get("name", tunnel_id)
|
|
|
|
is_active = await self._check_port("127.0.0.1", port)
|
|
self.tunnel_status[tunnel_id] = is_active
|
|
self.consecutive_failures[tunnel_id] = 0
|
|
|
|
status = "✅ active" if is_active else "❌ NOT active"
|
|
logger.info(f"[SSH-MONITOR] [{tunnel_id}] {name} - localhost:{port} - {status}")
|
|
|
|
# Start background monitor loop
|
|
self._is_monitoring = True
|
|
self.monitor_task = asyncio.create_task(self._monitor_loop())
|
|
logger.info(f"[SSH-MONITOR] ✅ Monitoring started (check every {self.check_interval}s)")
|
|
|
|
return True
|
|
|
|
async def stop_monitoring(self) -> None:
|
|
"""Stop the monitoring background task."""
|
|
if not self._is_monitoring:
|
|
return
|
|
|
|
self._is_monitoring = False
|
|
|
|
if self.monitor_task and not self.monitor_task.done():
|
|
self.monitor_task.cancel()
|
|
try:
|
|
await self.monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
logger.info("[SSH-MONITOR] ✅ Monitoring stopped")
|
|
|
|
async def _monitor_loop(self) -> None:
|
|
"""Background loop: check tunnel health every N seconds, restart if needed."""
|
|
while self._is_monitoring:
|
|
try:
|
|
await asyncio.sleep(self.check_interval)
|
|
|
|
if not self._is_monitoring:
|
|
break
|
|
|
|
needs_restart = False
|
|
|
|
for config in self.tunnel_configs:
|
|
tunnel_id = config.get("id", "default")
|
|
port = config.get("local_port", 1521)
|
|
|
|
is_healthy = await self._check_port("127.0.0.1", port)
|
|
self.tunnel_status[tunnel_id] = is_healthy
|
|
|
|
if is_healthy:
|
|
# Reset failure count on success
|
|
if self.consecutive_failures.get(tunnel_id, 0) > 0:
|
|
logger.info(f"[SSH-MONITOR] [{tunnel_id}] Recovered ✅")
|
|
self.consecutive_failures[tunnel_id] = 0
|
|
else:
|
|
# Increment failure count
|
|
self.consecutive_failures[tunnel_id] = \
|
|
self.consecutive_failures.get(tunnel_id, 0) + 1
|
|
|
|
failures = self.consecutive_failures[tunnel_id]
|
|
logger.warning(
|
|
f"[SSH-MONITOR] [{tunnel_id}] FAIL "
|
|
f"({failures}/{self.max_failures_before_restart})"
|
|
)
|
|
|
|
if failures >= self.max_failures_before_restart:
|
|
needs_restart = True
|
|
|
|
# Restart all tunnels if any failed enough times
|
|
if needs_restart:
|
|
await self._restart_tunnels()
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"[SSH-MONITOR] Monitor loop error: {e}")
|
|
await asyncio.sleep(5) # Brief pause before retrying
|
|
|
|
async def _check_port(self, host: str, port: int, timeout: float = 3.0) -> bool:
|
|
"""Check if a port is accessible (tunnel is working)."""
|
|
try:
|
|
# Use asyncio.open_connection for non-blocking port check
|
|
reader, writer = await asyncio.wait_for(
|
|
asyncio.open_connection(host, port),
|
|
timeout=timeout
|
|
)
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
return True
|
|
except (asyncio.TimeoutError, ConnectionRefusedError, OSError):
|
|
return False
|
|
except Exception as e:
|
|
logger.debug(f"[SSH-MONITOR] Port check error {host}:{port}: {e}")
|
|
return False
|
|
|
|
async def _restart_tunnels(self) -> bool:
|
|
"""Restart tunnels via platform-specific script."""
|
|
import time
|
|
|
|
# Check cooldown
|
|
now = time.time()
|
|
if now - self.last_restart_time < self.restart_cooldown:
|
|
remaining = int(self.restart_cooldown - (now - self.last_restart_time))
|
|
logger.warning(f"[SSH-MONITOR] Restart cooldown active ({remaining}s remaining)")
|
|
return False
|
|
|
|
self.last_restart_time = now
|
|
logger.warning("[SSH-MONITOR] 🔄 Restarting tunnels...")
|
|
|
|
# Build platform-specific command
|
|
if platform.system() == "Windows":
|
|
# On Windows, scripts are deployed to scripts/ folder
|
|
script_path = self._project_root / "scripts" / "ssh-tunnel.ps1"
|
|
# Fallback to development path if not found
|
|
if not script_path.exists():
|
|
script_path = self._project_root / "deployment" / "windows" / "scripts" / "ssh-tunnel.ps1"
|
|
if not script_path.exists():
|
|
logger.error(f"[SSH-MONITOR] Script not found in scripts/ or deployment/windows/scripts/")
|
|
return False
|
|
cmd = [
|
|
"powershell.exe",
|
|
"-ExecutionPolicy", "Bypass",
|
|
"-File", str(script_path),
|
|
"restart"
|
|
]
|
|
else:
|
|
script_path = self._project_root / "ssh-tunnel.sh"
|
|
if not script_path.exists():
|
|
logger.error(f"[SSH-MONITOR] Script not found: {script_path}")
|
|
return False
|
|
cmd = [str(script_path), "restart"]
|
|
|
|
try:
|
|
# Run restart command in subprocess
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
cwd=str(self._project_root)
|
|
)
|
|
)
|
|
|
|
if result.returncode == 0:
|
|
logger.info("[SSH-MONITOR] ✅ Tunnels restarted successfully")
|
|
# Reset failure counts
|
|
for tunnel_id in self.consecutive_failures:
|
|
self.consecutive_failures[tunnel_id] = 0
|
|
return True
|
|
else:
|
|
logger.error(f"[SSH-MONITOR] Restart failed (code {result.returncode})")
|
|
if result.stderr:
|
|
logger.error(f"[SSH-MONITOR] stderr: {result.stderr[:500]}")
|
|
return False
|
|
|
|
except subprocess.TimeoutExpired:
|
|
logger.error("[SSH-MONITOR] Restart command timed out")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"[SSH-MONITOR] Restart error: {e}")
|
|
return False
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
"""
|
|
Get current tunnel status for /health endpoint.
|
|
|
|
Returns:
|
|
{
|
|
"status": "connected" | "degraded" | "disconnected" | "not_configured",
|
|
"tunnels": {
|
|
"tunnel_id": true/false,
|
|
...
|
|
},
|
|
"monitoring": true/false
|
|
}
|
|
"""
|
|
if not self.tunnel_configs:
|
|
return {
|
|
"status": "not_configured",
|
|
"tunnels": {},
|
|
"monitoring": False
|
|
}
|
|
|
|
# Determine overall status
|
|
all_connected = all(self.tunnel_status.values()) if self.tunnel_status else False
|
|
any_connected = any(self.tunnel_status.values()) if self.tunnel_status else False
|
|
|
|
if all_connected:
|
|
status = "connected"
|
|
elif any_connected:
|
|
status = "degraded"
|
|
else:
|
|
status = "disconnected"
|
|
|
|
return {
|
|
"status": status,
|
|
"tunnels": dict(self.tunnel_status),
|
|
"monitoring": self._is_monitoring
|
|
}
|
|
|
|
def is_healthy(self) -> bool:
|
|
"""Quick check if all tunnels are healthy."""
|
|
if not self.tunnel_configs:
|
|
return True # No tunnels configured = healthy (direct connection)
|
|
return all(self.tunnel_status.values()) if self.tunnel_status else False
|
|
|
|
|
|
# Global singleton instance
|
|
ssh_tunnel_manager = SSHTunnelManager()
|