Skip to content

Commit

Permalink
Better gpu autoselection
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Jan 10, 2024
1 parent a818c69 commit 851767c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
3 changes: 1 addition & 2 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def train(
"gpu",
"auto",
]:
if torch.cuda.is_available():
config.trainer.devices = auto_select_gpus(config.trainer.devices)
config.trainer.devices = auto_select_gpus(config.trainer.devices)

if not config.trainer.fast_dev_run and not config.get("resume", False):
clear_outputs()
Expand Down
9 changes: 6 additions & 3 deletions hannah/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand Down Expand Up @@ -28,7 +28,7 @@
import time
from contextlib import _GeneratorContextManager, contextmanager
from pathlib import Path
from typing import Any, Callable, Iterator, List, Type, TypeVar
from typing import Any, Callable, Iterator, List, Type, TypeVar, Union

import hydra
import numpy as np
Expand Down Expand Up @@ -213,7 +213,10 @@ def extract_from_download_cache(
)


def auto_select_gpus(gpus=1) -> List[int]:
def auto_select_gpus(gpus=1) -> Union[List[int]]:
if not torch.cuda.is_available() or torch.cuda.device_count() < 1:
return gpus

num_gpus = gpus

if not nvsmi.is_nvidia_smi_on_path():
Expand Down

0 comments on commit 851767c

Please sign in to comment.