Skip to content

Commit

Permalink
Linearizationv1 debug and tutorial (#695)
Browse files Browse the repository at this point in the history
* Resolve minor discrepencies with PositionOutput dependency

* Project merge_id from PositionOutput to new name in Linearization

* Update make function to use pos_merge_id

* Update Linearization tutorial to v1 pipeline

* Update changelog

* Remove in progress node picker section fo tutorial

* Style cleanup
  • Loading branch information
samuelbray32 authored Nov 29, 2023
1 parent 04ee070 commit 4fa761a
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 626 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Additional documentation. #686
- Refactor input validation in DLC pipeline.
- Clean up following pre-commit checks.
- Minor fixes to LinearizedPositionV1 pipeline #695

## [0.4.3] (November 7, 2023)

Expand Down
866 changes: 361 additions & 505 deletions notebooks/24_Linearization.ipynb

Large diffs are not rendered by default.

174 changes: 58 additions & 116 deletions notebooks/py_scripts/24_Linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
# format_version: '1.5'
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3.10.5 64-bit
# display_name: spyglass
# language: python
# name: python3
# ---

# # Position - Linearization

# ## Overview
#

# _Developer Note:_ if you may make a PR in the future, be sure to copy this
# notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.
Expand All @@ -28,7 +27,7 @@
# inserts, see
# [the Insert Data notebook](./01_Insert_Data.ipynb)
#
# This pipeline takes 2D position data from the `IntervalPositionInfo` table and
# This pipeline takes 2D position data from the `PositionOutput` table and
# "linearizes" it to 1D position. If you haven't already done so, please generate
# input data with either the [Trodes](./20_Position_Trodes.ipynb) or DLC notebooks
# ([1](./21_Position_DLC_1.ipynb), [2](./22_Position_DLC_2.ipynb),
Expand All @@ -54,6 +53,7 @@

import spyglass.common as sgc
import spyglass.position.v1 as sgp
import spyglass.position_linearization.v1 as sgpl

# ignore datajoint+jupyter async warnings
import warnings
Expand All @@ -63,39 +63,47 @@
# -

# ## Retrieve 2D position
#

# To retrieve 2D position data, we'll specify an nwb file, a position time
# interval, and the set of parameters used to compute the position info.

from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename

nwb_file_name = "chimi20200216_new.nwb"
nwb_copy_file_name = sgc.nwb_helper_fn.get_nwb_copy_filename(nwb_file_name)
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
nwb_copy_file_name

# We will fetch the pandas dataframe from the `IntervalPositionInfo`.
# We will fetch the pandas dataframe from the `PositionOutput` table.

position_info = (
sgc.common_position.IntervalPositionInfo()
& {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default",
}
).fetch1_dataframe()
# +
from spyglass.position import PositionOutput
import pandas as pd

pos_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default",
}

# Note: You'll have to change the part table to the one where your data came from
merge_id = (PositionOutput.TrodesPosV1() & pos_key).fetch1("merge_id")
position_info = (PositionOutput & {"merge_id": merge_id}).fetch1_dataframe()
position_info
# -

# Before linearizing, plotting the head position will help us understand the data.

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
position_info.position_x,
position_info.position_y,
color="lightgrey",
)
ax.set_xlabel("x-position [cm]", fontsize=18)
ax.set_ylabel("y-position [cm]", fontsize=18)
ax.set_title("Head Position", fontsize=28)


# ## Specifying the track
#

Expand Down Expand Up @@ -167,7 +175,7 @@
# `environment`.

# +
sgc.common_position.TrackGraph.insert1(
sgpl.TrackGraph.insert1(
{
"track_graph_name": "6 arm",
"environment": "6 arm",
Expand All @@ -179,17 +187,17 @@
skip_duplicates=True,
)

graph = sgc.common_position.TrackGraph & {"track_graph_name": "6 arm"}
graph = sgpl.TrackGraph & {"track_graph_name": "6 arm"}
graph
# -

# `TrackGraph` has several methods for visualizing in 2D and 1D space.
# `plot_track_graph` plots in 2D to make sure our layout makes sense.

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
position_info.position_x,
position_info.position_y,
color="lightgrey",
alpha=0.7,
zorder=0,
Expand All @@ -200,7 +208,7 @@

# `plot_track_graph_as_1D` shows what this looks like in 1D.

fig, ax = plt.subplots(1, 1, figsize=(20, 1))
fig, ax = plt.subplots(1, 1, figsize=(10, 1))
graph.plot_track_graph_as_1D(ax=ax)

# ## Parameters
Expand All @@ -216,58 +224,62 @@
# at intersections or the head position swings closer to a non-target reward well
# while on another edge.

sgc.common_position.LinearizationParameters.insert1(
sgpl.LinearizationParameters.insert1(
{"linearization_param_name": "default"}, skip_duplicates=True
)
sgc.common_position.LinearizationParameters()
sgpl.LinearizationParameters()

# ## Linearization

# With linearization parameters, we specify the position interval we wish to
# linearize.
# linearize from the `PositionOutput` table and create an entry in `LinearizationSelection`

sgc.Session & {"nwb_file_name": nwb_copy_file_name}

# +
sgc.common_position.IntervalLinearizationSelection.insert1(
sgpl.LinearizationSelection.insert1(
{
"position_info_param_name": "default",
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"pos_merge_id": merge_id,
"track_graph_name": "6 arm",
"linearization_param_name": "default",
},
skip_duplicates=True,
)

sgc.common_position.IntervalLinearizationSelection()
sgpl.LinearizationSelection()
# -

# And then run linearization by populating `IntervalLinearizedPosition`.
# And then run linearization by populating `LinearizedPositionV1`.

sgc.common_position.IntervalLinearizedPosition().populate()
sgc.common_position.IntervalLinearizedPosition()
sgpl.LinearizedPositionV1().populate()
sgpl.LinearizedPositionV1()

# ## Examine data
#

# Populating `LinearizedPositionV1` also creates a corresponding entry in the `LinearizedPositionOutput` merge table. For downstream compatibility with alternate versions of the Linearization pipeline, we should fetch our data from here
#
# Running `fetch1_dataframe` will retrieve the linear position data, including...
#
# - `time`: dataframe index
# - `linear_position`: 1D linearized position
# - `track_segment_id`: index number of the edges given to track graph
# - `projected_{x,y}_position`: 2D position projected to the track graph

linear_position_df = (
IntervalLinearizedPosition()
& {
"position_info_param_name": "default",
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"track_graph_name": "6 arm",
"linearization_param_name": "default",
}
).fetch1_dataframe()
# +
linear_key = {
"pos_merge_id": merge_id,
"track_graph_name": "6 arm",
"linearization_param_name": "default",
}

from spyglass.position_linearization import LinearizedPositionOutput

linear_position_df = (LinearizedPositionOutput & linear_key).fetch1_dataframe()
linear_position_df

# -

# We'll plot the 1D position over time, colored by edge, and use the 1D track
# graph layout on the y-axis.

Expand All @@ -292,8 +304,8 @@

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
position_info.position_x,
position_info.position_y,
color="lightgrey",
alpha=0.7,
zorder=0,
Expand All @@ -304,73 +316,3 @@
linear_position_df.projected_x_position,
linear_position_df.projected_y_position,
)

# ## Interactive selection
#

#
# _Note:_ Work in Progress
#
# ### `NodePicker`
#
# Linearization heavily depends on the track graph is specified. To help simplify
# setting the nodes/edges, we can use the `NodePicker` to interactively set node
# positions and edges based on the video.

# +
# %matplotlib widget

key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
}

epoch = (
int(
key["interval_list_name"]
.replace("pos ", "")
.replace(" valid times", "")
)
+ 1
)
video_info = (
sgc.common_behav.VideoFile()
& {"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
).fetch1()

io = pynwb.NWBHDF5IO("/stelmo/nwb/raw/" + video_info["nwb_file_name"], "r")
nwb_file = io.read()
nwb_video = nwb_file.objects[video_info["video_file_object_id"]]
video_filename = nwb_video.external_file.value[0]

fig, ax = plt.subplots(figsize=(8, 8))
picker = sgc.common_position.NodePicker(ax=ax, video_filename=video_filename)
# -

# After selection, we can retrieve the data using the `node_positions` and `edges`
# attributes

picker.node_positions

picker.edges

# ### Selector
#
# We can also draw a 2d polygon around the track and attempt to recover the graph.

# +
# %matplotlib widget

fig, ax = plt.subplots(figsize=(8, 8))
selector = sgc.common_position.SelectFromCollection(ax, video_filename)

print("Select points in the figure by enclosing them within a polygon.")
print("Press the 'esc' key to start a new polygon.")
print("Try holding the 'shift' key to move all of the vertices.")
print("Try holding the 'ctrl' key to move a single vertex.")
# -

# ## Up Next
#
# Next, we'll combine ephys and position data in a process called
# [ripple detection](./30_Ripple_Detection.ipynb).
13 changes: 8 additions & 5 deletions src/spyglass/position_linearization/v1/linearization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy

import datajoint as dj
from datajoint.utils import to_camel_case
import numpy as np
from datajoint.utils import to_camel_case
from track_linearization import (
get_linearized_position,
make_track_graph,
Expand Down Expand Up @@ -94,7 +95,7 @@ def plot_track_graph_as_1D(
@schema
class LinearizationSelection(dj.Lookup):
definition = """
-> PositionOutput
-> PositionOutput.proj(pos_merge_id='merge_id')
-> TrackGraph
-> LinearizationParameters
---
Expand All @@ -116,9 +117,11 @@ def make(self, key):
orig_key = copy.deepcopy(key)
print(f"Computing linear position for: {key}")

position_nwb = PositionOutput.fetch_nwb(key)[0]
position_nwb = PositionOutput.fetch_nwb(
{"merge_id": key["pos_merge_id"]}
)[0]
key["analysis_file_name"] = AnalysisNwbfile().create(
key["nwb_file_name"]
position_nwb["nwb_file_name"]
)
position = np.asarray(
position_nwb["position"].get_spatial_series().data
Expand Down Expand Up @@ -164,7 +167,7 @@ def make(self, key):
)

nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
nwb_file_name=position_nwb["nwb_file_name"],
analysis_file_name=key["analysis_file_name"],
)

Expand Down

0 comments on commit 4fa761a

Please sign in to comment.