diff --git a/official/nlp/modeling/layers/block_sparse_attention.py b/official/nlp/modeling/layers/block_sparse_attention.py index b6ac35e0d6..de45937300 100644 --- a/official/nlp/modeling/layers/block_sparse_attention.py +++ b/official/nlp/modeling/layers/block_sparse_attention.py @@ -84,6 +84,17 @@ def __init__( "sigmoid_attn_bias must be specified for sigmoid attn." ) + def get_config(self): + config = super().get_config() + config.update({ + "src_block_size": self._src_block_size, + "tgt_block_size": self._tgt_block_size, + "use_sigmoid_attn": self._use_sigmoid_attn, + "sigmoid_attn_bias": self._sigmoid_attn_bias, + "num_kv_heads": self._num_kv_heads, + }) + return config + def _build_from_signature(self, query, value, key=None): # pytype: disable=attribute-error super()._build_from_signature(query, value, key) diff --git a/official/nlp/modeling/layers/multi_query_attention.py b/official/nlp/modeling/layers/multi_query_attention.py index 3df880d3ed..8bfb12ec14 100644 --- a/official/nlp/modeling/layers/multi_query_attention.py +++ b/official/nlp/modeling/layers/multi_query_attention.py @@ -93,6 +93,11 @@ def __init__(self, num_kv_heads=None, **kwargs): self._num_heads % self._num_kv_heads == 0 ), "num_kv_heads needs to divide num_heads exactly." + def get_config(self): + config = super().get_config() + config.update({"num_kv_heads": self._num_kv_heads}) + return config + def _build_from_signature( self, query: Union[tf.Tensor, tf.TensorShape],