Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] torchtune <> llama-stack integration skeleton #540

Merged
merged 21 commits into from
Dec 13, 2024
Merged

Conversation

SLR722
Copy link
Contributor

@SLR722 SLR722 commented Nov 27, 2024

Context

This is the 1st of series PRs that integrate torchtune with llama-stack as meta reference post-training implementation. For MVP, we will focus on single device LoRA SFT.

Though this PR is still WIP, we want to get early feedback on the high level design of this skeleton while still working on several details

Scope

To limit the scope of this PR, we focus on the skeleton of the implementation.

What are included?

  • refine the post-training SFT apis
  • skeleton of supervised_fine_tune implementation. We verified that we can call the supervised_fine_tune API successfully from llama stack client SDK (client side PR: post training CLI llama-stack-client-python#51)
  • a very basic single device LoRA training recipe based on torchtune core components
  • parity check with torchtune library and post training api unit test

What are not includes?

  • implementation of other job management, get training artifacts apis (separate PR)
  • refactor the meta reference inference logic to support eval on finetuned model (separate PR)
  • several necessary functionality in the training recipe such as logging, validation etc (separate PR)
  • interop with telemetry for tracing and metrics logging, currently temporarily log to local disk (separate PR)

Testing

e2e test
Although we haven't added detailed testing and numerical parity check with torchtune yet, we did a simple E2E test from client to server

  1. setup server with llama stack build --template experimental-post-training --image-type conda and llama stack run experimental-post-training
  2. On client, run llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training supervised_fine_tune
  3. Training finishes successfully. On server side, get the finetune checkpoints under output dir. On client side, get the job uuid

server
Screenshot 2024-12-02 at 5 52 32 PM

client
Screenshot 2024-12-02 at 5 52 37 PM

parity check
torchtune dataloader output and llama-stack post training dataloader output are same
Screenshot 2024-12-04 at 8 18 46 PM

torchtune LoRA SFT and llama-stack post training LoRA SFT on alpaca dataset with llama3.2 3B instruct model are numerical match

Screenshot 2024-12-04 at 8 17 01 PM Screenshot 2024-12-04 at 8 17 06 PM

**unit test **
Uploading Screenshot 2024-12-09 at 1.35.10 PM.png…

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 27, 2024
@SLR722 SLR722 changed the title Post training v2 [WIP] torchtune <> llama-stack integration Nov 27, 2024
@SLR722 SLR722 changed the title [WIP] torchtune <> llama-stack integration [1/n] torchtune <> llama-stack integration skeleton Dec 3, 2024
@raghotham
Copy link
Contributor

added some initial comments

@SLR722 SLR722 marked this pull request as ready for review December 3, 2024 19:57
}

EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
"alpaca": [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions:

  1. what do these three options mean?
  2. what does instruction mean? does it mean system_prompt?
  3. do you think we can use the types we have in the rest of our system -- for example, how is a dialog represented? We should be able to re-use the UserMessage, SystemMessage types we have in the rest of the system. Evals uses some of them.

Copy link
Contributor Author

@SLR722 SLR722 Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do these three options mean?

the 3 options mean 3 eligible alpaca dataset schemas. 'input' and 'text' columns are optional for alpaca dataset schema (see: https://github.com/pytorch/torchtune/blob/9cfa28835246a4c1ac4449e703eae8f49227db55/torchtune/data/_messages.py#L696 and https://huggingface.co/datasets/tatsu-lab/alpaca?row=0).

what does instruction mean? does it mean system_prompt?

instruction is different from system_prompt here. In alpaca dataset, 'instruction' pairs with 'input' as user_prompt (example: https://github.com/pytorch/torchtune/blob/9cfa28835246a4c1ac4449e703eae8f49227db55/torchtune/data/_messages.py#L696)

do you think we can use the types we have in the rest of our system -- for example, how is a dialog represented? We should be able to re-use the UserMessage, SystemMessage types we have in the rest of the system. Evals uses some of them.

torchtune has its own Message definition in the data transform https://github.com/pytorch/torchtune/blob/9cfa28835246a4c1ac4449e703eae8f49227db55/torchtune/data/_messages.py#L724. I lean toward directly import torchtune data transform to the stack and reuse its Message type. For dataset schema validation, I refer to how eval does

async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get this in!!!!

@ashwinb ashwinb merged commit aeb7639 into main Dec 13, 2024
2 checks passed
@ashwinb ashwinb deleted the post_training_v2 branch December 13, 2024 19:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants