Skip to content

Commit

Permalink
[ROCm] Implement RNN support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Jan 10, 2025
1 parent a7f384c commit 1b871e2
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 128 deletions.
32 changes: 31 additions & 1 deletion jax/experimental/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,31 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
return jax.random.uniform(
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)

def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional):
"""Swaps the weights for the input and output gates for an LSTM model."""
weights = jnp.asarray(weights) # Ensure weights are JAX arrays
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional)
num_directions = 2 if bidirectional else 1

w_offsets = 0
for l in range(num_layers):
for direction in range(num_directions):
# Iterate through all weight and bias gate names to swap gates in both weights and biases
for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]:
shape = flat_shapes.pop(0) # Get the current shape and remove it from the list
num_elems = math.prod(shape)
matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape)

# Swap between the input and output gates (third and fourth gates)
gates = jnp.split(matrix, 4, axis=0)
swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0)

# Update the weights with swapped matrix
weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten())
w_offsets += num_elems

return weights


def unpack_lstm_weights(
weights: Array, input_size: int, hidden_size: int, num_layers: int,
Expand Down Expand Up @@ -437,7 +462,9 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
if gpu_rnn:
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda')
if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"):
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm')


def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
Expand Down Expand Up @@ -481,5 +508,8 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
if gpu_rnn:
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"):
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm')

lstm.defvjp(lstm_fwd, lstm_bwd)
Loading

0 comments on commit 1b871e2

Please sign in to comment.