-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
training UX: automatic generating make_train_step #8495
Conversation
7ebd346
to
0769c21
Compare
eb3acbd
to
0b0f2b5
Compare
|
||
train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does donate_argnums
here imply that input buffers are donated to outputs? The (0, 2)
is pretty cryptic to me. Consider commenting on their meaning.
Or better, maybe this could be handled internally? We could jit the function inside make_train_step
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. The current issue is that sometimes I want to print out the stablehlo for inspection. So need to make the jax_jit'd object also to store the jax function. I'll followup.
WORKDIR / | ||
RUN git clone https://github.com/pytorch/xla.git | ||
WORKDIR xla/experimental/torch_xla2 | ||
RUN git checkout hanq_hybrid_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the hanq_hybrid_mesh
branch intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
return jax.make_array_from_single_device_arrays(shape, sharding, x_split) | ||
``` | ||
|
||
When running on single-host, `jax.device_put` sufficies. Multi-host need some |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: suffices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
``` | ||
|
||
When running on single-host, `jax.device_put` sufficies. Multi-host need some | ||
extra encantations so that we split an array to only the shards corresponding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: incantations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
jax_optimizer = optax.sgd(0.01) | ||
opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) | ||
|
||
#opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: obsolete code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
5. `interop.call_jax` API is used whenever we need something from Jax. Those API can be | ||
wrapped and have the "jaxiness" hidden. However, I don't think we need to do such hidding. | ||
|
||
6. Precompile: call to `helpers.precompile_step`. This is not needed. If not used, then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: helpers.compile_step_func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
jax_optimizer = optax.sgd(0.01) | ||
opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) | ||
|
||
#opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obsolete code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptal
|
||
def custom_attention( | ||
query, key, value, attn_mask=None, | ||
dropout_p=0.0, is_causal=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if dropout_p
is not zero or when is_causal
is False
? Should we assert that their values matches the behavior of splash_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit is to show case that you can register your own override of op.
So the user is writing this bit and we assume the user would know which special case it applies to them.
def make_train_step(model_fn, | ||
loss_fn, optax_optimizer, | ||
remat_policy=None, | ||
mark_fsdp_sharding_axis=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fairly specific to the FSDP sharding scheme. What if my model input and output uses different sharding schemes? What if I want my model output to be 2D sharded? What if they are PyTrees?
Instead, I wonder if we could separate the sharding concern from make_train_step
. For example, what if we had
def shard_input(fn, in_shardings) -> fn
def shard_output(fn, out_shardings) -> fn
# Alternatively, just a `shard` function where `in_shardings` and `out_shardings` semantics matches what's in https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html
def shard(fn, in_shardings, out_shardings)
That could internally wrap any function (such as the model code) and then annotate the inputs or outputs with sharding annotations? Then the user could apply whatever sharding they want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved sharding annotation to the client.
h, *rest = args | ||
newh = torch.func.functional_call(self.c.one_mod, weight, args) | ||
# next layer's input; and residual to be added to list | ||
return (newh, *rest), torch.ones(1, device='jax') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the torch.ones(1, ...)
simply be None
? None
is the base case of a PyTree and is what I managed to got working in torch_xla
's scan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
dafa6b2
to
075cfe5
Compare
Wondering if I should review this again |
Yes, PTAL, thanks! |
optax_optimizer: the optimizer from optax library. for example, optax.adam | ||
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how | ||
to do gradient checkpointing. If None, then it means checkpoint everything. | ||
mark_fsdp_sharding_axis: str. A string name for marking sharding for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: obsolete comment
This PR create a function that can generate a
train_step
function based on model and optimizer.It also introduces a class that can wrap Module List into a Scan loop.
Then it changes the
examples/train_llama_titan
to show case usages of those 2 new tools.