diff --git a/src/adapters/discord_bot.py b/src/adapters/discord_bot.py index 197e9a6..52d0185 100644 --- a/src/adapters/discord_bot.py +++ b/src/adapters/discord_bot.py @@ -1,11 +1,14 @@ """Discord bot adapter — slash commands and event handlers.""" +import asyncio import logging import discord from discord import app_commands from src.config import Config +from src.claude_session import clear_session, get_active_session +from src.router import route_message logger = logging.getLogger("echo-core.discord") @@ -42,6 +45,28 @@ def is_registered_channel(channel_id: str) -> bool: return any(ch.get("id") == channel_id for ch in channels.values()) +# --- Message splitting helper --- + + +def split_message(text: str, limit: int = 2000) -> list[str]: + """Split text into chunks that fit Discord's message limit.""" + if len(text) <= limit: + return [text] + + chunks = [] + while text: + if len(text) <= limit: + chunks.append(text) + break + # Find last newline before limit + split_at = text.rfind('\n', 0, limit) + if split_at == -1: + split_at = limit + chunks.append(text[:split_at]) + text = text[split_at:].lstrip('\n') + return chunks + + # --- Factory --- @@ -76,6 +101,8 @@ def create_bot(config: Config) -> discord.Client: "`/channel add ` — Register current channel (owner only)", "`/channels` — List registered channels", "`/admin add ` — Add an admin (owner only)", + "`/clear` — Clear the session for this channel", + "`/status` — Show session status for this channel", ] await interaction.response.send_message( "\n".join(lines), ephemeral=True @@ -161,6 +188,39 @@ def create_bot(config: Config) -> discord.Client: "\n".join(lines), ephemeral=True ) + @tree.command(name="clear", description="Clear the session for this channel") + async def clear(interaction: discord.Interaction) -> None: + channel_id = str(interaction.channel_id) + removed = clear_session(channel_id) + if removed: + await interaction.response.send_message( + "Session cleared.", ephemeral=True + ) + else: + await interaction.response.send_message( + "No active session for this channel.", ephemeral=True + ) + + @tree.command(name="status", description="Show session status") + async def status(interaction: discord.Interaction) -> None: + channel_id = str(interaction.channel_id) + session = get_active_session(channel_id) + if session is None: + await interaction.response.send_message( + "No active session.", ephemeral=True + ) + return + sid = session.get("session_id", "?") + truncated_sid = sid[:8] + "..." if len(sid) > 8 else sid + model = session.get("model", "?") + count = session.get("message_count", 0) + await interaction.response.send_message( + f"**Model:** {model}\n" + f"**Session:** `{truncated_sid}`\n" + f"**Messages:** {count}", + ephemeral=True, + ) + # --- Events --- @client.event @@ -168,20 +228,51 @@ def create_bot(config: Config) -> discord.Client: await tree.sync() logger.info("Echo Core online as %s", client.user) + async def _handle_chat(message: discord.Message) -> None: + """Process a chat message through the router and send the response.""" + channel_id = str(message.channel.id) + user_id = str(message.author.id) + text = message.content + + # React to acknowledge receipt + await message.add_reaction("\U0001f440") + + try: + async with message.channel.typing(): + response = await asyncio.to_thread( + route_message, channel_id, user_id, text + ) + + chunks = split_message(response) + for chunk in chunks: + await message.channel.send(chunk) + except Exception: + logger.exception("Error processing message from %s", message.author) + await message.channel.send( + "Sorry, something went wrong processing your message." + ) + finally: + # Remove the eyes reaction + try: + await message.remove_reaction("\U0001f440", client.user) + except discord.HTTPException: + pass + @client.event async def on_message(message: discord.Message) -> None: # Ignore bot's own messages if message.author == client.user: return - # DM handling: ignore if sender not admin + # DM handling: only process if sender is admin if isinstance(message.channel, discord.DMChannel): if not is_admin(str(message.author.id)): return logger.info( "DM from admin %s: %s", message.author, message.content[:100] ) - return # Stage 5 will add chat integration + await _handle_chat(message) + return # Guild messages: ignore if channel not registered if not is_registered_channel(str(message.channel.id)): @@ -193,6 +284,6 @@ def create_bot(config: Config) -> discord.Client: message.author, message.content[:100], ) - # Stage 5 will add chat integration here + await _handle_chat(message) return client diff --git a/src/router.py b/src/router.py new file mode 100644 index 0000000..fe724bb --- /dev/null +++ b/src/router.py @@ -0,0 +1,75 @@ +"""Echo Core message router — routes messages to Claude or commands.""" + +import logging +from src.config import Config +from src.claude_session import send_message, clear_session, get_active_session, list_sessions + +log = logging.getLogger(__name__) + +# Module-level config instance (lazy singleton) +_config: Config | None = None + + +def _get_config() -> Config: + """Return the module-level config, creating it on first access.""" + global _config + if _config is None: + _config = Config() + return _config + + +def route_message(channel_id: str, user_id: str, text: str, model: str | None = None) -> tuple[str, bool]: + """Route an incoming message. Returns (response_text, is_command). + + If text starts with / it's a command (handled here for text-based commands). + Otherwise it goes to Claude via send_message (auto start/resume). + """ + text = text.strip() + + # Text-based commands (not slash commands — these work in any adapter) + if text.lower() == "/clear": + cleared = clear_session(channel_id) + if cleared: + return "Session cleared.", True + return "No active session.", True + + if text.lower() == "/status": + return _status(channel_id), True + + if text.startswith("/"): + return f"Unknown command: {text.split()[0]}", True + + # Regular message → Claude + if not model: + # Get channel default model or global default + channel_cfg = _get_channel_config(channel_id) + model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet") + + try: + response = send_message(channel_id, text, model=model) + return response, False + except Exception as e: + log.error(f"Claude error for channel {channel_id}: {e}") + return f"Error: {e}", False + + +def _status(channel_id: str) -> str: + """Build status message for a channel.""" + session = get_active_session(channel_id) + if not session: + return "No active session." + + model = session.get("model", "unknown") + sid = session.get("session_id", "unknown")[:12] + count = session.get("message_count", 0) + + return f"Model: {model} | Session: {sid}... | Messages: {count}" + + +def _get_channel_config(channel_id: str) -> dict | None: + """Find channel config by ID.""" + channels = _get_config().get("channels", {}) + for alias, ch in channels.items(): + if ch.get("id") == channel_id: + return ch + return None diff --git a/tests/test_discord.py b/tests/test_discord.py index 998be68..4eea34d 100644 --- a/tests/test_discord.py +++ b/tests/test_discord.py @@ -14,6 +14,7 @@ from src.adapters.discord_bot import ( is_admin, is_owner, is_registered_channel, + split_message, ) @@ -390,3 +391,112 @@ class TestOnMessage: await on_message(message) assert "dm from admin" in caplog.text.lower() + + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.route_message") + async def test_chat_flow(self, mock_route, owned_bot): + """on_message chat flow: reaction, typing, route, send, cleanup.""" + mock_route.return_value = "Hello from Claude!" + + on_message = self._get_on_message(owned_bot) + + message = AsyncMock(spec=discord.Message) + message.author = MagicMock() + message.author.id = 555 + message.author.__eq__ = lambda self, other: False + message.channel = AsyncMock(spec=discord.TextChannel) + message.channel.id = 900 # registered channel + message.content = "test message" + + await on_message(message) + + # Verify eyes reaction added + message.add_reaction.assert_awaited_once_with("\U0001f440") + + # Verify typing indicator was triggered + message.channel.typing.assert_called_once() + + # Verify response sent + message.channel.send.assert_awaited_once_with("Hello from Claude!") + + # Verify eyes reaction removed + message.remove_reaction.assert_awaited_once() + + +# --- split_message --- + + +class TestSplitMessage: + def test_short_text_no_split(self): + result = split_message("hello") + assert result == ["hello"] + + def test_long_text_split_at_newline(self): + text = "a" * 10 + "\n" + "b" * 10 + result = split_message(text, limit=15) + assert result == ["a" * 10, "b" * 10] + + def test_very_long_without_newlines_hard_split(self): + text = "a" * 30 + result = split_message(text, limit=10) + assert result == ["a" * 10, "a" * 10, "a" * 10] + + +# --- /clear slash command --- + + +class TestClearSlashCommand: + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.clear_session") + async def test_clear_with_session(self, mock_clear, owned_bot): + mock_clear.return_value = True + cmd = _find_command(owned_bot.tree, "clear") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + assert "session cleared" in msg.args[0].lower() + + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.clear_session") + async def test_clear_no_session(self, mock_clear, owned_bot): + mock_clear.return_value = False + cmd = _find_command(owned_bot.tree, "clear") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + assert "no active session" in msg.args[0].lower() + + +# --- /status slash command --- + + +class TestStatusSlashCommand: + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.get_active_session") + async def test_status_with_session(self, mock_get, owned_bot): + mock_get.return_value = { + "session_id": "abcdef1234567890", + "model": "sonnet", + "message_count": 3, + } + cmd = _find_command(owned_bot.tree, "status") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + text = msg.args[0] if msg.args else msg.kwargs.get("content", "") + assert "sonnet" in text + assert "3" in text + + @pytest.mark.asyncio + @patch("src.adapters.discord_bot.get_active_session") + async def test_status_no_session(self, mock_get, owned_bot): + mock_get.return_value = None + cmd = _find_command(owned_bot.tree, "status") + interaction = _mock_interaction(channel_id="900") + await cmd.callback(interaction) + + msg = interaction.response.send_message.call_args + assert "no active session" in msg.args[0].lower() diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 0000000..edc8a99 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,188 @@ +"""Tests for src/router.py — message router.""" + +import pytest +from unittest.mock import MagicMock, patch + +from src.router import route_message, _get_channel_config + + +@pytest.fixture(autouse=True) +def reset_router_config(): + """Reset the module-level _config before each test.""" + import src.router + original = src.router._config + src.router._config = None + yield + src.router._config = original + + +# --- /clear command --- + + +class TestClearCommand: + @patch("src.router.clear_session") + def test_clear_active_session(self, mock_clear): + mock_clear.return_value = True + response, is_cmd = route_message("ch-1", "user-1", "/clear") + assert response == "Session cleared." + assert is_cmd is True + mock_clear.assert_called_once_with("ch-1") + + @patch("src.router.clear_session") + def test_clear_no_session(self, mock_clear): + mock_clear.return_value = False + response, is_cmd = route_message("ch-1", "user-1", "/clear") + assert response == "No active session." + assert is_cmd is True + + +# --- /status command --- + + +class TestStatusCommand: + @patch("src.router.get_active_session") + def test_status_active_session(self, mock_get): + mock_get.return_value = { + "model": "sonnet", + "session_id": "abcdef123456789", + "message_count": 5, + } + response, is_cmd = route_message("ch-1", "user-1", "/status") + assert is_cmd is True + assert "sonnet" in response + assert "abcdef123456" in response # first 12 chars + assert "5" in response + + @patch("src.router.get_active_session") + def test_status_no_session(self, mock_get): + mock_get.return_value = None + response, is_cmd = route_message("ch-1", "user-1", "/status") + assert response == "No active session." + assert is_cmd is True + + +# --- Unknown command --- + + +class TestUnknownCommand: + def test_unknown_command(self): + response, is_cmd = route_message("ch-1", "user-1", "/foo") + assert response == "Unknown command: /foo" + assert is_cmd is True + + def test_unknown_command_with_args(self): + response, is_cmd = route_message("ch-1", "user-1", "/bar baz") + assert response == "Unknown command: /bar" + assert is_cmd is True + + +# --- Regular messages --- + + +class TestRegularMessage: + @patch("src.router._get_channel_config") + @patch("src.router._get_config") + @patch("src.router.send_message") + def test_sends_to_claude(self, mock_send, mock_get_config, mock_chan_cfg): + mock_send.return_value = "Hello from Claude!" + mock_chan_cfg.return_value = None + mock_cfg = MagicMock() + mock_cfg.get.return_value = "sonnet" + mock_get_config.return_value = mock_cfg + + response, is_cmd = route_message("ch-1", "user-1", "hello") + assert response == "Hello from Claude!" + assert is_cmd is False + mock_send.assert_called_once_with("ch-1", "hello", model="sonnet") + + @patch("src.router.send_message") + def test_model_override(self, mock_send): + mock_send.return_value = "Response" + response, is_cmd = route_message("ch-1", "user-1", "hello", model="opus") + assert response == "Response" + assert is_cmd is False + mock_send.assert_called_once_with("ch-1", "hello", model="opus") + + @patch("src.router._get_channel_config") + @patch("src.router._get_config") + @patch("src.router.send_message") + def test_claude_error(self, mock_send, mock_get_config, mock_chan_cfg): + mock_send.side_effect = RuntimeError("API timeout") + mock_chan_cfg.return_value = None + mock_cfg = MagicMock() + mock_cfg.get.return_value = "sonnet" + mock_get_config.return_value = mock_cfg + + response, is_cmd = route_message("ch-1", "user-1", "hello") + assert "Error: API timeout" in response + assert is_cmd is False + + +# --- _get_channel_config --- + + +class TestGetChannelConfig: + @patch("src.router._get_config") + def test_finds_by_id(self, mock_get_config): + mock_cfg = MagicMock() + mock_cfg.get.return_value = { + "general": {"id": "ch-1", "default_model": "haiku"}, + "dev": {"id": "ch-2"}, + } + mock_get_config.return_value = mock_cfg + + result = _get_channel_config("ch-1") + assert result == {"id": "ch-1", "default_model": "haiku"} + + @patch("src.router._get_config") + def test_returns_none_when_not_found(self, mock_get_config): + mock_cfg = MagicMock() + mock_cfg.get.return_value = {"general": {"id": "ch-1"}} + mock_get_config.return_value = mock_cfg + + result = _get_channel_config("ch-999") + assert result is None + + +# --- Model resolution --- + + +class TestModelResolution: + @patch("src.router._get_channel_config") + @patch("src.router._get_config") + @patch("src.router.send_message") + def test_channel_default_model(self, mock_send, mock_get_config, mock_chan_cfg): + """Channel config default_model takes priority.""" + mock_send.return_value = "ok" + mock_chan_cfg.return_value = {"id": "ch-1", "default_model": "haiku"} + + route_message("ch-1", "user-1", "hello") + mock_send.assert_called_once_with("ch-1", "hello", model="haiku") + + @patch("src.router._get_channel_config") + @patch("src.router._get_config") + @patch("src.router.send_message") + def test_global_default_model(self, mock_send, mock_get_config, mock_chan_cfg): + """Falls back to bot.default_model when channel has no default.""" + mock_send.return_value = "ok" + mock_chan_cfg.return_value = {"id": "ch-1"} # no default_model + mock_cfg = MagicMock() + mock_cfg.get.return_value = "opus" + mock_get_config.return_value = mock_cfg + + route_message("ch-1", "user-1", "hello") + mock_send.assert_called_once_with("ch-1", "hello", model="opus") + + @patch("src.router._get_channel_config") + @patch("src.router._get_config") + @patch("src.router.send_message") + def test_sonnet_fallback(self, mock_send, mock_get_config, mock_chan_cfg): + """Falls back to 'sonnet' when no channel or global default.""" + mock_send.return_value = "ok" + mock_chan_cfg.return_value = None + mock_cfg = MagicMock() + mock_cfg.get.side_effect = lambda key, default=None: default + mock_get_config.return_value = mock_cfg + + route_message("ch-1", "user-1", "hello") + mock_send.assert_called_once_with("ch-1", "hello", model="sonnet")