Skip to content

Commit

Permalink
Removes the build file out of the obm_module and instead processes th…
Browse files Browse the repository at this point in the history
…e serving_configs in the constructor of obm_export to match what is done with the Tensorflow Export.

PiperOrigin-RevId: 696955640
  • Loading branch information
Orbax Authors committed Nov 22, 2024
1 parent 5048d35 commit 8c4aa3e
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 117 deletions.
9 changes: 0 additions & 9 deletions export/orbax/export/export_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
obx_export_config = config.config
maybe_reraise = reraise_utils.maybe_reraise


class ExportManager:
"""Exports a JAXModule with pre- and post-processors."""

Expand Down Expand Up @@ -62,14 +61,6 @@ def __init__(
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
obm_module_ = module.orbax_module()
if not isinstance(obm_module_, obm_module.ObmModule):
raise ValueError(
'module.orbax_module() must return an `ObmModule`. '
f'Got type: {type(obm_module_)}'
)
# TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step.
obm_module_.build(serving_configs)
else:
self._serialization_functions = tensorflow_export.TensorFlowExport(
self._jax_module, serving_configs
Expand Down
174 changes: 69 additions & 105 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,23 @@

"""Wraps JAX functions and parameters into a tf.Module."""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping
import copy
import logging
from typing import Any, Tuple, Union
from typing import Any, Optional, Union

import jax
from jax import export as jax_export
from orbax.export import constants
from orbax.export import serving_config as osc
from orbax.export import typing as orbax_export_typing
from orbax.export import utils
# from orbax.export import utils
from orbax.export.modules import orbax_module_base
from orbax.export.typing import PyTree
import tensorflow as tf


ApplyFn = orbax_export_typing.ApplyFn


def _to_jax_dtype(t):
if isinstance(t, tf.DType):
return t.as_numpy_dtype()
return t


def _to_jax_spec(tree: PyTree) -> PyTree:
return jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree
)


def _to_sequence(a):
if isinstance(a, Sequence):
return a
return (a,)


class ObmModule(orbax_module_base.OrbaxModuleBase):
"""A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow."""

Expand All @@ -69,7 +50,6 @@ def __init__(
'native_serialization_platform', 'flatten_signature', 'weights_name'and
'checkpoint_path'.
"""
self._params = params

# It is possible for jax2obm_kwargs to be None if the key is present.
if not jax2obm_kwargs:
Expand All @@ -78,11 +58,30 @@ def __init__(
self._apply_fn_map = self._normalize_apply_fn_map(
self._normalize_apply_fn_map(apply_fn)
)

if len(self._apply_fn_map) != 1:
raise NotImplementedError(
'ObmModule: Currently the ObmExport only supports a single method'
f' for export. Received: {self._apply_fn_map}'
)

self._native_serialization_platform = (
jax2obm_kwargs[constants.NATIVE_SERIALIZATION_PLATFORM]
if constants.NATIVE_SERIALIZATION_PLATFORM in jax2obm_kwargs
else None
)
supported_platforms = [
platform.name for platform in constants.OrbaxNativeSerializationType
]
if (
self._native_serialization_platform is not None
and self._native_serialization_platform not in supported_platforms
):
raise ValueError(
'native_serialization_platforms must be a sequence containing a'
f' subset of: {supported_platforms}'
)

self._flatten_signature = (
jax2obm_kwargs[constants.FLATTEN_SIGNATURE]
if constants.FLATTEN_SIGNATURE in jax2obm_kwargs
Expand All @@ -94,45 +93,11 @@ def __init__(
if self._support_tf_resources is None:
self._support_tf_resources = False

self._params_args_spec = _to_jax_spec(params)
self._params_args_spec = utils.to_jax_spec(params)

# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)

self.built = False

def build(
self,
serving_configs: Sequence[osc.ServingConfig],
) -> None:
if self.built:
raise ValueError(
'The `build` method has already been called.'
' It can only be called once.'
)
self._verify_serving_configs(serving_configs)

# Currently there will only ever be a single item in the mapping.
if len(self._apply_fn_map) != 1:
raise NotImplementedError(
'ObmModule: Currently the ObmExport only supports a single method'
f' for export. Received: {self._apply_fn_map}'
)

model_function_name, jax_fn = next(iter(self._apply_fn_map.items()))

self._convert_jax_functions_to_obm_functions(
jax_fn=jax_fn,
jax_fn_name=model_function_name,
params_args_spec=self._params_args_spec,
serving_config=serving_configs[0],
native_serialization_platform=self._native_serialization_platform,
flatten_signature=self._flatten_signature,
support_tf_resources=self._support_tf_resources,
)

self.built = True

def _normalize_apply_fn_map(
self, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]]
) -> Mapping[str, ApplyFn]:
Expand All @@ -147,56 +112,13 @@ def _normalize_apply_fn_map(
apply_fn_map = apply_fn
return apply_fn_map

def _verify_serving_configs(
self, serving_configs: Sequence[osc.ServingConfig]
):
if not serving_configs or len(serving_configs) != 1:
raise ValueError(
'ObmModule: A single serving_config must be provided for Orbax'
' Model export.'
)

if not serving_configs[0].input_signature:
# TODO(wangpeng): Infer input_signature from tf_preprocessor.
raise ValueError(
'ObmModule: The serving_config must have an input_signature set.'
)

if not serving_configs[0].signature_key:
raise ValueError(
'ObmModule: The serving_config must have a signature_key set.'
)

def _convert_jax_functions_to_obm_functions(
self,
*,
jax_fn,
jax_fn_name: str,
params_args_spec: PyTree,
serving_config: osc.ServingConfig,
native_serialization_platform,
flatten_signature: bool,
support_tf_resources: bool,
):
"""Converts the JAX functions to OrbaxModel functions."""
if serving_config.input_signature is None:
raise ValueError('serving_config.input_signature is required.')
if (
not support_tf_resources
and serving_config.extra_trackable_resources is not None
):
raise ValueError(
'serving_config.extra_trackable_resources can only be set when'
' support_tf_resources is True.'
)

def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
if constants.CHECKPOINT_PATH not in jax2obm_kwargs:
return

# TODO: b/374195447 - Add a version check for the Orbax checkpointer.
checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH]
weights_name = (
self._checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH]
self._weights_name = (
jax2obm_kwargs[constants.WEIGHTS_NAME]
if constants.WEIGHTS_NAME in jax2obm_kwargs
else constants.DEFAULT_WEIGHTS_NAME
Expand All @@ -212,15 +134,57 @@ def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map from function name to jit'd apply function."""
return self._apply_fn_map

@property
def native_serialization_platform(self) -> Optional[str]:
"""Returns the native serialization platform."""
return self._native_serialization_platform

@property
def flatten_signature(self) -> bool:
"""Returns the flatten signature."""
return self._flatten_signature

@property
def export_version(self) -> constants.ExportModelType:
"""Returns the export version."""
return constants.ExportModelType.ORBAX_MODEL

def set_pre_or_post_processor(
self, processor_name: str, processor: obm.Function
):
"""Sets the post processor name."""
assert processor_name is not None
assert processor is not None
setattr(self._orbax_model_module, processor_name, processor)

def set_model_function(
self, function_name: str, model_function: obm.ShloFunction
):
"""Sets the model function."""
assert model_function is not None
assert function_name is not None
setattr(self._orbax_model_module, function_name, model_function)

def is_checkpoint_path_set(self) -> bool:
"""Returns True if the checkpoint path is set."""
return self._checkpoint_path is not None

def checkpoint_path(self) -> str:
"""Returns the checkpoint path."""
if not self.is_checkpoint_path_set():
raise ValueError('Checkpoint path is not set.')
return self._checkpoint_path

def weights_name(self) -> str:
"""Returns the weights name."""
if not self.is_checkpoint_path_set():
raise ValueError('Checkpoint path must be set to get the weights name.')
return self._weights_name

@property
def model_params(self) -> PyTree:
"""Returns the model parameter specs."""
return self._params
return self._params_args_spec

def obm_module_to_jax_exported_map(
self,
Expand Down
3 changes: 3 additions & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from orbax.export import constants
from orbax.export import obm_export
from orbax.export import serving_config as osc
from orbax.export.modules import obm_module
from orbax.export import jax_module
import tensorflow as tf


Expand Down
Loading

0 comments on commit 8c4aa3e

Please sign in to comment.