diff --git a/app/account_linking.py b/app/account_linking.py index 781d638c9..88ef1fbda 100644 --- a/app/account_linking.py +++ b/app/account_linking.py @@ -276,7 +276,7 @@ def process_link_case( return link_user(link_request, current_user, partner) # There is a SL user registered with the partner. Check if is the current one - if partner_user.id == current_user.id: + if partner_user.user_id == current_user.id: # Update plan set_plan_for_partner_user(partner_user, link_request.plan) # It's the same user. No need to do anything @@ -285,5 +285,4 @@ def process_link_case( strategy="Link", ) else: - return switch_already_linked_user(link_request, partner_user, current_user) diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index 8e3551b41..52fa148a6 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -23,7 +23,7 @@ Action, ) from app.proton.utils import get_proton_partner -from app.utils import sanitize_next_url +from app.utils import sanitize_next_url, sanitize_scheme _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _token_url = PROTON_BASE_URL + "/oauth/token" @@ -34,6 +34,7 @@ SESSION_ACTION_KEY = "oauth_action" SESSION_STATE_KEY = "oauth_state" +DEFAULT_SCHEME = "auth.simplelogin" def get_api_key_for_user(user: User) -> str: @@ -75,6 +76,12 @@ def proton_login(): elif "oauth_next" in session: del session["oauth_next"] + scheme = sanitize_scheme(request.args.get("scheme")) + if scheme: + session["oauth_scheme"] = scheme + elif "oauth_scheme" in session: + del session["oauth_scheme"] + mode = request.args.get("mode", "session") if mode == "apikey": session["oauth_mode"] = "apikey" @@ -146,6 +153,7 @@ def check_status_code(response: requests.Response) -> requests.Response: handler = ProtonCallbackHandler(proton_client) proton_partner = get_proton_partner() + next_url = session.get("oauth_next") if action == Action.Login: res = handler.handle_login(proton_partner) elif action == Action.Link: @@ -156,15 +164,17 @@ def check_status_code(response: requests.Response) -> requests.Response: if res.flash_message is not None: flash(res.flash_message, res.flash_category) + oauth_scheme = session.get("oauth_scheme") if session.get("oauth_mode", "session") == "apikey": apikey = get_api_key_for_user(res.user) - return redirect(f"auth.simplelogin://callback?apikey={apikey}") + scheme = oauth_scheme or DEFAULT_SCHEME + return redirect(f"{scheme}:///login_callback?apikey={apikey}") if res.redirect_to_login: return redirect(url_for("auth.login")) - if res.redirect: - return after_login(res.user, res.redirect, login_from_proton=True) + if next_url and next_url[0] == "/" and oauth_scheme: + next_url = f"{oauth_scheme}://{next_url}" - next_url = session.get("oauth_next") - return after_login(res.user, next_url, login_from_proton=True) + redirect_url = next_url or res.redirect + return after_login(res.user, redirect_url, login_from_proton=True) diff --git a/app/proton/proton_callback_handler.py b/app/proton/proton_callback_handler.py index dd504f6db..53c807631 100644 --- a/app/proton/proton_callback_handler.py +++ b/app/proton/proton_callback_handler.py @@ -64,7 +64,9 @@ def handle_login(self, partner: Partner) -> ProtonCallbackResult: ) def handle_link( - self, current_user: Optional[User], partner: Partner + self, + current_user: Optional[User], + partner: Partner, ) -> ProtonCallbackResult: if current_user is None: raise Exception("Cannot link account with current_user being None") diff --git a/app/utils.py b/app/utils.py index ba5c8657e..4bf954422 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,3 +1,4 @@ +import re import secrets import string import time @@ -88,6 +89,8 @@ def sanitize(url: Optional[str], allowed_domains: List[str]) -> Optional[str]: else: return None if result.path and result.path[0] == "/" and not result.path.startswith("//"): + if result.query: + return f"{result.path}?{result.query}" return result.path return None @@ -97,6 +100,17 @@ def sanitize_next_url(url: Optional[str]) -> Optional[str]: return NextUrlSanitizer.sanitize(url, ALLOWED_REDIRECT_DOMAINS) +def sanitize_scheme(scheme: Optional[str]) -> Optional[str]: + if not scheme: + return None + if scheme in ["http", "https"]: + return None + scheme_regex = re.compile("^[a-z.]+$") + if scheme_regex.match(scheme): + return scheme + return None + + def query2str(query): """Useful utility method to print out a SQLAlchemy query""" return query.statement.compile(compile_kwargs={"literal_binds": True})