@@ -39,6 +39,11 @@ Commit = namedtuple('Commit', ['hash', 'short_hash', 'subject', 'date'])
CommitInfo = namedtuple('CommitInfo',
['hash', 'short_hash', 'subject', 'author'])
+# Named tuple for prepare_apply result
+ApplyInfo = namedtuple('ApplyInfo',
+ ['commits', 'branch_name', 'original_branch',
+ 'merge_found'])
+
def run_git(args):
"""Run a git command and return output."""
@@ -323,32 +328,37 @@ def write_history(source, commits, branch_name, conversation_log):
tout.info(f'Updated {HISTORY_FILE}')
-def do_apply(args, dbs): # pylint: disable=too-many-locals,too-many-branches
- """Apply the next set of commits using Claude agent
+def prepare_apply(dbs, source, branch):
+ """Prepare for applying commits from a source branch
+
+ Gets the next commits, sets up the branch name, and prints info about
+ what will be applied.
Args:
- args (Namespace): Parsed arguments with 'source' and 'branch' attributes
dbs (Database): Database instance
+ source (str): Source branch name
+ branch (str): Branch name to use, or None to auto-generate
Returns:
- int: 0 on success, 1 on failure
+ tuple: (ApplyInfo, return_code) where ApplyInfo is set if there are
+ commits to apply, or None with return_code indicating the result
+ (0 for no commits, 1 for error)
"""
- source = args.source
commits, merge_found, error = get_next_commits(dbs, source)
if error:
tout.error(error)
- return 1
+ return None, 1
if not commits:
tout.info('No new commits to cherry-pick')
- return 0
+ return None, 0
# Save current branch to return to later
original_branch = run_git(['rev-parse', '--abbrev-ref', 'HEAD'])
# Generate branch name if not provided
- branch_name = args.branch
+ branch_name = branch
if not branch_name:
# Use first commit's short hash as part of branch name
branch_name = f'cherry-{commits[0].short_hash}'
@@ -372,6 +382,28 @@ def do_apply(args, dbs): # pylint: disable=too-many-locals,too-many-branches
tout.info(f' {commit.short_hash} {commit.subject}')
tout.info('')
+ return ApplyInfo(commits, branch_name, original_branch, merge_found), 0
+
+
+def do_apply(args, dbs): # pylint: disable=too-many-locals,too-many-branches
+ """Apply the next set of commits using Claude agent
+
+ Args:
+ args (Namespace): Parsed arguments with 'source' and 'branch' attributes
+ dbs (Database): Database instance
+
+ Returns:
+ int: 0 on success, 1 on failure
+ """
+ source = args.source
+ info, ret = prepare_apply(dbs, source, args.branch)
+ if not info:
+ return ret
+
+ commits = info.commits
+ branch_name = info.branch_name
+ original_branch = info.original_branch
+
# Add commits to database with 'pending' status
source_id = dbs.source_get_id(source)
for commit in commits:
@@ -976,7 +976,7 @@ class TestApply(unittest.TestCase):
database.Database.instances.clear()
- args = argparse.Namespace(cmd='apply', source='unknown')
+ args = argparse.Namespace(cmd='apply', source='unknown', branch=None)
with terminal.capture() as (_, stderr):
ret = control.do_pickman(args)
self.assertEqual(ret, 1)
@@ -994,7 +994,7 @@ class TestApply(unittest.TestCase):
database.Database.instances.clear()
command.TEST_RESULT = command.CommandResult(stdout='')
- args = argparse.Namespace(cmd='apply', source='us/next')
+ args = argparse.Namespace(cmd='apply', source='us/next', branch=None)
with terminal.capture() as (stdout, _):
ret = control.do_pickman(args)
self.assertEqual(ret, 0)
@@ -1518,6 +1518,110 @@ Other content
self.assertIn('- ccc333c Third commit', commit_msg)
+class TestPrepareApply(unittest.TestCase):
+ """Tests for prepare_apply function."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ fd, self.db_path = tempfile.mkstemp(suffix='.db')
+ os.close(fd)
+ os.unlink(self.db_path)
+ self.old_db_fname = control.DB_FNAME
+ control.DB_FNAME = self.db_path
+ database.Database.instances.clear()
+
+ def tearDown(self):
+ """Clean up test fixtures."""
+ control.DB_FNAME = self.old_db_fname
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ database.Database.instances.clear()
+ command.TEST_RESULT = None
+
+ def test_prepare_apply_error(self):
+ """Test prepare_apply returns error code 1 on source not found."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+
+ info, ret = control.prepare_apply(dbs, 'unknown', None)
+
+ self.assertIsNone(info)
+ self.assertEqual(ret, 1)
+ dbs.close()
+
+ def test_prepare_apply_no_commits(self):
+ """Test prepare_apply returns code 0 when no commits."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ dbs.source_set('us/next', 'abc123')
+ dbs.commit()
+
+ command.TEST_RESULT = command.CommandResult(stdout='')
+
+ info, ret = control.prepare_apply(dbs, 'us/next', None)
+
+ self.assertIsNone(info)
+ self.assertEqual(ret, 0)
+ dbs.close()
+
+ def test_prepare_apply_with_commits(self):
+ """Test prepare_apply returns ApplyInfo with commits."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ dbs.source_set('us/next', 'abc123')
+ dbs.commit()
+
+ log_output = 'aaa111|aaa111a|Author 1|First commit|abc123\n'
+
+ def mock_git(pipe_list):
+ cmd = pipe_list[0] if pipe_list else []
+ if 'log' in cmd:
+ return command.CommandResult(stdout=log_output)
+ if 'rev-parse' in cmd:
+ return command.CommandResult(stdout='master')
+ return command.CommandResult(stdout='')
+
+ command.TEST_RESULT = mock_git
+
+ info, ret = control.prepare_apply(dbs, 'us/next', None)
+
+ self.assertIsNotNone(info)
+ self.assertEqual(ret, 0)
+ self.assertEqual(len(info.commits), 1)
+ self.assertEqual(info.branch_name, 'cherry-aaa111a')
+ self.assertEqual(info.original_branch, 'master')
+ dbs.close()
+
+ def test_prepare_apply_custom_branch(self):
+ """Test prepare_apply uses custom branch name."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ dbs.source_set('us/next', 'abc123')
+ dbs.commit()
+
+ log_output = 'aaa111|aaa111a|Author 1|First commit|abc123\n'
+
+ def mock_git(pipe_list):
+ cmd = pipe_list[0] if pipe_list else []
+ if 'log' in cmd:
+ return command.CommandResult(stdout=log_output)
+ if 'rev-parse' in cmd:
+ return command.CommandResult(stdout='master')
+ return command.CommandResult(stdout='')
+
+ command.TEST_RESULT = mock_git
+
+ info, _ = control.prepare_apply(dbs, 'us/next', 'my-branch')
+
+ self.assertIsNotNone(info)
+ self.assertEqual(info.branch_name, 'my-branch')
+ dbs.close()
+
+
class TestGetNextCommitsEmptyLine(unittest.TestCase):
"""Tests for get_next_commits with empty lines."""