Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training UX: automatic generating make_train_step #8495

Merged
merged 11 commits into from
Jan 22, 2025
Merged

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Dec 17, 2024

This PR create a function that can generate a train_step function based on model and optimizer.

It also introduces a class that can wrap Module List into a Scan loop.

Then it changes the examples/train_llama_titan to show case usages of those 2 new tools.

@qihqi qihqi force-pushed the hanq_train branch 2 times, most recently from 7ebd346 to 0769c21 Compare December 18, 2024 20:05
@qihqi qihqi requested a review from tengyifei December 21, 2024 00:07
@qihqi qihqi force-pushed the hanq_train branch 2 times, most recently from eb3acbd to 0b0f2b5 Compare December 23, 2024 18:21
@qihqi qihqi requested a review from lsy323 December 23, 2024 18:22

train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does donate_argnums here imply that input buffers are donated to outputs? The (0, 2) is pretty cryptic to me. Consider commenting on their meaning.

Or better, maybe this could be handled internally? We could jit the function inside make_train_step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The current issue is that sometimes I want to print out the stablehlo for inspection. So need to make the jax_jit'd object also to store the jax function. I'll followup.

WORKDIR /
RUN git clone https://github.com/pytorch/xla.git
WORKDIR xla/experimental/torch_xla2
RUN git checkout hanq_hybrid_mesh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the hanq_hybrid_mesh branch intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

return jax.make_array_from_single_device_arrays(shape, sharding, x_split)
```

When running on single-host, `jax.device_put` sufficies. Multi-host need some
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: suffices

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

```

When running on single-host, `jax.device_put` sufficies. Multi-host need some
extra encantations so that we split an array to only the shards corresponding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: incantations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

jax_optimizer = optax.sgd(0.01)
opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params)))

#opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: obsolete code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

5. `interop.call_jax` API is used whenever we need something from Jax. Those API can be
wrapped and have the "jaxiness" hidden. However, I don't think we need to do such hidding.

6. Precompile: call to `helpers.precompile_step`. This is not needed. If not used, then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: helpers.compile_step_func

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

jax_optimizer = optax.sgd(0.01)
opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params)))

#opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obsolete code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal


def custom_attention(
query, key, value, attn_mask=None,
dropout_p=0.0, is_causal=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if dropout_p is not zero or when is_causal is False? Should we assert that their values matches the behavior of splash_attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit is to show case that you can register your own override of op.
So the user is writing this bit and we assume the user would know which special case it applies to them.

def make_train_step(model_fn,
loss_fn, optax_optimizer,
remat_policy=None,
mark_fsdp_sharding_axis=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fairly specific to the FSDP sharding scheme. What if my model input and output uses different sharding schemes? What if I want my model output to be 2D sharded? What if they are PyTrees?

Instead, I wonder if we could separate the sharding concern from make_train_step. For example, what if we had

def shard_input(fn, in_shardings) -> fn
def shard_output(fn, out_shardings) -> fn

# Alternatively, just a `shard` function where `in_shardings` and `out_shardings` semantics matches what's in https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html
def shard(fn, in_shardings, out_shardings)

That could internally wrap any function (such as the model code) and then annotate the inputs or outputs with sharding annotations? Then the user could apply whatever sharding they want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved sharding annotation to the client.

h, *rest = args
newh = torch.func.functional_call(self.c.one_mod, weight, args)
# next layer's input; and residual to be added to list
return (newh, *rest), torch.ones(1, device='jax')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the torch.ones(1, ...) simply be None? None is the base case of a PyTree and is what I managed to got working in torch_xla's scan.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@qihqi qihqi force-pushed the hanq_train branch 2 times, most recently from dafa6b2 to 075cfe5 Compare January 16, 2025 18:39
@tengyifei
Copy link
Collaborator

Wondering if I should review this again

@qihqi
Copy link
Collaborator Author

qihqi commented Jan 17, 2025

Wondering if I should review this again

Yes, PTAL, thanks!

@qihqi qihqi requested a review from tengyifei January 17, 2025 01:00
optax_optimizer: the optimizer from optax library. for example, optax.adam
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
to do gradient checkpointing. If None, then it means checkpoint everything.
mark_fsdp_sharding_axis: str. A string name for marking sharding for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: obsolete comment

@tengyifei tengyifei merged commit f5b33c5 into master Jan 22, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants