Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Jan 15, 2025
1 parent 249a90f commit 339161b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,38 @@

from nvidia.dali.auto_aug import auto_augment, trivial_augment

@pipeline_def(enable_conditionals=True, exec_dynamic=True)
def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_augmentation,
dali_device="gpu", rank=0, world_size=1, send_filepaths=False):
rng = fn.random.coin_flip(probability=0.5)

def input_pipeline(data_dir, send_filepaths, random_shuffle=False, rank=0, world_size=1):
if data_dir is None:
if send_filepaths:
filepaths = fn.external_source(name="images", no_copy=True)
jpegs = fn.io.file.read(filepaths)
else:
jpegs = fn.external_source(name="images", no_copy=True)
return jpegs, None
else:
jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank,
num_shards=world_size, random_shuffle=True, pad_last_batch=True)
num_shards=world_size, random_shuffle=random_shuffle,
pad_last_batch=True)
return jpegs, labels

if dali_device == "gpu":
decoder_device = "mixed"
resize_device = "gpu"
else:
decoder_device = "cpu"
resize_device = "cpu"
@pipeline_def(enable_conditionals=True, exec_dynamic=True)
def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_augmentation,
dali_device="gpu", rank=0, world_size=1, send_filepaths=False):
jpegs, labels = input_pipeline(
data_dir, send_filepaths, random_shuffle=True, rank=rank, world_size=world_size)

decoder_device = "mixed" if dali_device == "gpu" else "cpu"
images = fn.decoders.image_random_crop(jpegs, device=decoder_device, output_type=types.RGB,
random_aspect_ratio=[0.75, 4.0 / 3.0],
random_area=[0.08, 1.0])

images = fn.resize(images, device=resize_device, size=[image_size, image_size],
images = fn.resize(images, size=[image_size, image_size],
interp_type=interpolation, antialias=False)

# Make sure that from this point we are processing on GPU regardless of dali_device parameter
images = images.gpu()

rng = fn.random.coin_flip(probability=0.5)
images = fn.flip(images, horizontal=rng)

# Based on the specification, apply the automatic augmentation policy. Note, that from the point
Expand All @@ -70,35 +70,22 @@ def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

if data_dir is None:
return output
else:
if labels:
return output, labels
else:
return output


@pipeline_def(exec_dynamic=True)
def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layout,
dali_device="gpu", rank=0, world_size=1, send_filepaths=False):
if data_dir is None:
if send_filepaths:
filepaths = fn.external_source(name="images", no_copy=True)
jpegs = fn.io.file.read(filepaths)
else:
jpegs = fn.external_source(name="images", no_copy=True)
else:
jpegs, label = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank,
num_shards=world_size, random_shuffle=False, pad_last_batch=True)

if dali_device == "gpu":
decoder_device = "mixed"
resize_device = "gpu"
else:
decoder_device = "cpu"
resize_device = "cpu"
jpegs, labels = input_pipeline(
data_dir, send_filepaths, random_shuffle=False, rank=rank, world_size=world_size)

decoder_device = "mixed" if dali_device == "gpu" else "cpu"
images = fn.decoders.image(jpegs, device=decoder_device, output_type=types.RGB)

