diff --git a/src/claude_session.py b/src/claude_session.py index 1f3d6f3..f287998 100644 --- a/src/claude_session.py +++ b/src/claude_session.py @@ -157,7 +157,12 @@ def _save_sessions(data: dict) -> None: def _run_claude(cmd: list[str], timeout: int) -> dict: - """Run a Claude CLI command and return the parsed JSON output.""" + """Run a Claude CLI command and return parsed output. + + Expects ``--output-format stream-json --verbose``. Parses the newline- + delimited JSON stream, collecting every text block from ``assistant`` + messages and metadata from the final ``result`` line. + """ if not shutil.which(CLAUDE_BIN): raise FileNotFoundError( "Claude CLI not found. " @@ -184,12 +189,50 @@ def _run_claude(cmd: list[str], timeout: int) -> dict: f"Claude CLI error (exit {proc.returncode}): {detail}" ) - try: - data = json.loads(proc.stdout) - except json.JSONDecodeError as exc: - raise RuntimeError(f"Failed to parse Claude CLI output: {exc}") + # --- Parse stream-json output --- + text_blocks: list[str] = [] + result_obj: dict | None = None - return data + for line in proc.stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + + msg_type = obj.get("type") + + if msg_type == "assistant": + # Extract text from content blocks + message = obj.get("message", {}) + for block in message.get("content", []): + if block.get("type") == "text": + text = block.get("text", "").strip() + if text: + text_blocks.append(text) + + elif msg_type == "result": + result_obj = obj + + if result_obj is None: + raise RuntimeError( + "Failed to parse Claude CLI output: no result line in stream" + ) + + # Build a dict compatible with the old json output format + combined_text = "\n\n".join(text_blocks) if text_blocks else result_obj.get("result", "") + + return { + "result": combined_text, + "session_id": result_obj.get("session_id", ""), + "usage": result_obj.get("usage", {}), + "total_cost_usd": result_obj.get("total_cost_usd", 0), + "cost_usd": result_obj.get("cost_usd", 0), + "duration_ms": result_obj.get("duration_ms", 0), + "num_turns": result_obj.get("num_turns", 0), + } # --------------------------------------------------------------------------- @@ -248,7 +291,7 @@ def start_session( cmd = [ CLAUDE_BIN, "-p", wrapped_message, "--model", model, - "--output-format", "json", + "--output-format", "stream-json", "--verbose", "--system-prompt", system_prompt, "--allowedTools", *ALLOWED_TOOLS, ] @@ -258,7 +301,7 @@ def start_session( _elapsed_ms = int((time.monotonic() - _t0) * 1000) for field in ("result", "session_id"): - if field not in data: + if not data.get(field): raise RuntimeError( f"Claude CLI response missing required field: {field}" ) @@ -317,7 +360,7 @@ def resume_session( cmd = [ CLAUDE_BIN, "-p", wrapped_message, "--resume", session_id, - "--output-format", "json", + "--output-format", "stream-json", "--verbose", "--allowedTools", *ALLOWED_TOOLS, ] @@ -325,7 +368,7 @@ def resume_session( data = _run_claude(cmd, timeout) _elapsed_ms = int((time.monotonic() - _t0) * 1000) - if "result" not in data: + if not data.get("result"): raise RuntimeError( "Claude CLI response missing required field: result" ) diff --git a/tests/test_claude_session.py b/tests/test_claude_session.py index 77f653d..4668a2e 100644 --- a/tests/test_claude_session.py +++ b/tests/test_claude_session.py @@ -28,21 +28,45 @@ from src.claude_session import ( # Helpers # --------------------------------------------------------------------------- -FAKE_CLI_RESPONSE = { +FAKE_RESULT_LINE = { "type": "result", "subtype": "success", "session_id": "sess-abc-123", "result": "Hello from Claude!", "cost_usd": 0.004, + "total_cost_usd": 0.004, "duration_ms": 1500, "num_turns": 1, + "usage": {"input_tokens": 100, "output_tokens": 50}, +} + +FAKE_ASSISTANT_LINE = { + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "Hello from Claude!"}], + }, } +def _make_stream(*assistant_texts, result_override=None): + """Build stream-json stdout with assistant + result lines.""" + lines = [] + for text in assistant_texts: + lines.append(json.dumps({ + "type": "assistant", + "message": {"content": [{"type": "text", "text": text}]}, + })) + result = dict(FAKE_RESULT_LINE) + if result_override: + result.update(result_override) + lines.append(json.dumps(result)) + return "\n".join(lines) + + def _make_proc(stdout=None, returncode=0, stderr=""): - """Build a fake subprocess.CompletedProcess.""" + """Build a fake subprocess.CompletedProcess with stream-json output.""" if stdout is None: - stdout = json.dumps(FAKE_CLI_RESPONSE) + stdout = _make_stream("Hello from Claude!") proc = MagicMock(spec=subprocess.CompletedProcess) proc.stdout = stdout proc.stderr = stderr @@ -153,10 +177,20 @@ class TestSafeEnv: class TestRunClaude: @patch("shutil.which", return_value="/usr/bin/claude") @patch("subprocess.run") - def test_returns_parsed_json(self, mock_run, mock_which): + def test_returns_parsed_stream(self, mock_run, mock_which): mock_run.return_value = _make_proc() result = _run_claude(["claude", "-p", "hi"], timeout=30) - assert result == FAKE_CLI_RESPONSE + assert result["result"] == "Hello from Claude!" + assert result["session_id"] == "sess-abc-123" + assert "usage" in result + + @patch("shutil.which", return_value="/usr/bin/claude") + @patch("subprocess.run") + def test_collects_multiple_text_blocks(self, mock_run, mock_which): + stdout = _make_stream("First message", "Second message", "Third message") + mock_run.return_value = _make_proc(stdout=stdout) + result = _run_claude(["claude", "-p", "hi"], timeout=30) + assert result["result"] == "First message\n\nSecond message\n\nThird message" @patch("shutil.which", return_value="/usr/bin/claude") @patch("subprocess.run") @@ -176,9 +210,11 @@ class TestRunClaude: @patch("shutil.which", return_value="/usr/bin/claude") @patch("subprocess.run") - def test_invalid_json_raises(self, mock_run, mock_which): - mock_run.return_value = _make_proc(stdout="not json {{{") - with pytest.raises(RuntimeError, match="Failed to parse"): + def test_no_result_line_raises(self, mock_run, mock_which): + # Stream with only an assistant line but no result line + stdout = json.dumps({"type": "assistant", "message": {"content": []}}) + mock_run.return_value = _make_proc(stdout=stdout) + with pytest.raises(RuntimeError, match="no result line"): _run_claude(["claude", "-p", "hi"], timeout=30) @patch("shutil.which", return_value=None) @@ -299,7 +335,7 @@ class TestStartSession: @patch("shutil.which", return_value="/usr/bin/claude") @patch("subprocess.run") - def test_missing_result_field_raises( + def test_missing_result_line_raises( self, mock_run, mock_which, tmp_path, monkeypatch ): sessions_dir = tmp_path / "sessions" @@ -308,15 +344,16 @@ class TestStartSession: monkeypatch.setattr( claude_session, "_SESSIONS_FILE", sessions_dir / "active.json" ) - bad_response = {"session_id": "abc"} # missing "result" - mock_run.return_value = _make_proc(stdout=json.dumps(bad_response)) + # Stream with no result line at all + bad_stream = json.dumps({"type": "assistant", "message": {"content": []}}) + mock_run.return_value = _make_proc(stdout=bad_stream) - with pytest.raises(RuntimeError, match="missing required field"): + with pytest.raises(RuntimeError, match="no result line"): start_session("general", "Hello") @patch("shutil.which", return_value="/usr/bin/claude") @patch("subprocess.run") - def test_missing_session_id_field_raises( + def test_missing_session_id_gives_empty_string( self, mock_run, mock_which, tmp_path, monkeypatch ): sessions_dir = tmp_path / "sessions" @@ -325,8 +362,10 @@ class TestStartSession: monkeypatch.setattr( claude_session, "_SESSIONS_FILE", sessions_dir / "active.json" ) - bad_response = {"result": "hello"} # missing "session_id" - mock_run.return_value = _make_proc(stdout=json.dumps(bad_response)) + # Result line without session_id → _run_claude returns "" for session_id + # → start_session checks for empty session_id + bad_stream = _make_stream("hello", result_override={"session_id": None}) + mock_run.return_value = _make_proc(stdout=bad_stream) with pytest.raises(RuntimeError, match="missing required field"): start_session("general", "Hello")