diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 080e33be0d..ecc8c9f3c6 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -91,7 +91,10 @@ def build( llama_model = model.core_model_id.value if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + if torch.cuda.is_available(): + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") model_parallel_size = config.model_parallel_size @@ -99,7 +102,14 @@ def build( initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + if torch.cuda.is_available(): + device = "cuda" + torch.cuda.set_device(local_rank) + elif torch.xpu.is_available(): + device = "xpu" + torch.xpu.set_device(local_rank) + else: + raise NotImplementedError("Devices other than CUDA and XPU are not supported yet") # seed must be the same in all processes if config.torch_seed is not None: @@ -176,10 +186,17 @@ def build( "Currently int4 and fp8 are the only supported quantization methods." ) else: - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) + if device == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + elif device == "xpu": + torch.set_default_device(device) + if torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) if model_args.vision_chunk_size > 0: model = CrossAttentionTransformer(model_args) model.setup_cache(model_args.max_batch_size, torch.bfloat16) @@ -187,6 +204,8 @@ def build( model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) + model.to(device) + log.info(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args, llama_model) @@ -195,7 +214,7 @@ def __init__( model: Transformer, tokenizer: Tokenizer, args: ModelArgs, - llama_model: str, + llama_model: str ): self.args = args self.model = model @@ -253,14 +272,14 @@ def generate( ) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + token_logprobs = torch.zeros_like(tokens) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz) input_text_mask = tokens != pad_id if min_prompt_len == total_len: # TODO(ashwin): unify this branch with the one below and figure out multimodal crap @@ -272,11 +291,11 @@ def generate( ignore_index=pad_id, ) - stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) for cur_pos in range(min_prompt_len, total_len): if is_vision: position_ids = torch.arange( - prev_pos, cur_pos, dtype=torch.long, device="cuda" + prev_pos, cur_pos, dtype=torch.long ) logits = self.model.forward( position_ids,