[Concept,8/9] pickman: Decompose mega-merges into sub-merge batches

Message ID 20260212211626.167191-9-sjg@u-boot.org
State New
Headers
Series pickman: Improve handling of large merges and add rewind |

Commit Message

Simon Glass Feb. 12, 2026, 9:16 p.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

When get_next_commits() encounters a large merge on the first-parent
chain (e.g., "Merge branch 'next'"), it currently collects ALL commits
from prev_commit..merge_hash at once. For mega-merges containing many
sub-merges, this can produce hundreds of commits in a single batch,
overwhelming the agent.

Add detect_sub_merges() to check if a merge commit contains sub-merges
on its second parent's first-parent chain. Add decompose_mega_merge()
to return one sub-merge batch at a time across multiple runs, handling
three phases: mainline commits, individual sub-merge batches, and
remainder commits.

Extract find_unprocessed_commits() from get_next_commits() to walk the
merge list and find the first with unprocessed commits.

Change get_next_commits() to return a 4-tuple with an advance_to field.
When advance_to is a hash, the caller advances the source position to
that hash. When advance_to is None, the source stays put (sub-merge
batch, to be continued next run). Update ApplyInfo, prepare_apply(),
execute_apply(), handle_already_applied(), do_apply(), and do_next_set()
to thread advance_to through the call chain.

Co-developed-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/control.py | 208 ++++++++++++++++--
 tools/pickman/ftest.py   | 461 ++++++++++++++++++++++++++++++++++++++-
 2 files changed, 640 insertions(+), 29 deletions(-)
  

Patch

diff --git a/tools/pickman/control.py b/tools/pickman/control.py
index 0418fbe2da6..2e6f49f1816 100644
--- a/tools/pickman/control.py
+++ b/tools/pickman/control.py
@@ -91,8 +91,9 @@  AgentCommit = namedtuple('AgentCommit',
 #
 # commits: list of CommitInfo to cherry-pick
 # merge_found: True if these commits came from a merge on the source branch
+# advance_to: hash to advance the source position to, or None to stay put
 NextCommitsInfo = namedtuple('NextCommitsInfo',
-                             ['commits', 'merge_found'])
+                             ['commits', 'merge_found', 'advance_to'])
 
 # Named tuple for prepare_apply() result
 #
@@ -100,9 +101,10 @@  NextCommitsInfo = namedtuple('NextCommitsInfo',
 # branch_name: name of the branch to create for the MR
 # original_branch: branch name before any conflict suffix
 # merge_found: True if these commits came from a merge on the source branch
+# advance_to: hash to advance the source position to, or None to stay put
 ApplyInfo = namedtuple('ApplyInfo',
                        ['commits', 'branch_name', 'original_branch',
-                        'merge_found'])
+                        'merge_found', 'advance_to'])
 
 
 def parse_log_output(log_output, has_parents=False):
@@ -749,7 +751,8 @@  def find_unprocessed_commits(dbs, last_commit, source, merge_hashes):
     """Find the first merge with unprocessed commits
 
     Walks through the merge hashes in order, looking for one that has
-    commits not yet tracked in the database.
+    commits not yet tracked in the database. Decomposes mega-merges
+    (merges containing sub-merges) into individual batches.
 
     Args:
         dbs (Database): Database instance
@@ -761,7 +764,20 @@  def find_unprocessed_commits(dbs, last_commit, source, merge_hashes):
         NextCommitsInfo: Info about the next commits to process
     """
     prev_commit = last_commit
+    skipped_merges = False
     for merge_hash in merge_hashes:
