-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate the stateless llama logic #729
base: main
Are you sure you want to change the base?
Conversation
why are we pulling this to @monorimet's branch? this should be fine standalone, @monorimet can rebase after merging it? That we we get good test coverage? |
device_inputs = [ | ||
ireert.asdevicearray(self.device, input_tensor) | ||
] | ||
if self.first_input: # or not self.streaming_llm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
streaming llm commented code? is this from the original?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I redid some of the logic to remove non-streaming since I thought our plan was to only support streaming, but I think that's not actually the case. I'll add the support back in.
@@ -3,6 +3,7 @@ | |||
import re | |||
import json | |||
from turbine_models.turbine_tank import turbine_tank | |||
from pathlib import Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused?
@@ -489,26 +491,362 @@ def evict_kvcache_space(self): | |||
return blob_name, tokenizer | |||
|
|||
|
|||
llm_model_map = { | |||
"meta-llama/Llama-2-7b-chat-hf": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not supporting the larger models we care about like 13b and 70b
@@ -489,26 +491,362 @@ def evict_kvcache_space(self): | |||
return blob_name, tokenizer | |||
|
|||
|
|||
llm_model_map = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might belong in a separate config file to reduce clutter, also more limiting to have this without some default setup
pipeline_dir: str | Path = "./shark_vmfbs", | ||
external_weights_dir: str | Path = "./shark_weights", | ||
external_weights: str = "safetensors", | ||
custom_vae: str = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove vae and other unnecessary flags that look to come from SD (scheduler etc etc)
} | ||
|
||
|
||
class StatelessLlamaPipeline: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama doesn't really have a pipeline, might want to remove pipeline references
|
||
# FILE MANAGEMENT AND PIPELINE SETUP | ||
|
||
def check_prepared( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file management looks to be copied from SD code, can we just combine to reduce repeated code?
|
||
# RUN | ||
|
||
def chat(self, prompt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used? Looks like this should be part of llm_runner.py, stateless_llama.py should be just for tracing and generating IR and or compiling vmfbs.
5a9aaa0
to
16ee249
Compare
.github/workflows/test_models.yml
Outdated
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux | ||
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any sdxl related changes should be moved to a different PR
@@ -0,0 +1,169 @@ | |||
// Copyright 2024 The IREE Authors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only for argmax right? We should also pull in all the transform spec changes that apply from the sdxl spec file as well
############################################################################## | ||
|
||
p.add_argument( | ||
"--seed", type=float, default=0, help="Seed for random number/latents generation." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama doesn't need a seed does it?
help="Path to location of vmfb files.", | ||
) | ||
|
||
p.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary weight flags for llama. We are only using 1 external weight file so could remove external_weights_dir, and I don't think we need external_weight_file below,
@@ -8,7 +8,7 @@ | |||
from iree.compiler.ir import Context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let keep the separate model updates in separate patches. Makes it easier to track and revert patches if ever needed
@@ -90,20 +90,42 @@ def test_vmfb_comparison(self): | |||
|
|||
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") | |||
|
|||
blob_name = llama.export_transformer_model( | |||
# blob_name = llama.export_transformer_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code?
(And accidentally undo some cleanup, oops)
0bfa2d9
to
6db0f19
Compare
It's producing vmfbs now, just needs some more cleanup and vmfb runner logic if we want to do that here.