Skip to content

Commit

Permalink
Merge pull request #542 from chaoming0625/master
Browse files Browse the repository at this point in the history
[dyn] add `save_state`, `load_state`, `reset_state`, and `clear_input` helpers
  • Loading branch information
chaoming0625 authored Nov 10, 2023
2 parents 8e201f6 + 178a7cc commit 40c2632
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 48 deletions.
2 changes: 2 additions & 0 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
# shared parameters
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
save_state as save_state,
load_state as load_state,
clear_input as clear_input)


Expand Down
11 changes: 2 additions & 9 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,21 @@
_minimal_taichi_version = (1, 7, 0)

taichi = None
has_import_ti = False
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None


def import_taichi():
global taichi, has_import_ti
if not has_import_ti:
global taichi
if taichi is None:
try:
import taichi as taichi # noqa
has_import_ti = True
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if taichi is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if taichi.__version__ < _minimal_taichi_version:
raise RuntimeError(
f'We need taichi>={_minimal_taichi_version}. '
Expand Down
49 changes: 48 additions & 1 deletion brainpy/_src/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .dynsys import DynamicalSystem, DynView
from typing import Dict

from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.dynsys import DynamicalSystem, DynView
from brainpy._src.math.object_transform.base import StateLoadResult


__all__ = [
'reset_state',
'load_state',
'save_state',
'clear_input',
]

Expand Down Expand Up @@ -30,3 +36,44 @@ def clear_input(target: DynamicalSystem, *args, **kwargs):
"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values():
node.clear_input(*args, **kwargs)


def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
"""Copy parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
target: DynamicalSystem. The dynamical system to load its states.
state_dict: dict. A dict containing parameters and persistent buffers.
Returns:
-------
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique()
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
r = node.load_state(state_dict[name], **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
return StateLoadResult(missing_keys, unexpected_keys)


def save_state(target: DynamicalSystem, **kwargs) -> Dict:
"""Save all states in the ``target`` as a dictionary for later disk serialization.
Args:
target: DynamicalSystem. The node to save its states.
Returns:
Dict. The state dict for serialization.
"""
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes
return {key: node.save_state(**kwargs) for key, node in nodes.items()}

12 changes: 10 additions & 2 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,14 @@ def unique_name(self, name=None, type_=None):
check_name_uniqueness(name=name, obj=self)
return name

def save_state(self, **kwargs) -> Dict:
"""Save states as a dictionary. """
return self.__save_state__(**kwargs)

def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
"""Load states from a dictionary."""
return self.__load_state__(state_dict, **kwargs)

def __save_state__(self, **kwargs) -> Dict:
"""Save states. """
return self.vars(include_self=True, level=0).unique().dict()
Expand All @@ -502,7 +510,7 @@ def state_dict(self, **kwargs) -> dict:
A dictionary containing a whole state of the module.
"""
nodes = self.nodes() # retrieve all nodes
return {key: node.__save_state__(**kwargs) for key, node in nodes.items()}
return {key: node.save_state(**kwargs) for key, node in nodes.items()}

def load_state_dict(
self,
Expand Down Expand Up @@ -544,7 +552,7 @@ def load_state_dict(
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
r = node.__load_state__(state_dict[name], **kwargs)
r = node.load_state(state_dict[name], **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
Expand Down
65 changes: 36 additions & 29 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import jax
import jax.numpy as jnp
import taichi as ti
import taichi as taichi
import pytest
import platform

import brainpy.math as bm

bm.set_platform('cpu')

if not platform.platform().startswith('Windows'):
pytest.skip(allow_module_level=True)


# @ti.kernel
# def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
# vector: ti.types.ndarray(ndim=1),
Expand All @@ -19,43 +25,44 @@
# for j in range(num_cols):
# out[indices[i, j]] += weight_0

@ti.func
def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
return weight[0]
@taichi.func
def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
return weight[0]

@ti.func
def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val

@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
vector: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)
@taichi.func
def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
out[index] += weight_val


prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
@taichi.kernel
def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
vector: taichi.types.ndarray(ndim=1),
weight: taichi.types.ndarray(ndim=1),
out: taichi.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
taichi.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)


# def test_taichi_op_register():
# s = 1000
# indices = bm.random.randint(0, s, (s, 1000))
# vector = bm.random.rand(s) < 0.1
# weight = bm.array([1.0])
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)


# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
def test_taichi_op_register():
s = 1000
indices = bm.random.randint(0, s, (s, 1000))
vector = bm.random.rand(s) < 0.1
weight = bm.array([1.0])

# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

# print(out)
# bm.clear_buffer_memory()
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

print(out)
bm.clear_buffer_memory()

# test_taichi_op_register()
21 changes: 16 additions & 5 deletions brainpy/_src/math/scales.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-


from typing import Sequence, Union

__all__ = [
'Scaling',
'IdScaling',
Expand All @@ -13,11 +15,20 @@ def __init__(self, scale, bias):
self.bias = bias

@classmethod
def transform(cls, V_range:list, scaled_V_range:list):
'''
V_range: [V_min, V_max]
scaled_V_range: [scaled_V_min, scaled_V_max]
'''
def transform(
cls,
V_range: Sequence[Union[float, int]],
scaled_V_range: Sequence[Union[float, int]] = (0., 1.)
) -> 'Scaling':
"""Transform the membrane potential range to a ``Scaling`` instance.
Args:
V_range: [V_min, V_max]
scaled_V_range: [scaled_V_min, scaled_V_max]
Returns:
The instanced scaling object.
"""
V_min, V_max = V_range
scaled_V_min, scaled_V_max = scaled_V_range
scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)
Expand Down
5 changes: 5 additions & 0 deletions docs/apis/brainpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
:local:
:depth: 1


Numerical Differential Integration
----------------------------------

Expand Down Expand Up @@ -77,5 +78,9 @@ Dynamical System Helpers
:template: classtemplate.rst

LoopOverTime
reset_state
save_state
load_state
clear_input


1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ jaxlib
matplotlib>=3.4
msgpack
tqdm
taichi

# test requirements
pytest
Expand Down
1 change: 0 additions & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ jaxlib
matplotlib>=3.4
scipy>=1.1.0
numba
taichi

# document requirements
pandoc
Expand Down

0 comments on commit 40c2632

Please sign in to comment.