Skip to content

Commit

Permalink
Merge pull request #495 from RandomDefaultUser/fix_mpi_hyperopt
Browse files Browse the repository at this point in the history
Fixed distributed Optuna running on multiple GPUs
  • Loading branch information
RandomDefaultUser authored Dec 22, 2023
2 parents 8af52ff + 45f0749 commit 7254c5a
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def train_network(self):
self.data.training_data_sets[0].shuffle()

if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
t0 = time.time()
batchid = 0
Expand Down Expand Up @@ -309,7 +309,7 @@ def train_network(self):
training_loss_sum += loss

if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / self.parameters.training_report_frequency
avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time
Expand All @@ -319,14 +319,14 @@ def train_network(self):
min_verbosity=2)
tsample = time.time()
batchid += 1
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
t1 = time.time()
printout(f"training time: {t1 - t0}", min_verbosity=2)

training_loss = training_loss_sum.item() / batchid

# Calculate the validation loss. and output it.
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in self.training_data_loaders:
Expand Down Expand Up @@ -375,14 +375,14 @@ def train_network(self):
self.tensor_board.close()

if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])

# Mix the DataSets up (this function only does something
# in the lazy loading case).
if self.parameters.use_shuffling_for_samplers:
self.data.mix_datasets()
if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])

# If a scheduler is used, update it.
if self.scheduler is not None:
Expand Down Expand Up @@ -636,8 +636,8 @@ def __process_mini_batch(self, network, input_data, target_data):
if self.parameters._configuration["gpu"]:
if self.parameters.use_graphs and self.train_graph is None:
printout("Capturing CUDA graph for training.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
Expand All @@ -651,7 +651,7 @@ def __process_mini_batch(self, network, input_data, target_data):
self.gradscaler.scale(loss).backward()
else:
loss.backward()
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_data = torch.empty_like(input_data)
Expand Down Expand Up @@ -742,7 +742,7 @@ def __validate_network(self, network, data_set_type, validation_type):
with torch.no_grad():
if self.parameters._configuration["gpu"]:
report_freq = self.parameters.training_report_frequency
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
batchid = 0
for loader in data_loaders:
Expand All @@ -754,15 +754,15 @@ def __validate_network(self, network, data_set_type, validation_type):

if self.parameters.use_graphs and self.validation_graph is None:
printout("Capturing CUDA graph for validation.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
prediction = network(x)
loss = network.calculate_loss(prediction, y)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = torch.empty_like(x)
Expand All @@ -786,7 +786,7 @@ def __validate_network(self, network, data_set_type, validation_type):
loss = network.calculate_loss(prediction, y)
validation_loss_sum += loss
if batchid != 0 and (batchid + 1) % report_freq == 0:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / report_freq
avg_sample_tput = report_freq * x.shape[0] / sample_time
Expand All @@ -796,7 +796,7 @@ def __validate_network(self, network, data_set_type, validation_type):
min_verbosity=2)
tsample = time.time()
batchid += 1
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in data_loaders:
Expand Down

0 comments on commit 7254c5a

Please sign in to comment.