Skip to content

Commit

Permalink
adopt pytest-jupyter in test_server_extension
Browse files Browse the repository at this point in the history
and trade sync requests for async tornado client
  • Loading branch information
minrk committed Sep 20, 2024
1 parent 0a698d0 commit b0e8fad
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 91 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[testing]
pip install -e .[test]
- name: Install SMART sandbox
run: |
git clone https://github.com/smart-on-fhir/smart-launcher-v2.git
cd smart-launcher-v2
git switch -c aa0f3b1 # Fix the version we use for the sandbox
npm ci
npm run build
env:
PORT: 5555
- name: Run tests
run: |
pytest tests/
Expand Down
52 changes: 32 additions & 20 deletions jupyter_smart_on_fhir/server_extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import base64
import hashlib
import json
import secrets
from urllib.parse import urlencode, urljoin

import requests
import tornado
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.utils import url_path_join
from tornado.httpclient import AsyncHTTPClient
from traitlets import List, Unicode

from jupyter_smart_on_fhir.auth import SMARTConfig, generate_state
Expand Down Expand Up @@ -54,7 +56,7 @@ class SMARTAuthHandler(JupyterHandler):
"""Handler for SMART on FHIR authentication"""

@tornado.web.authenticated
def get(self):
async def get(self):
fhir_url = self.get_argument("iss")
smart_config = SMARTConfig.from_url(fhir_url, self.request.full_url())
self.settings["launch"] = self.get_argument("launch")
Expand All @@ -64,11 +66,11 @@ def get(self):
self.settings["next_url"] = self.request.uri
self.redirect(login_path)
else:
data = self.get_data(token)
data = await self.get_data(token)
self.write(f"Authorization success: Fetched {str(data)}")
self.finish()

def get_data(self, token: str) -> dict:
async def get_data(self, token: str) -> dict:
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/fhir+json",
Expand All @@ -77,11 +79,8 @@ def get_data(self, token: str) -> dict:
url = (
f"{self.settings['smart_config'].fhir_url}/Condition" # Endpoint with data
)
f = requests.get(url, headers=headers)
try:
return f.json()
except requests.exceptions.JSONDecodeError:
raise RuntimeError(f.text)
resp = await AsyncHTTPClient().fetch(url, headers=headers)
return json.loads(resp.body.decode("utf8", "replace"))


class SMARTLoginHandler(JupyterHandler):
Expand All @@ -105,7 +104,9 @@ def get(self):
"aud": smart_config.fhir_url,
"state": state["state_id"],
"launch": self.settings["launch"],
"redirect_uri": urljoin(self.request.full_url(), callback_path),
"redirect_uri": urljoin(
self.request.full_url(), url_path_join(self.base_url, callback_path)
),
"client_id": self.settings["client_id"],
"code_challenge": code_challenge,
"code_challenge_method": "S256",
Expand All @@ -118,35 +119,46 @@ def get(self):
class SMARTCallbackHandler(JupyterHandler):
"""Callback handler for SMART on FHIR"""

def token_for_code(self, code: str) -> str:
async def token_for_code(self, code: str) -> str:
data = dict(
client_id=self.settings["client_id"],
grant_type="authorization_code",
code=code,
code_verifier=self.get_signed_cookie("code_verifier"),
redirect_uri=urljoin(self.request.full_url(), callback_path),
code_verifier=self.get_signed_cookie("code_verifier").decode("ascii"),
redirect_uri=urljoin(
self.request.full_url(), url_path_join(self.base_url, callback_path)
),
)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
token_reply = requests.post(
self.settings["smart_config"].token_url, data=data, headers=headers
token_reply = await AsyncHTTPClient().fetch(
self.settings["smart_config"].token_url,
body=urlencode(data),
headers=headers,
method="POST",
)
return token_reply.json()["access_token"]
return json.loads(token_reply.body.decode("utf8", "replace"))["access_token"]

@tornado.web.authenticated
def get(self):
async def get(self):
if "error" in self.request.arguments:
raise tornado.web.HTTPError(400, self.get_argument("error"))
code = self.get_argument("code")
if not code:
raise tornado.web.HTTPError(
400, "Error: no code in response from FHIR server"
)
state_id = self.get_signed_cookie("state_id").decode("utf-8")
if self.get_argument("state") != state_id:
state_id = self.get_signed_cookie("state_id", b"").decode("utf-8")
if not state_id:
raise tornado.web.HTTPError(400, "Error: missing state cookie")
arg_state = self.get_argument("state")
if not arg_state:
raise tornado.web.HTTPError(400, "Error: missing state argument")
if arg_state != state_id:
raise tornado.web.HTTPError(
400, "Error: state received from FHIR server does not match"
)
self.settings["smart_token"] = self.token_for_code(code)
self.settings["smart_token"] = await self.token_for_code(code)
next_url = self.settings["next_url"]
self.redirect(self.settings["next_url"])


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jupyter_smart_on_fhir"
version = "0.1.0"
version = "0.1.0.dev"
dependencies = [
"flask",
"tornado",
Expand All @@ -16,7 +16,7 @@ dependencies = [
]

[project.optional-dependencies]
testing = ["pytest"]
test = ["pytest", "pytest-jupyter[server]"]


[tool.setuptools]
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
import pytest
import requests

pytest_plugins = ["pytest_jupyter.jupyter_server"]


@pytest.fixture(scope="function") # module?
def sandbox():
port = 5555
os.environ["PORT"] = str(port)
url = f"http://localhost:{port}"
with subprocess.Popen(
["npm", "run", "start:prod"], cwd=os.environ.get("SANDBOX_DIR", ".")
["npm", "run", "start:prod"], cwd=os.environ["SANDBOX_DIR"]
) as sandbox_proc:
wait_for_server(url)
yield url
Expand Down
148 changes: 83 additions & 65 deletions tests/test_server_extension.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,40 @@
import os
import subprocess
import json
from http.cookies import SimpleCookie
from urllib.parse import parse_qsl, urlparse, urlunparse

import pytest
import requests
from conftest import SandboxConfig, wait_for_server
from conftest import SandboxConfig
from jupyter_server.utils import url_path_join
from tornado.httpclient import AsyncHTTPClient, HTTPClientError
from traitlets.config import Config

from jupyter_smart_on_fhir.server_extension import callback_path, login_path, smart_path

PORT = os.getenv("TEST_PORT", 18888)
ext_url = f"http://localhost:{PORT}"


def request_api(url, session=None, params=None, **kwargs):
query_args = {"token": "secret"}
query_args.update(params or {})
session = session or requests.Session()
return session.get(url, params=query_args, **kwargs)


@pytest.fixture
def jupyterdir(tmpdir):
path = tmpdir.join("jupyter")
path.mkdir()
return str(path)
def client_id():
return "client_id"


@pytest.fixture
def jupyter_server(tmpdir, jupyterdir):
client_id = os.environ["CLIENT_ID"] = "client_id"
env = os.environ.copy()
# avoid interacting with user configuration, state
env["JUPYTER_CONFIG_DIR"] = str(tmpdir / "dotjupyter")
env["JUPYTER_RUNTIME_DIR"] = str(tmpdir / "runjupyter")

extension_command = ["jupyter", "server", "extension"]
command = [
"jupyter-server",
"--ServerApp.token=secret",
f"--SMARTExtensionApp.client_id={client_id}",
f"--port={PORT}",
]
subprocess.check_call(
extension_command + ["enable", "jupyter_smart_on_fhir.server_extension"],
env=env,
)
def jp_server_config(client_id):
c = Config()
c.ServerApp.jpserver_extensions = {"jupyter_smart_on_fhir.server_extension": True}
c.SMARTExtensionApp.client_id = client_id

# launch the server
with subprocess.Popen(command, cwd=jupyterdir, env=env) as jupyter_proc:
wait_for_server(ext_url)
yield jupyter_proc
jupyter_proc.terminate()
return c


def test_uninformed_endpoint(jupyter_server):
response = request_api(ext_url + smart_path)
assert response.status_code == 400
async def test_uninformed_endpoint(jp_fetch):
with pytest.raises(HTTPClientError) as e:
await jp_fetch(smart_path)
assert e.value.code == 400


@pytest.fixture(scope="function")
def public_client():
@pytest.fixture
def public_client(client_id):
return SandboxConfig(
client_id=os.environ["CLIENT_ID"],
client_id=client_id,
client_type=0,
pkce_validation=2,
# setting IDs so we omit login screen in sandbox; unsure I would test that flow
Expand All @@ -69,31 +43,75 @@ def public_client():
)


def test_login_handler(jupyter_server, sandbox, public_client):
"""I think this test can be splitted in three with some engineering. Perhaps useful, not sure"""
session = requests.Session()
async def test_login_handler(
http_server_client, jp_base_url, jp_fetch, jp_serverapp, sandbox, public_client
):
"""I think this test can be split in three with some engineering. Perhaps useful, not sure"""
# Try endpoint and get redirected to login
query = {"iss": f"{sandbox}/v/r4/fhir", "launch": public_client.get_launch_code()}
response = request_api(
ext_url + smart_path, params=query, allow_redirects=False, session=session
)
assert response.status_code == 302
with pytest.raises(HTTPClientError) as exc_info:
response = await jp_fetch(
smart_path,
params=query,
follow_redirects=False,
)
response = exc_info.value.response
assert response.code == 302
assert response.headers["Location"] == login_path

# Login with headers and get redirected to auth url
response = request_api(ext_url + login_path, session=session, allow_redirects=False)
assert response.status_code == 302
with pytest.raises(HTTPClientError) as exc_info:
response = await jp_fetch(login_path, follow_redirects=False)
response = exc_info.value.response
assert response.code == 302
auth_url = response.headers["Location"]
assert auth_url.startswith(sandbox)
cookie = SimpleCookie()
for c in response.headers.get_list("Set-Cookie"):
cookie.load(c)

# Internally, get redirected to provider-auth
response = request_api(auth_url, session=session, allow_redirects=False)
assert response.status_code == 302
with pytest.raises(HTTPClientError) as exc_info:
http_client = AsyncHTTPClient()
response = await http_client.fetch(auth_url, follow_redirects=False)
response = exc_info.value.response
assert response.code == 302
callback_url = response.headers["Location"]
assert callback_url.startswith(ext_url + callback_path)
assert "code=" in callback_url
response = request_api(callback_url, session=session)
assert response.status_code == 200
assert response.url.startswith(ext_url + smart_path)

# TODO: Should I test token existence? And how?
callback_url_parsed = urlparse(callback_url)
# strip proto://host for jp_fetch
server_callback_url = urlunparse(callback_url_parsed._replace(netloc="", scheme=""))
params = dict(parse_qsl(callback_url_parsed.query))
# SMART does different URL escaping
# SMART dev server appears to do some weird unescaping with callback URL
server_callback_url = server_callback_url.replace("@", "%40")
assert server_callback_url.startswith(url_path_join(jp_base_url, callback_path))
assert "code" in params

cookie_header = "; ".join(
f"{morsel.key}={morsel.coded_value}" for morsel in cookie.values()
)
with pytest.raises(HTTPClientError) as exc_info:
await jp_fetch(
callback_path,
params=params,
headers={"Cookie": cookie_header},
follow_redirects=False,
)
response = exc_info.value.response
assert response.code == 302
dest_url = response.headers["Location"]

# TODO: test dest_url?
assert urlparse(dest_url).path.startswith(url_path_join(jp_base_url, smart_path))

# verify that token was issued and works
assert "smart_token" in jp_serverapp.web_app.settings
token = jp_serverapp.web_app.settings["smart_token"]
smart_config = jp_serverapp.web_app.settings["smart_config"]
url = url_path_join(smart_config.fhir_url, "Condition")
resp = await http_client.fetch(url, headers={"Authorization": f"Bearer {token}"})
data = json.loads(resp.body.decode("utf8"))
assert data
assert isinstance(data, dict)
assert "resourceType" in data
assert data["resourceType"] == "Bundle"

0 comments on commit b0e8fad

Please sign in to comment.