Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jsonData not passed correctly #203

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,6 @@ poetry.lock

# Ruff
.ruff_cache/

# JetBrains
.idea/
52 changes: 41 additions & 11 deletions pyDataverse/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,20 @@ def post_request(self, url, data=None, auth=False, params=None, files=None):
if isinstance(data, str):
data = json.loads(data)

# Decide whether to use 'data' or 'json' args
request_params = self._check_json_data_form(data)

if self.client is None:
return self._sync_request(
method=httpx.post,
url=url,
json=data,
params=params,
files=files,
method=httpx.post, url=url, params=params, files=files, **request_params
)
else:
return self._async_request(
method=self.client.post,
url=url,
json=data,
params=params,
files=files,
**request_params,
)

def put_request(self, url, data=None, auth=False, params=None):
Expand Down Expand Up @@ -216,19 +215,22 @@ def put_request(self, url, data=None, auth=False, params=None):
if isinstance(data, str):
data = json.loads(data)

# Decide whether to use 'data' or 'json' args
request_params = self._check_json_data_form(data)
JR-1991 marked this conversation as resolved.
Show resolved Hide resolved

if self.client is None:
return self._sync_request(
method=httpx.put,
url=url,
json=data,
params=params,
**request_params,
)
else:
return self._async_request(
method=self.client.put,
url=url,
json=data,
params=params,
**request_params,
)

def delete_request(self, url, auth=False, params=None):
Expand Down Expand Up @@ -268,6 +270,33 @@ def delete_request(self, url, auth=False, params=None):
params=params,
)

@staticmethod
def _check_json_data_form(data: Optional[Dict]):
JR-1991 marked this conversation as resolved.
Show resolved Hide resolved
"""This method checks and distributes given payload to match Dataverse expectations.

In the case of the form-data keyed by "jsonData", Dataverse expects
the payload as a string in a form of a dictionary. This is not possible
using HTTPXs json parameter, so we need to handle this case separately.
"""

if not data:
return {}
elif not isinstance(data, dict):
raise ValueError("Data must be a dictionary.")
elif "jsonData" not in data:
return {"json": data}

assert list(data.keys()) == [
"jsonData"
], "jsonData must be the only key in the dictionary."

# Content of JSON data should ideally be a string
content = data["jsonData"]
if not isinstance(content, str):
data["jsonData"] = json.dumps(content)

return {"data": data}

def _sync_request(
self,
method,
Expand Down Expand Up @@ -1813,9 +1842,10 @@ def upload_datafile(self, identifier, filename, json_str=None, is_pid=True):
url += "/datasets/{0}/add".format(identifier)

files = {"file": open(filename, "rb")}
return self.post_request(
url, data={"jsonData": json_str}, files=files, auth=True
)
metadata = {}
if json_str is not None:
metadata["jsonData"] = json_str
return self.post_request(url, data=metadata, files=files, auth=True)

def update_datafile_metadata(self, identifier, json_str=None, is_filepid=False):
"""Update datafile metadata.
Expand Down
2 changes: 0 additions & 2 deletions tests/api/test_async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@


class TestAsyncAPI:

@pytest.mark.asyncio
async def test_async_api(self, native_api):

async with native_api:
tasks = [native_api.get_info_version() for _ in range(10)]
responses = await asyncio.gather(*tasks)
Expand Down
175 changes: 109 additions & 66 deletions tests/api/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import httpx

from pyDataverse.api import NativeApi
from pyDataverse.api import DataAccessApi, NativeApi
from pyDataverse.models import Datafile


Expand Down Expand Up @@ -45,6 +45,41 @@ def test_file_upload(self):
# Assert
assert response.status_code == 200, "File upload failed."

def test_file_upload_without_metadata(self):
"""
Test case for uploading a file to a dataset without metadata.

--> json_str will be set as None

This test case performs the following steps:
1. Creates a dataset using the provided metadata.
2. Prepares a file for upload.
3. Uploads the file to the dataset.
4. Asserts that the file upload was successful.

Raises:
AssertionError: If the file upload fails.

"""
# Arrange
BASE_URL = os.getenv("BASE_URL").rstrip("/")
API_TOKEN = os.getenv("API_TOKEN")

# Create dataset
metadata = json.load(open("tests/data/file_upload_ds_minimum.json"))
pid = self._create_dataset(BASE_URL, API_TOKEN, metadata)
api = NativeApi(BASE_URL, API_TOKEN)

# Act
response = api.upload_datafile(
identifier=pid,
filename="tests/data/datafile.txt",
json_str=None,
)

# Assert
assert response.status_code == 200, "File upload failed."
JR-1991 marked this conversation as resolved.
Show resolved Hide resolved

def test_bulk_file_upload(self, create_mock_file):
"""
Test case for uploading bulk files to a dataset.
Expand Down Expand Up @@ -97,9 +132,66 @@ def test_bulk_file_upload(self, create_mock_file):
# Assert
assert response.status_code == 200, "File upload failed."

def test_file_replacement(self):
def test_file_replacement_wo_metadata(self):
"""
Test case for replacing a file in a dataset without metadata.

Steps:
1. Create a dataset using the provided metadata.
2. Upload a datafile to the dataset.
3. Replace the uploaded datafile with a mutated version.
4. Verify that the file replacement was successful and the content matches the expected content.
"""

# Arrange
BASE_URL = os.getenv("BASE_URL").rstrip("/")
API_TOKEN = os.getenv("API_TOKEN")

# Create dataset
metadata = json.load(open("tests/data/file_upload_ds_minimum.json"))
pid = self._create_dataset(BASE_URL, API_TOKEN, metadata)
api = NativeApi(BASE_URL, API_TOKEN)
data_api = DataAccessApi(BASE_URL, API_TOKEN)

# Perform file upload
df = Datafile({"pid": pid, "filename": "datafile.txt"})
response = api.upload_datafile(
identifier=pid,
filename="tests/data/replace.xyz",
json_str=df.json(),
)

# Retrieve file ID
file_id = response.json()["data"]["files"][0]["dataFile"]["id"]

# Act
with tempfile.TemporaryDirectory() as tempdir:
original = open("tests/data/replace.xyz").read()
mutated = "Z" + original[1::]
mutated_path = os.path.join(tempdir, "replace.xyz")

with open(mutated_path, "w") as f:
f.write(mutated)

json_data = {}

response = api.replace_datafile(
identifier=file_id,
filename=mutated_path,
json_str=json.dumps(json_data),
is_filepid=False,
)

# Assert
file_id = response.json()["data"]["files"][0]["dataFile"]["id"]
content = data_api.get_datafile(file_id, is_pid=False).text

assert response.status_code == 200, "File replacement failed."
assert content == mutated, "File content does not match the expected content."

def test_file_replacement_w_metadata(self):
"""
Test case for replacing a file in a dataset.
Test case for replacing a file in a dataset with metadata.

Steps:
1. Create a dataset using the provided metadata.
Expand All @@ -116,6 +208,7 @@ def test_file_replacement(self):
metadata = json.load(open("tests/data/file_upload_ds_minimum.json"))
pid = self._create_dataset(BASE_URL, API_TOKEN, metadata)
api = NativeApi(BASE_URL, API_TOKEN)
data_api = DataAccessApi(BASE_URL, API_TOKEN)

# Perform file upload
df = Datafile({"pid": pid, "filename": "datafile.txt"})
Expand All @@ -126,7 +219,7 @@ def test_file_replacement(self):
)

# Retrieve file ID
file_id = self._get_file_id(BASE_URL, API_TOKEN, pid)
file_id = response.json()["data"]["files"][0]["dataFile"]["id"]

# Act
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -141,6 +234,7 @@ def test_file_replacement(self):
"description": "My description.",
"categories": ["Data"],
"forceReplace": False,
"directoryLabel": "some/other",
JR-1991 marked this conversation as resolved.
Show resolved Hide resolved
}

response = api.replace_datafile(
Expand All @@ -151,17 +245,19 @@ def test_file_replacement(self):
)

# Assert
replaced_id = self._get_file_id(BASE_URL, API_TOKEN, pid)
replaced_content = self._fetch_datafile_content(
BASE_URL,
API_TOKEN,
replaced_id,
)
file_id = response.json()["data"]["files"][0]["dataFile"]["id"]
data_file = api.get_dataset(pid).json()["data"]["latestVersion"]["files"][0]
content = data_api.get_datafile(file_id, is_pid=False).text

assert response.status_code == 200, "File replacement failed."
assert (
replaced_content == mutated
), "File content does not match the expected content."
data_file["description"] == "My description."
), "Description does not match."
assert data_file["categories"] == ["Data"], "Categories do not match."
assert (
data_file["directoryLabel"] == "some/other"
), "Directory label does not match."
assert response.status_code == 200, "File replacement failed."
assert content == mutated, "File content does not match the expected content."

@staticmethod
def _create_dataset(
Expand Down Expand Up @@ -193,56 +289,3 @@ def _create_dataset(
response.raise_for_status()

return response.json()["data"]["persistentId"]

@staticmethod
def _get_file_id(
BASE_URL: str,
API_TOKEN: str,
pid: str,
):
"""
Retrieves the file ID for a given persistent identifier (PID) in Dataverse.

Args:
BASE_URL (str): The base URL of the Dataverse instance.
API_TOKEN (str): The API token for authentication.
pid (str): The persistent identifier (PID) of the dataset.

Returns:
str: The file ID of the latest version of the dataset.

Raises:
HTTPError: If the HTTP request to retrieve the file ID fails.
"""
response = httpx.get(
url=f"{BASE_URL}/api/datasets/:persistentId/?persistentId={pid}",
headers={"X-Dataverse-key": API_TOKEN},
)

response.raise_for_status()

return response.json()["data"]["latestVersion"]["files"][0]["dataFile"]["id"]

@staticmethod
def _fetch_datafile_content(
BASE_URL: str,
API_TOKEN: str,
id: str,
):
"""
Fetches the content of a datafile from the specified BASE_URL using the provided API_TOKEN.

Args:
BASE_URL (str): The base URL of the Dataverse instance.
API_TOKEN (str): The API token for authentication.
id (str): The ID of the datafile.

Returns:
str: The content of the datafile as a decoded UTF-8 string.
"""
url = f"{BASE_URL}/api/access/datafile/{id}"
headers = {"X-Dataverse-key": API_TOKEN}
response = httpx.get(url, headers=headers)
response.raise_for_status()

return response.content.decode("utf-8")
20 changes: 16 additions & 4 deletions tests/models/test_datafile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datafile data model tests."""

import json
import jsonschema
import os
Expand Down Expand Up @@ -244,7 +245,10 @@ def test_datafile_from_json_valid(self):
),
(({json_upload_min()}, {"validate": False}), object_data_min()),
(
({json_upload_min()}, {"filename_schema": "wrong", "validate": False},),
(
{json_upload_min()},
{"filename_schema": "wrong", "validate": False},
),
object_data_min(),
),
(
Expand Down Expand Up @@ -345,7 +349,10 @@ def test_datafile_to_json_valid(self):
json.loads(json_upload_min()),
),
(
(dict_flat_set_min(), {"filename_schema": "wrong", "validate": False},),
(
dict_flat_set_min(),
{"filename_schema": "wrong", "validate": False},
),
json.loads(json_upload_min()),
),
(
Expand Down Expand Up @@ -517,7 +524,10 @@ def test_dataverse_from_json_to_json_valid(self):
({json_upload_full()}, {}),
({json_upload_min()}, {"data_format": "dataverse_upload"}),
({json_upload_min()}, {"validate": False}),
({json_upload_min()}, {"filename_schema": "wrong", "validate": False},),
(
{json_upload_min()},
{"filename_schema": "wrong", "validate": False},
),
(
{json_upload_min()},
{
Expand Down Expand Up @@ -550,4 +560,6 @@ def test_dataverse_from_json_to_json_valid(self):

for key, val in pdv_end.get().items():
assert getattr(pdv_start, key) == getattr(pdv_end, key)
assert len(pdv_start.__dict__) == len(pdv_end.__dict__,)
assert len(pdv_start.__dict__) == len(
pdv_end.__dict__,
)
Loading
Loading