@@ -27,8 +27,14 @@ SIGNAL_SUCCESS = 'success'
SIGNAL_APPLIED = 'already_applied'
SIGNAL_CONFLICT = 'conflict'
-# Maximum buffer size for agent responses
-MAX_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB
+# Import common Claude agent utilities from shared module
+from u_boot_pylib.claude import (
+ AGENT_AVAILABLE, MAX_BUFFER_SIZE, check_available, run_agent_collect,
+)
+
+ClaudeAgentOptions = None
+if AGENT_AVAILABLE:
+ from u_boot_pylib.claude import ClaudeAgentOptions # pylint: disable=C0412
# Commits that need special handling (regenerate instead of cherry-pick)
# These run savedefconfig on all boards and depend on target branch
@@ -37,54 +43,6 @@ QCONFIG_SUBJECTS = [
'configs: Resync with savedefconfig',
]
-# Check if claude_agent_sdk is available
-try:
- from claude_agent_sdk import query, ClaudeAgentOptions
- AGENT_AVAILABLE = True
-except ImportError:
- AGENT_AVAILABLE = False
-
-
-def check_available():
- """Check if the Claude Agent SDK is available
-
- Returns:
- bool: True if available, False otherwise
- """
- if not AGENT_AVAILABLE:
- tout.error('Claude Agent SDK not available')
- tout.error('Install with: pip install claude-agent-sdk')
- return False
- return True
-
-
-async def run_agent_collect(prompt, options):
- """Run a Claude agent and collect its conversation log
-
- Sends the prompt to a Claude agent, streams output to stdout and
- collects all text blocks into a conversation log.
-
- Args:
- prompt (str): The prompt to send to the agent
- options (ClaudeAgentOptions): Agent configuration
-
- Returns:
- tuple: (success, conversation_log) where success is bool and
- conversation_log is the agent's output text
- """
- conversation_log = []
- try:
- async for message in query(prompt=prompt, options=options):
- if hasattr(message, 'content'):
- for block in message.content:
- if hasattr(block, 'text'):
- print(block.text)
- conversation_log.append(block.text)
- return True, '\n\n'.join(conversation_log)
- except (RuntimeError, ValueError, OSError) as exc:
- tout.error(f'Agent failed: {exc}')
- return False, '\n\n'.join(conversation_log)
-
def is_qconfig_commit(subject):
"""Check if a commit subject indicates a qconfig resync commit
@@ -3121,7 +3121,8 @@ class TestRunAgentCollect(unittest.TestCase):
async def fake_query(**kwargs):
yield msg
- with mock.patch.object(agent, 'query', fake_query, create=True):
+ with mock.patch('u_boot_pylib.claude.query', fake_query,
+ create=True):
with terminal.capture():
opts = mock.MagicMock()
success, log = asyncio.run(
@@ -3141,7 +3142,8 @@ class TestRunAgentCollect(unittest.TestCase):
yield msg
raise RuntimeError('agent crashed')
- with mock.patch.object(agent, 'query', fake_query, create=True):
+ with mock.patch('u_boot_pylib.claude.query', fake_query,
+ create=True):
with terminal.capture():
opts = mock.MagicMock()
success, log = asyncio.run(
@@ -3157,7 +3159,8 @@ class TestRunAgentCollect(unittest.TestCase):
async def fake_query(**kwargs):
yield msg
- with mock.patch.object(agent, 'query', fake_query, create=True):
+ with mock.patch('u_boot_pylib.claude.query', fake_query,
+ create=True):
with terminal.capture():
opts = mock.MagicMock()
success, log = asyncio.run(
@@ -6610,14 +6613,14 @@ class TestResolveSubtreeConflicts(unittest.TestCase):
"""Test successful conflict resolution."""
mock_collect = mock.AsyncMock(return_value=(True, 'resolved'))
with terminal.capture():
- with mock.patch.object(agent, 'AGENT_AVAILABLE', True):
- with mock.patch.object(agent, 'run_agent_collect',
- mock_collect):
- with mock.patch.object(agent, 'ClaudeAgentOptions',
- create=True):
- success, log = agent.resolve_subtree_conflicts(
- 'dts', 'v6.15-dts', 'dts/upstream',
- '/tmp/test')
+ with mock.patch('u_boot_pylib.claude.AGENT_AVAILABLE', True), \
+ mock.patch.object(agent, 'run_agent_collect',
+ mock_collect), \
+ mock.patch.object(agent, 'ClaudeAgentOptions',
+ create=True):
+ success, log = agent.resolve_subtree_conflicts(
+ 'dts', 'v6.15-dts', 'dts/upstream',
+ '/tmp/test')
self.assertTrue(success)
self.assertEqual(log, 'resolved')
@@ -6625,20 +6628,20 @@ class TestResolveSubtreeConflicts(unittest.TestCase):
"""Test failed conflict resolution."""
mock_collect = mock.AsyncMock(return_value=(False, 'failed'))
with terminal.capture():
- with mock.patch.object(agent, 'AGENT_AVAILABLE', True):
- with mock.patch.object(agent, 'run_agent_collect',
- mock_collect):
- with mock.patch.object(agent, 'ClaudeAgentOptions',
- create=True):
- success, log = agent.resolve_subtree_conflicts(
- 'dts', 'v6.15-dts', 'dts/upstream',
- '/tmp/test')
+ with mock.patch('u_boot_pylib.claude.AGENT_AVAILABLE', True), \
+ mock.patch.object(agent, 'run_agent_collect',
+ mock_collect), \
+ mock.patch.object(agent, 'ClaudeAgentOptions',
+ create=True):
+ success, log = agent.resolve_subtree_conflicts(
+ 'dts', 'v6.15-dts', 'dts/upstream',
+ '/tmp/test')
self.assertFalse(success)
def test_sdk_unavailable(self):
"""Test returns failure when SDK is not available."""
with terminal.capture():
- with mock.patch.object(agent, 'AGENT_AVAILABLE', False):
+ with mock.patch('u_boot_pylib.claude.AGENT_AVAILABLE', False):
success, log = agent.resolve_subtree_conflicts(
'dts', 'v6.15-dts', 'dts/upstream', '/tmp/test')
self.assertFalse(success)
@@ -1,4 +1,4 @@
# SPDX-License-Identifier: GPL-2.0+
-__all__ = ['command', 'cros_subprocess', 'gitutil', 'terminal', 'test_util',
- 'tools', 'tout']
+__all__ = ['claude', 'command', 'cros_subprocess', 'gitutil', 'terminal',
+ 'test_util', 'tools', 'tout']
@@ -28,12 +28,14 @@ def run_tests():
help='Verbose output')
args = parser.parse_args()
+ from u_boot_pylib import test_claude
+
to_run = args.testname if args.testname not in [None, 'test'] else None
result = test_util.run_test_suites(
'u_boot_pylib', False, args.verbose, False,
False, None, to_run, None,
['u_boot_pylib.terminal', 'u_boot_pylib.gitutil',
- cros_subprocess.TestSubprocess])
+ cros_subprocess.TestSubprocess, test_claude.TestClaude])
sys.exit(0 if result.wasSuccessful() else 1)
new file mode 100644
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: GPL-2.0+
+#
+# Copyright 2025 Canonical Ltd.
+# Written by Simon Glass <simon.glass@canonical.com>
+#
+
+"""Common Claude Agent SDK utilities.
+
+Provides shared functions for running Claude agents across tools that need
+AI assistance (e.g. pickman, patman review).
+"""
+
+from u_boot_pylib import tout
+
+# Maximum buffer size for agent responses
+MAX_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB
+
+# Check if claude_agent_sdk is available
+try:
+ from claude_agent_sdk import query, ClaudeAgentOptions
+ AGENT_AVAILABLE = True
+except ImportError:
+ AGENT_AVAILABLE = False
+
+
+def check_available():
+ """Check if the Claude Agent SDK is available
+
+ Returns:
+ bool: True if available, False otherwise
+ """
+ if not AGENT_AVAILABLE:
+ tout.error('Claude Agent SDK not available')
+ tout.error('Install with: pip install claude-agent-sdk')
+ return False
+ return True
+
+
+async def run_agent_collect(prompt, options):
+ """Run a Claude agent and collect its conversation log
+
+ Sends the prompt to a Claude agent, streams output to stdout and
+ collects all text blocks into a conversation log.
+
+ Args:
+ prompt (str): The prompt to send to the agent
+ options (ClaudeAgentOptions): Agent configuration
+
+ Returns:
+ tuple: (success, conversation_log) where success is bool and
+ conversation_log is the agent's output text
+ """
+ conversation_log = []
+ try:
+ async for message in query(prompt=prompt, options=options):
+ if hasattr(message, 'content'):
+ for block in message.content:
+ if hasattr(block, 'text'):
+ print(block.text)
+ conversation_log.append(block.text)
+ return True, '\n\n'.join(conversation_log)
+ except (RuntimeError, ValueError, OSError) as exc:
+ tout.error(f'Agent failed: {exc}')
+ return False, '\n\n'.join(conversation_log)
new file mode 100644
@@ -0,0 +1,111 @@
+# SPDX-License-Identifier: GPL-2.0+
+#
+# Copyright 2025 Canonical Ltd.
+#
+
+"""Tests for the Claude Agent SDK utilities module."""
+
+import asyncio
+import unittest
+from unittest.mock import MagicMock
+
+from u_boot_pylib import claude
+from u_boot_pylib import terminal
+
+
+class TestClaude(unittest.TestCase):
+ """Tests for u_boot_pylib.claude"""
+
+ def test_check_available_when_sdk_missing(self):
+ """check_available() returns False when SDK is not installed"""
+ if not claude.AGENT_AVAILABLE:
+ with terminal.capture():
+ self.assertFalse(claude.check_available())
+
+ def test_check_available_when_sdk_present(self):
+ """check_available() returns True when SDK is installed"""
+ old = claude.AGENT_AVAILABLE
+ try:
+ claude.AGENT_AVAILABLE = True
+ self.assertTrue(claude.check_available())
+ finally:
+ claude.AGENT_AVAILABLE = old
+
+ def test_max_buffer_size(self):
+ """MAX_BUFFER_SIZE is defined and reasonable"""
+ self.assertEqual(claude.MAX_BUFFER_SIZE, 10 * 1024 * 1024)
+
+ def _setup_claude_with_mock_query(self, mock_query):
+ """Inject a mock query function into the claude module"""
+ claude.query = mock_query
+
+ def test_run_agent_collect_success(self):
+ """run_agent_collect() collects text from agent messages"""
+ block1 = MagicMock()
+ block1.text = 'Hello'
+ msg1 = MagicMock()
+ msg1.content = [block1]
+
+ block2 = MagicMock()
+ block2.text = 'World'
+ msg2 = MagicMock()
+ msg2.content = [block2]
+
+ # pylint: disable=W0613
+ async def mock_query(**kwargs):
+ for msg in [msg1, msg2]:
+ yield msg
+
+ self._setup_claude_with_mock_query(mock_query)
+ loop = asyncio.new_event_loop()
+ with terminal.capture():
+ success, log = loop.run_until_complete(
+ claude.run_agent_collect('test prompt', MagicMock()))
+ loop.close()
+
+ self.assertTrue(success)
+ self.assertIn('Hello', log)
+ self.assertIn('World', log)
+
+ def test_run_agent_collect_handles_error(self):
+ """run_agent_collect() returns False on agent failure"""
+ # pylint: disable=W0613
+ async def mock_query(**kwargs):
+ raise RuntimeError('Agent crashed')
+ yield # pylint: disable=W0101
+
+ self._setup_claude_with_mock_query(mock_query)
+ loop = asyncio.new_event_loop()
+ with terminal.capture():
+ success, _ = loop.run_until_complete(
+ claude.run_agent_collect('test prompt', MagicMock()))
+ loop.close()
+
+ self.assertFalse(success)
+
+ def test_run_agent_collect_skips_non_text_blocks(self):
+ """run_agent_collect() ignores blocks without text attribute"""
+ text_block = MagicMock()
+ text_block.text = 'Real text'
+ tool_block = MagicMock(spec=[]) # No text attribute
+
+ msg = MagicMock()
+ msg.content = [tool_block, text_block]
+
+ # pylint: disable=W0613
+ async def mock_query(**kwargs):
+ yield msg
+
+ self._setup_claude_with_mock_query(mock_query)
+ loop = asyncio.new_event_loop()
+ with terminal.capture():
+ success, log = loop.run_until_complete(
+ claude.run_agent_collect('test prompt', MagicMock()))
+ loop.close()
+
+ self.assertTrue(success)
+ self.assertIn('Real text', log)
+
+
+if __name__ == '__main__':
+ unittest.main()