diff --git a/CHANGELOG.md b/CHANGELOG.md
index b73eeae96..9f0472818 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## Unreleased
 
+- Added support for safetensors in `hf_olmo` conversion script.
+
 ## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17
 
 ### Added
@@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 - Swapped in correct flan data mix.
 - Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream.
 - Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints.
-- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout 
+- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout
 
 
 ## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11
diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py
index 2e0a9e074..e00f4cbe0 100644
--- a/hf_olmo/convert_olmo_to_hf.py
+++ b/hf_olmo/convert_olmo_to_hf.py
@@ -1,10 +1,10 @@
 import argparse
-import logging
 import os
 import re
 import shutil
 import tempfile
 from hashlib import md5
+from pathlib import Path
 from typing import Iterable, Optional
 from urllib.parse import urlparse
 
@@ -16,10 +16,11 @@
 from hf_olmo.modeling_olmo import OLMoForCausalLM
 from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
 from olmo import ModelConfig, Tokenizer, TrainConfig
+from olmo.aliases import PathOrStr
 from olmo.checkpoint import build_sharded_checkpointer
-from olmo.util import _get_s3_client
+from olmo.safetensors_util import safetensors_file_to_state_dict
+from olmo.util import _get_gcs_client, _get_s3_client
 
-logger = logging.getLogger(__name__)
 
 HF_FILENAMES = {
     "config.json",
@@ -30,6 +31,12 @@
 }
 
 
