Skip to content

Commit

Permalink
Merge pull request #2 from YosefLab/labels
Browse files Browse the repository at this point in the history
text as label
  • Loading branch information
colganwi authored May 16, 2024
2 parents 138835a + 2d2d2f6 commit 6f3aa23
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
57 changes: 35 additions & 22 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections.abc import Mapping, Sequence

import cycler
Expand Down Expand Up @@ -72,6 +73,12 @@ def branches(
-------
ax - The axes that the plot was drawn on.
""" # noqa: D205
# Setup
if not ax:
ax = plt.gca()
if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2)
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
kwargs = kwargs if kwargs else {}
if not key:
key = next(iter(tdata.obst.keys()))
Expand Down Expand Up @@ -116,18 +123,20 @@ def branches(
else:
raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.")
# Plot
if not ax:
subplot_kw = {"projection": "polar"} if polar else None
fig, ax = plt.subplots(subplot_kw=subplot_kw)
elif (ax.name == "polar") != polar:
raise ValueError("Provided axis does not match the requested 'polar' setting.")
ax.add_collection(LineCollection(zorder=1, **kwargs))
# Configure plot
lat_lim = (-0.2, depth)
lon_lim = (0, 2 * np.pi)
ax.set_xlim(lon_lim if polar else lat_lim)
ax.set_ylim(lat_lim if polar else lon_lim)
ax.axis("off")
if polar:
ax.set_ylim((-depth * 0.05, depth * 1.05))
ax.spines["polar"].set_visible(False)
else:
ax.set_ylim((-0.03 * np.pi, 2.03 * np.pi))
ax.set_xlim((-depth * 0.05, depth * 1.05))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.tick_params(length=0)
ax.set_xticks([])
ax.set_yticks([])
ax._attrs = {
"node_coords": node_coords,
"leaves": leaves,
Expand Down Expand Up @@ -388,19 +397,23 @@ def annotation(
# Plot
if attrs["polar"]:
ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs)
ax.set_ylim(-0.2, end_lat)
ax.set_ylim(-attrs["depth"] * 0.05, end_lat)
else:
ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs)
ax.set_xlim(-0.2, end_lat)
labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2)
for idx, label in enumerate(labels):
if is_array and len(labels) == 1:
ax.text(labels_lats[idx], -0.1, label, ha="center", va="top")
ax.set_ylim(-0.5, 2 * np.pi)
else:
ax.text(labels_lats[idx], -0.1, label, ha="center", va="top", rotation=90)
ax.set_ylim(-1, 2 * np.pi)
ax.set_xlim(-attrs["depth"] * 0.05, end_lat)
# Add labels
if labels and len(labels) > 0:
labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2)
existing_ticks = ax.get_xticks()
existing_labels = [label.get_text() for label in ax.get_xticklabels()]
ax.set_xticks(np.append(existing_ticks, labels_lats[:-1]))
ax.set_xticklabels(existing_labels + labels)
for label in ax.get_xticklabels()[len(existing_ticks) :]:
if is_array and len(labels) == 1:
label.set_rotation(0)
else:
label.set_rotation(90)
ax._attrs.update({"offset": end_lat})
return ax

Expand Down
21 changes: 10 additions & 11 deletions tests/test_plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_polar_with_clades(tdata):
fig, ax = plt.subplots(dpi=600, subplot_kw={"polar": True})
fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True})
pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax)
pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax)
pycea.pl.annotation(tdata, keys="clade", ax=ax)
Expand All @@ -18,19 +18,18 @@ def test_polar_with_clades(tdata):


def test_angled_numeric_annotations(tdata):
fig, ax = plt.subplots(dpi=600)
pycea.pl.branches(
tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True, ax=ax
tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True
)
pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20, ax=ax)
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, ax=ax)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", ax=ax)
pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20)
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes")
plt.savefig(plot_path / "angled_numeric.png")
plt.close()


def test_matrix_annotation(tdata):
fig, ax = plt.subplots(dpi=600)
fig, ax = plt.subplots(dpi=300)
pycea.pl.tree(
tdata,
key="tree",
Expand All @@ -44,19 +43,19 @@ def test_matrix_annotation(tdata):
plt.close()


def test_branches_invalid_input(tdata):
def test_branches_bad_input(tdata):
fig, ax = plt.subplots()
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", color=["bad"] * 5)
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5)
# Can't plot polar with non-polar axis
with pytest.raises(ValueError):
# Warns about polar
with pytest.warns(match="Polar"):
pycea.pl.branches(tdata, key="tree", polar=True, ax=ax)
plt.close()


def test_annotation_invalid_input(tdata):
def test_annotation_bad_input(tdata):
# Need to plot branches first
fig, ax = plt.subplots()
with pytest.raises(ValueError):
Expand Down

0 comments on commit 6f3aa23

Please sign in to comment.