Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified error calculation #560

Merged
merged 11 commits into from
Oct 7, 2024
Merged

Conversation

nerkulec
Copy link

This PR includes logging overhaul and makes energy error calculations uniform (in units of meV/atom).

@nerkulec nerkulec marked this pull request as draft July 22, 2024 08:36
@nerkulec nerkulec force-pushed the uniform_error_calculation branch from 24ec8b3 to 863ea6f Compare July 22, 2024 09:19
@nerkulec nerkulec marked this pull request as ready for review July 22, 2024 09:20
@nerkulec nerkulec marked this pull request as draft July 22, 2024 09:52
@nerkulec nerkulec force-pushed the uniform_error_calculation branch from 863ea6f to a9925b2 Compare July 22, 2024 10:41
@nerkulec nerkulec force-pushed the uniform_error_calculation branch from df8a019 to 049d51d Compare July 22, 2024 11:14
@nerkulec nerkulec marked this pull request as ready for review July 24, 2024 10:48
@RandomDefaultUser RandomDefaultUser marked this pull request as draft July 26, 2024 08:37
@RandomDefaultUser RandomDefaultUser marked this pull request as ready for review July 26, 2024 08:37
Copy link
Member

@RandomDefaultUser RandomDefaultUser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nerkulec , thank you the PR! All around, it looks pretty great, and I think unifying the error calculations is a great idea.
I have some questions to the code (namely one maybe obscure error metric I was using occasionally and validation grpahs) that I left as comments.

Two general things:

  1. Since this introduces the tqdm package, shouldn't we include it in the requirements.txt? Or is it reliably shipped with another package? This may also affect the cpu_environments.yml.
  2. The docs are currently failing with this PR, which is due to tqdm not yet included in the autodoc_mock_imports list in the conf.py of sphinx. Sphinx is trying to import tqdm to generate an automated API (which is of course unnecessary) but failing, since it is not installed alongside the other docs packages. By adding it to the autodoc_mock_imports list, sphinx will not attempt to import it.

mala/network/runner.py Outdated Show resolved Hide resolved
errors[energy_type] = be_error
except ValueError:
errors[energy_type] = float("inf")
elif energy_type == "band_energy_dft_fe":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there still a use case for calculating energies with DFT Fermi energy? I haven't done it myself in a long time, because I think it does not make that much sense conceptually. If we are overhauling the entire workflow in this aspect anyway, I'd argue for getting rid of it entirely - but if there is still good use for it we can of course keep it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I removed it

mala/network/runner.py Show resolved Hide resolved
Comment on lines -960 to -1070
s.wait_stream(
torch.cuda.current_stream(
self.parameters._configuration[
"device"
]
)
)
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = network(x)
if self.parameters_full.use_ddp:
loss = network.module.calculate_loss(
prediction, y
)
else:
loss = network.calculate_loss(
prediction, y
)
torch.cuda.current_stream(
self.parameters._configuration["device"]
).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = (
torch.empty_like(x)
)
self.static_target_validation = (
torch.empty_like(y)
)

# Capture graph
self.validation_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.validation_graph):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
self.static_prediction_validation = (
network(
self.static_input_validation
)
)
if self.parameters_full.use_ddp:
self.static_loss_validation = network.module.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)
else:
self.static_loss_validation = network.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)

if self.validation_graph:
self.static_input_validation.copy_(x)
self.static_target_validation.copy_(y)
self.validation_graph.replay()
validation_loss_sum += (
self.static_loss_validation
)
else:
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = network(x)
if self.parameters_full.use_ddp:
loss = network.module.calculate_loss(
prediction, y
)
else:
loss = network.calculate_loss(
prediction, y
)
validation_loss_sum += loss
if (
batchid != 0
and (batchid + 1) % report_freq == 0
):
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
sample_time = time.time() - tsample
avg_sample_time = sample_time / report_freq
avg_sample_tput = (
report_freq * x.shape[0] / sample_time
)
printout(
f"batch {batchid + 1}, " # /{total_samples}, "
f"validation avg time: {avg_sample_time} "
f"validation avg throughput: {avg_sample_tput}",
min_verbosity=2,
)
tsample = time.time()
batchid += 1
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but I cannot find this portion of the code in the new __validate_network - is there a reason why we would want to get rid of the validation graphs? I thought they were working nicely.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation graphs in validate_network only feed the inputs forward through the network, and later invoke calculate_loss on the model which is just the mean square error. No gradients are accumulated or weights updated. I could potentially leave the computation graphs here just for the LDOS metric, but then all other metrics' calculations would still run outside the computation graph, since it's outside of torch. I could re-use the network predictions out of the graph, but I doubt there is any significant speed improvement compared to evaluation in eager mode. I even doubt whether there is any noticeable improvement when you add the MSE on top. That's why I removed them altogether here. If you strongly feel that this hurts performance when using just the LDOS validation metric or done a benchmark that shows the performance difference I can put it back here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, got it. I think I didn't look into the actual functionality deep enough, just noticed that code that had been authored by Josh had been deleted and just wanted to know if it was by accident or intentional. Your explanation makes sense, I am OK with deleting this part.

@RandomDefaultUser RandomDefaultUser self-requested a review October 7, 2024 11:51
Copy link
Member

@RandomDefaultUser RandomDefaultUser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for incorporating all the changes and feedback, this looks great to me now and can be merged from my side!

@RandomDefaultUser RandomDefaultUser merged commit d56e2d1 into develop Oct 7, 2024
6 checks passed
@RandomDefaultUser RandomDefaultUser deleted the uniform_error_calculation branch October 7, 2024 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants