Skip to content

Commit

Permalink
Resubmit PR 1304.
Browse files Browse the repository at this point in the history
  • Loading branch information
cpgaffney1 committed Nov 12, 2024
1 parent d2c1b86 commit 8989abc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
import numbers
from typing import Any, List, Optional

from absl import logging
Expand Down Expand Up @@ -91,6 +92,11 @@ def _validate_save_state(
):
if item is None:
raise ValueError('Must provide item to save.')
if isinstance(item, jax.Array | numbers.Number):
raise ValueError(
'StandardCheckpointHandler / StandardSave does not support single '
'arrays or scalars. Use ArrayCheckpointHandler / ArraySave'
)
if save_args is None:
save_args = jax.tree.map(lambda x: None, item)

Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/single_host_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils
from orbax.checkpoint._src.serialization import type_handlers
import tensorstore as ts

Expand Down

0 comments on commit 8989abc

Please sign in to comment.