Skip to content

Commit

Permalink
feat: 🎸 change format of /rows to fix the features order
Browse files Browse the repository at this point in the history
Note that the datasets library does not
officially support a way to specify the columns order, but relying on
info.features gives the dataset maintainer the ability to change this
order if they want.

Effectively, Dict preserves the order of the items, while JSON does not:
we have to pass the features in an Array to ensure the order is
preserved.

BREAKING CHANGE: 🧨 format change
  • Loading branch information
severo committed Oct 4, 2021
1 parent fe6ec9d commit bd06c94
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 31 deletions.
40 changes: 31 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,33 +461,55 @@ Parameters:

Responses:

- `200`: JSON content that provides the types of the columns (see https://huggingface.co/docs/datasets/about_dataset_features.html) and the data rows, with the following structure:
- `200`: JSON content that provides the types of the columns (see features at https://huggingface.co/docs/datasets/about_dataset_features.html) and the data rows, with the following structure. Note that the features are ordered and this order can be used to display the columns in a table for example.

```json
{
"features": [
{
"dataset": "glue",
"config": "ax",
"features": {
"premise": {
"feature": {
"name": "premise",
"content": {
"dtype": "string",
"id": null,
"_type": "Value"
},
"hypothesis": {
}
}
},
{
"dataset": "glue",
"config": "ax",
"feature": {
"name": "hypothesis",
"content": {
"dtype": "string",
"id": null,
"_type": "Value"
},
"label": {
}
}
},
{
"dataset": "glue",
"config": "ax",
"feature": {
"name": "label",
"content": {
"num_classes": 3,
"names": ["entailment", "neutral", "contradiction"],
"names_file": null,
"id": null,
"_type": "ClassLabel"
},
"idx": {
}
}
},
{
"dataset": "glue",
"config": "ax",
"feature": {
"name": "idx",
"content": {
"dtype": "int32",
"id": null,
"_type": "Value"
Expand Down
20 changes: 10 additions & 10 deletions src/datasets_preview_backend/queries/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datasets_preview_backend.responses import CachedResponse
from datasets_preview_backend.types import (
ConfigsContent,
FeaturesItem,
FeatureItem,
InfosContent,
RowItem,
RowsContent,
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_rows(
num_rows = EXTRACT_ROWS_LIMIT

rowItems: List[RowItem] = []
featuresItems: List[FeaturesItem] = []
featureItems: List[FeatureItem] = []

if config is not None and split is not None:
try:
Expand Down Expand Up @@ -104,11 +104,12 @@ def get_rows(
infoItem = infoItems[0]
if "features" not in infoItem:
raise Status400Error("a dataset config info should contain a 'features' property")
localFeaturesItems: List[FeaturesItem] = [
{"dataset": dataset, "config": config, "features": infoItem["features"]}
localFeatureItems: List[FeatureItem] = [
{"dataset": dataset, "config": config, "feature": {"name": name, "content": content}}
for (name, content) in infoItem["features"].items()
]

return {"features": localFeaturesItems, "rows": rowItems}
return {"features": localFeatureItems, "rows": rowItems}

if config is None:
content = get_configs_response(dataset=dataset, token=token).content
Expand Down Expand Up @@ -142,12 +143,11 @@ def get_rows(
raise Status400Error("rows could not be found")
rows_content = cast(RowsContent, content)
rowItems += rows_content["rows"]
for featuresItem in rows_content["features"]:
# there should be only one element. Anyway, let's loop
if featuresItem not in featuresItems:
featuresItems.append(featuresItem)
for featureItem in rows_content["features"]:
if featureItem not in featureItems:
featureItems.append(featureItem)

return {"features": featuresItems, "rows": rowItems}
return {"features": featureItems, "rows": rowItems}


@memoize(cache, expire=CACHE_TTL_SECONDS) # type:ignore
Expand Down
11 changes: 8 additions & 3 deletions src/datasets_preview_backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ class RowItem(TypedDict):
row: Any


class FeaturesItem(TypedDict):
class Feature(TypedDict):
name: str
content: Any


class FeatureItem(TypedDict):
dataset: str
config: str
features: Any
feature: Feature


# Content of endpoint responses
Expand All @@ -55,7 +60,7 @@ class SplitsContent(TypedDict):


class RowsContent(TypedDict):
features: List[FeaturesItem]
features: List[FeatureItem]
rows: List[RowItem]


Expand Down
23 changes: 14 additions & 9 deletions tests/queries/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,40 @@ def test_get_split_features() -> None:
split = "train"
response = get_rows(dataset, config, split)
assert "features" in response
assert len(response["features"]) == 1
featuresItem = response["features"][0]
assert "dataset" in featuresItem
assert "config" in featuresItem
assert "features" in featuresItem
assert featuresItem["features"]["tokens"]["_type"] == "Sequence"
assert len(response["features"]) == 3
featureItem = response["features"][0]
assert "dataset" in featureItem
assert "config" in featureItem
assert "feature" in featureItem
feature = featureItem["feature"]
assert "name" in feature
assert "content" in feature
assert feature["name"] == "id"
assert "_type" in feature["content"]
assert feature["content"]["_type"] == "Value"


def test_get_split_rows_without_split() -> None:
dataset = "acronym_identification"
response = get_rows(dataset, DEFAULT_CONFIG_NAME)
assert len(response["rows"]) == 3 * EXTRACT_ROWS_LIMIT
assert len(response["features"]) == 1
assert len(response["features"]) == 3


def test_get_split_rows_without_config() -> None:
dataset = "acronym_identification"
split = "train"
response1 = get_rows(dataset)
assert len(response1["rows"]) == 1 * 3 * EXTRACT_ROWS_LIMIT
assert len(response1["features"]) == 1
assert len(response1["features"]) == 3

response2 = get_rows(dataset, None, split)
assert response1 == response2

dataset = "adversarial_qa"
response3 = get_rows(dataset)
assert len(response3["rows"]) == 4 * 3 * EXTRACT_ROWS_LIMIT
assert len(response3["features"]) == 4
assert len(response3["features"]) == 4 * 6


def test_get_unknown_dataset() -> None:
Expand Down

0 comments on commit bd06c94

Please sign in to comment.