finally properly handle transactions
This commit is contained in:
		
							
								
								
									
										40
									
								
								app/db.py
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								app/db.py
									
									
									
									
									
								
							@@ -5,10 +5,11 @@ from flask import current_app
 | 
			
		||||
class DB:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._connection = None
 | 
			
		||||
        self._transaction_depth = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def _get_connection(self):
 | 
			
		||||
    def connection(self, in_transaction = False):
 | 
			
		||||
        if self._connection:
 | 
			
		||||
            yield self._connection
 | 
			
		||||
            return
 | 
			
		||||
@@ -18,29 +19,36 @@ class DB:
 | 
			
		||||
        conn.execute("PRAGMA FOREIGN_KEYS = 1")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if in_transaction:
 | 
			
		||||
                self._connection = conn
 | 
			
		||||
                self._transaction_depth += 1
 | 
			
		||||
                conn.execute("BEGIN")
 | 
			
		||||
 | 
			
		||||
            yield conn
 | 
			
		||||
 | 
			
		||||
            if in_transaction:
 | 
			
		||||
                conn.commit()
 | 
			
		||||
        except Exception:
 | 
			
		||||
            if in_transaction and self._connection:
 | 
			
		||||
                conn.rollback()
 | 
			
		||||
        finally:
 | 
			
		||||
            conn.close()
 | 
			
		||||
            if in_transaction:
 | 
			
		||||
                self._transaction_depth -= 1
 | 
			
		||||
                if self._transaction_depth == 0:
 | 
			
		||||
                    self._connection = None
 | 
			
		||||
                    conn.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def transaction(self):
 | 
			
		||||
        """Transaction context."""
 | 
			
		||||
        tr_connection = sqlite3.connect(current_app.config["DB_PATH"])
 | 
			
		||||
        tr_connection.row_factory = sqlite3.Row
 | 
			
		||||
        tr_connection.execute("PRAGMA FOREIGN_KEYS = 1")
 | 
			
		||||
        tr_connection.execute("BEGIN")
 | 
			
		||||
        try:
 | 
			
		||||
            yield
 | 
			
		||||
            tr_connection.execute("COMMIT")
 | 
			
		||||
        except Exception:
 | 
			
		||||
            tr_connection.execute("ROLLBACK")
 | 
			
		||||
            raise
 | 
			
		||||
        with self.connection(in_transaction=True) as conn:
 | 
			
		||||
            yield conn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def query(self, sql, *args):
 | 
			
		||||
        """Executes a query and returns a list of dictionaries."""
 | 
			
		||||
        with self._get_connection() as conn:
 | 
			
		||||
        with self.connection() as conn:
 | 
			
		||||
            rows = conn.execute(sql, args).fetchall()
 | 
			
		||||
            return [dict(row) for row in rows]
 | 
			
		||||
 | 
			
		||||
@@ -56,7 +64,7 @@ class DB:
 | 
			
		||||
            RETURNING *
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        with self._get_connection() as conn:
 | 
			
		||||
        with self.connection() as conn:
 | 
			
		||||
            result = conn.execute(sql, values).fetchone()
 | 
			
		||||
            conn.commit()
 | 
			
		||||
            return dict(result) if result else None
 | 
			
		||||
@@ -64,14 +72,14 @@ class DB:
 | 
			
		||||
 | 
			
		||||
    def execute(self, sql, *args):
 | 
			
		||||
        """Executes a query without returning."""
 | 
			
		||||
        with self._get_connection() as conn:
 | 
			
		||||
        with self.connection() as conn:
 | 
			
		||||
            conn.execute(sql, args)
 | 
			
		||||
            conn.commit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def fetch_one(self, sql, *args):
 | 
			
		||||
        """Grabs the first row of a query."""
 | 
			
		||||
        with self._get_connection() as conn:
 | 
			
		||||
        with self.connection() as conn:
 | 
			
		||||
            row = conn.execute(sql, args).fetchone()
 | 
			
		||||
            return dict(row) if row else None
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user