Skip to content

Commit

Permalink
Support next with Proton Link (#1226)
Browse files Browse the repository at this point in the history
* Support next with Proton Link

* Add support for double next

* Fix bug on account relink
  • Loading branch information
cquintana92 authored Aug 11, 2022
1 parent 3a75686 commit 596dd0b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
3 changes: 1 addition & 2 deletions app/account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def process_link_case(
return link_user(link_request, current_user, partner)

# There is a SL user registered with the partner. Check if is the current one
if partner_user.id == current_user.id:
if partner_user.user_id == current_user.id:
# Update plan
set_plan_for_partner_user(partner_user, link_request.plan)
# It's the same user. No need to do anything
Expand All @@ -285,5 +285,4 @@ def process_link_case(
strategy="Link",
)
else:

return switch_already_linked_user(link_request, partner_user, current_user)
22 changes: 16 additions & 6 deletions app/auth/views/proton.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Action,
)
from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url
from app.utils import sanitize_next_url, sanitize_scheme

_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
_token_url = PROTON_BASE_URL + "/oauth/token"
Expand All @@ -34,6 +34,7 @@

SESSION_ACTION_KEY = "oauth_action"
SESSION_STATE_KEY = "oauth_state"
DEFAULT_SCHEME = "auth.simplelogin"


def get_api_key_for_user(user: User) -> str:
Expand Down Expand Up @@ -75,6 +76,12 @@ def proton_login():
elif "oauth_next" in session:
del session["oauth_next"]

scheme = sanitize_scheme(request.args.get("scheme"))
if scheme:
session["oauth_scheme"] = scheme
elif "oauth_scheme" in session:
del session["oauth_scheme"]

mode = request.args.get("mode", "session")
if mode == "apikey":
session["oauth_mode"] = "apikey"
Expand Down Expand Up @@ -146,6 +153,7 @@ def check_status_code(response: requests.Response) -> requests.Response:
handler = ProtonCallbackHandler(proton_client)
proton_partner = get_proton_partner()

next_url = session.get("oauth_next")
if action == Action.Login:
res = handler.handle_login(proton_partner)
elif action == Action.Link:
Expand All @@ -156,15 +164,17 @@ def check_status_code(response: requests.Response) -> requests.Response:
if res.flash_message is not None:
flash(res.flash_message, res.flash_category)

oauth_scheme = session.get("oauth_scheme")
if session.get("oauth_mode", "session") == "apikey":
apikey = get_api_key_for_user(res.user)
return redirect(f"auth.simplelogin://callback?apikey={apikey}")
scheme = oauth_scheme or DEFAULT_SCHEME
return redirect(f"{scheme}:///login_callback?apikey={apikey}")

if res.redirect_to_login:
return redirect(url_for("auth.login"))

if res.redirect:
return after_login(res.user, res.redirect, login_from_proton=True)
if next_url and next_url[0] == "/" and oauth_scheme:
next_url = f"{oauth_scheme}://{next_url}"

next_url = session.get("oauth_next")
return after_login(res.user, next_url, login_from_proton=True)
redirect_url = next_url or res.redirect
return after_login(res.user, redirect_url, login_from_proton=True)
4 changes: 3 additions & 1 deletion app/proton/proton_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def handle_login(self, partner: Partner) -> ProtonCallbackResult:
)

def handle_link(
self, current_user: Optional[User], partner: Partner
self,
current_user: Optional[User],
partner: Partner,
) -> ProtonCallbackResult:
if current_user is None:
raise Exception("Cannot link account with current_user being None")
Expand Down
14 changes: 14 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import secrets
import string
import time
Expand Down Expand Up @@ -88,6 +89,8 @@ def sanitize(url: Optional[str], allowed_domains: List[str]) -> Optional[str]:
else:
return None
if result.path and result.path[0] == "/" and not result.path.startswith("//"):
if result.query:
return f"{result.path}?{result.query}"
return result.path

return None
Expand All @@ -97,6 +100,17 @@ def sanitize_next_url(url: Optional[str]) -> Optional[str]:
return NextUrlSanitizer.sanitize(url, ALLOWED_REDIRECT_DOMAINS)


def sanitize_scheme(scheme: Optional[str]) -> Optional[str]:
if not scheme:
return None
if scheme in ["http", "https"]:
return None
scheme_regex = re.compile("^[a-z.]+$")
if scheme_regex.match(scheme):
return scheme
return None


def query2str(query):
"""Useful utility method to print out a SQLAlchemy query"""
return query.statement.compile(compile_kwargs={"literal_binds": True})
Expand Down

0 comments on commit 596dd0b

Please sign in to comment.