Skip to content

Commit

Permalink
Use gpu autodection on trgets
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Jan 10, 2024
1 parent 851767c commit ded16fc
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 41 deletions.
14 changes: 1 addition & 13 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,7 @@
from . import conf # noqa
from .callbacks.optimization import HydraOptCallback
from .callbacks.prediction_logger import PredictionLogger
from .utils import (
auto_select_gpus,
clear_outputs,
common_callbacks,
git_version,
log_execution_env_state,
)
from .utils import clear_outputs, common_callbacks, git_version, log_execution_env_state
from .utils.dvclive import DVCLIVE_AVAILABLE, DVCLogger
from .utils.logger import JSONLogger

Expand Down Expand Up @@ -81,12 +75,6 @@ def train(
for seed in config.seed:
seed_everything(seed, workers=True)

if isinstance(config.trainer.devices, int) and config.trainer.accelerator in [
"gpu",
"auto",
]:
config.trainer.devices = auto_select_gpus(config.trainer.devices)

if not config.trainer.fast_dev_run and not config.get("resume", False):
clear_outputs()

Expand Down
4 changes: 1 addition & 3 deletions hannah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand All @@ -18,7 +18,6 @@
#
from .imports import lazy_import
from .utils import (
auto_select_gpus,
clear_outputs,
common_callbacks,
extract_from_download_cache,
Expand All @@ -33,7 +32,6 @@
"log_execution_env_state",
"list_all_files",
"extract_from_download_cache",
"auto_select_gpus",
"common_callbacks",
"clear_outputs",
"fullname",
Expand Down
25 changes: 0 additions & 25 deletions hannah/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,6 @@ def extract_from_download_cache(
)


def auto_select_gpus(gpus=1) -> Union[List[int]]:
if not torch.cuda.is_available() or torch.cuda.device_count() < 1:
return gpus

num_gpus = gpus

if not nvsmi.is_nvidia_smi_on_path():
return list(range(num_gpus))

gpus = list(nvsmi.get_gpus())

gpus = list(
sorted(gpus, key=lambda gpu: (gpu.mem_free, 1.0 - gpu.gpu_util), reverse=True)
)

job_num = hydra.core.hydra_config.HydraConfig.get().job.get("num", 0)

result = []
for i in range(num_gpus):
num = (i + job_num) % len(gpus)
result.append(int(gpus[num].id))

return result


def common_callbacks(config: DictConfig) -> list:
callbacks: List[Callback] = []

Expand Down

0 comments on commit ded16fc

Please sign in to comment.