diff --git a/export/orbax/export/export_manager.py b/export/orbax/export/export_manager.py index 3b9da718..5e09a124 100644 --- a/export/orbax/export/export_manager.py +++ b/export/orbax/export/export_manager.py @@ -31,7 +31,6 @@ obx_export_config = config.config maybe_reraise = reraise_utils.maybe_reraise - class ExportManager: """Exports a JAXModule with pre- and post-processors.""" @@ -62,14 +61,6 @@ def __init__( self._serialization_functions = obm_export.ObmExport( self._jax_module, serving_configs ) - obm_module_ = module.orbax_module() - if not isinstance(obm_module_, obm_module.ObmModule): - raise ValueError( - 'module.orbax_module() must return an `ObmModule`. ' - f'Got type: {type(obm_module_)}' - ) - # TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step. - obm_module_.build(serving_configs) else: self._serialization_functions = tensorflow_export.TensorFlowExport( self._jax_module, serving_configs diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index 3f5a2552..46c05bd3 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -14,42 +14,23 @@ """Wraps JAX functions and parameters into a tf.Module.""" -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Mapping +import copy import logging -from typing import Any, Tuple, Union +from typing import Any, Optional, Union -import jax from jax import export as jax_export from orbax.export import constants -from orbax.export import serving_config as osc from orbax.export import typing as orbax_export_typing +from orbax.export import utils # from orbax.export import utils from orbax.export.modules import orbax_module_base from orbax.export.typing import PyTree import tensorflow as tf - ApplyFn = orbax_export_typing.ApplyFn -def _to_jax_dtype(t): - if isinstance(t, tf.DType): - return t.as_numpy_dtype() - return t - - -def _to_jax_spec(tree: PyTree) -> PyTree: - return jax.tree_util.tree_map( - lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree - ) - - -def _to_sequence(a): - if isinstance(a, Sequence): - return a - return (a,) - - class ObmModule(orbax_module_base.OrbaxModuleBase): """A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow.""" @@ -69,7 +50,6 @@ def __init__( 'native_serialization_platform', 'flatten_signature', 'weights_name'and 'checkpoint_path'. """ - self._params = params # It is possible for jax2obm_kwargs to be None if the key is present. if not jax2obm_kwargs: @@ -78,11 +58,30 @@ def __init__( self._apply_fn_map = self._normalize_apply_fn_map( self._normalize_apply_fn_map(apply_fn) ) + + if len(self._apply_fn_map) != 1: + raise NotImplementedError( + 'ObmModule: Currently the ObmExport only supports a single method' + f' for export. Received: {self._apply_fn_map}' + ) + self._native_serialization_platform = ( jax2obm_kwargs[constants.NATIVE_SERIALIZATION_PLATFORM] if constants.NATIVE_SERIALIZATION_PLATFORM in jax2obm_kwargs else None ) + supported_platforms = [ + platform.name for platform in constants.OrbaxNativeSerializationType + ] + if ( + self._native_serialization_platform is not None + and self._native_serialization_platform not in supported_platforms + ): + raise ValueError( + 'native_serialization_platforms must be a sequence containing a' + f' subset of: {supported_platforms}' + ) + self._flatten_signature = ( jax2obm_kwargs[constants.FLATTEN_SIGNATURE] if constants.FLATTEN_SIGNATURE in jax2obm_kwargs @@ -94,45 +93,11 @@ def __init__( if self._support_tf_resources is None: self._support_tf_resources = False - self._params_args_spec = _to_jax_spec(params) + self._params_args_spec = utils.to_jax_spec(params) # Set the Orbax checkpoint path if provided in the jax2obm_kwargs. self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs) - self.built = False - - def build( - self, - serving_configs: Sequence[osc.ServingConfig], - ) -> None: - if self.built: - raise ValueError( - 'The `build` method has already been called.' - ' It can only be called once.' - ) - self._verify_serving_configs(serving_configs) - - # Currently there will only ever be a single item in the mapping. - if len(self._apply_fn_map) != 1: - raise NotImplementedError( - 'ObmModule: Currently the ObmExport only supports a single method' - f' for export. Received: {self._apply_fn_map}' - ) - - model_function_name, jax_fn = next(iter(self._apply_fn_map.items())) - - self._convert_jax_functions_to_obm_functions( - jax_fn=jax_fn, - jax_fn_name=model_function_name, - params_args_spec=self._params_args_spec, - serving_config=serving_configs[0], - native_serialization_platform=self._native_serialization_platform, - flatten_signature=self._flatten_signature, - support_tf_resources=self._support_tf_resources, - ) - - self.built = True - def _normalize_apply_fn_map( self, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]] ) -> Mapping[str, ApplyFn]: @@ -147,56 +112,13 @@ def _normalize_apply_fn_map( apply_fn_map = apply_fn return apply_fn_map - def _verify_serving_configs( - self, serving_configs: Sequence[osc.ServingConfig] - ): - if not serving_configs or len(serving_configs) != 1: - raise ValueError( - 'ObmModule: A single serving_config must be provided for Orbax' - ' Model export.' - ) - - if not serving_configs[0].input_signature: - # TODO(wangpeng): Infer input_signature from tf_preprocessor. - raise ValueError( - 'ObmModule: The serving_config must have an input_signature set.' - ) - - if not serving_configs[0].signature_key: - raise ValueError( - 'ObmModule: The serving_config must have a signature_key set.' - ) - - def _convert_jax_functions_to_obm_functions( - self, - *, - jax_fn, - jax_fn_name: str, - params_args_spec: PyTree, - serving_config: osc.ServingConfig, - native_serialization_platform, - flatten_signature: bool, - support_tf_resources: bool, - ): - """Converts the JAX functions to OrbaxModel functions.""" - if serving_config.input_signature is None: - raise ValueError('serving_config.input_signature is required.') - if ( - not support_tf_resources - and serving_config.extra_trackable_resources is not None - ): - raise ValueError( - 'serving_config.extra_trackable_resources can only be set when' - ' support_tf_resources is True.' - ) - def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs): if constants.CHECKPOINT_PATH not in jax2obm_kwargs: return # TODO: b/374195447 - Add a version check for the Orbax checkpointer. - checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH] - weights_name = ( + self._checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH] + self._weights_name = ( jax2obm_kwargs[constants.WEIGHTS_NAME] if constants.WEIGHTS_NAME in jax2obm_kwargs else constants.DEFAULT_WEIGHTS_NAME @@ -212,15 +134,57 @@ def apply_fn_map(self) -> Mapping[str, ApplyFn]: """Returns the apply_fn_map from function name to jit'd apply function.""" return self._apply_fn_map + @property + def native_serialization_platform(self) -> Optional[str]: + """Returns the native serialization platform.""" + return self._native_serialization_platform + + @property + def flatten_signature(self) -> bool: + """Returns the flatten signature.""" + return self._flatten_signature + @property def export_version(self) -> constants.ExportModelType: """Returns the export version.""" return constants.ExportModelType.ORBAX_MODEL + def set_pre_or_post_processor( + self, processor_name: str, processor: obm.Function + ): + """Sets the post processor name.""" + assert processor_name is not None + assert processor is not None + setattr(self._orbax_model_module, processor_name, processor) + + def set_model_function( + self, function_name: str, model_function: obm.ShloFunction + ): + """Sets the model function.""" + assert model_function is not None + assert function_name is not None + setattr(self._orbax_model_module, function_name, model_function) + + def is_checkpoint_path_set(self) -> bool: + """Returns True if the checkpoint path is set.""" + return self._checkpoint_path is not None + + def checkpoint_path(self) -> str: + """Returns the checkpoint path.""" + if not self.is_checkpoint_path_set(): + raise ValueError('Checkpoint path is not set.') + return self._checkpoint_path + + def weights_name(self) -> str: + """Returns the weights name.""" + if not self.is_checkpoint_path_set(): + raise ValueError('Checkpoint path must be set to get the weights name.') + return self._weights_name + @property def model_params(self) -> PyTree: """Returns the model parameter specs.""" - return self._params + return self._params_args_spec def obm_module_to_jax_exported_map( self, diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 44c453a7..c288fcd3 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp from orbax.export import constants +from orbax.export import obm_export from orbax.export import serving_config as osc from orbax.export.modules import obm_module +from orbax.export import jax_module import tensorflow as tf diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index 1c8f9e77..bbeeb08e 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -14,17 +14,74 @@ """Export class that implements the save and load abstract class defined in Export Base for use with the Orbax Model export format.""" -from typing import Any, Callable, Mapping, cast +from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, cast, Union from absl import logging +import jax from orbax.export import constants from orbax.export import export_base +from orbax.export import jax_module +from orbax.export import serving_config as osc +from orbax.export import utils from orbax.export.modules import obm_module +from orbax.export.typing import PyTree +import tensorflow as tf + + +def _callable_to_obm_function( + f: Callable[..., Any], + input_spec: Sequence[PyTree], + name: str | None = None, +) -> Tuple[obm.Function, tf.types.experimental.ConcreteFunction]: + cf = tf.function(f).get_concrete_function(*_to_tf_spec(input_spec)) + if name is None: + obm_fn = tf2obm.tf_concrete_function_to_obm_function(cf) + else: + obm_fn = tf2obm.tf_concrete_function_name_to_obm_function(name, cf) + return obm_fn, cf + + +def _to_tf_dtype(t): + if isinstance(t, obm.ShloDType): + t = obm.shlo_dtype_to_np_dtype(t) + return t + + +def _to_sequence(a): + if isinstance(a, Sequence): + return a + return (a,) + + +def _to_tf_spec(tree: PyTree) -> PyTree: + return jax.tree_util.tree_map( + lambda x: tf.TensorSpec(x.shape, _to_tf_dtype(x.dtype)), tree + ) class ObmExport(export_base.ExportBase): """Defines the save and load methods for exporting a model using Orbax Model export.""" + def __init__( + self, + module: jax_module.JaxModule, + serving_configs: Sequence[osc.ServingConfig], + ): + """Initializes the ObmExport class.""" + obm_model_module = module.export_module() + + self._orchestration = obm.simple_orchestration_pb2.SimpleOrchestration() + self._module = cast(obm_module.ObmModule, obm_model_module) + + self._verify_serving_configs(serving_configs) + self._save_supplemental_info_closure: ( + Callable[[str], Mapping[str, obm.SupplementalInfo]] | None + ) = None + + # Currently only a single serving config is supported. This is verified in + # _verify_serving_configs. + self._process_serving_configs(serving_configs[0]) + def save( self, model_path: str, @@ -38,11 +95,11 @@ def save( arguments are `save_options` and `serving_signatures`. """ - if self._module.export_version() != constants.ExportModelType.ORBAX_MODEL: + if self._module.export_version != constants.ExportModelType.ORBAX_MODEL: raise ValueError( "JaxModule is not of type ORBAX_MODEL. Please use the correct" " export_version. Expected ORBAX_MODEL, got" - f" {self._module.export_version()}" + f" {self._module.export_version}" ) def load(self, model_path: str, **kwargs: Any): @@ -54,3 +111,105 @@ def load(self, model_path: str, **kwargs: Any): def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: """Returns a map of signature keys to serving functions.""" raise NotImplementedError("ObmExport.load not implemented yet.") + + def _verify_serving_configs( + self, serving_configs: Sequence[osc.ServingConfig] + ): + if not serving_configs or len(serving_configs) != 1: + raise ValueError( + "ObmModule: A single serving_config must be provided for Orbax" + " Model export." + ) + + if not serving_configs[0].input_signature: + # TODO(wangpeng): Infer input_signature from tf_preprocessor. + raise ValueError( + "ObmModule: The serving_config must have an input_signature set." + ) + + if not serving_configs[0].signature_key: + raise ValueError( + "ObmModule: The serving_config must have a signature_key set." + ) + + if ( + not self._module.support_tf_resources() + and serving_configs[0].extra_trackable_resources is not None + ): + raise ValueError( + "serving_config.extra_trackable_resources can only be set when" + " support_tf_resources is True." + ) + + def _maybe_add_pre_processor( + self, + serving_config: osc.ServingConfig, + concrete_functions: Dict[str, tf.types.experimental.ConcreteFunction], + ) -> Sequence[Any]: + """Processes the serving configs to return a map of signature keys to serving functions.""" + if serving_config.tf_preprocessor is not None: + pre_processor_name = constants.DEFAULT_PRE_PROCESSOR_NAME + pre_processor, tf_pre_processor = _callable_to_obm_function( + serving_config.tf_preprocessor, + serving_config.input_signature, + name=pre_processor_name + if self._module.support_tf_resources() + else None, + ) + + pre_proc_output_spec = tf_pre_processor.structured_outputs + concrete_functions[pre_processor_name] = tf_pre_processor + self._module.set_pre_or_post_processor(pre_processor_name, pre_processor) + self._orchestration.pre_processor_name = pre_processor_name + return _to_sequence(pre_proc_output_spec) + + return _to_sequence(serving_config.input_signature) + + def _maybe_add_post_processor( + self, + serving_config: osc.ServingConfig, + concrete_functions: Dict[str, tf.types.experimental.ConcreteFunction], + model_function: obm.ShloFunction, + ): + """Processes the serving configs to return a map of signature keys to serving functions.""" + if serving_config.tf_postprocessor is None: + return + post_processor_input_spec = _to_sequence(model_function.output_signature) + post_processor_name = constants.DEFAULT_POST_PROCESSOR_NAME + post_processor, tf_post_processor = _callable_to_obm_function( + serving_config.tf_postprocessor, + post_processor_input_spec, + name=post_processor_name + if self._module.support_tf_resources() + else None, + ) + concrete_functions[post_processor_name] = tf_post_processor + self._module.set_pre_or_post_processor(post_processor_name, post_processor) + self._orchestration.post_processor_name = post_processor_name + + def _create_supplemental_info_closure( + self, + concrete_functions: Dict[str, tf.types.experimental.ConcreteFunction], + extra_trackable_resources: Union[Any, None] = None, + ) -> Callable[..., Any]: + def save_supplemental_info(path: str): + tf2obm.save_tf_concrete_functions( + path, + concrete_functions, + extra_trackable_resources, + ) + tf_global_supplemental = tf2obm.tf_saved_model_as_obm_supplemental() + return { + tf2obm.TF_SAVED_MODEL_SUPPLEMENTAL_NAME: obm.SupplementalInfo( + tf_global_supplemental, None + ) + } + + return save_supplemental_info + + def _process_serving_configs(self, serving_config: osc.ServingConfig): + """Processes the serving configs to return a map of signature keys to serving functions.""" + if serving_config.input_signature is None: + raise ValueError("serving_config.input_signature is required.") + + jax_fn_name, jax_fn = next(iter(self._module.apply_fn_map.items())) diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index 2ecd7419..c7eeeb3c 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -491,3 +491,16 @@ def make_e2e_inference_fn( return with_default_args( infer_step_func_map[signature_key], serving_config.get_input_signature() ) + + +def _to_jax_dtype(t): + if isinstance(t, tf.DType): + return t.as_numpy_dtype() + return t + + +def to_jax_spec(tree: PyTree) -> PyTree: + return jax.tree_util.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree + ) +