Skip to content

Commit

Permalink
fix(parsing): normalize nested column names
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Oct 26, 2024
1 parent cfe3d9c commit 6ecc5b0
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 19 deletions.
17 changes: 2 additions & 15 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -13,6 +12,7 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import Generator
from datachain.lib.utils import normalize_col_names

if TYPE_CHECKING:
from datasets.features.features import Features
Expand Down Expand Up @@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
signal_schema = _get_datachain_schema(schema)
if signal_schema:
return signal_schema.values
columns = _convert_col_names(col_names) # type: ignore[arg-type]
columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
hf_schema = _get_hf_schema(schema)
if hf_schema:
return {
Expand All @@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
return output


def _convert_col_names(col_names: Sequence[str]) -> list[str]:
default_column = 0
converted_col_names = []
for column in col_names:
column = column.lower()
column = re.sub("[^0-9a-z_]+", "", column)
if not column:
column = f"c{default_column}"
default_column += 1
converted_col_names.append(column)
return converted_col_names


def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
"""Convert pyarrow types to basic types."""
from datetime import datetime
Expand Down
15 changes: 13 additions & 2 deletions src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from typing import ClassVar, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model

from datachain.lib.model_store import ModelStore
from datachain.lib.utils import normalize_col_names

StandardType = Union[
type[int],
Expand Down Expand Up @@ -60,7 +61,17 @@ def is_chain_type(t: type) -> bool:


def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
fields = {name: (anno, ...) for name, anno in data_dict.items()}
columns = normalize_col_names(list(data_dict.keys()))

if len(columns.values()) != len(set(columns.values())):
raise ValueError(

Check warning on line 67 in src/datachain/lib/data_model.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/data_model.py#L67

Added line #L67 was not covered by tests
"Can't create a data model for data that has duplicate columns."
)

columns = {v: k for k, v in columns.items()}
fields = {
columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
}
return create_model(
name,
__base__=(DataModel,), # type: ignore[call-overload]
Expand Down
28 changes: 28 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence


class AbstractUDF(ABC):
Expand Down Expand Up @@ -28,3 +30,29 @@ def __init__(self, message):
class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")


def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
gen_col_counter = 0
new_col_names = {}
org_col_names = set(col_names)

for org_column in col_names:
new_column = org_column.lower()
new_column = re.sub("[-_\\s]+", "_", new_column)
new_column = new_column.strip("_")
new_column = re.sub("[^0-9a-z_]+", "", new_column)

if not new_column or (new_column != org_column and new_column in org_col_names):
while True:
generated_column = f"c{gen_col_counter}"
gen_col_counter += 1
if new_column:
generated_column = f"{generated_column}_{new_column}"

Check warning on line 51 in src/datachain/lib/utils.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/utils.py#L51

Added line #L51 was not covered by tests
if generated_column not in org_col_names:
new_column = generated_column
break

new_col_names[new_column] = org_column

return new_col_names
12 changes: 10 additions & 2 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,21 @@ def test_parquet_convert_column_names():
("dot.notation.col", pa.int32()),
("with-dashes", pa.int32()),
("with spaces", pa.int32()),
("with-multiple--dashes", pa.int32()),
("with__underscores", pa.int32()),
("__leading__underscores", pa.int32()),
("trailing__underscores__", pa.int32()),
]
)
assert list(schema_to_output(schema)) == [
"uppercasecol",
"dotnotationcol",
"withdashes",
"withspaces",
"with_dashes",
"with_spaces",
"with_multiple_dashes",
"with_underscores",
"leading_underscores",
"trailing_underscores",
]


Expand Down
31 changes: 31 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
"city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_DATA_NESTED_NOT_NORMALIZED = {
"nAmE": [
{"first-SELECT": "Alice", "l--as@t": "Smith"},
{"l--as@t": "Jones", "first-SELECT": "Bob"},
{"first-SELECT": "Charlie", "l--as@t": "Brown"},
{"first-SELECT": "David", "l--as@t": "White"},
{"first-SELECT": "Eva", "l--as@t": "Black"},
],
"AgE": [25, 30, 35, 40, 45],
"citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_OTHER_DATA = {
"last_name": ["Smith", "Jones"],
"country": ["USA", "Russia"],
Expand Down Expand Up @@ -984,6 +996,25 @@ def test_parse_tabular_format(tmp_dir, test_session):
assert df1.equals(df)


def test_parse_nested_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED)
path = tmp_dir / "test.jsonl"
path.write_text(df.to_json(orient="records", lines=True))
dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular(
format="json"
)
# Field names are normalized, values are preserved
# E.g. nAmE -> name, l--as@t -> l_ast, etc
df1 = dc.select("name", "age", "city").to_pandas()

assert df1["name"]["first_select"].to_list() == [
d["first-SELECT"] for d in df["nAmE"].to_list()
]
assert df1["name"]["l_ast"].to_list() == [
d["l--as@t"] for d in df["nAmE"].to_list()
]


def test_parse_tabular_partitions(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down

0 comments on commit 6ecc5b0

Please sign in to comment.