Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add method to modify synaptic gain #897

Merged
merged 14 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Changelog
- Add minimum spectral frequency widget to GUI for adjusting spectrogram
frequency axis, by `George Dang`_ in :gh:`894`

- Add method to modify synaptic gains, by `Nick Tolley`_ and `George Dang`_
in :gh:`897`

Bug
~~~
Expand Down
76 changes: 75 additions & 1 deletion hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def __eq__(self, other):
'clear_connectivity', 'clear_drives',
'connectivity', 'copy', 'gid_to_type',
'plot_cells', 'set_cell_positions',
'to_dict', 'write_configuration'])
'to_dict', 'write_configuration',
'update_weights'])
attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore]

for attr in attrs_to_check:
Expand Down Expand Up @@ -1358,6 +1359,8 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
_validate_type(item, (int, float), arg_name, 'int or float')
conn['nc_dict'][key] = item

conn['nc_dict']['gain'] = 1.0

# Probabilistically define connections
if probability != 1.0:
_connection_probability(conn, probability, conn_seed)
Expand Down Expand Up @@ -1428,6 +1431,75 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3,
method=method,
min_distance=min_distance)})

def update_weights(self, e_e=None, e_i=None,
i_e=None, i_i=None, copy=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default behavior of the method be copy=False? Will most use cases be updating the weights in-place or saving to a new object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed! will change to false and then good to merge

"""Update synaptic weights of the network.

Parameters
----------
e_e : float
Synaptic gain of excitatory to excitatory connections
(default None)
e_i : float
Synaptic gain of excitatory to inhibitory connections
(default None)
i_e : float
Synaptic gain of inhibitory to excitatory connections
(default None)
i_i : float
Synaptic gain of inhibitory to inhibitory connections
(default None)
copy : bool
If True, returns a copy of the network. If False,
the network is updated in place with a return of None.

Returns
-------
net : instance of Network
A copy of the instance with updated synaptic gains if copy=True.

Notes
-----
Synaptic gains must be non-negative. The synaptic gains will only be
updated if a float value is provided. If None is provided
(the default), the synapticgain will remain unchanged.

"""
_validate_type(copy, bool, 'copy')

net = self.copy() if copy else self

e_conns = pick_connection(self, receptor=['ampa', 'nmda'])
e_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in e_conns]).tolist()

i_conns = pick_connection(self, receptor=['gabaa', 'gabab'])
i_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in i_conns]).tolist()
conn_types = {
'e_e': (e_e, e_cells, e_cells),
'e_i': (e_i, e_cells, i_cells),
'i_e': (i_e, i_cells, e_cells),
'i_i': (i_i, i_cells, i_cells)
}

for conn_type, (gain, e_vals, i_vals) in conn_types.items():
if gain is None:
continue

_validate_type(gain, (int, float), conn_type, 'int or float')
if gain < 0.0:
raise ValueError("Synaptic gains must be non-negative."
f"Got {gain} for '{conn_type}'.")

conn_indices = pick_connection(net, src_gids=e_vals,
target_gids=i_vals)
for conn_idx in conn_indices:
net.connectivity[conn_idx]['nc_dict']['gain'] = gain

if copy:
return net

def plot_cells(self, ax=None, show=True):
"""Plot the cells using Network.pos_dict.

Expand Down Expand Up @@ -1495,6 +1567,8 @@ class _Connectivity(dict):
Synaptic delay in ms.
lamtha : float
Space constant.
gain : float
Multiplicative factor for synaptic weight.
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
Expand Down
1 change: 1 addition & 0 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _connect_celltypes(self):
for conn in connectivity:
loc, receptor = conn['loc'], conn['receptor']
nc_dict = deepcopy(conn['nc_dict'])
nc_dict['A_weight'] *= nc_dict['gain']
ntolley marked this conversation as resolved.
Show resolved Hide resolved
# Gather indices of targets on current node
valid_targets = set()
for src_gid, target_gids in conn['gid_pairs'].items():
Expand Down
Loading
Loading