Skip to content

Commit

Permalink
Merge pull request #3 from YosefLab/plotting-update
Browse files Browse the repository at this point in the history
annotation borders
  • Loading branch information
colganwi authored May 20, 2024
2 parents 6f3aa23 + a6415d0 commit e94521e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
33 changes: 24 additions & 9 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
def branches(
tdata: td.TreeData,
key: str = None,
keys: str | Sequence[str] = None,
polar: bool = False,
extend_branches: bool = False,
angled_branches: bool = False,
Expand All @@ -49,8 +49,8 @@ def branches(
----------
tdata
The `treedata.TreeData` object.
key
The `obst` key of the tree to plot.
keys
The `obst` key or keys of the trees to plot.
polar
Whether to plot the tree in polar coordinates.
extend_branches
Expand Down Expand Up @@ -80,8 +80,12 @@ def branches(
warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2)
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
kwargs = kwargs if kwargs else {}
if not key:
if not keys:
key = next(iter(tdata.obst.keys()))
elif isinstance(keys, str):
key = keys
else:
raise ValueError("Passing a list of keys not implemented. Please pass a single key.")
tree = tdata.obst[key]
# Get layout
node_coords, branch_coords, leaves, depth = layout_tree(
Expand Down Expand Up @@ -308,6 +312,7 @@ def annotation(
width: int | float = 0.05,
gap: int | float = 0.03,
label: bool | str | Sequence[str] = True,
border_width: int | float = 0,
cmap: str | mcolors.Colormap = None,
palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
vmax: int | float | None = None,
Expand All @@ -332,6 +337,8 @@ def annotation(
label
Annotation labels. If `True`, the keys are used as labels.
If a string or a sequence of strings, the strings are used as labels.
border_width
The width of the border around the annotation bar.
{common_plot_args}
na_color
The color to use for annotations with missing data.
Expand Down Expand Up @@ -398,9 +405,17 @@ def annotation(
if attrs["polar"]:
ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs)
ax.set_ylim(-attrs["depth"] * 0.05, end_lat)
# ax.plot([0, np.pi, np.pi, 0, 0], [start_lat, start_lat, end_lat, end_lat, start_lat], color="black")
else:
ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs)
ax.set_xlim(-attrs["depth"] * 0.05, end_lat)
ax.set_xlim(-attrs["depth"] * 0.05, end_lat + attrs["depth"] * 0.05)
# Add border
ax.plot(
[lats[0], lats[0], lats[-1], lats[-1], lats[0]],
[lons[0], lons[-1], lons[-1], lons[0], lons[0]],
color="black",
linewidth=border_width,
)
# Add labels
if labels and len(labels) > 0:
labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
Expand All @@ -427,7 +442,7 @@ def annotation(
)
def tree(
tdata: td.TreeData,
key: str = None,
keys: str | Sequence[str] = None,
nodes: str | Sequence[str] = None,
annotation_keys: str | Sequence[str] = None,
polar: bool = False,
Expand All @@ -451,8 +466,8 @@ def tree(
----------
tdata
The TreeData object.
key
The `obst` key of the tree to plot.
keys
The `obst` key or keys of the trees to plot.
nodes
Either "all", "leaves", "internal", or a list of nodes to plot.
annotation_keys
Expand Down Expand Up @@ -484,7 +499,7 @@ def tree(
# Plot branches
ax = _branches(
tdata,
key=key,
keys=keys,
polar=polar,
extend_branches=extend_branches,
angled_branches=angled_branches,
Expand Down
18 changes: 9 additions & 9 deletions tests/test_plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def test_polar_with_clades(tdata):
fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True})
pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax)
pycea.pl.branches(tdata, keys="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax)
pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax)
pycea.pl.annotation(tdata, keys="clade", ax=ax)
plt.savefig(plot_path / "polar_clades.png")
Expand All @@ -19,11 +19,11 @@ def test_polar_with_clades(tdata):

def test_angled_numeric_annotations(tdata):
pycea.pl.branches(
tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True
tdata, keys="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True
)
pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20)
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes")
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, border_width=2)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", border_width=2)
plt.savefig(plot_path / "angled_numeric.png")
plt.close()

Expand All @@ -32,7 +32,7 @@ def test_matrix_annotation(tdata):
fig, ax = plt.subplots(dpi=300)
pycea.pl.tree(
tdata,
key="tree",
keys="tree",
nodes="internal",
node_color="clade",
node_size="time",
Expand All @@ -46,12 +46,12 @@ def test_matrix_annotation(tdata):
def test_branches_bad_input(tdata):
fig, ax = plt.subplots()
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", color=["bad"] * 5)
pycea.pl.branches(tdata, keys="tree", color=["bad"] * 5)
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5)
pycea.pl.branches(tdata, keys="tree", linewidth=["bad"] * 5)
# Warns about polar
with pytest.warns(match="Polar"):
pycea.pl.branches(tdata, key="tree", polar=True, ax=ax)
pycea.pl.branches(tdata, keys="tree", polar=True, ax=ax)
plt.close()


Expand All @@ -60,7 +60,7 @@ def test_annotation_bad_input(tdata):
fig, ax = plt.subplots()
with pytest.raises(ValueError):
pycea.pl.annotation(tdata, keys="clade")
pycea.pl.branches(tdata, key="tree", ax=ax)
pycea.pl.branches(tdata, keys="tree", ax=ax)
with pytest.raises(ValueError):
pycea.pl.annotation(tdata, keys=None, ax=ax)
with pytest.raises(ValueError):
Expand Down

0 comments on commit e94521e

Please sign in to comment.