-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid grad check #27
base: main
Are you sure you want to change the base?
Changes from 11 commits
8f590f5
85e4932
30cbd19
21ba0f2
71449f4
b384821
481401d
45acfb9
e75f852
ee666db
fd106c5
20e8a1d
c6e0d9a
52827e6
8fad9b7
f2d8c23
6cacf43
9bc3d18
beb2fb9
9d9747a
51279b5
6b6c327
e0789e3
75e4cf9
0922139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
import abc | ||
from typing import Any, Callable, Dict, List, Union | ||
from itertools import product | ||
from itertools import chain | ||
|
||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import math | ||
|
||
from .constants import ( | ||
# TYPE_DIMENSION, | ||
|
@@ -161,3 +162,109 @@ def method(self, *args, **kwargs): | |
success=success, | ||
) | ||
return derivative_check_result | ||
|
||
|
||
class HybridDerivativeCheck(DerivativeCheck): | ||
method_id = "hybrid" | ||
|
||
def method(self, *args, **kwargs): | ||
expected_values = [] | ||
test_values = [] | ||
success = True | ||
for direction_index, directional_derivative in enumerate( | ||
self.derivative.directional_derivatives | ||
): | ||
test_value = directional_derivative.value | ||
test_values.append(test_value) | ||
|
||
expected_value = [] | ||
for output_index in np.ndindex(self.output_indices): | ||
element = self.expectation[output_index][direction_index] | ||
expected_value.append(element) | ||
expected_value = np.array(expected_value).reshape(test_value.shape) | ||
expected_values.append(expected_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is the same in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored: ExtractMethod. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I implemented the refactor in |
||
|
||
# debug | ||
assert len(expected_values) == len( | ||
test_values | ||
), "Mismatch of step sizes" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove, or raise some specific exception? This is always due to a user error, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should always be a user error, removed. |
||
|
||
results_all = [] | ||
directional_derivative_check_results = [] | ||
for step_size in range(0, len(expected_values)): | ||
approxs_for_param = [] | ||
grads_for_param = [] | ||
Comment on lines
+189
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be reset in the inner loop, instead of here in the outer loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be okay as-is. |
||
results = [] | ||
for diff_index, directional_derivative in enumerate( | ||
self.derivative.directional_derivatives | ||
): | ||
try: | ||
for grad, approx in zip( | ||
expected_values[diff_index - 1][step_size - 1], | ||
test_values[diff_index - 1][step_size - 1], | ||
): | ||
approxs_for_param.append(approx) | ||
grads_for_param.append(grad) | ||
fd_range = np.percentile(approxs_for_param, [0, 100]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this just the min and max of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
fd_mean = np.mean(approxs_for_param) | ||
grad_mean = np.mean(grads_for_param) | ||
if not (fd_range[0] <= grad_mean <= fd_range[1]): | ||
if np.any( | ||
[ | ||
abs(x - y) > kwargs["atol"] | ||
for i, x in enumerate(approxs_for_param) | ||
for j, y in enumerate(approxs_for_param) | ||
if i != j | ||
] | ||
): | ||
fd_range = abs(fd_range[1] - fd_range[0]) | ||
if ( | ||
abs(grad_mean - fd_mean) | ||
/ abs(fd_range + np.finfo(float).eps) | ||
) > kwargs["rtol"]: | ||
results.append(False) | ||
else: | ||
results.append(False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The handling of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
else: | ||
results.append( | ||
None | ||
) # can't judge consistency / questionable grad approxs | ||
else: | ||
fd_range = abs(fd_range[1] - fd_range[0]) | ||
if math.isinf( | ||
(fd_range) | ||
or math.isnan(fd_range) | ||
or math.isinf(fd_mean) | ||
or math.isnan(fd_mean) | ||
): | ||
stephanmg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
results.append(None) | ||
else: | ||
results.append(True) | ||
stephanmg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except (IndexError, TypeError): | ||
# TODO: Fix this, why does this occur? | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a reproducible example of this? I could take a look There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be fixed now. |
||
|
||
directional_derivative_check_result = ( | ||
DirectionalDerivativeCheckResult( | ||
direction_id=directional_derivative.id, | ||
method_id=self.method_id, | ||
test=test_value, | ||
expectation=expected_value, | ||
output={"return": results}, | ||
success=all(results), | ||
) | ||
) | ||
directional_derivative_check_results.append( | ||
directional_derivative_check_result | ||
) | ||
results_all.append(results) | ||
|
||
success = all(chain(*results_all)) | ||
derivative_check_result = DerivativeCheckResult( | ||
method_id=self.method_id, | ||
directional_derivative_check_results=directional_derivative_check_results, | ||
test=self.derivative.value, | ||
expectation=self.expectation, | ||
success=success, | ||
) | ||
return derivative_check_result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you doc the logic of this class in a class docstring here? i.e. the formula/algorithm used to determine whether gradients are correct. Will be great for when I write the remaining docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. LMK if more is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I meant the actual formula, preferably with latex. There's an example of latex in RTD for
amici.import_utils.noise_distribution_to_cost_function
here [1], with the docstring here [2].But this can also wait until you publish the formula. If you decide to wait, could you add a TODO to do this later? Would be great to add a reference to your paper when it's published, too.
[1] https://amici.readthedocs.io/en/latest/generated/amici.import_utils.html#amici.import_utils.noise_distribution_to_cost_function
[2] https://github.com/AMICI-dev/AMICI/blob/3c5e997df3655c26dde35705ef25b2a0f419fe8b/python/sdist/amici/import_utils.py#L105-L107