Skip to content

Commit

Permalink
feat: enable xpu support for meta-reference stack
Browse files Browse the repository at this point in the history
Requires: meta-llama/llama-models#165
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Dec 2, 2024
1 parent 6bcd1bd commit cc78805
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions llama_stack/providers/inline/inference/meta_reference/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,25 @@ def build(
llama_model = model.core_model_id.value

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if torch.cuda.is_available():
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")

model_parallel_size = config.model_parallel_size

if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if torch.cuda.is_available():
device = 'cuda'
torch.cuda.set_device(local_rank)
elif torch.xpu.is_available():
device = 'xpu'
torch.xpu.set_device(local_rank)
else:
raise NotImplementedError("Devices other than CUDA and XPU are not supported yet")

# seed must be the same in all processes
if config.torch_seed is not None:
Expand Down Expand Up @@ -176,32 +186,36 @@ def build(
"Currently int4 and fp8 are the only supported quantization methods."
)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
if torch.cuda.is_bf16_supported() or torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model = Transformer(model_args, device=device)
model.load_state_dict(state_dict, strict=False)

model.to(device)

log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
return Llama(model, tokenizer, model_args, llama_model, device)

def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
device: torch.device = torch.device('cuda')
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
self.device = device

@torch.inference_mode()
def generate(
Expand Down Expand Up @@ -253,14 +267,14 @@ def generate(
)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=self.device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
Expand All @@ -272,11 +286,11 @@ def generate(
ignore_index=pad_id,
)

stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device=self.device)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
prev_pos, cur_pos, dtype=torch.long, device=self.device
)
logits = self.model.forward(
position_ids,
Expand Down

0 comments on commit cc78805

Please sign in to comment.