diff --git a/compliance-api/src/compliance_api/models/base_model.py b/compliance-api/src/compliance_api/models/base_model.py index b96c804f..fbfb5fa0 100644 --- a/compliance-api/src/compliance_api/models/base_model.py +++ b/compliance-api/src/compliance_api/models/base_model.py @@ -14,7 +14,7 @@ """Super class to handle all operations related to base model.""" from datetime import datetime -from sqlalchemy import Boolean, Column, DateTime, String +from sqlalchemy import Boolean, Column, DateTime, String, asc from .db import db @@ -34,9 +34,28 @@ class BaseModel(db.Model): is_deleted = Column(Boolean, default=False, server_default="f", nullable=False) @classmethod - def get_all(cls): + def get_all(cls, default_filters=True): """Fetch list of users by access type.""" - return cls.query.all() + query = {} + if default_filters and hasattr(cls, "is_active"): + query["is_active"] = True + if hasattr(cls, "is_deleted"): + query["is_deleted"] = False + rows = cls.query.filter_by(**query).all() # pylint: disable=no-member + return rows + + @classmethod + def get_by_params(cls, params: dict, default_filters=True): + """Return based on the params.""" + query = {} + for key, value in params.items(): + query[key] = value + if default_filters and hasattr(cls, "is_active"): + query["is_active"] = True + if hasattr(cls, "is_deleted"): + query["is_deleted"] = False + rows = cls.query.filter_by(**query).order_by(asc("id")).all() + return rows @classmethod def find_by_id(cls, identifier: int): diff --git a/compliance-api/src/compliance_api/models/staff_user.py b/compliance-api/src/compliance_api/models/staff_user.py index 68d58106..4fe2cea8 100644 --- a/compliance-api/src/compliance_api/models/staff_user.py +++ b/compliance-api/src/compliance_api/models/staff_user.py @@ -9,7 +9,7 @@ from typing import Optional from sqlalchemy import Column, ForeignKey, Integer, String -from sqlalchemy.orm import column_property, relationship +from sqlalchemy.orm import relationship from .base_model import BaseModel @@ -34,26 +34,38 @@ class StaffUser(BaseModel): __tablename__ = "staff_users" - id = Column(Integer, primary_key=True, autoincrement=True) - first_name = Column(String(50)) - last_name = Column(String(50)) - full_name = column_property(first_name + " " + last_name) + id = Column( + Integer, + primary_key=True, + autoincrement=True, + comment="The unique identifier of the staff user.", + ) + first_name = Column(String(50), comment="The firstname of the staff user.") + last_name = Column(String(50), comment="The lastname of the staff user.") position_id = Column( Integer, ForeignKey("positions.id", name="staff_users_position_id_fkey"), nullable=False, + comment="The unique identifier of the position of the staff user.", ) deputy_director_id = Column( Integer, ForeignKey("staff_users.id", name="staff_users_deputy_director_id_fkey"), nullable=True, + comment="The unique identifier of the deputy director.", ) supervisor_id = Column( Integer, ForeignKey("staff_users.id", name="staff_users_supervisor_id_fkey"), nullable=True, + comment="The unique identifier of the supervisor.", + ) + auth_user_guid = Column( + String(100), + index=True, + unique=True, + comment="The unique identifier from the identity provider.", ) - auth_user_guid = Column(String(100), index=True, unique=True) position = relationship("Position", foreign_keys=[position_id], lazy="select") @classmethod diff --git a/compliance-api/src/compliance_api/schemas/base_schema.py b/compliance-api/src/compliance_api/schemas/base_schema.py index 213ed38d..1f62b403 100644 --- a/compliance-api/src/compliance_api/schemas/base_schema.py +++ b/compliance-api/src/compliance_api/schemas/base_schema.py @@ -15,29 +15,40 @@ from marshmallow import Schema, fields, post_dump +from compliance_api.models.db import ma + class BaseSchema(Schema): # pylint: disable=too-many-ancestors, too-few-public-methods """Base Schema.""" def __init__(self, *args, **kwargs): """Excludes versions. Otherwise database will query _versions table.""" - if hasattr(self.opts.model, 'versions') and (len(self.opts.fields) == 0): - self.opts.exclude += ('versions',) + meta = getattr(self, "Meta", None) + if ( + meta and hasattr(meta, "model") and hasattr(meta["model"], "versions") and not self.fields + ): + self.exclude = getattr(self.Meta, "exclude", ()) + ("versions",) super().__init__(*args, **kwargs) class Meta: # pylint: disable=too-few-public-methods """Meta class to declare any class attributes.""" - datetimeformat = '%Y-%m-%dT%H:%M:%S+00:00' # Default output date format. + datetimeformat = "%Y-%m-%dT%H:%M:%S+00:00" # Default output date format. created_by = fields.Function( - lambda obj: f'{obj.created_by.firstname} {obj.created_by.lastname}' if getattr(obj, 'created_by', - None) else None + lambda obj: ( + f"{obj.created_by.firstname} {obj.created_by.lastname}" + if getattr(obj, "created_by", None) + else None + ) ) updated_by = fields.Function( - lambda obj: f'{obj.updated_by.firstname} {obj.updated_by.lastname}' if getattr(obj, 'updated_by', - None) else None + lambda obj: ( + f"{obj.updated_by.firstname} {obj.updated_by.lastname}" + if getattr(obj, "updated_by", None) + else None + ) ) @post_dump(pass_many=True) @@ -45,16 +56,40 @@ def _remove_empty(self, data, many): # pylint: disable=no-self-use """Remove all empty values and versions from the dumped dict.""" if not many: for key in list(data): - if key == 'versions': + if key == "versions": data.pop(key) - return { - key: value for key, value in data.items() - if value is not None - } + return {key: value for key, value in data.items() if value is not None} for item in data: for key in list(item): - if (key == 'versions') or (item[key] is None): + if (key == "versions") or (item[key] is None): item.pop(key) return data + + +class AutoSchemaBase(ma.SQLAlchemyAutoSchema): # pylint: disable=too-many-ancestors + """Representation of a base SQL alchemy auto schema with basic functions.""" + + class Meta: # pylint: disable=too-few-public-methods + """Meta information applicable to all schemas.""" + + model = None + exclude = ( + "created_date", + "created_by", + "updated_date", + "updated_by", + "is_deleted", + ) + # abstract=True + + def on_bind_field(self, field_name, field_obj): + """on_bind_field method.""" + # Get the SQLAlchemy column associated with this field + column = self.Meta.model.__table__.columns.get(field_name) + if column is not None and column.comment: + # Set the description meta attribute to the column's comment + field_obj.metadata["description"] = column.comment + + super().on_bind_field(field_name, field_obj) diff --git a/compliance-api/src/compliance_api/schemas/staff_user.py b/compliance-api/src/compliance_api/schemas/staff_user.py index 276f7d0c..4f069ce2 100644 --- a/compliance-api/src/compliance_api/schemas/staff_user.py +++ b/compliance-api/src/compliance_api/schemas/staff_user.py @@ -15,54 +15,28 @@ from marshmallow import EXCLUDE, Schema, fields from marshmallow_enum import EnumField -from compliance_api.models.staff_user import PERMISSION_MAP, PermissionEnum, StaffUser +from compliance_api.models.staff_user import PermissionEnum, StaffUser +from .base_schema import AutoSchemaBase from .common import KeyValueSchema -class StaffUserSchema(Schema): +class StaffUserSchema(AutoSchemaBase): # pylint: disable=too-many-ancestors """Staff User schema.""" - class Meta: # pylint: disable=too-few-public-methods + class Meta(AutoSchemaBase.Meta): # pylint: disable=too-few-public-methods """Exclude unknown fields in the deserialized output.""" unknown = EXCLUDE model = StaffUser include_fk = True - id = fields.Int( - metadata={"description": "The unique identifier of the staff user."} - ) - first_name = fields.Str( - metadata={"description": "The firstname of the staff user."} - ) - last_name = fields.Str(metadata={"description": "The lastname of the staff user."}) - position_id = fields.Int( - metadata={ - "description": "The unique identifier of the position of the staff user." - } - ) - position = fields.Nested( - KeyValueSchema, dump_only=True - ) - deputy_director_id = fields.Int( - metadata={"description": "The unique identifier of the deputy director."} - ) - supervisor_id = fields.Int( - metadata={"description": "The unique identifier of the supervisor."} - ) - auth_user_id = fields.Str( - metadata={"description": "The unique identifier from the identity provider."} - ) - full_name = fields.Str( - metadata={"description": "Fullname of the staff user"} - ) - # permission = fields.Method("get_user_permission", required=True) + position = fields.Nested(KeyValueSchema, dump_only=True) + full_name = fields.Method("get_full_name") - def get_user_permission(self, staff_user: StaffUser): # pylint: disable=no-self-use - """Extract the permission value from the enum.""" - permission_value = PERMISSION_MAP[staff_user.permission] - return permission_value + def get_full_name(self, obj): # pylint: disable=no-self-use + """Derive fullname.""" + return f"{obj.first_name} {obj.last_name}" class StaffUserCreateSchema(Schema): diff --git a/compliance-api/src/compliance_api/services/__init__.py b/compliance-api/src/compliance_api/services/__init__.py index e405832c..f79e2cfd 100644 --- a/compliance-api/src/compliance_api/services/__init__.py +++ b/compliance-api/src/compliance_api/services/__init__.py @@ -12,6 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """Exposes all of the Services used in the compliance_api.""" -from .auth_service import AuthService from .position import PositionService from .staff_user import StaffUserService diff --git a/compliance-api/src/compliance_api/services/auth_service.py b/compliance-api/src/compliance_api/services/auth_service.py deleted file mode 100644 index 1afa42b1..00000000 --- a/compliance-api/src/compliance_api/services/auth_service.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Service to call epic.authorize endpoints.""" - - -class AuthService: - """Handle service request for epic.authorize.""" - - @staticmethod - def get_epic_user_by_id(user_id: str): - """Return the user representation from epic.authorize.""" - return {"first_name": "Dinesh", "last_name": "Balakrishnan", "user_id": user_id} diff --git a/compliance-api/src/compliance_api/services/authorize_service/auth_service.py b/compliance-api/src/compliance_api/services/authorize_service/auth_service.py new file mode 100644 index 00000000..489f891e --- /dev/null +++ b/compliance-api/src/compliance_api/services/authorize_service/auth_service.py @@ -0,0 +1,57 @@ +"""Service to call epic.authorize endpoints.""" + +import requests +from flask import current_app, g + +from compliance_api.exceptions import BusinessError +from compliance_api.utils.enum import HttpMethod + +from .auth_user_schema import AuthUserSchema +from .constant import API_REQUEST_TIMEOUT + + +class AuthService: + """Handle service request for epic.authorize.""" + + @staticmethod + def get_epic_user_by_id(auth_user_id: str): + """Return the user representation from epic.authorize.""" + auth_user_response_json = _request_auth_service(f"/users/{auth_user_id}") + return AuthUserSchema().load(auth_user_response_json) + + @staticmethod + def update_user_group(auth_user_id: str, payload: dict): + """Update the group of the user in the identity server.""" + update_group_response = _request_auth_service( + f"/users/{auth_user_id}/group", HttpMethod.PATCH, payload + ) + return update_group_response + + +def _request_auth_service( + self, relative_url, http_method: HttpMethod = HttpMethod.GET, data=None +): + """REST Api call to authorize service.""" + token = getattr(g, "access_token", None) + if not token: + raise BusinessError("No access token found", 401) + auth_base_url = current_app.config["AUTH_BASE_URL"] + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + } + + url = f"{auth_base_url}/{relative_url}" + + if http_method == HttpMethod.GET: + response = requests.get(url, headers=headers, timeout=API_REQUEST_TIMEOUT) + elif http_method == HttpMethod.PUT: + response = requests.put( + url, headers=headers, data=data, timeout=API_REQUEST_TIMEOUT + ) + elif http_method == HttpMethod.DELETE: + response = requests.delete(url, headers=headers, timeout=API_REQUEST_TIMEOUT) + else: + raise ValueError("Invalid HTTP method") + response.raise_for_status() + return response diff --git a/compliance-api/src/compliance_api/services/authorize_service/auth_user_schema.py b/compliance-api/src/compliance_api/services/authorize_service/auth_user_schema.py new file mode 100644 index 00000000..fc22a9d2 --- /dev/null +++ b/compliance-api/src/compliance_api/services/authorize_service/auth_user_schema.py @@ -0,0 +1,23 @@ +# Copyright © 2024 Province of British Columbia +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Schema for user representation from epic.authorize.""" +from marshmallow import Schema, fields + + +class AuthUserSchema(Schema): + """Schema for the auth user.""" + + first_name = fields.Str(metadata={"description": "The first name of the user"}, required=True) + last_name = fields.Str(metadata={"description": "The lastname of the user"}, required=True) + id = fields.Str(metadata={"description": "The unique id of the user"}, required=True) diff --git a/compliance-api/src/compliance_api/services/authorize_service/constant.py b/compliance-api/src/compliance_api/services/authorize_service/constant.py new file mode 100644 index 00000000..980d8117 --- /dev/null +++ b/compliance-api/src/compliance_api/services/authorize_service/constant.py @@ -0,0 +1,3 @@ +"""Constant used for authorize service.""" + +API_REQUEST_TIMEOUT = 60 diff --git a/compliance-api/src/compliance_api/services/staff_user.py b/compliance-api/src/compliance_api/services/staff_user.py index a9709472..0e489a42 100644 --- a/compliance-api/src/compliance_api/services/staff_user.py +++ b/compliance-api/src/compliance_api/services/staff_user.py @@ -1,10 +1,12 @@ """Service for user management.""" from compliance_api.exceptions import UnprocessableEntityError +from compliance_api.models.db import session_scope from compliance_api.models.staff_user import PERMISSION_MAP, PermissionEnum from compliance_api.models.staff_user import StaffUser as UserModel +from compliance_api.utils.constant import AUTH_APP -from .auth_service import AuthService +from .authorize_service.auth_service import AuthService class StaffUserService: @@ -23,7 +25,7 @@ def get_all_users(cls): return users @classmethod - def create_user(cls, user_data): + def create_user(cls, user_data: dict): """Create user.""" auth_user_id = user_data.get("auth_user_id", None) auth_user = AuthService.get_epic_user_by_id(auth_user_id) @@ -31,9 +33,15 @@ def create_user(cls, user_data): raise UnprocessableEntityError( f"No user found from EPIC.Authorize corresponding to the given {auth_user_id}" ) - user_data["first_name"] = auth_user.get("first_name", None) - user_data["last_name"] = auth_user.get("last_name", None) - created_user = UserModel.create_user(user_data) + user_obj = _create_staff_user_object(user_data, auth_user) + group_payload = { + "auth_user_id": auth_user_id, + "app": AUTH_APP, + "group": user_data.get("permission", None), + } + with session_scope() as session: + created_user = UserModel.create_user(user_obj, session) + AuthService.update_user_group(auth_user_id, group_payload) return created_user @classmethod @@ -45,9 +53,15 @@ def update_user(cls, user_id, user_data): raise UnprocessableEntityError( f"No user found from EPIC.Authorize corresponding to the given {auth_user_id}" ) - user_data["first_name"] = auth_user.get("first_name", None) - user_data["last_name"] = auth_user.get("last_name", None) - updated_user = UserModel.update_user(user_id, user_data) + user_obj = _create_staff_user_object(user_data, auth_user) + group_payload = { + "auth_user_id": auth_user_id, + "app": AUTH_APP, + "group": user_data.get("permission", None), + } + with session_scope() as session: + updated_user = UserModel.update_user(user_id, user_obj, session) + AuthService.update_user_group(auth_user_id, group_payload) return updated_user @classmethod @@ -67,3 +81,15 @@ def get_permission_levels(cls): return [ {"id": perm.name, "name": PERMISSION_MAP[perm]} for perm in PermissionEnum ] + + +def _create_staff_user_object(user_data: dict, auth_user: dict): + """Create a staff user object.""" + return { + "first_name": auth_user.get("first_name", None), + "last_name": auth_user.get("last_name", None), + "position_id": user_data.get("position_id", None), + "deputy_director_id": user_data.get("deputy_director_id"), + "supervisor_id": user_data.get("supervisor_id", None), + "auth_user_id": user_data.get("auth_user_id", None), + } diff --git a/compliance-api/src/compliance_api/utils/enum.py b/compliance-api/src/compliance_api/utils/enum.py new file mode 100644 index 00000000..f7f83a5d --- /dev/null +++ b/compliance-api/src/compliance_api/utils/enum.py @@ -0,0 +1,25 @@ +# Copyright © 2019 Province of British Columbia +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Enum definitions.""" +from enum import Enum + + +class HttpMethod(Enum): + """Http methods.""" + + GET = "GET" + PUT = "PUT" + POST = "POST" + PATCH = "PATCH" + DELETE = "DELETE"