[Concept,08/24] pickman: Extract prepare_apply() from do_apply()

Message ID 20251217022823.392557-9-sjg@u-boot.org
State New
Headers
Series pickman: Refine the feature set |

Commit Message

Simon Glass Dec. 17, 2025, 2:27 a.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

Extract the setup logic from do_apply() into prepare_apply() which:
- Gets the next commits from the source branch
- Validates the source exists and has commits
- Saves the original branch
- Generates or uses provided branch name
- Deletes existing branch if needed
- Prints info about what will be applied

Returns (ApplyInfo, return_code) tuple where ApplyInfo contains commits,
branch_name, original_branch, and merge_found. This makes the setup
logic testable independently from the agent and git operations.

Add tests for prepare_apply():
- test_prepare_apply_error: Test error handling for unknown source
- test_prepare_apply_no_commits: Test when no commits to apply
- test_prepare_apply_with_commits: Test successful preparation
- test_prepare_apply_custom_branch: Test custom branch name

Co-developed-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/control.py |  48 ++++++++++++++---
 tools/pickman/ftest.py   | 108 ++++++++++++++++++++++++++++++++++++++-
 2 files changed, 146 insertions(+), 10 deletions(-)
  

Patch

diff --git a/tools/pickman/control.py b/tools/pickman/control.py
index 190f92fc57a..aa3fd65ec8e 100644
--- a/tools/pickman/control.py
+++ b/tools/pickman/control.py
@@ -39,6 +39,11 @@  Commit = namedtuple('Commit', ['hash', 'short_hash', 'subject', 'date'])
 CommitInfo = namedtuple('CommitInfo',
                         ['hash', 'short_hash', 'subject', 'author'])
 
+# Named tuple for prepare_apply result
+ApplyInfo = namedtuple('ApplyInfo',
+                       ['commits', 'branch_name', 'original_branch',
+                        'merge_found'])
+
 
 def run_git(args):
     """Run a git command and return output."""
@@ -323,32 +328,37 @@  def write_history(source, commits, branch_name, conversation_log):
     tout.info(f'Updated {HISTORY_FILE}')
 
 
-def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
-    """Apply the next set of commits using Claude agent
+def prepare_apply(dbs, source, branch):
+    """Prepare for applying commits from a source branch
+
+    Gets the next commits, sets up the branch name, and prints info about
+    what will be applied.
 
     Args:
-        args (Namespace): Parsed arguments with 'source' and 'branch' attributes
         dbs (Database): Database instance
+        source (str): Source branch name
+        branch (str): Branch name to use, or None to auto-generate
 
     Returns:
-        int: 0 on success, 1 on failure
+        tuple: (ApplyInfo, return_code) where ApplyInfo is set if there are
+            commits to apply, or None with return_code indicating the result
+            (0 for no commits, 1 for error)
     """
-    source = args.source
     commits, merge_found, error = get_next_commits(dbs, source)
 
     if error:
         tout.error(error)
-        return 1
+        return None, 1
 
     if not commits:
         tout.info('No new commits to cherry-pick')
-        return 0
+        return None, 0
 
     # Save current branch to return to later
     original_branch = run_git(['rev-parse', '--abbrev-ref', 'HEAD'])
 
     # Generate branch name if not provided
-    branch_name = args.branch
+    branch_name = branch
     if not branch_name:
         # Use first commit's short hash as part of branch name
         branch_name = f'cherry-{commits[0].short_hash}'
@@ -372,6 +382,28 @@  def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
         tout.info(f'  {commit.short_hash} {commit.subject}')
     tout.info('')
 
+    return ApplyInfo(commits, branch_name, original_branch, merge_found), 0
+
+
+def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
+    """Apply the next set of commits using Claude agent
+
+    Args:
+        args (Namespace): Parsed arguments with 'source' and 'branch' attributes
+        dbs (Database): Database instance
+
+    Returns:
+        int: 0 on success, 1 on failure
+    """
+    source = args.source
+    info, ret = prepare_apply(dbs, source, args.branch)
+    if not info:
+        return ret
+
+    commits = info.commits
+    branch_name = info.branch_name
+    original_branch = info.original_branch
+
     # Add commits to database with 'pending' status
     source_id = dbs.source_get_id(source)
     for commit in commits:
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index d50be16fea7..1418211d4ae 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -976,7 +976,7 @@  class TestApply(unittest.TestCase):
 
         database.Database.instances.clear()
 
-        args = argparse.Namespace(cmd='apply', source='unknown')
+        args = argparse.Namespace(cmd='apply', source='unknown', branch=None)
         with terminal.capture() as (_, stderr):
             ret = control.do_pickman(args)
         self.assertEqual(ret, 1)
@@ -994,7 +994,7 @@  class TestApply(unittest.TestCase):
         database.Database.instances.clear()
         command.TEST_RESULT = command.CommandResult(stdout='')
 
-        args = argparse.Namespace(cmd='apply', source='us/next')
+        args = argparse.Namespace(cmd='apply', source='us/next', branch=None)
         with terminal.capture() as (stdout, _):
             ret = control.do_pickman(args)
         self.assertEqual(ret, 0)
@@ -1518,6 +1518,110 @@  Other content
         self.assertIn('- ccc333c Third commit', commit_msg)
 
 
+class TestPrepareApply(unittest.TestCase):
+    """Tests for prepare_apply function."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        fd, self.db_path = tempfile.mkstemp(suffix='.db')
+        os.close(fd)
+        os.unlink(self.db_path)
+        self.old_db_fname = control.DB_FNAME
+        control.DB_FNAME = self.db_path
+        database.Database.instances.clear()
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        control.DB_FNAME = self.old_db_fname
+        if os.path.exists(self.db_path):
+            os.unlink(self.db_path)
+        database.Database.instances.clear()
+        command.TEST_RESULT = None
+
+    def test_prepare_apply_error(self):
+        """Test prepare_apply returns error code 1 on source not found."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+
+            info, ret = control.prepare_apply(dbs, 'unknown', None)
+
+            self.assertIsNone(info)
+            self.assertEqual(ret, 1)
+            dbs.close()
+
+    def test_prepare_apply_no_commits(self):
+        """Test prepare_apply returns code 0 when no commits."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            command.TEST_RESULT = command.CommandResult(stdout='')
+
+            info, ret = control.prepare_apply(dbs, 'us/next', None)
+
+            self.assertIsNone(info)
+            self.assertEqual(ret, 0)
+            dbs.close()
+
+    def test_prepare_apply_with_commits(self):
+        """Test prepare_apply returns ApplyInfo with commits."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            log_output = 'aaa111|aaa111a|Author 1|First commit|abc123\n'
+
+            def mock_git(pipe_list):
+                cmd = pipe_list[0] if pipe_list else []
+                if 'log' in cmd:
+                    return command.CommandResult(stdout=log_output)
+                if 'rev-parse' in cmd:
+                    return command.CommandResult(stdout='master')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            info, ret = control.prepare_apply(dbs, 'us/next', None)
+
+            self.assertIsNotNone(info)
+            self.assertEqual(ret, 0)
+            self.assertEqual(len(info.commits), 1)
+            self.assertEqual(info.branch_name, 'cherry-aaa111a')
+            self.assertEqual(info.original_branch, 'master')
+            dbs.close()
+
+    def test_prepare_apply_custom_branch(self):
+        """Test prepare_apply uses custom branch name."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            log_output = 'aaa111|aaa111a|Author 1|First commit|abc123\n'
+
+            def mock_git(pipe_list):
+                cmd = pipe_list[0] if pipe_list else []
+                if 'log' in cmd:
+                    return command.CommandResult(stdout=log_output)
+                if 'rev-parse' in cmd:
+                    return command.CommandResult(stdout='master')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            info, _ = control.prepare_apply(dbs, 'us/next', 'my-branch')
+
+            self.assertIsNotNone(info)
+            self.assertEqual(info.branch_name, 'my-branch')
+            dbs.close()
+
+
 class TestGetNextCommitsEmptyLine(unittest.TestCase):
     """Tests for get_next_commits with empty lines."""