Skip to content

Commit

Permalink
feat: add ability to sort table columns
Browse files Browse the repository at this point in the history
This commit adds table sorting functionality for all resources to the UI. It adds 'sortBy' and
'descending' query parameters to each of the base GET endpoints and updates the services to handle
changes to the query. It adds unit tests to test the new sorting functionality.
  • Loading branch information
henrychoy authored and keithmanville committed Sep 13, 2024
1 parent 6bfbff1 commit 76d5646
Show file tree
Hide file tree
Showing 66 changed files with 1,364 additions and 129 deletions.
6 changes: 6 additions & 0 deletions src/dioptra/restapi/v1/artifacts/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,16 @@ def get(self):
search_string = unquote(parsed_query_params["search"])
page_index = parsed_query_params["index"]
page_length = parsed_query_params["page_length"]
sort_by_string = parsed_query_params["sort_by"]
descending = parsed_query_params["descending"]

artifacts, total_num_artifacts = self._artifact_service.get(
group_id=group_id,
search_string=search_string,
page_index=page_index,
page_length=page_length,
sort_by_string=sort_by_string,
descending=descending,
log=log,
)
return utils.build_paging_envelope(
Expand All @@ -101,6 +105,8 @@ def get(self):
index=page_index,
length=page_length,
total_num_elements=total_num_artifacts,
sort_by=sort_by_string,
descending=descending,
)

@login_required
Expand Down
11 changes: 11 additions & 0 deletions src/dioptra/restapi/v1/artifacts/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class ArtifactDoesNotExistError(Exception):
"""The requested artifact does not exist."""


class ArtifactSortError(Exception):
"""The requested sortBy column is not a sortable field."""


def register_error_handlers(api: Api) -> None:
@api.errorhandler(ArtifactDoesNotExistError)
def handle_artifact_does_not_exist_error(error):
Expand All @@ -42,3 +46,10 @@ def handle_artifact_already_exists_error(error):
},
400,
)

@api.errorhandler(ArtifactSortError)
def handle_queue_sort_error(error):
return (
{"message": "Bad Request - This column can not be sorted."},
400,
)
2 changes: 2 additions & 0 deletions src/dioptra/restapi/v1/artifacts/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GroupIdQueryParametersSchema,
PagingQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
generate_base_resource_ref_schema,
generate_base_resource_schema,
)
Expand Down Expand Up @@ -84,5 +85,6 @@ class ArtifactGetQueryParameters(
PagingQueryParametersSchema,
GroupIdQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
):
"""The query parameters for the GET method of the /artifacts endpoint."""
28 changes: 27 additions & 1 deletion src/dioptra/restapi/v1/artifacts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from dioptra.restapi.v1.jobs.service import ExperimentJobIdService, JobIdService
from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters

from .errors import ArtifactAlreadyExistsError, ArtifactDoesNotExistError
from .errors import (
ArtifactAlreadyExistsError,
ArtifactDoesNotExistError,
ArtifactSortError,
)

LOGGER: BoundLogger = structlog.stdlib.get_logger()

Expand All @@ -41,6 +45,12 @@
"uri": lambda x: models.Artifact.uri.like(x, escape="/"),
"description": lambda x: models.Artifact.description.like(x, escape="/"),
}
SORTABLE_FIELDS: Final[dict[str, Any]] = {
"uri": models.Artifact.uri,
"createdOn": models.Artifact.created_on,
"lastModifiedOn": models.Resource.last_modified_on,
"description": models.Artifact.description,
}


class ArtifactService(object):
Expand Down Expand Up @@ -128,6 +138,8 @@ def get(
search_string: str,
page_index: int,
page_length: int,
sort_by_string: str,
descending: bool,
**kwargs,
) -> Any:
"""Fetch a list of artifacts, optionally filtering by search string and paging
Expand All @@ -138,6 +150,8 @@ def get(
search_string: A search string used to filter results.
page_index: The index of the first group to be returned.
page_length: The maximum number of artifacts to be returned.
sort_by_string: The name of the column to sort.
descending: Boolean indicating whether to sort by descending or not.
Returns:
A tuple containing a list of artifacts and the total number of artifacts
Expand Down Expand Up @@ -197,6 +211,18 @@ def get(
.offset(page_index)
.limit(page_length)
)

if sort_by_string and sort_by_string in SORTABLE_FIELDS:
sort_column = SORTABLE_FIELDS[sort_by_string]
if descending:
sort_column = sort_column.desc()
else:
sort_column = sort_column.asc()
latest_artifacts_stmt = latest_artifacts_stmt.order_by(sort_column)
elif sort_by_string and sort_by_string not in SORTABLE_FIELDS:
log.debug(f"sort_by_string: '{sort_by_string}' is not in SORTABLE_FIELDS")
raise ArtifactSortError

artifacts = db.session.scalars(latest_artifacts_stmt).all()

drafts_stmt = select(
Expand Down
6 changes: 6 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ def get(self):
search_string = parsed_query_params["search"]
page_index = parsed_query_params["index"]
page_length = parsed_query_params["page_length"]
sort_by_string = parsed_query_params["sort_by"]
descending = parsed_query_params["descending"]

entrypoints, total_num_entrypoints = self._entrypoint_service.get(
group_id=group_id,
search_string=search_string,
page_index=page_index,
page_length=page_length,
sort_by_string=sort_by_string,
descending=descending,
log=log,
)
return utils.build_paging_envelope(
Expand All @@ -118,6 +122,8 @@ def get(self):
index=page_index,
length=page_length,
total_num_elements=total_num_entrypoints,
sort_by=sort_by_string,
descending=descending,
)

@login_required
Expand Down
11 changes: 11 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class EntrypointParameterNamesNotUniqueError(Exception):
"""Multiple entrypoint parameters share the same name."""


class EntrypointSortError(Exception):
"""The requested sortBy column is not a sortable field."""


def register_error_handlers(api: Api) -> None:
@api.errorhandler(EntrypointDoesNotExistError)
def handle_entrypoint_does_not_exist_error(error):
Expand Down Expand Up @@ -64,3 +68,10 @@ def handle_entrypoint_parameter_names_not_unique_error(error):
"message": "Bad Request - The entrypoint contains multiple parameters "
"with the same name."
}, 400

@api.errorhandler(EntrypointSortError)
def handle_queue_sort_error(error):
return (
{"message": "Bad Request - This column can not be sorted."},
400,
)
2 changes: 2 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GroupIdQueryParametersSchema,
PagingQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
generate_base_resource_ref_schema,
generate_base_resource_schema,
)
Expand Down Expand Up @@ -227,5 +228,6 @@ class EntrypointGetQueryParameters(
PagingQueryParametersSchema,
GroupIdQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
):
"""The query parameters for the GET method of the /entrypoints endpoint."""
23 changes: 23 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
EntrypointDoesNotExistError,
EntrypointParameterNamesNotUniqueError,
EntrypointPluginDoesNotExistError,
EntrypointSortError,
)

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand All @@ -50,6 +51,12 @@
"description": lambda x: models.EntryPoint.description.like(x, escape="/"),
"task_graph": lambda x: models.EntryPoint.task_graph.like(x, escape="/"),
}
SORTABLE_FIELDS: Final[dict[str, Any]] = {
"name": models.EntryPoint.name,
"createdOn": models.EntryPoint.created_on,
"lastModifiedOn": models.Resource.last_modified_on,
"description": models.EntryPoint.description,
}


class EntrypointService(object):
Expand Down Expand Up @@ -175,6 +182,8 @@ def get(
search_string: str,
page_index: int,
page_length: int,
sort_by_string: str,
descending: bool,
**kwargs,
) -> tuple[list[utils.EntrypointDict], int]:
"""Fetch a list of entrypoints, optionally filtering by search string and paging
Expand All @@ -185,6 +194,8 @@ def get(
search_string: A search string used to filter results.
page_index: The index of the first group to be returned.
page_length: The maximum number of entrypoints to be returned.
sort_by_string: The name of the column to sort.
descending: Boolean indicating whether to sort by descending or not.
Returns:
A tuple containing a list of entrypoints and the total number of entrypoints
Expand Down Expand Up @@ -243,6 +254,18 @@ def get(
.offset(page_index)
.limit(page_length)
)

if sort_by_string and sort_by_string in SORTABLE_FIELDS:
sort_column = SORTABLE_FIELDS[sort_by_string]
if descending:
sort_column = sort_column.desc()
else:
sort_column = sort_column.asc()
entrypoints_stmt = entrypoints_stmt.order_by(sort_column)
elif sort_by_string and sort_by_string not in SORTABLE_FIELDS:
log.debug(f"sort_by_string: '{sort_by_string}' is not in SORTABLE_FIELDS")
raise EntrypointSortError

entrypoints = list(db.session.scalars(entrypoints_stmt).unique().all())

queue_ids = set(
Expand Down
12 changes: 12 additions & 0 deletions src/dioptra/restapi/v1/experiments/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,16 @@ def get(self):
search_string = parsed_query_params["search"]
page_index = parsed_query_params["index"]
page_length = parsed_query_params["page_length"]
sort_by_string = parsed_query_params["sort_by"]
descending = parsed_query_params["descending"]

experiments, total_num_experiments = self._experiment_service.get(
group_id=group_id,
search_string=search_string,
page_index=page_index,
page_length=page_length,
sort_by_string=sort_by_string,
descending=descending,
log=log,
)
return utils.build_paging_envelope(
Expand All @@ -130,6 +134,8 @@ def get(self):
index=page_index,
length=page_length,
total_num_elements=total_num_experiments,
sort_by=sort_by_string,
descending=descending,
)

@login_required
Expand Down Expand Up @@ -260,12 +266,16 @@ def get(self, id: int):
search_string = unquote(parsed_query_params["search"])
page_index = parsed_query_params["index"]
page_length = parsed_query_params["page_length"]
sort_by_string = unquote(parsed_query_params["sort_by"])
descending = parsed_query_params["descending"]

jobs, total_num_jobs = self._experiment_job_service.get(
experiment_id=id,
search_string=search_string,
page_index=page_index,
page_length=page_length,
sort_by_string=sort_by_string,
descending=descending,
log=log,
)
return utils.build_paging_envelope(
Expand All @@ -278,6 +288,8 @@ def get(self, id: int):
index=page_index,
length=page_length,
total_num_elements=total_num_jobs,
sort_by=sort_by_string,
descending=descending,
)

@login_required
Expand Down
11 changes: 11 additions & 0 deletions src/dioptra/restapi/v1/experiments/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class ExperimentDoesNotExistError(Exception):
"""The requested experiment does not exist."""


class ExperimentSortError(Exception):
"""The requested sortBy column is not a sortable field."""


def register_error_handlers(api: Api) -> None:
@api.errorhandler(ExperimentAlreadyExistsError)
def handle_experiment_already_exists_error(error):
Expand All @@ -40,3 +44,10 @@ def handle_experiment_already_exists_error(error):
@api.errorhandler(ExperimentDoesNotExistError)
def handle_experiment_does_not_exist_error(error):
return {"message": "Not Found - The requested experiment does not exist"}, 404

@api.errorhandler(ExperimentSortError)
def handle_queue_sort_error(error):
return (
{"message": "Bad Request - This column can not be sorted."},
400,
)
2 changes: 2 additions & 0 deletions src/dioptra/restapi/v1/experiments/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GroupIdQueryParametersSchema,
PagingQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
generate_base_resource_ref_schema,
generate_base_resource_schema,
)
Expand Down Expand Up @@ -120,5 +121,6 @@ class ExperimentGetQueryParameters(
PagingQueryParametersSchema,
GroupIdQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
):
"""The query parameters for the GET method of the /experiments endpoint."""
28 changes: 27 additions & 1 deletion src/dioptra/restapi/v1/experiments/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from dioptra.restapi.v1.groups.service import GroupIdService
from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters

from .errors import ExperimentAlreadyExistsError, ExperimentDoesNotExistError
from .errors import (
ExperimentAlreadyExistsError,
ExperimentDoesNotExistError,
ExperimentSortError,
)

LOGGER: BoundLogger = structlog.stdlib.get_logger()

Expand All @@ -44,6 +48,12 @@
"description": lambda x: models.Experiment.description.like(x, escape="/"),
"tag": lambda x: models.Experiment.tags.any(models.Tag.name.like(x, escape="/")),
}
SORTABLE_FIELDS: Final[dict[str, Any]] = {
"name": models.Experiment.name,
"createdOn": models.Experiment.created_on,
"lastModifiedOn": models.Resource.last_modified_on,
"description": models.Experiment.description,
}


class ExperimentService(object):
Expand Down Expand Up @@ -147,6 +157,8 @@ def get(
search_string: str,
page_index: int,
page_length: int,
sort_by_string: str,
descending: bool,
**kwargs,
) -> tuple[list[utils.ExperimentDict], int]:
"""Fetch a list of experiments, optionally filtering by search string and paging
Expand All @@ -157,6 +169,8 @@ def get(
search_string: A search string used to filter results.
page_index: The index of the first page to be returned.
page_length: The maximum number of experiments to be returned.
sort_by_string: The name of the column to sort.
descending: Boolean indicating whether to sort by descending or not.
Returns:
A tuple containing a list of experiments and the total number of experiments
Expand Down Expand Up @@ -214,6 +228,18 @@ def get(
.offset(page_index)
.limit(page_length)
)

if sort_by_string and sort_by_string in SORTABLE_FIELDS:
sort_column = SORTABLE_FIELDS[sort_by_string]
if descending:
sort_column = sort_column.desc()
else:
sort_column = sort_column.asc()
experiments_stmt = experiments_stmt.order_by(sort_column)
elif sort_by_string and sort_by_string not in SORTABLE_FIELDS:
log.debug(f"sort_by_string: '{sort_by_string}' is not in SORTABLE_FIELDS")
raise ExperimentSortError

experiments = list(db.session.scalars(experiments_stmt).all())

entrypoint_ids = {
Expand Down
Loading

0 comments on commit 76d5646

Please sign in to comment.