Skip to content

Commit

Permalink
Fix dtype serialization (#19752)
Browse files Browse the repository at this point in the history
* Fix `dtype` serialization

* Address comments

* Update comments
  • Loading branch information
james77777778 authored May 25, 2024
1 parent 490b1f0 commit 510d406
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 7 deletions.
4 changes: 1 addition & 3 deletions keras/src/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def get(identifier):

if identifier is None:
return dtype_policy.dtype_policy()
if isinstance(
identifier, (DTypePolicy, FloatDTypePolicy, QuantizedDTypePolicy)
):
if isinstance(identifier, DTypePolicy):
return identifier
if isinstance(identifier, dict):
return deserialize(identifier)
Expand Down
25 changes: 21 additions & 4 deletions keras/src/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,17 @@ def __new__(cls, *args, **kwargs):
# Generate a config to be returned by default by `get_config()`.
arg_names = inspect.getfullargspec(cls.__init__).args
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
if "dtype" in kwargs and isinstance(
kwargs["dtype"], dtype_policies.DTypePolicy
):
kwargs["dtype"] = kwargs["dtype"].get_config()

# Explicitly serialize `dtype` to support auto_config
dtype = kwargs.get("dtype", None)
if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
# For backward compatibility, we use a str (`name`) for
# `FloatDTypePolicy`
if not dtype.is_quantized:
kwargs["dtype"] = dtype.name
# Otherwise, use `dtype_policies.serialize`
else:
kwargs["dtype"] = dtype_policies.serialize(dtype)

# For safety, we only rely on auto-configs for a small set of
# serializable types.
Expand Down Expand Up @@ -198,12 +205,22 @@ def from_config(cls, config):
This method is the reverse of `get_config`, capable of instantiating the
same operation from the config dictionary.
Note: If you override this method, you might receive a serialized dtype
config, which is a `dict`. You can deserialize it as follows:
```python
if "dtype" in config and isinstance(config["dtype"], dict):
policy = dtype_policies.deserialize(config["dtype"])
```
Args:
config: A Python dictionary, typically the output of `get_config`.
Returns:
An operation instance.
"""
# Explicitly deserialize dtype config if needed. This enables users to
# directly interact with the instance of `DTypePolicy`.
if "dtype" in config and isinstance(config["dtype"], dict):
config = config.copy()
config["dtype"] = dtype_policies.deserialize(config["dtype"])
Expand Down
28 changes: 28 additions & 0 deletions keras/src/ops/operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def compute_output_spec(self, x):

class OpWithCustomDtype(operation.Operation):
def __init__(self, dtype):
if not isinstance(dtype, (str, dtype_policies.DTypePolicy)):
raise AssertionError(
"`dtype` must be a instance of `DTypePolicy` or str. "
f"Received: dtype={dtype} of type {type(dtype)}"
)
super().__init__(dtype=dtype)

def call(self, x):
Expand Down Expand Up @@ -174,9 +179,32 @@ def test_valid_naming(self):
OpWithMultipleOutputs(name="test/op")

def test_dtype(self):
# Test dtype argument
op = OpWithCustomDtype(dtype="bfloat16")
self.assertEqual(op._dtype_policy.name, "bfloat16")

policy = dtype_policies.DTypePolicy("mixed_bfloat16")
op = OpWithCustomDtype(dtype=policy)
self.assertEqual(op._dtype_policy.name, "mixed_bfloat16")

# Test dtype config to ensure it remains unchanged
config = op.get_config()
copied_config = config.copy()
OpWithCustomDtype.from_config(config)
self.assertEqual(config, copied_config)

# Test floating dtype serialization
op = OpWithCustomDtype(dtype="mixed_bfloat16")
config = op.get_config()
self.assertEqual(config["dtype"], "mixed_bfloat16") # A plain string
revived_op = OpWithCustomDtype.from_config(config)
self.assertEqual(op._dtype_policy.name, revived_op._dtype_policy.name)

# Test quantized dtype serialization
policy = dtype_policies.QuantizedDTypePolicy("int8", "bfloat16")
op = OpWithCustomDtype(policy)
self.assertEqual(op._dtype_policy.name, "int8_from_bfloat16")
config = op.get_config() # A serialized config
self.assertEqual(config["dtype"], dtype_policies.serialize(policy))
revived_op = OpWithCustomDtype.from_config(config)
self.assertEqual(op._dtype_policy.name, revived_op._dtype_policy.name)

0 comments on commit 510d406

Please sign in to comment.