stage-5: full discord-claude chat integration
Message router, typing indicator, emoji reactions, auto start/resume sessions, message splitting >2000 chars. 34 new tests (141 total). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
"""Discord bot adapter — slash commands and event handlers."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from src.config import Config
|
||||
from src.claude_session import clear_session, get_active_session
|
||||
from src.router import route_message
|
||||
|
||||
logger = logging.getLogger("echo-core.discord")
|
||||
|
||||
@@ -42,6 +45,28 @@ def is_registered_channel(channel_id: str) -> bool:
|
||||
return any(ch.get("id") == channel_id for ch in channels.values())
|
||||
|
||||
|
||||
# --- Message splitting helper ---
|
||||
|
||||
|
||||
def split_message(text: str, limit: int = 2000) -> list[str]:
|
||||
"""Split text into chunks that fit Discord's message limit."""
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= limit:
|
||||
chunks.append(text)
|
||||
break
|
||||
# Find last newline before limit
|
||||
split_at = text.rfind('\n', 0, limit)
|
||||
if split_at == -1:
|
||||
split_at = limit
|
||||
chunks.append(text[:split_at])
|
||||
text = text[split_at:].lstrip('\n')
|
||||
return chunks
|
||||
|
||||
|
||||
# --- Factory ---
|
||||
|
||||
|
||||
@@ -76,6 +101,8 @@ def create_bot(config: Config) -> discord.Client:
|
||||
"`/channel add <alias>` — Register current channel (owner only)",
|
||||
"`/channels` — List registered channels",
|
||||
"`/admin add <user_id>` — Add an admin (owner only)",
|
||||
"`/clear` — Clear the session for this channel",
|
||||
"`/status` — Show session status for this channel",
|
||||
]
|
||||
await interaction.response.send_message(
|
||||
"\n".join(lines), ephemeral=True
|
||||
@@ -161,6 +188,39 @@ def create_bot(config: Config) -> discord.Client:
|
||||
"\n".join(lines), ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="clear", description="Clear the session for this channel")
|
||||
async def clear(interaction: discord.Interaction) -> None:
|
||||
channel_id = str(interaction.channel_id)
|
||||
removed = clear_session(channel_id)
|
||||
if removed:
|
||||
await interaction.response.send_message(
|
||||
"Session cleared.", ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
"No active session for this channel.", ephemeral=True
|
||||
)
|
||||
|
||||
@tree.command(name="status", description="Show session status")
|
||||
async def status(interaction: discord.Interaction) -> None:
|
||||
channel_id = str(interaction.channel_id)
|
||||
session = get_active_session(channel_id)
|
||||
if session is None:
|
||||
await interaction.response.send_message(
|
||||
"No active session.", ephemeral=True
|
||||
)
|
||||
return
|
||||
sid = session.get("session_id", "?")
|
||||
truncated_sid = sid[:8] + "..." if len(sid) > 8 else sid
|
||||
model = session.get("model", "?")
|
||||
count = session.get("message_count", 0)
|
||||
await interaction.response.send_message(
|
||||
f"**Model:** {model}\n"
|
||||
f"**Session:** `{truncated_sid}`\n"
|
||||
f"**Messages:** {count}",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
# --- Events ---
|
||||
|
||||
@client.event
|
||||
@@ -168,20 +228,51 @@ def create_bot(config: Config) -> discord.Client:
|
||||
await tree.sync()
|
||||
logger.info("Echo Core online as %s", client.user)
|
||||
|
||||
async def _handle_chat(message: discord.Message) -> None:
|
||||
"""Process a chat message through the router and send the response."""
|
||||
channel_id = str(message.channel.id)
|
||||
user_id = str(message.author.id)
|
||||
text = message.content
|
||||
|
||||
# React to acknowledge receipt
|
||||
await message.add_reaction("\U0001f440")
|
||||
|
||||
try:
|
||||
async with message.channel.typing():
|
||||
response = await asyncio.to_thread(
|
||||
route_message, channel_id, user_id, text
|
||||
)
|
||||
|
||||
chunks = split_message(response)
|
||||
for chunk in chunks:
|
||||
await message.channel.send(chunk)
|
||||
except Exception:
|
||||
logger.exception("Error processing message from %s", message.author)
|
||||
await message.channel.send(
|
||||
"Sorry, something went wrong processing your message."
|
||||
)
|
||||
finally:
|
||||
# Remove the eyes reaction
|
||||
try:
|
||||
await message.remove_reaction("\U0001f440", client.user)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message) -> None:
|
||||
# Ignore bot's own messages
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
# DM handling: ignore if sender not admin
|
||||
# DM handling: only process if sender is admin
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
if not is_admin(str(message.author.id)):
|
||||
return
|
||||
logger.info(
|
||||
"DM from admin %s: %s", message.author, message.content[:100]
|
||||
)
|
||||
return # Stage 5 will add chat integration
|
||||
await _handle_chat(message)
|
||||
return
|
||||
|
||||
# Guild messages: ignore if channel not registered
|
||||
if not is_registered_channel(str(message.channel.id)):
|
||||
@@ -193,6 +284,6 @@ def create_bot(config: Config) -> discord.Client:
|
||||
message.author,
|
||||
message.content[:100],
|
||||
)
|
||||
# Stage 5 will add chat integration here
|
||||
await _handle_chat(message)
|
||||
|
||||
return client
|
||||
|
||||
75
src/router.py
Normal file
75
src/router.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Echo Core message router — routes messages to Claude or commands."""
|
||||
|
||||
import logging
|
||||
from src.config import Config
|
||||
from src.claude_session import send_message, clear_session, get_active_session, list_sessions
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Module-level config instance (lazy singleton)
|
||||
_config: Config | None = None
|
||||
|
||||
|
||||
def _get_config() -> Config:
|
||||
"""Return the module-level config, creating it on first access."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = Config()
|
||||
return _config
|
||||
|
||||
|
||||
def route_message(channel_id: str, user_id: str, text: str, model: str | None = None) -> tuple[str, bool]:
|
||||
"""Route an incoming message. Returns (response_text, is_command).
|
||||
|
||||
If text starts with / it's a command (handled here for text-based commands).
|
||||
Otherwise it goes to Claude via send_message (auto start/resume).
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# Text-based commands (not slash commands — these work in any adapter)
|
||||
if text.lower() == "/clear":
|
||||
cleared = clear_session(channel_id)
|
||||
if cleared:
|
||||
return "Session cleared.", True
|
||||
return "No active session.", True
|
||||
|
||||
if text.lower() == "/status":
|
||||
return _status(channel_id), True
|
||||
|
||||
if text.startswith("/"):
|
||||
return f"Unknown command: {text.split()[0]}", True
|
||||
|
||||
# Regular message → Claude
|
||||
if not model:
|
||||
# Get channel default model or global default
|
||||
channel_cfg = _get_channel_config(channel_id)
|
||||
model = (channel_cfg or {}).get("default_model") or _get_config().get("bot.default_model", "sonnet")
|
||||
|
||||
try:
|
||||
response = send_message(channel_id, text, model=model)
|
||||
return response, False
|
||||
except Exception as e:
|
||||
log.error(f"Claude error for channel {channel_id}: {e}")
|
||||
return f"Error: {e}", False
|
||||
|
||||
|
||||
def _status(channel_id: str) -> str:
|
||||
"""Build status message for a channel."""
|
||||
session = get_active_session(channel_id)
|
||||
if not session:
|
||||
return "No active session."
|
||||
|
||||
model = session.get("model", "unknown")
|
||||
sid = session.get("session_id", "unknown")[:12]
|
||||
count = session.get("message_count", 0)
|
||||
|
||||
return f"Model: {model} | Session: {sid}... | Messages: {count}"
|
||||
|
||||
|
||||
def _get_channel_config(channel_id: str) -> dict | None:
|
||||
"""Find channel config by ID."""
|
||||
channels = _get_config().get("channels", {})
|
||||
for alias, ch in channels.items():
|
||||
if ch.get("id") == channel_id:
|
||||
return ch
|
||||
return None
|
||||
@@ -14,6 +14,7 @@ from src.adapters.discord_bot import (
|
||||
is_admin,
|
||||
is_owner,
|
||||
is_registered_channel,
|
||||
split_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -390,3 +391,112 @@ class TestOnMessage:
|
||||
await on_message(message)
|
||||
|
||||
assert "dm from admin" in caplog.text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.adapters.discord_bot.route_message")
|
||||
async def test_chat_flow(self, mock_route, owned_bot):
|
||||
"""on_message chat flow: reaction, typing, route, send, cleanup."""
|
||||
mock_route.return_value = "Hello from Claude!"
|
||||
|
||||
on_message = self._get_on_message(owned_bot)
|
||||
|
||||
message = AsyncMock(spec=discord.Message)
|
||||
message.author = MagicMock()
|
||||
message.author.id = 555
|
||||
message.author.__eq__ = lambda self, other: False
|
||||
message.channel = AsyncMock(spec=discord.TextChannel)
|
||||
message.channel.id = 900 # registered channel
|
||||
message.content = "test message"
|
||||
|
||||
await on_message(message)
|
||||
|
||||
# Verify eyes reaction added
|
||||
message.add_reaction.assert_awaited_once_with("\U0001f440")
|
||||
|
||||
# Verify typing indicator was triggered
|
||||
message.channel.typing.assert_called_once()
|
||||
|
||||
# Verify response sent
|
||||
message.channel.send.assert_awaited_once_with("Hello from Claude!")
|
||||
|
||||
# Verify eyes reaction removed
|
||||
message.remove_reaction.assert_awaited_once()
|
||||
|
||||
|
||||
# --- split_message ---
|
||||
|
||||
|
||||
class TestSplitMessage:
|
||||
def test_short_text_no_split(self):
|
||||
result = split_message("hello")
|
||||
assert result == ["hello"]
|
||||
|
||||
def test_long_text_split_at_newline(self):
|
||||
text = "a" * 10 + "\n" + "b" * 10
|
||||
result = split_message(text, limit=15)
|
||||
assert result == ["a" * 10, "b" * 10]
|
||||
|
||||
def test_very_long_without_newlines_hard_split(self):
|
||||
text = "a" * 30
|
||||
result = split_message(text, limit=10)
|
||||
assert result == ["a" * 10, "a" * 10, "a" * 10]
|
||||
|
||||
|
||||
# --- /clear slash command ---
|
||||
|
||||
|
||||
class TestClearSlashCommand:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.adapters.discord_bot.clear_session")
|
||||
async def test_clear_with_session(self, mock_clear, owned_bot):
|
||||
mock_clear.return_value = True
|
||||
cmd = _find_command(owned_bot.tree, "clear")
|
||||
interaction = _mock_interaction(channel_id="900")
|
||||
await cmd.callback(interaction)
|
||||
|
||||
msg = interaction.response.send_message.call_args
|
||||
assert "session cleared" in msg.args[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.adapters.discord_bot.clear_session")
|
||||
async def test_clear_no_session(self, mock_clear, owned_bot):
|
||||
mock_clear.return_value = False
|
||||
cmd = _find_command(owned_bot.tree, "clear")
|
||||
interaction = _mock_interaction(channel_id="900")
|
||||
await cmd.callback(interaction)
|
||||
|
||||
msg = interaction.response.send_message.call_args
|
||||
assert "no active session" in msg.args[0].lower()
|
||||
|
||||
|
||||
# --- /status slash command ---
|
||||
|
||||
|
||||
class TestStatusSlashCommand:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.adapters.discord_bot.get_active_session")
|
||||
async def test_status_with_session(self, mock_get, owned_bot):
|
||||
mock_get.return_value = {
|
||||
"session_id": "abcdef1234567890",
|
||||
"model": "sonnet",
|
||||
"message_count": 3,
|
||||
}
|
||||
cmd = _find_command(owned_bot.tree, "status")
|
||||
interaction = _mock_interaction(channel_id="900")
|
||||
await cmd.callback(interaction)
|
||||
|
||||
msg = interaction.response.send_message.call_args
|
||||
text = msg.args[0] if msg.args else msg.kwargs.get("content", "")
|
||||
assert "sonnet" in text
|
||||
assert "3" in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.adapters.discord_bot.get_active_session")
|
||||
async def test_status_no_session(self, mock_get, owned_bot):
|
||||
mock_get.return_value = None
|
||||
cmd = _find_command(owned_bot.tree, "status")
|
||||
interaction = _mock_interaction(channel_id="900")
|
||||
await cmd.callback(interaction)
|
||||
|
||||
msg = interaction.response.send_message.call_args
|
||||
assert "no active session" in msg.args[0].lower()
|
||||
|
||||
188
tests/test_router.py
Normal file
188
tests/test_router.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Tests for src/router.py — message router."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from src.router import route_message, _get_channel_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_router_config():
|
||||
"""Reset the module-level _config before each test."""
|
||||
import src.router
|
||||
original = src.router._config
|
||||
src.router._config = None
|
||||
yield
|
||||
src.router._config = original
|
||||
|
||||
|
||||
# --- /clear command ---
|
||||
|
||||
|
||||
class TestClearCommand:
|
||||
@patch("src.router.clear_session")
|
||||
def test_clear_active_session(self, mock_clear):
|
||||
mock_clear.return_value = True
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/clear")
|
||||
assert response == "Session cleared."
|
||||
assert is_cmd is True
|
||||
mock_clear.assert_called_once_with("ch-1")
|
||||
|
||||
@patch("src.router.clear_session")
|
||||
def test_clear_no_session(self, mock_clear):
|
||||
mock_clear.return_value = False
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/clear")
|
||||
assert response == "No active session."
|
||||
assert is_cmd is True
|
||||
|
||||
|
||||
# --- /status command ---
|
||||
|
||||
|
||||
class TestStatusCommand:
|
||||
@patch("src.router.get_active_session")
|
||||
def test_status_active_session(self, mock_get):
|
||||
mock_get.return_value = {
|
||||
"model": "sonnet",
|
||||
"session_id": "abcdef123456789",
|
||||
"message_count": 5,
|
||||
}
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/status")
|
||||
assert is_cmd is True
|
||||
assert "sonnet" in response
|
||||
assert "abcdef123456" in response # first 12 chars
|
||||
assert "5" in response
|
||||
|
||||
@patch("src.router.get_active_session")
|
||||
def test_status_no_session(self, mock_get):
|
||||
mock_get.return_value = None
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/status")
|
||||
assert response == "No active session."
|
||||
assert is_cmd is True
|
||||
|
||||
|
||||
# --- Unknown command ---
|
||||
|
||||
|
||||
class TestUnknownCommand:
|
||||
def test_unknown_command(self):
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/foo")
|
||||
assert response == "Unknown command: /foo"
|
||||
assert is_cmd is True
|
||||
|
||||
def test_unknown_command_with_args(self):
|
||||
response, is_cmd = route_message("ch-1", "user-1", "/bar baz")
|
||||
assert response == "Unknown command: /bar"
|
||||
assert is_cmd is True
|
||||
|
||||
|
||||
# --- Regular messages ---
|
||||
|
||||
|
||||
class TestRegularMessage:
|
||||
@patch("src.router._get_channel_config")
|
||||
@patch("src.router._get_config")
|
||||
@patch("src.router.send_message")
|
||||
def test_sends_to_claude(self, mock_send, mock_get_config, mock_chan_cfg):
|
||||
mock_send.return_value = "Hello from Claude!"
|
||||
mock_chan_cfg.return_value = None
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.return_value = "sonnet"
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
response, is_cmd = route_message("ch-1", "user-1", "hello")
|
||||
assert response == "Hello from Claude!"
|
||||
assert is_cmd is False
|
||||
mock_send.assert_called_once_with("ch-1", "hello", model="sonnet")
|
||||
|
||||
@patch("src.router.send_message")
|
||||
def test_model_override(self, mock_send):
|
||||
mock_send.return_value = "Response"
|
||||
response, is_cmd = route_message("ch-1", "user-1", "hello", model="opus")
|
||||
assert response == "Response"
|
||||
assert is_cmd is False
|
||||
mock_send.assert_called_once_with("ch-1", "hello", model="opus")
|
||||
|
||||
@patch("src.router._get_channel_config")
|
||||
@patch("src.router._get_config")
|
||||
@patch("src.router.send_message")
|
||||
def test_claude_error(self, mock_send, mock_get_config, mock_chan_cfg):
|
||||
mock_send.side_effect = RuntimeError("API timeout")
|
||||
mock_chan_cfg.return_value = None
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.return_value = "sonnet"
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
response, is_cmd = route_message("ch-1", "user-1", "hello")
|
||||
assert "Error: API timeout" in response
|
||||
assert is_cmd is False
|
||||
|
||||
|
||||
# --- _get_channel_config ---
|
||||
|
||||
|
||||
class TestGetChannelConfig:
|
||||
@patch("src.router._get_config")
|
||||
def test_finds_by_id(self, mock_get_config):
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.return_value = {
|
||||
"general": {"id": "ch-1", "default_model": "haiku"},
|
||||
"dev": {"id": "ch-2"},
|
||||
}
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
result = _get_channel_config("ch-1")
|
||||
assert result == {"id": "ch-1", "default_model": "haiku"}
|
||||
|
||||
@patch("src.router._get_config")
|
||||
def test_returns_none_when_not_found(self, mock_get_config):
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.return_value = {"general": {"id": "ch-1"}}
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
result = _get_channel_config("ch-999")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- Model resolution ---
|
||||
|
||||
|
||||
class TestModelResolution:
|
||||
@patch("src.router._get_channel_config")
|
||||
@patch("src.router._get_config")
|
||||
@patch("src.router.send_message")
|
||||
def test_channel_default_model(self, mock_send, mock_get_config, mock_chan_cfg):
|
||||
"""Channel config default_model takes priority."""
|
||||
mock_send.return_value = "ok"
|
||||
mock_chan_cfg.return_value = {"id": "ch-1", "default_model": "haiku"}
|
||||
|
||||
route_message("ch-1", "user-1", "hello")
|
||||
mock_send.assert_called_once_with("ch-1", "hello", model="haiku")
|
||||
|
||||
@patch("src.router._get_channel_config")
|
||||
@patch("src.router._get_config")
|
||||
@patch("src.router.send_message")
|
||||
def test_global_default_model(self, mock_send, mock_get_config, mock_chan_cfg):
|
||||
"""Falls back to bot.default_model when channel has no default."""
|
||||
mock_send.return_value = "ok"
|
||||
mock_chan_cfg.return_value = {"id": "ch-1"} # no default_model
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.return_value = "opus"
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
route_message("ch-1", "user-1", "hello")
|
||||
mock_send.assert_called_once_with("ch-1", "hello", model="opus")
|
||||
|
||||
@patch("src.router._get_channel_config")
|
||||
@patch("src.router._get_config")
|
||||
@patch("src.router.send_message")
|
||||
def test_sonnet_fallback(self, mock_send, mock_get_config, mock_chan_cfg):
|
||||
"""Falls back to 'sonnet' when no channel or global default."""
|
||||
mock_send.return_value = "ok"
|
||||
mock_chan_cfg.return_value = None
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.get.side_effect = lambda key, default=None: default
|
||||
mock_get_config.return_value = mock_cfg
|
||||
|
||||
route_message("ch-1", "user-1", "hello")
|
||||
mock_send.assert_called_once_with("ch-1", "hello", model="sonnet")
|
||||
Reference in New Issue
Block a user