From 0c2e920206424aa4338d00239b41ade8fba493b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lera=20Elvo=C3=A9?= Date: Sun, 19 Apr 2026 12:57:59 +0300 Subject: [PATCH] add csrf protection --- app/__init__.py | 16 ++++++++++++- app/auth.py | 36 +++++++++++++++++++++++++++++- app/routes/mod.py | 5 ++++- app/routes/users.py | 15 ++++++++----- app/templates/users/user_page.html | 3 ++- app/util.py | 12 +++++++++- 6 files changed, 77 insertions(+), 10 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index 2cac29f..cb79105 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -10,12 +10,13 @@ from .constants import ( ) from .lib.babycode import babycode_to_html, babycode_to_rssxml, EMOJI, BABYCODE_VERSION from .lib.exceptions import SiteNameMissingException -from .util import get_post_url, dict_to_query_string +from .util import get_post_url, dict_to_query_string, csrf_input, get_csrf_token from datetime import datetime, timezone from flask_caching import Cache import os import time import secrets +import hmac import tomllib import json @@ -238,6 +239,17 @@ def create_app(): session.clear() return redirect(url_for('topics.all_topics')) + @app.before_request + def generate_csrf_token(): + if is_logged_in() and not session.get('csrf'): + rng = secrets.token_bytes(32) + session_key = session['pyrom_session_key'] + message = f'd${len(session_key)}${session_key}@{len(rng)}@{rng.hex()}' + hashed = hmac.digest(app.config['SECRET_KEY'].encode('utf-8'), message.encode('utf-8'), 'SHA256') + csrf_token = f'{hashed.hex()}.{rng.hex()}' + + session['csrf'] = csrf_token + commit = '' with open('.git/refs/heads/main') as f: commit = f.read().strip() @@ -264,6 +276,8 @@ def create_app(): 'is_mod': lambda: is_logged_in() and get_active_user().is_mod(), 'get_active_user': get_active_user, 'get_post_url': get_post_url, + 'csrf_input': csrf_input, + 'get_csrf_token': get_csrf_token, } @app.template_filter('ts_datetime') diff --git a/app/auth.py b/app/auth.py index 4e52dc5..2020960 100644 --- a/app/auth.py +++ b/app/auth.py @@ -1,8 +1,9 @@ -from flask import session, flash, redirect, url_for, abort +from flask import session, flash, redirect, url_for, abort, request, current_app from .models import Sessions, Users from argon2 import PasswordHasher from functools import wraps import secrets +import hmac import time import re @@ -88,3 +89,36 @@ def mod_only(view_func): abort(403) return view_func(*args, **kwargs) return wrapper + +def csrf_verified(view_func): + """ + protects a request with a form against csrf and invalidates the csrf token stored in the session. + + requires @login_requred. + """ + + @wraps(view_func) + def wrapper(*args, **kwargs): + if not session.get('csrf'): + abort(403) + if not request.form.get('csrf'): + abort(403) + + parts = request.form['csrf'].split('.') + if len(parts) != 2: + abort(403) + + given_message = parts[0] + rng = bytes.fromhex(parts[1]) + session_key = session['pyrom_session_key'] + message = f'd${len(session_key)}${session_key}@{len(rng)}@{rng.hex()}' + expected = hmac.digest(current_app.config['SECRET_KEY'].encode('utf-8'), message.encode('utf-8'), 'SHA256').hex() + + if not hmac.compare_digest(given_message, expected): + abort(403) + + session.pop('csrf') + + return view_func(*args, **kwargs) + return wrapper + diff --git a/app/routes/mod.py b/app/routes/mod.py index ce7b5b1..e032cd2 100644 --- a/app/routes/mod.py +++ b/app/routes/mod.py @@ -1,5 +1,5 @@ from flask import Blueprint, abort, redirect, url_for, request, render_template -from ..auth import is_logged_in, get_active_user +from ..auth import is_logged_in, get_active_user, csrf_verified from ..models import Topics, Threads bp = Blueprint('mod', __name__, url_prefix='/mod/') @@ -81,13 +81,16 @@ def sticky_thread(thread_id): return redirect(url_for('threads.thread', slug=thread.slug)) @bp.post('/users//make-guest/') +@csrf_verified def make_user_guest(user_id): return 'stub' @bp.post('/users//make-user/') +@csrf_verified def make_user_regular(user_id): return 'stub' @bp.post('/users//make-mod/') +@csrf_verified def make_user_mod(user_id): return 'stub' diff --git a/app/routes/users.py b/app/routes/users.py index 1bf248a..54dc574 100644 --- a/app/routes/users.py +++ b/app/routes/users.py @@ -2,7 +2,11 @@ from flask import Blueprint, redirect, url_for, render_template, request, sessio from functools import wraps import time -from ..auth import digest, verify, create_session, is_logged_in, parse_username, is_password_valid, login_required +from ..auth import ( + digest, verify, create_session, + is_logged_in, parse_username, is_password_valid, + login_required + ) from ..models import Users from ..constants import PermissionLevel from secrets import compare_digest as compare_timesafe @@ -24,6 +28,11 @@ def redirect_if_logged_in(destination='topics.all_topics'): def log_in(): return render_template('users/log_in.html') +@bp.post('/log-out/') +@login_required +def log_out(): + return 'stub' + @bp.post('/log-in/') @redirect_if_logged_in() def log_in_post(): @@ -124,7 +133,3 @@ def inbox(username): def bookmarks(username): return 'stub' -@bp.post('//log_out/') -@login_required -def log_out(username): - return 'stub' diff --git a/app/templates/users/user_page.html b/app/templates/users/user_page.html index a9dcf59..1ebe825 100644 --- a/app/templates/users/user_page.html +++ b/app/templates/users/user_page.html @@ -9,7 +9,7 @@ {%- if target_user.id == get_active_user().id -%}
Actions -
+
@@ -19,6 +19,7 @@
Moderation actions
+ {{csrf_input() | safe}} {%- if target_user.is_guest() -%} {%- else -%} diff --git a/app/util.py b/app/util.py index 8e823bf..6239dc0 100644 --- a/app/util.py +++ b/app/util.py @@ -1,5 +1,6 @@ -from flask import url_for +from flask import url_for, session from .models import Posts, Threads +from .auth import is_logged_in def get_post_url(post_id, _anchor=False, external=False): post = Posts.find({'id': post_id}) @@ -14,3 +15,12 @@ def get_post_url(post_id, _anchor=False, external=False): def dict_to_query_string(d) -> str: return '?' + '&'.join([f'{key}={str(value)}' for key, value in d.items()]) + +def get_csrf_token(): + if not is_logged_in(): + return '' + + return session.get('csrf', '') + +def csrf_input(): + return f''