Skip to content

Commit

Permalink
Fix autograd bug and update test_csrmv_taichi.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 2, 2023
1 parent e8234ae commit ffccc19
Show file tree
Hide file tree
Showing 2 changed files with 447 additions and 441 deletions.
31 changes: 23 additions & 8 deletions brainpy/_src/math/sparse/_csr_mv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def _sparse_csr_matvec_jvp(
vector,
shape=shape,
transpose=transpose)
else:
dv = csrmv_taichi(values_dot,
col_indices,
row_ptr,
vector,
shape=shape,
transpose=transpose)
dw = csrmv_taichi(values,
col_indices,
row_ptr,
vector_dot,
shape=shape,
transpose=transpose)
dr = [dv[0] + dw[0]]

return r, dr

def _sparse_csr_matvec_transpose(
Expand Down Expand Up @@ -228,9 +243,9 @@ def csrmv_taichi(
prim = None

if transpose:
prim = _event_csr_matvec_transpose_p
prim = _csr_matvec_transpose_p
else:
prim = _event_csr_matvec_p
prim = _csr_matvec_p

return prim(data,
indices,
Expand All @@ -241,13 +256,13 @@ def csrmv_taichi(
shape=shape)

# transpose
_event_csr_matvec_transpose_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
_csr_matvec_transpose_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
gpu_kernel=_sparse_csr_matvec_transpose_gpu)
_event_csr_matvec_transpose_p.def_jvp_rule(_sparse_csr_matvec_jvp)
_event_csr_matvec_transpose_p.def_transpose_rule(_sparse_csr_matvec_transpose)
_csr_matvec_transpose_p.def_jvp_rule(_sparse_csr_matvec_jvp)
_csr_matvec_transpose_p.def_transpose_rule(_sparse_csr_matvec_transpose)

# no transpose
_event_csr_matvec_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu,
_csr_matvec_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu,
gpu_kernel=_sparse_csr_matvec_gpu)
_event_csr_matvec_p.def_jvp_rule(_sparse_csr_matvec_jvp)
_event_csr_matvec_p.def_transpose_rule(_sparse_csr_matvec_transpose)
_csr_matvec_p.def_jvp_rule(_sparse_csr_matvec_jvp)
_csr_matvec_p.def_transpose_rule(_sparse_csr_matvec_transpose)
Loading

0 comments on commit ffccc19

Please sign in to comment.