new file mode 100644
@@ -0,0 +1,193 @@
+# SPDX-License-Identifier: GPL-2.0+
+#
+# Copyright 2025 Canonical Ltd.
+# Written by Simon Glass <simon.glass@canonical.com>
+#
+"""Database for pickman - tracks cherry-pick state.
+
+This uses sqlite3 with a local file (.pickman.db).
+
+To adjust the schema, increment LATEST, create a _migrate_to_v<x>() function
+and add code in migrate_to() to call it.
+"""
+
+import os
+import sqlite3
+
+from u_boot_pylib import tools
+from u_boot_pylib import tout
+
+# Schema version (version 0 means there is no database yet)
+LATEST = 1
+
+# Default database filename
+DB_FNAME = '.pickman.db'
+
+
+class Database:
+ """Database of cherry-pick state used by pickman"""
+
+ # dict of databases:
+ # key: filename
+ # value: Database object
+ instances = {}
+
+ def __init__(self, db_path):
+ """Set up a new database object
+
+ Args:
+ db_path (str): Path to the database
+ """
+ if db_path in Database.instances:
+ raise ValueError(f"There is already a database for '{db_path}'")
+ self.con = None
+ self.cur = None
+ self.db_path = db_path
+ self.is_open = False
+ Database.instances[db_path] = self
+
+ @staticmethod
+ def get_instance(db_path):
+ """Get the database instance for a path
+
+ Args:
+ db_path (str): Path to the database
+
+ Return:
+ tuple:
+ Database: Database instance, created if necessary
+ bool: True if newly created
+ """
+ dbs = Database.instances.get(db_path)
+ if dbs:
+ return dbs, False
+ return Database(db_path), True
+
+ def start(self):
+ """Open the database ready for use, migrate to latest schema"""
+ self.open_it()
+ self.migrate_to(LATEST)
+
+ def open_it(self):
+ """Open the database, creating it if necessary"""
+ if self.is_open:
+ raise ValueError('Already open')
+ if not os.path.exists(self.db_path):
+ tout.warning(f'Creating new database {self.db_path}')
+ self.con = sqlite3.connect(self.db_path)
+ self.cur = self.con.cursor()
+ self.is_open = True
+ Database.instances[self.db_path] = self
+
+ def close(self):
+ """Close the database"""
+ if not self.is_open:
+ raise ValueError('Already closed')
+ self.con.close()
+ self.cur = None
+ self.con = None
+ self.is_open = False
+ Database.instances.pop(self.db_path, None)
+
+ def _create_v1(self):
+ """Create a database with the v1 schema"""
+ # Table for tracking source branches and their last cherry-picked commit
+ self.cur.execute(
+ 'CREATE TABLE source ('
+ 'id INTEGER PRIMARY KEY AUTOINCREMENT, '
+ 'name TEXT UNIQUE, '
+ 'last_commit TEXT)')
+
+ # Schema version table
+ self.cur.execute('CREATE TABLE schema_version (version INTEGER)')
+
+ def migrate_to(self, dest_version):
+ """Migrate the database to the selected version
+
+ Args:
+ dest_version (int): Version to migrate to
+ """
+ while True:
+ version = self.get_schema_version()
+ if version >= dest_version:
+ break
+
+ self.close()
+ tools.write_file(f'{self.db_path}old.v{version}',
+ tools.read_file(self.db_path))
+
+ version += 1
+ tout.info(f'Update database to v{version}')
+ self.open_it()
+ if version == 1:
+ self._create_v1()
+
+ self.cur.execute('DELETE FROM schema_version')
+ self.cur.execute(
+ 'INSERT INTO schema_version (version) VALUES (?)',
+ (version,))
+ self.commit()
+
+ def get_schema_version(self):
+ """Get the version of the database's schema
+
+ Return:
+ int: Database version, 0 means there is no data
+ """
+ try:
+ self.cur.execute('SELECT version FROM schema_version')
+ return self.cur.fetchone()[0]
+ except sqlite3.OperationalError:
+ return 0
+
+ def execute(self, query, parameters=()):
+ """Execute a database query
+
+ Args:
+ query (str): Query string
+ parameters (tuple): Parameters to pass
+
+ Return:
+ Cursor result
+ """
+ return self.cur.execute(query, parameters)
+
+ def commit(self):
+ """Commit changes to the database"""
+ self.con.commit()
+
+ def rollback(self):
+ """Roll back changes to the database"""
+ self.con.rollback()
+
+ # source functions
+
+ def source_get(self, name):
+ """Get the last cherry-picked commit for a source branch
+
+ Args:
+ name (str): Source branch name
+
+ Return:
+ str: Commit hash, or None if not found
+ """
+ res = self.execute(
+ 'SELECT last_commit FROM source WHERE name = ?', (name,))
+ rec = res.fetchone()
+ if rec:
+ return rec[0]
+ return None
+
+ def source_set(self, name, commit):
+ """Set the last cherry-picked commit for a source branch
+
+ Args:
+ name (str): Source branch name
+ commit (str): Commit hash
+ """
+ self.execute(
+ 'UPDATE source SET last_commit = ? WHERE name = ?', (commit, name))
+ if self.cur.rowcount == 0:
+ self.execute(
+ 'INSERT INTO source (name, last_commit) VALUES (?, ?)',
+ (name, commit))
@@ -7,6 +7,7 @@
import os
import sys
+import tempfile
import unittest
# Allow 'from pickman import xxx' to work via symlink
@@ -19,6 +20,7 @@ from u_boot_pylib import terminal
from pickman import __main__ as pickman
from pickman import control
+from pickman import database
class TestCommit(unittest.TestCase):
@@ -152,5 +154,75 @@ class TestMain(unittest.TestCase):
command.TEST_RESULT = None
+class TestDatabase(unittest.TestCase):
+ """Tests for Database class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ fd, self.db_path = tempfile.mkstemp(suffix='.db')
+ os.close(fd)
+ os.unlink(self.db_path) # Remove so database creates it fresh
+ database.Database.instances.clear()
+
+ def tearDown(self):
+ """Clean up test fixtures."""
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ database.Database.instances.clear()
+
+ def test_create_database(self):
+ """Test creating a new database."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ self.assertTrue(dbs.is_open)
+ self.assertEqual(dbs.get_schema_version(), database.LATEST)
+ dbs.close()
+
+ def test_source_get_empty(self):
+ """Test getting source from empty database."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ result = dbs.source_get('us/next')
+ self.assertIsNone(result)
+ dbs.close()
+
+ def test_source_set_and_get(self):
+ """Test setting and getting source commit."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ dbs.source_set('us/next', 'abc123def456')
+ dbs.commit()
+ result = dbs.source_get('us/next')
+ self.assertEqual(result, 'abc123def456')
+ dbs.close()
+
+ def test_source_update(self):
+ """Test updating source commit."""
+ with terminal.capture():
+ dbs = database.Database(self.db_path)
+ dbs.start()
+ dbs.source_set('us/next', 'abc123')
+ dbs.commit()
+ dbs.source_set('us/next', 'def456')
+ dbs.commit()
+ result = dbs.source_get('us/next')
+ self.assertEqual(result, 'def456')
+ dbs.close()
+
+ def test_get_instance(self):
+ """Test get_instance returns same database."""
+ with terminal.capture():
+ dbs1, created1 = database.Database.get_instance(self.db_path)
+ dbs1.start()
+ dbs2, created2 = database.Database.get_instance(self.db_path)
+ self.assertTrue(created1)
+ self.assertFalse(created2)
+ self.assertIs(dbs1, dbs2)
+ dbs1.close()
+
+
if __name__ == '__main__':
unittest.main()