+def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False):
+    """Necessary because Path.walk() was only added in python 3.12"""
+    for root, dirs, files in os.walk(str(path), topdown=top_down, onerror=on_error, followlinks=follow_symlinks):
+        yield Path(root), dirs, files
+
+
 def longest_common_prefix(strs: Iterable[str]) -> str:
     """
     Finds the longest common prefix among a list of strings.
@@ -48,29 +55,37 @@ def longest_common_prefix(strs: Iterable[str]) -> str:
     return shortest_str
 
 
-def write_config(checkpoint_dir: str):
+def write_config(checkpoint_dir: str, destination_dir: str):
     # save config as HF config
 
-    logger.info(f"Loading checkpoint from {checkpoint_dir}")
+    print(f"Loading checkpoint from {checkpoint_dir}")
+
+    if os.path.exists(os.path.join(destination_dir, "config.yaml")):
+        config_path = os.path.join(destination_dir, "config.yaml")
+    else:
+        config_path = os.path.join(checkpoint_dir, "config.yaml")
 
-    config_path = os.path.join(checkpoint_dir, "config.yaml")
     model_config = ModelConfig.load(config_path, key="model")
     config_kwargs = model_config.asdict()
     config_kwargs["use_cache"] = True
     config = OLMoConfig(**config_kwargs)
 
-    logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
-    config.save_pretrained(checkpoint_dir)
+    print(f"Saving HF-compatible config to {os.path.join(destination_dir, 'config.json')}")
+    config.save_pretrained(destination_dir)
 
 
-def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
+def write_model(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False):
     # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
     # So, we explicitly store the model with the expected prefix.
 
-    old_model_path = os.path.join(checkpoint_dir, "model.pt")
-    new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
+    if os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.pt")):
+        state_dict = torch.load(old_model_path, map_location="cpu")
+    elif os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.safetensors")):
+        state_dict = safetensors_file_to_state_dict(old_model_path, map_location="cpu")
+    else:
+        raise ValueError(f"No model found in {checkpoint_dir}")
 
-    state_dict = torch.load(old_model_path, map_location="cpu")
+    new_model_path = os.path.join(destination_dir, "pytorch_model.bin")
 
     # this takes care of the case where the model was saved with a different prefix,
     # typically due to unsharding.
@@ -85,7 +100,7 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
         os.remove(old_model_path)
 
 
-def write_tokenizer(checkpoint_dir: str):
+def write_tokenizer(checkpoint_dir: str, destination_dir: str):
     tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir)
     tokenizer = OLMoTokenizerFast(
         tokenizer_object=tokenizer_raw.base_tokenizer,
@@ -96,33 +111,37 @@ def write_tokenizer(checkpoint_dir: str):
     tokenizer.model_input_names = ["input_ids", "attention_mask"]
     tokenizer.pad_token_id = tokenizer_raw.pad_token_id
     tokenizer.eos_token_id = tokenizer_raw.eos_token_id
-
-    tokenizer.save_pretrained(checkpoint_dir)
+    tokenizer.save_pretrained(destination_dir)
 
 
-def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
+def convert_checkpoint(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False):
     print("Converting checkpoint to HF format...")
-    write_config(checkpoint_dir)
+    write_config(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir)
 
     print("Saving model to checkpoint...")
-    write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility)
+    write_model(
+        checkpoint_dir=checkpoint_dir,
+        destination_dir=destination_dir,
+        ignore_olmo_compatibility=ignore_olmo_compatibility
+    )
 
     print("Saving tokenizer to checkpoint...")
-    write_tokenizer(checkpoint_dir)
+    write_tokenizer(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir)
 
     # Cannot remove it before writing the tokenizer
     if ignore_olmo_compatibility:
-        os.remove(os.path.join(checkpoint_dir, "config.yaml"))
+        os.remove(os.path.join(destination_dir, "config.yaml"))
 
 
-def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = None):
-    path = os.path.join(checkpoint_dir, "config.yaml")
-    conf = om.load(path)
+def fix_tokenizer(checkpoint_dir: str, destination_dir: str, tokenizer_name_or_path: Optional[str] = None):
+    Path(destination_dir).mkdir(parents=True, exist_ok=True)
 
-    print("Saving tokenizer to checkpoint...")
+    source_path = os.path.join(checkpoint_dir, "config.yaml")
+    dest_path = os.path.join(destination_dir, "config.yaml")
+    conf = om.load(source_path)
 
+    print(f"Saving saving new tokenizer configuration to {dest_path}")
     tokenizer_name_or_path = str(tokenizer_name_or_path or conf["tokenizer"]["identifier"])  # pyright: ignore
-
     try:
         if os.path.exists(tokenizer_name_or_path):
             Tokenizer.from_file(tokenizer_name_or_path)
@@ -130,7 +149,7 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N
             Tokenizer.from_pretrained(tokenizer_name_or_path)
     except Exception as e:
         # the tokenizer is not valid
-        logger.error(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}")
+        print(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}")
         raise e
 
     conf["tokenizer"]["identifier"] = tokenizer_name_or_path  # pyright: ignore
@@ -140,7 +159,24 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N
     ):
         conf["model"]["eos_token_id"] = 50279  # pyright: ignore
 
-    om.save(conf, path)
+    om.save(conf, dest_path)
+
+
+def download_gcs_directory(bucket_name: str, prefix: str, local_dir: str):
+    path_local = Path(local_dir)
+    path_prefix = Path(prefix)
+
+    gcs_client = _get_gcs_client()
+    bucket = gcs_client.bucket(bucket_name)
+
+    path_local.mkdir(parents=True, exist_ok=True)
+
+    files_to_download = list(bucket.list_blobs(prefix=prefix))
+
+    for elem in tqdm(files_to_download, desc="Downloading files from GCS"):
+        local_destination = path_local / Path(elem.name).relative_to(path_prefix)
+        local_destination.parent.mkdir(parents=True, exist_ok=True)
+        elem.download_to_filename(local_destination)
 
 
 def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: str | None = None):
@@ -162,7 +198,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore:
             files_to_download.append(obj["Key"])
 
     # Initialize the progress bar
-    for s3_key in tqdm(files_to_download, desc="Downloading files"):
+    for s3_key in tqdm(files_to_download, desc="Downloading files from S3"):
         # Construct the full local path
         local_file_path = os.path.join(local_dir, os.path.relpath(s3_key, prefix))
         local_file_dir = os.path.dirname(local_file_path)
@@ -178,7 +214,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore:
 def make_local_checkpoint(checkpoint_dir: str) -> str:
     parsed_dir = urlparse(checkpoint_dir)
 
-    assert parsed_dir.scheme in ["s3", ""], "Only s3 and local paths are supported."
+    assert parsed_dir.scheme in ["s3", "gs", "", "file"], "Only s3, gcs, and local paths are supported."
 
     if os.path.exists(checkpoint_dir):
         return checkpoint_dir
@@ -189,51 +225,105 @@ def make_local_checkpoint(checkpoint_dir: str) -> str:
     try:
         os.makedirs(temp_dir, exist_ok=True)
         print(f"Downloading checkpoint to {temp_dir}...")
-        download_s3_directory(
-            bucket_name=parsed_dir.netloc,
-            prefix=parsed_dir.path.lstrip("/"),
-            local_dir=temp_dir,
-            ignore=r"/(optim|train)/",
-        )
+
+        if parsed_dir.scheme == "gs":
+            download_gcs_directory(
+                bucket_name=parsed_dir.netloc,
+                prefix=parsed_dir.path.lstrip("/"),
+                local_dir=temp_dir,
+            )
+        elif parsed_dir.scheme == "s3":
+            download_s3_directory(
+                bucket_name=parsed_dir.netloc,
+                prefix=parsed_dir.path.lstrip("/"),
+                local_dir=temp_dir,
+                ignore=r"/(optim|train)/",
+            )
+        else:
+            raise ValueError(f"Unsupported: {checkpoint_dir}. Only s3://, gs://, and local are supported.")
     except Exception as e:
-        logger.error(f"Error downloading checkpoint: {e}")
+        print(f"Error downloading checkpoint: {e}")
         shutil.rmtree(temp_dir)
         raise e
 
     return temp_dir
 
 
+def upload_s3_directory(local_checkpoint_dir: str, destination_dir: str):
+    parsed_destination = urlparse(destination_dir)
+    if parsed_destination.scheme != "s3":
+        raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 paths are supported.")
+
+    s3_client = _get_s3_client("s3")
+    s3_bucket_name = parsed_destination.netloc
+    s3_prefix = Path(parsed_destination.path)
+    local_checkpoint_path = Path(local_checkpoint_dir)
+    local_paths = [
+        Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames
+    ]
+
+    for local_path in tqdm(local_paths, desc="Uploading files to S3"):
+        destination = s3_prefix / local_path.relative_to(local_checkpoint_path)
+        s3_client.upload_file(local_path, s3_bucket_name, str(destination))
+
+
+def upload_gcs_directory(local_checkpoint_dir: str, destination_dir: str):
+    parsed_destination = urlparse(destination_dir)
+    if parsed_destination.scheme != "gs":
+        raise ValueError(f"Unsupported destination: {destination_dir}. Only gs paths are supported.")
+
+    gcs_client = _get_gcs_client()
+    bucket_name = parsed_destination.netloc
+    prefix = Path(parsed_destination.path)
+    local_checkpoint_path = Path(local_checkpoint_dir)
+    local_paths = [
+        Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames
+    ]
+
+    bucket = gcs_client.bucket(bucket_name)
+
+    for local_path in tqdm(local_paths, desc="Uploading files to GCS"):
+        destination = prefix / local_path.relative_to(local_checkpoint_path)
+        blob = bucket.blob(str(destination))
+        blob.upload_from_filename(local_path)
+
+
 def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str):
     if destination_dir == local_checkpoint_dir:
         return
-    elif (parsed_url := urlparse(destination_dir)).scheme == "s3":
-        s3_bucket_name = parsed_url.netloc
-        s3_prefix = parsed_url.path[1:]
-
-        local_paths = [
-            os.path.join(root, post_fn)
-            for root, _, files in os.walk(local_checkpoint_dir)
-            for post_fn in files
-            if os.path.basename(post_fn) in HF_FILENAMES
-        ]
-        dest_paths = [
-            os.path.join(s3_prefix, os.path.relpath(local_path, local_checkpoint_dir))
-            for local_path in local_paths
-        ]
-
-        s3_client = _get_s3_client("s3")
-        for local_path, dest_path in tqdm(
-            zip(local_paths, dest_paths), desc="Uploading files", total=len(local_paths)
-        ):
-            s3_client.upload_file(local_path, s3_bucket_name, dest_path)
-    elif parsed_url.scheme == "":
-        shutil.copytree(local_checkpoint_dir, destination_dir)
-    else:
-        raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 and local paths are supported.")
 
+    if (parsed_url := urlparse(destination_dir)).scheme == "s3":
+        return upload_s3_directory(local_checkpoint_dir, destination_dir)
+
+    elif parsed_url.scheme == "gs":
+        return upload_gcs_directory(local_checkpoint_dir, destination_dir)
 
-def maybe_unshard(checkpoint_dir: str):
+    # if parsed_url.scheme in ("file", ""):
+
+    breakpoint()
+
+    raise ValueError(f"Unsupported protocol: {destination_dir}. Only s3://, gs://, and local are supported.")
+
+
+def maybe_unshard(checkpoint_dir: str, destination_dir: str):
     if os.path.exists(os.path.join(checkpoint_dir, "model.pt")):
+        # copy the model.pt to the destination directory
+        if checkpoint_dir != destination_dir:
+            print("Copying model.pt to destination directory...")
+            shutil.copy(os.path.join(checkpoint_dir, "model.pt"), os.path.join(destination_dir, "model.pt"))
+
+        print("model.pt found; skipping unsharding.")
+        return
+
+    if os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")):
+        # copy the model.safetensors to the destination directory
+        if checkpoint_dir != destination_dir:
+            print("Copying model.safetensors to destination directory...")
+            shutil.copy(
+                os.path.join(checkpoint_dir, "model.safetensors"),
+                os.path.join(destination_dir, "model.safetensors")
+            )
+        print("model.savetensors found; skipping unsharding.")
         return
 
     print(f"Unsharding {checkpoint_dir}...")
@@ -268,12 +358,6 @@ def main():
         help="Ignore compatibility with the olmo codebase. "
         "This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.",
     )
-    parser.add_argument(
-        "--logger-level",
-        default="warning",
-        help="Set the logger level.",
-    )
-
     parser.add_argument(
         "--tokenizer",
         help="Override the tokenizer to use for the checkpoint.",
@@ -285,29 +369,48 @@ def main():
     )
 
     args = parser.parse_args()
+    local_destination_dir = args.destination_dir or args.checkpoint_dir
 
-    args.destination_dir = args.destination_dir or args.checkpoint_dir
-    logging.basicConfig()
-    logger.setLevel(logging.getLevelName(args.logger_level.upper()))
+    try:
+        local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir)
 
-    local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir)
-    args.checkpoint_dir = local_checkpoint_dir
-    maybe_unshard(local_checkpoint_dir)
+        if local_checkpoint_dir != args.checkpoint_dir:
+            # if using a remote checkpoint, save the converted checkpoint locally
+            print("Remote checkpoint; using local directory as destination.")
+            local_destination_dir = local_checkpoint_dir
 
-    fix_tokenizer(checkpoint_dir=local_checkpoint_dir, tokenizer_name_or_path=args.tokenizer)
-    convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility)
+        Path(args.destination_dir).mkdir(parents=True, exist_ok=True)
+        maybe_unshard(checkpoint_dir=local_checkpoint_dir, destination_dir=local_destination_dir)
 
-    if not args.keep_olmo_artifacts:
-        print("Removing non-HF artifacts...")
-        os.remove(os.path.join(local_checkpoint_dir, "config.yaml"))
-        os.remove(os.path.join(local_checkpoint_dir, "model.pt"))
-        shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True)
-        shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True)
-        shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True)
+        fix_tokenizer(
+            checkpoint_dir=local_checkpoint_dir,
+            destination_dir=local_destination_dir,
+            tokenizer_name_or_path=args.tokenizer
+        )
 
-    upload_local_checkpoint(local_checkpoint_dir, args.destination_dir)
+        convert_checkpoint(
+            checkpoint_dir=args.checkpoint_dir,
+            destination_dir=local_destination_dir,
+            ignore_olmo_compatibility=args.ignore_olmo_compatibility
+        )
+
+        if not args.keep_olmo_artifacts:
+            print("Removing non-HF artifacts...")
+            os.remove(os.path.join(local_checkpoint_dir, "config.yaml"))
+            os.remove(os.path.join(local_checkpoint_dir, "model.pt"))
+            shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True)
+            shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True)
+            shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True)
+
+        upload_local_checkpoint(local_destination_dir, args.destination_dir)
 
-    print(f"Converted checkpoint saved to {args.destination_dir}")
+        print(f"Converted checkpoint saved to {args.destination_dir}")
+    except Exception as e:
+        print(f"Error converting checkpoint: {e}")
+        if args.checkpoint_dir != local_destination_dir:
+            print("Removing partially converted checkpoint...")
+            shutil.rmtree(args.destination_dir)
+        raise e
 
 
 if __name__ == "__main__":
diff --git a/olmo/util.py b/olmo/util.py
index aad77eb1c..5ba85a9a6 100644
--- a/olmo/util.py
+++ b/olmo/util.py
@@ -503,6 +503,13 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
     raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
 
 
+@cache
+def _get_gcs_client():
+    from google.cloud import storage as gcs
+
+    return gcs.Client()
+
+
 @cache
 def _get_s3_client(scheme: str):
     session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
@@ -637,7 +644,11 @@ def _http_file_size(scheme: str, host_name: str, path: str) -> int:
     import requests
 
     response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True)
-    return int(response.headers.get("content-length"))
+
+    if (content_length := response.headers.get("content-length")) is not None:
+        return int(content_length)
+
+    raise OLMoNetworkError(f"Failed to get {scheme} file size")
 
 
 def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes:
@@ -647,9 +658,10 @@ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: i
         f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
     )
     result = response.content
-    assert (
-        len(result) == num_bytes
-    ), f"expected {num_bytes} bytes, got {len(result)}"  # Some web servers silently ignore range requests and send everything
+
+    # Some web servers silently ignore range requests and send everything
+    assert len(result) == num_bytes, f"expected {num_bytes} bytes, got {len(result)}"
+
     return result
 
 
diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py
new file mode 100644
index 000000000..64e037ede
--- /dev/null
+++ b/scripts/olmo_soup.py
@@ -0,0 +1,123 @@
+"""
+Soups OLMo checkpoints.
+
+Example usage:
+
+```bash
+    python scripts/olmo_soup.py -c \
+        /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan/step11931 \
+        /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed2/step11931 \
+        /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed3/step11931 \
+        -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup/step11931
+```
+
+Author: Luca Soldaini (@soldni)
+
+"""  # noqa
+
+
+import argparse
+from enum import Enum
+from pathlib import Path
+
+import torch
+from tqdm import tqdm
+
+from olmo.checkpoint import build_sharded_checkpointer
+from olmo.config import TrainConfig
+from olmo.safetensors_util import safetensors_file_to_state_dict
+
+
+class SoupType(Enum):
+    uniform = "uniform"
+
+
+def load_checkpoint(path: Path) -> dict[str, torch.Tensor]:
+    if path.exists() and path.is_file():
+        return torch.load(path, map_location="cpu", weights_only=True)
+
+    if (path / "model.pt").exists():
+        return torch.load(path / "model.pt", map_location="cpu", weights_only=True)
+
+    if (path / "model.safetensors").exists():
+        safetensors_file_to_state_dict(path / "model.safetensors")
+
+    if (path / "model").exists() and (config_path := (path / "config.yaml")).exists():
+        train_config = TrainConfig.load(config_path)
+        checkpointer = build_sharded_checkpointer(train_config)
+        model_state, _, _ = checkpointer.unshard_checkpoint(
+            load_path=str(path), load_optimizer_state=False, load_trainer_state=False
+        )
+        return model_state
+
+    raise FileNotFoundError(f"Could not find checkpoint in {path}")
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Soup OLMo checkponts")
+    parser.add_argument(
+        "-c",
+        "--checkpoints",
+        type=Path,
+        required=True,
+        nargs="+",
+        help="Path to checkpoint(s) to soup",
+    )
+    parser.add_argument(
+        "-s",
+        "--soup-type",
+        type=SoupType,
+        default=SoupType.uniform,
+        help=f"Methods for checkpoint souping. Choose from: {', '.join(SoupType.__members__.keys())}",
+    )
+    parser.add_argument(
+        "-o",
+        "--output",
+        type=Path,
+        required=True,
+        help="Path to save the souped checkpoint",
+    )
+    opts = parser.parse_args()
+    return opts
+
+
+def main():
+    args = parse_args()
+
+    checkpoint_average: dict[str, torch.Tensor] = {}
+
+    for path in tqdm(args.checkpoints, desc="Loading checkpoints", position=0):
+        state_dict = load_checkpoint(path)
+
+        if len(checkpoint_average) == 0:
+            # initialize checkpoint_average with zeros
+            checkpoint_average = {k: torch.zeros_like(v) for k, v in state_dict.items()}
+
+        if any(k not in state_dict for k in checkpoint_average.keys()) or any(
+            k not in checkpoint_average for k in state_dict.keys()
+        ):
+            raise ValueError(f"Checkpoint {path} has different keys")
+
+        for k in tqdm(state_dict, desc="Summing checkpoints", position=1):
+            if state_dict[k].shape != checkpoint_average[k].shape:
+                raise ValueError(f"Checkpoint {path} has different shape for key {k}")
+            checkpoint_average[k] += state_dict[k] / len(args.checkpoints)
+
+        # free memory
+        del state_dict
+
+    print(f"Saving averaged checkpoint to {args.output}")
+    # save the averaged checkpoint
+    args.output.mkdir(parents=True, exist_ok=True)
+    torch.save(checkpoint_average, args.output / "model.pt")
+
+    print("Copying config.yaml")
+    # copy the config file
+    if (config_path := args.checkpoints[0] / "config.yaml").exists():
+        with open(config_path, "r") as src_f, open(args.output / "config.yaml", "w") as dst_f:
+            dst_f.write(src_f.read())
+    print("Done!")
+
+
+if __name__ == "__main__":
+    main()