From 67ccf5774f8e6ab409bc206dc0d288ea84a6c594 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 22 Oct 2024 21:53:48 -0700 Subject: [PATCH 01/11] add support for converting from safetensors --- hf_olmo/convert_olmo_to_hf.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 2e0a9e074..48011a3ef 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -9,15 +9,16 @@ from urllib.parse import urlparse import torch +from olmo import ModelConfig, Tokenizer, TrainConfig +from olmo.checkpoint import build_sharded_checkpointer +from olmo.util import _get_s3_client from omegaconf import OmegaConf as om +from safetensors.torch import load_file from tqdm import tqdm from hf_olmo.configuration_olmo import OLMoConfig from hf_olmo.modeling_olmo import OLMoForCausalLM from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast -from olmo import ModelConfig, Tokenizer, TrainConfig -from olmo.checkpoint import build_sharded_checkpointer -from olmo.util import _get_s3_client logger = logging.getLogger(__name__) @@ -67,10 +68,16 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. # So, we explicitly store the model with the expected prefix. - old_model_path = os.path.join(checkpoint_dir, "model.pt") - new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin") + if os.path.exists(os.path.join(checkpoint_dir, "model.pt")): + old_model_path = os.path.join(checkpoint_dir, "model.pt") + state_dict = torch.load(old_model_path, map_location="cpu") + elif os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): + old_model_path = os.path.join(checkpoint_dir, "model.safetensors") + state_dict = load_file(old_model_path, device="cpu") + else: + raise ValueError(f"No model found in {checkpoint_dir}") - state_dict = torch.load(old_model_path, map_location="cpu") + new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin") # this takes care of the case where the model was saved with a different prefix, # typically due to unsharding. @@ -233,7 +240,9 @@ def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str): def maybe_unshard(checkpoint_dir: str): - if os.path.exists(os.path.join(checkpoint_dir, "model.pt")): + if os.path.exists(os.path.join(checkpoint_dir, "model.pt")) or os.path.exists( + os.path.join(checkpoint_dir, "model.safetensors") + ): return print(f"Unsharding {checkpoint_dir}...") From 087e253e2135d5a420fbb3e07b66b834a93957c1 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 22 Oct 2024 21:54:59 -0700 Subject: [PATCH 02/11] changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b73eeae96..9f0472818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Added support for safetensors in `hf_olmo` conversion script. + ## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17 ### Added @@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Swapped in correct flan data mix. - Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream. - Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints. -- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout +- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout ## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11 From 48a4e8e0ba942cd52554f9f669ecd28b8c56352c Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 29 Oct 2024 19:51:34 -0700 Subject: [PATCH 03/11] sorted imports --- hf_olmo/convert_olmo_to_hf.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 48011a3ef..0c87d0e58 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -9,16 +9,16 @@ from urllib.parse import urlparse import torch -from olmo import ModelConfig, Tokenizer, TrainConfig -from olmo.checkpoint import build_sharded_checkpointer -from olmo.util import _get_s3_client from omegaconf import OmegaConf as om -from safetensors.torch import load_file from tqdm import tqdm from hf_olmo.configuration_olmo import OLMoConfig from hf_olmo.modeling_olmo import OLMoForCausalLM from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast +from olmo import ModelConfig, Tokenizer, TrainConfig +from olmo.checkpoint import build_sharded_checkpointer +from olmo.safetensors_util import safetensors_file_to_state_dict +from olmo.util import _get_s3_client logger = logging.getLogger(__name__) @@ -68,12 +68,10 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. # So, we explicitly store the model with the expected prefix. - if os.path.exists(os.path.join(checkpoint_dir, "model.pt")): - old_model_path = os.path.join(checkpoint_dir, "model.pt") + if os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.pt")): state_dict = torch.load(old_model_path, map_location="cpu") - elif os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): - old_model_path = os.path.join(checkpoint_dir, "model.safetensors") - state_dict = load_file(old_model_path, device="cpu") + elif os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.safetensors")): + state_dict = safetensors_file_to_state_dict(old_model_path, map_location="cpu") else: raise ValueError(f"No model found in {checkpoint_dir}") From 747453136c6f2534f268bf8674bb2e56eda3d325 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 30 Oct 2024 17:44:50 -0700 Subject: [PATCH 04/11] gcs support --- hf_olmo/convert_olmo_to_hf.py | 111 ++++++++++++++++++++++++---------- olmo/util.py | 43 +++++++++++-- 2 files changed, 118 insertions(+), 36 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 0c87d0e58..74056333e 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -5,6 +5,7 @@ import shutil import tempfile from hashlib import md5 +from pathlib import Path from typing import Iterable, Optional from urllib.parse import urlparse @@ -18,7 +19,7 @@ from olmo import ModelConfig, Tokenizer, TrainConfig from olmo.checkpoint import build_sharded_checkpointer from olmo.safetensors_util import safetensors_file_to_state_dict -from olmo.util import _get_s3_client +from olmo.util import _get_gcs_client, _get_s3_client, walk_local_path logger = logging.getLogger(__name__) @@ -148,6 +149,21 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N om.save(conf, path) +def download_gcs_directory(bucket_name: str, prefix: str, local_dir: str): + path_local = Path(local_dir) + path_prefix = Path(prefix) + + gcs_client = _get_s3_client() + bucket = gcs_client.bucket(bucket_name) + + path_local.mkdir(parents=True, exist_ok=True) + + for elem in bucket.list_blobs(prefix=prefix): + local_destination = path_local / Path(elem.name).relative_to(path_prefix) + local_destination.parent.mkdir(parents=True, exist_ok=True) + elem.download_to_filename(local_destination) + + def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: str | None = None): # Create S3 client s3_client = _get_s3_client("s3") @@ -183,7 +199,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: def make_local_checkpoint(checkpoint_dir: str) -> str: parsed_dir = urlparse(checkpoint_dir) - assert parsed_dir.scheme in ["s3", ""], "Only s3 and local paths are supported." + assert parsed_dir.scheme in ["s3", "gs", ""], "Only s3 and local paths are supported." if os.path.exists(checkpoint_dir): return checkpoint_dir @@ -194,12 +210,22 @@ def make_local_checkpoint(checkpoint_dir: str) -> str: try: os.makedirs(temp_dir, exist_ok=True) print(f"Downloading checkpoint to {temp_dir}...") - download_s3_directory( - bucket_name=parsed_dir.netloc, - prefix=parsed_dir.path.lstrip("/"), - local_dir=temp_dir, - ignore=r"/(optim|train)/", - ) + + if parsed_dir.scheme == "gs": + download_gcs_directory( + bucket_name=parsed_dir.netloc, + prefix=parsed_dir.path.lstrip("/"), + local_dir=temp_dir, + ) + elif parsed_dir.scheme == "s3": + download_s3_directory( + bucket_name=parsed_dir.netloc, + prefix=parsed_dir.path.lstrip("/"), + local_dir=temp_dir, + ignore=r"/(optim|train)/", + ) + else: + raise ValueError(f"Unsupported: {checkpoint_dir}. Only s3://, gs://, and local are supported.") except Exception as e: logger.error(f"Error downloading checkpoint: {e}") shutil.rmtree(temp_dir) @@ -208,33 +234,54 @@ def make_local_checkpoint(checkpoint_dir: str) -> str: return temp_dir +def upload_s3_directory(local_checkpoint_dir: str, destination_dir: str): + parsed_destination = urlparse(destination_dir) + if parsed_destination.scheme != "s3": + raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 paths are supported.") + + s3_client = _get_s3_client("s3") + s3_bucket_name = parsed_destination.netloc + s3_prefix = Path(parsed_destination.path) + local_checkpoint_path = Path(local_checkpoint_dir) + local_paths = [ + Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames + ] + + for local_path in tqdm(local_paths, desc="Uploading files to S3"): + destination = s3_prefix / local_path.relative_to(local_checkpoint_path) + s3_client.upload_file(local_path, s3_bucket_name, str(destination)) + + +def upload_gcs_directory(local_checkpoint_dir: str, destination_dir: str): + parsed_destination = urlparse(destination_dir) + if parsed_destination.scheme != "gs": + raise ValueError(f"Unsupported destination: {destination_dir}. Only gs paths are supported.") + + gcs_client = _get_gcs_client() + bucket_name = parsed_destination.netloc + prefix = Path(parsed_destination.path) + local_checkpoint_path = Path(local_checkpoint_dir) + local_paths = [ + Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames + ] + + for local_path in tqdm(local_paths, desc="Uploading files to GCS"): + destination = prefix / local_path.relative_to(local_checkpoint_path) + blob = gcs_client.bucket(bucket_name).blob(str(destination)) + blob.upload_from_filename(local_path) + + def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str): if destination_dir == local_checkpoint_dir: return - elif (parsed_url := urlparse(destination_dir)).scheme == "s3": - s3_bucket_name = parsed_url.netloc - s3_prefix = parsed_url.path[1:] - - local_paths = [ - os.path.join(root, post_fn) - for root, _, files in os.walk(local_checkpoint_dir) - for post_fn in files - if os.path.basename(post_fn) in HF_FILENAMES - ] - dest_paths = [ - os.path.join(s3_prefix, os.path.relpath(local_path, local_checkpoint_dir)) - for local_path in local_paths - ] - - s3_client = _get_s3_client("s3") - for local_path, dest_path in tqdm( - zip(local_paths, dest_paths), desc="Uploading files", total=len(local_paths) - ): - s3_client.upload_file(local_path, s3_bucket_name, dest_path) - elif parsed_url.scheme == "": - shutil.copytree(local_checkpoint_dir, destination_dir) - else: - raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 and local paths are supported.") + + if (parsed_url := urlparse(destination_dir)).scheme == "s3": + return upload_s3_directory(local_checkpoint_dir, destination_dir) + + elif parsed_url.scheme == "gs": + return upload_gcs_directory(local_checkpoint_dir, destination_dir) + + raise ValueError(f"Unsupported protocol: {destination_dir}. Only s3://, gs://, and local are supported.") def maybe_unshard(checkpoint_dir: str): diff --git a/olmo/util.py b/olmo/util.py index aad77eb1c..694688926 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -313,6 +313,12 @@ def get_progress_bar() -> Progress: return get_download_progress() +def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False): + """Necessary because Path.walk() was only added in python 3.12""" + for root, dirs, files in os.walk(path, topdown=top_down, onerror=on_error, followlinks=follow_symlinks): + yield Path(root), dirs, files + + def resource_path( folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None ) -> Path: @@ -503,6 +509,30 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]: raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}") +@cache +def _get_gcs_client(): + from google.auth import default + from google.auth.credentials import TokenState + from google.auth.exceptions import DefaultCredentialsError + from google.cloud import storage as gcs + from google.oauth2 import service_account + + try: + credentials, _ = default() + if not getattr(credentials, "service_account_email", None): + raise DefaultCredentialsError("Cannot get GCS credentials") + if getattr(credentials, "token_state", None) != TokenState.FRESH: + raise DefaultCredentialsError("Cannot get GCS credentials") + except DefaultCredentialsError: + pass + + if credentials_path := os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None): + credentials = service_account.Credentials.from_service_account_file(credentials_path) + return gcs.Client(credentials=credentials) + + raise DefaultCredentialsError("Cannot get GCS credentials") + + @cache def _get_s3_client(scheme: str): session = boto3.Session(profile_name=_get_s3_profile_name(scheme)) @@ -637,7 +667,11 @@ def _http_file_size(scheme: str, host_name: str, path: str) -> int: import requests response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True) - return int(response.headers.get("content-length")) + + if (content_length := response.headers.get("content-length")) is not None: + return int(content_length) + + raise OLMoNetworkError(f"Failed to get {scheme} file size") def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes: @@ -647,9 +681,10 @@ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: i f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"} ) result = response.content - assert ( - len(result) == num_bytes - ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything + + # Some web servers silently ignore range requests and send everything + assert len(result) == num_bytes, f"expected {num_bytes} bytes, got {len(result)}" + return result From 6d1e26d55c85c65457259af705aa7838f0f66f0b Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 30 Oct 2024 20:27:31 -0700 Subject: [PATCH 05/11] simplified client --- hf_olmo/convert_olmo_to_hf.py | 21 ++++++++++++++++----- olmo/util.py | 19 +------------------ 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 74056333e..1406bdf75 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -17,9 +17,10 @@ from hf_olmo.modeling_olmo import OLMoForCausalLM from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast from olmo import ModelConfig, Tokenizer, TrainConfig +from olmo.aliases import PathOrStr from olmo.checkpoint import build_sharded_checkpointer from olmo.safetensors_util import safetensors_file_to_state_dict -from olmo.util import _get_gcs_client, _get_s3_client, walk_local_path +from olmo.util import _get_gcs_client, _get_s3_client logger = logging.getLogger(__name__) @@ -32,6 +33,12 @@ } +def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False): + """Necessary because Path.walk() was only added in python 3.12""" + for root, dirs, files in os.walk(path, topdown=top_down, onerror=on_error, followlinks=follow_symlinks): + yield Path(root), dirs, files + + def longest_common_prefix(strs: Iterable[str]) -> str: """ Finds the longest common prefix among a list of strings. @@ -153,12 +160,14 @@ def download_gcs_directory(bucket_name: str, prefix: str, local_dir: str): path_local = Path(local_dir) path_prefix = Path(prefix) - gcs_client = _get_s3_client() + gcs_client = _get_gcs_client() bucket = gcs_client.bucket(bucket_name) path_local.mkdir(parents=True, exist_ok=True) - for elem in bucket.list_blobs(prefix=prefix): + files_to_download = list(bucket.list_blobs(prefix=prefix)) + + for elem in tqdm(files_to_download, desc="Downloading files from GCS"): local_destination = path_local / Path(elem.name).relative_to(path_prefix) local_destination.parent.mkdir(parents=True, exist_ok=True) elem.download_to_filename(local_destination) @@ -183,7 +192,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: files_to_download.append(obj["Key"]) # Initialize the progress bar - for s3_key in tqdm(files_to_download, desc="Downloading files"): + for s3_key in tqdm(files_to_download, desc="Downloading files from S3"): # Construct the full local path local_file_path = os.path.join(local_dir, os.path.relpath(s3_key, prefix)) local_file_dir = os.path.dirname(local_file_path) @@ -265,9 +274,11 @@ def upload_gcs_directory(local_checkpoint_dir: str, destination_dir: str): Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames ] + bucket = gcs_client.bucket(bucket_name) + for local_path in tqdm(local_paths, desc="Uploading files to GCS"): destination = prefix / local_path.relative_to(local_checkpoint_path) - blob = gcs_client.bucket(bucket_name).blob(str(destination)) + blob = bucket.blob(str(destination)) blob.upload_from_filename(local_path) diff --git a/olmo/util.py b/olmo/util.py index 694688926..bd7c6504a 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -511,26 +511,9 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]: @cache def _get_gcs_client(): - from google.auth import default - from google.auth.credentials import TokenState - from google.auth.exceptions import DefaultCredentialsError from google.cloud import storage as gcs - from google.oauth2 import service_account - try: - credentials, _ = default() - if not getattr(credentials, "service_account_email", None): - raise DefaultCredentialsError("Cannot get GCS credentials") - if getattr(credentials, "token_state", None) != TokenState.FRESH: - raise DefaultCredentialsError("Cannot get GCS credentials") - except DefaultCredentialsError: - pass - - if credentials_path := os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None): - credentials = service_account.Credentials.from_service_account_file(credentials_path) - return gcs.Client(credentials=credentials) - - raise DefaultCredentialsError("Cannot get GCS credentials") + return gcs.Client() @cache From 646750488fed923bddf0b72ab7b20afc3cba80ca Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 30 Oct 2024 20:29:05 -0700 Subject: [PATCH 06/11] type --- hf_olmo/convert_olmo_to_hf.py | 2 +- olmo/util.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 1406bdf75..1b6b9ef0d 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -35,7 +35,7 @@ def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False): """Necessary because Path.walk() was only added in python 3.12""" - for root, dirs, files in os.walk(path, topdown=top_down, onerror=on_error, followlinks=follow_symlinks): + for root, dirs, files in os.walk(str(path), topdown=top_down, onerror=on_error, followlinks=follow_symlinks): yield Path(root), dirs, files diff --git a/olmo/util.py b/olmo/util.py index bd7c6504a..5ba85a9a6 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -313,12 +313,6 @@ def get_progress_bar() -> Progress: return get_download_progress() -def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False): - """Necessary because Path.walk() was only added in python 3.12""" - for root, dirs, files in os.walk(path, topdown=top_down, onerror=on_error, followlinks=follow_symlinks): - yield Path(root), dirs, files - - def resource_path( folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None ) -> Path: From d4a123049a5a22624e3532ac8a50446b0dbbf24d Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 4 Nov 2024 11:10:13 -0800 Subject: [PATCH 07/11] Create olmo_soup.py --- scripts/olmo_soup.py | 119 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 scripts/olmo_soup.py diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py new file mode 100644 index 000000000..83c63de9a --- /dev/null +++ b/scripts/olmo_soup.py @@ -0,0 +1,119 @@ +''' +Soups OLMo checkpoints. + +Example usage: + +```bash + python scripts/olmo_soup.py -c \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed2/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed3/step11931 \ + -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup/step11931 +``` + +Author: Luca Soldaini (@soldni) + +''' # noqa + + +import argparse +from enum import Enum +from pathlib import Path + +import torch +from tqdm import tqdm +from olmo.config import TrainConfig +from olmo.checkpoint import build_sharded_checkpointer +from olmo.safetensors_util import safetensors_file_to_state_dict + + +class SoupType(Enum): + uniform = "uniform" + + +def load_checkpoint(path: Path) -> dict[str, torch.Tensor]: + if path.exists() and path.is_file(): + return torch.load(path, map_location="cpu", weights_only=True) + + if (path / "model.pt").exists(): + return torch.load(path / "model.pt", map_location="cpu", weights_only=True) + + if (path / "model.safetensors").exists(): + safetensors_file_to_state_dict(path / "model.safetensors") + + if (path / "model").exists() and (config_path := (path / "config.yaml")).exists(): + train_config = TrainConfig.load(config_path) + checkpointer = build_sharded_checkpointer(train_config) + model_state, _, _ = checkpointer.unshard_checkpoint( + load_path=str(path), load_optimizer_state=False, load_trainer_state=False + ) + return model_state + + raise FileNotFoundError(f"Could not find checkpoint in {path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Soup OLMo checkponts") + parser.add_argument( + "-c", "--checkpoints", + type=Path, + required=True, + nargs="+", + help="Path to checkpoint(s) to soup", + ) + parser.add_argument( + "-s", "--soup-type", + type=SoupType, + default=SoupType.uniform, + help=f"Methods for checkpoint souping. Choose from: {', '.join(SoupType.__members__.keys())}", + ) + parser.add_argument( + "-o", "--output", + type=Path, + required=True, + help="Path to save the souped checkpoint", + ) + opts = parser.parse_args() + return opts + + +def main(): + args = parse_args() + + checkpoint_average: dict[str, torch.Tensor] = {} + + for path in tqdm(args.checkpoints, desc="Loading checkpoints", position=0): + state_dict = load_checkpoint(path) + + if len(checkpoint_average) == 0: + # initialize checkpoint_average with zeros + checkpoint_average = {k: torch.zeros_like(v) for k, v in state_dict.items()} + + if ( + any(k not in state_dict for k in checkpoint_average.keys()) or + any(k not in checkpoint_average for k in state_dict.keys()) + ): + raise ValueError(f"Checkpoint {path} has different keys") + + for k in tqdm(state_dict, desc="Summing checkpoints", position=1): + if state_dict[k].shape != checkpoint_average[k].shape: + raise ValueError(f"Checkpoint {path} has different shape for key {k}") + checkpoint_average[k] += state_dict[k] / len(args.checkpoints) + + # free memory + del state_dict + + print(f"Saving averaged checkpoint to {args.output}") + # save the averaged checkpoint + args.output.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint_average, args.output / "model.pt") + + print("Copying config.yaml") + # copy the config file + if (config_path := args.checkpoints[0] / "config.yaml").exists(): + config_path.rename(args.output / "config.yaml") + print("Done!") + + +if __name__ == "__main__": + main() From 133e2ec56916d2d4165193229b4c654fd10b362d Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 4 Nov 2024 13:20:03 -0800 Subject: [PATCH 08/11] style --- scripts/olmo_soup.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py index 83c63de9a..3e995b39d 100644 --- a/scripts/olmo_soup.py +++ b/scripts/olmo_soup.py @@ -1,4 +1,4 @@ -''' +""" Soups OLMo checkpoints. Example usage: @@ -13,7 +13,7 @@ Author: Luca Soldaini (@soldni) -''' # noqa +""" # noqa import argparse @@ -22,8 +22,9 @@ import torch from tqdm import tqdm -from olmo.config import TrainConfig + from olmo.checkpoint import build_sharded_checkpointer +from olmo.config import TrainConfig from olmo.safetensors_util import safetensors_file_to_state_dict @@ -55,20 +56,23 @@ def load_checkpoint(path: Path) -> dict[str, torch.Tensor]: def parse_args(): parser = argparse.ArgumentParser(description="Soup OLMo checkponts") parser.add_argument( - "-c", "--checkpoints", + "-c", + "--checkpoints", type=Path, required=True, nargs="+", help="Path to checkpoint(s) to soup", ) parser.add_argument( - "-s", "--soup-type", + "-s", + "--soup-type", type=SoupType, default=SoupType.uniform, help=f"Methods for checkpoint souping. Choose from: {', '.join(SoupType.__members__.keys())}", ) parser.add_argument( - "-o", "--output", + "-o", + "--output", type=Path, required=True, help="Path to save the souped checkpoint", @@ -89,9 +93,8 @@ def main(): # initialize checkpoint_average with zeros checkpoint_average = {k: torch.zeros_like(v) for k, v in state_dict.items()} - if ( - any(k not in state_dict for k in checkpoint_average.keys()) or - any(k not in checkpoint_average for k in state_dict.keys()) + if any(k not in state_dict for k in checkpoint_average.keys()) or any( + k not in checkpoint_average for k in state_dict.keys() ): raise ValueError(f"Checkpoint {path} has different keys") From 2c8ccc95abd54b61e3872ee3958b489bd0ce7def Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 11 Nov 2024 13:26:32 -0800 Subject: [PATCH 09/11] tokenizer --- hf_olmo/convert_olmo_to_hf.py | 154 +++++++++++++++++++++------------- 1 file changed, 96 insertions(+), 58 deletions(-) diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 1b6b9ef0d..e00f4cbe0 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -1,5 +1,4 @@ import argparse -import logging import os import re import shutil @@ -22,7 +21,6 @@ from olmo.safetensors_util import safetensors_file_to_state_dict from olmo.util import _get_gcs_client, _get_s3_client -logger = logging.getLogger(__name__) HF_FILENAMES = { "config.json", @@ -57,22 +55,26 @@ def longest_common_prefix(strs: Iterable[str]) -> str: return shortest_str -def write_config(checkpoint_dir: str): +def write_config(checkpoint_dir: str, destination_dir: str): # save config as HF config - logger.info(f"Loading checkpoint from {checkpoint_dir}") + print(f"Loading checkpoint from {checkpoint_dir}") + + if os.path.exists(os.path.join(destination_dir, "config.yaml")): + config_path = os.path.join(destination_dir, "config.yaml") + else: + config_path = os.path.join(checkpoint_dir, "config.yaml") - config_path = os.path.join(checkpoint_dir, "config.yaml") model_config = ModelConfig.load(config_path, key="model") config_kwargs = model_config.asdict() config_kwargs["use_cache"] = True config = OLMoConfig(**config_kwargs) - logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") - config.save_pretrained(checkpoint_dir) + print(f"Saving HF-compatible config to {os.path.join(destination_dir, 'config.json')}") + config.save_pretrained(destination_dir) -def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): +def write_model(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False): # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. # So, we explicitly store the model with the expected prefix. @@ -83,7 +85,7 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): else: raise ValueError(f"No model found in {checkpoint_dir}") - new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin") + new_model_path = os.path.join(destination_dir, "pytorch_model.bin") # this takes care of the case where the model was saved with a different prefix, # typically due to unsharding. @@ -98,7 +100,7 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): os.remove(old_model_path) -def write_tokenizer(checkpoint_dir: str): +def write_tokenizer(checkpoint_dir: str, destination_dir: str): tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir) tokenizer = OLMoTokenizerFast( tokenizer_object=tokenizer_raw.base_tokenizer, @@ -109,33 +111,37 @@ def write_tokenizer(checkpoint_dir: str): tokenizer.model_input_names = ["input_ids", "attention_mask"] tokenizer.pad_token_id = tokenizer_raw.pad_token_id tokenizer.eos_token_id = tokenizer_raw.eos_token_id - - tokenizer.save_pretrained(checkpoint_dir) + tokenizer.save_pretrained(destination_dir) -def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): +def convert_checkpoint(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False): print("Converting checkpoint to HF format...") - write_config(checkpoint_dir) + write_config(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir) print("Saving model to checkpoint...") - write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility) + write_model( + checkpoint_dir=checkpoint_dir, + destination_dir=destination_dir, + ignore_olmo_compatibility=ignore_olmo_compatibility + ) print("Saving tokenizer to checkpoint...") - write_tokenizer(checkpoint_dir) + write_tokenizer(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir) # Cannot remove it before writing the tokenizer if ignore_olmo_compatibility: - os.remove(os.path.join(checkpoint_dir, "config.yaml")) + os.remove(os.path.join(destination_dir, "config.yaml")) -def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = None): - path = os.path.join(checkpoint_dir, "config.yaml") - conf = om.load(path) +def fix_tokenizer(checkpoint_dir: str, destination_dir: str, tokenizer_name_or_path: Optional[str] = None): + Path(destination_dir).mkdir(parents=True, exist_ok=True) - print("Saving tokenizer to checkpoint...") + source_path = os.path.join(checkpoint_dir, "config.yaml") + dest_path = os.path.join(destination_dir, "config.yaml") + conf = om.load(source_path) + print(f"Saving saving new tokenizer configuration to {dest_path}") tokenizer_name_or_path = str(tokenizer_name_or_path or conf["tokenizer"]["identifier"]) # pyright: ignore - try: if os.path.exists(tokenizer_name_or_path): Tokenizer.from_file(tokenizer_name_or_path) @@ -143,7 +149,7 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N Tokenizer.from_pretrained(tokenizer_name_or_path) except Exception as e: # the tokenizer is not valid - logger.error(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}") + print(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}") raise e conf["tokenizer"]["identifier"] = tokenizer_name_or_path # pyright: ignore @@ -153,7 +159,7 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N ): conf["model"]["eos_token_id"] = 50279 # pyright: ignore - om.save(conf, path) + om.save(conf, dest_path) def download_gcs_directory(bucket_name: str, prefix: str, local_dir: str): @@ -208,7 +214,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: def make_local_checkpoint(checkpoint_dir: str) -> str: parsed_dir = urlparse(checkpoint_dir) - assert parsed_dir.scheme in ["s3", "gs", ""], "Only s3 and local paths are supported." + assert parsed_dir.scheme in ["s3", "gs", "", "file"], "Only s3, gcs, and local paths are supported." if os.path.exists(checkpoint_dir): return checkpoint_dir @@ -236,7 +242,7 @@ def make_local_checkpoint(checkpoint_dir: str) -> str: else: raise ValueError(f"Unsupported: {checkpoint_dir}. Only s3://, gs://, and local are supported.") except Exception as e: - logger.error(f"Error downloading checkpoint: {e}") + print(f"Error downloading checkpoint: {e}") shutil.rmtree(temp_dir) raise e @@ -292,13 +298,32 @@ def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str): elif parsed_url.scheme == "gs": return upload_gcs_directory(local_checkpoint_dir, destination_dir) + # if parsed_url.scheme in ("file", ""): + + breakpoint() + raise ValueError(f"Unsupported protocol: {destination_dir}. Only s3://, gs://, and local are supported.") -def maybe_unshard(checkpoint_dir: str): - if os.path.exists(os.path.join(checkpoint_dir, "model.pt")) or os.path.exists( - os.path.join(checkpoint_dir, "model.safetensors") - ): +def maybe_unshard(checkpoint_dir: str, destination_dir: str): + if os.path.exists(os.path.join(checkpoint_dir, "model.pt")): + # copy the model.pt to the destination directory + if checkpoint_dir != destination_dir: + print("Copying model.pt to destination directory...") + shutil.copy(os.path.join(checkpoint_dir, "model.pt"), os.path.join(destination_dir, "model.pt")) + + print("model.pt found; skipping unsharding.") + return + + if os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): + # copy the model.safetensors to the destination directory + if checkpoint_dir != destination_dir: + print("Copying model.safetensors to destination directory...") + shutil.copy( + os.path.join(checkpoint_dir, "model.safetensors"), + os.path.join(destination_dir, "model.safetensors") + ) + print("model.savetensors found; skipping unsharding.") return print(f"Unsharding {checkpoint_dir}...") @@ -333,12 +358,6 @@ def main(): help="Ignore compatibility with the olmo codebase. " "This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.", ) - parser.add_argument( - "--logger-level", - default="warning", - help="Set the logger level.", - ) - parser.add_argument( "--tokenizer", help="Override the tokenizer to use for the checkpoint.", @@ -350,29 +369,48 @@ def main(): ) args = parser.parse_args() + local_destination_dir = args.destination_dir or args.checkpoint_dir - args.destination_dir = args.destination_dir or args.checkpoint_dir - logging.basicConfig() - logger.setLevel(logging.getLevelName(args.logger_level.upper())) - - local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir) - args.checkpoint_dir = local_checkpoint_dir - maybe_unshard(local_checkpoint_dir) - - fix_tokenizer(checkpoint_dir=local_checkpoint_dir, tokenizer_name_or_path=args.tokenizer) - convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility) - - if not args.keep_olmo_artifacts: - print("Removing non-HF artifacts...") - os.remove(os.path.join(local_checkpoint_dir, "config.yaml")) - os.remove(os.path.join(local_checkpoint_dir, "model.pt")) - shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True) - shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True) - shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True) - - upload_local_checkpoint(local_checkpoint_dir, args.destination_dir) - - print(f"Converted checkpoint saved to {args.destination_dir}") + try: + local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir) + + if local_checkpoint_dir != args.checkpoint_dir: + # if using a remote checkpoint, save the converted checkpoint locally + print("Remote checkpoint; using local directory as destination.") + local_destination_dir = local_checkpoint_dir + + Path(args.destination_dir).mkdir(parents=True, exist_ok=True) + maybe_unshard(checkpoint_dir=local_checkpoint_dir, destination_dir=local_destination_dir) + + fix_tokenizer( + checkpoint_dir=local_checkpoint_dir, + destination_dir=local_destination_dir, + tokenizer_name_or_path=args.tokenizer + ) + + convert_checkpoint( + checkpoint_dir=args.checkpoint_dir, + destination_dir=local_destination_dir, + ignore_olmo_compatibility=args.ignore_olmo_compatibility + ) + + if not args.keep_olmo_artifacts: + print("Removing non-HF artifacts...") + os.remove(os.path.join(local_checkpoint_dir, "config.yaml")) + os.remove(os.path.join(local_checkpoint_dir, "model.pt")) + shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True) + shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True) + shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True) + + upload_local_checkpoint(local_destination_dir, args.destination_dir) + + print(f"Converted checkpoint saved to {args.destination_dir}") + except Exception as e: + print(f"Error converting checkpoint: {e}") + if args.checkpoint_dir != local_destination_dir: + print("Removing partially converted checkpoint...") + shutil.rmtree(args.destination_dir) + raise e if __name__ == "__main__": From 5fae50e5e68f29355cfd9b38b3872c20693f9bb0 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 11 Nov 2024 15:21:40 -0800 Subject: [PATCH 10/11] fixed souping moving yaml --- scripts/olmo_soup.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py index 3e995b39d..9f99edf2e 100644 --- a/scripts/olmo_soup.py +++ b/scripts/olmo_soup.py @@ -11,6 +11,24 @@ -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup/step11931 ``` +Merge any three checkpoints out of five: + +```bash +for i in $(seq 1 5); do + for j in $(seq $((i+1)) 5); do + for k in $(seq $((j+1)) 5); do + echo "Merging $i $j $k" + python scripts/olmo_soup.py -c \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$i/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$j/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$k/step11931 \ + -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup-$i$j$k/step11931 + done + done +done +``` + + Author: Luca Soldaini (@soldni) """ # noqa @@ -114,7 +132,8 @@ def main(): print("Copying config.yaml") # copy the config file if (config_path := args.checkpoints[0] / "config.yaml").exists(): - config_path.rename(args.output / "config.yaml") + with open(config_path, "r") as src_f, open(args.output / "config.yaml", "w") as dst_f: + dst_f.write(src_f.read()) print("Done!") From 7fdd21f9c9b2c07a38b6bf6f4871526b6ec48d15 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 11 Nov 2024 17:17:55 -0800 Subject: [PATCH 11/11] instr --- scripts/olmo_soup.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py index 9f99edf2e..64e037ede 100644 --- a/scripts/olmo_soup.py +++ b/scripts/olmo_soup.py @@ -11,24 +11,6 @@ -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup/step11931 ``` -Merge any three checkpoints out of five: - -```bash -for i in $(seq 1 5); do - for j in $(seq $((i+1)) 5); do - for k in $(seq $((j+1)) 5); do - echo "Merging $i $j $k" - python scripts/olmo_soup.py -c \ - /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$i/step11931 \ - /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$j/step11931 \ - /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed$k/step11931 \ - -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup-$i$j$k/step11931 - done - done -done -``` - - Author: Luca Soldaini (@soldni) """ # noqa