Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement nodes/{node_id}/download endpoint #74

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aiida_restapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
'disabled': False,
}
}

# The chunks size for streaming data for download
DOWNLOAD_CHUNK_SIZE = 1024
50 changes: 48 additions & 2 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import os
import tempfile
from pathlib import Path
from typing import Any, List, Optional
from typing import Any, Generator, List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import EntryPointError
from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent
from aiida.plugins.entry_point import load_entry_point
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from pydantic import ValidationError

from aiida_restapi import models, resources
from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

from .auth import get_current_active_user

Expand Down Expand Up @@ -41,6 +43,50 @@ async def get_nodes_download_formats() -> dict[str, Any]:
return resources.get_all_download_formats()


@router.get('/nodes/{nodes_id}/download')
@with_dbenv()
async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse:
"""Get nodes by id."""
from aiida.orm import load_node

try:
node = load_node(nodes_id)
except NotExistent:
raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}')

if download_format is None:
raise HTTPException(
status_code=422,
detail='Please specify the download format. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.',
)

elif download_format in node.get_export_formats():
# byteobj, dict with {filename: filecontent}
import io

try:
exported_bytes, _ = node._exportcontent(download_format)
except LicensingException as exc:
raise HTTPException(status_code=500, detail=str(exc))

def stream() -> Generator[bytes, None, None]:
with io.BytesIO(exported_bytes) as handler:
while chunk := handler.read(DOWNLOAD_CHUNK_SIZE):
yield chunk

return StreamingResponse(stream(), media_type=f'application/{download_format}')

else:
raise HTTPException(
status_code=422,
detail='The format {} is not supported. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'.format(download_format),
)


@router.get('/nodes/{nodes_id}', response_model=models.Node)
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/user_guide/graphql.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,10 @@ http://localhost:5000/api/v4/nodes/ffe11/repo/list
```html
http://localhost:5000/api/v4/nodes/ffe11/repo/contents?filename="aiida.in"
```


Not implemented for GraphQL, please use the REST API for this use case.

```html
http://localhost:5000/api/v4/nodes/fafdsf/download?download_format=xsf
```
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ testing = [
'pytest-regressions',
'pytest-cov',
'requests',
'httpx'
'httpx',
'numpy~=1.21'
]

[project.urls]
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union

import numpy as np
import pytest
import pytz
from aiida import orm
Expand Down Expand Up @@ -164,6 +165,17 @@ def default_nodes():
return [node_1.pk, node_2.pk, node_3.pk, node_4.pk]


@pytest.fixture(scope='function')
def array_data_node():
"""Populate database with downloadable node (implmenting a _prepare_* function).
For testing the chunking of the streaming we create an array that needs to be splitted int two chunks."""

from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

nb_elements = DOWNLOAD_CHUNK_SIZE // 64 + 1
return orm.ArrayData(np.arange(nb_elements, dtype=np.int64)).store()


@pytest.fixture(scope='function')
def authenticate():
"""Authenticate user.
Expand Down
21 changes: 21 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,24 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused
assert check_response.status_code == 200, response.content
assert check_response.json()['extras']['extra_one'] == 'value_1'
assert check_response.json()['extras']['extra_two'] == 'value_2'


@pytest.mark.anyio
async def test_get_download_node(array_data_node, async_client):
"""Test download node /nodes/{nodes_id}/download.
The async client is needed to avoid an error caused by an I/O operation on closed file"""

# Test that array is correctly downloaded as json
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=json')
assert response.status_code == 200, response.json()
assert response.json().get('default', None) == array_data_node.get_array().tolist()

# Test exception when wrong download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=cif')
assert response.status_code == 422, response.json()
assert 'format cif is not supported' in response.json()['detail']

# Test exception when no download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download')
assert response.status_code == 422, response.json()
assert 'Please specify the download format' in response.json()['detail']
Loading