Skip to content

Commit

Permalink
OIDC: Plugin-customizable OpenIDProvider class (#982)
Browse files Browse the repository at this point in the history
  • Loading branch information
psrok1 authored Sep 25, 2024
1 parent 930c530 commit a65ec4b
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 127 deletions.
Empty file added mwdb/core/oauth/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions mwdb/core/oauth.py → mwdb/core/oauth/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@


class OpenIDClient:
"""
Stateful client representing OpenID Connect session using
specified client and provider data
"""

supported_algorithms = ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"]

def __init__(
Expand Down
122 changes: 122 additions & 0 deletions mwdb/core/oauth/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import hashlib
from typing import TYPE_CHECKING, Iterator

from authlib.oidc.core import UserInfo
from marshmallow import ValidationError
from sqlalchemy import exists

from mwdb.schema.user import UserLoginSchemaBase

from .client import OpenIDClient

if TYPE_CHECKING:
from mwdb.model import Group, User


class OpenIDProvider:
"""
OpenID Connect Identity Provider representation with generic handlers.
You can override these methods with your own implementation
that is specific for provider.
"""

scope = "openid profile email"

def __init__(
self,
name,
client_id,
client_secret,
authorization_endpoint,
token_endpoint,
userinfo_endpoint,
jwks_uri,
):
self.name = name
self.client = OpenIDClient(
client_id=client_id,
client_secret=client_secret,
grant_type="authorization_code",
response_type="code",
scope=self.scope,
authorization_endpoint=authorization_endpoint,
token_endpoint=token_endpoint,
userinfo_endpoint=userinfo_endpoint,
jwks_uri=jwks_uri,
state=None,
)

def get_group_name(self) -> str:
"""
Group name that is used for registering a new OpenID provider
"""
return ("OpenID_" + self.name)[:32]

def create_provider_group(self) -> "Group":
"""
Creates a Group model object for a new OpenID provider
"""
from mwdb.model import Group

group_name = self.get_group_name()
return Group(name=group_name, immutable=True, workspace=False)

def iter_user_name_variants(self, sub: bytes, userinfo: UserInfo) -> Iterator[str]:
"""
Yield username variants that are used when user registers using OpenID identity
Usernames are yielded starting from most-preferred
"""
login_claims = ["preferred_username", "nickname", "name"]

for claim in login_claims:
username = userinfo.get(claim)
if not username:
continue
yield username
# If no candidates in claims: try fallback login
sub_md5 = hashlib.md5(sub.encode("utf-8")).hexdigest()[:8]
yield f"{self.name}-{sub_md5}"

def get_user_email(self, sub: bytes, userinfo: UserInfo) -> str:
"""
User e-mail that is used when user registers using OpenID identity
"""
if "email" in userinfo.keys():
return userinfo["email"]
else:
return f"{sub}@mwdb.local"

def get_user_description(self, sub: bytes, userinfo: UserInfo) -> str:
"""
User description that is used when user registers using OpenID identity
"""
return "Registered via OpenID Connect protocol"

def create_user(self, sub: bytes, userinfo: UserInfo) -> "User":
"""
Creates a User model object for a new OpenID identity user
"""
from mwdb.model import Group, User, db

for username in self.iter_user_name_variants(sub, userinfo):
try:
UserLoginSchemaBase().load({"login": username})
except ValidationError:
continue
already_exists = db.session.query(
exists().where(Group.name == username)
).scalar()
if not already_exists:
break
else:
raise RuntimeError("Can't find any good username candidate for user")

user_email = self.get_user_email(sub, userinfo)
user_description = self.get_user_description(sub, userinfo)
return User.create(
username,
user_email,
user_description,
)
5 changes: 5 additions & 0 deletions mwdb/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

_plugin_handlers = []
loaded_plugins = {}
openid_provider_classes = {}


class PluginAppContext(object):
Expand All @@ -33,6 +34,10 @@ def register_converter(self, converter_name, converter):
def register_schema_spec(self, schema_name, schema):
api.spec.components.schema(schema_name, schema=schema)

def register_openid_provider_class(self, provider_name, provider_class):
global openid_provider_classes
openid_provider_classes[provider_name] = provider_class


def hook_handler_method(meth):
@functools.wraps(meth)
Expand Down
4 changes: 2 additions & 2 deletions mwdb/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
from .file import File # noqa: E402
from .group import Group, Member # noqa: E402
from .karton import KartonAnalysis, karton_object # noqa: E402
from .oauth import OpenIDProvider, OpenIDUserIdentity # noqa: E402
from .oauth import OpenIDProviderSettings, OpenIDUserIdentity # noqa: E402
from .object import Object, relation # noqa: E402
from .object_permission import ObjectPermission # noqa: E402
from .quick_query import QuickQuery # noqa: E402
Expand All @@ -74,7 +74,7 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
"AttributePermission",
"Object",
"ObjectPermission",
"OpenIDProvider",
"OpenIDProviderSettings",
"OpenIDUserIdentity",
"relation",
"QuickQuery",
Expand Down
28 changes: 15 additions & 13 deletions mwdb/model/oauth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from mwdb.core.oauth import OpenIDClient
from typing import Type

from mwdb.core.oauth.provider import OpenIDProvider

from . import db


class OpenIDProvider(db.Model):
def get_oidc_provider_class(provider_name: str) -> Type[OpenIDProvider]:
from mwdb.core.plugins import openid_provider_classes

return openid_provider_classes.get(provider_name, OpenIDProvider)


class OpenIDProviderSettings(db.Model):
__tablename__ = "openid_provider"

id = db.Column(db.Integer, primary_key=True, autoincrement=True)
Expand All @@ -28,24 +36,18 @@ class OpenIDProvider(db.Model):
cascade="all, delete",
)

def get_oidc_client(self):
return OpenIDClient(
def get_oidc_provider(self):
openid_provider_class = get_oidc_provider_class(self.name)
return openid_provider_class(
name=self.name,
client_id=self.client_id,
client_secret=self.client_secret,
scope="openid profile email",
grant_type="authorization_code",
response_type="code",
authorization_endpoint=self.authorization_endpoint,
token_endpoint=self.token_endpoint,
userinfo_endpoint=self.userinfo_endpoint,
jwks_uri=self.jwks_endpoint,
state=None,
)

@property
def group_name(self):
return ("OpenID_" + self.name)[:32]


class OpenIDUserIdentity(db.Model):
__tablename__ = "openid_identity"
Expand All @@ -63,5 +65,5 @@ class OpenIDUserIdentity(db.Model):

user = db.relationship("User", back_populates="openid_identities")
provider = db.relationship(
OpenIDProvider, back_populates="identities", lazy="selectin"
OpenIDProviderSettings, back_populates="identities", lazy="selectin"
)
Loading

0 comments on commit a65ec4b

Please sign in to comment.