diff --git a/app/db.py b/app/db.py index 030692a..4da83f7 100644 --- a/app/db.py +++ b/app/db.py @@ -5,10 +5,11 @@ from flask import current_app class DB: def __init__(self): self._connection = None + self._transaction_depth = 0 @contextmanager - def _get_connection(self): + def connection(self, in_transaction = False): if self._connection: yield self._connection return @@ -18,29 +19,36 @@ class DB: conn.execute("PRAGMA FOREIGN_KEYS = 1") try: + if in_transaction: + self._connection = conn + self._transaction_depth += 1 + conn.execute("BEGIN") + yield conn + + if in_transaction: + conn.commit() + except Exception: + if in_transaction and self._connection: + conn.rollback() finally: - conn.close() + if in_transaction: + self._transaction_depth -= 1 + if self._transaction_depth == 0: + self._connection = None + conn.close() @contextmanager def transaction(self): """Transaction context.""" - tr_connection = sqlite3.connect(current_app.config["DB_PATH"]) - tr_connection.row_factory = sqlite3.Row - tr_connection.execute("PRAGMA FOREIGN_KEYS = 1") - tr_connection.execute("BEGIN") - try: - yield - tr_connection.execute("COMMIT") - except Exception: - tr_connection.execute("ROLLBACK") - raise + with self.connection(in_transaction=True) as conn: + yield conn def query(self, sql, *args): """Executes a query and returns a list of dictionaries.""" - with self._get_connection() as conn: + with self.connection() as conn: rows = conn.execute(sql, args).fetchall() return [dict(row) for row in rows] @@ -56,7 +64,7 @@ class DB: RETURNING * """ - with self._get_connection() as conn: + with self.connection() as conn: result = conn.execute(sql, values).fetchone() conn.commit() return dict(result) if result else None @@ -64,14 +72,14 @@ class DB: def execute(self, sql, *args): """Executes a query without returning.""" - with self._get_connection() as conn: + with self.connection() as conn: conn.execute(sql, args) conn.commit() def fetch_one(self, sql, *args): """Grabs the first row of a query.""" - with self._get_connection() as conn: + with self.connection() as conn: row = conn.execute(sql, args).fetchone() return dict(row) if row else None