diff --git a/app/migrations.py b/app/migrations.py index 4523f44..0ae3767 100644 --- a/app/migrations.py +++ b/app/migrations.py @@ -1,9 +1,9 @@ from .db import db -# format: {integer: str|list} -MIGRATIONS = { +# format: [str|tuple(str, any...)|callable] +MIGRATIONS = [ -} +] def run_migrations(): db.execute(""" @@ -16,18 +16,21 @@ def run_migrations(): return print("Running migrations...") ran = 0 - completed = [row["id"] for row in db.query("SELECT id FROM _migrations")] - for migration_id in sorted(MIGRATIONS.keys()): - if migration_id not in completed: - print(f"Running migration #{migration_id}") + completed = {int(row["id"]) for row in db.query("SELECT id FROM _migrations")} + to_run = {idx: migration_obj for idx, migration_obj in enumerate(MIGRATIONS) if idx not in completed} + if not to_run: + print('No migrations need to run.') + return + + with db.transaction(): + for migration_id, migration_obj in to_run.items(): + if isinstance(migration_obj, str): + db.execute(migration_obj) + elif isinstance(migration_obj, tuple): + db.execute(migration_obj[0], *migration_obj[1:]) + elif callable(migration_obj): + migration_obj() + + db.execute('INSERT INTO _migrations (id) VALUES (?)', migration_id) ran += 1 - statements = MIGRATIONS[migration_id] - # support both strings and lists - if isinstance(statements, str): - statements = [statements] - - for sql in statements: - db.execute(sql) - - db.execute("INSERT INTO _migrations (id) VALUES (?)", migration_id) print(f"Ran {ran} migrations.")