Skip to content

Commit

Permalink
Epic.auth compliance integration
Browse files Browse the repository at this point in the history
  • Loading branch information
dinesh-aot committed Aug 13, 2024
1 parent a3081ae commit d63c321
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""staff_user_unique_constraint
Revision ID: f9a414e34ea4
Revises: 8348b72b25df
Create Date: 2024-08-13 11:06:16.532337
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'f9a414e34ea4'
down_revision = '8348b72b25df'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('staff_users', schema=None) as batch_op:
batch_op.alter_column('id',
existing_type=sa.INTEGER(),
comment='The unique identifier of the staff user.',
existing_nullable=False,
autoincrement=True)
batch_op.alter_column('first_name',
existing_type=sa.VARCHAR(length=50),
comment='The firstname of the staff user.',
existing_nullable=True)
batch_op.alter_column('last_name',
existing_type=sa.VARCHAR(length=50),
comment='The lastname of the staff user.',
existing_nullable=True)
batch_op.alter_column('position_id',
existing_type=sa.INTEGER(),
comment='The unique identifier of the position of the staff user.',
existing_nullable=False)
batch_op.alter_column('deputy_director_id',
existing_type=sa.INTEGER(),
comment='The unique identifier of the deputy director.',
existing_nullable=True)
batch_op.alter_column('supervisor_id',
existing_type=sa.INTEGER(),
comment='The unique identifier of the supervisor.',
existing_nullable=True)
batch_op.alter_column('auth_user_guid',
existing_type=sa.VARCHAR(length=100),
comment='The unique identifier from the identity provider.',
existing_nullable=True)
batch_op.drop_index('ix_staff_users_auth_user_guid')
batch_op.create_index(batch_op.f('ix_staff_users_auth_user_guid'), ['auth_user_guid'], unique=False)
batch_op.create_index('uq_auth_user_guid_is_deleted_false', ['auth_user_guid'], unique=True, postgresql_where=sa.text('is_deleted = false'))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('staff_users', schema=None) as batch_op:
batch_op.drop_index('uq_auth_user_guid_is_deleted_false', postgresql_where=sa.text('is_deleted = false'))
batch_op.drop_index(batch_op.f('ix_staff_users_auth_user_guid'))
batch_op.create_index('ix_staff_users_auth_user_guid', ['auth_user_guid'], unique=True)
batch_op.alter_column('auth_user_guid',
existing_type=sa.VARCHAR(length=100),
comment=None,
existing_comment='The unique identifier from the identity provider.',
existing_nullable=True)
batch_op.alter_column('supervisor_id',
existing_type=sa.INTEGER(),
comment=None,
existing_comment='The unique identifier of the supervisor.',
existing_nullable=True)
batch_op.alter_column('deputy_director_id',
existing_type=sa.INTEGER(),
comment=None,
existing_comment='The unique identifier of the deputy director.',
existing_nullable=True)
batch_op.alter_column('position_id',
existing_type=sa.INTEGER(),
comment=None,
existing_comment='The unique identifier of the position of the staff user.',
existing_nullable=False)
batch_op.alter_column('last_name',
existing_type=sa.VARCHAR(length=50),
comment=None,
existing_comment='The lastname of the staff user.',
existing_nullable=True)
batch_op.alter_column('first_name',
existing_type=sa.VARCHAR(length=50),
comment=None,
existing_comment='The firstname of the staff user.',
existing_nullable=True)
batch_op.alter_column('id',
existing_type=sa.INTEGER(),
comment=None,
existing_comment='The unique identifier of the staff user.',
existing_nullable=False,
autoincrement=True)

# ### end Alembic commands ###
5 changes: 5 additions & 0 deletions compliance-api/src/compliance_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def create_app(run_mode=os.getenv("FLASK_ENV", "development")):
@app.before_request
def set_origin():
g.origin_url = request.environ.get("HTTP_ORIGIN", "localhost")
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
g.access_token = auth_header.split(' ')[1]
else:
g.access_token = None

build_cache(app)

Expand Down
26 changes: 23 additions & 3 deletions compliance-api/src/compliance_api/models/staff_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
from typing import Optional

from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String
from sqlalchemy.orm import relationship

from .base_model import BaseModel
Expand Down Expand Up @@ -63,10 +63,22 @@ class StaffUser(BaseModel):
auth_user_guid = Column(
String(100),
index=True,
unique=True,
comment="The unique identifier from the identity provider.",
)
position = relationship("Position", foreign_keys=[position_id], lazy="select")
deputy_director = relationship(
"StaffUser", foreign_keys=[deputy_director_id], lazy="select"
)
supervisor = relationship("StaffUser", foreign_keys=[supervisor_id], lazy="select")
is_deleted = Column(Boolean, default=False, server_default="f", nullable=False)
__table_args__ = (
Index(
"uq_auth_user_guid_is_deleted_false",
"auth_user_guid",
unique=True,
postgresql_where=(is_deleted is False),
),
)

@classmethod
def create_user(cls, user_data, session=None) -> StaffUser:
Expand All @@ -84,11 +96,19 @@ def update_user(cls, user_id, user_dict, session=None) -> Optional[StaffUser]:
"""Update user."""
query = StaffUser.query.filter_by(id=user_id)
user: StaffUser = query.first()
if not user:
if not user or user.is_deleted:
return None
query.update(user_dict)
if session:
session.flush()
else:
cls.session.commit()
return user

@classmethod
def get_staff_user_by_auth_guid(cls, auth_guid: str) -> StaffUser:
"""Retrieve the staff user by auth_guid."""
staff_user = StaffUser.query.filter_by(
auth_user_guid=auth_guid, is_deleted=False
).first()
return staff_user
2 changes: 1 addition & 1 deletion compliance-api/src/compliance_api/resources/staff_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get():
@auth.require
@ApiHelper.swagger_decorators(API, endpoint_description="Create a user")
@API.expect(user_request_model)
@API.response(code=201, model=user_request_model, description="UserCreated")
@API.response(code=201, model=user_list_model, description="UserCreated")
@API.response(400, "Bad Request")
def post():
"""Create a user."""
Expand Down
7 changes: 6 additions & 1 deletion compliance-api/src/compliance_api/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Super class to handle all operations related to base schema."""

from flask import json
from marshmallow import Schema, fields, post_dump

from compliance_api.exceptions import BadRequestError
from compliance_api.models.db import ma


Expand All @@ -30,6 +31,10 @@ def __init__(self, *args, **kwargs):
self.exclude = getattr(self.Meta, "exclude", ()) + ("versions",)
super().__init__(*args, **kwargs)

def handle_error(self, error, data, **kwargs):
"""Log and raise our custom exception when validation fails."""
raise BadRequestError(json.dumps(error.messages))

class Meta: # pylint: disable=too-few-public-methods
"""Meta class to declare any class attributes."""

Expand Down
53 changes: 43 additions & 10 deletions compliance-api/src/compliance_api/schemas/staff_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staff User Schema."""
from marshmallow import EXCLUDE, Schema, fields
from marshmallow import EXCLUDE, fields, post_dump, post_load
from marshmallow_enum import EnumField

from compliance_api.models.staff_user import PermissionEnum, StaffUser

from .base_schema import AutoSchemaBase
from .base_schema import AutoSchemaBase, BaseSchema
from .common import KeyValueSchema


class StaffUserSchema(AutoSchemaBase): # pylint: disable=too-many-ancestors
"""Staff User schema."""
class StaffUserSchemaSkeleton(AutoSchemaBase): # pylint: disable=too-many-ancestors
"""Basic schema for staff user."""

class Meta(AutoSchemaBase.Meta): # pylint: disable=too-few-public-methods
"""Exclude unknown fields in the deserialized output."""
"""Meta."""

unknown = EXCLUDE
model = StaffUser
include_fk = True

position = fields.Nested(KeyValueSchema, dump_only=True)
permission = fields.Str(
metadata={"description": "The permission level of the user in the app"}
)
full_name = fields.Method("get_full_name")

def get_full_name(self, obj): # pylint: disable=no-self-use
"""Derive fullname."""
return f"{obj.first_name} {obj.last_name}"
return f"{obj.first_name} {obj.last_name}" if obj else ""

@post_dump
def nullify_nested(
self, data, **kwargs
): # pylint: disable=no-self-use, unused-argument
"""Make nested objects null if the referenced ID is null."""
if data.get("deputy_director_id") is None:
data["deputy_director"] = None
if data.get("supervisor_id") is None:
data["supervisor"] = None
return data

class StaffUserCreateSchema(Schema):

class StaffUserSchema(StaffUserSchemaSkeleton): # pylint: disable=too-many-ancestors
"""Staff User schema."""

deputy_director = fields.Nested(StaffUserSchemaSkeleton, dump_only=True)
supervisor = fields.Nested(StaffUserSchemaSkeleton, dump_only=True)


