NNX : Create a Custom Primitive layer that works with the jax.grad or flax.nnx.grad #4434
Unanswered
DiagRisker
asked this question in
Q&A
Replies: 2 comments 5 replies
-
Hey @DiagRisker , for NNX Modules you have to use |
Beta Was this translation helpful? Give feedback.
4 replies
-
Hey @DiagRisker, sorry for the slow reply. I looked into this and the issue is a bug in JAX where it doesn't call def __call__(self, x: jax.Array) -> jax.Array:
out = self.conv(x, self.kern.value)
if self.use_bias:
out += self.bias.value
return out Here's a the full example using import jax.numpy as jnp
from flax import nnx
import jax
from jax import lax
from functools import partial
from typing import Any, Callable, Sequence, Union
def Custuniform(**args):
# print('args : ', args)
if args['dtype'] == (jnp.complex64):
args['dtype'] = jnp.float32 # print('rewriting args... ')
return jax.random.uniform(**args) + 1j * jax.random.uniform(**args)
# if dtype = "quaternion"
return jax.random.uniform(**args)
class ConvND(nnx.Module): # (kernel_size, features, c_in )
def __init__(
self,
#
kernel_size: tuple,
features: int = 1,
c_in: int = 1, # channels in
*,
stride: tuple[int] | int = (1,),
# dilation options
input_dilation: Union[None, int, Sequence[int]] = None,
kernel_dilation: Union[None, int, Sequence[int]] = None,
#
precision=jax.lax.Precision('high'),
use_bias: bool = False,
rngs: nnx.Rngs,
dtype=jnp.float32, # for weigths initialization type of data (not vector valued)
):
self.kern = nnx.Param(
Custuniform(
key=rngs.params(),
shape=tuple(kernel_size) + (c_in, features),
dtype=dtype,
)
)
# kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
self.use_bias = use_bias
if use_bias:
self.bias = nnx.Param(
Custuniform(key=rngs.params(), shape=(features,), dtype=dtype)
)
# stride
if not hasattr(stride, '__iter__'):
stride = (stride,) * len(kernel_size)
elif len(stride) - len(kernel_size):
stride = (stride[0],) * len(kernel_size)
# enforcing the (Batch_size, *spatial_dim, channel_size) I/O convention
incf = (0, len(kernel_size) + 1) + tuple(range(1, len(kernel_size) + 1))
kercf = (len(kernel_size) + 1, len(kernel_size)) + tuple(
range(0, len(kernel_size))
)
dimnum = jax.lax.ConvDimensionNumbers(incf, kercf, incf)
# compiling the convolution with the given parameters
self.conv = jax.jit(
partial(
jax.lax.conv_general_dilated,
window_strides=stride,
dimension_numbers=dimnum,
padding='VALID',
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
precision=precision,
)
)
# the function will be compiled per instance of the class.. > not perfect (depending on jax/flax.nnx jit caching system)
def __call__(self, x: jax.Array) -> jax.Array:
out = self.conv(x, self.kern.value)
if self.use_bias:
out += self.bias.value
return out
c_in = 5
c_out = 2
rngs = nnx.Rngs(0)
input = Custuniform(key=rngs(), shape=(60, 20, 20, 20, c_in), dtype=jnp.float32)
Kerns = (3, 3, 3)
# nnx.Conv(in_features=c_in, out_features=c_out, kernel_size=Kerns, padding='VALID', rngs=rngs)
model = ConvND(Kerns, c_out, c_in=c_in, rngs=nnx.Rngs(0))
y = model(input)
def loss(model):
return ((model(input)) ** 2).mean()
grads = nnx.grad(loss)(model)
print(jax.tree.map(jnp.shape, grads)) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
As a PhD student, I've been investigating recently what I can do with flax.nnx modules:
The short goal was to implement a custom Convolution layer and notably others ( I'll focus on this one for this question), everything works in forward mode, however nor nnx.grad, nor jax.grad do work with it.. Here is a generic example:
I used flax nnx documentation (and forums), notably : https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html
And also flax/nnx/nn/linear.py from the git, to be sure my implementation works the same way, without success for this specific need.
Would you have a suggestion (another source), or an idea of why this does not work? @cgarciae
The error message is tied to the Params in the Pytree attribute LG, it looks like this :
"""
tangent = Traced<ShapedArray(float32[3,3,3,5,2])>with<JaxprTrace(level=1/0)>
with pval =(ShapedArray(float32[3,3,3,5,2]), None)
recipe = LambdaBinding()
)' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
"""
NB : I have a similar implementation in JAX, that works for autodifferentiation, but I can't use the model construction of nnx (creating a model in cascade by assigning sub layer/ blocks in a dict):
Apologies for the long question!
Thanks in advance
Beta Was this translation helpful? Give feedback.
All reactions