diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 950ce34e..4048848d 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -1,4 +1,3 @@ -import re from collections.abc import Sequence from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Optional @@ -13,6 +12,7 @@ from datachain.lib.model_store import ModelStore from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf import Generator +from datachain.lib.utils import normalize_col_names if TYPE_CHECKING: from datasets.features.features import Features @@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non signal_schema = _get_datachain_schema(schema) if signal_schema: return signal_schema.values - columns = _convert_col_names(col_names) # type: ignore[arg-type] + columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type] hf_schema = _get_hf_schema(schema) if hf_schema: return { @@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non return output -def _convert_col_names(col_names: Sequence[str]) -> list[str]: - default_column = 0 - converted_col_names = [] - for column in col_names: - column = column.lower() - column = re.sub("[^0-9a-z_]+", "", column) - if not column: - column = f"c{default_column}" - default_column += 1 - converted_col_names.append(column) - return converted_col_names - - def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911 """Convert pyarrow types to basic types.""" from datetime import datetime diff --git a/src/datachain/lib/data_model.py b/src/datachain/lib/data_model.py index 67f95f29..94e6d314 100644 --- a/src/datachain/lib/data_model.py +++ b/src/datachain/lib/data_model.py @@ -2,9 +2,10 @@ from datetime import datetime from typing import ClassVar, Union, get_args, get_origin -from pydantic import BaseModel, create_model +from pydantic import BaseModel, Field, create_model from datachain.lib.model_store import ModelStore +from datachain.lib.utils import normalize_col_names StandardType = Union[ type[int], @@ -60,7 +61,17 @@ def is_chain_type(t: type) -> bool: def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]: - fields = {name: (anno, ...) for name, anno in data_dict.items()} + columns = normalize_col_names(list(data_dict.keys())) + + if len(columns.values()) != len(set(columns.values())): + raise ValueError( + "Can't create a data model for data that has duplicate columns." + ) + + columns = {v: k for k, v in columns.items()} + fields = { + columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items() + } return create_model( name, __base__=(DataModel,), # type: ignore[call-overload] diff --git a/src/datachain/lib/utils.py b/src/datachain/lib/utils.py index cd11da9c..0b405978 100644 --- a/src/datachain/lib/utils.py +++ b/src/datachain/lib/utils.py @@ -1,4 +1,6 @@ +import re from abc import ABC, abstractmethod +from collections.abc import Sequence class AbstractUDF(ABC): @@ -28,3 +30,29 @@ def __init__(self, message): class DataChainColumnError(DataChainParamsError): def __init__(self, col_name, msg): super().__init__(f"Error for column {col_name}: {msg}") + + +def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]: + gen_col_counter = 0 + new_col_names = {} + org_col_names = set(col_names) + + for org_column in col_names: + new_column = org_column.lower() + new_column = re.sub("[-_\\s]+", "_", new_column) + new_column = new_column.strip("_") + new_column = re.sub("[^0-9a-z_]+", "", new_column) + + if not new_column or (new_column != org_column and new_column in org_col_names): + while True: + generated_column = f"c{gen_col_counter}" + gen_col_counter += 1 + if new_column: + generated_column = f"{generated_column}_{new_column}" + if generated_column not in org_col_names: + new_column = generated_column + break + + new_col_names[new_column] = org_column + + return new_col_names diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index 4d1414b9..15f19977 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -168,13 +168,21 @@ def test_parquet_convert_column_names(): ("dot.notation.col", pa.int32()), ("with-dashes", pa.int32()), ("with spaces", pa.int32()), + ("with-multiple--dashes", pa.int32()), + ("with__underscores", pa.int32()), + ("__leading__underscores", pa.int32()), + ("trailing__underscores__", pa.int32()), ] ) assert list(schema_to_output(schema)) == [ "uppercasecol", "dotnotationcol", - "withdashes", - "withspaces", + "with_dashes", + "with_spaces", + "with_multiple_dashes", + "with_underscores", + "leading_underscores", + "trailing_underscores", ] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 7da36787..a3f6682e 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -36,6 +36,18 @@ "city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], } +DF_DATA_NESTED_NOT_NORMALIZED = { + "nAmE": [ + {"first-SELECT": "Alice", "l--as@t": "Smith"}, + {"l--as@t": "Jones", "first-SELECT": "Bob"}, + {"first-SELECT": "Charlie", "l--as@t": "Brown"}, + {"first-SELECT": "David", "l--as@t": "White"}, + {"first-SELECT": "Eva", "l--as@t": "Black"}, + ], + "AgE": [25, 30, 35, 40, 45], + "citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], +} + DF_OTHER_DATA = { "last_name": ["Smith", "Jones"], "country": ["USA", "Russia"], @@ -984,6 +996,25 @@ def test_parse_tabular_format(tmp_dir, test_session): assert df1.equals(df) +def test_parse_nested_json(tmp_dir, test_session): + df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED) + path = tmp_dir / "test.jsonl" + path.write_text(df.to_json(orient="records", lines=True)) + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + format="json" + ) + # Field names are normalized, values are preserved + # E.g. nAmE -> name, l--as@t -> l_ast, etc + df1 = dc.select("name", "age", "city").to_pandas() + + assert df1["name"]["first_select"].to_list() == [ + d["first-SELECT"] for d in df["nAmE"].to_list() + ] + assert df1["name"]["l_ast"].to_list() == [ + d["l--as@t"] for d in df["nAmE"].to_list() + ] + + def test_parse_tabular_partitions(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet"