diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index 4834233e7..733e4b293 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -18,6 +18,7 @@ CERT_FILE="${AUTH_CERT:-$(realpath server/cert.crt)}" MEDPERF_STORAGE=~/.medperf MEDPERF_SUBSTORAGE="$MEDPERF_STORAGE/$(echo $SERVER_URL | cut -d '/' -f 3 | sed -e 's/[.:]/_/g')" MEDPERF_LOG_STORAGE="$MEDPERF_SUBSTORAGE/logs/medperf.log" +VERSION_PREFIX="/api/v0" echo "Server URL: $SERVER_URL" echo "Storage location: $MEDPERF_SUBSTORAGE" @@ -83,15 +84,15 @@ METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" METRICS_SING_IMAGE="$ASSETS_URL/metrics/mlcube/workspace/.image/image.tar.gz" # admin token -ADMIN_TOKEN=$(curl -sk -X POST https://127.0.0.1:8000/auth-token/ -d '{"username": "admin", "password": "admin"}' -H 'Content-Type: application/json' | jq -r '.token') +ADMIN_TOKEN=$(curl -sk -X POST $SERVER_URL$VERSION_PREFIX/auth-token/ -d '{"username": "admin", "password": "admin"}' -H 'Content-Type: application/json' | jq -r '.token') # create users MODELOWNER="mockmodelowner" DATAOWNER="mockdataowner" BENCHMARKOWNER="mockbenchmarkowner" -curl -sk -X POST https://127.0.0.1:8000/users/ -d '{"first_name": "model", "last_name": "owner", "username": "'"$MODELOWNER"'", "password": "test", "email": "model@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" -curl -sk -X POST https://127.0.0.1:8000/users/ -d '{"first_name": "bmk", "last_name": "owner", "username": "'"$BENCHMARKOWNER"'", "password": "test", "email": "bmk@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" -curl -sk -X POST https://127.0.0.1:8000/users/ -d '{"first_name": "data", "last_name": "owner", "username": "'"$DATAOWNER"'", "password": "test", "email": "data@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" +curl -sk -X POST $SERVER_URL$VERSION_PREFIX/users/ -d '{"first_name": "model", "last_name": "owner", "username": "'"$MODELOWNER"'", "password": "test", "email": "model@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" +curl -sk -X POST $SERVER_URL$VERSION_PREFIX/users/ -d '{"first_name": "bmk", "last_name": "owner", "username": "'"$BENCHMARKOWNER"'", "password": "test", "email": "bmk@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" +curl -sk -X POST $SERVER_URL$VERSION_PREFIX/users/ -d '{"first_name": "data", "last_name": "owner", "username": "'"$DATAOWNER"'", "password": "test", "email": "data@owner.com"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" ########################################################## ################### Start Testing ######################## @@ -184,7 +185,7 @@ medperf benchmark submit --name bmk --description bmk --demo-url $DEMO_URL --dat checkFailed "Benchmark submission failed" BMK_UID=$(medperf benchmark ls | tail -n 1 | tr -s ' ' | cut -d ' ' -f 2) -curl -sk -X PUT https://127.0.0.1:8000/benchmarks/$BMK_UID/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" +curl -sk -X PUT $SERVER_URL$VERSION_PREFIX/benchmarks/$BMK_UID/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Token $ADMIN_TOKEN" checkFailed "Benchmark approval failed" ########################################################## diff --git a/cli/medperf/__init__.py b/cli/medperf/__init__.py index e69de29bb..ad5cc752c 100644 --- a/cli/medperf/__init__.py +++ b/cli/medperf/__init__.py @@ -0,0 +1 @@ +from ._version import __version__ # noqa diff --git a/cli/medperf/__main__.py b/cli/medperf/__main__.py index 214da93f1..56c0fc64b 100644 --- a/cli/medperf/__main__.py +++ b/cli/medperf/__main__.py @@ -3,6 +3,7 @@ import logging.handlers from os.path import expanduser, abspath +from medperf import __version__ import medperf.config as config from medperf.ui.factory import UIFactory from medperf.decorators import clean_except, configurable @@ -173,11 +174,12 @@ def main(ctx: typer.Context): log = config.loglevel.upper() log_lvl = getattr(logging, log) setup_logging(log_lvl) + logging.info(f"Running MedPerf v{__version__} on {log_lvl} logging level") config.ui = UIFactory.create_ui(config.ui) config.comms = CommsFactory.create_comms(config.comms, config.server) - config.ui.print(f"MedPerf {config.version}") + config.ui.print(f"MedPerf {__version__}") if __name__ == "__main__": diff --git a/cli/medperf/_version.py b/cli/medperf/_version.py new file mode 100644 index 000000000..3dc1f76bc --- /dev/null +++ b/cli/medperf/_version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index 79f531190..c40fed43d 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -15,6 +15,19 @@ def __init__(self, source: str, ui: UI, token: str = None): token (str, Optional): authentication token to be used throughout communication. Defaults to None. """ + @classmethod + @abstractmethod + def parse_url(self, url: str) -> str: + """Parse the source URL so that it can be used by the comms implementation. + It should handle protocols and versioning to be able to communicate with the API. + + Args: + url (str): base URL + + Returns: + str: parsed URL with protocol and version + """ + @abstractmethod def login(self, ui: UI): """Authenticate the comms instance for further interactions diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index be3085ab3..cf454391d 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -40,20 +40,31 @@ def log_response_error(res, warn=False): class REST(Comms): def __init__(self, source: str, token=None): - self.server_url = self.__parse_url(source) + self.server_url = self.parse_url(source) self.token = token self.cert = config.certificate if self.cert is None: # No certificate provided, default to normal verification self.cert = True - def __parse_url(self, url): + @classmethod + def parse_url(cls, url: str) -> str: + """Parse the source URL so that it can be used by the comms implementation. + It should handle protocols and versioning to be able to communicate with the API. + + Args: + url (str): base URL + + Returns: + str: parsed URL with protocol and version + """ url_sections = url.split("://") + api_path = f"/api/v{config.major_version}" # Remove protocol if passed if len(url_sections) > 1: url = "".join(url_sections[1:]) - return f"https://{url}" + return f"https://{url}{api_path}" def login(self, user: str, pwd: str): """Authenticates the user with the server. Required for most endpoints diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 394763a41..b70c5ba03 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -1,6 +1,8 @@ +from ._version import __version__ from os.path import expanduser, abspath -version = "0.0.0" +major_version, minor_version, patch_version = __version__.split(".") + server = "https://api.medperf.org" certificate = None diff --git a/cli/medperf/setup.py b/cli/medperf/setup.py deleted file mode 100644 index 31816bee6..000000000 --- a/cli/medperf/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -from setuptools import setup - -with open("requirements.txt", "r") as f: - requires = [] - for line in f: - req = line.split("#", 1)[0].strip() - if req and not req.startswith("--"): - requires.append(req) - -setup( - name="medperf", - version="0.0.0", - description="CLI Tool for federated benchmarking on medical private data", - url="https://github.com/mlcommons/medical", - author="MLCommons", - license="Apache 2.0", - packages=["medperf"], - install_requires=requires, - python_requires=">=3.6", - entry_points=""" - [console_scripts] - medperf=medperf.__main__:app - """, -) diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index 90fc7eaaa..54c8fe8f8 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -10,6 +10,7 @@ from medperf.tests.mocks import MockResponse url = "https://mock.url" +full_url = REST.parse_url(url) patch_server = "medperf.comms.rest.{}" @@ -22,24 +23,24 @@ def server(mocker, ui): @pytest.mark.parametrize( "method_params", [ - ("get_benchmark", "get", 200, [1], {}, (f"{url}/benchmarks/1",), {}), + ("get_benchmark", "get", 200, [1], {}, (f"{full_url}/benchmarks/1",), {}), ( "get_benchmark_models", "get_list", 200, [1], [], - (f"{url}/benchmarks/1/models",), + (f"{full_url}/benchmarks/1/models",), {}, ), - ("get_cube_metadata", "get", 200, [1], {}, (f"{url}/mlcubes/1/",), {}), + ("get_cube_metadata", "get", 200, [1], {}, (f"{full_url}/mlcubes/1/",), {}), ( "upload_dataset", "post", 201, [{}], {"id": 1}, - (f"{url}/datasets/",), + (f"{full_url}/datasets/",), {"json": {}}, ), ( @@ -48,7 +49,7 @@ def server(mocker, ui): 201, [{}], {"id": 1}, - (f"{url}/results/",), + (f"{full_url}/results/",), {"json": {}}, ), ( @@ -57,7 +58,7 @@ def server(mocker, ui): 201, [1, 1], {}, - (f"{url}/datasets/benchmarks/",), + (f"{full_url}/datasets/benchmarks/",), { "json": { "benchmark": 1, @@ -71,18 +72,18 @@ def server(mocker, ui): "_REST__set_approval_status", "put", 200, - [f"{url}/mlcubes/1/benchmarks/1", Status.APPROVED.value], + [f"{full_url}/mlcubes/1/benchmarks/1", Status.APPROVED.value], {}, - (f"{url}/mlcubes/1/benchmarks/1",), + (f"{full_url}/mlcubes/1/benchmarks/1",), {"json": {"approval_status": Status.APPROVED.value}}, ), ( "_REST__set_approval_status", "put", 200, - [f"{url}/mlcubes/1/benchmarks/1", Status.REJECTED.value], + [f"{full_url}/mlcubes/1/benchmarks/1", Status.REJECTED.value], {}, - (f"{url}/mlcubes/1/benchmarks/1",), + (f"{full_url}/mlcubes/1/benchmarks/1",), {"json": {"approval_status": Status.REJECTED.value}}, ), ( @@ -91,7 +92,7 @@ def server(mocker, ui): 200, ["pwd"], {}, - (f"{url}/me/password/",), + (f"{full_url}/me/password/",), {"json": {"password": "pwd"}}, ), ], @@ -149,7 +150,7 @@ def test_login_with_user_and_pwd(mocker, server, ui, uname, pwd): res = MockResponse({"token": ""}, 200) spy = mocker.patch("requests.post", return_value=res) exp_body = {"username": uname, "password": pwd} - exp_path = f"{url}/auth-token/" + exp_path = f"{full_url}/auth-token/" cert_verify = config.certificate or True # Act @@ -249,12 +250,12 @@ def test__req_sanitizes_json(mocker, server): def test__get_list_uses_default_page_size(mocker, server): # Arrange exp_page_size = config.default_page_size - exp_url = f"{url}?limit={exp_page_size}&offset=0" + exp_url = f"{full_url}?limit={exp_page_size}&offset=0" ret_body = MockResponse({"count": 1, "next": None, "results": []}, 200) spy = mocker.patch.object(server, "_REST__auth_get", return_value=ret_body) # Act - server._REST__get_list(url) + server._REST__get_list(full_url) # Assert spy.assert_called_once_with(exp_url) @@ -336,7 +337,7 @@ def test_get_benchmarks_calls_benchmarks_path(mocker, server, body): bmarks = server.get_benchmarks() # Assert - spy.assert_called_once_with(f"{url}/benchmarks/") + spy.assert_called_once_with(f"{full_url}/benchmarks/") assert bmarks == [body] @@ -367,7 +368,7 @@ def test_get_user_benchmarks_calls_auth_get_for_expected_path(mocker, server): server.get_user_benchmarks() # Assert - spy.assert_called_once_with(f"{url}/me/benchmarks/") + spy.assert_called_once_with(f"{full_url}/me/benchmarks/") def test_get_user_benchmarks_returns_benchmarks(mocker, server): @@ -394,7 +395,7 @@ def test_get_mlcubes_calls_mlcubes_path(mocker, server, body): cubes = server.get_cubes() # Assert - spy.assert_called_once_with(f"{url}/mlcubes/") + spy.assert_called_once_with(f"{full_url}/mlcubes/") assert cubes == [body] @@ -484,7 +485,7 @@ def test_get_user_cubes_calls_auth_get_for_expected_path(mocker, server): server.get_user_cubes() # Assert - spy.assert_called_once_with(f"{url}/me/mlcubes/") + spy.assert_called_once_with(f"{full_url}/me/mlcubes/") def test_get_cube_file_calls_download_direct_link_method(mocker, server): @@ -512,7 +513,7 @@ def test_get_datasets_calls_datasets_path(mocker, server, body): dsets = server.get_datasets() # Assert - spy.assert_called_once_with(f"{url}/datasets/") + spy.assert_called_once_with(f"{full_url}/datasets/") assert dsets == [body] @@ -527,7 +528,7 @@ def test_get_dataset_calls_specific_dataset_path(mocker, server, uid, body): dset = server.get_dataset(uid) # Assert - spy.assert_called_once_with(f"{url}/datasets/{uid}/") + spy.assert_called_once_with(f"{full_url}/datasets/{uid}/") assert dset == body @@ -543,7 +544,7 @@ def test_get_user_datasets_calls_auth_get_for_expected_path(mocker, server): server.get_user_datasets() # Assert - spy.assert_called_once_with(f"{url}/me/datasets/") + spy.assert_called_once_with(f"{full_url}/me/datasets/") @pytest.mark.parametrize("body", [{"mlcube": 1}, {}, {"test": "test"}]) @@ -616,7 +617,7 @@ def test_set_dataset_association_approval_sets_approval( spy = mocker.patch( patch_server.format("REST._REST__set_approval_status"), return_value=res ) - exp_url = f"{url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" + exp_url = f"{full_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" # Act server.set_dataset_association_approval(benchmark_uid, dataset_uid, status) @@ -636,7 +637,7 @@ def test_set_mlcube_association_approval_sets_approval( spy = mocker.patch( patch_server.format("REST._REST__set_approval_status"), return_value=res ) - exp_url = f"{url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" + exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act server.set_mlcube_association_approval(benchmark_uid, mlcube_uid, status) @@ -648,7 +649,7 @@ def test_set_mlcube_association_approval_sets_approval( def test_get_datasets_associations_gets_associations(mocker, server): # Arrange spy = mocker.patch(patch_server.format("REST._REST__get_list"), return_value=[]) - exp_path = f"{url}/me/datasets/associations/" + exp_path = f"{full_url}/me/datasets/associations/" # Act server.get_datasets_associations() @@ -660,7 +661,7 @@ def test_get_datasets_associations_gets_associations(mocker, server): def test_get_cubes_associations_gets_associations(mocker, server): # Arrange spy = mocker.patch(patch_server.format("REST._REST__get_list"), return_value=[]) - exp_path = f"{url}/me/mlcubes/associations/" + exp_path = f"{full_url}/me/mlcubes/associations/" # Act server.get_cubes_associations() @@ -675,7 +676,7 @@ def test_get_result_calls_specified_path(mocker, server, uid, body): # Arrange res = MockResponse(body, 200) spy = mocker.patch(patch_server.format("REST._REST__auth_get"), return_value=res) - exp_path = f"{url}/results/{uid}/" + exp_path = f"{full_url}/results/{uid}/" # Act result = server.get_result(uid) @@ -707,7 +708,7 @@ def test_set_mlcube_association_priority_sets_priority( # Arrange res = MockResponse({}, 200) spy = mocker.patch(patch_server.format("REST._REST__auth_put"), return_value=res) - exp_url = f"{url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" + exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act server.set_mlcube_association_priority(benchmark_uid, mlcube_uid, priority) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 7aa766744..9681bcbe3 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -41,7 +41,6 @@ def setup_logging(log_lvl): requests_logger = logging.getLogger("requests") requests_logger.addHandler(handler) requests_logger.setLevel(log_lvl) - logging.info(f"Running MedPerf v{config.version} on {log_lvl} logging level") def delete_credentials(): diff --git a/cli/setup.py b/cli/setup.py index c5790dc2b..7f0e8b104 100644 --- a/cli/setup.py +++ b/cli/setup.py @@ -1,4 +1,5 @@ from setuptools import setup +from medperf._version import __version__ with open("requirements.txt", "r") as f: requires = [] @@ -9,7 +10,7 @@ setup( name="medperf", - version="0.0.0", + version=__version__, description="CLI Tool for federated benchmarking on medical private data", url="https://github.com/aristizabal95/medperf", author="MLCommons", diff --git a/server/benchmark/urls.py b/server/benchmark/urls.py index b24521a8b..cb120f4f5 100644 --- a/server/benchmark/urls.py +++ b/server/benchmark/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views +app_name = "Benchmark" urlpatterns = [ path("", views.BenchmarkList.as_view()), diff --git a/server/benchmark/views.py b/server/benchmark/views.py index 6e654cae0..a6c7e67b6 100644 --- a/server/benchmark/views.py +++ b/server/benchmark/views.py @@ -5,6 +5,7 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .models import Benchmark from .serializers import BenchmarkSerializer, BenchmarkApprovalSerializer @@ -15,6 +16,7 @@ class BenchmarkList(GenericAPIView): serializer_class = BenchmarkSerializer queryset = "" + @extend_schema(operation_id="benchmarks_retrieve_all") def get(self, request, format=None): """ List all benchmarks diff --git a/server/benchmarkdataset/views.py b/server/benchmarkdataset/views.py index 39f7d3ed3..f4ab47b50 100644 --- a/server/benchmarkdataset/views.py +++ b/server/benchmarkdataset/views.py @@ -3,6 +3,7 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .permissions import IsAdmin, IsDatasetOwner, IsBenchmarkOwner from .serializers import ( @@ -39,6 +40,7 @@ def get_object(self, pk): except BenchmarkDataset.DoesNotExist: raise Http404 + @extend_schema(operation_id="datasets_benchmarks_retrieve_all") def get(self, request, pk, format=None): """ Retrieve all benchmarks associated with a dataset diff --git a/server/benchmarkmodel/views.py b/server/benchmarkmodel/views.py index 463b1ff32..ccf841022 100644 --- a/server/benchmarkmodel/views.py +++ b/server/benchmarkmodel/views.py @@ -3,6 +3,7 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .permissions import IsAdmin, IsMlCubeOwner, IsBenchmarkOwner from .serializers import ( @@ -39,6 +40,7 @@ def get_object(self, pk): except BenchmarkModel.DoesNotExist: raise Http404 + @extend_schema(operation_id="mlcubes_benchmarks_retrieve_all") def get(self, request, pk, format=None): """ Retrieve all benchmarks associated with a model diff --git a/server/dataset/tests.py b/server/dataset/tests.py index 5fdef49f7..d94289b21 100644 --- a/server/dataset/tests.py +++ b/server/dataset/tests.py @@ -1,5 +1,6 @@ import string import random +from django.conf import settings from django.contrib.auth.models import User from rest_framework.test import APIClient from rest_framework import status @@ -12,14 +13,14 @@ class DatasetTest(MedPerfTest): def setUp(self): super(DatasetTest, self).setUp() - username = "dataowner" password = "".join(random.choice(string.ascii_letters) for m in range(10)) user = User.objects.create_user(username=username, password=password,) user.save() + self.api_prefix = "/api/" + settings.SERVER_API_VERSION self.client = APIClient() response = self.client.post( - "/auth-token/", {"username": username, "password": password}, format="json", + self.api_prefix + "/auth-token/", {"username": username, "password": password}, format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.token = response.data["token"] @@ -37,21 +38,21 @@ def setUp(self): "metadata": {"key": "value"}, } - response = self.client.post("/mlcubes/", data_preproc_mlcube, format="json") + response = self.client.post(self.api_prefix + "/mlcubes/", data_preproc_mlcube, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.data_preproc_mlcube_id = response.data["id"] def test_unauthenticated_user(self): client = APIClient() - response = client.get("/datasets/1/") + response = client.get(self.api_prefix + "/datasets/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete("/datasets/1/") + response = client.delete(self.api_prefix + "/datasets/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put("/datasets/1/") + response = client.put(self.api_prefix + "/datasets/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.post("/datasets/", {}) + response = client.post(self.api_prefix + "/datasets/", {}) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get("/datasets/") + response = client.get(self.api_prefix + "/datasets/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_crud_user(self): @@ -66,17 +67,17 @@ def test_crud_user(self): "data_preparation_mlcube": self.data_preproc_mlcube_id, } - response = self.client.post("/datasets/", testdataset, format="json") + response = self.client.post(self.api_prefix + "/datasets/", testdataset, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) uid = response.data["id"] - response = self.client.get("/datasets/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in response.data.items(): if k in testdataset: self.assertEqual(testdataset[k], v) - response = self.client.get("/datasets/") + response = self.client.get(self.api_prefix + "/datasets/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]), 1) @@ -92,10 +93,10 @@ def test_crud_user(self): } response = self.client.put( - "/datasets/{0}/".format(uid), newtestdataset, format="json" + self.api_prefix + "/datasets/{0}/".format(uid), newtestdataset, format="json" ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.client.get("/datasets/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in response.data.items(): @@ -103,15 +104,15 @@ def test_crud_user(self): self.assertEqual(newtestdataset[k], v) # TODO Revisit when delete permissions are fixed - # response = self.client.delete("/datasets/{0}/".format(uid)) + # response = self.client.delete(self.api_prefix + "/datasets/{0}/".format(uid)) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - # response = self.client.get("/datasets/{0}/".format(uid)) + # response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) # self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_invalid_dataset(self): invalid_id = 9999 - response = self.client.get("/datasets/{0}/".format(invalid_id)) + response = self.client.get(self.api_prefix + "/datasets/{0}/".format(invalid_id)) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_duplicate_gen_uid(self): @@ -127,10 +128,10 @@ def test_duplicate_gen_uid(self): "data_preparation_mlcube": self.data_preproc_mlcube_id, } - response = self.client.post("/datasets/", testdataset, format="json") + response = self.client.post(self.api_prefix + "/datasets/", testdataset, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) - response = self.client.post("/datasets/", testdataset, format="json") + response = self.client.post(self.api_prefix + "/datasets/", testdataset, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_optional_fields(self): diff --git a/server/dataset/urls.py b/server/dataset/urls.py index 8df29c459..9b8ce2601 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -2,6 +2,8 @@ from . import views from benchmarkdataset import views as bviews +app_name = "Dataset" + urlpatterns = [ path("", views.DatasetList.as_view()), path("/", views.DatasetDetail.as_view()), diff --git a/server/dataset/views.py b/server/dataset/views.py index af122dc13..82f207202 100644 --- a/server/dataset/views.py +++ b/server/dataset/views.py @@ -2,6 +2,7 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .models import Dataset from .permissions import IsAdmin, IsDatasetOwner @@ -12,6 +13,7 @@ class DatasetList(GenericAPIView): serializer_class = DatasetSerializer queryset = "" + @extend_schema(operation_id="datasets_retrieve_all") def get(self, request, format=None): """ List all datasets diff --git a/server/medperf/settings.py b/server/medperf/settings.py index a3c9fd424..61f5ee874 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -92,7 +92,8 @@ "result", "rest_framework", "rest_framework.authtoken", - "drf_yasg", + "drf_spectacular", + "drf_spectacular_sidecar", "corsheaders", ] @@ -207,23 +208,47 @@ # https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" +# Set this to supported api version. +# This will be the default version for unversioned apis picked by the swagger schema. +SERVER_API_VERSION = "v0" REST_FRAMEWORK = { - "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.coreapi.AutoSchema", + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", "DEFAULT_AUTHENTICATION_CLASSES": [ "rest_framework.authentication.TokenAuthentication", ], "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", + "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.NamespaceVersioning", + "DEFAULT_PARSER_CLASSES": [ + "rest_framework.parsers.JSONParser", + ], + "DEFAULT_VERSION": SERVER_API_VERSION, "PAGE_SIZE": 32, } -SWAGGER_SETTINGS = { - "USE_SESSION_AUTH": False, - "SECURITY_DEFINITIONS": { - "api_key": {"type": "apiKey", "in": "header", "name": "Authorization"} +SPECTACULAR_SETTINGS = { + "SWAGGER_UI_SETTINGS": { + "deepLinking": True, + "displayRequestDuration": True, + "tryItOutEnabled": True, + "filter": True, + "syntaxHighlight.activate": True, + "syntaxHighlight.theme": "monokai", + # other swagger settings }, - "JSON_EDITOR": True, + "TITLE": "MedPerf API", + "DESCRIPTION": "MedPerf API description", + "VERSION": None, + "SERVE_INCLUDE_SCHEMA": True, + "PARSER_WHITELIST": [ + "rest_framework.parsers.JSONParser", + ], + 'SCHEMA_PATH_PREFIX': r'/api/v[0-9]', + "SWAGGER_UI_DIST": "SIDECAR", # shorthand to use the sidecar instead + "SWAGGER_UI_FAVICON_HREF": "SIDECAR", + "REDOC_DIST": "SIDECAR", + # other spectacular settings } # Setup support for proxy headers diff --git a/server/medperf/urls.py b/server/medperf/urls.py index cc175b638..93c0f6493 100644 --- a/server/medperf/urls.py +++ b/server/medperf/urls.py @@ -16,38 +16,27 @@ from django.contrib import admin from django.urls import include, re_path, path from rest_framework.authtoken.views import obtain_auth_token -from rest_framework import permissions -from drf_yasg.views import get_schema_view -from drf_yasg import openapi +from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView, SpectacularRedocView +from django.conf import settings -schema_view = get_schema_view( - openapi.Info( - title="MedPerf API", - default_version="v1", - description="MedPerf API description", - ), - public=True, - permission_classes=(permissions.AllowAny,), -) +from utils.views import ServerAPIVersion + +API_VERSION = settings.SERVER_API_VERSION +API_PREFIX = 'api/' + API_VERSION + '/' urlpatterns = [ - re_path( - r"^swagger(?P\.json|\.yaml)$", - schema_view.without_ui(cache_timeout=0), - name="schema-json", - ), - re_path( - r"^swagger/$", - schema_view.with_ui("swagger", cache_timeout=0), - name="schema-swagger-ui", - ), - re_path(r"^$", schema_view.with_ui("redoc", cache_timeout=0), name="schema-redoc",), - path("admin/", admin.site.urls), - path("benchmarks/", include("benchmark.urls")), - path("mlcubes/", include("mlcube.urls")), - path("datasets/", include("dataset.urls")), - path("results/", include("result.urls")), - path("users/", include("user.urls")), - path("me/", include("utils.urls")), - path("auth-token/", obtain_auth_token, name="auth-token"), + path("schema/", SpectacularAPIView.as_view(api_version=API_VERSION), name="schema"), + path("swagger/", SpectacularSwaggerView.as_view(), name="swagger-ui"), + re_path(r"^$", SpectacularRedocView.as_view(), name="redoc"), + path("admin/", admin.site.urls, name="admin"), + path("version", ServerAPIVersion.as_view(), name="get-version"), + path(API_PREFIX, include([ + path("benchmarks/", include("benchmark.urls", namespace=API_VERSION), name="benchmark"), + path("mlcubes/", include("mlcube.urls", namespace=API_VERSION), name="mlcube"), + path("datasets/", include("dataset.urls", namespace=API_VERSION), name="dataset"), + path("results/", include("result.urls", namespace=API_VERSION), name="result"), + path("users/", include("user.urls", namespace=API_VERSION), name="users"), + path("me/", include("utils.urls", namespace=API_VERSION), name="me"), + path("auth-token/", obtain_auth_token, name="auth-token"), + ])), ] diff --git a/server/mlcube/tests.py b/server/mlcube/tests.py index af69f3308..20ab032fe 100644 --- a/server/mlcube/tests.py +++ b/server/mlcube/tests.py @@ -1,5 +1,6 @@ import string import random +from django.conf import settings from django.contrib.auth.models import User from rest_framework.test import APIClient from rest_framework import status @@ -12,14 +13,14 @@ class MlCubeTest(MedPerfTest): def setUp(self): super(MlCubeTest, self).setUp() - username = "mlcubeowner" password = "".join(random.choice(string.ascii_letters) for m in range(10)) user = User.objects.create_user(username=username, password=password,) user.save() + self.api_prefix = "/api/" + settings.SERVER_API_VERSION self.client = APIClient() response = self.client.post( - "/auth-token/", {"username": username, "password": password}, format="json", + self.api_prefix + "/auth-token/", {"username": username, "password": password}, format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.token = response.data["token"] @@ -27,15 +28,15 @@ def setUp(self): def test_unauthenticated_user(self): client = APIClient() - response = client.get("/mlcubes/1/") + response = client.get(self.api_prefix + "/mlcubes/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete("/mlcubes/1/") + response = client.delete(self.api_prefix + "/mlcubes/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put("/mlcubes/1/") + response = client.put(self.api_prefix + "/mlcubes/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.post("/mlcubes/", {}) + response = client.post(self.api_prefix + "/mlcubes/", {}) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get("/mlcubes/") + response = client.get(self.api_prefix + "/mlcubes/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_crud_user(self): @@ -52,17 +53,17 @@ def test_crud_user(self): "metadata": {"key": "value"}, } - response = self.client.post("/mlcubes/", testmlcube, format="json") + response = self.client.post(self.api_prefix + "/mlcubes/", testmlcube, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) uid = response.data["id"] - response = self.client.get("/mlcubes/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in response.data.items(): if k in testmlcube: self.assertEqual(testmlcube[k], v) - response = self.client.get("/mlcubes/") + response = self.client.get(self.api_prefix + "/mlcubes/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]), 1) @@ -76,10 +77,10 @@ def test_crud_user(self): } response = self.client.put( - "/mlcubes/{0}/".format(uid), newmlcube, format="json" + self.api_prefix + "/mlcubes/{0}/".format(uid), newmlcube, format="json" ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.client.get("/mlcubes/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in response.data.items(): @@ -87,15 +88,15 @@ def test_crud_user(self): self.assertEqual(newmlcube[k], v) # TODO Revisit when delete permissions are fixed - # response = self.client.delete("/mlcubes/{0}/".format(uid)) + # response = self.client.delete(self.api_prefix + "/mlcubes/{0}/".format(uid)) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - # response = self.client.get("/mlcubes/{0}/".format(uid)) + # response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) # self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_invalid_mlcube(self): invalid_id = 9999 - response = self.client.get("/mlcubes/{0}/".format(invalid_id)) + response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(invalid_id)) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_optional_fields(self): diff --git a/server/mlcube/urls.py b/server/mlcube/urls.py index 15341e4a3..50cfd99b0 100644 --- a/server/mlcube/urls.py +++ b/server/mlcube/urls.py @@ -2,6 +2,8 @@ from benchmarkmodel import views as bviews from . import views +app_name = "MLCube" + urlpatterns = [ path("", views.MlCubeList.as_view()), path("/", views.MlCubeDetail.as_view()), diff --git a/server/mlcube/views.py b/server/mlcube/views.py index b2e52c6cf..933cfc5c5 100644 --- a/server/mlcube/views.py +++ b/server/mlcube/views.py @@ -2,9 +2,10 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema + from .models import MlCube from .serializers import MlCubeSerializer, MlCubeDetailSerializer - from .permissions import IsAdmin, IsMlCubeOwner @@ -12,6 +13,7 @@ class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" + @extend_schema(operation_id="mlcubes_retrieve_all") def get(self, request, format=None): """ List all mlcubes diff --git a/server/requirements.txt b/server/requirements.txt index b0e2ab711..a87d36107 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,6 +1,7 @@ Django==3.2.16 -djangorestframework==3.13.1 -drf-yasg==1.20.0 +djangorestframework==3.14.0 +drf-spectacular==0.25.1 +drf-spectacular-sidecar==2022.12.1 django-storages[google]==1.12.3 django-environ==0.8.1 django-cors-headers==3.10.1 @@ -12,4 +13,4 @@ pyOpenSSL==22.0.0 Werkzeug==2.0.2 django-extensions==3.1.5 #Temporary fix for pyOpenSSL. https://github.com/aws/aws-sam-cli/issues/4527#issuecomment-1369776818 -cryptography==38.0.4 \ No newline at end of file +cryptography==38.0.4 diff --git a/server/result/urls.py b/server/result/urls.py index aa8d2a78b..57f1cacfd 100644 --- a/server/result/urls.py +++ b/server/result/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views +app_name = "Result" urlpatterns = [ path("", views.ModelResultList.as_view()), diff --git a/server/result/views.py b/server/result/views.py index bc1800331..b650045d8 100644 --- a/server/result/views.py +++ b/server/result/views.py @@ -2,6 +2,8 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema + from .models import ModelResult from .serializers import ModelResultSerializer from .permissions import IsAdmin, IsBenchmarkOwner, IsDatasetOwner, IsResultOwner @@ -18,6 +20,7 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsDatasetOwner] return super(self.__class__, self).get_permissions() + @extend_schema(operation_id="results_retrieve_all") def get(self, request, format=None): """ List all results diff --git a/server/seed.py b/server/seed.py index 38fd96653..6571212ee 100644 --- a/server/seed.py +++ b/server/seed.py @@ -15,6 +15,29 @@ def __init__(self, host, cert): self.host = host self.cert = cert + def validate(self, verify=False, version=None): + try: + resp = requests.request( + method="GET", + url=self.host + "/version", + verify=self.cert, + ) + except requests.exceptions.RequestException as e: + raise SystemExit(e) + if resp.status_code != 200: + sys.exit("Response code is " + str(resp.status_code)) + + res = json.loads(resp.text) + if "version" not in res: + sys.exit("Version response is empty") + print("Server running at version " + res["version"]) + if verify: + if res["version"] != version: + sys.exit("Server version do not match with the client argument") + print("Server version match with client version") + self.version = res["version"] + return res["version"] + def request(self, endpoint, method, token, data, out_field=None): headers = {} if token: @@ -26,7 +49,7 @@ def request(self, endpoint, method, token, data, out_field=None): resp = requests.request( method=method, headers=headers, - url=self.host + endpoint, + url=self.host + "/api/" + self.version + endpoint, data=json.dumps(data), verify=self.cert, ) @@ -52,6 +75,10 @@ def request(self, endpoint, method, token, data, out_field=None): def seed(args): # Get Admin API token using admin credentials api_server = Server(host=args.server, cert=args.cert) + if args.version: + api_server.validate(True, args.version) + else: + api_server.validate(False) admin_token = api_server.request( "/auth-token/", "POST", @@ -383,5 +410,6 @@ def seed(args): parser.add_argument("--username", type=str, help="Admin username", default="admin") parser.add_argument("--password", type=str, help="Admin password", default="admin") parser.add_argument("--cert", type=str, help="Server certificate") + parser.add_argument("--version", type=str, help="Server version") args = parser.parse_args() seed(args) diff --git a/server/user/tests.py b/server/user/tests.py index c57da3876..5488b28dc 100644 --- a/server/user/tests.py +++ b/server/user/tests.py @@ -1,5 +1,6 @@ from rest_framework.test import APIClient from rest_framework import status +from django.conf import settings from medperf.tests import MedPerfTest @@ -9,12 +10,12 @@ class UserTest(MedPerfTest): def setUp(self): super(UserTest, self).setUp() - username = "admin" password = "admin" + self.api_prefix = "/api/" + settings.SERVER_API_VERSION self.client = APIClient() response = self.client.post( - "/auth-token/", {"username": username, "password": password}, format="json", + self.api_prefix + "/auth-token/", {"username": username, "password": password}, format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.token = response.data["token"] @@ -22,15 +23,15 @@ def setUp(self): def test_unauthenticated_user(self): client = APIClient() - response = client.get("/users/1/") + response = client.get(self.api_prefix + "/users/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete("/users/1/") + response = client.delete(self.api_prefix + "/users/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put("/users/1/") + response = client.put(self.api_prefix + "/users/1/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.post("/users/", {}) + response = client.post(self.api_prefix + "/users/", {}) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get("/users/") + response = client.get(self.api_prefix + "/users/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_crud_user(self): @@ -42,24 +43,24 @@ def test_crud_user(self): "last_name": "owner", } - response = self.client.post("/users/", testuser, format="json") + response = self.client.post(self.api_prefix + "/users/", testuser, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) uid = response.data["id"] - response = self.client.get("/users/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/users/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in response.data.items(): if k in testuser: self.assertEqual(testuser[k], v) - response = self.client.get("/users/") + response = self.client.get(self.api_prefix + "/users/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]), 2) - response = self.client.delete("/users/{0}/".format(uid)) + response = self.client.delete(self.api_prefix + "/users/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - response = self.client.get("/users/{0}/".format(uid)) + response = self.client.get(self.api_prefix + "/users/{0}/".format(uid)) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_duplicate_usernames(self): @@ -71,10 +72,10 @@ def test_duplicate_usernames(self): "last_name": "owner", } - response = self.client.post("/users/", testuser, format="json") + response = self.client.post(self.api_prefix + "/users/", testuser, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) - response = self.client.post("/users/", testuser, format="json") + response = self.client.post(self.api_prefix + "/users/", testuser, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_duplicate_emails(self): @@ -86,7 +87,7 @@ def test_duplicate_emails(self): "last_name": "owner", } - response = self.client.post("/users/", testuser, format="json") + response = self.client.post(self.api_prefix + "/users/", testuser, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) testuser = { @@ -97,12 +98,12 @@ def test_duplicate_emails(self): "last_name": "owner", } - response = self.client.post("/users/", testuser, format="json") + response = self.client.post(self.api_prefix + "/users/", testuser, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_invalid_user(self): invalid_uid = 9999 - response = self.client.get("/users/{0}/".format(invalid_uid)) + response = self.client.get(self.api_prefix + "/users/{0}/".format(invalid_uid)) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_optional_fields(self): diff --git a/server/user/urls.py b/server/user/urls.py index b2670b53a..8c277fb3a 100644 --- a/server/user/urls.py +++ b/server/user/urls.py @@ -1,6 +1,8 @@ from django.urls import path from . import views +app_name = "User" + urlpatterns = [ path("", views.UserList.as_view()), path("/", views.UserDetail.as_view()), diff --git a/server/user/views.py b/server/user/views.py index 29168acf6..60f6fd1b5 100644 --- a/server/user/views.py +++ b/server/user/views.py @@ -3,6 +3,7 @@ from rest_framework.response import Response from rest_framework import status from django.contrib.auth.models import User +from drf_spectacular.utils import extend_schema from .serializers import UserSerializer from .permissions import IsAdmin, IsOwnUser @@ -13,6 +14,7 @@ class UserList(GenericAPIView): serializer_class = UserSerializer queryset = "" + @extend_schema(operation_id="users_retrieve_all") def get(self, request, format=None): """ List all users diff --git a/server/utils/urls.py b/server/utils/urls.py index 95b4901ed..f9eb68f2d 100644 --- a/server/utils/urls.py +++ b/server/utils/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views +app_name = "MyUser" urlpatterns = [ path("", views.User.as_view()), diff --git a/server/utils/views.py b/server/utils/views.py index 351c86484..f1e3c815b 100644 --- a/server/utils/views.py +++ b/server/utils/views.py @@ -12,10 +12,14 @@ from benchmarkmodel.models import BenchmarkModel from benchmarkdataset.models import BenchmarkDataset from django.http import Http404 +from django.conf import settings from django.db.models import Q from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from rest_framework.permissions import AllowAny +from drf_spectacular.utils import extend_schema, inline_serializer +from rest_framework import serializers class User(GenericAPIView): @@ -172,3 +176,25 @@ def get(self, request, format=None): benchmarkmodels = self.paginate_queryset(benchmarkmodels) serializer = BenchmarkModelListSerializer(benchmarkmodels, many=True) return self.get_paginated_response(serializer.data) + + +class ServerAPIVersion(GenericAPIView): + permission_classes = (AllowAny,) + queryset = "" + + @extend_schema( + responses={ + 200: inline_serializer( + name='VersionResponse', + fields={ + 'version': serializers.CharField(), + } + ) + } + ) + def get(self, request=None, format=None): + """ + Retrieve version of Server API + """ + result = {"version": settings.SERVER_API_VERSION} + return Response(result)