diff --git a/src/adapters/discord_bot.py b/src/adapters/discord_bot.py index 52d0185..b094348 100644 --- a/src/adapters/discord_bot.py +++ b/src/adapters/discord_bot.py @@ -2,12 +2,21 @@ import asyncio import logging +import os +import signal +from pathlib import Path import discord from discord import app_commands from src.config import Config -from src.claude_session import clear_session, get_active_session +from src.claude_session import ( + clear_session, + get_active_session, + set_session_model, + PROJECT_ROOT, + VALID_MODELS, +) from src.router import route_message logger = logging.getLogger("echo-core.discord") @@ -103,6 +112,10 @@ def create_bot(config: Config) -> discord.Client: "`/admin add ` — Add an admin (owner only)", "`/clear` — Clear the session for this channel", "`/status` — Show session status for this channel", + "`/model` — View current model and available models", + "`/model ` — Change model for this channel's session", + "`/logs [n]` — Show last N log lines (default 10)", + "`/restart` — Restart the bot process (owner only)", ] await interaction.response.send_message( "\n".join(lines), ephemeral=True @@ -191,10 +204,12 @@ def create_bot(config: Config) -> discord.Client: @tree.command(name="clear", description="Clear the session for this channel") async def clear(interaction: discord.Interaction) -> None: channel_id = str(interaction.channel_id) + default_model = config.get("bot.default_model", "sonnet") removed = clear_session(channel_id) if removed: await interaction.response.send_message( - "Session cleared.", ephemeral=True + f"Session cleared. Model reset to {default_model}.", + ephemeral=True, ) else: await interaction.response.send_message( @@ -221,6 +236,107 @@ def create_bot(config: Config) -> discord.Client: ephemeral=True, ) + @tree.command(name="model", description="View or change the AI model") + @app_commands.describe(choice="Model to switch to") + @app_commands.choices(choice=[ + app_commands.Choice(name="opus", value="opus"), + app_commands.Choice(name="sonnet", value="sonnet"), + app_commands.Choice(name="haiku", value="haiku"), + ]) + async def model_cmd( + interaction: discord.Interaction, + choice: app_commands.Choice[str] | None = None, + ) -> None: + channel_id = str(interaction.channel_id) + if choice is None: + # Show current model and available models + session = get_active_session(channel_id) + if session: + current = session.get("model", "unknown") + else: + current = config.get("bot.default_model", "sonnet") + available = ", ".join(sorted(VALID_MODELS)) + await interaction.response.send_message( + f"**Current model:** {current}\n" + f"**Available:** {available}", + ephemeral=True, + ) + else: + model = choice.value + session = get_active_session(channel_id) + if session: + set_session_model(channel_id, model) + else: + # No session yet — pre-set in active.json so next message uses it + from src.claude_session import _load_sessions, _save_sessions + from datetime import datetime, timezone + sessions = _load_sessions() + sessions[channel_id] = { + "session_id": "", + "model": model, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_message_at": datetime.now(timezone.utc).isoformat(), + "message_count": 0, + } + _save_sessions(sessions) + await interaction.response.send_message( + f"Model changed to **{model}**.", ephemeral=True + ) + + @tree.command(name="restart", description="Restart the bot process") + async def restart(interaction: discord.Interaction) -> None: + if not is_owner(str(interaction.user.id)): + await interaction.response.send_message( + "Owner only.", ephemeral=True + ) + return + pid_file = PROJECT_ROOT / "echo-core.pid" + if not pid_file.exists(): + await interaction.response.send_message( + "No PID file found (echo-core.pid).", ephemeral=True + ) + return + try: + pid = int(pid_file.read_text().strip()) + os.kill(pid, signal.SIGTERM) + await interaction.response.send_message( + "Restarting...", ephemeral=True + ) + except ProcessLookupError: + await interaction.response.send_message( + f"Process {pid} not found.", ephemeral=True + ) + except ValueError: + await interaction.response.send_message( + "Invalid PID file content.", ephemeral=True + ) + + @tree.command(name="logs", description="Show recent log lines") + @app_commands.describe(n="Number of lines to show (default 10)") + async def logs_cmd( + interaction: discord.Interaction, n: int = 10 + ) -> None: + log_path = PROJECT_ROOT / "logs" / "echo-core.log" + if not log_path.exists(): + await interaction.response.send_message( + "No log file found.", ephemeral=True + ) + return + try: + all_lines = log_path.read_text(encoding="utf-8").splitlines() + tail = all_lines[-n:] if len(all_lines) >= n else all_lines + text = "\n".join(tail) + # Truncate to fit Discord message limit (2000 - code block overhead) + if len(text) > 1900: + text = text[-1900:] + await interaction.response.send_message( + f"```\n{text}\n```", ephemeral=True + ) + except Exception as e: + await interaction.response.send_message( + f"Error reading logs: {e}", ephemeral=True + ) + # --- Events --- @client.event diff --git a/src/claude_session.py b/src/claude_session.py index 2c99875..dc24a2d 100644 --- a/src/claude_session.py +++ b/src/claude_session.py @@ -271,6 +271,20 @@ def get_active_session(channel_id: str) -> dict | None: return sessions.get(channel_id) +def set_session_model(channel_id: str, model: str) -> bool: + """Update the model for a channel's active session. Returns True if session existed.""" + if model not in VALID_MODELS: + raise ValueError( + f"Invalid model '{model}'. Must be one of: {', '.join(sorted(VALID_MODELS))}" + ) + sessions = _load_sessions() + if channel_id not in sessions: + return False + sessions[channel_id]["model"] = model + _save_sessions(sessions) + return True + + def list_sessions() -> dict: """Return all channel→session mappings from active.json.""" return _load_sessions() diff --git a/src/router.py b/src/router.py index fe724bb..80551ce 100644 --- a/src/router.py +++ b/src/router.py @@ -2,7 +2,14 @@ import logging from src.config import Config -from src.claude_session import send_message, clear_session, get_active_session, list_sessions +from src.claude_session import ( + send_message, + clear_session, + get_active_session, + list_sessions, + set_session_model, + VALID_MODELS, +) log = logging.getLogger(__name__) @@ -28,22 +35,30 @@ def route_message(channel_id: str, user_id: str, text: str, model: str | None = # Text-based commands (not slash commands — these work in any adapter) if text.lower() == "/clear": + default_model = _get_config().get("bot.default_model", "sonnet") cleared = clear_session(channel_id) if cleared: - return "Session cleared.", True + return f"Session cleared. Model reset to {default_model}.", True return "No active session.", True if text.lower() == "/status": return _status(channel_id), True + if text.lower().startswith("/model"): + return _model_command(channel_id, text), True + if text.startswith("/"): return f"Unknown command: {text.split()[0]}", True # Regular message → Claude if not model: - # Get channel default model or global default - channel_cfg = _get_channel_config(channel_id) - model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet") + # Check session model first, then channel default, then global default + session = get_active_session(channel_id) + if session and session.get("model"): + model = session["model"] + else: + channel_cfg = _get_channel_config(channel_id) + model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet") try: response = send_message(channel_id, text, model=model) @@ -66,6 +81,43 @@ def _status(channel_id: str) -> str: return f"Model: {model} | Session: {sid}... | Messages: {count}" +def _model_command(channel_id: str, text: str) -> str: + """Handle /model [choice] text command.""" + parts = text.strip().split() + if len(parts) == 1: + # /model — show current + session = get_active_session(channel_id) + if session: + current = session.get("model", "unknown") + else: + channel_cfg = _get_channel_config(channel_id) + current = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet") + available = ", ".join(sorted(VALID_MODELS)) + return f"Current model: {current}\nAvailable: {available}" + + choice = parts[1].lower() + if choice not in VALID_MODELS: + return f"Invalid model '{choice}'. Choose from: {', '.join(sorted(VALID_MODELS))}" + + session = get_active_session(channel_id) + if session: + set_session_model(channel_id, choice) + else: + # Pre-set for next message + from src.claude_session import _load_sessions, _save_sessions + from datetime import datetime, timezone + sessions = _load_sessions() + sessions[channel_id] = { + "session_id": "", + "model": choice, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_message_at": datetime.now(timezone.utc).isoformat(), + "message_count": 0, + } + _save_sessions(sessions) + return f"Model changed to {choice}." + + def _get_channel_config(channel_id: str) -> dict | None: """Find channel config by ID.""" channels = _get_config().get("channels", {}) diff --git a/tests/test_claude_session.py b/tests/test_claude_session.py index b4ea90e..e487292 100644 --- a/tests/test_claude_session.py +++ b/tests/test_claude_session.py @@ -20,6 +20,7 @@ from src.claude_session import ( list_sessions, resume_session, send_message, + set_session_model, start_session, ) @@ -574,3 +575,46 @@ class TestListSessions: claude_session, "_SESSIONS_FILE", tmp_path / "nonexistent.json" ) assert list_sessions() == {} + + +# --------------------------------------------------------------------------- +# set_session_model +# --------------------------------------------------------------------------- + + +class TestSetSessionModel: + def test_updates_model_in_active_json(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text(json.dumps({ + "general": { + "session_id": "abc", + "model": "sonnet", + "message_count": 1, + } + })) + + result = set_session_model("general", "opus") + assert result is True + + data = json.loads(sf.read_text()) + assert data["general"]["model"] == "opus" + + def test_returns_false_when_no_session(self, tmp_path, monkeypatch): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + sf = sessions_dir / "active.json" + monkeypatch.setattr(claude_session, "SESSIONS_DIR", sessions_dir) + monkeypatch.setattr(claude_session, "_SESSIONS_FILE", sf) + + sf.write_text("{}") + result = set_session_model("general", "opus") + assert result is False + + def test_invalid_model_raises(self): + with pytest.raises(ValueError, match="Invalid model"): + set_session_model("general", "gpt4") diff --git a/tests/test_discord.py b/tests/test_discord.py index 4eea34d..7bc85b9 100644 --- a/tests/test_discord.py +++ b/tests/test_discord.py @@ -2,6 +2,7 @@ import json import logging +import signal import pytest from unittest.mock import AsyncMock, MagicMock, patch @@ -500,3 +501,158 @@ class TestStatusSlashCommand: msg = interaction.response.send_message.call_args assert "no active session" in msg.args[0].lower() + + +# --- /clear mentions model reset --- + + +class TestClearMentionsModelReset: + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.clear_session") + async def test_clear_mentions_model_reset(self, mock_clear, owned_bot): + mock_clear.return_value = True + cmd = _find_command(owned_bot.tree, "clear") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + text = msg.args[0] + assert "model reset" in text.lower() + assert "sonnet" in text.lower() + + +# --- /model slash command --- + + +class TestModelSlashCommand: + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.get_active_session") + async def test_model_no_args_shows_current_with_session(self, mock_get, owned_bot): + mock_get.return_value = {"model": "opus"} + cmd = _find_command(owned_bot.tree, "model") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction, choice=None) + + msg = interaction.response.send_message.call_args + text = msg.args[0] + assert "opus" in text.lower() + assert "haiku" in text + assert "sonnet" in text + assert msg.kwargs.get("ephemeral") is True + + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.get_active_session") + async def test_model_no_args_shows_default_without_session(self, mock_get, owned_bot): + mock_get.return_value = None + cmd = _find_command(owned_bot.tree, "model") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction, choice=None) + + msg = interaction.response.send_message.call_args + text = msg.args[0] + assert "sonnet" in text # default from config + + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.set_session_model") + @patch("src.adapters.discord_bot.get_active_session") + async def test_model_with_choice_changes_existing_session(self, mock_get, mock_set, owned_bot): + mock_get.return_value = {"model": "sonnet", "session_id": "abc"} + cmd = _find_command(owned_bot.tree, "model") + interaction = _mock_interaction(channel_id="900") + choice = MagicMock() + choice.value = "opus" + await cmd.callback(interaction, choice=choice) + + mock_set.assert_called_once_with("900", "opus") + msg = interaction.response.send_message.call_args + assert "opus" in msg.args[0] + + @pytest.mark.asyncio + @patch("src.claude_session._save_sessions") + @patch("src.claude_session._load_sessions") + @patch("src.adapters.discord_bot.get_active_session") + async def test_model_with_choice_presets_when_no_session(self, mock_get, mock_load, mock_save, owned_bot): + mock_get.return_value = None + mock_load.return_value = {} + cmd = _find_command(owned_bot.tree, "model") + interaction = _mock_interaction(channel_id="900") + choice = MagicMock() + choice.value = "haiku" + await cmd.callback(interaction, choice=choice) + + mock_save.assert_called_once() + saved_data = mock_save.call_args[0][0] + assert saved_data["900"]["model"] == "haiku" + assert saved_data["900"]["session_id"] == "" + msg = interaction.response.send_message.call_args + assert "haiku" in msg.args[0] + + +# --- /restart slash command --- + + +class TestRestartSlashCommand: + @pytest.mark.asyncio + async def test_restart_owner_succeeds(self, owned_bot, tmp_path): + pid_file = tmp_path / "echo-core.pid" + pid_file.write_text("12345") + with patch.object(discord_bot, "PROJECT_ROOT", tmp_path), \ + patch("os.kill") as mock_kill: + cmd = _find_command(owned_bot.tree, "restart") + interaction = _mock_interaction(user_id="111") + await cmd.callback(interaction) + + mock_kill.assert_called_once_with(12345, signal.SIGTERM) + msg = interaction.response.send_message.call_args + assert "restarting" in msg.args[0].lower() + + @pytest.mark.asyncio + async def test_restart_non_owner_rejected(self, owned_bot): + cmd = _find_command(owned_bot.tree, "restart") + interaction = _mock_interaction(user_id="999") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + assert "owner only" in msg.args[0].lower() + + @pytest.mark.asyncio + async def test_restart_no_pid_file(self, owned_bot, tmp_path): + with patch.object(discord_bot, "PROJECT_ROOT", tmp_path): + cmd = _find_command(owned_bot.tree, "restart") + interaction = _mock_interaction(user_id="111") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + assert "no pid file" in msg.args[0].lower() + + +# --- /logs slash command --- + + +class TestLogsSlashCommand: + @pytest.mark.asyncio + async def test_logs_returns_code_block(self, owned_bot, tmp_path): + log_dir = tmp_path / "logs" + log_dir.mkdir() + log_file = log_dir / "echo-core.log" + log_file.write_text("line1\nline2\nline3\n") + with patch.object(discord_bot, "PROJECT_ROOT", tmp_path): + cmd = _find_command(owned_bot.tree, "logs") + interaction = _mock_interaction() + await cmd.callback(interaction, n=10) + + msg = interaction.response.send_message.call_args + text = msg.args[0] + assert "```" in text + assert "line1" in text + assert "line3" in text + + @pytest.mark.asyncio + async def test_logs_no_file(self, owned_bot, tmp_path): + with patch.object(discord_bot, "PROJECT_ROOT", tmp_path): + cmd = _find_command(owned_bot.tree, "logs") + interaction = _mock_interaction() + await cmd.callback(interaction, n=10) + + msg = interaction.response.send_message.call_args + assert "no log file" in msg.args[0].lower() diff --git a/tests/test_router.py b/tests/test_router.py index edc8a99..e66d11a 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -20,21 +20,94 @@ def reset_router_config(): class TestClearCommand: + @patch("src.router._get_config") @patch("src.router.clear_session") - def test_clear_active_session(self, mock_clear): + def test_clear_active_session(self, mock_clear, mock_get_config): mock_clear.return_value = True + mock_cfg = MagicMock() + mock_cfg.get.return_value = "sonnet" + mock_get_config.return_value = mock_cfg response, is_cmd = route_message("ch-1", "user-1", "/clear") - assert response == "Session cleared." + assert response == "Session cleared. Model reset to sonnet." assert is_cmd is True mock_clear.assert_called_once_with("ch-1") + @patch("src.router._get_config") @patch("src.router.clear_session") - def test_clear_no_session(self, mock_clear): + def test_clear_no_session(self, mock_clear, mock_get_config): mock_clear.return_value = False + mock_cfg = MagicMock() + mock_cfg.get.return_value = "sonnet" + mock_get_config.return_value = mock_cfg response, is_cmd = route_message("ch-1", "user-1", "/clear") assert response == "No active session." assert is_cmd is True + @patch("src.router._get_config") + @patch("src.router.clear_session") + def test_clear_mentions_model_reset(self, mock_clear, mock_get_config): + mock_clear.return_value = True + mock_cfg = MagicMock() + mock_cfg.get.return_value = "opus" + mock_get_config.return_value = mock_cfg + response, is_cmd = route_message("ch-1", "user-1", "/clear") + assert "model reset" in response.lower() + assert "opus" in response + + +# --- /model command --- + + +class TestModelCommand: + @patch("src.router.get_active_session") + def test_model_show_current_with_session(self, mock_get): + mock_get.return_value = {"model": "opus"} + response, is_cmd = route_message("ch-1", "user-1", "/model") + assert is_cmd is True + assert "opus" in response + assert "haiku" in response # available models listed + + @patch("src.router._get_config") + @patch("src.router._get_channel_config") + @patch("src.router.get_active_session") + def test_model_show_current_no_session(self, mock_get, mock_chan_cfg, mock_get_config): + mock_get.return_value = None + mock_chan_cfg.return_value = None + mock_cfg = MagicMock() + mock_cfg.get.return_value = "sonnet" + mock_get_config.return_value = mock_cfg + response, is_cmd = route_message("ch-1", "user-1", "/model") + assert is_cmd is True + assert "sonnet" in response + + @patch("src.router.set_session_model") + @patch("src.router.get_active_session") + def test_model_change_opus(self, mock_get, mock_set): + mock_get.return_value = {"model": "sonnet", "session_id": "abc"} + response, is_cmd = route_message("ch-1", "user-1", "/model opus") + assert is_cmd is True + mock_set.assert_called_once_with("ch-1", "opus") + assert "opus" in response + + def test_model_invalid_choice(self): + response, is_cmd = route_message("ch-1", "user-1", "/model gpt4") + assert is_cmd is True + assert "invalid" in response.lower() + assert "gpt4" in response + + @patch("src.claude_session._save_sessions") + @patch("src.claude_session._load_sessions") + @patch("src.router.get_active_session") + def test_model_change_no_session_presets(self, mock_get, mock_load, mock_save): + mock_get.return_value = None + mock_load.return_value = {} + response, is_cmd = route_message("ch-1", "user-1", "/model haiku") + assert is_cmd is True + mock_save.assert_called_once() + saved = mock_save.call_args[0][0] + assert saved["ch-1"]["model"] == "haiku" + assert "haiku" in response + # --- /status command --- @@ -186,3 +259,13 @@ class TestModelResolution: route_message("ch-1", "user-1", "hello") mock_send.assert_called_once_with("ch-1", "hello", model="sonnet") + + @patch("src.router.get_active_session") + @patch("src.router.send_message") + def test_session_model_takes_priority(self, mock_send, mock_get_session): + """Session model takes priority over channel and global defaults.""" + mock_send.return_value = "ok" + mock_get_session.return_value = {"model": "opus", "session_id": "abc"} + + route_message("ch-1", "user-1", "hello") + mock_send.assert_called_once_with("ch-1", "hello", model="opus")