Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add method to modify synaptic gain #897

Merged
merged 14 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Changelog
- Add minimum spectral frequency widget to GUI for adjusting spectrogram
frequency axis, by `George Dang`_ in :gh:`894`

- Add method to modify synaptic gains, by `Nick Tolley`_ and `George Dang`_
in :gh:`897`

Bug
~~~
Expand Down
56 changes: 55 additions & 1 deletion hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def __eq__(self, other):
'clear_connectivity', 'clear_drives',
'connectivity', 'copy', 'gid_to_type',
'plot_cells', 'set_cell_positions',
'to_dict', 'write_configuration'])
'to_dict', 'write_configuration',
'update_weights'])
attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore]

for attr in attrs_to_check:
Expand Down Expand Up @@ -1358,6 +1359,8 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
_validate_type(item, (int, float), arg_name, 'int or float')
conn['nc_dict'][key] = item

conn['nc_dict']['gain'] = 1.0

# Probabilistically define connections
if probability != 1.0:
_connection_probability(conn, probability, conn_seed)
Expand Down Expand Up @@ -1428,6 +1431,55 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3,
method=method,
min_distance=min_distance)})

def update_weights(self, e_e=1.0, e_i=1.0, i_e=1.0, i_i=1.0, copy=True):
"""Update synaptic weights of the network.

Parameters
----------
e_e : float
Synaptic gain of excitatory to excitatory connections (default 1.0)
e_i : float
Synaptic gain of excitatory to inhibitory connections (default 1.0)
i_e : float
Synaptic gain of inhibitory to excitatory connections (default 1.0)
i_i : float
Synaptic gain of inhibitory to inhibitory connections (default 1.0)
copy : bool
If True, create a copy of the network. If False, update the network
in place (default True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns ?

"""
_validate_type(copy, bool, 'copy')

net = self.copy() if copy else self

e_conns = pick_connection(self, receptor=['ampa', 'nmda'])
e_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in e_conns]).tolist()

i_conns = pick_connection(self, receptor=['gabaa', 'gabab'])
i_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in i_conns]).tolist()
conn_types = {
'e_e': (e_e, e_cells, e_cells),
'e_i': (e_i, e_cells, i_cells),
'i_e': (i_e, i_cells, e_cells),
'i_i': (i_i, i_cells, i_cells)
}

for conn_type, (gain, e_vals, i_vals) in conn_types.items():
_validate_type(gain, (int, float), conn_type, 'int or float')
if gain < 0.0:
raise ValueError("Synaptic gains must be non-negative."
f"Got {gain} for '{conn_type}'.")

conn_indices = pick_connection(net, src_gids=e_vals,
target_gids=i_vals)
for conn_idx in conn_indices:
net.connectivity[conn_idx]['nc_dict']['gain'] = gain

if copy:
return net

def plot_cells(self, ax=None, show=True):
"""Plot the cells using Network.pos_dict.

Expand Down Expand Up @@ -1495,6 +1547,8 @@ class _Connectivity(dict):
Synaptic delay in ms.
lamtha : float
Space constant.
gain : float
Multiplicative factor for synaptic weight.
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
Expand Down
1 change: 1 addition & 0 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _connect_celltypes(self):
for conn in connectivity:
loc, receptor = conn['loc'], conn['receptor']
nc_dict = deepcopy(conn['nc_dict'])
nc_dict['A_weight'] *= nc_dict['gain']
ntolley marked this conversation as resolved.
Show resolved Hide resolved
# Gather indices of targets on current node
valid_targets = set()
for src_gid, target_gids in conn['gid_pairs'].items():
Expand Down
Loading
Loading