@@ -4,8 +4,9 @@
# Copyright 2025 Canonical Ltd.
# Written by Simon Glass <simon.glass@canonical.com>
#
-"""Entry point for pickman - dispatches to control module."""
+"""Entry point for pickman - parses arguments and dispatches to control."""
+import argparse
import os
import sys
@@ -17,9 +18,32 @@ sys.path.insert(0, os.path.join(our_path, '..'))
from pickman import control
-def main():
- """Main function."""
- return control.do_pickman()
+def parse_args(argv):
+ """Parse command line arguments.
+
+ Args:
+ argv (list): Command line arguments
+
+ Returns:
+ Namespace: Parsed arguments
+ """
+ parser = argparse.ArgumentParser(description='Check commit differences')
+ subparsers = parser.add_subparsers(dest='cmd', required=True)
+
+ subparsers.add_parser('compare', help='Compare branches')
+ subparsers.add_parser('test', help='Run tests')
+
+ return parser.parse_args(argv)
+
+
+def main(argv=None):
+ """Main function to parse args and run commands.
+
+ Args:
+ argv (list): Command line arguments (None for sys.argv[1:])
+ """
+ args = parse_args(argv)
+ return control.do_pickman(args)
if __name__ == '__main__':
@@ -8,12 +8,14 @@
from collections import namedtuple
import os
import sys
+import unittest
# Allow 'from pickman import xxx' to work via symlink
our_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(our_path, '..'))
# pylint: disable=wrong-import-position,import-error
+from pickman import ftest
from u_boot_pylib import command
from u_boot_pylib import tout
@@ -54,14 +56,12 @@ def compare_branches(master, source):
return count, Commit(full_hash, short_hash, subject, date)
-def do_pickman():
- """Main entry point for pickman.
+def do_compare(args): # pylint: disable=unused-argument
+ """Compare branches and print results.
- Returns:
- int: 0 on success
+ Args:
+ args (Namespace): Parsed arguments
"""
- tout.init(tout.INFO)
-
count, base = compare_branches(BRANCH_MASTER, BRANCH_SOURCE)
tout.info(f'Commits in {BRANCH_SOURCE} not in {BRANCH_MASTER}: {count}')
@@ -72,3 +72,39 @@ def do_pickman():
tout.info(f' Date: {base.date}')
return 0
+
+
+def do_test(args): # pylint: disable=unused-argument
+ """Run tests for this module.
+
+ Args:
+ args (Namespace): Parsed arguments
+
+ Returns:
+ int: 0 if tests passed, 1 otherwise
+ """
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromModule(ftest)
+ runner = unittest.TextTestRunner()
+ result = runner.run(suite)
+
+ return 0 if result.wasSuccessful() else 1
+
+
+def do_pickman(args):
+ """Main entry point for pickman commands.
+
+ Args:
+ args (Namespace): Parsed arguments
+
+ Returns:
+ int: 0 on success, 1 on failure
+ """
+ tout.init(tout.INFO)
+
+ if args.cmd == 'compare':
+ return do_compare(args)
+ if args.cmd == 'test':
+ return do_test(args)
+
+ return 1
@@ -13,9 +13,11 @@ import unittest
our_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(our_path, '..'))
-# pylint: disable=wrong-import-position,import-error
+# pylint: disable=wrong-import-position,import-error,cyclic-import
from u_boot_pylib import command
+from u_boot_pylib import terminal
+from pickman import __main__ as pickman
from pickman import control
@@ -97,5 +99,58 @@ class TestCompareBranches(unittest.TestCase):
command.TEST_RESULT = None
+class TestParseArgs(unittest.TestCase):
+ """Tests for parse_args function."""
+
+ def test_parse_compare(self):
+ """Test parsing compare command."""
+ args = pickman.parse_args(['compare'])
+ self.assertEqual(args.cmd, 'compare')
+
+ def test_parse_test(self):
+ """Test parsing test command."""
+ args = pickman.parse_args(['test'])
+ self.assertEqual(args.cmd, 'test')
+
+ def test_parse_no_command(self):
+ """Test parsing with no command raises error."""
+ with terminal.capture():
+ with self.assertRaises(SystemExit):
+ pickman.parse_args([])
+
+
+class TestMain(unittest.TestCase):
+ """Tests for main function."""
+
+ def test_main_compare(self):
+ """Test main with compare command."""
+ results = iter([
+ '10',
+ 'abc123',
+ 'abc123\nabc\nSubject\n2024-01-01 00:00:00 -0000',
+ ])
+
+ def handle_command(**_):
+ return command.CommandResult(stdout=next(results))
+
+ command.TEST_RESULT = handle_command
+ try:
+ with terminal.capture() as (stdout, _):
+ ret = pickman.main(['compare'])
+ self.assertEqual(ret, 0)
+ lines = iter(stdout.getvalue().splitlines())
+ self.assertEqual('Commits in us/next not in ci/master: 10',
+ next(lines))
+ self.assertEqual('', next(lines))
+ self.assertEqual('Last common commit:', next(lines))
+ self.assertEqual(' Hash: abc', next(lines))
+ self.assertEqual(' Subject: Subject', next(lines))
+ self.assertEqual(' Date: 2024-01-01 00:00:00 -0000',
+ next(lines))
+ self.assertRaises(StopIteration, next, lines)
+ finally:
+ command.TEST_RESULT = None
+
+
if __name__ == '__main__':
unittest.main()