[Concept,09/24] pickman: Extract execute_apply() from do_apply()

Message ID 20251217022823.392557-10-sjg@u-boot.org
State New
Headers
Series pickman: Refine the feature set |

Commit Message

Simon Glass Dec. 17, 2025, 2:27 a.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

Extract the core 'apply' logic into execute_apply() which handles
database operations, agent invocation, and MR creation. This separates
the non-git logic from do_apply() which now handles setup, history
writing, and branch restoration.

Also move branch restoration to the end of do_apply() using a ret
variable to ensure we always restore the branch even if MR creation
fails.

Add tests for execute_apply() covering success, failure, push, and
push failure cases.

Rename conversation_log to conv_log for brevity.

Co-developed-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/control.py |  86 +++++++++++++++++-----------
 tools/pickman/ftest.py   | 120 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 174 insertions(+), 32 deletions(-)
  

Patch

diff --git a/tools/pickman/control.py b/tools/pickman/control.py
index aa3fd65ec8e..973db22e106 100644
--- a/tools/pickman/control.py
+++ b/tools/pickman/control.py
@@ -263,7 +263,7 @@  Commits:
 {commit_list}"""
 
 
-def get_history(fname, source, commits, branch_name, conversation_log):
+def get_history(fname, source, commits, branch_name, conv_log):
     """Read, update and write history file for a cherry-pick operation
 
     Args:
@@ -271,7 +271,7 @@  def get_history(fname, source, commits, branch_name, conversation_log):
         source (str): Source branch name
         commits (list): list of CommitInfo tuples
         branch_name (str): Name of the cherry-pick branch
-        conversation_log (str): The agent's conversation output
+        conv_log (str): The agent's conversation output
 
     Returns:
         tuple: (content, commit_msg) where content is the updated history
@@ -281,7 +281,7 @@  def get_history(fname, source, commits, branch_name, conversation_log):
     entry = f"""{summary}
 
 ### Conversation log
-{conversation_log}
+{conv_log}
 
 ---
 
@@ -309,17 +309,17 @@  def get_history(fname, source, commits, branch_name, conversation_log):
     return content, commit_msg
 
 
-def write_history(source, commits, branch_name, conversation_log):
+def write_history(source, commits, branch_name, conv_log):
     """Write an entry to the pickman history file and commit it
 
     Args:
         source (str): Source branch name
         commits (list): list of CommitInfo tuples
         branch_name (str): Name of the cherry-pick branch
-        conversation_log (str): The agent's conversation output
+        conv_log (str): The agent's conversation output
     """
     _, commit_msg = get_history(HISTORY_FILE, source, commits, branch_name,
-                                conversation_log)
+                                conv_log)
 
     # Commit the history file (use -f in case .gitignore patterns match)
     run_git(['add', '-f', HISTORY_FILE])
@@ -385,25 +385,20 @@  def prepare_apply(dbs, source, branch):
     return ApplyInfo(commits, branch_name, original_branch, merge_found), 0
 
 
-def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
-    """Apply the next set of commits using Claude agent
+def execute_apply(dbs, source, commits, branch_name, args):  # pylint: disable=too-many-locals
+    """Execute the apply operation: run agent, update database, push MR
 
     Args:
-        args (Namespace): Parsed arguments with 'source' and 'branch' attributes
         dbs (Database): Database instance
+        source (str): Source branch name
+        commits (list): List of CommitInfo namedtuples
+        branch_name (str): Branch name for cherry-picks
+        args (Namespace): Parsed arguments with 'push', 'remote', 'target'
 
     Returns:
-        int: 0 on success, 1 on failure
+        tuple: (ret, success, conv_log) where ret is 0 on success,
+            1 on failure
     """
-    source = args.source
-    info, ret = prepare_apply(dbs, source, args.branch)
-    if not info:
-        return ret
-
-    commits = info.commits
-    branch_name = info.branch_name
-    original_branch = info.original_branch
-
     # Add commits to database with 'pending' status
     source_id = dbs.source_get_id(source)
     for commit in commits:
@@ -413,7 +408,7 @@  def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
 
     # Convert CommitInfo to tuple format expected by agent
     commit_tuples = [(c.hash, c.short_hash, c.subject) for c in commits]
