diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index acbf3640..25adba96 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +import numbers from typing import Any, List, Optional from absl import logging @@ -91,6 +92,11 @@ def _validate_save_state( ): if item is None: raise ValueError('Must provide item to save.') + if isinstance(item, jax.Array | numbers.Number): + raise ValueError( + 'StandardCheckpointHandler / StandardSave does not support single ' + 'arrays or scalars. Use ArrayCheckpointHandler / ArraySave' + ) if save_args is None: save_args = jax.tree.map(lambda x: None, item) diff --git a/checkpoint/orbax/checkpoint/single_host_test.py b/checkpoint/orbax/checkpoint/single_host_test.py index 2eeea371..949ab561 100644 --- a/checkpoint/orbax/checkpoint/single_host_test.py +++ b/checkpoint/orbax/checkpoint/single_host_test.py @@ -23,6 +23,7 @@ import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils from orbax.checkpoint._src.serialization import type_handlers import tensorstore as ts @@ -63,6 +64,13 @@ def test_save_and_restore_jax_array(self, use_zarr3): np.testing.assert_array_equal(x, restored_tree['x']) assert isinstance(restored_tree['x'], jax.Array) + @parameterized.parameters({'x': jnp.array([1, 2])}, {'x': 1}) + def test_save_singular_array_with_standard_checkpoint_handler(self, x): + handler = standard_checkpoint_handler_test_utils.StandardCheckpointHandler() + with self.assertRaisesRegex(ValueError, + '.*Use ArrayCheckpointHandler / ArraySave.*'): + handler.save(self.ckpt_dir, args=standard_checkpoint_handler_test_utils.StandardSaveArgs(x)) + def test_save_and_restore_zarrv3_jax_array_default_chunk_size(self): handler = PyTreeCheckpointHandler(use_zarr3=True) key = jax.random.PRNGKey(0)