Skip to content

Commit

Permalink
Refactored code
Browse files Browse the repository at this point in the history
  • Loading branch information
raj1701 committed Aug 24, 2023
1 parent 8952d98 commit 6c55cb7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 62 deletions.
104 changes: 44 additions & 60 deletions hnn_core/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def gid(self, gid):
raise RuntimeError('Global ID for this cell already assigned!')


def _get_nseg(L):
nseg = 1
if L > 100.: # 100 um
nseg = int(L / 50.)
# make dend.nseg odd for all sections
if not nseg % 2:
nseg += 1
return nseg


class Section:
"""Section class.
Expand Down Expand Up @@ -205,12 +215,7 @@ def __init__(self, L, diam, Ra, cm, end_pts=None):
self.syns = list()

# For distance functionality
self.nseg = 1
if self.L > 100.: # 100 um
self.nseg = int(self.L / 50.)
# make dend.nseg odd for all sections
if not self.nseg % 2:
self.nseg += 1
self.nseg = _get_nseg(self.L)

def __repr__(self):
return f'L={self.L}, diam={self.diam}, cm={self.cm}, Ra={self.Ra}'
Expand Down Expand Up @@ -265,7 +270,7 @@ class Cell:
where sec_name is the name of the section and node_pos is the 0 end
or 1 end. The data structure is the adjacency list represetation of a
tree. The keys of the dict are the parent nodes. The value is the
list of nodes (chldren nodes) connected to the parent node.
list of nodes (children nodes) connected to the parent node.
Attributes
Expand Down Expand Up @@ -303,7 +308,7 @@ class Cell:
where sec_name is the name of the section and node_pos is the 0 end
or 1 end. The data structure is the adjacency list represetation of a
tree. The keys of the dict are the parent nodes. The value is the
list of nodes (chldren nodes) connected to the parent node.
list of nodes (children nodes) connected to the parent node.
Examples
--------
Expand Down Expand Up @@ -342,8 +347,7 @@ def __init__(self, name, pos, sections, synapses, sect_loc, cell_tree,
# Store the tree representation of the cell
self.cell_tree = cell_tree

# self._update_end_pts() # Old implementation
self.update_end_pts() # New implementation
self._update_end_pts() # New implementation

self._compute_section_mechs() # Set mech values of all sections

Expand All @@ -365,13 +369,28 @@ def gid(self, gid):
raise RuntimeError('Global ID for this cell already assigned!')

def distance_section(self, target_sec_name, curr_node):
"""Find distance between the current node and the target section.
Parameters
----------
target_sec_name : string
Name of the target section
curr_node : tuple
Source node from where search begins.
It is of the the form (sec_name, end_pt).
Returns
-------
distance : float
Path distance between source node and mid of the target section.
"""
# Python version of the Neuron distance function
# https://nrn.readthedocs.io/en/latest/python/modelspec/programmatic/topology/geometry.html#distance # noqa
if self.cell_tree is None:
raise TypeError("distance_section() "
"cannot work with cell_tree as None.")
if curr_node not in self.cell_tree:
return 1000000
return np.nan

# Children of the current section
curr_sec_children = self.cell_tree[curr_node]
Expand All @@ -386,7 +405,7 @@ def distance_section(self, target_sec_name, curr_node):
if (target_sec_name, end_pt) in curr_sec_children:
return self.sections[target_sec_name].L / 2

dist = 1000000 # Return large value
dist = np.nan # Return nan

# Recursion to find distance
for node in self.cell_tree[curr_node]:
Expand All @@ -395,7 +414,10 @@ def distance_section(self, target_sec_name, curr_node):
self.sections[node[0]].L)
else:
dist_temp = self.distance_section(target_sec_name, node)
dist = min(dist, dist_temp)
if np.isnan(dist) and np.isnan(dist_temp):
dist = np.nan
else:
dist = np.nanmin([dist, dist_temp])

return dist

Expand Down Expand Up @@ -437,19 +459,10 @@ def _compute_section_mechs(self):
seg_xs, seg_vals = list(), list()
section_distance = self.distance_section(sec_name,
('soma', 0))
# Finding centres of all segments in the section
# If number of segments is 5 then seg_centres will
# be 0.1, 0.3, 0.5, 0.7 and 0.9.
adjacent_seg_dist = 1 / section.nseg
first_seg_centre = (0.5 - (((section.nseg - 1) / 2) *
adjacent_seg_dist))
last_seg_centre = (0.5 + (((section.nseg - 1) / 2) *
adjacent_seg_dist))
seg_centres = list(np.linspace(first_seg_centre,
last_seg_centre,
num=section.nseg))

for seg_x in seg_centres:
seg_centers = (np.linspace(0, 1, section.nseg * 2 + 1)
[1::2])

for seg_x in seg_centers:
# sec_end_dist is distance between 0 end of soma to
# the 0 or 1 end of section (whichever is closer)
sec_end_dist = section_distance - (section.L / 2)
Expand Down Expand Up @@ -479,11 +492,6 @@ def _create_sections(self, sections, cell_tree):
"""
if 'soma' not in self.sections:
raise KeyError('soma must be defined for cell')
# shift cell to self.pos and reorient apical dendrite
# along z direction of self.pos
dx = self.pos[0] - self.sections['soma'].end_pts[0][0]
dy = self.pos[1] - self.sections['soma'].end_pts[0][1]
dz = self.pos[2] - self.sections['soma'].end_pts[0][2]

