[Concept,03/17] pickman: Add database for tracking cherry-pick state

Message ID 20251217022611.389379-4-sjg@u-boot.org
State New
Headers
Series pickman: Add a manager for cherry-picks |

Commit Message

Simon Glass Dec. 17, 2025, 2:25 a.m. UTC
  From: Simon Glass <simon.glass@canonical.com>

Add an sqlite3 database module to track the state of cherry-picking
commits between branches. The database uses .pickman.db and includes:

- source table: tracks source branches and their last cherry-picked
  commit into master
- Schema versioning for future migrations

The database code is mostly lifted from patman

Co-developed-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Simon Glass <simon.glass@canonical.com>
---

 tools/pickman/database.py | 193 ++++++++++++++++++++++++++++++++++++++
 tools/pickman/ftest.py    |  72 ++++++++++++++
 2 files changed, 265 insertions(+)
 create mode 100644 tools/pickman/database.py
  

Patch

diff --git a/tools/pickman/database.py b/tools/pickman/database.py
new file mode 100644
index 00000000000..436734fe1f7
--- /dev/null
+++ b/tools/pickman/database.py
@@ -0,0 +1,193 @@ 
+# SPDX-License-Identifier: GPL-2.0+
+#
+# Copyright 2025 Canonical Ltd.
+# Written by Simon Glass <simon.glass@canonical.com>
+#
+"""Database for pickman - tracks cherry-pick state.
+
+This uses sqlite3 with a local file (.pickman.db).
+
+To adjust the schema, increment LATEST, create a _migrate_to_v<x>() function
+and add code in migrate_to() to call it.
+"""
+
+import os
+import sqlite3
+
+from u_boot_pylib import tools
+from u_boot_pylib import tout
+
+# Schema version (version 0 means there is no database yet)
+LATEST = 1
+
+# Default database filename
+DB_FNAME = '.pickman.db'
+
+
+class Database:
+    """Database of cherry-pick state used by pickman"""
+
+    # dict of databases:
+    #   key: filename
+    #   value: Database object
+    instances = {}
+
+    def __init__(self, db_path):
+        """Set up a new database object
+
+        Args:
+            db_path (str): Path to the database
+        """
+        if db_path in Database.instances:
+            raise ValueError(f"There is already a database for '{db_path}'")
+        self.con = None
+        self.cur = None
+        self.db_path = db_path
+        self.is_open = False
+        Database.instances[db_path] = self
+
+    @staticmethod
+    def get_instance(db_path):
+        """Get the database instance for a path
+
+        Args:
+            db_path (str): Path to the database
+
+        Return:
+            tuple:
+                Database: Database instance, created if necessary
+                bool: True if newly created
+        """
+        dbs = Database.instances.get(db_path)
+        if dbs:
+            return dbs, False
+        return Database(db_path), True
+
+    def start(self):
+        """Open the database ready for use, migrate to latest schema"""
+        self.open_it()
+        self.migrate_to(LATEST)
+
+    def open_it(self):
+        """Open the database, creating it if necessary"""
+        if self.is_open:
+            raise ValueError('Already open')
+        if not os.path.exists(self.db_path):
+            tout.warning(f'Creating new database {self.db_path}')
+        self.con = sqlite3.connect(self.db_path)
+        self.cur = self.con.cursor()
+        self.is_open = True
+        Database.instances[self.db_path] = self
+
+    def close(self):
+        """Close the database"""
+        if not self.is_open:
+            raise ValueError('Already closed')
+        self.con.close()
+        self.cur = None
+        self.con = None
+        self.is_open = False
+        Database.instances.pop(self.db_path, None)
+
+    def _create_v1(self):
+        """Create a database with the v1 schema"""
+        # Table for tracking source branches and their last cherry-picked commit
+        self.cur.execute(
+            'CREATE TABLE source ('
+            'id INTEGER PRIMARY KEY AUTOINCREMENT, '
+            'name TEXT UNIQUE, '
+            'last_commit TEXT)')
+
+        # Schema version table
+        self.cur.execute('CREATE TABLE schema_version (version INTEGER)')
+
+    def migrate_to(self, dest_version):
+        """Migrate the database to the selected version
+
+        Args:
+            dest_version (int): Version to migrate to
+        """
+        while True:
+            version = self.get_schema_version()
+            if version >= dest_version:
+                break
+
+            self.close()
+            tools.write_file(f'{self.db_path}old.v{version}',
+                             tools.read_file(self.db_path))
+
+            version += 1
+            tout.info(f'Update database to v{version}')
+            self.open_it()
+            if version == 1:
+                self._create_v1()
+
+            self.cur.execute('DELETE FROM schema_version')
+            self.cur.execute(
+                'INSERT INTO schema_version (version) VALUES (?)',
+                (version,))
+            self.commit()
+
+    def get_schema_version(self):
+        """Get the version of the database's schema
+
+        Return:
+            int: Database version, 0 means there is no data
+        """
+        try:
+            self.cur.execute('SELECT version FROM schema_version')
+            return self.cur.fetchone()[0]
+        except sqlite3.OperationalError:
+            return 0
+
+    def execute(self, query, parameters=()):
+        """Execute a database query
+
+        Args:
+            query (str): Query string
+            parameters (tuple): Parameters to pass
+
+        Return:
+            Cursor result
+        """
+        return self.cur.execute(query, parameters)
+
+    def commit(self):
+        """Commit changes to the database"""
+        self.con.commit()
+
+    def rollback(self):
+        """Roll back changes to the database"""
+        self.con.rollback()
+
+    # source functions
+
+    def source_get(self, name):
+        """Get the last cherry-picked commit for a source branch
+
+        Args:
+            name (str): Source branch name
+
+        Return:
+            str: Commit hash, or None if not found
+        """
+        res = self.execute(
+            'SELECT last_commit FROM source WHERE name = ?', (name,))
+        rec = res.fetchone()
+        if rec:
+            return rec[0]
+        return None
+
+    def source_set(self, name, commit):
+        """Set the last cherry-picked commit for a source branch
+
+        Args:
+            name (str): Source branch name
+            commit (str): Commit hash
+        """
+        self.execute(
+            'UPDATE source SET last_commit = ? WHERE name = ?', (commit, name))
+        if self.cur.rowcount == 0:
+            self.execute(
+                'INSERT INTO source (name, last_commit) VALUES (?, ?)',
+                (name, commit))
diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py
index eeb19926f76..b975b9c6a2b 100644
--- a/tools/pickman/ftest.py
+++ b/tools/pickman/ftest.py
@@ -7,6 +7,7 @@ 
 
 import os
 import sys
