From 004cb079702bf7317bb5ba688211a8d40b86f24a Mon Sep 17 00:00:00 2001 From: David Fan Date: Fri, 13 Dec 2024 16:57:03 +0000 Subject: [PATCH] Add cuda support when loading local onnx model --- src/turnkeyml/llm/tools/ort_genai/oga.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/turnkeyml/llm/tools/ort_genai/oga.py b/src/turnkeyml/llm/tools/ort_genai/oga.py index de5a14a..633708c 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga.py +++ b/src/turnkeyml/llm/tools/ort_genai/oga.py @@ -35,7 +35,7 @@ oga_model_builder_cache_path = "model_builder" # Mapping from processor to executiion provider, used in pathnames and by model_builder -execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml"} +execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml", "cuda": "cuda"} class OrtGenaiTokenizer(TokenizerAdapter): @@ -248,7 +248,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser: parser.add_argument( "-d", "--device", - choices=["igpu", "npu", "cpu"], + choices=["igpu", "npu", "cpu", "cuda"], default="igpu", help="Which device to load the model on to (default: igpu)", )