[Concept,02/17] pickman: Add argument parsing with compare and test commands

Message ID 20251217022611.389379-3-sjg@u-boot.org
State New
Headers
Series pickman: Add a manager for cherry-picks |

Commit Message

Simon Glass Dec. 17, 2025, 2:25 a.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

Add subcommand support:
- compare: Compare branches (existing functionality)
- test: Run the functional tests

Co-developed-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/__main__.py | 32 +++++++++++++++++++---
 tools/pickman/control.py  | 48 ++++++++++++++++++++++++++++-----
 tools/pickman/ftest.py    | 57 ++++++++++++++++++++++++++++++++++++++-
 3 files changed, 126 insertions(+), 11 deletions(-)
  

Patch

diff --git a/tools/pickman/__main__.py b/tools/pickman/__main__.py
index eb0d6e226cc..0984c62d3e6 100755
--- a/tools/pickman/__main__.py
+++ b/tools/pickman/__main__.py
@@ -4,8 +4,9 @@ 
 # Copyright 2025 Canonical Ltd.
 # Written by Simon Glass <simon.glass@canonical.com>
 #
-"""Entry point for pickman - dispatches to control module."""
+"""Entry point for pickman - parses arguments and dispatches to control."""
 
+import argparse
 import os
 import sys
 
@@ -17,9 +18,32 @@  sys.path.insert(0, os.path.join(our_path, '..'))
 from pickman import control
 
 
-def main():
-    """Main function."""
-    return control.do_pickman()
+def parse_args(argv):
+    """Parse command line arguments.
+
+    Args:
+        argv (list): Command line arguments
+
+    Returns:
+        Namespace: Parsed arguments
+    """
+    parser = argparse.ArgumentParser(description='Check commit differences')
+    subparsers = parser.add_subparsers(dest='cmd', required=True)
+
+    subparsers.add_parser('compare', help='Compare branches')
+    subparsers.add_parser('test', help='Run tests')
+
+    return parser.parse_args(argv)
+
+
+def main(argv=None):
+    """Main function to parse args and run commands.
+
+    Args:
+        argv (list): Command line arguments (None for sys.argv[1:])
+    """
+    args = parse_args(argv)
+    return control.do_pickman(args)
 
 
 if __name__ == '__main__':
diff --git a/tools/pickman/control.py b/tools/pickman/control.py
index 990fa1b0729..0ed54dd724c 100644
--- a/tools/pickman/control.py
+++ b/tools/pickman/control.py
@@ -8,12 +8,14 @@ 
 from collections import namedtuple
 import os
 import sys
+import unittest
 
 # Allow 'from pickman import xxx' to work via symlink
 our_path = os.path.dirname(os.path.realpath(__file__))
 sys.path.insert(0, os.path.join(our_path, '..'))
 
 # pylint: disable=wrong-import-position,import-error
+from pickman import ftest
 from u_boot_pylib import command
 from u_boot_pylib import tout
 
@@ -54,14 +56,12 @@  def compare_branches(master, source):
     return count, Commit(full_hash, short_hash, subject, date)
 
 
-def do_pickman():
-    """Main entry point for pickman.
+def do_compare(args):  # pylint: disable=unused-argument
+    """Compare branches and print results.
 
-    Returns:
-        int: 0 on success
+    Args:
+        args (Namespace): Parsed arguments
     """
-    tout.init(tout.INFO)
-
     count, base = compare_branches(BRANCH_MASTER, BRANCH_SOURCE)
 
     tout.info(f'Commits in {BRANCH_SOURCE} not in {BRANCH_MASTER}: {count}')
@@ -72,3 +72,39 @@  def do_pickman():
     tout.info(f'  Date:    {base.date}')
 
     return 0
+
+
+def do_test(args):  # pylint: disable=unused-argument
+    """Run tests for this module.
+
+    Args:
+        args (Namespace): Parsed arguments
+
+    Returns:
+        int: 0 if tests passed, 1 otherwise
+    """
+    loader = unittest.TestLoader()
+    suite = loader.loadTestsFromModule(ftest)
+    runner = unittest.TextTestRunner()
+    result = runner.run(suite)
+
+    return 0 if result.wasSuccessful() else 1
+
+
+def do_pickman(args):
+    """Main entry point for pickman commands.
+
+    Args:
+        args (Namespace): Parsed arguments
+
+    Returns:
+        int: 0 on success, 1 on failure
+    """
+    tout.init(tout.INFO)
+
+    if args.cmd == 'compare':
+        return do_compare(args)
+    if args.cmd == 'test':
+        return do_test(args)
+
+    return 1
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index 7b34a260659..eeb19926f76 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -13,9 +13,11 @@  import unittest
 our_path = os.path.dirname(os.path.realpath(__file__))
 sys.path.insert(0, os.path.join(our_path, '..'))
 
-# pylint: disable=wrong-import-position,import-error
+# pylint: disable=wrong-import-position,import-error,cyclic-import
 from u_boot_pylib import command
+from u_boot_pylib import terminal
 
+from pickman import __main__ as pickman
 from pickman import control
 
 
@@ -97,5 +99,58 @@  class TestCompareBranches(unittest.TestCase):
             command.TEST_RESULT = None
 
 
+class TestParseArgs(unittest.TestCase):
+    """Tests for parse_args function."""
+
+    def test_parse_compare(self):
+        """Test parsing compare command."""
+        args = pickman.parse_args(['compare'])
+        self.assertEqual(args.cmd, 'compare')
+
+    def test_parse_test(self):
+        """Test parsing test command."""
+        args = pickman.parse_args(['test'])
+        self.assertEqual(args.cmd, 'test')
+
+    def test_parse_no_command(self):
+        """Test parsing with no command raises error."""
+        with terminal.capture():
+            with self.assertRaises(SystemExit):
+                pickman.parse_args([])
+
+
+class TestMain(unittest.TestCase):
+    """Tests for main function."""
+
+    def test_main_compare(self):
+        """Test main with compare command."""
+        results = iter([
+            '10',
+            'abc123',
+            'abc123\nabc\nSubject\n2024-01-01 00:00:00 -0000',
+        ])
+
+        def handle_command(**_):
+            return command.CommandResult(stdout=next(results))
+
+        command.TEST_RESULT = handle_command
+        try:
+            with terminal.capture() as (stdout, _):
+                ret = pickman.main(['compare'])
+            self.assertEqual(ret, 0)
+            lines = iter(stdout.getvalue().splitlines())
+            self.assertEqual('Commits in us/next not in ci/master: 10',
+                             next(lines))
+            self.assertEqual('', next(lines))
+            self.assertEqual('Last common commit:', next(lines))
+            self.assertEqual('  Hash:    abc', next(lines))
+            self.assertEqual('  Subject: Subject', next(lines))
+            self.assertEqual('  Date:    2024-01-01 00:00:00 -0000',
+                             next(lines))
+            self.assertRaises(StopIteration, next, lines)
+        finally:
+            command.TEST_RESULT = None
+
+
 if __name__ == '__main__':
     unittest.main()