Skip to content

Commit

Permalink
Add docstrings to SortedSpikesGroup and Decoding methods (#996)
Browse files Browse the repository at this point in the history
* Add docstrings

* update changelog

* fix spelling

---------

Co-authored-by: Samuel Bray <[email protected]>
  • Loading branch information
samuelbray32 and Samuel Bray authored Jun 4, 2024
1 parent 04ec37a commit 6b49c2d
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- Don't insert lab member when creating lab team #983
- Spikesorting
- Allow user to set smoothing timescale in `SortedSpikesGroup.get_firing_rate` #994
- Update docstrings #996

## [0.5.2] (April 22, 2024)

Expand Down
112 changes: 111 additions & 1 deletion src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,33 @@ def make(self, key):

DecodingOutput.insert1(orig_key, skip_duplicates=True)

def fetch_results(self):
def fetch_results(self) -> xr.Dataset:
"""Retrieve the decoding results
Returns
-------
xr.Dataset
The decoding results (posteriors, etc.)
"""
return ClusterlessDetector.load_results(self.fetch1("results_path"))

def fetch_model(self):
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
"""Fetch the environments for the decoding model
Parameters
----------
key : dict
The decoding selection key
Returns
-------
List[TrackGraph]
list of track graphs in the trained model
"""
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand Down Expand Up @@ -309,6 +328,18 @@ def fetch_environments(key):

@staticmethod
def _get_interval_range(key):
"""Get the maximum range of model times in the encoding and decoding intervals
Parameters
----------
key : dict
The decoding selection key
Returns
-------
Tuple[float, float]
The minimum and maximum times for the model
"""
encoding_interval = (
IntervalList
& {
Expand Down Expand Up @@ -338,6 +369,18 @@ def _get_interval_range(key):

@staticmethod
def fetch_position_info(key):
"""Fetch the position information for the decoding model
Parameters
----------
key : dict
The decoding selection key
Returns
-------
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -363,6 +406,18 @@ def fetch_position_info(key):

@staticmethod
def fetch_linear_position_info(key):
"""Fetch the position information and project it onto the track graph
Parameters
----------
key : dict
The decoding selection key
Returns
-------
pd.DataFrame
The linearized position information
"""
environment = ClusterlessDecodingV1.fetch_environments(key)[0]

position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
Expand Down Expand Up @@ -391,6 +446,22 @@ def fetch_linear_position_info(key):

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
"""Fetch the spike times for the decoding model
Parameters
----------
key : dict
The decoding selection key
filter_by_interval : bool, optional
Whether to filter for spike times in the model interval, by default True
time_slice : Slice, optional
User provided slice of time to restrict spikes to, by default None
Returns
-------
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
waveform_keys = (
(
UnitWaveformFeaturesGroup.UnitFeatures
Expand Down Expand Up @@ -426,6 +497,20 @@ def fetch_spike_data(key, filter_by_interval=True):

@classmethod
def get_spike_indicator(cls, key, time):
"""get spike indicator matrix for the group
Parameters
----------
key : dict
key to identify the group
time : np.ndarray
time vector for which to calculate the spike indicator matrix
Returns
-------
np.ndarray
spike indicator matrix with shape (len(time), n_units)
"""
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.fetch_spike_data(key)[0]
Expand All @@ -442,6 +527,24 @@ def get_spike_indicator(cls, key, time):

@classmethod
def get_firing_rate(cls, key, time, multiunit=False):
"""get time-dependent firing rate for units in the group
Parameters
----------
key : dict
key to identify the group
time : np.ndarray
time vector for which to calculate the firing rate
multiunit : bool, optional
if True, return the multiunit firing rate for units in the group, by default False
smoothing_sigma : float, optional
standard deviation of gaussian filter to smooth firing rates in seconds, by default 0.015
Returns
-------
np.ndarray
_description_
"""
spike_indicator = cls.get_spike_indicator(key, time)
if spike_indicator.ndim == 1:
spike_indicator = spike_indicator[:, np.newaxis]
Expand All @@ -461,6 +564,13 @@ def get_firing_rate(cls, key, time, multiunit=False):
)

def get_ahead_behind_distance(self):
"""get the ahead-behind distance for the decoding model
Returns
-------
distance_metrics : np.ndarray
Information about the distance of the animal to the mental position.
"""
# TODO: allow specification of specific time interval
# TODO: allow specification of track graph
# TODO: Handle decode intervals, store in table
Expand Down
95 changes: 94 additions & 1 deletion src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,33 @@ def make(self, key):

DecodingOutput.insert1(orig_key, skip_duplicates=True)

def fetch_results(self):
def fetch_results(self) -> xr.Dataset:
"""Retrieve the decoding results
Returns
-------
xr.Dataset
The decoding results (posteriors, etc.)
"""
return SortedSpikesDetector.load_results(self.fetch1("results_path"))

def fetch_model(self):
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
"""Fetch the environments for the decoding model
Parameters
----------
key : dict
The decoding selection key
Returns
-------
List[TrackGraph]
list of track graphs in the trained model
"""
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand Down Expand Up @@ -273,6 +292,18 @@ def fetch_environments(key):

@staticmethod
def _get_interval_range(key):
"""Get the maximum range of model times in the encoding and decoding intervals
Parameters
----------
key : dict
The decoding selection key
Returns
-------
Tuple[float, float]
The minimum and maximum times for the model
"""
encoding_interval = (
IntervalList
& {
Expand Down Expand Up @@ -302,6 +333,18 @@ def _get_interval_range(key):

@staticmethod
def fetch_position_info(key):
"""Fetch the position information for the decoding model
Parameters
----------
key : dict
The decoding selection key
Returns
-------
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -326,6 +369,18 @@ def fetch_position_info(key):

@staticmethod
def fetch_linear_position_info(key):
"""Fetch the position information and project it onto the track graph
Parameters
----------
key : dict
The decoding selection key
Returns
-------
pd.DataFrame
The linearized position information
"""
environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
Expand All @@ -352,6 +407,22 @@ def fetch_linear_position_info(key):

@staticmethod
def fetch_spike_data(key, filter_by_interval=True, time_slice=None):
"""Fetch the spike times for the decoding model
Parameters
----------
key : dict
The decoding selection key
filter_by_interval : bool, optional
Whether to filter for spike times in the model interval, by default True
time_slice : Slice, optional
User provided slice of time to restrict spikes to, by default None
Returns
-------
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
spike_times = SortedSpikesGroup.fetch_spike_data(key)
if not filter_by_interval:
return spike_times
Expand All @@ -371,6 +442,13 @@ def fetch_spike_data(key, filter_by_interval=True, time_slice=None):
return new_spike_times

def spike_times_sorted_by_place_field_peak(self, time_slice=None):
"""Spike times of units sorted by place field peak location
Parameters
----------
time_slice : Slice, optional
time range to limit returned spikes to, by default None
"""
if time_slice is None:
time_slice = slice(-np.inf, np.inf)

Expand All @@ -395,8 +473,23 @@ def spike_times_sorted_by_place_field_peak(self, time_slice=None):
]
for neuron_ind in neuron_sort_ind
]
return new_spike_times

def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
"""Get the ahead-behind distance of the decoded position from the animal's actual position
Parameters
----------
track_graph : TrackGraph, optional
environment track graph to project position on, by default None
time_slice : Slice, optional
time intrerval to restrict to, by default None
Returns
-------
distance_metrics : np.ndarray
Information about the distance of the animal to the mental position.
"""
# TODO: store in table

if time_slice is None:
Expand Down
Loading

0 comments on commit 6b49c2d

Please sign in to comment.