Skip to content

Commit

Permalink
[dnn] Improve error handling in dnn/linear module (#704)
Browse files Browse the repository at this point in the history
Improve error handling in `dnn/linear` module
  • Loading branch information
Routhleck authored Dec 5, 2024
1 parent a354bf2 commit 2bb2f9e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 33 deletions.
16 changes: 8 additions & 8 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
from brainpy._src.dependency_check import import_taichi, import_braintaichi
from brainpy._src.dependency_check import import_taichi, import_braintaichi, raise_braintaichi_not_found
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
from brainpy.check import is_initializer
Expand Down Expand Up @@ -241,7 +241,7 @@ def update(self, x):
return x


if ti is not None:
if ti is not None and bti is not None:

# @numba.njit(nogil=True, fastmath=True, parallel=False)
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
Expand Down Expand Up @@ -321,7 +321,7 @@ def _dense_on_pre(

def dense_on_pre(weight, spike, trace, w_min, w_max):
if dense_on_pre_prim is None:
raise PackageMissingError.by_purpose('taichi', 'custom operators')
raise_braintaichi_not_found()

if w_min is None:
w_min = -np.inf
Expand All @@ -341,7 +341,7 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):

def dense_on_post(weight, spike, trace, w_min, w_max):
if dense_on_post_prim is None:
raise PackageMissingError.by_purpose('taichi', 'custom operators')
raise_braintaichi_not_found()

if w_min is None:
w_min = -np.inf
Expand Down Expand Up @@ -728,7 +728,7 @@ def _batch_csrmv(self, x):
transpose=self.transpose)


if ti is not None:
if ti is not None and bti is not None:
@ti.kernel
def _csr_on_pre_update(
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
Expand Down Expand Up @@ -852,7 +852,7 @@ def _csc_on_post_update(

def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
if csr_on_pre_update_prim is None:
raise PackageMissingError.by_purpose('taichi', 'customized operators')
raise_braintaichi_not_found()

if w_min is None:
w_min = -np.inf
Expand All @@ -874,7 +874,7 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):

def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None):
if coo_on_pre_update_prim is None:
raise PackageMissingError.by_purpose('taichi', 'customized operators')
raise_braintaichi_not_found()

if w_min is None:
w_min = -np.inf
Expand All @@ -897,7 +897,7 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None

def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None):
if csc_on_post_update_prim is None:
raise PackageMissingError.by_purpose('taichi', 'customized operators')
raise_braintaichi_not_found()

if w_min is None:
w_min = -np.inf
Expand Down
2 changes: 1 addition & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ jax
jaxlib
scipy>=1.1.0
brainpy
brainpylib
brainpy_datasets
h5py
pathos
braintaichi

# test requirements
pytest
Expand Down
24 changes: 0 additions & 24 deletions docs/quickstart/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ To install brainpy with minimum requirements (only depends on ``jax``), you can
# or
pip install brainpy[cuda11_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0
pip install brainpy[cuda12_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0
# or
Expand Down Expand Up @@ -64,7 +63,6 @@ To install a GPU-only version of BrainPy, you can run
.. code-block:: bash
pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0
pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0
Expand All @@ -79,25 +77,3 @@ you can run the following in your cloud TPU VM:
pip install brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # for google TPU
``brainpylib``
--------------


``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks.


To install the ``brainpylib`` package on CPU devices, you can run

.. code-block:: bash
pip install brainpylib
To install the ``brainpylib`` package on CUDA (Linux only), you can run


.. code-block:: bash
pip install brainpylib

0 comments on commit 2bb2f9e

Please sign in to comment.