[Concept,09/16] pickman: Create a function to run an agent

Message ID 20260222154303.2851319-10-sjg@u-boot.org
State New
Headers
Series pickman: Support monitoring and fixing pipeline failures |

Commit Message

Simon Glass Feb. 22, 2026, 3:42 p.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

The agent-message-streaming pattern (async iteration, text extraction
and conversation-log collection) is duplicated in run() and
run_review_agent()

Extract it into a shared run_agent_collect() helper.

Co-developed-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/agent.py | 55 ++++++++++++++++++++-----------------
 tools/pickman/ftest.py | 62 ++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 92 insertions(+), 25 deletions(-)
  

Patch

diff --git a/tools/pickman/agent.py b/tools/pickman/agent.py
index 85f8efee1df..63952c1c005 100644
--- a/tools/pickman/agent.py
+++ b/tools/pickman/agent.py
@@ -55,6 +55,34 @@  def check_available():
     return True
 
 
+async def run_agent_collect(prompt, options):
+    """Run a Claude agent and collect its conversation log
+
+    Sends the prompt to a Claude agent, streams output to stdout and
+    collects all text blocks into a conversation log.
+
+    Args:
+        prompt (str): The prompt to send to the agent
+        options (ClaudeAgentOptions): Agent configuration
+
+    Returns:
+        tuple: (success, conversation_log) where success is bool and
+            conversation_log is the agent's output text
+    """
+    conversation_log = []
+    try:
+        async for message in query(prompt=prompt, options=options):
+            if hasattr(message, 'content'):
+                for block in message.content:
+                    if hasattr(block, 'text'):
+                        print(block.text)
+                        conversation_log.append(block.text)
+        return True, '\n\n'.join(conversation_log)
+    except (RuntimeError, ValueError, OSError) as exc:
+        tout.error(f'Agent failed: {exc}')
+        return False, '\n\n'.join(conversation_log)
+
+
 def is_qconfig_commit(subject):
     """Check if a commit subject indicates a qconfig resync commit
 
@@ -228,19 +256,7 @@  this means the series was already applied via a different path. In this case:
     tout.info(f'Starting Claude agent to cherry-pick {len(commits)} commits...')
     tout.info('')
 
-    conversation_log = []
-    try:
-        async for message in query(prompt=prompt, options=options):
-            # Print agent output and capture it
-            if hasattr(message, 'content'):
-                for block in message.content:
-                    if hasattr(block, 'text'):
-                        print(block.text)
-                        conversation_log.append(block.text)
-        return True, '\n\n'.join(conversation_log)
-    except (RuntimeError, ValueError, OSError) as exc:
-        tout.error(f'Agent failed: {exc}')
-        return False, '\n\n'.join(conversation_log)
+    return await run_agent_collect(prompt, options)
 
 
 def read_signal_file(repo_path=None):
@@ -492,18 +508,7 @@  async def run_review_agent(mr_iid, branch_name, comments, remote,
     tout.info(f'Starting Claude agent to {task_desc}...')
     tout.info('')
 
-    conversation_log = []
-    try:
-        async for message in query(prompt=prompt, options=options):
-            if hasattr(message, 'content'):
-                for block in message.content:
-                    if hasattr(block, 'text'):
-                        print(block.text)
-                        conversation_log.append(block.text)
-        return True, '\n\n'.join(conversation_log)
-    except (RuntimeError, ValueError, OSError) as exc:
-        tout.error(f'Agent failed: {exc}')
-        return False, '\n\n'.join(conversation_log)
+    return await run_agent_collect(prompt, options)
 
 
 # pylint: disable=too-many-arguments
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index de6bce40614..42ce05962e2 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -6,6 +6,7 @@ 
 # pylint: disable=too-many-lines
 """Tests for pickman."""
 
+import asyncio
 import argparse
 import os
 import shutil
@@ -2971,6 +2972,67 @@  class TestExecuteApply(unittest.TestCase):
             dbs.close()
 
 
+class TestRunAgentCollect(unittest.TestCase):
+    """Tests for run_agent_collect function."""
+
+    def test_success(self):
+        """Test successful agent run collects text blocks."""
+        block1 = mock.MagicMock()
+        block1.text = 'hello'
+        block2 = mock.MagicMock()
+        block2.text = 'world'
+        msg = mock.MagicMock()
+        msg.content = [block1, block2]
+
+        async def fake_query(**kwargs):
+            yield msg
+
+        with mock.patch.object(agent, 'query', fake_query, create=True):
+            with terminal.capture():
+                opts = mock.MagicMock()
+                success, log = asyncio.run(
+                    agent.run_agent_collect('prompt', opts))
+
+        self.assertTrue(success)
+        self.assertEqual(log, 'hello\n\nworld')
+
+    def test_failure(self):
+        """Test agent failure returns False with partial log."""
+        block = mock.MagicMock()
+        block.text = 'partial'
+        msg = mock.MagicMock()
+        msg.content = [block]
+
+        async def fake_query(**kwargs):
+            yield msg
+            raise RuntimeError('agent crashed')
+
+        with mock.patch.object(agent, 'query', fake_query, create=True):
+            with terminal.capture():
+                opts = mock.MagicMock()
+                success, log = asyncio.run(
+                    agent.run_agent_collect('prompt', opts))
+
+        self.assertFalse(success)
+        self.assertEqual(log, 'partial')
+
+    def test_no_content(self):
+        """Test messages without content are skipped."""
+        msg = mock.MagicMock(spec=[])
+
+        async def fake_query(**kwargs):
+            yield msg
+
+        with mock.patch.object(agent, 'query', fake_query, create=True):
+            with terminal.capture():
+                opts = mock.MagicMock()
+                success, log = asyncio.run(
+                    agent.run_agent_collect('prompt', opts))
+
+        self.assertTrue(success)
+        self.assertEqual(log, '')
+
+
 class TestSignalFile(unittest.TestCase):
     """Tests for signal file handling."""