Skip to content

Commit

Permalink
Add cuda support when loading local onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom committed Dec 13, 2024
1 parent 63f11db commit 004cb07
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/turnkeyml/llm/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
oga_model_builder_cache_path = "model_builder"

# Mapping from processor to executiion provider, used in pathnames and by model_builder
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml"}
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml", "cuda": "cuda"}


class OrtGenaiTokenizer(TokenizerAdapter):
Expand Down Expand Up @@ -248,7 +248,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
parser.add_argument(
"-d",
"--device",
choices=["igpu", "npu", "cpu"],
choices=["igpu", "npu", "cpu", "cuda"],
default="igpu",
help="Which device to load the model on to (default: igpu)",
)
Expand Down

0 comments on commit 004cb07

Please sign in to comment.