class StaffUserCreateSchema(BaseSchema):
"""User create Schema."""

class Meta: # pylint: disable=too-few-public-methods
Expand All @@ -54,12 +75,14 @@ class Meta: # pylint: disable=too-few-public-methods
required=True,
)
deputy_director_id = fields.Int(
metadata={"description": "The unique identifier of the deputy director."}
metadata={"description": "The unique identifier of the deputy director."},
allow_none=True,
)
supervisor_id = fields.Int(
metadata={"description": "The unique identifier of the supervisor."}
metadata={"description": "The unique identifier of the supervisor."},
allow_none=True,
)
auth_user_id = fields.Str(
auth_user_guid = fields.Str(
metadata={"description": "The unique identifier from the identity provider."},
required=True,
)
Expand All @@ -69,3 +92,13 @@ class Meta: # pylint: disable=too-few-public-methods
by_value=True,
required=True,
)

@post_load
def extract_permission_value(
self, data, **kwargs
): # pylint: disable=no-self-use, unused-argument
"""Extract the value of the permission enum."""
permission_enum = data.get("permission")
if permission_enum:
data["permission"] = permission_enum.value
return data
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,30 @@ class AuthService:
"""Handle service request for epic.authorize."""

@staticmethod
def get_epic_user_by_id(auth_user_id: str):
def get_epic_user_by_guid(auth_user_guid: str):
"""Return the user representation from epic.authorize."""
auth_user_response_json = _request_auth_service(f"/users/{auth_user_id}")
return AuthUserSchema().load(auth_user_response_json)
auth_user_response = _request_auth_service(f"users/{auth_user_guid}")
if auth_user_response.status_code != 200:
raise BusinessError(
f"Error finding user with ID {auth_user_guid} from auth server"
)
return AuthUserSchema().load(auth_user_response.json())

@staticmethod
def update_user_group(auth_user_id: str, payload: dict):
def update_user_group(auth_user_guid: str, payload: dict):
"""Update the group of the user in the identity server."""
update_group_response = _request_auth_service(
f"/users/{auth_user_id}/group", HttpMethod.PATCH, payload
f"users/{auth_user_guid}/groups", HttpMethod.PUT, payload
)
if update_group_response.status_code != 204:
raise BusinessError(
f"Update group in the auth server failed for user : {auth_user_guid}"
)
return update_group_response


def _request_auth_service(
self, relative_url, http_method: HttpMethod = HttpMethod.GET, data=None
relative_url, http_method: HttpMethod = HttpMethod.GET, data=None
):
"""REST Api call to authorize service."""
token = getattr(g, "access_token", None)
Expand All @@ -41,13 +49,17 @@ def _request_auth_service(
"Authorization": f"Bearer {token}",
}

url = f"{auth_base_url}/{relative_url}"
url = f"{auth_base_url}/api/{relative_url}"

if http_method == HttpMethod.GET:
response = requests.get(url, headers=headers, timeout=API_REQUEST_TIMEOUT)
elif http_method == HttpMethod.PUT:
response = requests.put(
url, headers=headers, data=data, timeout=API_REQUEST_TIMEOUT
url, headers=headers, json=data, timeout=API_REQUEST_TIMEOUT
)
elif http_method == HttpMethod.PATCH:
response = requests.patch(
url, headers=headers, json=data, timeout=API_REQUEST_TIMEOUT
)
elif http_method == HttpMethod.DELETE:
response = requests.delete(url, headers=headers, timeout=API_REQUEST_TIMEOUT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Schema for user representation from epic.authorize."""
from marshmallow import Schema, fields
from marshmallow import EXCLUDE, Schema, fields


class AuthUserSchema(Schema):
"""Schema for the auth user."""

first_name = fields.Str(metadata={"description": "The first name of the user"}, required=True)
last_name = fields.Str(metadata={"description": "The lastname of the user"}, required=True)
id = fields.Str(metadata={"description": "The unique id of the user"}, required=True)
class Meta:
"""Meta for AuthUserSchema."""

unknown = EXCLUDE

first_name = fields.Str(
metadata={"description": "The first name of the user"}, required=True
)
last_name = fields.Str(
metadata={"description": "The lastname of the user"}, required=True
)
id = fields.Str(
data_key="username",
metadata={"description": "The unique id of the user"},
required=True,
)
Loading

0 comments on commit d63c321

Please sign in to comment.