Skip to content

Commit

Permalink
OIDC: Reference group by FK, set workspace attribute to False (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
psrok1 authored Aug 20, 2024
1 parent 34fbcce commit b15ba09
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""OIDC group referenced by id instead of name + convert to non-workspace
Revision ID: 72a94f88d2b6
Revises: 6fc42e070495
Create Date: 2024-08-20 13:29:36.839985
"""
import logging

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "72a94f88d2b6"
down_revision = "6fc42e070495"
branch_labels = None
depends_on = None

group_helper = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Integer()),
sa.Column("name", sa.String(32)),
sa.Column("private", sa.Boolean()),
sa.Column("default", sa.Boolean()),
sa.Column("workspace", sa.Boolean()),
)

provider_helper = sa.Table(
"openid_provider",
sa.MetaData(),
sa.Column("id", sa.Integer()),
sa.Column("name", sa.String(64)),
sa.Column("group_id", sa.Integer()),
)

logger = logging.getLogger("alembic")


def group_name_from_provider_name(provider_name):
return ("OpenID_" + provider_name)[:32]


def upgrade():
connection = op.get_bind()
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("openid_provider", sa.Column("group_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "openid_provider", "group", ["group_id"], ["id"])

# Migrate existing providers
for provider in connection.execute(provider_helper.select()):
group_name = group_name_from_provider_name(provider.name)
group = connection.execute(
group_helper.select().where(group_helper.c.name == group_name)
).first()
connection.execute(
group_helper.update()
.where(group_helper.c.name == group_name)
.values(workspace=False)
)
connection.execute(
provider_helper.update()
.where(provider_helper.c.id == provider.id)
.values(group_id=group.id)
)

# Set group_id as non-nullable
op.alter_column(
"openid_provider", "group_id", existing_type=sa.INTEGER(), nullable=False
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "openid_provider", type_="foreignkey")
op.drop_column("openid_provider", "group_id")
# ### end Alembic commands ###
18 changes: 9 additions & 9 deletions mwdb/model/oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from werkzeug.exceptions import NotFound

from mwdb.core.oauth import OpenIDClient
from mwdb.model import Group

from . import db

Expand All @@ -19,11 +16,17 @@ class OpenIDProvider(db.Model):
jwks_endpoint = db.Column(db.Text, nullable=True)
logout_endpoint = db.Column(db.Text, nullable=True)

group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False)

identities = db.relationship(
"OpenIDUserIdentity",
back_populates="provider",
cascade="all, delete-orphan",
)
group = db.relationship(
"Group",
cascade="all, delete",
)

def get_oidc_client(self):
return OpenIDClient(
Expand All @@ -39,12 +42,9 @@ def get_oidc_client(self):
state=None,
)

def get_group(self):
group_name = ("OpenID_" + self.name)[:32]
group = db.session.query(Group).filter(Group.name == group_name).first()
if group is None:
raise NotFound("No such group")
return group
@property
def group_name(self):
return ("OpenID_" + self.name)[:32]


class OpenIDUserIdentity(db.Model):
Expand Down
23 changes: 12 additions & 11 deletions mwdb/resources/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,22 @@ def post(self):
logout_endpoint=logout_endpoint,
)

group_name = ("OpenID_" + obj["name"])[:32]

group_name_obj = load_schema({"name": group_name}, GroupNameSchemaBase())
group_name_obj = load_schema(
{"name": provider.group_name}, GroupNameSchemaBase()
)

if db.session.query(
exists().where(Group.name == group_name_obj["name"])
).scalar():
raise Conflict("Group exists yet, choose another provider name")

group = Group(name=group_name_obj["name"], immutable=True)

group = Group(name=group_name_obj["name"], immutable=True, workspace=False)
db.session.add(group)
db.session.add(provider)
db.session.flush()
db.session.refresh(group)

provider.group_id = group.id
db.session.add(provider)
db.session.commit()
hooks.on_created_group(group)

Expand Down Expand Up @@ -301,15 +303,14 @@ def delete(self, provider_name):
.filter(OpenIDProvider.name == provider_name)
.first()
)
provider_group_name = provider.group_name
if not provider:
raise NotFound(f"Requested provider name '{provider_name}' not found")
group = provider.get_group()

db.session.delete(provider)
db.session.delete(group)
db.session.commit()

hooks.on_removed_group(("OpenID_" + provider_name)[:32])
hooks.on_removed_group(provider_group_name)
logger.info("Provider was deleted", extra={"provider": provider_name})
schema = OpenIDProviderSuccessResponseSchema()
return schema.dump({"name": provider_name})
Expand Down Expand Up @@ -429,7 +430,7 @@ def post(self, provider_name):
if not provider:
raise NotFound(f"Requested provider name '{provider_name}' not found")

group = provider.get_group()
group = provider.group

schema = OpenIDAuthorizeRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)
Expand Down Expand Up @@ -564,7 +565,7 @@ def post(self, provider_name):
if not provider:
raise NotFound(f"Requested provider name '{provider_name}' not found")

group = provider.get_group()
group = provider.group

schema = OpenIDAuthorizeRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)
Expand Down

0 comments on commit b15ba09

Please sign in to comment.