finally properly handle transactions
This commit is contained in:
parent
320b898b29
commit
a7f9fbfe90
40
app/db.py
40
app/db.py
@ -5,10 +5,11 @@ from flask import current_app
|
|||||||
class DB:
|
class DB:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
self._transaction_depth = 0
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _get_connection(self):
|
def connection(self, in_transaction = False):
|
||||||
if self._connection:
|
if self._connection:
|
||||||
yield self._connection
|
yield self._connection
|
||||||
return
|
return
|
||||||
@ -18,29 +19,36 @@ class DB:
|
|||||||
conn.execute("PRAGMA FOREIGN_KEYS = 1")
|
conn.execute("PRAGMA FOREIGN_KEYS = 1")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if in_transaction:
|
||||||
|
self._connection = conn
|
||||||
|
self._transaction_depth += 1
|
||||||
|
conn.execute("BEGIN")
|
||||||
|
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
|
if in_transaction:
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
if in_transaction and self._connection:
|
||||||
|
conn.rollback()
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
if in_transaction:
|
||||||
|
self._transaction_depth -= 1
|
||||||
|
if self._transaction_depth == 0:
|
||||||
|
self._connection = None
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def transaction(self):
|
def transaction(self):
|
||||||
"""Transaction context."""
|
"""Transaction context."""
|
||||||
tr_connection = sqlite3.connect(current_app.config["DB_PATH"])
|
with self.connection(in_transaction=True) as conn:
|
||||||
tr_connection.row_factory = sqlite3.Row
|
yield conn
|
||||||
tr_connection.execute("PRAGMA FOREIGN_KEYS = 1")
|
|
||||||
tr_connection.execute("BEGIN")
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
tr_connection.execute("COMMIT")
|
|
||||||
except Exception:
|
|
||||||
tr_connection.execute("ROLLBACK")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def query(self, sql, *args):
|
def query(self, sql, *args):
|
||||||
"""Executes a query and returns a list of dictionaries."""
|
"""Executes a query and returns a list of dictionaries."""
|
||||||
with self._get_connection() as conn:
|
with self.connection() as conn:
|
||||||
rows = conn.execute(sql, args).fetchall()
|
rows = conn.execute(sql, args).fetchall()
|
||||||
return [dict(row) for row in rows]
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
@ -56,7 +64,7 @@ class DB:
|
|||||||
RETURNING *
|
RETURNING *
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self._get_connection() as conn:
|
with self.connection() as conn:
|
||||||
result = conn.execute(sql, values).fetchone()
|
result = conn.execute(sql, values).fetchone()
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return dict(result) if result else None
|
return dict(result) if result else None
|
||||||
@ -64,14 +72,14 @@ class DB:
|
|||||||
|
|
||||||
def execute(self, sql, *args):
|
def execute(self, sql, *args):
|
||||||
"""Executes a query without returning."""
|
"""Executes a query without returning."""
|
||||||
with self._get_connection() as conn:
|
with self.connection() as conn:
|
||||||
conn.execute(sql, args)
|
conn.execute(sql, args)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def fetch_one(self, sql, *args):
|
def fetch_one(self, sql, *args):
|
||||||
"""Grabs the first row of a query."""
|
"""Grabs the first row of a query."""
|
||||||
with self._get_connection() as conn:
|
with self.connection() as conn:
|
||||||
row = conn.execute(sql, args).fetchone()
|
row = conn.execute(sql, args).fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user