Skip to content

Commit

Permalink
Merge pull request #24 from broadinstitute/subsetted-tabular-fix-labe…
Browse files Browse the repository at this point in the history
…ls-and-typings

fix(breadbox): Subsetted tabular fix labels and typings
  • Loading branch information
jessica-cheng authored Aug 13, 2024
2 parents 691409a + 26fb114 commit e4288a3
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 39 deletions.
66 changes: 60 additions & 6 deletions breadbox/breadbox/crud/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from typing import Dict, Optional, List, Type, Union, Tuple
from uuid import UUID, uuid4
import warnings
import json

import pandas as pd
import numpy as np
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import aliased, with_polymorphic

from breadbox.db.session import SessionWithUser
from ..io.data_validation import dimension_label_df_schema
from ..io.data_validation import (
dimension_label_df_schema,
annotation_type_to_pandas_column_type,
)
from ..schemas.dataset import (
MatrixDatasetIn,
TabularDatasetIn,
Expand Down Expand Up @@ -1506,10 +1510,29 @@ def get_dataset_feature(
return dataset_feature


def _get_column_types(columns_metadata, columns: Optional[List[str]]):
col_and_column_metadata_pairs = columns_metadata.items()
if columns is None:
return {
col: annotation_type_to_pandas_column_type(column_metadata.col_type)
for col, column_metadata in col_and_column_metadata_pairs
}

else:
column_types = {}
for col, column_metadata in col_and_column_metadata_pairs:
if col in columns:
column_types[col] = annotation_type_to_pandas_column_type(
column_metadata.col_type
)

return column_types


def get_subsetted_tabular_dataset_df(
db: SessionWithUser,
user: str,
dataset: Dataset,
dataset: TabularDataset,
tabular_dimensions_info: TabularDimensionsInfo,
strict: bool,
):
Expand All @@ -1530,15 +1553,20 @@ def get_subsetted_tabular_dataset_df(
)

if tabular_dimensions_info.identifier == FeatureSampleIdentifier.label:
# Get the corresponding dimension ids for the dimension labels and use the dimension ids to filter values by
# Get the corresponding dimension ids for the dimension labels from the dataset's dimension type and use the dimension ids to filter values by
dimension_type: DimensionType = db.query(DimensionType).filter(
DimensionType.name == dataset.index_type_name
).one()

label_filter_statements = [
TabularColumn.dataset_id == dataset.id,
TabularColumn.dataset_id == dimension_type.dataset_id,
TabularColumn.given_id == "label",
]
if tabular_dimensions_info.indices:
label_filter_statements.append(
TabularCell.value.in_(tabular_dimensions_info.indices)
)

ids_by_label = (
db.query(TabularCell)
.join(TabularColumn)
Expand Down Expand Up @@ -1604,11 +1632,37 @@ def get_subsetted_tabular_dataset_df(

# Need to index by "value" after checking if empty db bc empty db has no 'value' keyword
subsetted_tabular_dataset_df = pivot_df["value"]
# TODO: It seems like None data values are potentially stored as 'nan' in the db. Must fix this!
subsetted_tabular_dataset_df = subsetted_tabular_dataset_df.replace({np.nan: None})
# set typing for columns
col_dtypes = _get_column_types(
dataset.columns_metadata, tabular_dimensions_info.columns
)
subsetted_tabular_dataset_df = _convert_subsetted_tabular_df_dtypes(
subsetted_tabular_dataset_df, col_dtypes, dataset.columns_metadata
)
return subsetted_tabular_dataset_df


def _convert_subsetted_tabular_df_dtypes(
df: pd.DataFrame,
dtype_map: Dict[str, Any],
dataset_columns_metadata: Dict[str, ColumnMetadata],
):
# Replace string boolean values with boolean
for col, dtype in dtype_map.items():
column = df[col]
if dtype == pd.BooleanDtype():
column = column.replace({"True": True, "False": False})
column = column.astype(dtype)
# NOTE: if col type is list string, convert to list. col dtype will be changed to object
if (
dtype == pd.StringDtype()
and dataset_columns_metadata[col].col_type == AnnotationType.list_strings
):
column = column.apply(lambda x: json.loads(x) if x is not pd.NA else x)
df[col] = column
return df


def get_truncated_message(missing_tabular_columns, missing_tabular_indices):
num_missing_cols = len(missing_tabular_columns)
num_missing_indices = len(missing_tabular_indices)
Expand Down
26 changes: 14 additions & 12 deletions breadbox/breadbox/io/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def validate_all_columns_have_types(
)


# TODO: Replace this with annotation_type_to_pandas_column_type. This is used in dimension type tabular datasets only and we should try to make these dtype mappings standard for all tabular datasets
def map_annotation_type_to_pandas_dtype(annotation_type: AnnotationType):
annotation_type_to_pandas_type_mappings = {
AnnotationType.continuous: "float",
Expand All @@ -67,6 +68,19 @@ def map_annotation_type_to_pandas_dtype(annotation_type: AnnotationType):
return annotation_type_to_pandas_type_mappings.get(annotation_type)


def annotation_type_to_pandas_column_type(annotation_type: AnnotationType):
annotation_type_to_pandas_type_mappings = {
AnnotationType.continuous: pd.Float64Dtype(),
AnnotationType.categorical: pd.CategoricalDtype(),
AnnotationType.binary: pd.BooleanDtype(),
AnnotationType.text: pd.StringDtype(),
AnnotationType.list_strings: pd.StringDtype(),
}
dtype = annotation_type_to_pandas_type_mappings.get(annotation_type)
assert dtype is not None
return dtype


def _validate_dimension_type_metadata_file(
metadata_file: Optional[UploadFile],
annotation_type_mapping: Dict[str, AnnotationType],
Expand Down Expand Up @@ -479,18 +493,6 @@ def read_and_validate_tabular_df(
columns_metadata: Dict[str, ColumnMetadata],
dimension_type_identifier: str,
):
def annotation_type_to_pandas_column_type(annotation_type: AnnotationType):
annotation_type_to_pandas_type_mappings = {
AnnotationType.continuous: pd.Float64Dtype(),
AnnotationType.categorical: pd.CategoricalDtype(),
AnnotationType.binary: pd.BooleanDtype(),
AnnotationType.text: pd.StringDtype(),
AnnotationType.list_strings: pd.StringDtype(),
}
dtype = annotation_type_to_pandas_type_mappings.get(annotation_type)
assert dtype is not None
return dtype

def can_parse_list_strings(val):
example_list_string = '["x", "y"]'
if val is not None and not pd.isnull(val):
Expand Down
2 changes: 0 additions & 2 deletions breadbox/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,5 +403,3 @@ def main_func():

if __name__ == "__main__":
cli()

# test commit
2 changes: 1 addition & 1 deletion breadbox/tests/api/test_dataset_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_add_tabular_dataset_with_missing_vals(
)
assert subsetted_by_id_res.status_code == 200
expected_res = {
"attr2": {"ACH-1": "False"},
"attr2": {"ACH-1": False},
"attr3": {"ACH-1": None},
"attr4": {"ACH-1": None},
"attr5": {"ACH-1": "cat1"},
Expand Down
153 changes: 135 additions & 18 deletions breadbox/tests/api/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2446,22 +2446,57 @@ def test_get_tabular_dataset_data(
private_group: Dict,
settings,
):
tabular_file_1 = factories.tabular_csv_data_file(
cols=["depmap_id", "label", "col_1", "col_2"],
row_values=[["ACH-1", "ach1", 1, "hi"], ["ACH-2", "ach2", np.NaN, "bye"]],
admin_headers = {"X-Forwarded-Email": settings.admin_users[0]}

# Give metadata for depmap model
r_add_metadata_for_depmap_model = client.patch(
"/types/sample/depmap_model/metadata",
data={
"name": "depmap model metadata",
"annotation_type_mapping": json.dumps(
{"annotation_type_mapping": {"label": "text", "depmap_id": "text",}}
),
},
files={
"metadata_file": (
"new_feature_metadata",
factories.tabular_csv_data_file(
cols=["label", "depmap_id"],
row_values=[
["ach1", "ACH-1"],
["ach2", "ACH-2"],
["ach3", "ACH-3"],
],
),
"text/csv",
)
},
headers=admin_headers,
)
tabular_file_2 = factories.tabular_csv_data_file(
cols=["depmap_id", "label", "col_1", "col_2"],
row_values=[["ACH-1", "ach1", 1, "hi"]],
assert_status_ok(
r_add_metadata_for_depmap_model
), r_add_metadata_for_depmap_model.status_code == 200

# Create tabular dataset
tabular_file_1 = factories.tabular_csv_data_file(
cols=[
"depmap_id",
"label",
"col_1",
"col_2",
"col_3",
"col_4",
"col_5",
], # NOTE: Add 'label' col to ensure endpoint only uses 'label' in dim type metadata
row_values=[
["ACH-1", "other_label_1", 1, "hi", False, "cat1", '["a"]'],
["ACH-2", "other_label_2", np.NaN, "bye", np.NaN, "cat2", np.NaN],
],
)

tabular_file_ids_1, tabular_file_1_hash = factories.file_ids_and_md5_hash(
client, tabular_file_1
)
tabular_file_ids_2, tabular_file_2_hash = factories.file_ids_and_md5_hash(
client, tabular_file_2
)

admin_headers = {"X-Forwarded-Email": settings.admin_users[0]}

tabular_dataset_1_response = client.post(
"/dataset-v2/",
Expand All @@ -2480,6 +2515,9 @@ def test_get_tabular_dataset_data(
"label": {"col_type": "text"},
"col_1": {"units": "a unit", "col_type": "continuous"},
"col_2": {"col_type": "text"},
"col_3": {"col_type": "binary"},
"col_4": {"col_type": "categorical"},
"col_5": {"col_type": "list_strings"},
},
},
headers=admin_headers,
Expand Down Expand Up @@ -2542,9 +2580,12 @@ def test_get_tabular_dataset_data(
)
assert res.json() == {
"depmap_id": {"ACH-1": "ACH-1"},
"label": {"ACH-1": "ach1"},
"col_1": {"ACH-1": "1.0"},
"label": {"ACH-1": "other_label_1"},
"col_1": {"ACH-1": 1},
"col_2": {"ACH-1": "hi"},
"col_3": {"ACH-1": False},
"col_4": {"ACH-1": "cat1"},
"col_5": {"ACH-1": ["a"]},
}

# When both columns and indices not provided, the entire dataset should return'
Expand All @@ -2555,18 +2596,24 @@ def test_get_tabular_dataset_data(
)
assert res.json() == {
"depmap_id": {"ACH-1": "ACH-1", "ACH-2": "ACH-2"},
"label": {"ACH-1": "ach1", "ACH-2": "ach2"},
"col_1": {"ACH-1": "1.0", "ACH-2": None},
"label": {"ACH-1": "other_label_1", "ACH-2": "other_label_2"},
"col_1": {"ACH-1": 1, "ACH-2": None},
"col_2": {"ACH-1": "hi", "ACH-2": "bye"},
"col_3": {"ACH-1": False, "ACH-2": None},
"col_4": {"ACH-1": "cat1", "ACH-2": "cat2"},
"col_5": {"ACH-1": ["a"], "ACH-2": None},
}
res = client.post(
f"/datasets/tabular/{tabular_dataset_1_id}", headers=admin_headers,
)
assert res.json() == {
"depmap_id": {"ACH-1": "ACH-1", "ACH-2": "ACH-2"},
"label": {"ACH-1": "ach1", "ACH-2": "ach2"},
"col_1": {"ACH-1": "1.0", "ACH-2": None},
"label": {"ACH-1": "other_label_1", "ACH-2": "other_label_2"},
"col_1": {"ACH-1": 1, "ACH-2": None},
"col_2": {"ACH-1": "hi", "ACH-2": "bye"},
"col_3": {"ACH-1": False, "ACH-2": None},
"col_4": {"ACH-1": "cat1", "ACH-2": "cat2"},
"col_5": {"ACH-1": ["a"], "ACH-2": None},
}

# Test if no matches found with given query params --> empty df
Expand Down Expand Up @@ -2609,7 +2656,7 @@ def test_get_tabular_dataset_data(
},
headers=admin_headers,
)
assert res.json() == {"col_1": {"ACH-1": "1.0"}}
assert res.json() == {"col_1": {"ACH-1": 1}}

# With strict keyword
res = client.post(
Expand All @@ -2635,6 +2682,76 @@ def test_get_tabular_dataset_data(
)
assert res.status_code == 400

def test_get_tabular_dataset_data_no_index_metadata(
self,
client: TestClient,
minimal_db: SessionWithUser,
mock_celery,
private_group: Dict,
settings,
):
admin_headers = {"X-Forwarded-Email": settings.admin_users[0]}

tabular_file_2 = factories.tabular_csv_data_file(
cols=["depmap_id", "col_1", "col_2"], row_values=[["ACH-1", 1, "hi"]],
)
tabular_file_ids_2, tabular_file_2_hash = factories.file_ids_and_md5_hash(
client, tabular_file_2
)
tabular_dataset_2_response = client.post(
"/dataset-v2/",
json={
"format": "tabular",
"name": "Test Dataset 2",
"index_type": "depmap_model",
"data_type": "User upload",
"file_ids": tabular_file_ids_2,
"dataset_md5": tabular_file_2_hash,
"is_transient": False,
"group_id": private_group["id"],
"dataset_metadata": None,
"columns_metadata": {
"depmap_id": {"col_type": "text",},
"col_1": {"units": "a unit", "col_type": "continuous"},
"col_2": {"col_type": "text"},
},
},
headers=admin_headers,
)
assert_status_ok(tabular_dataset_2_response)
tabular_dataset_2_id = tabular_dataset_2_response.json()["result"]["dataset"][
"id"
]

tabular_dataset_2 = (
minimal_db.query(Dataset).filter_by(id=tabular_dataset_2_id).one()
)
assert tabular_dataset_2

# Get a subset of the tabular dataset by id
res = client.post(
f"/datasets/tabular/{tabular_dataset_2_id}",
json={
"indices": ["ACH-1"],
"identifier": "id",
"columns": ["col_1", "col_2"],
},
headers=admin_headers,
)
assert res.json() == {"col_1": {"ACH-1": 1}, "col_2": {"ACH-1": "hi"}}

# Get a subset of the tabular dataset by label (no data)
res = client.post(
f"/datasets/tabular/{tabular_dataset_2_id}",
json={
"indices": ["ach1"],
"identifier": "label",
"columns": ["col_1", "col_2"],
},
headers=admin_headers,
)
assert res.json() == {}


class TestPatch:
def test_update_dataset(
Expand Down

0 comments on commit e4288a3

Please sign in to comment.