Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

csr and ij connection show unexpected different behavior(for csr connection, multiple connection is not regarded) #667

Open
ZhenyuanJin opened this issue May 1, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@ZhenyuanJin
Copy link

import sys
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
import os
from scipy.sparse import coo_matrix, csr_matrix

brainpy version 2.5.0

class EINet(bp.DynSysGroup):
def init(self, E_neuron, I_neuron, E_params, I_params, E2E_synapse, E2I_synapse, I2E_synapse, I2I_synapse, E2E_synapse_params, E2I_synapse_params, I2E_synapse_params, I2I_synapse_params, E2E_comm, E2I_comm, I2E_comm, I2I_comm):
super().init()

    self.E_params = E_params.copy()
    self.I_params = I_params.copy()
    self.E2E_synapse_params = E2E_synapse_params.copy()
    self.E2I_synapse_params = E2I_synapse_params.copy()
    self.I2E_synapse_params = I2E_synapse_params.copy()
    self.I2I_synapse_params = I2I_synapse_params.copy()

    # neurons
    self.E = E_neuron(**self.E_params)
    self.I = I_neuron(**self.I_params)

    # synapses
    self.E2E = E2E_synapse(pre=self.E, post=self.E, comm=E2E_comm, **self.E2E_synapse_params)
    self.E2I = E2I_synapse(pre=self.E, post=self.I, comm=E2I_comm, **self.E2I_synapse_params)
    self.I2E = I2E_synapse(pre=self.I, post=self.E, comm=I2E_comm, **self.I2E_synapse_params)
    self.I2I = I2I_synapse(pre=self.I, post=self.I, comm=I2I_comm, **self.I2I_synapse_params)
    
def update(self, E_inp, I_inp):
    self.E2E()
    self.E2I()
    self.I2E()
    self.I2I()
    self.E(E_inp)
    self.I(I_inp)

    # monitor
    return self.E.spike, self.I.spike, self.E.V, self.I.V

def get_run_func(EI_net, E_inp_kwargs, I_inp_kwargs, E_size, I_size, input_type='constant'):
if input_type == 'constant':
def run_func(i):
return EI_net.step_run(i, E_inp_kwargs['mean'], I_inp_kwargs['mean'])
if input_type == 'wiener':
def run_func(i):
local_E_inp = np.random.randn(E_size)*E_inp_kwargs['std'] + E_inp_kwargs['mean']
local_I_inp = np.random.randn(I_size)*I_inp_kwargs['std'] + I_inp_kwargs['mean']
return EI_net.step_run(i, local_E_inp, local_I_inp)
return run_func

# 利用brainpy搭建SNN

E2E_weight = 1
E2I_weight = 0
I2E_weight = 0
I2I_weight = 0
E_size = 2
I_size = 2

E_params = {'size': E_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
I_params = {'size': I_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
E2E_synapse_params = {'delay': 0}
E2I_synapse_params = {'delay': 0}
I2E_synapse_params = {'delay': 0}
I2I_synapse_params = {'delay': 0}

E_inp_mean = np.zeros(E_size)
E_inp_mean[0] = 30 # 只在第一个神经元输入电流
E_inp_std = 0
I_inp_mean = 0
I_inp_std = 0

dt = 1.
bm.set_dt(dt)

for mode in ['csr', 'ij']:
if mode == 'csr':
# 利用csr创建conn
row_indices = np.array([0, 0])
col_indices = np.array([1, 1])
E2E_csr = csr_matrix((np.ones_like(row_indices), (row_indices, col_indices)), shape=(E_size, E_size))
E2E_conn = bp.connect.SparseMatConn(E2E_csr)
if mode == 'ij':
# 利用ij创建conn
pre_list = np.array([0, 0])
post_list = np.array([1, 1])
E2E_conn = bp.conn.IJConn(i=pre_list, j=post_list)
E2E_conn = E2E_conn(pre_size=E_size, post_size=E_size)

# 利用conn创建comm
E2E_comm = bp.dnn.EventCSRLinear(conn=E2E_conn, weight=E2E_weight)

# set the weight 0 thus connection will be ignored
E2I_conn = E2E_conn
E2I_comm = bp.dnn.EventCSRLinear(conn=E2I_conn, weight=E2I_weight)

# set the weight 0 thus connection will be ignored
I2E_conn = E2E_conn
I2E_comm = bp.dnn.EventCSRLinear(conn=I2E_conn, weight=I2E_weight)

# set the weight 0 thus connection will be ignored
I2I_conn = E2E_conn
I2I_comm = bp.dnn.EventCSRLinear(conn=I2I_conn, weight=I2I_weight)

EI_net = EINet(E_neuron=bp.dyn.LifRef, I_neuron=bp.dyn.LifRef, E_params=E_params, I_params=I_params, E2E_synapse=bp.dyn.FullProjDelta, E2I_synapse=bp.dyn.FullProjDelta, I2E_synapse=bp.dyn.FullProjDelta, I2I_synapse=bp.dyn.FullProjDelta, E2E_synapse_params=E2E_synapse_params, E2I_synapse_params=E2I_synapse_params, I2E_synapse_params=I2E_synapse_params, I2I_synapse_params=I2I_synapse_params, E2E_comm=E2E_comm, E2I_comm=E2I_comm, I2E_comm=I2E_comm, I2I_comm=I2I_comm)

run_func = get_run_func(EI_net, {'mean': E_inp_mean, 'std': E_inp_std}, {'mean': I_inp_mean, 'std': I_inp_std}, E_size, I_size, input_type='wiener')

indices = np.arange(100)
ts = indices * bm.get_dt()
E_spikes, I_spikes, E_V, I_V = bm.for_loop(run_func, indices, progress_bar=True)

fig, ax = plt.subplots()
ax.plot(ts, E_V[:, 0], label='E neuron 0', color='blue')
ax.plot(ts, E_V[:, 1], label='E neuron 1', color='red')
ax.axhline(E2E_weight, label='E2E weight', color='black')
ax.axhline(E2E_weight*2, label='2*E2E weight', color='black', linestyle='--')
ax.legend()
ax.set_title(f'Connection mode: {mode}')
@ZhenyuanJin ZhenyuanJin added the bug Something isn't working label May 1, 2024
@Routhleck
Copy link
Member

Apologies for the delayed response. Based on your description, I have some understanding of the issue. It seems that in BrainPy, for CSR matrix connections, the handling of multiple connections hasn't been considered. In BrainPy's CSRConn , the problem of repeated connections is indeed not considered. If we print the conn matrix here and pre_ids, post_ids, it will be like this.

source code: https://gist.github.com/Routhleck/f37c18283c169ed3148f2ab0ac6a1a08

csr:
 conn_mat: [[False  True]
 [False False]]
 pre ids: [0]
 post ids: [1]

ij:
 conn_mat: [[False  True]
 [False False]]
 pre ids: [0 0]
 post ids: [1 1]

To address this, I suggest considering a custom function to handle multiple connections. This could involve assigning weight values to each connection. In BrainPy, the CSR matrix multiplication operator can accept weights to perform the corresponding calculations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants