Skip to content

Commit

Permalink
Merge branch 'f/h100_tuning' into 'main'
Browse files Browse the repository at this point in the history
Some tuning knobs for training on H100

See merge request es/ai/hannah/hannah!368
  • Loading branch information
cgerum committed Jan 30, 2024
2 parents f23247d + 1219bc7 commit d015464
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 15 deletions.
5 changes: 3 additions & 2 deletions experiments/cifar10/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ monitor:
direction: maximize

module:
batch_size: 512
batch_size: 2048

trainer:
max_epochs: 30
max_epochs: 50
precision: 16

optimizer:
lr: 0.3
3 changes: 2 additions & 1 deletion experiments/cifar10/experiment/sweep_lr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ experiment_id: sweep_lr
hydra:
mode: MULTIRUN
sweep:
subdir: lr=${scheduler.max_lr}
subdir: ${model.name}/lr=${scheduler.max_lr}
sweeper:
params:
scheduler.max_lr: 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9
model: timm_resnet18,timm_mobilenetv3_small_075,timm_mobilenetv3_small_100,kakao_resnet8
2 changes: 1 addition & 1 deletion experiments/cifar10/experiment/sweep_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ hydra:
subdir: ${model.name}
sweeper:
params:
model: timm_resnet18,timm_mobilenetv3_small_075,timm_mobilenetv3_small_100
model: timm_resnet18,timm_mobilenetv3_small_075,timm_mobilenetv3_small_100,kakao_resnet8
3 changes: 2 additions & 1 deletion hannah/conf/nas/model_trainer/simple.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
_target_: hannah.nas.search.model_trainer.simple_model_trainer.SimpleModelTrainer
_target_: hannah.nas.search.model_trainer.simple_model_trainer.SimpleModelTrainer
per_process_memory_fraction: null
1 change: 0 additions & 1 deletion hannah/modules/augmentation/batch_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def forward(self, x) -> torch.Tensor:
Returns:
Tuple[torch.Tensor, torch.Tensor]; Batch augmented with `replica` different random augmentations
"""

result = self.transforms(x)

return result
5 changes: 5 additions & 0 deletions hannah/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
shuffle_all_dataloaders: bool = False,
augmentation: Optional[DictConfig] = None,
pseudo_labeling: Optional[DictConfig] = None,
log_images: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -111,6 +112,8 @@ def __init__(

self.loss_weights = None

self._log_images = log_images

@abstractmethod
def prepare_data(self) -> Any:
# get all the necessary data stuff
Expand Down Expand Up @@ -408,6 +411,8 @@ def _plot_confusion_matrix(self) -> None:
)

def _log_batch_images(self, name: str, batch_idx: int, data: torch.tensor):
if not self._log_images:
return
loggers = self._logger_iterator()
for logger in loggers:
if hasattr(logger.experiment, "add_image"):
Expand Down
19 changes: 14 additions & 5 deletions hannah/modules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,10 @@ def setup_augmentations(self, pipeline_configs):

self.default_augmentation = torch.nn.Sequential(*default_augment)
augmentations = {k: torch.nn.Sequential(*v) for k, v in augmentations.items()}
self.augmentations = torch.nn.ModuleDict(augmentations)

return augmentations

# self.augmentations = torch.nn.ModuleDict(augmentations)

def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False):
batch_size = self.hparams["batch_size"]
Expand All @@ -310,29 +313,35 @@ def calc_workers(dataset):
else dataset.max_workers
)
return result

num_workers = calc_workers(dataset)

loader = data.DataLoader(
dataset,
batch_size=batch_size,
drop_last=True,
num_workers=calc_workers(dataset),
num_workers=num_workers,
sampler=sampler if not dataset.sequential else None,
collate_fn=vision_collate_fn,
multiprocessing_context="fork" if self.hparams["num_workers"] > 0 else None,
multiprocessing_context="fork" if num_workers > 0 else None,
persistent_workers = True if num_workers > 0 else False,
prefetch_factor = 2 if num_workers > 0 else None,
pin_memory=True,
)
self.batches_per_epoch = len(loader)

if unlabeled_data:
unlabeled_workers = calc_workers(unlabeled_data)
loader_unlabeled = data.DataLoader(
unlabeled_data,
batch_size=batch_size,
drop_last=True,
num_workers=calc_workers(unlabeled_data),
num_workers=unlabeled_workers,
sampler=data.RandomSampler(unlabeled_data)
if not unlabeled_data.sequential
else None,
multiprocessing_context="fork"
if self.hparams["num_workers"] > 0
if unlabeled_workers > 0
else None,
)

Expand Down
3 changes: 1 addition & 2 deletions hannah/modules/vision/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from hannah.datasets.collate import vision_collate_fn

from ..augmentation.batch_augmentation import BatchAugmentationPipeline
from ..metrics import Error
from .base import VisionBaseModule
from .loss import SemiSupervisedLoss
Expand Down Expand Up @@ -144,7 +143,7 @@ def test_step(self, batch, batch_idx):
if y is not None and preds is not None:
self.test_confusion(preds, y)

if y is not None and step_results.logits.numel() > 0:
if y is not None and hasattr(step_results, 'logits') and step_results.logits.numel() > 0:
probs = torch.softmax(step_results.logits, dim=1)
self.test_roc(probs, y)
self.test_pr_curve(probs, y)
Expand Down
7 changes: 5 additions & 2 deletions hannah/nas/search/model_trainer/simple_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@


class SimpleModelTrainer:
def __init__(self) -> None:
pass
def __init__(self, per_process_memory_fraction = None) -> None:
self.per_process_memory_fraction = per_process_memory_fraction

def build_model(self, model, parameters):
# model_instance = deepcopy(model)
Expand Down Expand Up @@ -134,4 +134,7 @@ def setup_devices(self, num, config, logger):
)
device = device % torch.cuda.device_count()

if self.per_process_memory_fraction:
torch.cuda.set_per_process_memory_fraction(self.per_process_memory_fraction, device=device)

config.trainer.devices = [device]
2 changes: 2 additions & 0 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def train(
validate_output = False
if hasattr(config, "validate_output") and isinstance(config.validate_output, bool):
validate_output = config.validate_output

torch.set_float32_matmul_precision('high')

for seed in config.seed:
seed_everything(seed, workers=True)
Expand Down

0 comments on commit d015464

Please sign in to comment.