initial commit
This commit is contained in:
218
app/db.py
Normal file
218
app/db.py
Normal file
@ -0,0 +1,218 @@
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from flask import current_app
|
||||
|
||||
class DB:
|
||||
def __init__(self):
|
||||
self._transaction_depth = 0
|
||||
self._connection = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
if self._connection and self._transaction_depth > 0:
|
||||
yield self._connection
|
||||
return
|
||||
|
||||
conn = sqlite3.connect(current_app.config["DB_PATH"])
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA FOREIGN_KEYS = 1")
|
||||
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
if self._transaction_depth == 0:
|
||||
conn.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction(self):
|
||||
"""Transaction context."""
|
||||
self.begin()
|
||||
try:
|
||||
yield
|
||||
self.commit()
|
||||
except Exception:
|
||||
self.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def begin(self):
|
||||
"""Begins a new transaction."""
|
||||
if self._transaction_depth == 0:
|
||||
if not self._connection:
|
||||
self._connection = sqlite3.connect(current_app.config["DB_PATH"])
|
||||
self._connection.row_factory = sqlite3.Row
|
||||
self._connection.execute("PRAGMA FOREIGN_KEYS = 1")
|
||||
self._connection.execute("BEGIN")
|
||||
self._transaction_depth += 1
|
||||
|
||||
|
||||
def commit(self):
|
||||
"""Commits the current transaction."""
|
||||
if self._transaction_depth > 0:
|
||||
self._transaction_depth -= 1
|
||||
if self._transaction_depth == 0:
|
||||
self._connection.commit()
|
||||
|
||||
|
||||
def rollback(self):
|
||||
"""Rolls back the current transaction."""
|
||||
if self._transaction_depth > 0:
|
||||
self._transaction_depth = 0
|
||||
self._connection.rollback()
|
||||
|
||||
|
||||
def query(self, sql, *args):
|
||||
"""Executes a query and returns a list of dictionaries."""
|
||||
with self._get_connection() as conn:
|
||||
rows = conn.execute(sql, args).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
def insert(self, table, columns, *values):
|
||||
if isinstance(columns, (list, tuple)):
|
||||
columns = ", ".join(columns)
|
||||
|
||||
placeholders = ", ".join(["?"] * len(values))
|
||||
sql = f"""
|
||||
INSERT INTO {table} ({columns})
|
||||
VALUES ({placeholders})
|
||||
RETURNING *
|
||||
"""
|
||||
|
||||
with self._get_connection() as conn:
|
||||
result = conn.execute(sql, values).fetchone()
|
||||
conn.commit()
|
||||
return dict(result) if result else None
|
||||
|
||||
|
||||
def execute(self, sql, *args):
|
||||
"""Executes a query without returning."""
|
||||
with self._get_connection() as conn:
|
||||
conn.execute(sql, args)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fetch_one(self, sql, *args):
|
||||
"""Grabs the first row of a query."""
|
||||
with self._get_connection() as conn:
|
||||
row = conn.execute(sql, args).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
class QueryBuilder:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
self._where = {}
|
||||
self._select = "*"
|
||||
self._params = []
|
||||
|
||||
|
||||
def select(self, columns = "*"):
|
||||
self._select = columns
|
||||
return self
|
||||
|
||||
|
||||
def where(self, condition):
|
||||
self._where.update(condition)
|
||||
return self
|
||||
|
||||
|
||||
def build_select(self):
|
||||
sql = f"SELECT {self._select} FROM {self.table}"
|
||||
if self._where:
|
||||
conditions = " AND ".join(f"{k} = ?" for k in self._where.keys())
|
||||
sql += f" WHERE {conditions}"
|
||||
return sql, list(self._where.values())
|
||||
|
||||
|
||||
def build_update(self, data):
|
||||
columns = ", ".join(f"{k} = ?" for k in data.keys())
|
||||
sql = f"UPDATE {self.table} SET {columns}"
|
||||
if self._where:
|
||||
conditions = " AND ".join(f"{k} = ?" for k in self._where.keys())
|
||||
sql += f" WHERE {conditions}"
|
||||
params = list(data.values()) + list(self._where.values())
|
||||
return sql, params
|
||||
|
||||
|
||||
def build_delete(self):
|
||||
sql = f"DELETE FROM {self.table}"
|
||||
if self._where:
|
||||
conditions = " AND ".join(f"{k} = ?" for k in self._where.keys())
|
||||
sql += f" WHERE {conditions}"
|
||||
return sql, list(self._where.values())
|
||||
|
||||
|
||||
def first(self):
|
||||
sql, params = self.build_select()
|
||||
print(sql, params)
|
||||
return db.fetch_one(f"{sql} LIMIT 1", *params)
|
||||
|
||||
|
||||
def all(self):
|
||||
sql, params = self.build_select()
|
||||
return db.query(sql, *params)
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
self._data = {}
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self._data[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"No column '{key}'")
|
||||
|
||||
|
||||
@classmethod
|
||||
def find(cls, condition):
|
||||
row = db.QueryBuilder(cls.table)\
|
||||
.where(condition)\
|
||||
.first()
|
||||
if not row:
|
||||
return None
|
||||
instance = cls(cls.table)
|
||||
instance._data = dict(row)
|
||||
return instance
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, values):
|
||||
if not values:
|
||||
return None
|
||||
|
||||
columns = list(values.keys())
|
||||
row = db.insert(cls.table, columns, *values.values())
|
||||
|
||||
if row:
|
||||
instance = cls(cls.table)
|
||||
instance._data = row
|
||||
return instance
|
||||
return None
|
||||
|
||||
|
||||
def update(self, data):
|
||||
qb = db.QueryBuilder(self.table)\
|
||||
.where({"id": self._data["id"]})
|
||||
sql, params = qb.build_update(data)
|
||||
db.execute(sql, *params)
|
||||
self._data.update(data)
|
||||
|
||||
|
||||
def delete(self):
|
||||
qb = db.QueryBuilder(self.table)\
|
||||
.where({"id": self._data["id"]})
|
||||
sql, params = qb.build_delete()
|
||||
db.execute(sql, *params)
|
||||
self._data = {}
|
||||
|
||||
db = DB()
|
Reference in New Issue
Block a user