From 50c4e2e1b3c40907d200f9c9c1152d1f457384f4 Mon Sep 17 00:00:00 2001 From: Samuel Jones Date: Thu, 9 Jan 2025 09:26:08 +0000 Subject: [PATCH] Add inst scientist auth (#92) * Add AD checks for instrument scientists * Formatting and linting commit * Trial fixes for e2e * Add is_instrument_scientist patchs to e2e tests * Add docstring * Add docstring checks to ruff * DO a whole bunch of docs fixes. * DO a whole bunch of docs fixes. * DO a whole bunch of docs fixes. * Remove support for D100 ruff linter * Formatting and linting commit --------- Co-authored-by: github-actions --- fia_auth/auth.py | 4 +- fia_auth/db.py | 15 +++---- fia_auth/exception_handlers.py | 2 + fia_auth/exceptions.py | 11 ++++-- fia_auth/experiments.py | 4 +- fia_auth/fia_auth.py | 4 +- fia_auth/model.py | 22 +++++------ fia_auth/roles.py | 28 ++++++++++++++ fia_auth/routers.py | 12 +++--- fia_auth/tokens.py | 31 ++++++++------- pyproject.toml | 6 ++- test/e2e/conftest.py | 9 +---- test/e2e/test_auth.py | 37 +++++++++++++----- test/e2e/test_db.py | 12 ++---- test/e2e/test_experiments.py | 5 +-- test/test_auth.py | 1 + test/test_model.py | 30 +++++++++++++- test/test_roles.py | 71 ++++++++++++++++++++++++++++++++++ test/test_tokens.py | 1 + 19 files changed, 217 insertions(+), 88 deletions(-) create mode 100644 fia_auth/roles.py create mode 100644 test/test_roles.py diff --git a/fia_auth/auth.py b/fia_auth/auth.py index 39a3196..57af5d7 100644 --- a/fia_auth/auth.py +++ b/fia_auth/auth.py @@ -1,6 +1,4 @@ -""" -Module containing code to authenticate with the UOWS -""" +"""Module containing code to authenticate with the UOWS""" import logging import os diff --git a/fia_auth/db.py b/fia_auth/db.py index 534be82..ddc0e8e 100644 --- a/fia_auth/db.py +++ b/fia_auth/db.py @@ -1,6 +1,4 @@ -""" -DB Access moculde -""" +"""DB Access moculde""" import logging import os @@ -13,17 +11,15 @@ class Base(DeclarativeBase): - """ - SQLAlchemy Base Model - """ + + """SQLAlchemy Base Model""" id: Mapped[int] = mapped_column(primary_key=True) class Staff(Base): - """ - Staff user - """ + + """Staff user""" __tablename__ = "staff" user_number: Mapped[int] = mapped_column(Integer()) @@ -47,7 +43,6 @@ def is_staff_user(user_number: int) -> bool: :param user_number: The user number to check :return: boolean indicating if it is a staff """ - try: with SESSION() as session: session.execute(select(Staff.user_number).where(Staff.user_number == user_number)).one() diff --git a/fia_auth/exception_handlers.py b/fia_auth/exception_handlers.py index 5f2b85d..37e0e94 100644 --- a/fia_auth/exception_handlers.py +++ b/fia_auth/exception_handlers.py @@ -1,3 +1,5 @@ +"""Error handlers""" + import logging from http import HTTPStatus diff --git a/fia_auth/exceptions.py b/fia_auth/exceptions.py index 15bc44c..ad8b753 100644 --- a/fia_auth/exceptions.py +++ b/fia_auth/exceptions.py @@ -1,23 +1,26 @@ -""" -FIA Auth custom exceptions -""" +"""FIA Auth custom exceptions""" class UOWSError(Exception): + """Problem authenticating with the user office web service""" class ProposalAllocationsError(Exception): + """Problem connecting with the proposal allocations api""" class AuthenticationError(Exception): + """Problem with authentication mechanism""" class BadCredentialsError(AuthenticationError): - """ "Bad Credentials Provided""" + + """Bad Credentials Provided""" class BadJWTError(AuthenticationError): + """Raised when a bad jwt has been given to the service""" diff --git a/fia_auth/experiments.py b/fia_auth/experiments.py index 1412daf..c2e01fe 100644 --- a/fia_auth/experiments.py +++ b/fia_auth/experiments.py @@ -1,6 +1,4 @@ -""" -Module for dealing with experiment user data via proposal allocations API -""" +"""Module for dealing with experiment user data via proposal allocations API""" import logging import os diff --git a/fia_auth/fia_auth.py b/fia_auth/fia_auth.py index 6c0adb4..b3e47e9 100644 --- a/fia_auth/fia_auth.py +++ b/fia_auth/fia_auth.py @@ -1,6 +1,4 @@ -""" -Module containing the fast api app. Uvicorn loads this to start the api -""" +"""Module containing the fast api app. Uvicorn loads this to start the api""" from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware diff --git a/fia_auth/model.py b/fia_auth/model.py index e8af90f..a7d34f6 100644 --- a/fia_auth/model.py +++ b/fia_auth/model.py @@ -1,6 +1,4 @@ -""" -Internal Models to help abstract and encapsulate the authentication process -""" +"""Internal Models to help abstract and encapsulate the authentication process""" import enum from dataclasses import dataclass @@ -8,21 +6,20 @@ from pydantic import BaseModel from fia_auth.db import is_staff_user +from fia_auth.roles import is_instrument_scientist class UserCredentials(BaseModel): - """ - Pydantic model for user credentials. Allows FastAPI to validate the object recieved in the login endpoint - """ + + """Pydantic model for user credentials. Allows FastAPI to validate the object recieved in the login endpoint""" username: str password: str class Role(enum.Enum): - """ - Role Enum to differentiate between user and staff. It is assumed staff will see all data - """ + + """Role Enum to differentiate between user and staff. It is assumed staff will see all data""" STAFF = "staff" USER = "user" @@ -30,9 +27,8 @@ class Role(enum.Enum): @dataclass class User: - """ - Internal User Model for packing JWTs - """ + + """Internal User Model for packing JWTs""" user_number: int @@ -42,6 +38,6 @@ def role(self) -> Role: Determine and determine the role of the user based on their usernumber :return: """ - if is_staff_user(self.user_number): + if is_staff_user(self.user_number) or is_instrument_scientist(self.user_number): return Role.STAFF return Role.USER diff --git a/fia_auth/roles.py b/fia_auth/roles.py new file mode 100644 index 0000000..170c7bb --- /dev/null +++ b/fia_auth/roles.py @@ -0,0 +1,28 @@ +"""Functions for handling role checks""" + +import os +from http import HTTPStatus + +import requests + + +def is_instrument_scientist(user_number: int) -> bool: + """ + Check if the user number is an instrument scientist according to UOWs (User Office Web Service). + :param user_number: The user number assigned to each user from UOWs + :return: True if the user number is an instrument scientist, false if not or failed connection. + """ + uows_url = os.environ.get("UOWS_URL", "https://devapi.facilities.rl.ac.uk/users-service") + uows_api_key = os.environ.get("UOWS_API_KEY", "") + response = requests.get( + url=f"{uows_url}/v1/role/{user_number}", + headers={"Authorization": f"Api-key {uows_api_key}", "accept": "application/json"}, + timeout=1, + ) + if response.status_code != HTTPStatus.OK: + from fia_auth.auth import logger + + logger.info("User number %s is not an instrument scientist or UOWS API is down", user_number) + return False + roles = response.json() + return {"name": "ISIS Instrument Scientist"} in roles diff --git a/fia_auth/routers.py b/fia_auth/routers.py index b848241..880ccc1 100644 --- a/fia_auth/routers.py +++ b/fia_auth/routers.py @@ -1,6 +1,4 @@ -""" -Module containing the fastapi routers -""" +"""Module containing the fastapi routers""" from __future__ import annotations @@ -31,7 +29,7 @@ async def get_experiments( user_number: int, credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)] ) -> list[int]: - """ + r""" Get the experiment (RB) numbers for the given user number provided by query string parameter \f @@ -46,7 +44,7 @@ async def get_experiments( @ROUTER.post("/api/jwt/authenticate", tags=["auth"]) async def login(credentials: UserCredentials) -> JSONResponse: - """ + r""" Login with facilities account \f :param credentials: username and password @@ -74,7 +72,7 @@ async def login(credentials: UserCredentials) -> JSONResponse: @ROUTER.post("/api/jwt/checkToken") def verify(token: dict[str, Any]) -> Literal["ok"]: - """ + r""" Verify an access token was generated by this auth server and has not expired \f :param token: The JWT @@ -90,7 +88,7 @@ def verify(token: dict[str, Any]) -> Literal["ok"]: def refresh( body: dict[str, Any], refresh_token: Annotated[str | None, Cookie(alias="refresh_token")] = None ) -> JSONResponse: - """ + r""" Refresh an access token based on a refresh token \f :param refresh_token: diff --git a/fia_auth/tokens.py b/fia_auth/tokens.py index 99b92c8..676c47e 100644 --- a/fia_auth/tokens.py +++ b/fia_auth/tokens.py @@ -1,6 +1,4 @@ -""" -Module containing token classes, creation, and loading functions -""" +"""Module containing token classes, creation, and loading functions""" from __future__ import annotations @@ -25,15 +23,14 @@ class Token(ABC): - """ - Abstract token class defines verify method - """ + + """Abstract token class defines verify method""" jwt: str def verify(self) -> None: """ - Verifies the token, ensuring that it has a valid format, signature, and has not expired. Will raise a + Verify the token, ensuring that it has a valid format, signature, and has not expired. Will raise a BadJWTError if verification fails :return: None """ @@ -65,11 +62,15 @@ def _encode(self) -> None: class AccessToken(Token): - """ - Access Token is a short-lived (5 minute) token that stores user information - """ + + """Access Token is a short-lived (5 minute) token that stores user information""" def __init__(self, jwt_token: str | None = None, payload: dict[str, Any] | None = None) -> None: + """ + Create AccessToken, requires jwt_token XOR a payload + :param jwt_token: JWT to populate the payload if JWT provided and no payload. + :param payload: Payload to encode if no JWT provided + """ if payload and not jwt_token: self._payload = payload self._payload["exp"] = datetime.now(UTC) + timedelta(minutes=float(ACCESS_TOKEN_LIFETIME_MINUTES)) @@ -99,11 +100,15 @@ def refresh(self) -> None: class RefreshToken(Token): - """ - Refresh token is a long-lived (12 hour) token that is required to refresh an access token - """ + + """Refresh token is a long-lived (12 hour) token that is required to refresh an access token""" def __init__(self, jwt_token: str | None = None) -> None: + """ + Create the RefreshToken + :param jwt_token: Optionally provide the jwt_token to create the payload else construct payload using default + method. + """ if jwt_token is None: self._payload = {"exp": datetime.now(UTC) + timedelta(hours=12)} self._encode() diff --git a/pyproject.toml b/pyproject.toml index dcbfdb8..add63bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "psycopg2==2.9.9", "PyJWT==2.8.0", "SQLAlchemy==2.0.30", - "uvicorn==0.30.1" + "uvicorn==0.30.1", + "requests==2.32.3" ] [project.urls] @@ -45,6 +46,7 @@ select = [ "F", # flake8 - Basic initial rules "E", # pycodestyle (Error) - pep8 compliance "W", # pycodestyle (Warning) - pep8 compliance + "D", # pydocstyle - Enforce docstrings present "C90", # mccabe - flags extremely complex functions "I", # isort - Sort imports and flag missing imports "N", # pep8-naming - Ensures pep8 compliance for naming @@ -77,6 +79,8 @@ ignore = [ "S101", # flake8-bandit - Use of assert (all over pytest tests) "ISC001", # Conflicts with the formatter "COM812", # Conflicts with the formatter + "D211", # Conflicts with it's own rules: D203 + "D104", "D205", "D212", "D400", "D415" # Overzealous docstring checker ] [tool.ruff.lint.pylint] diff --git a/test/e2e/conftest.py b/test/e2e/conftest.py index c1e7892..039ba21 100644 --- a/test/e2e/conftest.py +++ b/test/e2e/conftest.py @@ -1,6 +1,4 @@ -""" -e2e session scoped fixtures -""" +"""e2e session scoped fixtures""" import pytest @@ -9,9 +7,6 @@ @pytest.fixture(scope="session", autouse=True) def _setup(): - """ - Setup database pre-testing - :return: - """ + """Set up database pre-testing""" Base.metadata.drop_all(ENGINE) Base.metadata.create_all(ENGINE) diff --git a/test/e2e/test_auth.py b/test/e2e/test_auth.py index 2a290ba..de38659 100644 --- a/test/e2e/test_auth.py +++ b/test/e2e/test_auth.py @@ -1,3 +1,5 @@ +# ruff: noqa: D100, D103 + from http import HTTPStatus from unittest.mock import Mock, patch @@ -11,16 +13,19 @@ client = TestClient(app) -@patch("fia_auth.auth.requests.post") -def test_successful_login(mock_post): - mock_response = Mock() - mock_post.return_value = mock_response +@patch("fia_auth.model.is_instrument_scientist") +@patch("fia_auth.auth.requests") +def test_successful_login(mock_auth_requests, is_instrument_scientist): + mock_auth_response = Mock() + mock_auth_response.status_code = HTTPStatus.CREATED + mock_auth_response.json.return_value = {"userId": 1234} + mock_auth_requests.post.return_value = mock_auth_response + is_instrument_scientist.return_value = False - mock_response.status_code = HTTPStatus.CREATED - mock_response.json.return_value = {"userId": 1234} response = client.post("/api/jwt/authenticate", json={"username": "foo", "password": "foo"}) assert response.json()["token"].startswith("ey") assert response.cookies["refresh_token"].startswith("ey") + is_instrument_scientist.assert_called_once_with(1234) @patch("fia_auth.auth.requests.post") @@ -45,12 +50,15 @@ def test_unsuccessful_login_uows_failure(mock_post): assert response.status_code == HTTPStatus.FORBIDDEN -def test_verify_success(): +@patch("fia_auth.model.is_instrument_scientist") +def test_verify_success(is_instrument_scientist): + is_instrument_scientist.return_value = False user = User(123) access_token = generate_access_token(user) response = client.post("/api/jwt/checkToken", json={"token": access_token.jwt}) assert response.status_code == HTTPStatus.OK + is_instrument_scientist.assert_called_once_with(123) def test_verify_fail_badly_formed_token(): @@ -63,7 +71,9 @@ def test_verify_fail_bad_signature(): assert response.status_code == HTTPStatus.FORBIDDEN -def test_token_refresh_success(): +@patch("fia_auth.model.is_instrument_scientist") +def test_token_refresh_success(is_instrument_scientist): + is_instrument_scientist.return_value = False user = User(123) access_token = generate_access_token(user) refresh_token = generate_refresh_token() @@ -71,9 +81,12 @@ def test_token_refresh_success(): "/api/jwt/refresh", json={"token": access_token.jwt}, cookies={"refresh_token": refresh_token.jwt} ) assert response.json()["token"].startswith("ey") + is_instrument_scientist.assert_called_once_with(123) -def test_token_refresh_no_refresh_token_given(): +@patch("fia_auth.model.is_instrument_scientist") +def test_token_refresh_no_refresh_token_given(is_instrument_scientist): + is_instrument_scientist.return_value = False user = User(123) access_token = generate_access_token(user) response = client.post( @@ -82,9 +95,12 @@ def test_token_refresh_no_refresh_token_given(): ) assert response.status_code == HTTPStatus.FORBIDDEN + is_instrument_scientist.assert_called_once_with(123) -def test_token_refresh_expired_refresh_token(): +@patch("fia_auth.model.is_instrument_scientist") +def test_token_refresh_expired_refresh_token(is_instrument_scientist): + is_instrument_scientist.return_value = False user = User(123) access_token = generate_access_token(user) refresh_token = ( @@ -95,3 +111,4 @@ def test_token_refresh_expired_refresh_token(): ) assert response.status_code == HTTPStatus.FORBIDDEN + is_instrument_scientist.assert_called_once_with(123) diff --git a/test/e2e/test_db.py b/test/e2e/test_db.py index af72c65..5e8eef2 100644 --- a/test/e2e/test_db.py +++ b/test/e2e/test_db.py @@ -1,14 +1,10 @@ -""" -Test cases for db module -""" +"""Test cases for db module""" from fia_auth.db import SESSION, Staff, is_staff_user def test_is_staff_staff_user_exists(): - """ - test is staff returns true when staff exists - """ + """Test is staff returns true when staff exists""" with SESSION() as session: staff = Staff(user_number=54321) session.add(staff) @@ -18,7 +14,5 @@ def test_is_staff_staff_user_exists(): def test_is_staff_user_does_not_exist(): - """ - Test is staff returns false when not staff - """ + """Test is staff returns false when not staff""" assert not is_staff_user(5678) diff --git a/test/e2e/test_experiments.py b/test/e2e/test_experiments.py index 77e791f..cf28fc5 100644 --- a/test/e2e/test_experiments.py +++ b/test/e2e/test_experiments.py @@ -1,6 +1,5 @@ -""" -e2e test cases -""" +"""e2e test cases""" +# ruff: noqa: D103 from http import HTTPStatus from unittest.mock import patch diff --git a/test/test_auth.py b/test/test_auth.py index eafc4c7..48b7169 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -1,3 +1,4 @@ +# ruff: noqa: D100, D103 from http import HTTPStatus from unittest.mock import Mock, patch diff --git a/test/test_model.py b/test/test_model.py index 48b933a..018d572 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -6,18 +6,44 @@ @patch("fia_auth.model.is_staff_user") -def test_role_is_user(mock_is_staff_user): +@patch("fia_auth.model.is_instrument_scientist") +def test_role_is_user(mock_is_staff_user, mock_is_instrument_scientist): """Test user enum assign""" mock_is_staff_user.return_value = False + mock_is_instrument_scientist.return_value = False user = User(user_number=1234) assert user.role == Role.USER @patch("fia_auth.model.is_staff_user") -def test_role_is_staff(mock_is_staff_user): +@patch("fia_auth.model.is_instrument_scientist") +def test_role_is_staff_in_db(mock_is_staff_user, mock_is_instrument_scientist): """Test staff enum assign""" mock_is_staff_user.return_value = True + mock_is_instrument_scientist.return_value = False + + user = User(user_number=1234) + assert user.role == Role.STAFF + + +@patch("fia_auth.model.is_staff_user") +@patch("fia_auth.model.is_instrument_scientist") +def test_role_is_staff_in_db_and_inst_scientist(mock_is_staff_user, mock_is_instrument_scientist): + """Test staff enum assign""" + mock_is_staff_user.return_value = True + mock_is_instrument_scientist.return_value = True + + user = User(user_number=1234) + assert user.role == Role.STAFF + + +@patch("fia_auth.model.is_staff_user") +@patch("fia_auth.model.is_instrument_scientist") +def test_role_is_staff_inst_scientist(mock_is_staff_user, mock_is_instrument_scientist): + """Test staff enum assign""" + mock_is_staff_user.return_value = False + mock_is_instrument_scientist.return_value = True user = User(user_number=1234) assert user.role == Role.STAFF diff --git a/test/test_roles.py b/test/test_roles.py new file mode 100644 index 0000000..67bbd04 --- /dev/null +++ b/test/test_roles.py @@ -0,0 +1,71 @@ +# ruff: noqa: D100, D103 +import os +import random +from http import HTTPStatus +from unittest import mock + +from fia_auth.roles import is_instrument_scientist + + +@mock.patch("fia_auth.roles.requests") +def test_is_instrument_scientist_true(requests): + uows_url = str(mock.MagicMock()) + uows_api_key = str(mock.MagicMock()) + os.environ["UOWS_URL"] = uows_url + os.environ["UOWS_API_KEY"] = uows_api_key + user_number = random.randint(0, 100000) # noqa: S311 + requests.get.return_value.status_code = HTTPStatus.OK + requests.get.return_value.json.return_value = [{"name": "ISIS Instrument Scientist"}] + + result = is_instrument_scientist(user_number) + + requests.get.assert_called_once_with( + url=f"{uows_url}/v1/role/{user_number}", + headers={"Authorization": f"Api-key {uows_api_key}", "accept": "application/json"}, + timeout=1, + ) + assert result + os.environ.pop("UOWS_URL") + os.environ.pop("UOWS_API_KEY") + + +@mock.patch("fia_auth.roles.requests") +def test_is_instrument_scientist_false(requests): + uows_url = str(mock.MagicMock()) + uows_api_key = str(mock.MagicMock()) + os.environ["UOWS_URL"] = uows_url + os.environ["UOWS_API_KEY"] = uows_api_key + user_number = random.randint(0, 100000) # noqa: S311 + requests.get.return_value.status_code = HTTPStatus.OK + requests.get.return_value.json.return_value = [{"name": "Not ISIS Instrument Scientist"}] + + result = is_instrument_scientist(user_number) + + requests.get.assert_called_once_with( + url=f"{uows_url}/v1/role/{user_number}", + headers={"Authorization": f"Api-key {uows_api_key}", "accept": "application/json"}, + timeout=1, + ) + assert not result + os.environ.pop("UOWS_URL") + os.environ.pop("UOWS_API_KEY") + + +@mock.patch("fia_auth.roles.requests") +def test_is_instrument_scientist_false_when_forbidden(requests): + uows_url = str(mock.MagicMock()) + uows_api_key = str(mock.MagicMock()) + os.environ["UOWS_URL"] = uows_url + os.environ["UOWS_API_KEY"] = uows_api_key + user_number = random.randint(0, 100000) # noqa: S311 + requests.get.return_value.status_code = HTTPStatus.FORBIDDEN + result = is_instrument_scientist(user_number) + + requests.get.assert_called_once_with( + url=f"{uows_url}/v1/role/{user_number}", + headers={"Authorization": f"Api-key {uows_api_key}", "accept": "application/json"}, + timeout=1, + ) + assert not result + os.environ.pop("UOWS_URL") + os.environ.pop("UOWS_API_KEY") diff --git a/test/test_tokens.py b/test/test_tokens.py index 2259015..96f4874 100644 --- a/test/test_tokens.py +++ b/test/test_tokens.py @@ -1,3 +1,4 @@ +# ruff: noqa: D100, D103 from datetime import UTC, datetime, timedelta from unittest.mock import Mock, patch