229 lines
6.0 KiB
Python
229 lines
6.0 KiB
Python
import sqlite3
|
|
from contextlib import contextmanager
|
|
from flask import current_app
|
|
|
|
class DB:
|
|
def __init__(self):
|
|
self._connection = None
|
|
self._transaction_depth = 0
|
|
|
|
|
|
@contextmanager
|
|
def connection(self, in_transaction = False):
|
|
if self._connection:
|
|
yield self._connection
|
|
return
|
|
|
|
conn = sqlite3.connect(current_app.config["DB_PATH"])
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA FOREIGN_KEYS = 1")
|
|
|
|
try:
|
|
if in_transaction:
|
|
self._connection = conn
|
|
self._transaction_depth += 1
|
|
conn.execute("BEGIN")
|
|
|
|
yield conn
|
|
|
|
if in_transaction:
|
|
conn.commit()
|
|
except Exception:
|
|
if in_transaction and self._connection:
|
|
conn.rollback()
|
|
finally:
|
|
if in_transaction:
|
|
self._transaction_depth -= 1
|
|
if self._transaction_depth == 0:
|
|
self._connection = None
|
|
conn.close()
|
|
|
|
|
|
@contextmanager
|
|
def transaction(self):
|
|
"""Transaction context."""
|
|
with self.connection(in_transaction=True) as conn:
|
|
yield conn
|
|
|
|
|
|
def query(self, sql, *args):
|
|
"""Executes a query and returns a list of dictionaries."""
|
|
with self.connection() as conn:
|
|
rows = conn.execute(sql, args).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
|
|
def insert(self, table, columns, *values):
|
|
if isinstance(columns, (list, tuple)):
|
|
columns = ", ".join(columns)
|
|
|
|
placeholders = ", ".join(["?"] * len(values))
|
|
sql = f"""
|
|
INSERT INTO {table} ({columns})
|
|
VALUES ({placeholders})
|
|
RETURNING *
|
|
"""
|
|
|
|
with self.connection() as conn:
|
|
result = conn.execute(sql, values).fetchone()
|
|
conn.commit()
|
|
return dict(result) if result else None
|
|
|
|
|
|
def execute(self, sql, *args):
|
|
"""Executes a query without returning."""
|
|
with self.connection() as conn:
|
|
conn.execute(sql, args)
|
|
conn.commit()
|
|
|
|
|
|
def fetch_one(self, sql, *args):
|
|
"""Grabs the first row of a query."""
|
|
with self.connection() as conn:
|
|
row = conn.execute(sql, args).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
class QueryBuilder:
|
|
def __init__(self, table):
|
|
self.table = table
|
|
self._where = [] # list of tuples
|
|
self._select = "*"
|
|
|
|
|
|
def _build_where(self):
|
|
if not self._where:
|
|
return "", []
|
|
|
|
conditions = []
|
|
params = []
|
|
for col, op, val in self._where:
|
|
conditions.append(f"{col} {op} ?")
|
|
params.append(val)
|
|
|
|
return " WHERE " + " AND ".join(conditions), params
|
|
|
|
|
|
def select(self, columns = "*"):
|
|
self._select = columns
|
|
return self
|
|
|
|
|
|
def where(self, condition, operator = "="):
|
|
if isinstance(condition, dict):
|
|
for key, value in condition.items():
|
|
self._where.append((key, "=", value))
|
|
elif isinstance(condition, list):
|
|
for c in condition:
|
|
self._where.append(c)
|
|
return self
|
|
|
|
|
|
def build_select(self):
|
|
sql = f"SELECT {self._select} FROM {self.table}"
|
|
where_clause, params = self._build_where()
|
|
return sql + where_clause, params
|
|
|
|
|
|
def build_update(self, data):
|
|
columns = ", ".join(f"{k} = ?" for k in data.keys())
|
|
sql = f"UPDATE {self.table} SET {columns}"
|
|
where_clause, where_params = self._build_where()
|
|
params = list(data.values()) + list(where_params)
|
|
return sql + where_clause, params
|
|
|
|
|
|
def build_delete(self):
|
|
sql = f"DELETE FROM {self.table}"
|
|
where_clause, params = self._build_where()
|
|
return sql + where_clause, params
|
|
|
|
|
|
def first(self):
|
|
sql, params = self.build_select()
|
|
return db.fetch_one(f"{sql} LIMIT 1", *params)
|
|
|
|
|
|
def all(self):
|
|
sql, params = self.build_select()
|
|
return db.query(sql, *params)
|
|
|
|
|
|
class Model:
|
|
def __init__(self, table):
|
|
self.table = table
|
|
self._data = {}
|
|
|
|
|
|
def __getitem__(self, key):
|
|
return self._data[key]
|
|
|
|
|
|
def __getattr__(self, key):
|
|
try:
|
|
return self._data[key]
|
|
except KeyError:
|
|
raise AttributeError(f"No column '{key}'")
|
|
|
|
|
|
@classmethod
|
|
def find(cls, condition):
|
|
row = db.QueryBuilder(cls.table)\
|
|
.where(condition)\
|
|
.first()
|
|
if not row:
|
|
return None
|
|
instance = cls(cls.table)
|
|
instance._data = dict(row)
|
|
return instance
|
|
|
|
|
|
@classmethod
|
|
def create(cls, values):
|
|
if not values:
|
|
return None
|
|
|
|
columns = list(values.keys())
|
|
row = db.insert(cls.table, columns, *values.values())
|
|
|
|
if row:
|
|
instance = cls(cls.table)
|
|
instance._data = row
|
|
return instance
|
|
return None
|
|
|
|
|
|
@classmethod
|
|
def count(cls, conditions = None):
|
|
qb = db.QueryBuilder(cls.table).select("COUNT(*) AS c")
|
|
if conditions is not None:
|
|
qb.where(conditions)
|
|
|
|
result = qb.first()
|
|
return result["c"] if result else 0
|
|
|
|
|
|
@classmethod
|
|
def select(cls, sel = "*"):
|
|
qb = db.QueryBuilder(cls.table).select(sel)
|
|
result = qb.all()
|
|
return result if result else []
|
|
|
|
|
|
def update(self, data):
|
|
qb = db.QueryBuilder(self.table)\
|
|
.where({"id": self._data["id"]})
|
|
sql, params = qb.build_update(data)
|
|
db.execute(sql, *params)
|
|
self._data.update(data)
|
|
|
|
|
|
def delete(self):
|
|
qb = db.QueryBuilder(self.table)\
|
|
.where({"id": self._data["id"]})
|
|
sql, params = qb.build_delete()
|
|
db.execute(sql, *params)
|
|
self._data = {}
|
|
|
|
db = DB()
|