diff --git a/app/blueprints/tests/routes/security.py b/app/blueprints/tests/routes/security.py index 067df472..10fb3f85 100644 --- a/app/blueprints/tests/routes/security.py +++ b/app/blueprints/tests/routes/security.py @@ -40,12 +40,24 @@ def set_permissions_value(): return render_template(bp.tmpl("security.html"), permission="admin") +@bp.route("/already-logged-in/bool/with-flash", methods=["GET"]) +@login_check("logged_in", True, pass_endpoint="tests.already_logged_in", message="Already logged in") +def already_logged_in_bool(): + return render_template(bp.tmpl("security.html"), logged_in_on=session.get("logged_in")) + + @bp.route("/must-be-logged-in/bool", methods=["GET"]) @login_check("logged_in", True, "tests.login_failed") def must_be_logged_in_bool(): return render_template(bp.tmpl("security.html"), logged_in_on=True) +@bp.route("/must-be-logged-in/bool/with-flash", methods=["GET"]) +@login_check("logged_in", True, "tests.login_failed", message="Login needed") +def must_be_logged_in_bool_with_flash(): + return render_template(bp.tmpl("security.html"), logged_in_on=True) + + @bp.route("/must-be-logged-in/str", methods=["GET"]) @login_check("logged_in", "li", "tests.login_failed") def must_be_logged_in_str(): @@ -78,7 +90,12 @@ def permission_check_adv(): @bp.route("/login-failed", methods=["GET"]) def login_failed(): - return "Login failed" + return render_template(bp.tmpl("login_failed.html")) + + +@bp.route("/already-logged-in", methods=["GET"]) +def already_logged_in(): + return render_template(bp.tmpl("already_logged_in.html")) @bp.route("/permission-failed", methods=["GET"]) diff --git a/app/blueprints/tests/templates/tests/already_logged_in.html b/app/blueprints/tests/templates/tests/already_logged_in.html new file mode 100644 index 00000000..ea9fab13 --- /dev/null +++ b/app/blueprints/tests/templates/tests/already_logged_in.html @@ -0,0 +1,9 @@ +

Already logged in

+ +{% with messages = get_flashed_messages(with_categories=true) %} + {% if messages %} + {% for category, message in messages %} + {{ category }}:{{ message }}, + {% endfor %} + {% endif %} +{% endwith %} \ No newline at end of file diff --git a/app/blueprints/tests/templates/tests/login_failed.html b/app/blueprints/tests/templates/tests/login_failed.html new file mode 100644 index 00000000..9977d0fe --- /dev/null +++ b/app/blueprints/tests/templates/tests/login_failed.html @@ -0,0 +1,9 @@ +

Login failed

+ +{% with messages = get_flashed_messages(with_categories=true) %} + {% if messages %} + {% for category, message in messages %} + {{ category }}:{{ message }}, + {% endfor %} + {% endif %} +{% endwith %} \ No newline at end of file diff --git a/app/blueprints/tests/templates/tests/security.html b/app/blueprints/tests/templates/tests/security.html index 81a6bfd0..cc443ed4 100644 --- a/app/blueprints/tests/templates/tests/security.html +++ b/app/blueprints/tests/templates/tests/security.html @@ -1,5 +1,4 @@

Security

-

Logged_in {{ session.get('logged_in') }}

permissions {{ session.get('permissions') }}

diff --git a/src/flask_imp/security.py b/src/flask_imp/security.py index 47d69353..2d4c5a60 100644 --- a/src/flask_imp/security.py +++ b/src/flask_imp/security.py @@ -1,5 +1,6 @@ import typing as t from functools import wraps +from functools import partial from flask import flash, abort from flask import redirect @@ -66,12 +67,23 @@ def login_page(): """ def login_check_wrapper(func): + @wraps(func) def inner(*args, **kwargs): skey = session.get(session_key) + def setup_flash(_message, _message_category): + if _message: + partial_flash = partial(flash, _message) + if _message_category: + partial_flash(_message_category) + else: + partial_flash() + if skey is None: if fail_endpoint: + setup_flash(message, message_category) + if endpoint_kwargs: return redirect(url_for(fail_endpoint, **endpoint_kwargs)) @@ -82,8 +94,7 @@ def inner(*args, **kwargs): if skey is not None: if _check_against_values_allowed(skey, values_allowed): if pass_endpoint: - if message: - flash(message, message_category) + setup_flash(message, message_category) if endpoint_kwargs: return redirect(url_for(pass_endpoint, **endpoint_kwargs)) @@ -93,6 +104,8 @@ def inner(*args, **kwargs): return func(*args, **kwargs) if fail_endpoint: + setup_flash(message, message_category) + if endpoint_kwargs: return redirect(url_for(fail_endpoint, **endpoint_kwargs)) @@ -141,14 +154,22 @@ def permission_check_wrapper(func): def inner(*args, **kwargs): skey = session.get(session_key) + def setup_flash(_message, _message_category): + if _message: + partial_flash = partial(flash, _message) + if _message_category: + partial_flash(_message_category) + else: + partial_flash() + if skey: if _check_against_values_allowed(skey, values_allowed): return func(*args, **kwargs) - if message: - flash(message, message_category) + setup_flash(message, message_category) if fail_endpoint: + if endpoint_kwargs: return redirect(url_for(fail_endpoint, **endpoint_kwargs)) diff --git a/tests/test_group.py b/tests/test_group.py index 6d919649..23f263dc 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -92,6 +92,18 @@ def test_security_login_fail(client): assert b"Login failed" in response.data +def test_security_login_fail_with_message(client): + client.get('/tests/logout') + response = client.get('/tests/must-be-logged-in/bool/with-flash', follow_redirects=True) + assert b"message:Login needed" in response.data + + +def test_security_already_logged_in_pass_with_message(client): + client.get('/tests/login/bool') + response = client.get('/tests/must-be-logged-in/bool/with-flash', follow_redirects=True) + assert b"message:Already logged in" in response.data + + def test_permission_list(client): client.get('/tests/set-permission/list') response = client.get('/tests/must-have-permissions/std')