pyrom/app/db.py

229 lines
6.0 KiB
Python

import sqlite3
from contextlib import contextmanager
from flask import current_app
class DB:
def __init__(self):
self._connection = None
self._transaction_depth = 0
@contextmanager
def connection(self, in_transaction = False):
if self._connection:
yield self._connection
return
conn = sqlite3.connect(current_app.config["DB_PATH"])
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA FOREIGN_KEYS = 1")
try:
if in_transaction:
self._connection = conn
self._transaction_depth += 1
conn.execute("BEGIN")
yield conn
if in_transaction:
conn.commit()
except Exception:
if in_transaction and self._connection:
conn.rollback()
finally:
if in_transaction:
self._transaction_depth -= 1
if self._transaction_depth == 0:
self._connection = None
conn.close()
@contextmanager
def transaction(self):
"""Transaction context."""
with self.connection(in_transaction=True) as conn:
yield conn
def query(self, sql, *args):
"""Executes a query and returns a list of dictionaries."""
with self.connection() as conn:
rows = conn.execute(sql, args).fetchall()
return [dict(row) for row in rows]
def insert(self, table, columns, *values):
if isinstance(columns, (list, tuple)):
columns = ", ".join(columns)
placeholders = ", ".join(["?"] * len(values))
sql = f"""
INSERT INTO {table} ({columns})
VALUES ({placeholders})
RETURNING *
"""
with self.connection() as conn:
result = conn.execute(sql, values).fetchone()
conn.commit()
return dict(result) if result else None
def execute(self, sql, *args):
"""Executes a query without returning."""
with self.connection() as conn:
conn.execute(sql, args)
conn.commit()
def fetch_one(self, sql, *args):
"""Grabs the first row of a query."""
with self.connection() as conn:
row = conn.execute(sql, args).fetchone()
return dict(row) if row else None
class QueryBuilder:
def __init__(self, table):
self.table = table
self._where = [] # list of tuples
self._select = "*"
def _build_where(self):
if not self._where:
return "", []
conditions = []
params = []
for col, op, val in self._where:
conditions.append(f"{col} {op} ?")
params.append(val)
return " WHERE " + " AND ".join(conditions), params
def select(self, columns = "*"):
self._select = columns
return self
def where(self, condition, operator = "="):
if isinstance(condition, dict):
for key, value in condition.items():
self._where.append((key, "=", value))
elif isinstance(condition, list):
for c in condition:
self._where.append(c)
return self
def build_select(self):
sql = f"SELECT {self._select} FROM {self.table}"
where_clause, params = self._build_where()
return sql + where_clause, params
def build_update(self, data):
columns = ", ".join(f"{k} = ?" for k in data.keys())
sql = f"UPDATE {self.table} SET {columns}"
where_clause, where_params = self._build_where()
params = list(data.values()) + list(where_params)
return sql + where_clause, params
def build_delete(self):
sql = f"DELETE FROM {self.table}"
where_clause, params = self._build_where()
return sql + where_clause, params
def first(self):
sql, params = self.build_select()
return db.fetch_one(f"{sql} LIMIT 1", *params)
def all(self):
sql, params = self.build_select()
return db.query(sql, *params)
class Model:
def __init__(self, table):
self.table = table
self._data = {}
def __getitem__(self, key):
return self._data[key]
def __getattr__(self, key):
try:
return self._data[key]
except KeyError:
raise AttributeError(f"No column '{key}'")
@classmethod
def find(cls, condition):
row = db.QueryBuilder(cls.table)\
.where(condition)\
.first()
if not row:
return None
instance = cls(cls.table)
instance._data = dict(row)
return instance
@classmethod
def create(cls, values):
if not values:
return None
columns = list(values.keys())
row = db.insert(cls.table, columns, *values.values())
if row:
instance = cls(cls.table)
instance._data = row
return instance
return None
@classmethod
def count(cls, conditions = None):
qb = db.QueryBuilder(cls.table).select("COUNT(*) AS c")
if conditions is not None:
qb.where(conditions)
result = qb.first()
return result["c"] if result else 0
@classmethod
def select(cls, sel = "*"):
qb = db.QueryBuilder(cls.table).select(sel)
result = qb.all()
return result if result else []
def update(self, data):
qb = db.QueryBuilder(self.table)\
.where({"id": self._data["id"]})
sql, params = qb.build_update(data)
db.execute(sql, *params)
self._data.update(data)
def delete(self):
qb = db.QueryBuilder(self.table)\
.where({"id": self._data["id"]})
sql, params = qb.build_delete()
db.execute(sql, *params)
self._data = {}
db = DB()