add csrf protection
This commit is contained in:
@@ -10,12 +10,13 @@ from .constants import (
|
||||
)
|
||||
from .lib.babycode import babycode_to_html, babycode_to_rssxml, EMOJI, BABYCODE_VERSION
|
||||
from .lib.exceptions import SiteNameMissingException
|
||||
from .util import get_post_url, dict_to_query_string
|
||||
from .util import get_post_url, dict_to_query_string, csrf_input, get_csrf_token
|
||||
from datetime import datetime, timezone
|
||||
from flask_caching import Cache
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
import hmac
|
||||
import tomllib
|
||||
import json
|
||||
|
||||
@@ -238,6 +239,17 @@ def create_app():
|
||||
session.clear()
|
||||
return redirect(url_for('topics.all_topics'))
|
||||
|
||||
@app.before_request
|
||||
def generate_csrf_token():
|
||||
if is_logged_in() and not session.get('csrf'):
|
||||
rng = secrets.token_bytes(32)
|
||||
session_key = session['pyrom_session_key']
|
||||
message = f'd${len(session_key)}${session_key}@{len(rng)}@{rng.hex()}'
|
||||
hashed = hmac.digest(app.config['SECRET_KEY'].encode('utf-8'), message.encode('utf-8'), 'SHA256')
|
||||
csrf_token = f'{hashed.hex()}.{rng.hex()}'
|
||||
|
||||
session['csrf'] = csrf_token
|
||||
|
||||
commit = ''
|
||||
with open('.git/refs/heads/main') as f:
|
||||
commit = f.read().strip()
|
||||
@@ -264,6 +276,8 @@ def create_app():
|
||||
'is_mod': lambda: is_logged_in() and get_active_user().is_mod(),
|
||||
'get_active_user': get_active_user,
|
||||
'get_post_url': get_post_url,
|
||||
'csrf_input': csrf_input,
|
||||
'get_csrf_token': get_csrf_token,
|
||||
}
|
||||
|
||||
@app.template_filter('ts_datetime')
|
||||
|
||||
36
app/auth.py
36
app/auth.py
@@ -1,8 +1,9 @@
|
||||
from flask import session, flash, redirect, url_for, abort
|
||||
from flask import session, flash, redirect, url_for, abort, request, current_app
|
||||
from .models import Sessions, Users
|
||||
from argon2 import PasswordHasher
|
||||
from functools import wraps
|
||||
import secrets
|
||||
import hmac
|
||||
import time
|
||||
import re
|
||||
|
||||
@@ -88,3 +89,36 @@ def mod_only(view_func):
|
||||
abort(403)
|
||||
return view_func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def csrf_verified(view_func):
|
||||
"""
|
||||
protects a request with a form against csrf and invalidates the csrf token stored in the session.
|
||||
|
||||
requires @login_requred.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not session.get('csrf'):
|
||||
abort(403)
|
||||
if not request.form.get('csrf'):
|
||||
abort(403)
|
||||
|
||||
parts = request.form['csrf'].split('.')
|
||||
if len(parts) != 2:
|
||||
abort(403)
|
||||
|
||||
given_message = parts[0]
|
||||
rng = bytes.fromhex(parts[1])
|
||||
session_key = session['pyrom_session_key']
|
||||
message = f'd${len(session_key)}${session_key}@{len(rng)}@{rng.hex()}'
|
||||
expected = hmac.digest(current_app.config['SECRET_KEY'].encode('utf-8'), message.encode('utf-8'), 'SHA256').hex()
|
||||
|
||||
if not hmac.compare_digest(given_message, expected):
|
||||
abort(403)
|
||||
|
||||
session.pop('csrf')
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask import Blueprint, abort, redirect, url_for, request, render_template
|
||||
from ..auth import is_logged_in, get_active_user
|
||||
from ..auth import is_logged_in, get_active_user, csrf_verified
|
||||
from ..models import Topics, Threads
|
||||
bp = Blueprint('mod', __name__, url_prefix='/mod/')
|
||||
|
||||
@@ -81,13 +81,16 @@ def sticky_thread(thread_id):
|
||||
return redirect(url_for('threads.thread', slug=thread.slug))
|
||||
|
||||
@bp.post('/users/<int:user_id>/make-guest/')
|
||||
@csrf_verified
|
||||
def make_user_guest(user_id):
|
||||
return 'stub'
|
||||
|
||||
@bp.post('/users/<int:user_id>/make-user/')
|
||||
@csrf_verified
|
||||
def make_user_regular(user_id):
|
||||
return 'stub'
|
||||
|
||||
@bp.post('/users/<int:user_id>/make-mod/')
|
||||
@csrf_verified
|
||||
def make_user_mod(user_id):
|
||||
return 'stub'
|
||||
|
||||
@@ -2,7 +2,11 @@ from flask import Blueprint, redirect, url_for, render_template, request, sessio
|
||||
from functools import wraps
|
||||
import time
|
||||
|
||||
from ..auth import digest, verify, create_session, is_logged_in, parse_username, is_password_valid, login_required
|
||||
from ..auth import (
|
||||
digest, verify, create_session,
|
||||
is_logged_in, parse_username, is_password_valid,
|
||||
login_required
|
||||
)
|
||||
from ..models import Users
|
||||
from ..constants import PermissionLevel
|
||||
from secrets import compare_digest as compare_timesafe
|
||||
@@ -24,6 +28,11 @@ def redirect_if_logged_in(destination='topics.all_topics'):
|
||||
def log_in():
|
||||
return render_template('users/log_in.html')
|
||||
|
||||
@bp.post('/log-out/')
|
||||
@login_required
|
||||
def log_out():
|
||||
return 'stub'
|
||||
|
||||
@bp.post('/log-in/')
|
||||
@redirect_if_logged_in()
|
||||
def log_in_post():
|
||||
@@ -124,7 +133,3 @@ def inbox(username):
|
||||
def bookmarks(username):
|
||||
return 'stub'
|
||||
|
||||
@bp.post('/<username>/log_out/')
|
||||
@login_required
|
||||
def log_out(username):
|
||||
return 'stub'
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
{%- if target_user.id == get_active_user().id -%}
|
||||
<fieldset class="plank even no-shadow minimal thread-actions">
|
||||
<legend>Actions</legend>
|
||||
<form action="{{url_for('users.log_out', username=target_user.username)}}" method="POST">
|
||||
<form action="{{url_for('users.log_out')}}" method="POST">
|
||||
<input type="submit" class="warn" value="Log out">
|
||||
</form>
|
||||
</fieldset>
|
||||
@@ -19,6 +19,7 @@
|
||||
<fieldset class="plank even no-shadow minimal thread-actions">
|
||||
<legend>Moderation actions</legend>
|
||||
<form method="POST">
|
||||
{{csrf_input() | safe}}
|
||||
{%- if target_user.is_guest() -%}
|
||||
<input class="warn" type="submit" value="Approve user" formaction="{{url_for('mod.make_user_regular', user_id=target_user.id)}}">
|
||||
{%- else -%}
|
||||
|
||||
12
app/util.py
12
app/util.py
@@ -1,5 +1,6 @@
|
||||
from flask import url_for
|
||||
from flask import url_for, session
|
||||
from .models import Posts, Threads
|
||||
from .auth import is_logged_in
|
||||
|
||||
def get_post_url(post_id, _anchor=False, external=False):
|
||||
post = Posts.find({'id': post_id})
|
||||
@@ -14,3 +15,12 @@ def get_post_url(post_id, _anchor=False, external=False):
|
||||
|
||||
def dict_to_query_string(d) -> str:
|
||||
return '?' + '&'.join([f'{key}={str(value)}' for key, value in d.items()])
|
||||
|
||||
def get_csrf_token():
|
||||
if not is_logged_in():
|
||||
return ''
|
||||
|
||||
return session.get('csrf', '')
|
||||
|
||||
def csrf_input():
|
||||
return f'<input type="hidden" name="csrf" value="{get_csrf_token()}">'
|
||||
|
||||
Reference in New Issue
Block a user