Skip to content

Commit

Permalink
Add checkpoint/_src/metadata/BUILD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713155423
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 13, 2025
1 parent 3ef736f commit 6b2353b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
130 changes: 130 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "tree_rich_types",
srcs = ["tree_rich_types.py"],
deps = [
":pytree_metadata_options",
":value_metadata_entry",
],
)

py_test(
name = "tree_rich_types_test",
srcs = ["tree_rich_types_test.py"],
deps = [":tree_rich_types"],
)

py_library(
name = "tree",
srcs = ["tree.py"],
deps = [
":empty_values",
":pytree_metadata_options",
":tree_rich_types",
":value",
":value_metadata_entry",
],
)

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
deps = [":tree"],
)

py_library(
name = "value",
srcs = ["value.py"],
deps = [":sharding"],
)

py_library(
name = "sharding",
srcs = ["sharding.py"],
)

py_test(
name = "sharding_test",
srcs = ["sharding_test.py"],
deps = [":sharding"],
)

py_library(
name = "checkpoint",
srcs = ["checkpoint.py"],
)

py_test(
name = "checkpoint_test",
srcs = ["checkpoint_test.py"],
deps = [
":checkpoint",
":root_metadata_serialization",
":step_metadata_serialization",
],
)

py_test(
name = "sharding_tpu_test",
srcs = ["sharding_tpu_test.py"],
python_version = "PY3",
deps = [":sharding"],
)

py_library(
name = "root_metadata_serialization",
srcs = ["root_metadata_serialization.py"],
deps = [
":checkpoint",
":metadata_serialization_utils",
],
)

py_library(
name = "step_metadata_serialization",
srcs = ["step_metadata_serialization.py"],
deps = [
":checkpoint",
":metadata_serialization_utils",
],
)

py_library(
name = "metadata_serialization_utils",
srcs = ["metadata_serialization_utils.py"],
)

py_library(
name = "pytree_metadata_options",
srcs = ["pytree_metadata_options.py"],
)

py_library(
name = "value_metadata_entry",
srcs = ["value_metadata_entry.py"],
deps = [
":empty_values",
":pytree_metadata_options",
],
)

py_library(
name = "empty_values",
srcs = ["empty_values.py"],
deps = [":pytree_metadata_options"],
)

py_test(
name = "empty_values_test",
srcs = ["empty_values_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":empty_values",
":pytree_metadata_options",
],
)
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ jaxlib>=0.4.34
portpicker~=1.6
absl-py>=1.0,==1.*
numpy>=1.26.0
orbax-checkpoint>=0.9.0
etils[epath]
simplejson
chex
optax

0 comments on commit 6b2353b

Please sign in to comment.