+        # Check for mega-merge (contains sub-merges)
+        sub_merges = detect_sub_merges(merge_hash)
+        if sub_merges:
+            commits, advance_to = decompose_mega_merge(
+                dbs, prev_commit, merge_hash, sub_merges)
+            if commits:
+                return NextCommitsInfo(commits, True, advance_to)
+            # All sub-merges done, skip past this mega-merge
+            prev_commit = merge_hash
+            skipped_merges = True
+            continue
+
         # Get all commits from prev_commit to this merge
         log_output = run_git([
             'log', '--reverse', '--format=%H|%h|%an|%s|%P',
@@ -778,10 +794,11 @@  def find_unprocessed_commits(dbs, last_commit, source, merge_hashes):
                    if not dbs.commit_get(c.hash)]
 
         if commits:
-            return NextCommitsInfo(commits, True)
+            return NextCommitsInfo(commits, True, merge_hash)
 
         # All commits in this merge are processed, skip to next
         prev_commit = merge_hash
+        skipped_merges = True
 
     # No merges with unprocessed commits, check remaining commits
     log_output = run_git([
@@ -790,12 +807,14 @@  def find_unprocessed_commits(dbs, last_commit, source, merge_hashes):
     ])
 
     if not log_output:
-        return NextCommitsInfo([], False)
+        # If we skipped merges, advance past them
+        advance_to = prev_commit if skipped_merges else None
+        return NextCommitsInfo([], False, advance_to)
 
     all_commits = parse_log_output(log_output, has_parents=True)
     commits = [c for c in all_commits if not dbs.commit_get(c.hash)]
 
-    return NextCommitsInfo(commits, False)
+    return NextCommitsInfo(commits, False, None)
 
 
 def get_next_commits(dbs, source):
@@ -804,7 +823,8 @@  def get_next_commits(dbs, source):
     Finds commits between the last cherry-picked commit and the next merge
     commit on the first-parent (mainline) chain of the source branch.
     Skips merges whose commits are already tracked in the database (from
-    pending MRs).
+    pending MRs). Decomposes mega-merges (merges containing sub-merges)
+    into individual sub-merge batches.
 
     Args:
         dbs (Database): Database instance
@@ -827,7 +847,7 @@  def get_next_commits(dbs, source):
     ])
 
     if not fp_output:
-        return NextCommitsInfo([], False), None
+        return NextCommitsInfo([], False, None), None
 
     # Build list of merge hashes on the first-parent chain
     merge_hashes = []
@@ -903,6 +923,130 @@  def get_commits_for_pick(commit_spec):
     return commits, err
 
 
+def detect_sub_merges(merge_hash):
+    """Check if a merge commit contains sub-merges
+
+    Examines the second parent's first-parent chain to find merge commits
+    (sub-merges) within a larger merge.
+
+    Args:
+        merge_hash (str): Hash of the merge commit to check
+
+    Returns:
+        list: List of sub-merge hashes in chronological order, or empty
+            list if not a merge or has no sub-merges
+    """
+    # Get parents of the merge
+    try:
+        parents = run_git(['rev-parse', f'{merge_hash}^@'])
+    except command.CommandExc:
+        return []
+
+    parent_list = parents.strip().split('\n')
+    if len(parent_list) < 2:
+        return []
+
+    first_parent = parent_list[0]
+    second_parent = parent_list[1]
+
+    # Find merges on the second parent's first-parent chain
+    try:
+        out = run_git([
+            'log', '--reverse', '--first-parent', '--merges',
+            '--format=%H', f'^{first_parent}', second_parent
+        ])
+    except command.CommandExc:
+        return []
+
+    if not out:
+        return []
+
+    return [line for line in out.split('\n') if line]
+
+
+def decompose_mega_merge(dbs, prev_commit, merge_hash, sub_merges):
+    """Return the next unprocessed batch from a mega-merge
+
+    Handles three phases:
+    1. Mainline commits before the merge (prev_commit..merge^1)
+    2. Sub-merge batches (one at a time, skipping processed ones)
+    3. Remainder commits after the last sub-merge
+
+    Pre-adds the mega-merge commit itself to DB as 'skipped' so it does
+    not appear as an orphan commit.
+
+    Args:
+        dbs (Database): Database instance
+        prev_commit (str): Hash of the last processed commit
+        merge_hash (str): Hash of the mega-merge commit
+        sub_merges (list): List of sub-merge hashes in chronological order
+
+    Returns:
+        tuple: (commits, advance_to) where:
+            commits: list of CommitInfo tuples for the next batch
+            advance_to: hash to advance source to, or None to stay put
+    """
+    parents = run_git(['rev-parse', f'{merge_hash}^@']).strip().split('\n')
+    first_parent = parents[0]
+    second_parent = parents[1]
+
+    # Pre-add the mega-merge commit itself as skipped
+    if not dbs.commit_get(merge_hash):
+        source_id = None
+        sources = dbs.source_get_all()
+        if sources:
+            source_id = dbs.source_get_id(sources[0][0])
+        if source_id:
+            info = run_git(['log', '-1', '--format=%s|%an', merge_hash])
+            parts = info.split('|', 1)
+            subject = parts[0]
+            author = parts[1] if len(parts) > 1 else ''
+            dbs.commit_add(merge_hash, source_id, subject, author,
+                           status='skipped')
+            dbs.commit()
+
+    # Phase 1: mainline commits before the merge
+    log_output = run_git([
+        'log', '--reverse', '--format=%H|%h|%an|%s|%P',
+        f'{prev_commit}..{first_parent}'
+    ])
+    if log_output:
+        all_commits = parse_log_output(log_output, has_parents=True)
+        commits = [c for c in all_commits if not dbs.commit_get(c.hash)]
+        if commits:
+            return commits, first_parent
+
+    # Phase 2: sub-merge batches
+    prev_sub = first_parent
+    for sub_hash in sub_merges:
+        # Get commits for this sub-merge
+        log_output = run_git([
+            'log', '--reverse', '--format=%H|%h|%an|%s|%P',
+            f'^{prev_sub}', sub_hash
+        ])
+        if log_output:
+            all_commits = parse_log_output(log_output, has_parents=True)
+            commits = [c for c in all_commits if not dbs.commit_get(c.hash)]
+            if commits:
+                return commits, None
+        prev_sub = sub_hash
+
+    # Phase 3: remainder after the last sub-merge
+    last_sub = sub_merges[-1] if sub_merges else first_parent
+    log_output = run_git([
+        'log', '--reverse', '--format=%H|%h|%an|%s|%P',
+        f'^{last_sub}', second_parent
+    ])
+    if log_output:
+        all_commits = parse_log_output(log_output, has_parents=True)
+        commits = [c for c in all_commits if not dbs.commit_get(c.hash)]
+        if commits:
+            return commits, None
+
+    # All done
+    return [], None
+
+
 def do_next_set(args, dbs):
     """Show the next set of commits to cherry-pick from a source
 
@@ -1279,6 +1423,14 @@  def prepare_apply(dbs, source, branch):
         return None, 1
 
     if not info.commits:
+        # If advance_to is set, advance source past fully-processed merges
+        if info.advance_to:
+            dbs.source_set(source, info.advance_to)
+            dbs.commit()
+            tout.info(f"Advanced source '{source}' to "
+                      f'{info.advance_to[:12]}')
+            # Retry with updated position
+            return prepare_apply(dbs, source, branch)
         tout.info('No new commits to cherry-pick')
         return None, 0
 
@@ -1313,12 +1465,12 @@  def prepare_apply(dbs, source, branch):
     tout.info('')
 
     return ApplyInfo(commits, branch_name, original_branch,
-                     info.merge_found), 0
+                     info.merge_found, info.advance_to), 0
 
 
 # pylint: disable=too-many-arguments
 def handle_already_applied(dbs, source, commits, branch_name, conv_log, args,
-                           signal_commit):
+                           signal_commit, advance_to=None):
     """Handle the case where commits are already applied to the target branch
 
     Creates an MR with [skip] prefix to record the attempt and updates the
@@ -1332,6 +1484,9 @@  def handle_already_applied(dbs, source, commits, branch_name, conv_log, args,
         conv_log (str): Conversation log from the agent
         args (Namespace): Parsed arguments with 'push', 'remote', 'target'
         signal_commit (str): Last commit hash from signal file
+        advance_to (str): Hash to advance source to, or None to use last
+            commit. If explicitly None (sub-merge batch), source is not
+            advanced.
 
     Returns:
         int: 0 on success, 1 on failure
@@ -1343,11 +1498,20 @@  def handle_already_applied(dbs, source, commits, branch_name, conv_log, args,
         dbs.commit_set_status(commit.hash, 'skipped')
     dbs.commit()
 
-    # Update source position to the last commit (or signal_commit if provided)
-    last_hash = signal_commit if signal_commit else commits[-1].hash
-    dbs.source_set(source, last_hash)
-    dbs.commit()
-    tout.info(f"Updated source '{source}' to {last_hash[:12]}")
+    # Update source position
+    if advance_to is not None:
+        dbs.source_set(source, advance_to)
+        dbs.commit()
+        tout.info(f"Updated source '{source}' to {advance_to[:12]}")
+    elif signal_commit:
+        dbs.source_set(source, signal_commit)
+        dbs.commit()
+        tout.info(f"Updated source '{source}' to {signal_commit[:12]}")
+    else:
+        last_hash = commits[-1].hash
+        dbs.source_set(source, last_hash)
+        dbs.commit()
+        tout.info(f"Updated source '{source}' to {last_hash[:12]}")
 
     # Push and create MR with [skip] prefix if requested
     if args.push:
@@ -1382,7 +1546,7 @@  def handle_already_applied(dbs, source, commits, branch_name, conv_log, args,
     return 0
 
 
-def execute_apply(dbs, source, commits, branch_name, args):  # pylint: disable=too-many-locals
+def execute_apply(dbs, source, commits, branch_name, args, advance_to=None):  # pylint: disable=too-many-locals
     """Execute the apply operation: run agent, update database, push MR
 
     Args:
@@ -1391,6 +1555,8 @@  def execute_apply(dbs, source, commits, branch_name, args):  # pylint: disable=t
         commits (list): List of CommitInfo namedtuples
         branch_name (str): Branch name for cherry-picks
         args (Namespace): Parsed arguments with 'push', 'remote', 'target'
+        advance_to (str): Hash to advance source to after success, or None
+            to skip source advancement (sub-merge batch)
 
     Returns:
         tuple: (ret, success, conv_log) where ret is 0 on success,
@@ -1416,7 +1582,8 @@  def execute_apply(dbs, source, commits, branch_name, args):  # pylint: disable=t
     signal_status, signal_commit = agent.read_signal_file()
     if signal_status == agent.SIGNAL_APPLIED:
         ret = handle_already_applied(dbs, source, commits, branch_name,
-                                     conv_log, args, signal_commit)
+                                     conv_log, args, signal_commit,
+                                     advance_to)
         return ret, False, conv_log
 
     # Verify the branch actually exists - agent may have aborted and deleted it
@@ -1449,8 +1616,8 @@  def execute_apply(dbs, source, commits, branch_name, args):  # pylint: disable=t
                       f"{commits[-1].chash}' to update the database")
 
     # Update database with the last processed commit if successful
-    if success:
-        dbs.source_set(source, commits[-1].hash)
+    if success and advance_to is not None:
+        dbs.source_set(source, advance_to)
         dbs.commit()
 
     return ret, success, conv_log
@@ -1476,7 +1643,8 @@  def do_apply(args, dbs):
     original_branch = info.original_branch
 
     ret, success, conv_log = execute_apply(dbs, source, commits,
-                                                   branch_name, args)
+                                           branch_name, args,
+                                           info.advance_to)
 
     # Write history file if successful
     if success:
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index 38e5cef5306..eb69bde96a4 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -951,8 +951,14 @@  class TestNextSet(unittest.TestCase):
 
         def mock_git(pipe_list):
             cmd = pipe_list[0] if pipe_list else []
+            if '--first-parent' in cmd and '--merges' in cmd:
+                # detect_sub_merges: no sub-merges
+                return command.CommandResult(stdout='')
             if '--first-parent' in cmd:
                 return command.CommandResult(stdout=fp_log_output)
+            if 'rev-parse' in cmd:
+                # detect_sub_merges: return two parents (it's a merge)
+                return command.CommandResult(stdout='bbb222\nddd444\n')
             return command.CommandResult(stdout=full_log_output)
 
         command.TEST_RESULT = mock_git
@@ -1235,8 +1241,14 @@  class TestGetNextCommits(unittest.TestCase):
 
             def mock_git(pipe_list):
                 cmd = pipe_list[0] if pipe_list else []
+                if '--first-parent' in cmd and '--merges' in cmd:
+                    # detect_sub_merges: no sub-merges
+                    return command.CommandResult(stdout='')
                 if '--first-parent' in cmd:
                     return command.CommandResult(stdout=fp_log_output)
+                if 'rev-parse' in cmd:
+                    # detect_sub_merges: return parents
+                    return command.CommandResult(stdout='aaa111\nccc333\n')
                 return command.CommandResult(stdout=full_log_output)
 
             command.TEST_RESULT = mock_git
@@ -2965,18 +2977,22 @@  class TestGetNextCommitsEmptyLine(unittest.TestCase):
                 'merge2|merge2m|Author 4|Second merge|ccc333 side2\n'
             )
 
-            call_count = [0]
-
-            # pylint: disable=unused-argument
             def mock_git(pipe_list):
-                call_count[0] += 1
-                # First call: get first-parent log
-                if call_count[0] == 1:
+                cmd = pipe_list[0] if pipe_list else []
+                if '--first-parent' in cmd and '--merges' in cmd:
+                    # detect_sub_merges: no sub-merges
+                    return command.CommandResult(stdout='')
+                if '--first-parent' in cmd:
                     return command.CommandResult(stdout=fp_log)
-                # Second call: get commits for first merge
-                if call_count[0] == 2:
+                if 'rev-parse' in cmd:
+                    # detect_sub_merges: return parents for merges
+                    return command.CommandResult(stdout='aaa111\nside1\n')
+                # Determine which merge range by checking the cmd args
+                cmd_str = ' '.join(cmd)
+                if 'merge1' in cmd_str and 'abc123' in cmd_str:
                     return command.CommandResult(stdout=merge1_log)
-                # Third call: get commits for second merge
+                if 'merge2' in cmd_str and 'merge1' in cmd_str:
+                    return command.CommandResult(stdout=merge2_log)
                 return command.CommandResult(stdout=merge2_log)
 
             command.TEST_RESULT = mock_git
@@ -2991,6 +3007,433 @@  class TestGetNextCommitsEmptyLine(unittest.TestCase):
             dbs.close()
 
 
+class TestDetectSubMerges(unittest.TestCase):
+    """Tests for detect_sub_merges function."""
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        command.TEST_RESULT = None
+
+    def test_not_a_merge(self):
+        """Test detect_sub_merges returns empty for non-merge commit."""
+        # Single parent means not a merge
+        command.TEST_RESULT = command.CommandResult(stdout='abc123\n')
+        result = control.detect_sub_merges('abc123')
+        self.assertEqual(result, [])
+
+    def test_no_sub_merges(self):
+        """Test detect_sub_merges returns empty when no sub-merges exist."""
+        call_count = [0]
+
+        def mock_git(pipe_list):  # pylint: disable=unused-argument
+            call_count[0] += 1
+            if call_count[0] == 1:
+                # rev-parse ^@ returns two parents (it's a merge)
+                return command.CommandResult(stdout='parent1\nparent2\n')
+            # log --merges returns empty (no sub-merges)
+            return command.CommandResult(stdout='')
+
+        command.TEST_RESULT = mock_git
+        result = control.detect_sub_merges('merge123')
+        self.assertEqual(result, [])
+
+    def test_found_sub_merges(self):
+        """Test detect_sub_merges finds sub-merges."""
+        call_count = [0]
+
+        def mock_git(pipe_list):  # pylint: disable=unused-argument
+            call_count[0] += 1
+            if call_count[0] == 1:
+                # rev-parse ^@ returns two parents
+                return command.CommandResult(stdout='parent1\nparent2\n')
+            # log --merges returns sub-merge hashes
+            return command.CommandResult(
+                stdout='sub_merge1\nsub_merge2\nsub_merge3\n')
+
+        command.TEST_RESULT = mock_git
+        result = control.detect_sub_merges('mega_merge')
+        self.assertEqual(result, ['sub_merge1', 'sub_merge2', 'sub_merge3'])
+
+    def test_error_handling(self):
+        """Test detect_sub_merges returns empty on git error."""
+        def mock_git_fail(**_kwargs):
+            raise command.CommandExc('git error', command.CommandResult())
+
+        command.TEST_RESULT = mock_git_fail
+        result = control.detect_sub_merges('bad_hash')
+        self.assertEqual(result, [])
+
+
+class TestDecomposeMegaMerge(unittest.TestCase):
+    """Tests for decompose_mega_merge function."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        fd, self.db_path = tempfile.mkstemp(suffix='.db')
+        os.close(fd)
+        os.unlink(self.db_path)
+        database.Database.instances.clear()
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        if os.path.exists(self.db_path):
+            os.unlink(self.db_path)
+        database.Database.instances.clear()
+        command.TEST_RESULT = None
+
+    def test_first_batch_mainline(self):
+        """Test decompose returns mainline commits first."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    # rev-parse ^@ for mega-merge parents
+                    return command.CommandResult(
+                        stdout='first_parent\nsecond_parent\n')
+                if call_count[0] == 2:
+                    # log -1 for mega-merge subject/author (pre-add)
+                    return command.CommandResult(
+                        stdout='Mega merge subject|Author\n')
+                if call_count[0] == 3:
+                    # Mainline commits (prev..first_parent)
+                    return command.CommandResult(
+                        stdout='aaa|aaa1|A|Mainline commit|base\n')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            commits, advance_to = control.decompose_mega_merge(
+                dbs, 'base', 'mega_hash', ['sub1', 'sub2'])
+
+            self.assertEqual(len(commits), 1)
+            self.assertEqual(commits[0].chash, 'aaa1')
+            self.assertEqual(advance_to, 'first_parent')
+            dbs.close()
+
+    def test_sub_merge_batch(self):
+        """Test decompose returns sub-merge batch when mainline is done."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    # rev-parse ^@ for mega-merge parents
+                    return command.CommandResult(
+                        stdout='first_parent\nsecond_parent\n')
+                if call_count[0] == 2:
+                    # log -1 for mega-merge subject/author
+                    return command.CommandResult(
+                        stdout='Mega merge|Author\n')
+                if call_count[0] == 3:
+                    # Mainline commits - empty (already processed)
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 4:
+                    # Sub-merge 1 commits
+                    return command.CommandResult(
+                        stdout='bbb|bbb1|B|Sub commit 1|first_parent\n'
+                               'ccc|ccc1|C|Sub commit 2|bbb\n')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            commits, advance_to = control.decompose_mega_merge(
+                dbs, 'base', 'mega_hash', ['sub1', 'sub2'])
+
+            self.assertEqual(len(commits), 2)
+            self.assertEqual(commits[0].chash, 'bbb1')
+            self.assertIsNone(advance_to)
+            dbs.close()
+
+    def test_skips_processed_sub_merge(self):
+        """Test decompose skips sub-merges already in database."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            # Add sub-merge 1 commits to database
+            source_id = dbs.source_get_id('us/next')
+            dbs.commit_add('bbb', source_id, 'Sub commit 1', 'B',
+                           status='applied')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    return command.CommandResult(
+                        stdout='first_parent\nsecond_parent\n')
+                if call_count[0] == 2:
+                    return command.CommandResult(
+                        stdout='Mega merge|Author\n')
+                if call_count[0] == 3:
+                    # Mainline - empty
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 4:
+                    # Sub-merge 1 commits (already in DB)
+                    return command.CommandResult(
+                        stdout='bbb|bbb1|B|Sub commit 1|first_parent\n')
+                if call_count[0] == 5:
+                    # Sub-merge 2 commits (not in DB)
+                    return command.CommandResult(
+                        stdout='ddd|ddd1|D|Sub commit 3|sub1\n')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            commits, advance_to = control.decompose_mega_merge(
+                dbs, 'base', 'mega_hash', ['sub1', 'sub2'])
+
+            self.assertEqual(len(commits), 1)
+            self.assertEqual(commits[0].chash, 'ddd1')
+            self.assertIsNone(advance_to)
+            dbs.close()
+
+    def test_all_done(self):
+        """Test decompose returns empty when all sub-merges are processed."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            # Add all commits to database
+            source_id = dbs.source_get_id('us/next')
+            dbs.commit_add('bbb', source_id, 'Sub commit 1', 'B',
+                           status='applied')
+            dbs.commit_add('ddd', source_id, 'Sub commit 2', 'D',
+                           status='applied')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=too-many-return-statements,unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    return command.CommandResult(
+                        stdout='first_parent\nsecond_parent\n')
+                if call_count[0] == 2:
+                    return command.CommandResult(
+                        stdout='Mega merge|Author\n')
+                if call_count[0] == 3:
+                    # Mainline - empty
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 4:
+                    # Sub-merge 1
+                    return command.CommandResult(
+                        stdout='bbb|bbb1|B|Sub commit 1|first_parent\n')
+                if call_count[0] == 5:
+                    # Sub-merge 2
+                    return command.CommandResult(
+                        stdout='ddd|ddd1|D|Sub commit 2|sub1\n')
+                if call_count[0] == 6:
+                    # Remainder - empty
+                    return command.CommandResult(stdout='')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            commits, advance_to = control.decompose_mega_merge(
+                dbs, 'base', 'mega_hash', ['sub1', 'sub2'])
+
+            self.assertEqual(len(commits), 0)
+            self.assertIsNone(advance_to)
+            dbs.close()
+
+
+class TestGetNextCommitsMegaMerge(unittest.TestCase):
+    """Tests for get_next_commits with mega-merges."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        fd, self.db_path = tempfile.mkstemp(suffix='.db')
+        os.close(fd)
+        os.unlink(self.db_path)
+        database.Database.instances.clear()
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        if os.path.exists(self.db_path):
+            os.unlink(self.db_path)
+        database.Database.instances.clear()
+        command.TEST_RESULT = None
+
+    def test_returns_sub_batch(self):
+        """Test get_next_commits returns sub-merge batch for mega-merge."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=too-many-return-statements,unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    # First-parent log shows one mega-merge
+                    return command.CommandResult(
+                        stdout='mega|mega1|A|Merge branch next|'
+                               'base second_parent\n')
+                if call_count[0] == 2:
+                    # detect_sub_merges: rev-parse ^@
+                    return command.CommandResult(
+                        stdout='base\nsecond_parent\n')
+                if call_count[0] == 3:
+                    # detect_sub_merges: log --merges (found sub-merges)
+                    return command.CommandResult(stdout='sub1\n')
+                if call_count[0] == 4:
+                    # decompose: rev-parse ^@ for mega-merge
+                    return command.CommandResult(
+                        stdout='base\nsecond_parent\n')
+                if call_count[0] == 5:
+                    # decompose: log -1 for mega-merge info
+                    return command.CommandResult(
+                        stdout='Mega merge|Author\n')
+                if call_count[0] == 6:
+                    # decompose: mainline commits (empty)
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 7:
+                    # decompose: sub-merge 1 commits
+                    return command.CommandResult(
+                        stdout='aaa|aaa1|A|Sub commit|base\n')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            info, err = control.get_next_commits(dbs, 'us/next')
+
+            self.assertIsNone(err)
+            self.assertTrue(info.merge_found)
+            self.assertEqual(len(info.commits), 1)
+            self.assertEqual(info.commits[0].chash, 'aaa1')
+            # Sub-merge batch: advance_to should be None
+            self.assertIsNone(info.advance_to)
+            dbs.close()
+
+    def test_all_done_advances_past(self):
+        """Test get_next_commits advances past fully-processed mega-merge."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            # Add all sub-merge commits to database
+            source_id = dbs.source_get_id('us/next')
+            dbs.commit_add('aaa', source_id, 'Sub commit', 'A',
+                           status='applied')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=too-many-return-statements,unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    # First-parent log shows mega-merge
+                    return command.CommandResult(
+                        stdout='mega|mega1|A|Merge branch next|'
+                               'base second_parent\n')
+                if call_count[0] == 2:
+                    # detect_sub_merges: rev-parse ^@
+                    return command.CommandResult(
+                        stdout='base\nsecond_parent\n')
+                if call_count[0] == 3:
+                    # detect_sub_merges: log --merges
+                    return command.CommandResult(stdout='sub1\n')
+                if call_count[0] == 4:
+                    # decompose: rev-parse ^@
+                    return command.CommandResult(
+                        stdout='base\nsecond_parent\n')
+                if call_count[0] == 5:
+                    # decompose: log -1 for mega-merge info
+                    return command.CommandResult(
+                        stdout='Mega merge|Author\n')
+                if call_count[0] == 6:
+                    # decompose: mainline (empty)
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 7:
+                    # decompose: sub-merge 1 (in DB)
+                    return command.CommandResult(
+                        stdout='aaa|aaa1|A|Sub commit|base\n')
+                if call_count[0] == 8:
+                    # decompose: remainder (empty)
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 9:
+                    # Remaining commits after mega-merge (empty)
+                    return command.CommandResult(stdout='')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            info, err = control.get_next_commits(dbs, 'us/next')
+
+            self.assertIsNone(err)
+            self.assertFalse(info.merge_found)
+            self.assertEqual(len(info.commits), 0)
+            # Should advance past the mega-merge
+            self.assertEqual(info.advance_to, 'mega')
+            dbs.close()
+
+    def test_normal_merge_returns_advance_to(self):
+        """Test get_next_commits returns advance_to for normal merges."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'base')
+            dbs.commit()
+
+            call_count = [0]
+
+            def mock_git(pipe_list):  # pylint: disable=unused-argument
+                call_count[0] += 1
+                if call_count[0] == 1:
+                    # First-parent log shows a normal merge
+                    return command.CommandResult(
+                        stdout='merge1|m1|A|Merge branch feat|'
+                               'base side1\n')
+                if call_count[0] == 2:
+                    # detect_sub_merges: rev-parse ^@
+                    return command.CommandResult(
+                        stdout='base\nside1\n')
+                if call_count[0] == 3:
+                    # detect_sub_merges: log --merges (no sub-merges)
+                    return command.CommandResult(stdout='')
+                if call_count[0] == 4:
+                    # Commits for this merge
+                    return command.CommandResult(
+                        stdout='aaa|aaa1|A|Commit 1|base\n'
+                               'merge1|m1|A|Merge branch feat|'
+                               'base side1\n')
+                return command.CommandResult(stdout='')
+
+            command.TEST_RESULT = mock_git
+
+            info, err = control.get_next_commits(dbs, 'us/next')
+
+            self.assertIsNone(err)
+            self.assertTrue(info.merge_found)
+            self.assertEqual(len(info.commits), 2)
+            # Normal merge: advance_to is the merge hash
+            self.assertEqual(info.advance_to, 'merge1')
+            dbs.close()
+
+
 class TestDoCommitSourceResolveError(unittest.TestCase):
     """Tests for do_commit_source error handling."""