Skip to content

Commit

Permalink
Made flash messages work better with page security decorators. Added …
Browse files Browse the repository at this point in the history
…testing for this.
  • Loading branch information
CheeseCake87 committed Sep 20, 2023
1 parent 91c0bc0 commit e475f12
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 6 deletions.
19 changes: 18 additions & 1 deletion app/blueprints/tests/routes/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,24 @@ def set_permissions_value():
return render_template(bp.tmpl("security.html"), permission="admin")


@bp.route("/already-logged-in/bool/with-flash", methods=["GET"])
@login_check("logged_in", True, pass_endpoint="tests.already_logged_in", message="Already logged in")
def already_logged_in_bool():
return render_template(bp.tmpl("security.html"), logged_in_on=session.get("logged_in"))


@bp.route("/must-be-logged-in/bool", methods=["GET"])
@login_check("logged_in", True, "tests.login_failed")
def must_be_logged_in_bool():
return render_template(bp.tmpl("security.html"), logged_in_on=True)


@bp.route("/must-be-logged-in/bool/with-flash", methods=["GET"])
@login_check("logged_in", True, "tests.login_failed", message="Login needed")
def must_be_logged_in_bool_with_flash():
return render_template(bp.tmpl("security.html"), logged_in_on=True)


@bp.route("/must-be-logged-in/str", methods=["GET"])
@login_check("logged_in", "li", "tests.login_failed")
def must_be_logged_in_str():
Expand Down Expand Up @@ -78,7 +90,12 @@ def permission_check_adv():

@bp.route("/login-failed", methods=["GET"])
def login_failed():
return "Login failed"
return render_template(bp.tmpl("login_failed.html"))


@bp.route("/already-logged-in", methods=["GET"])
def already_logged_in():
return render_template(bp.tmpl("already_logged_in.html"))


@bp.route("/permission-failed", methods=["GET"])
Expand Down
9 changes: 9 additions & 0 deletions app/blueprints/tests/templates/tests/already_logged_in.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<p>Already logged in</p>

{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
{% for category, message in messages %}
{{ category }}:{{ message }},
{% endfor %}
{% endif %}
{% endwith %}
9 changes: 9 additions & 0 deletions app/blueprints/tests/templates/tests/login_failed.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<p>Login failed</p>

{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
{% for category, message in messages %}
{{ category }}:{{ message }},
{% endfor %}
{% endif %}
{% endwith %}
1 change: 0 additions & 1 deletion app/blueprints/tests/templates/tests/security.html
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
<h3>Security</h3>


<p>Logged_in {{ session.get('logged_in') }}</p>
<p>permissions {{ session.get('permissions') }}</p>
29 changes: 25 additions & 4 deletions src/flask_imp/security.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t
from functools import wraps
from functools import partial

from flask import flash, abort
from flask import redirect
Expand Down Expand Up @@ -66,12 +67,23 @@ def login_page():
"""

def login_check_wrapper(func):

@wraps(func)
def inner(*args, **kwargs):
skey = session.get(session_key)

def setup_flash(_message, _message_category):
if _message:
partial_flash = partial(flash, _message)
if _message_category:
partial_flash(_message_category)
else:
partial_flash()

if skey is None:
if fail_endpoint:
setup_flash(message, message_category)

if endpoint_kwargs:
return redirect(url_for(fail_endpoint, **endpoint_kwargs))

Expand All @@ -82,8 +94,7 @@ def inner(*args, **kwargs):
if skey is not None:
if _check_against_values_allowed(skey, values_allowed):
if pass_endpoint:
if message:
flash(message, message_category)
setup_flash(message, message_category)

if endpoint_kwargs:
return redirect(url_for(pass_endpoint, **endpoint_kwargs))
Expand All @@ -93,6 +104,8 @@ def inner(*args, **kwargs):
return func(*args, **kwargs)

if fail_endpoint:
setup_flash(message, message_category)

if endpoint_kwargs:
return redirect(url_for(fail_endpoint, **endpoint_kwargs))

Expand Down Expand Up @@ -141,14 +154,22 @@ def permission_check_wrapper(func):
def inner(*args, **kwargs):
skey = session.get(session_key)

def setup_flash(_message, _message_category):
if _message:
partial_flash = partial(flash, _message)
if _message_category:
partial_flash(_message_category)
else:
partial_flash()

if skey:
if _check_against_values_allowed(skey, values_allowed):
return func(*args, **kwargs)

if message:
flash(message, message_category)
setup_flash(message, message_category)

if fail_endpoint:

if endpoint_kwargs:
return redirect(url_for(fail_endpoint, **endpoint_kwargs))

Expand Down
12 changes: 12 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def test_security_login_fail(client):
assert b"Login failed" in response.data


def test_security_login_fail_with_message(client):
client.get('/tests/logout')
response = client.get('/tests/must-be-logged-in/bool/with-flash', follow_redirects=True)
assert b"message:Login needed" in response.data


def test_security_already_logged_in_pass_with_message(client):
client.get('/tests/login/bool')
response = client.get('/tests/must-be-logged-in/bool/with-flash', follow_redirects=True)
assert b"message:Already logged in" in response.data


def test_permission_list(client):
client.get('/tests/set-permission/list')
response = client.get('/tests/must-have-permissions/std')
Expand Down

0 comments on commit e475f12

Please sign in to comment.