Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datashader speedup and bugfixes #309

Merged
merged 34 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
50808b3
speed up datashader by using canvas size equal to image size
Jul 17, 2024
759ecd8
_cax definition bugfix
Jul 17, 2024
ea718bf
attempt mypy error fixes
Jul 17, 2024
eef7b8b
delete comments
Jul 17, 2024
702a226
continuous agg with mean instead of sum and use linear cmap instead o…
Jul 19, 2024
420772a
Merge branch 'main' into feature/296-datashader-canvas-size
LucaMarconato Jul 22, 2024
48c1c52
switch back to sum() and move duplicate datashader code into private …
Jul 23, 2024
fc89462
make ds reduction a kwarg and fix ds colorbar limits for continuous c…
Jul 30, 2024
a0dac08
adapt ds px using dpi/100
Jul 31, 2024
c8b0b34
spread how kw refactoring
Aug 6, 2024
e1662e8
add tests and minor adaptations
Sep 10, 2024
d3a4c14
Merge branch 'main' into feature/296-datashader-canvas-size
Sonja-Stockhaus Sep 10, 2024
f0074c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
6dfaf51
fix merge conflict resolution
Sep 10, 2024
3d890a6
update test images
Sep 10, 2024
a2b66e1
update changelog
Sep 10, 2024
8b0b24d
minor refactor
timtreis Sep 14, 2024
982c627
added tests for remaining methods
timtreis Sep 14, 2024
0d124d7
added images from runner
timtreis Sep 14, 2024
ceb4fd2
remove m2 reduction, fix datashader coloring shapes by given color
Sep 18, 2024
59a19da
add ds shapes outlines, fix shapes fill_alpha behavior, fix std/var/a…
Oct 1, 2024
22cdabc
add test images
Oct 1, 2024
1d07544
update test images
Oct 1, 2024
febd424
changelog update and documentation
Oct 2, 2024
2a95236
Merge branch 'main' into feature/296-datashader-canvas-size
Sonja-Stockhaus Oct 11, 2024
1d871ed
fix resolved merge conflict
Oct 11, 2024
0bf8c35
fix resolved merge conflict
Oct 11, 2024
4391b81
Wouter's feedback
Oct 14, 2024
a06400b
fix coordinate system
Oct 14, 2024
4b52fe7
passing Normalize with datashader has an effect now
Oct 16, 2024
cd0f68b
fix norm behavior
Oct 16, 2024
214cb9d
Tim's feedback
Oct 16, 2024
ddc7927
fix test
Oct 16, 2024
73568ec
add test image
Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 85 additions & 19 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import matplotlib
import matplotlib.transforms as mtransforms
import numpy as np
import numpy.ma as ma
import pandas as pd
import scanpy as sc
import spatialdata as sd
Expand All @@ -20,11 +21,12 @@
from matplotlib.colors import ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent
from spatialdata.models import PointsModel, get_table_keys
from spatialdata.models import Image2DModel, PointsModel, get_table_keys
from spatialdata.transformations import (
get_transformation,
set_transformation,
)
from spatialdata.transformations.transformations import Scale

from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import (
Expand Down Expand Up @@ -164,16 +166,29 @@ def _render_shapes(
if method == "datashader":
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
Sonja-Stockhaus marked this conversation as resolved.
Show resolved Hide resolved

extent = get_extent(sdata.shapes[element])
x_ext = extent["x"][1]
y_ext = extent["y"][1]
x_range = [0, x_ext]
y_range = [0, y_ext]
# round because we need integers
plot_width = int(np.round(x_range[1] - x_range[0]))
plot_height = int(np.round(y_range[1] - y_range[0]))
extent = get_extent(sdata_filt.shapes[element], coordinate_system=coordinate_system)
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
previous_xlim = ax.get_xlim()
previous_ylim = ax.get_ylim()
# increase range if sth larger was rendered before
if _mpl_ax_contains_elements(ax):
x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])]
if ax.yaxis_inverted(): # case for e.g. images
y_ext = [min(y_ext[0], previous_ylim[1]), max(y_ext[1], previous_ylim[0])]
else: # case for e.g. labels
y_ext = [min(y_ext[0], previous_ylim[0]), max(y_ext[1], previous_ylim[1])]

