-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
meta-llama: update llama-stack instruction to 6bcd1bd
Signed-off-by: Dmitry Rogozhkin <[email protected]>
- Loading branch information
Showing
2 changed files
with
13 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
From 3295a112e6e40c1f9cf80374833a20ebad648848 Mon Sep 17 00:00:00 2001 | ||
From cc788054276114390871e5172b1b1e360f14b365 Mon Sep 17 00:00:00 2001 | ||
From: Dmitry Rogozhkin <[email protected]> | ||
Date: Mon, 18 Nov 2024 16:00:55 -0800 | ||
Subject: [PATCH] feat: enable xpu support for meta-reference stack | ||
|
@@ -10,10 +10,10 @@ Signed-off-by: Dmitry Rogozhkin <[email protected]> | |
1 file changed, 26 insertions(+), 12 deletions(-) | ||
|
||
diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py | ||
index 38c9824..aec503c 100644 | ||
index 080e33b..fbced7c 100644 | ||
--- a/llama_stack/providers/inline/inference/meta_reference/generation.py | ||
+++ b/llama_stack/providers/inline/inference/meta_reference/generation.py | ||
@@ -89,7 +89,10 @@ class Llama: | ||
@@ -91,7 +91,10 @@ class Llama: | ||
llama_model = model.core_model_id.value | ||
|
||
if not torch.distributed.is_initialized(): | ||
|
@@ -25,7 +25,7 @@ index 38c9824..aec503c 100644 | |
|
||
model_parallel_size = config.model_parallel_size | ||
|
||
@@ -97,7 +100,14 @@ class Llama: | ||
@@ -99,7 +102,14 @@ class Llama: | ||
initialize_model_parallel(model_parallel_size) | ||
|
||
local_rank = int(os.environ.get("LOCAL_RANK", 0)) | ||
|
@@ -41,7 +41,7 @@ index 38c9824..aec503c 100644 | |
|
||
# seed must be the same in all processes | ||
if config.torch_seed is not None: | ||
@@ -175,19 +185,21 @@ class Llama: | ||
@@ -176,19 +186,21 @@ class Llama: | ||
"Currently int4 and fp8 are the only supported quantization methods." | ||
) | ||
else: | ||
|
@@ -62,13 +62,13 @@ index 38c9824..aec503c 100644 | |
|
||
+ model.to(device) | ||
+ | ||
print(f"Loaded in {time.time() - start_time:.2f} seconds") | ||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds") | ||
- return Llama(model, tokenizer, model_args, llama_model) | ||
+ return Llama(model, tokenizer, model_args, llama_model, device) | ||
|
||
def __init__( | ||
self, | ||
@@ -195,12 +207,14 @@ class Llama: | ||
@@ -196,12 +208,14 @@ class Llama: | ||
tokenizer: Tokenizer, | ||
args: ModelArgs, | ||
llama_model: str, | ||
|
@@ -83,7 +83,7 @@ index 38c9824..aec503c 100644 | |
|
||
@torch.inference_mode() | ||
def generate( | ||
@@ -254,14 +268,14 @@ class Llama: | ||
@@ -253,14 +267,14 @@ class Llama: | ||
) | ||
|
||
pad_id = self.tokenizer.pad_id | ||
|
@@ -101,7 +101,7 @@ index 38c9824..aec503c 100644 | |
input_text_mask = tokens != pad_id | ||
if min_prompt_len == total_len: | ||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap | ||
@@ -273,11 +287,11 @@ class Llama: | ||
@@ -272,11 +286,11 @@ class Llama: | ||
ignore_index=pad_id, | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters