diff --git a/server/src/endpoints/google_oauth.py b/server/src/endpoints/google_oauth.py index fe519a8..8c0ff3b 100644 --- a/server/src/endpoints/google_oauth.py +++ b/server/src/endpoints/google_oauth.py @@ -9,6 +9,7 @@ from starlette.config import Config from starlette.responses import RedirectResponse, HTMLResponse from starlette.middleware.sessions import SessionMiddleware +from starlette.datastructures import URL from src.core.database import get_db from src.crud.user import user_crud @@ -60,15 +61,24 @@ async def login(request: Request): ) callback_url = request.url_for("callback") - redirect_uri = ( - f"{request.scope['scheme']}://{callback_url.netloc}{callback_url.path}" - ) - print("redirect_uri", redirect_uri) + + # workaround to recover protocol for callback when called via https without a reverse proxy + prompton_env = request.query_params.get("prompton_env") + if prompton_env is None: + redirect_url = ( + f"{request.scope['scheme']}://{callback_url.netloc}{callback_url.path}" + ) + else: + prompton_env_url = URL(prompton_env) + redirect_url = ( + f"{prompton_env_url.scheme}://{callback_url.netloc}{callback_url.path}" + ) + request.session["original_query_params"] = urlencode( request.query_params.multi_items(), doseq=True ) - return await oauth.google.authorize_redirect(request, redirect_uri) + return await oauth.google.authorize_redirect(request, redirect_url) @app.get("/callback") @@ -81,21 +91,23 @@ async def callback(request: Request, db=Depends(get_db)): token = await oauth.google.authorize_access_token(request) user = token.get("userinfo") if user: + if not user.get("email_verified"): + raise OAuthSignInFailed("Google Email is not verified.") + request.session["user_email"] = user.get("email") request.session["user"] = user original_query_params = parse_qs( - request.session.get("original_query_params", "") + request.session.get("original_query_params", None) ) - logged_in_redirect_uri = original_query_params.get( - "logged_in_redirect_uri", ["/oauth/logged_in"] - )[0] - - del original_query_params["logged_in_redirect_uri"] - - if not user.get("email_verified"): - raise OAuthSignInFailed("Google Email is not verified.") + if "logged_in_redirect_uri" in original_query_params: + logged_in_redirect_uri = original_query_params.get( + "logged_in_redirect_uri", ["/oauth/logged_in"] + )[0] + del original_query_params["logged_in_redirect_uri"] + else: + logged_in_redirect_uri = "/oauth/logged_in" headers = {}