diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index b347335..e84526c 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[testing] + pip install -e .[test] - name: Install SMART sandbox run: | git clone https://github.com/smart-on-fhir/smart-launcher-v2.git @@ -34,8 +34,6 @@ jobs: git switch -c aa0f3b1 # Fix the version we use for the sandbox npm ci npm run build - env: - PORT: 5555 - name: Run tests run: | pytest tests/ diff --git a/jupyter_smart_on_fhir/server_extension.py b/jupyter_smart_on_fhir/server_extension.py index 5e30e8e..320473a 100644 --- a/jupyter_smart_on_fhir/server_extension.py +++ b/jupyter_smart_on_fhir/server_extension.py @@ -1,12 +1,14 @@ import base64 import hashlib +import json import secrets from urllib.parse import urlencode, urljoin -import requests import tornado from jupyter_server.base.handlers import JupyterHandler from jupyter_server.extension.application import ExtensionApp +from jupyter_server.utils import url_path_join +from tornado.httpclient import AsyncHTTPClient from traitlets import List, Unicode from jupyter_smart_on_fhir.auth import SMARTConfig, generate_state @@ -54,7 +56,7 @@ class SMARTAuthHandler(JupyterHandler): """Handler for SMART on FHIR authentication""" @tornado.web.authenticated - def get(self): + async def get(self): fhir_url = self.get_argument("iss") smart_config = SMARTConfig.from_url(fhir_url, self.request.full_url()) self.settings["launch"] = self.get_argument("launch") @@ -64,11 +66,11 @@ def get(self): self.settings["next_url"] = self.request.uri self.redirect(login_path) else: - data = self.get_data(token) + data = await self.get_data(token) self.write(f"Authorization success: Fetched {str(data)}") self.finish() - def get_data(self, token: str) -> dict: + async def get_data(self, token: str) -> dict: headers = { "Authorization": f"Bearer {token}", "Accept": "application/fhir+json", @@ -77,11 +79,8 @@ def get_data(self, token: str) -> dict: url = ( f"{self.settings['smart_config'].fhir_url}/Condition" # Endpoint with data ) - f = requests.get(url, headers=headers) - try: - return f.json() - except requests.exceptions.JSONDecodeError: - raise RuntimeError(f.text) + resp = await AsyncHTTPClient().fetch(url, headers=headers) + return json.loads(resp.body.decode("utf8", "replace")) class SMARTLoginHandler(JupyterHandler): @@ -105,7 +104,9 @@ def get(self): "aud": smart_config.fhir_url, "state": state["state_id"], "launch": self.settings["launch"], - "redirect_uri": urljoin(self.request.full_url(), callback_path), + "redirect_uri": urljoin( + self.request.full_url(), url_path_join(self.base_url, callback_path) + ), "client_id": self.settings["client_id"], "code_challenge": code_challenge, "code_challenge_method": "S256", @@ -118,22 +119,27 @@ def get(self): class SMARTCallbackHandler(JupyterHandler): """Callback handler for SMART on FHIR""" - def token_for_code(self, code: str) -> str: + async def token_for_code(self, code: str) -> str: data = dict( client_id=self.settings["client_id"], grant_type="authorization_code", code=code, - code_verifier=self.get_signed_cookie("code_verifier"), - redirect_uri=urljoin(self.request.full_url(), callback_path), + code_verifier=self.get_signed_cookie("code_verifier").decode("ascii"), + redirect_uri=urljoin( + self.request.full_url(), url_path_join(self.base_url, callback_path) + ), ) headers = {"Content-Type": "application/x-www-form-urlencoded"} - token_reply = requests.post( - self.settings["smart_config"].token_url, data=data, headers=headers + token_reply = await AsyncHTTPClient().fetch( + self.settings["smart_config"].token_url, + body=urlencode(data), + headers=headers, + method="POST", ) - return token_reply.json()["access_token"] + return json.loads(token_reply.body.decode("utf8", "replace"))["access_token"] @tornado.web.authenticated - def get(self): + async def get(self): if "error" in self.request.arguments: raise tornado.web.HTTPError(400, self.get_argument("error")) code = self.get_argument("code") @@ -141,12 +147,18 @@ def get(self): raise tornado.web.HTTPError( 400, "Error: no code in response from FHIR server" ) - state_id = self.get_signed_cookie("state_id").decode("utf-8") - if self.get_argument("state") != state_id: + state_id = self.get_signed_cookie("state_id", b"").decode("utf-8") + if not state_id: + raise tornado.web.HTTPError(400, "Error: missing state cookie") + arg_state = self.get_argument("state") + if not arg_state: + raise tornado.web.HTTPError(400, "Error: missing state argument") + if arg_state != state_id: raise tornado.web.HTTPError( 400, "Error: state received from FHIR server does not match" ) - self.settings["smart_token"] = self.token_for_code(code) + self.settings["smart_token"] = await self.token_for_code(code) + next_url = self.settings["next_url"] self.redirect(self.settings["next_url"]) diff --git a/pyproject.toml b/pyproject.toml index 896b337..b1d0b7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "jupyter_smart_on_fhir" -version = "0.1.0" +version = "0.1.0.dev" dependencies = [ "flask", "tornado", @@ -16,7 +16,7 @@ dependencies = [ ] [project.optional-dependencies] -testing = ["pytest"] +test = ["pytest", "pytest-jupyter[server]"] [tool.setuptools] diff --git a/tests/conftest.py b/tests/conftest.py index 84ccb29..6af8121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ import pytest import requests +pytest_plugins = ["pytest_jupyter.jupyter_server"] + @pytest.fixture(scope="function") # module? def sandbox(): @@ -16,7 +18,7 @@ def sandbox(): os.environ["PORT"] = str(port) url = f"http://localhost:{port}" with subprocess.Popen( - ["npm", "run", "start:prod"], cwd=os.environ.get("SANDBOX_DIR", ".") + ["npm", "run", "start:prod"], cwd=os.environ["SANDBOX_DIR"] ) as sandbox_proc: wait_for_server(url) yield url diff --git a/tests/test_server_extension.py b/tests/test_server_extension.py index 7f36d98..c8be97a 100644 --- a/tests/test_server_extension.py +++ b/tests/test_server_extension.py @@ -1,66 +1,40 @@ -import os -import subprocess +import json +from http.cookies import SimpleCookie +from urllib.parse import parse_qsl, urlparse, urlunparse import pytest -import requests -from conftest import SandboxConfig, wait_for_server +from conftest import SandboxConfig +from jupyter_server.utils import url_path_join +from tornado.httpclient import AsyncHTTPClient, HTTPClientError +from traitlets.config import Config from jupyter_smart_on_fhir.server_extension import callback_path, login_path, smart_path -PORT = os.getenv("TEST_PORT", 18888) -ext_url = f"http://localhost:{PORT}" - - -def request_api(url, session=None, params=None, **kwargs): - query_args = {"token": "secret"} - query_args.update(params or {}) - session = session or requests.Session() - return session.get(url, params=query_args, **kwargs) - @pytest.fixture -def jupyterdir(tmpdir): - path = tmpdir.join("jupyter") - path.mkdir() - return str(path) +def client_id(): + return "client_id" @pytest.fixture -def jupyter_server(tmpdir, jupyterdir): - client_id = os.environ["CLIENT_ID"] = "client_id" - env = os.environ.copy() - # avoid interacting with user configuration, state - env["JUPYTER_CONFIG_DIR"] = str(tmpdir / "dotjupyter") - env["JUPYTER_RUNTIME_DIR"] = str(tmpdir / "runjupyter") - - extension_command = ["jupyter", "server", "extension"] - command = [ - "jupyter-server", - "--ServerApp.token=secret", - f"--SMARTExtensionApp.client_id={client_id}", - f"--port={PORT}", - ] - subprocess.check_call( - extension_command + ["enable", "jupyter_smart_on_fhir.server_extension"], - env=env, - ) +def jp_server_config(client_id): + c = Config() + c.ServerApp.jpserver_extensions = {"jupyter_smart_on_fhir.server_extension": True} + c.SMARTExtensionApp.client_id = client_id - # launch the server - with subprocess.Popen(command, cwd=jupyterdir, env=env) as jupyter_proc: - wait_for_server(ext_url) - yield jupyter_proc - jupyter_proc.terminate() + return c -def test_uninformed_endpoint(jupyter_server): - response = request_api(ext_url + smart_path) - assert response.status_code == 400 +async def test_uninformed_endpoint(jp_fetch): + with pytest.raises(HTTPClientError) as e: + await jp_fetch(smart_path) + assert e.value.code == 400 -@pytest.fixture(scope="function") -def public_client(): +@pytest.fixture +def public_client(client_id): return SandboxConfig( - client_id=os.environ["CLIENT_ID"], + client_id=client_id, client_type=0, pkce_validation=2, # setting IDs so we omit login screen in sandbox; unsure I would test that flow @@ -69,31 +43,75 @@ def public_client(): ) -def test_login_handler(jupyter_server, sandbox, public_client): - """I think this test can be splitted in three with some engineering. Perhaps useful, not sure""" - session = requests.Session() +async def test_login_handler( + http_server_client, jp_base_url, jp_fetch, jp_serverapp, sandbox, public_client +): + """I think this test can be split in three with some engineering. Perhaps useful, not sure""" # Try endpoint and get redirected to login query = {"iss": f"{sandbox}/v/r4/fhir", "launch": public_client.get_launch_code()} - response = request_api( - ext_url + smart_path, params=query, allow_redirects=False, session=session - ) - assert response.status_code == 302 + with pytest.raises(HTTPClientError) as exc_info: + response = await jp_fetch( + smart_path, + params=query, + follow_redirects=False, + ) + response = exc_info.value.response + assert response.code == 302 assert response.headers["Location"] == login_path # Login with headers and get redirected to auth url - response = request_api(ext_url + login_path, session=session, allow_redirects=False) - assert response.status_code == 302 + with pytest.raises(HTTPClientError) as exc_info: + response = await jp_fetch(login_path, follow_redirects=False) + response = exc_info.value.response + assert response.code == 302 auth_url = response.headers["Location"] assert auth_url.startswith(sandbox) + cookie = SimpleCookie() + for c in response.headers.get_list("Set-Cookie"): + cookie.load(c) # Internally, get redirected to provider-auth - response = request_api(auth_url, session=session, allow_redirects=False) - assert response.status_code == 302 + with pytest.raises(HTTPClientError) as exc_info: + http_client = AsyncHTTPClient() + response = await http_client.fetch(auth_url, follow_redirects=False) + response = exc_info.value.response + assert response.code == 302 callback_url = response.headers["Location"] - assert callback_url.startswith(ext_url + callback_path) - assert "code=" in callback_url - response = request_api(callback_url, session=session) - assert response.status_code == 200 - assert response.url.startswith(ext_url + smart_path) - - # TODO: Should I test token existence? And how? + callback_url_parsed = urlparse(callback_url) + # strip proto://host for jp_fetch + server_callback_url = urlunparse(callback_url_parsed._replace(netloc="", scheme="")) + params = dict(parse_qsl(callback_url_parsed.query)) + # SMART does different URL escaping + # SMART dev server appears to do some weird unescaping with callback URL + server_callback_url = server_callback_url.replace("@", "%40") + assert server_callback_url.startswith(url_path_join(jp_base_url, callback_path)) + assert "code" in params + + cookie_header = "; ".join( + f"{morsel.key}={morsel.coded_value}" for morsel in cookie.values() + ) + with pytest.raises(HTTPClientError) as exc_info: + await jp_fetch( + callback_path, + params=params, + headers={"Cookie": cookie_header}, + follow_redirects=False, + ) + response = exc_info.value.response + assert response.code == 302 + dest_url = response.headers["Location"] + + # TODO: test dest_url? + assert urlparse(dest_url).path.startswith(url_path_join(jp_base_url, smart_path)) + + # verify that token was issued and works + assert "smart_token" in jp_serverapp.web_app.settings + token = jp_serverapp.web_app.settings["smart_token"] + smart_config = jp_serverapp.web_app.settings["smart_config"] + url = url_path_join(smart_config.fhir_url, "Condition") + resp = await http_client.fetch(url, headers={"Authorization": f"Bearer {token}"}) + data = json.loads(resp.body.decode("utf8")) + assert data + assert isinstance(data, dict) + assert "resourceType" in data + assert data["resourceType"] == "Bundle"