Skip to content

Commit

Permalink
[efficient] Fix is_short_seq order so that we can also apply feature …
Browse files Browse the repository at this point in the history
…transform and apply softmax afterwards.

PiperOrigin-RevId: 383967806
  • Loading branch information
frederick0329 authored and tensorflower-gardener committed Jul 10, 2021
1 parent dbaec32 commit 5f23689
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions official/nlp/modeling/layers/kernel_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ def _generalized_kernel(x, projection_matrix, f, h):
functools.partial(
_generalized_kernel,
# Avoid exp explosion by shifting.
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),),
"identity": lambda x, projection_matrix, is_query: x
f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))),
),
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
# pylint: enable=g-long-lambda

Expand Down Expand Up @@ -260,18 +262,6 @@ def _compute_attention(self,
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output

projection_matrix = None
if self._num_random_features > 0:
if self._redraw and training:
Expand All @@ -280,23 +270,36 @@ def _compute_attention(self,
else:
projection_matrix = self._projection_matrix

# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)

key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)

if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)

kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
if is_short_seq:
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output

def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key)
Expand Down

0 comments on commit 5f23689

Please sign in to comment.