images = fn.resize(images, device=resize_device, resize_shorter=image_size,
images = fn.resize(images, resize_shorter=image_size,
interp_type=interpolation, antialias=False)

images = images.gpu()
Expand All @@ -108,7 +95,7 @@ def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layo
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

if data_dir is None:
return output
if labels:
return output, labels
else:
return output, label
return output
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,9 @@ def gdtl(
pipe, reader_name="Reader", fill_last_batch=False
)

dali_server = None
return (
DALIWrapper(train_loader, num_classes, one_hot, memory_format),
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
dali_server,
)

return gdtl
Expand Down Expand Up @@ -220,12 +218,9 @@ def gdvl(
val_loader = DALIClassificationIterator(
pipe, reader_name="Reader", fill_last_batch=False
)

dali_server = None
return (
DALIWrapper(val_loader, num_classes, one_hot, memory_format),
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
dali_server,
)

return gdvl
Expand Down Expand Up @@ -376,11 +371,9 @@ def get_pytorch_train_loader(
persistent_workers=True,
prefetch_factor=prefetch_factor,
)
dali_server = None
return (
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, True, memory_format, "CHW"),
len(train_loader),
dali_server,
)


Expand Down Expand Up @@ -433,11 +426,9 @@ def get_pytorch_val_loader(
persistent_workers=True,
prefetch_factor=prefetch_factor,
)
dali_server = None
return (
PrefetchedWrapper(val_loader, 0, num_classes, one_hot, True, memory_format, "CHW"),
len(val_loader),
dali_server
)

def read_file(path):
Expand Down Expand Up @@ -516,7 +507,6 @@ def get_impl(data_path,
return (
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, False, memory_format, output_layout),
len(train_loader),
dali_server,
)
return get_impl

Expand Down Expand Up @@ -590,7 +580,6 @@ def get_impl(data_path,
return (
PrefetchedWrapper(val_loader, 0, num_classes, one_hot, False, memory_format, output_layout),
len(val_loader),
dali_server,
)
return get_impl

Expand Down Expand Up @@ -641,7 +630,6 @@ def get_synthetic_loader(
memory_format=torch.contiguous_format,
**kwargs,
):
dali_server = None
return (
SynteticDataLoader(
batch_size,
Expand All @@ -653,5 +641,4 @@ def get_synthetic_loader(
memory_format=memory_format,
),
-1,
dali_server,
)
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,34 @@ def calc_ips(batch_size, time):
)
tbs = world_size * batch_size
return tbs / time

def find_children_by_type_name(obj, target_type_name):
"""
Recursively find all child instances of a target type by its name.
:param obj: The object to inspect.
:param target_type_name: The name of the target type as a string.
:return: A list of instances of the target type found in the object.
"""
found = []

# Check if the object's type is a match
if type(obj).__name__ == target_type_name:
found.append(obj)

# If the object is a dictionary, check its values
elif isinstance(obj, dict):
for value in obj.values():
found.extend(find_children_by_type_name(value, target_type_name))

# If the object is a list, tuple, or set, check its elements
elif isinstance(obj, (list, tuple, set)):
for item in obj:
found.extend(find_children_by_type_name(item, target_type_name))

# If the object has a __dict__, check its attributes (for custom objects)
elif hasattr(obj, "__dict__"):
for attr_value in vars(obj).values():
found.extend(find_children_by_type_name(attr_value, target_type_name))

return found
34 changes: 13 additions & 21 deletions docs/examples/use_cases/pytorch/efficientnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@

import argparse
import random
from copy import deepcopy
from contextlib import nullcontext

import torch.backends.cudnn as cudnn
import torch.distributed as dist
Expand All @@ -61,7 +59,7 @@
)
from image_classification.gpu_affinity import set_affinity, AffinityMode
import dllogger

from contextlib import ExitStack

def available_models():
models = {m.name: m for m in [efficientnet_b0]}
Expand Down Expand Up @@ -513,7 +511,7 @@ def _worker_init_fn(id):
print("Bad databackend picked")
exit(1)

train_loader, train_loader_len, dali_server_train = get_train_loader(
train_loader, train_loader_len = get_train_loader(
args.data,
image_size,
args.batch_size,
Expand All @@ -530,7 +528,7 @@ def _worker_init_fn(id):
if args.mixup != 0.0:
train_loader = MixUpWrapper(args.mixup, train_loader)

val_loader, val_loader_len, dali_server_val = get_val_loader(
val_loader, val_loader_len = get_val_loader(
args.data,
image_size,
args.batch_size,
Expand Down Expand Up @@ -604,37 +602,29 @@ def _worker_init_fn(id):
lr_policy,
train_loader,
train_loader_len,
dali_server_train,
val_loader,
dali_server_val,
logger,
start_epoch,
best_prec1,
)


def conditional_with(resource):
if hasattr(resource, '__enter__') and hasattr(resource, '__exit__'):
return resource
return nullcontext(resource)


def main(args, model_args, model_arch):
exp_start_time = time.time()
(
trainer,
lr_policy,
train_loader,
train_loader_len,
dali_server_train,
val_loader,
dali_server_val,
logger,
start_epoch,
best_prec1,
) = prepare_for_training(args, model_args, model_arch)

with conditional_with(dali_server_train), conditional_with(dali_server_val):
with ExitStack() as stack:
for obj in find_children_by_type_name((train_loader, val_loader), "DALIServer"):

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error documentation

Local variable 'train_loader' may be used before it is initialized.

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error documentation

Local variable 'val_loader' may be used before it is initialized.
stack.enter_context(obj)
train_loop(
trainer,
lr_policy,
Expand All @@ -658,11 +648,10 @@ def main(args, model_args, model_arch):
topk=args.topk,
data_loader_only=args.data_loader_only,
)
exp_duration = time.time() - exp_start_time
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.end()
print("Experiment ended")

exp_duration = time.time() - exp_start_time

Check notice

Code scanning / CodeQL

Unused local variable Note documentation

Variable exp_duration is not used.
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.end()
print("Experiment ended")

if __name__ == "__main__":
epilog = [
Expand Down Expand Up @@ -690,3 +679,6 @@ def main(args, model_args, model_arch):
cudnn.benchmark = True

main(args, model_args, model_arch)
print("Done")
import sys
sys.exit(0)

0 comments on commit 339161b

Please sign in to comment.