+import tempfile
 import unittest
 
 # Allow 'from pickman import xxx' to work via symlink
@@ -19,6 +20,7 @@  from u_boot_pylib import terminal
 
 from pickman import __main__ as pickman
 from pickman import control
+from pickman import database
 
 
 class TestCommit(unittest.TestCase):
@@ -152,5 +154,75 @@  class TestMain(unittest.TestCase):
             command.TEST_RESULT = None
 
 
+class TestDatabase(unittest.TestCase):
+    """Tests for Database class."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        fd, self.db_path = tempfile.mkstemp(suffix='.db')
+        os.close(fd)
+        os.unlink(self.db_path)  # Remove so database creates it fresh
+        database.Database.instances.clear()
+
+    def tearDown(self):
+        """Clean up test fixtures."""
+        if os.path.exists(self.db_path):
+            os.unlink(self.db_path)
+        database.Database.instances.clear()
+
+    def test_create_database(self):
+        """Test creating a new database."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            self.assertTrue(dbs.is_open)
+            self.assertEqual(dbs.get_schema_version(), database.LATEST)
+            dbs.close()
+
+    def test_source_get_empty(self):
+        """Test getting source from empty database."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            result = dbs.source_get('us/next')
+            self.assertIsNone(result)
+            dbs.close()
+
+    def test_source_set_and_get(self):
+        """Test setting and getting source commit."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123def456')
+            dbs.commit()
+            result = dbs.source_get('us/next')
+            self.assertEqual(result, 'abc123def456')
+            dbs.close()
+
+    def test_source_update(self):
+        """Test updating source commit."""
+        with terminal.capture():
+            dbs = database.Database(self.db_path)
+            dbs.start()
+            dbs.source_set('us/next', 'abc123')
+            dbs.commit()
+            dbs.source_set('us/next', 'def456')
+            dbs.commit()
+            result = dbs.source_get('us/next')
+            self.assertEqual(result, 'def456')
+            dbs.close()
+
+    def test_get_instance(self):
+        """Test get_instance returns same database."""
+        with terminal.capture():
+            dbs1, created1 = database.Database.get_instance(self.db_path)
+            dbs1.start()
+            dbs2, created2 = database.Database.get_instance(self.db_path)
+            self.assertTrue(created1)
+            self.assertFalse(created2)
+            self.assertIs(dbs1, dbs2)
+            dbs1.close()
+
+
 if __name__ == '__main__':
     unittest.main()