Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup calc_scores() by caching outputs of calc_error(). #17

Open
wants to merge 1 commit into
base: old
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/scenicplus/BASCA.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ def calc_error(vect, P, i, j):
return np.sum(((vect - z) ** 2)[0: N])


@numba.jit(nopython=True)
def calc_errors(calc_errors_matrix, calc_errors_is_cached, vect, P, i, j):
# Check if we have the cost_ab(vect, a, b) value cached.
if calc_errors_is_cached[i, j] == np.bool_(True):
return calc_errors_matrix[i, j]

# Else calculate cost_ab(vect, a, b) and cache it for next time.
current_calc_error = calc_error(vect, P, i, j)
calc_errors_matrix[i, j] = current_calc_error
calc_errors_is_cached[i, j] = np.bool_(True)

return current_calc_error


@numba.jit(nopython=True)
def moving_block_bootstrap(v):
N = v.shape[0]
Expand Down Expand Up @@ -177,13 +191,16 @@ def calc_scores(vect, P):
# stores the index of the discontinuity with the maximum score for each step function
ind_Q_max = np.zeros(N, dtype=np.int64)

calc_errors_matrix = np.zeros((N, N - 1), np.float64)
calc_errors_is_cached = np.zeros((N, N - 1), np.bool_)

for j in range(0, N):
q_max = -1
ind_q_max = -1
for i in range(0, j + 1):
# calculate jump height
h = calc_jump_height(vect, P, i, j)
e = calc_error(vect, P, i, j)
e = calc_errors(calc_errors_matrix, calc_errors_is_cached, vect, P, i, j)
q = h / e
if q > q_max:
q_max = q
Expand Down