Skip to content

Commit

Permalink
add header auth (#19)
Browse files Browse the repository at this point in the history
* add header auth

* first functioning version

* break after success validate & some style issue
  • Loading branch information
LeoQuote authored Jun 18, 2020
1 parent 136b5a6 commit fac3842
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 3 deletions.
3 changes: 2 additions & 1 deletion helpdesk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware

from helpdesk.libs.auth import SessionAuthBackend
from helpdesk.libs.auth import SessionAuthBackend, BearerAuthMiddleware
from helpdesk.libs.proxy import ProxyHeadersMiddleware
from helpdesk.config import DEBUG, SESSION_SECRET_KEY, SESSION_TTL, SENTRY_DSN, TRUSTED_HOSTS
from helpdesk.views.api import bp as api_bp
Expand All @@ -34,6 +34,7 @@ def create_app():
middleware = [
Middleware(ProxyHeadersMiddleware, trusted_hosts=TRUSTED_HOSTS),
Middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY, max_age=SESSION_TTL),
Middleware(BearerAuthMiddleware),
Middleware(AuthenticationMiddleware, backend=SessionAuthBackend()),
Middleware(GZipMiddleware),
Middleware(SentryMiddleware),
Expand Down
102 changes: 102 additions & 0 deletions helpdesk/libs/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,82 @@

import logging

import requests
from starlette.authentication import (
AuthenticationBackend,
AuthCredentials,
UnauthenticatedUser,
)
from starlette.middleware.base import BaseHTTPMiddleware
from authlib.jose import jwt
from authlib.jose.errors import JoseError, ExpiredTokenError

from helpdesk.config import OPENID_PRIVIDERS, oauth_username_func
from helpdesk.libs.sentry import report
from helpdesk.models.user import User

logger = logging.getLogger(__name__)


# load auth providers
class Validator:
def __init__(self, metadata_url=None, client_id=None, *args, **kwargs):
self.metadata_url = metadata_url
self.client_id = client_id
if not self.client_id:
raise ValueError('Init validator failed, client_id not set')
self.client_kwargs = kwargs.get('client_kwargs')
self.fetch_jwk()

def fetch_jwk(self):
# Fetch the public key for validating Bearer token
server_metadata = self.get(self.metadata_url)
self.jwk = self.get(server_metadata['jwks_uri'])

def get(self, *args, **kwargs):
if self.client_kwargs:
r = requests.get(*args, **kwargs, **self.client_kwargs)
else:
r = requests.get(*args, **kwargs)
r.raise_for_status()
return r.json()

def valide_token(self, token: str):
"""validate token string, return a parsed token if valid, return None if not valid
:return tuple (is_valid -> bool, id_token or None)
"""
try:
if "https://accounts.google.com" in self.metadata_url:
# google's certs would change from time to time, let's refetch it before every try
self.fetch_jwk()
token = jwt.decode(token, self.jwk)
except ValueError as e:
if str(e) == 'Invalid JWK kid':
logger.info(
'This token cannot be decoded with current provider, will try another provider if available.')
return None, None
else:
raise e

try:
token.validate()
return True, token
except ExpiredTokenError as e:
logger.info('Auth header expired, %s', e)
return True, None
except JoseError as e:
logger.debug('Jose error: %s', e)
report()
return None, None


registed_validator = {}

for provider, info in OPENID_PRIVIDERS.items():
client = Validator(metadata_url=info['server_metadata_url'], **info)
registed_validator[provider] = client


# ref: https://www.starlette.io/authentication/
class SessionAuthBackend(AuthenticationBackend):
async def authenticate(self, request):
Expand All @@ -27,5 +94,40 @@ async def authenticate(self, request):
return AuthCredentials([]), UnauthenticatedUser()


class BearerAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
authheader = request.headers.get("Authorization")
if authheader and authheader.lower().startswith("bearer "):
_, token_str = authheader.split(" ", 1)
if token_str:
for validator_name, validator in registed_validator.items():
logger.info("Trying to validate token with %s", validator_name)
valid, id_token = validator.valide_token(token_str)
if not id_token and not valid:
# not valid in this provider, try next
continue
if not id_token and valid:
# valid in this provider, but expired, stop trying other provider
break
# check aud and iss
aud = id_token.get('aud')
if id_token.get('azp') != validator.client_id and (not aud or validator.client_id not in aud):
logger.info('Token is valid, not expired, but not belonged to this client')
break
username = oauth_username_func(id_token)
email = id_token.get('email', '')
roles = []
access = id_token.get('resource_access', {})
for rs in access.values():
roles.extend(rs.get('roles', []))

user = User(username, email, roles, id_token.get('picture', ''))

request.session['user'] = user.to_json()
break
response = await call_next(request)
return response


def unauth(request):
return request.session.pop('user', None)
2 changes: 0 additions & 2 deletions helpdesk/views/auth/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ async def callback(request):

username = oauth_username_func(id_token)
email = id_token['email']
if not User.validate_email(email):
return HTMLResponse("invalid email", 403)

roles = []
access = id_token.get('resource_access', {})
Expand Down

0 comments on commit fac3842

Please sign in to comment.