Skip to content

Commit

Permalink
Merge branch 'async-requests'
Browse files Browse the repository at this point in the history
  • Loading branch information
JR-1991 committed Apr 24, 2024
2 parents 4511df2 + 6f0e89d commit 072acb7
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 56 deletions.
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
httpx>=0.26.0
jsonschema>=3.2.0
urllib3>=2.2.1
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
-r common.txt
pytest
pytest-cov
pytest-asyncio
tox
selenium==3.141.0
236 changes: 180 additions & 56 deletions src/pyDataverse/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dataverse API wrapper for all it's API's."""

import json
from typing import Any, Dict, Optional
import httpx
import subprocess as sp
from urllib.parse import urljoin
Expand All @@ -26,7 +27,7 @@ class Api:
Base URL of Dataverse instance. Without trailing `/` at the end.
e.g. `http://demo.dataverse.org`
api_token : str
Authentication token for the api.
Authenication token for the api.

Check failure on line 30 in src/pyDataverse/api.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

Authenication ==> Authentication
Attributes
----------
Expand All @@ -37,7 +38,10 @@ class Api:
"""

def __init__(
self, base_url: str, api_token: str = None, api_version: str = "latest"
self,
base_url: str,
api_token: Optional[str] = None,
api_version: str = "latest",
):
"""Init an Api() class.
Expand All @@ -64,6 +68,7 @@ def __init__(
raise ApiUrlError("base_url {0} is not a string.".format(base_url))

self.base_url = base_url
self.client = None

if not isinstance(api_version, ("".__class__, "".__class__)):
raise ApiUrlError("api_version {0} is not a string.".format(api_version))
Expand Down Expand Up @@ -120,28 +125,17 @@ def get_request(self, url, params=None, auth=False):
if self.api_token:
params["key"] = str(self.api_token)

try:
url = urljoin(self.base_url_api, url)
resp = httpx.get(url, params=params, follow_redirects=True)
if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
"ERROR: GET - Authorization invalid {0}. MSG: {1}.".format(
url, error_msg
)
)
elif resp.status_code >= 300:
if resp.text:
error_msg = resp.text
raise OperationFailedError(
"ERROR: GET HTTP {0} - {1}. MSG: {2}".format(
resp.status_code, url, error_msg
)
)
return resp
except ConnectError:
raise ConnectError(
"ERROR: GET - Could not establish connection to api {0}.".format(url)
if self.client is None:
return self._sync_request(
method=httpx.get,
url=url,
params=params,
)
else:
return self._async_request(
method=self.client.get,
url=url,
params=params,
)

def post_request(self, url, data=None, auth=False, params=None, files=None):
Expand Down Expand Up @@ -174,19 +168,21 @@ def post_request(self, url, data=None, auth=False, params=None, files=None):
if self.api_token:
params["key"] = self.api_token

try:
resp = httpx.post(url, data=data, params=params, files=files, follow_redirects=True)
if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
"ERROR: POST HTTP 401 - Authorization error {0}. MSG: {1}".format(
url, error_msg
)
)
return resp
except ConnectError:
raise ConnectError(
"ERROR: POST - Could not establish connection to API: {0}".format(url)
if self.client is None:
return self._sync_request(
method=httpx.post,
url=url,
data=data,
params=params,
files=files,
)
else:
return self._async_request(
method=self.client.post,
url=url,
data=data,
params=params,
files=files,
)

def put_request(self, url, data=None, auth=False, params=None):
Expand Down Expand Up @@ -215,19 +211,19 @@ def put_request(self, url, data=None, auth=False, params=None):
if self.api_token:
params["key"] = self.api_token

try:
resp = httpx.put(url, data=data, params=params, follow_redirects=True)
if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
"ERROR: PUT HTTP 401 - Authorization error {0}. MSG: {1}".format(
url, error_msg
)
)
return resp
except ConnectError:
raise ConnectError(
"ERROR: PUT - Could not establish connection to api '{0}'.".format(url)
if self.client is None:
return self._sync_request(
method=httpx.put,
url=url,
data=data,
params=params,
)
else:
return self._async_request(
method=self.client.put,
url=url,
data=data,
params=params,
)

def delete_request(self, url, auth=False, params=None):
Expand All @@ -254,13 +250,141 @@ def delete_request(self, url, auth=False, params=None):
if self.api_token:
params["key"] = self.api_token

if self.client is None:
return self._sync_request(
method=httpx.delete,
url=url,
params=params,
)
else:
return self._async_request(
method=self.client.delete,
url=url,
params=params,
)

def _sync_request(
self,
method,
**kwargs,
):
"""
Sends a synchronous request to the specified URL using the specified HTTP method.
Args:
method (function): The HTTP method to use for the request.
**kwargs: Additional keyword arguments to be passed to the method.
Returns:
requests.Response: The response object returned by the request.
Raises:
ApiAuthorizationError: If the response status code is 401 (Authorization error).
ConnectError: If a connection to the API cannot be established.
"""
assert "url" in kwargs, "URL is required for a request."

kwargs = self._filter_kwargs(kwargs)

try:
resp = method(**kwargs)

if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
"ERROR: HTTP 401 - Authorization error {0}. MSG: {1}".format(
kwargs["url"], error_msg
)
)

return resp

except ConnectError:
raise ConnectError(
"ERROR - Could not establish connection to api '{0}'.".format(
kwargs["url"]
)
)

async def _async_request(
self,
method,
**kwargs,
):
"""
Sends an asynchronous request to the specified URL using the specified HTTP method.
Args:
method (callable): The HTTP method to use for the request.
**kwargs: Additional keyword arguments to be passed to the method.
Raises:
ApiAuthorizationError: If the response status code is 401 (Authorization error).
ConnectError: If a connection to the API cannot be established.
Returns:
The response object.
"""
assert "url" in kwargs, "URL is required for a request."

kwargs = self._filter_kwargs(kwargs)

try:
return httpx.delete(url, params=params, follow_redirects=True)
resp = await method(**kwargs)

if resp.status_code == 401:
error_msg = resp.json()["message"]
raise ApiAuthorizationError(
"ERROR: HTTP 401 - Authorization error {0}. MSG: {1}".format(
kwargs["url"], error_msg
)
)

return resp

except ConnectError:
raise ConnectError(
"ERROR: DELETE could not establish connection to api {}.".format(url)
"ERROR - Could not establish connection to api '{0}'.".format(
kwargs["url"]
)
)

@staticmethod
def _filter_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters out any keyword arguments that are `None` from the specified dictionary.
Args:
kwargs (Dict[str, Any]): The dictionary to filter.
Returns:
Dict[str, Any]: The filtered dictionary.
"""
return {k: v for k, v in kwargs.items() if v is not None}

async def __aenter__(self):
"""
Context manager method that initializes an instance of httpx.AsyncClient.
Returns:
httpx.AsyncClient: An instance of httpx.AsyncClient.
"""
self.client = httpx.AsyncClient()

async def __aexit__(self, exc_type, exc_value, traceback):
"""
Closes the client connection when exiting a context manager.
Args:
exc_type (type): The type of the exception raised, if any.
exc_value (Exception): The exception raised, if any.
traceback (traceback): The traceback object associated with the exception, if any.
"""

await self.client.aclose()
self.client = None


class DataAccessApi(Api):
"""Class to access Dataverse's Data Access API.
Expand Down Expand Up @@ -338,13 +462,13 @@ def get_datafile(
"""
is_first_param = True
if is_pid:
url = "{0}/datafile/:persistentId/?persistentId={1}".format(
self.base_url_api_data_access, identifier
)
else:
url = "{0}/datafile/{1}".format(self.base_url_api_data_access, identifier)
if data_format or no_var_header or image_thumb:
url += "?"
else:
url = "{0}/datafile/:persistentId/?persistentId={1}".format(
self.base_url_api_data_access, identifier
)
if data_format:
url += "format={0}".format(data_format)
is_first_param = False
Expand Down
16 changes: 16 additions & 0 deletions tests/api/test_async_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio
import pytest


class TestAsyncAPI:

@pytest.mark.asyncio
async def test_async_api(self, native_api):

async with native_api:
tasks = [native_api.get_info_version() for _ in range(10)]
responses = await asyncio.gather(*tasks)

assert len(responses) == 10
for response in responses:
assert response.status_code == 200, "Request failed."

0 comments on commit 072acb7

Please sign in to comment.