for sec_name in sections:
sec = h.Section(name=f'{self.name}_{sec_name}')
Expand All @@ -492,20 +500,13 @@ def _create_sections(self, sections, cell_tree):
h.pt3dclear(sec=sec)
h.pt3dconst(0, sec=sec) # be explicit, see documentation
for pt in sections[sec_name].end_pts:
h.pt3dadd(pt[0] + dx,
pt[1] + dy,
pt[2] + dz, 1, sec=sec)
h.pt3dadd(pt[0], pt[1], pt[2], 1, sec=sec)
# with pt3dconst==0, these will alter the 3d points defined above!
sec.L = sections[sec_name].L
sec.diam = sections[sec_name].diam
sec.Ra = sections[sec_name].Ra
sec.cm = sections[sec_name].cm

if sec.L > 100.: # 100 um
sec.nseg = int(sec.L / 50.)
# make dend.nseg odd for all sections
if not sec.nseg % 2:
sec.nseg += 1
sec.nseg = sections[sec_name].nseg

if cell_tree is None:
cell_tree = dict()
Expand Down Expand Up @@ -536,8 +537,6 @@ def build(self, sec_name_apical=None):
with this section. The section should belong to the apical dendrite
of a pyramidal neuron.
"""
from .network_builder import load_custom_mechanisms
load_custom_mechanisms()
self._create_sections(self.sections, self.cell_tree)
self._create_synapses(self.sections, self.synapses)
self._set_biophysics(self.sections)
Expand Down Expand Up @@ -774,23 +773,6 @@ def plot_morphology(self, ax=None, color=None, show=True):
"""
return plot_cell_morphology(self, ax=ax, color=color, show=show)

def _update_end_pts(self):
""""Create cell and copy coordinates to Section.end_pts"""
self._create_sections(self.sections, self.cell_tree)
section_names = list(self.sections.keys())

for name in section_names:
nrn_pts = self._nrn_sections[name].psection()['morphology'][
'pts3d']

del self._nrn_sections[name]

x0, y0, z0 = nrn_pts[0][0], nrn_pts[0][1], nrn_pts[0][2]
x1, y1, z1 = nrn_pts[1][0], nrn_pts[1][1], nrn_pts[1][2]
self.sections[name]._end_pts = [[x0, y0, z0], [x1, y1, z1]]

self._nrn_sections = dict()

def _update_section_end_pts_L(self, node, dpt):
if self.cell_tree is None:
return
Expand All @@ -811,6 +793,7 @@ def _update_section_end_pts_L(self, node, dpt):

def define_shape(self, node):
"""Redefines end_pts according to section lengths.
Detects change in section lengths of the sections in the
subtree of the input node.
Expand Down Expand Up @@ -864,8 +847,9 @@ def define_shape(self, node):
for child_node in self.cell_tree[node]:
self.define_shape(child_node)

def update_end_pts(self):
def _update_end_pts(self):
"""Update all end pts according to the length of the sections.
Can be used whenever length of any section is updated
Returns
Expand Down
4 changes: 2 additions & 2 deletions hnn_core/tests/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_cell():
end_pts_original.append(section.end_pts)
section._L = section._L * 2
cell1.sections[sec_name] = section
cell1.update_end_pts()
cell1._update_end_pts()
for sec_name in cell1.sections.keys():
end_pts_new.append(cell1.sections[sec_name].end_pts)

Expand All @@ -133,7 +133,7 @@ def test_cell():
cell1.sections[sec_name] = section

end_pts_new = list()
cell1.update_end_pts()
cell1._update_end_pts()
for sec_name in cell1.sections.keys():
section = cell1.sections[sec_name]
cell1.sections[sec_name] = section
Expand Down

0 comments on commit 6c55cb7

Please sign in to comment.