finally properly handle transactions

This commit is contained in:
Lera Elvoé 2025-07-01 00:22:12 +03:00
parent 320b898b29
commit a7f9fbfe90
Signed by: yagich
SSH Key Fingerprint: SHA256:6xjGb6uA7lAVcULa7byPEN//rQ0wPoG+UzYVMfZnbvc

View File

@ -5,10 +5,11 @@ from flask import current_app
class DB: class DB:
def __init__(self): def __init__(self):
self._connection = None self._connection = None
self._transaction_depth = 0
@contextmanager @contextmanager
def _get_connection(self): def connection(self, in_transaction = False):
if self._connection: if self._connection:
yield self._connection yield self._connection
return return
@ -18,29 +19,36 @@ class DB:
conn.execute("PRAGMA FOREIGN_KEYS = 1") conn.execute("PRAGMA FOREIGN_KEYS = 1")
try: try:
if in_transaction:
self._connection = conn
self._transaction_depth += 1
conn.execute("BEGIN")
yield conn yield conn
if in_transaction:
conn.commit()
except Exception:
if in_transaction and self._connection:
conn.rollback()
finally: finally:
if in_transaction:
self._transaction_depth -= 1
if self._transaction_depth == 0:
self._connection = None
conn.close() conn.close()
@contextmanager @contextmanager
def transaction(self): def transaction(self):
"""Transaction context.""" """Transaction context."""
tr_connection = sqlite3.connect(current_app.config["DB_PATH"]) with self.connection(in_transaction=True) as conn:
tr_connection.row_factory = sqlite3.Row yield conn
tr_connection.execute("PRAGMA FOREIGN_KEYS = 1")
tr_connection.execute("BEGIN")
try:
yield
tr_connection.execute("COMMIT")
except Exception:
tr_connection.execute("ROLLBACK")
raise
def query(self, sql, *args): def query(self, sql, *args):
"""Executes a query and returns a list of dictionaries.""" """Executes a query and returns a list of dictionaries."""
with self._get_connection() as conn: with self.connection() as conn:
rows = conn.execute(sql, args).fetchall() rows = conn.execute(sql, args).fetchall()
return [dict(row) for row in rows] return [dict(row) for row in rows]
@ -56,7 +64,7 @@ class DB:
RETURNING * RETURNING *
""" """
with self._get_connection() as conn: with self.connection() as conn:
result = conn.execute(sql, values).fetchone() result = conn.execute(sql, values).fetchone()
conn.commit() conn.commit()
return dict(result) if result else None return dict(result) if result else None
@ -64,14 +72,14 @@ class DB:
def execute(self, sql, *args): def execute(self, sql, *args):
"""Executes a query without returning.""" """Executes a query without returning."""
with self._get_connection() as conn: with self.connection() as conn:
conn.execute(sql, args) conn.execute(sql, args)
conn.commit() conn.commit()
def fetch_one(self, sql, *args): def fetch_one(self, sql, *args):
"""Grabs the first row of a query.""" """Grabs the first row of a query."""
with self._get_connection() as conn: with self.connection() as conn:
row = conn.execute(sql, args).fetchone() row = conn.execute(sql, args).fetchone()
return dict(row) if row else None return dict(row) if row else None