From c6231ac4c7c9ef509012d84c21eb53fa239c67f3 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Thu, 19 Sep 2024 16:47:55 -0700 Subject: [PATCH] feat: enable pytorch xpu support for non-attention models XPU backend is available natively (without IPEX) in pytorch starting from pytorch 2.4. This commit extends TGI to cover the case when user has XPU support thru pytorch 2.4, but does not have IPEX installed. Models which don't require attention can work. For attention required models more work is needed to provide attention implementation. Tested with the following models: * teknium/OpenHermes-2.5-Mistral-7B * bigscience/bloom-560m * google/gemma-7b * google/flan-t5-xxl Signed-off-by: Dmitry Rogozhkin --- .../models/causal_lm.py | 26 +++++++++++-------- .../models/seq2seq_lm.py | 25 +++++++++++------- .../utils/import_utils.py | 5 ++++ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 28534d0f73b..de2c065160b 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -517,14 +517,13 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -593,8 +592,14 @@ def fallback( if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") + device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -616,18 +621,17 @@ def fallback( torch_dtype=dtype, device_map=( "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 + if device_count > 1 else None ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if ( - torch.cuda.is_available() - and torch.cuda.device_count() == 1 + device_count == 1 and quantize != "bitsandbytes" ): - model = model.cuda() + model = model.to(device) if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 04d4c28ba3e..94f87d02350 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -558,14 +558,13 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -630,8 +629,14 @@ def fallback( if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") + device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -646,14 +651,14 @@ def fallback( torch_dtype=dtype, device_map=( "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 + if device_count > 1 else None ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() + if device_count == 1: + model = model.to(device) tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 782b4f15b46..b693258c84d 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -66,6 +66,11 @@ def noop(*args, **kwargs): empty_cache = noop synchronize = noop get_free_memory = get_cpu_free_memory +elif hasattr(torch, "xpu") and torch.xpu.is_available(): + SYSTEM = "xpu" + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory else: SYSTEM = "cpu"