Skip to content

Commit

Permalink
CA1 tetrodes only, handle noise
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Sep 23, 2023
1 parent cbf98d3 commit 7fca7d9
Showing 1 changed file with 109 additions and 10 deletions.
119 changes: 109 additions & 10 deletions src/spyglass/anna_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@

import spyglass.common as sgc
import spyglass.position as sgp
from spyglass.common import (
BrainRegion,
ElectrodeGroup,
IntervalList,
interval_list_intersect,
)
from spyglass.decoding import UnitMarks
from spyglass.decoding.visualization import create_2D_decode_view
from spyglass.lfp import LFPOutput
from spyglass.lfp.analysis.v1 import lfp_band
from spyglass.spikesorting import ArtifactRemovedIntervalList, SortGroup

PROCESSED_DATA_DIR = (
"/cumulus/edeno/spyglass/notebooks/anna_data/Processed-Data"
Expand All @@ -36,7 +43,9 @@ def convert_dio_events_to_start_stop_times(event, event_name):


def load_position(
nwb_file_name, position_interval_name, trodes_pos_params_name="default_led0"
nwb_file_name,
position_interval_name,
trodes_pos_params_name="default_led0",
):
trodes_key = (
sgp.v1.TrodesPosSelection()
Expand Down Expand Up @@ -73,17 +82,96 @@ def load_position(
return position_info


def load_marks(nwb_file_name, epoch_name):
marks = (
UnitMarks
def load_marks(nwb_file_name, epoch_name, position_interval_name):
marks_table = (
(
UnitMarks
& {
"nwb_file_name": nwb_file_name,
"sort_interval_name": epoch_name,
"curation_id": 0,
}
)
* SortGroup.SortGroupElectrode
* ElectrodeGroup
)
marks_table = marks_table * BrainRegion
marks_table = pd.DataFrame(marks_table)

# restrict to CA1
restriction = marks_table.loc[marks_table.region_name == "ca1"].to_dict(
orient="records"
)
marks = (UnitMarks & restriction).fetch_dataframe()
marks = [(mark.index.to_numpy(), mark.to_numpy()) for mark in marks]
spike_times, spike_waveform_features = zip(*marks)
spike_times = list(spike_times)
spike_waveform_features = list(spike_waveform_features)

epoch_interval = (
IntervalList
& {"nwb_file_name": nwb_file_name, "interval_list_name": epoch_name}
).fetch1("valid_times")
ephys_interval = (
IntervalList
& {
"nwb_file_name": nwb_file_name,
"sort_interval_name": epoch_name,
"curation_id": 0,
"interval_list_name": "raw data valid times",
}
).fetch_dataframe()
marks = [(mark.index.to_numpy(), mark.to_numpy()) for mark in marks]
spike_times, spike_waveform_features = zip(*marks)
).fetch1("valid_times")
position_interval = (
IntervalList
& {
"nwb_file_name": nwb_file_name,
"interval_list_name": position_interval_name,
}
).fetch1("valid_times")

valid_interval = interval_list_intersect(
interval_list_intersect(epoch_interval, ephys_interval),
position_interval,
)

artifact_times = (ArtifactRemovedIntervalList & restriction).fetch(
"artifact_times"
)
for ind, (
tetrode_spike_times,
tetrode_waveform_features,
tetrode_artifact_times,
) in enumerate(zip(spike_times, spike_waveform_features, artifact_times)):
for start_time, end_time in tetrode_artifact_times:
is_in_artifact = np.logical_and(
tetrode_spike_times >= start_time,
tetrode_spike_times <= end_time,
)
tetrode_spike_times = tetrode_spike_times[~is_in_artifact]
tetrode_waveform_features = tetrode_waveform_features[
~is_in_artifact
]
for start_time, end_time in valid_interval:
is_in_interval = np.logical_and(
tetrode_spike_times >= start_time,
tetrode_spike_times <= end_time,
)
tetrode_spike_times = tetrode_spike_times[is_in_interval]
tetrode_waveform_features = tetrode_waveform_features[
is_in_interval
]

is_bad_spike = np.any(np.abs(tetrode_waveform_features) > 1000, axis=1)
for bad_spike_time in tetrode_spike_times[is_bad_spike]:
is_in_interval = np.logical_and(
tetrode_spike_times >= bad_spike_time - 0.002,
tetrode_spike_times <= bad_spike_time + 0.002,
)
tetrode_spike_times = tetrode_spike_times[~is_in_interval]
tetrode_waveform_features = tetrode_waveform_features[
~is_in_interval
]

spike_times[ind] = tetrode_spike_times
spike_waveform_features[ind] = tetrode_waveform_features

return spike_times, spike_waveform_features

Expand Down Expand Up @@ -221,7 +309,9 @@ def load_data(
position_info = load_position(
nwb_file_name, position_interval_name, trodes_pos_params_name
)
spike_times, spike_waveform_features = load_marks(nwb_file_name, epoch_name)
spike_times, spike_waveform_features = load_marks(
nwb_file_name, epoch_name, position_interval_name
)
beam_breaks, pump_events, light_events = load_dios(
nwb_file_name, position_info.index.to_numpy()[[0, -1]]
)
Expand All @@ -231,6 +321,15 @@ def load_data(
epoch_name,
)

# only analyze times between first and last beam break
start_time = beam_breaks.start_time.min()
end_time = beam_breaks.end_time.max()
position_info = position_info.loc[start_time:end_time]
lfp_df = lfp_df.loc[start_time:end_time]
theta_lfp = theta_lfp.loc[start_time:end_time]
theta_phase = theta_phase.loc[start_time:end_time]
theta_power = theta_power.loc[start_time:end_time]

return (
position_info,
spike_times,
Expand Down

0 comments on commit 7fca7d9

Please sign in to comment.