-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unified error calculation #560
Conversation
24ec8b3
to
863ea6f
Compare
863ea6f
to
a9925b2
Compare
df8a019
to
049d51d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nerkulec , thank you the PR! All around, it looks pretty great, and I think unifying the error calculations is a great idea.
I have some questions to the code (namely one maybe obscure error metric I was using occasionally and validation grpahs) that I left as comments.
Two general things:
- Since this introduces the tqdm package, shouldn't we include it in the requirements.txt? Or is it reliably shipped with another package? This may also affect the cpu_environments.yml.
- The docs are currently failing with this PR, which is due to tqdm not yet included in the autodoc_mock_imports list in the conf.py of sphinx. Sphinx is trying to import tqdm to generate an automated API (which is of course unnecessary) but failing, since it is not installed alongside the other docs packages. By adding it to the autodoc_mock_imports list, sphinx will not attempt to import it.
mala/network/runner.py
Outdated
errors[energy_type] = be_error | ||
except ValueError: | ||
errors[energy_type] = float("inf") | ||
elif energy_type == "band_energy_dft_fe": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there still a use case for calculating energies with DFT Fermi energy? I haven't done it myself in a long time, because I think it does not make that much sense conceptually. If we are overhauling the entire workflow in this aspect anyway, I'd argue for getting rid of it entirely - but if there is still good use for it we can of course keep it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I removed it
s.wait_stream( | ||
torch.cuda.current_stream( | ||
self.parameters._configuration[ | ||
"device" | ||
] | ||
) | ||
) | ||
# Warmup for graphs | ||
with torch.cuda.stream(s): | ||
for _ in range(20): | ||
with torch.cuda.amp.autocast( | ||
enabled=self.parameters.use_mixed_precision | ||
): | ||
prediction = network(x) | ||
if self.parameters_full.use_ddp: | ||
loss = network.module.calculate_loss( | ||
prediction, y | ||
) | ||
else: | ||
loss = network.calculate_loss( | ||
prediction, y | ||
) | ||
torch.cuda.current_stream( | ||
self.parameters._configuration["device"] | ||
).wait_stream(s) | ||
|
||
# Create static entry point tensors to graph | ||
self.static_input_validation = ( | ||
torch.empty_like(x) | ||
) | ||
self.static_target_validation = ( | ||
torch.empty_like(y) | ||
) | ||
|
||
# Capture graph | ||
self.validation_graph = torch.cuda.CUDAGraph() | ||
with torch.cuda.graph(self.validation_graph): | ||
with torch.cuda.amp.autocast( | ||
enabled=self.parameters.use_mixed_precision | ||
): | ||
self.static_prediction_validation = ( | ||
network( | ||
self.static_input_validation | ||
) | ||
) | ||
if self.parameters_full.use_ddp: | ||
self.static_loss_validation = network.module.calculate_loss( | ||
self.static_prediction_validation, | ||
self.static_target_validation, | ||
) | ||
else: | ||
self.static_loss_validation = network.calculate_loss( | ||
self.static_prediction_validation, | ||
self.static_target_validation, | ||
) | ||
|
||
if self.validation_graph: | ||
self.static_input_validation.copy_(x) | ||
self.static_target_validation.copy_(y) | ||
self.validation_graph.replay() | ||
validation_loss_sum += ( | ||
self.static_loss_validation | ||
) | ||
else: | ||
with torch.cuda.amp.autocast( | ||
enabled=self.parameters.use_mixed_precision | ||
): | ||
prediction = network(x) | ||
if self.parameters_full.use_ddp: | ||
loss = network.module.calculate_loss( | ||
prediction, y | ||
) | ||
else: | ||
loss = network.calculate_loss( | ||
prediction, y | ||
) | ||
validation_loss_sum += loss | ||
if ( | ||
batchid != 0 | ||
and (batchid + 1) % report_freq == 0 | ||
): | ||
torch.cuda.synchronize( | ||
self.parameters._configuration["device"] | ||
) | ||
sample_time = time.time() - tsample | ||
avg_sample_time = sample_time / report_freq | ||
avg_sample_tput = ( | ||
report_freq * x.shape[0] / sample_time | ||
) | ||
printout( | ||
f"batch {batchid + 1}, " # /{total_samples}, " | ||
f"validation avg time: {avg_sample_time} " | ||
f"validation avg throughput: {avg_sample_tput}", | ||
min_verbosity=2, | ||
) | ||
tsample = time.time() | ||
batchid += 1 | ||
torch.cuda.synchronize( | ||
self.parameters._configuration["device"] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be missing something, but I cannot find this portion of the code in the new __validate_network - is there a reason why we would want to get rid of the validation graphs? I thought they were working nicely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation graphs in validate_network only feed the inputs forward through the network, and later invoke calculate_loss on the model which is just the mean square error. No gradients are accumulated or weights updated. I could potentially leave the computation graphs here just for the LDOS metric, but then all other metrics' calculations would still run outside the computation graph, since it's outside of torch. I could re-use the network predictions out of the graph, but I doubt there is any significant speed improvement compared to evaluation in eager mode. I even doubt whether there is any noticeable improvement when you add the MSE on top. That's why I removed them altogether here. If you strongly feel that this hurts performance when using just the LDOS validation metric or done a benchmark that shows the performance difference I can put it back here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, got it. I think I didn't look into the actual functionality deep enough, just noticed that code that had been authored by Josh had been deleted and just wanted to know if it was by accident or intentional. Your explanation makes sense, I am OK with deleting this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for incorporating all the changes and feedback, this looks great to me now and can be merged from my side!
This PR includes logging overhaul and makes energy error calculations uniform (in units of meV/atom).