Skip to content

Commit

Permalink
added bokeh implementation of view_quilt
Browse files Browse the repository at this point in the history
  • Loading branch information
salehtahini committed Jun 17, 2024
1 parent 398c64f commit 860e8d8
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 9 deletions.
116 changes: 116 additions & 0 deletions caiman/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,3 +1362,119 @@ def view_quilt(template_image: np.ndarray,
ax=ax)

return ax


def create_quilt_patches(patch_rows, patch_cols):
"""
Helper function for nb_view_quilt.
Create patches given the row and column coordinates.
Args:
patch_rows (ndarray): Array of row start and end positions for each patch.
patch_cols (ndarray): Array of column start and end positions for each patch.
Returns:
list: A list of dictionaries, each containing the center coordinates, width,
and height of a patch.
"""
patches = []
for row in patch_rows:
for col in patch_cols:
center_x = (col[0] + col[1]) / 2
center_y = (row[0] + row[1]) / 2
width = col[1] - col[0]
height = row[1] - row[0]
patches.append({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height})
return patches


def nb_view_quilt(template_image: np.ndarray,
rf: int,
stride_input: int,
color: Optional[Any]='white',
alpha: Optional[float]=0.2):
"""
Bokeh implementation of view_quilt.
Plot patches on the template image given stride and overlap parameters.
Args:
template_image (ndarray): Row x column summary image upon which to draw patches (e.g., correlation image).
rf (int): Half-size of the patches in pixels (patch width is rf*2 + 1).
stride_input (int): Amount of overlap between the patches in pixels (overlap is stride_input + 1).
color (Optional[Any]): Color of the patches, default 'white'.
alpha (Optional[float]): Patch transparency, default 0.2.
"""

width = (rf*2)+1
overlap = stride_input+1
stride = width-overlap

im_dims = template_image.shape
patch_rows, patch_cols = get_rectangle_coords(im_dims, stride, overlap)
patches = create_quilt_patches(patch_rows, patch_cols)

plot = bpl.figure(x_range=(0, im_dims[1]), y_range=(im_dims[0], 0), width=600, height=600)
#plot.y_range.flipped = True
plot.image(image=[template_image], x=0, y=0, dw=im_dims[1], dh=im_dims[0], palette="Greys256")
source = ColumnDataSource(data=dict(
center_x=[patch['center_x'] for patch in patches],
center_y=[patch['center_y'] for patch in patches],
width=[patch['width'] for patch in patches],
height=[patch['height'] for patch in patches]
))
plot.rect(x='center_x', y='center_y', width='width', height='height', source=source, color=color, alpha=alpha)

# Create sliders
stride_slider = Slider(start=1, end=100, value=rf, step=1, title="Patch half-size (rf)")
overlap_slider = Slider(start=0, end=100, value=stride_input, step=1, title="Overlap (stride)")

callback = CustomJS(args=dict(source=source, im_dims=im_dims, stride_slider=stride_slider, overlap_slider=overlap_slider), code="""
function get_rectangle_coords(im_dims, stride, overlap) {
let patch_width = overlap + stride;
let patch_onset_rows = Array.from({length: Math.ceil((im_dims[0] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[0] - patch_width]);
let patch_offset_rows = patch_onset_rows.map(x => Math.min(x + patch_width, im_dims[0] - 1));
let patch_rows = patch_onset_rows.map((x, i) => [x, patch_offset_rows[i]]);
let patch_onset_cols = Array.from({length: Math.ceil((im_dims[1] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[1] - patch_width]);
let patch_offset_cols = patch_onset_cols.map(x => Math.min(x + patch_width, im_dims[1] - 1));
let patch_cols = patch_onset_cols.map((x, i) => [x, patch_offset_cols[i]]);
return [patch_rows, patch_cols];
}
function create_quilt_patches(patch_rows, patch_cols) {
let patches = [];
for (let row of patch_rows) {
for (let col of patch_cols) {
let center_x = (col[0] + col[1]) / 2;
let center_y = (row[0] + row[1]) / 2;
let width = col[1] - col[0];
let height = row[1] - row[0];
patches.push({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height});
}
}
return patches;
}
let width = (stride_slider.value * 2) + 1;
let overlap = overlap_slider.value + 1;
let stride = width - overlap
let [patch_rows, patch_cols] = get_rectangle_coords(im_dims, stride, overlap);
let patches = create_quilt_patches(patch_rows, patch_cols);
source.data = {
center_x: patches.map(patch => patch.center_x),
center_y: patches.map(patch => patch.center_y),
width: patches.map(patch => patch.width),
height: patches.map(patch => patch.height)
};
source.change.emit();
""")

stride_slider.js_on_change('value', callback)
overlap_slider.js_on_change('value', callback)


bpl.show(bokeh.layouts.row(plot, bokeh.layouts.column(stride_slider, overlap_slider)))
14 changes: 5 additions & 9 deletions demos/notebooks/demo_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"from caiman.source_extraction.cnmf import cnmf, params\n",
"from caiman.utils.utils import download_demo\n",
"from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour\n",
"from caiman.utils.visualization import view_quilt\n",
"from caiman.utils.visualization import nb_view_quilt\n",
"\n",
"bpl.output_notebook()\n",
"hv.notebook_extension('bokeh')"
Expand Down Expand Up @@ -741,7 +741,7 @@
"metadata": {},
"source": [
"### Selecting spatial parameters\n",
"To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use `view_quilt()` function to see if our key spatial parameters are in the right ballpark (note we recommend running this viewer in interactive qt mode so you can interact with it and get a better feel for the parameters):"
"To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use the interactive `nb_view_quilt()` function or `view_quilt()` function to see if our key spatial parameters are in the right ballpark (you can use the sliders in the interactive version to change the `rf` and `stride` parameters and get a better feel for them):"
]
},
{
Expand All @@ -757,13 +757,9 @@
"print(f'Patch width: {cnmf_patch_width} , Stride: {cnmf_patch_stride}, Overlap: {cnmf_patch_overlap}');\n",
"\n",
"# plot the patches\n",
"patch_ax = view_quilt(correlation_image, \n",
" cnmf_patch_stride, \n",
" cnmf_patch_overlap, \n",
" vmin=np.percentile(np.ravel(correlation_image),50), \n",
" vmax=np.percentile(np.ravel(correlation_image),99.5),\n",
" figsize=(4,4));\n",
"patch_ax.set_title(f'CNMF Patches Width {cnmf_patch_width}, Overlap {cnmf_patch_overlap}');"
"patch_ax = nb_view_quilt(correlation_image, \n",
" cnmf_model.params.patch['rf'], \n",
" cnmf_model.params.patch['stride']);"
]
},
{
Expand Down

0 comments on commit 860e8d8

Please sign in to comment.