From c494017bd6f3f935a02ad13d25c6d5067a25157b Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Thu, 12 Dec 2024 17:59:26 -0800 Subject: [PATCH] Catch multiple slashes in source dataset into one slash (#1697) Co-authored-by: Vincent Chen --- llmfoundry/utils/config_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 997273de7f..252841cb50 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -8,6 +8,7 @@ import os import warnings from dataclasses import dataclass, fields +from pathlib import Path from typing import ( Any, Callable, @@ -703,6 +704,8 @@ def _process_data_source( true_split (str): The split of the dataset to be added (i.e. train or eval) data_paths (List[Tuple[str, str, str]]): A list of tuples formatted as (data type, path, split) """ + if source_dataset_path: + source_dataset_path = str(Path(source_dataset_path)) # Check for Delta table if source_dataset_path and len(source_dataset_path.split('.')) == 3: data_paths.append(('delta_table', source_dataset_path, true_split)) @@ -788,7 +791,6 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None: # Map data source types to their respective MLFlow DataSource. for dataset_type, path, split in data_paths: - if dataset_type in dataset_source_mapping: source_class = dataset_source_mapping[dataset_type] if dataset_type == 'delta_table':