-    success, conversation_log = agent.cherry_pick_commits(commit_tuples, source,
+    success, conv_log = agent.cherry_pick_commits(commit_tuples, source,
                                                           branch_name)
 
     # Update commit status based on result
@@ -422,15 +417,7 @@  def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
         dbs.commit_set_status(commit.hash, status)
     dbs.commit()
 
-    # Write history file if successful
-    if success:
-        write_history(source, commits, branch_name, conversation_log)
-
-    # Return to original branch
-    current_branch = run_git(['rev-parse', '--abbrev-ref', 'HEAD'])
-    if current_branch != original_branch:
-        tout.info(f'Returning to {original_branch}')
-        run_git(['checkout', original_branch])
+    ret = 0 if success else 1
 
     if success:
         # Push and create MR if requested
@@ -441,18 +428,53 @@  def do_apply(args, dbs):  # pylint: disable=too-many-locals,too-many-branches
             title = f'[pickman] {commits[-1].subject}'
             # Description matches .pickman-history entry (summary + conversation)
             summary = format_history_summary(source, commits, branch_name)
-            description = f'{summary}\n\n### Conversation log\n{conversation_log}'
+            description = f'{summary}\n\n### Conversation log\n{conv_log}'
 
             mr_url = gitlab_api.push_and_create_mr(
                 remote, branch_name, target, title, description
             )
             if not mr_url:
-                return 1
+                ret = 1
         else:
             tout.info(f"Use 'pickman commit-source {source} "
                       f"{commits[-1].short_hash}' to update the database")
 
-    return 0 if success else 1
+    return ret, success, conv_log
+
+
+def do_apply(args, dbs):
+    """Apply the next set of commits using Claude agent
+
+    Args:
+        args (Namespace): Parsed arguments with 'source' and 'branch' attributes
+        dbs (Database): Database instance
+
+    Returns:
+        int: 0 on success, 1 on failure
+    """
+    source = args.source
+    info, ret = prepare_apply(dbs, source, args.branch)
+    if not info:
+        return ret
+
+    commits = info.commits
+    branch_name = info.branch_name
+    original_branch = info.original_branch
+
+    ret, success, conv_log = execute_apply(dbs, source, commits,
+                                                   branch_name, args)
+
+    # Write history file if successful
+    if success:
+        write_history(source, commits, branch_name, conv_log)
+
+    # Return to original branch
+    current_branch = run_git(['rev-parse', '--abbrev-ref', 'HEAD'])
+    if current_branch != original_branch:
+        tout.info(f'Returning to {original_branch}')
+        run_git(['checkout', original_branch])
+
+    return ret
 
 
 def do_commit_source(args, dbs):
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index 1418211d4ae..9c65652f27a 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -1622,6 +1622,126 @@  class TestPrepareApply(unittest.TestCase):
             dbs.close()
 
 
+class TestExecuteApply(unittest.TestCase):
+    """Tests for execute_apply function."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        fd, self.db_path = tempfile.mkstemp(suffix='.db')
+        os.close(fd)
+        os.unlink(self.db_path)
+        self.old_db_fname = control.DB_FNAME
+        control.DB_FNAME = self.db_path
+        database.Database.instances.clear()
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        control.DB_FNAME = self.old_db_fname
+        if os.path.exists(self.db_path):
+            os.unlink(self.db_path)
+        database.Database.instances.clear()
+
+    def test_execute_apply_success(self):
+        """Test execute_apply with successful cherry-pick."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            commits = [control.CommitInfo('aaa111', 'aaa111a', 'Test commit',
+                                          'Author')]
+            args = argparse.Namespace(push=False)
+
+            with mock.patch.object(control.agent, 'cherry_pick_commits',
+                                   return_value=(True, 'conversation log')):
+                ret, success, conv_log = control.execute_apply(
+                    dbs, 'us/next', commits, 'cherry-branch', args)
+
+            self.assertEqual(ret, 0)
+            self.assertTrue(success)
+            self.assertEqual(conv_log, 'conversation log')
+
+            # Check commit was added to database
+            commit_rec = dbs.commit_get('aaa111')
+            self.assertIsNotNone(commit_rec)
+            self.assertEqual(commit_rec[6], 'applied')  # status field
+            dbs.close()
+
+    def test_execute_apply_failure(self):
+        """Test execute_apply with failed cherry-pick."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            commits = [control.CommitInfo('bbb222', 'bbb222b', 'Test commit',
+                                          'Author')]
+            args = argparse.Namespace(push=False)
+
+            with mock.patch.object(control.agent, 'cherry_pick_commits',
+                                   return_value=(False, 'error log')):
+                ret, success, _ = control.execute_apply(
+                    dbs, 'us/next', commits, 'cherry-branch', args)
+
+            self.assertEqual(ret, 1)
+            self.assertFalse(success)
+
+            # Check commit status is conflict
+            commit_rec = dbs.commit_get('bbb222')
+            self.assertEqual(commit_rec[6], 'conflict')
+            dbs.close()
+
+    def test_execute_apply_with_push(self):
+        """Test execute_apply with push enabled."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            commits = [control.CommitInfo('ccc333', 'ccc333c', 'Test commit',
+                                          'Author')]
+            args = argparse.Namespace(push=True, remote='origin',
+                                      target='main')
+
+            with mock.patch.object(control.agent, 'cherry_pick_commits',
+                                   return_value=(True, 'log')):
+                with mock.patch.object(gitlab_api, 'push_and_create_mr',
+                                       return_value='https://mr/url'):
+                    ret, success, _ = control.execute_apply(
+                        dbs, 'us/next', commits, 'cherry-branch', args)
+
+            self.assertEqual(ret, 0)
+            self.assertTrue(success)
+            dbs.close()
+
+    def test_execute_apply_push_fails(self):
+        """Test execute_apply when MR creation fails."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+
+            commits = [control.CommitInfo('ddd444', 'ddd444d', 'Test commit',
+                                          'Author')]
+            args = argparse.Namespace(push=True, remote='origin',
+                                      target='main')
+
+            with mock.patch.object(control.agent, 'cherry_pick_commits',
+                                   return_value=(True, 'log')):
+                with mock.patch.object(gitlab_api, 'push_and_create_mr',
+                                       return_value=None):
+                    ret, success, _ = control.execute_apply(
+                        dbs, 'us/next', commits, 'cherry-branch', args)
+
+            self.assertEqual(ret, 1)
+            self.assertTrue(success)  # cherry-pick succeeded, MR failed
+            dbs.close()
+
+
 class TestGetNextCommitsEmptyLine(unittest.TestCase):
     """Tests for get_next_commits with empty lines."""