Skip to content

Commit

Permalink
Merge pull request #617 from RandomDefaultUser/fix_ddp_validation
Browse files Browse the repository at this point in the history
Recovering DDP scalability
  • Loading branch information
RandomDefaultUser authored Nov 29, 2024
2 parents dddccd4 + 20a06f2 commit 03f6b96
Showing 1 changed file with 193 additions and 33 deletions.
226 changes: 193 additions & 33 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,47 +675,207 @@ def _validate_network(self, data_set_fractions, metrics):
)
loader_id += 1
else:
with torch.no_grad():
for snapshot_number in trange(
offset_snapshots,
number_of_snapshots + offset_snapshots,
desc="Validation",
disable=self.parameters_full.verbosity < 2,
):
# Get optimal batch size and number of batches per snapshotss
grid_size = (
self.data.parameters.snapshot_directories_list[
snapshot_number
].grid_size
)
# If only the LDOS is in the validation metrics (as is the
# case for, e.g., distributed network trainings), we can
# use a faster (or at least better parallelizing) code

optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
number_of_batches_per_snapshot = int(
grid_size / optimal_batch_size
if (
len(self.parameters.validation_metrics) == 1
and self.parameters.validation_metrics[0] == "ldos"
):

errors[data_set_type]["ldos"] = (
self.__calculate_validation_error_ldos_only(
data_loaders
)
)

actual_outputs, predicted_outputs = (
self._forward_entire_snapshot(
else:
with torch.no_grad():
for snapshot_number in trange(
offset_snapshots,
number_of_snapshots + offset_snapshots,
desc="Validation",
disable=self.parameters_full.verbosity < 2,
):
# Get optimal batch size and number of batches per snapshotss
grid_size = (
self.data.parameters.snapshot_directories_list[
snapshot_number
].grid_size
)

optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
number_of_batches_per_snapshot = int(
grid_size / optimal_batch_size
)

actual_outputs, predicted_outputs = (
self._forward_entire_snapshot(
snapshot_number,
data_sets[0],
data_set_type[0:2],
number_of_batches_per_snapshot,
optimal_batch_size,
)
)
calculated_errors = self._calculate_errors(
actual_outputs,
predicted_outputs,
metrics,
snapshot_number,
data_sets[0],
data_set_type[0:2],
number_of_batches_per_snapshot,
optimal_batch_size,
)
for metric in metrics:
errors[data_set_type][metric].append(
calculated_errors[metric]
)
return errors

def __calculate_validation_error_ldos_only(self, data_loaders):
validation_loss_sum = torch.zeros(
1, device=self.parameters._configuration["device"]
)
with torch.no_grad():
if self.parameters._configuration["gpu"]:
report_freq = self.parameters.training_log_interval
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
tsample = time.time()
batchid = 0
for loader in data_loaders:
for x, y in loader:
x = x.to(
self.parameters._configuration["device"],
non_blocking=True,
)
calculated_errors = self._calculate_errors(
actual_outputs,
predicted_outputs,
metrics,
snapshot_number,
y = y.to(
self.parameters._configuration["device"],
non_blocking=True,
)
for metric in metrics:
errors[data_set_type][metric].append(
calculated_errors[metric]

if (
self.parameters.use_graphs
and self._validation_graph is None
):
printout(
"Capturing CUDA graph for validation.",
min_verbosity=2,
)
return errors
s = torch.cuda.Stream(
self.parameters._configuration["device"]
)
s.wait_stream(
torch.cuda.current_stream(
self.parameters._configuration["device"]
)
)
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = self.network(x)
if self.parameters_full.use_ddp:
loss = self.network.module.calculate_loss(
prediction, y
)
else:
loss = self.network.calculate_loss(
prediction, y
)
torch.cuda.current_stream(
self.parameters._configuration["device"]
).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = torch.empty_like(x)
self.static_target_validation = torch.empty_like(y)

# Capture graph
self._validation_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._validation_graph):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
self.static_prediction_validation = (
self.network(
self.static_input_validation
)
)
if self.parameters_full.use_ddp:
self.static_loss_validation = self.network.module.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)
else:
self.static_loss_validation = self.network.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)

if self._validation_graph:
self.static_input_validation.copy_(x)
self.static_target_validation.copy_(y)
self._validation_graph.replay()
validation_loss_sum += self.static_loss_validation
else:
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = self.network(x)
if self.parameters_full.use_ddp:
loss = self.network.module.calculate_loss(
prediction, y
)
else:
loss = self.network.calculate_loss(
prediction, y
)
validation_loss_sum += loss
if batchid != 0 and (batchid + 1) % report_freq == 0:
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
sample_time = time.time() - tsample
avg_sample_time = sample_time / report_freq
avg_sample_tput = (
report_freq * x.shape[0] / sample_time
)
printout(
f"batch {batchid + 1}, " # /{total_samples}, "
f"validation avg time: {avg_sample_time} "
f"validation avg throughput: {avg_sample_tput}",
min_verbosity=2,
)
tsample = time.time()
batchid += 1
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
else:
batchid = 0
for loader in data_loaders:
for x, y in loader:
x = x.to(self.parameters._configuration["device"])
y = y.to(self.parameters._configuration["device"])
prediction = self.network(x)
if self.parameters_full.use_ddp:
validation_loss_sum += (
self.network.module.calculate_loss(
prediction, y
).item()
)
else:
validation_loss_sum += self.network.calculate_loss(
prediction, y
).item()
batchid += 1

return validation_loss_sum.item() / batchid

def __prepare_to_train(self, optimizer_dict):
"""Prepare everything for training."""
Expand Down

0 comments on commit 03f6b96

Please sign in to comment.