Skip to content

Commit

Permalink
Config: Add experimental torch cuda malloc backend
Browse files Browse the repository at this point in the history
This option saves some VRAM, but does have the chance to error out.
Add this in the experimental config section.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 15, 2024
1 parent 664e2c4 commit 949248f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
5 changes: 5 additions & 0 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ def add_developer_args(parser: argparse.ArgumentParser):
type=str_to_bool,
help="Disables API request streaming",
)
developer_group.add_argument(
"--cuda-malloc-backend",
type=str_to_bool,
help="Disables API request streaming",
)
6 changes: 5 additions & 1 deletion config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ sampling:
# WARNING: Using this can result in a generation speed penalty
#override_preset:

# Options for development
# Options for development and experimentation
developer:
# Skips exllamav2 version check (default: False)
# It's highly recommended to update your dependencies rather than enabling this flag
Expand All @@ -46,6 +46,10 @@ developer:
# A kill switch for turning off SSE in the API server
#disable_request_streaming: False

# Enable the torch CUDA malloc backend (default: False)
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
#cuda_malloc_backend: False

# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import os
import pathlib
import uvicorn
from asyncio import CancelledError
Expand Down Expand Up @@ -600,6 +601,11 @@ def entrypoint(args: Optional[dict] = None):
else:
check_exllama_version()

# Enable CUDA malloc backend
if unwrap(developer_config.get("cuda_malloc_backend"), False):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("Enabled the experimental CUDA malloc backend.")

network_config = get_network_config()

# Initialize auth keys
Expand Down

0 comments on commit 949248f

Please sign in to comment.