@@ -55,6 +55,34 @@ def check_available():
return True
+async def run_agent_collect(prompt, options):
+ """Run a Claude agent and collect its conversation log
+
+ Sends the prompt to a Claude agent, streams output to stdout and
+ collects all text blocks into a conversation log.
+
+ Args:
+ prompt (str): The prompt to send to the agent
+ options (ClaudeAgentOptions): Agent configuration
+
+ Returns:
+ tuple: (success, conversation_log) where success is bool and
+ conversation_log is the agent's output text
+ """
+ conversation_log = []
+ try:
+ async for message in query(prompt=prompt, options=options):
+ if hasattr(message, 'content'):
+ for block in message.content:
+ if hasattr(block, 'text'):
+ print(block.text)
+ conversation_log.append(block.text)
+ return True, '\n\n'.join(conversation_log)
+ except (RuntimeError, ValueError, OSError) as exc:
+ tout.error(f'Agent failed: {exc}')
+ return False, '\n\n'.join(conversation_log)
+
+
def is_qconfig_commit(subject):
"""Check if a commit subject indicates a qconfig resync commit
@@ -228,19 +256,7 @@ this means the series was already applied via a different path. In this case:
tout.info(f'Starting Claude agent to cherry-pick {len(commits)} commits...')
tout.info('')
- conversation_log = []
- try:
- async for message in query(prompt=prompt, options=options):
- # Print agent output and capture it
- if hasattr(message, 'content'):
- for block in message.content:
- if hasattr(block, 'text'):
- print(block.text)
- conversation_log.append(block.text)
- return True, '\n\n'.join(conversation_log)
- except (RuntimeError, ValueError, OSError) as exc:
- tout.error(f'Agent failed: {exc}')
- return False, '\n\n'.join(conversation_log)
+ return await run_agent_collect(prompt, options)
def read_signal_file(repo_path=None):
@@ -492,18 +508,7 @@ async def run_review_agent(mr_iid, branch_name, comments, remote,
tout.info(f'Starting Claude agent to {task_desc}...')
tout.info('')
- conversation_log = []
- try:
- async for message in query(prompt=prompt, options=options):
- if hasattr(message, 'content'):
- for block in message.content:
- if hasattr(block, 'text'):
- print(block.text)
- conversation_log.append(block.text)
- return True, '\n\n'.join(conversation_log)
- except (RuntimeError, ValueError, OSError) as exc:
- tout.error(f'Agent failed: {exc}')
- return False, '\n\n'.join(conversation_log)
+ return await run_agent_collect(prompt, options)
# pylint: disable=too-many-arguments
@@ -6,6 +6,7 @@
# pylint: disable=too-many-lines
"""Tests for pickman."""
+import asyncio
import argparse
import os
import shutil
@@ -2971,6 +2972,67 @@ class TestExecuteApply(unittest.TestCase):
dbs.close()
+class TestRunAgentCollect(unittest.TestCase):
+ """Tests for run_agent_collect function."""
+
+ def test_success(self):
+ """Test successful agent run collects text blocks."""
+ block1 = mock.MagicMock()
+ block1.text = 'hello'
+ block2 = mock.MagicMock()
+ block2.text = 'world'
+ msg = mock.MagicMock()
+ msg.content = [block1, block2]
+
+ async def fake_query(**kwargs):
+ yield msg
+
+ with mock.patch.object(agent, 'query', fake_query, create=True):
+ with terminal.capture():
+ opts = mock.MagicMock()
+ success, log = asyncio.run(
+ agent.run_agent_collect('prompt', opts))
+
+ self.assertTrue(success)
+ self.assertEqual(log, 'hello\n\nworld')
+
+ def test_failure(self):
+ """Test agent failure returns False with partial log."""
+ block = mock.MagicMock()
+ block.text = 'partial'
+ msg = mock.MagicMock()
+ msg.content = [block]
+
+ async def fake_query(**kwargs):
+ yield msg
+ raise RuntimeError('agent crashed')
+
+ with mock.patch.object(agent, 'query', fake_query, create=True):
+ with terminal.capture():
+ opts = mock.MagicMock()
+ success, log = asyncio.run(
+ agent.run_agent_collect('prompt', opts))
+
+ self.assertFalse(success)
+ self.assertEqual(log, 'partial')
+
+ def test_no_content(self):
+ """Test messages without content are skipped."""
+ msg = mock.MagicMock(spec=[])
+
+ async def fake_query(**kwargs):
+ yield msg
+
+ with mock.patch.object(agent, 'query', fake_query, create=True):
+ with terminal.capture():
+ opts = mock.MagicMock()
+ success, log = asyncio.run(
+ agent.run_agent_collect('prompt', opts))
+
+ self.assertTrue(success)
+ self.assertEqual(log, '')
+
+
class TestSignalFile(unittest.TestCase):
"""Tests for signal file handling."""