# compute canvas size in pixels close to the actual image size to speed up computation
plot_width = x_ext[1] - x_ext[0]
plot_height = y_ext[1] - y_ext[0]
plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi))
plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi))
factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px])
plot_width = int(np.round(plot_width / factor))
plot_height = int(np.round(plot_height / factor))

cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range)
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)

_geometry = shapes["geometry"]
is_point = _geometry.type == "Point"
Expand Down Expand Up @@ -223,16 +238,44 @@ def _render_shapes(
cmap=render_params.cmap_params.cmap,
)
)
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
_cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
_cax.set_transform(trans)
cax = ax.add_image(_cax)

# create SpatialImage to get it back to original size
rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1))
rgba_image = Image2DModel.parse(
rgba_image,
dims=("c", "y", "x"),
transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
)

# prepare transformation
trans = get_transformation(rgba_image, get_all=True)["global"]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)
trans_data = trans + ax.transData

rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (and in render_shapes), we access the image as numpy array from the SpatialImage. mypy doesn't believe that compute() exists...

rgba_image = ma.masked_array(rgba_image) # type conversion for mypy
_cax = _ax_show_and_transform(
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha
)

cax = None
if aggregate_with_sum is not None:
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]),
cmap=render_params.cmap_params.cmap,
)

# rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
# _cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
# _cax.set_transform(trans)
# cax = ax.add_image(_cax)
# if aggregate_with_sum is not None:
# cax = ScalarMappable(
# norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]),
# cmap=render_params.cmap_params.cmap,
# )

elif method == "matplotlib":
_cax = _get_collection_shape(
shapes=shapes,
Expand Down Expand Up @@ -416,9 +459,15 @@ def _render_points(
y_ext = [min(y_ext[0], previous_ylim[1]), max(y_ext[1], previous_ylim[0])]
else: # case for e.g. labels
y_ext = [min(y_ext[0], previous_ylim[0]), max(y_ext[1], previous_ylim[1])]
# round because we need integers
plot_width = int(np.round(x_ext[1] - x_ext[0]))
plot_height = int(np.round(y_ext[1] - y_ext[0]))

# compute canvas size in pixels close to the actual image size to speed up computation
plot_width = x_ext[1] - x_ext[0]
plot_height = y_ext[1] - y_ext[0]
plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi))
plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi))
factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px])
plot_width = int(np.round(plot_width / factor))
plot_height = int(np.round(plot_height / factor))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd consider bundling this code in a private function since it's duplicate from above.


# use datashader for the visualization of points
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
Expand Down Expand Up @@ -456,8 +505,25 @@ def _render_points(
cmap=render_params.cmap_params.cmap,
)

rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
# create SpatialImage to get it back to original size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this code is a duplicate, I'd consider it refactoring it into a private function.

rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1))
rgba_image = Image2DModel.parse(
rgba_image,
dims=("c", "y", "x"),
transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
)

# prepare transformation
trans = get_transformation(rgba_image, get_all=True)["global"]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)
trans_data = trans + ax.transData

rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined]
rgba_image = ma.masked_array(rgba_image) # type conversion for mypy
_ax_show_and_transform(rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.alpha)

cax = None
if aggregate_with_sum is not None:
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]),
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,7 @@ def _ax_show_and_transform(
alpha: float | None = None,
cmap: ListedColormap | LinearSegmentedColormap | None = None,
zorder: int = 0,
) -> None:
) -> matplotlib.image.AxesImage:
if not cmap and alpha is not None:
im = ax.imshow(
array,
Expand All @@ -1978,6 +1978,7 @@ def _ax_show_and_transform(
zorder=zorder,
)
im.set_transform(trans_data)
return im


def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = None) -> ListedColormap:
Expand Down
Loading