Skip to content

Commit

Permalink
REFACTOR: plot with Biopython to replace R legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
tcztzy committed Feb 21, 2024
1 parent 76cc8ea commit e9a11d8
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 52 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ __pycache__/

# renv
/renv.lock

# matplotlib
/result_images/
3 changes: 3 additions & 0 deletions argweaver/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .bed import read_bed, subset_bed

__all__ = ["read_bed", "subset_bed"]
47 changes: 47 additions & 0 deletions argweaver/io/bed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import gzip
from pathlib import Path

import pandas as pd

from argweaver.utils import parse_region

__all__ = ["read_bed"]


def read_bed(filename):
"""Read a bed file.
Parameters
----------
filename : {py:obj}`str`
Path to the bed file.
Returns
-------
{py:obj}`pandas.DataFrame`
"""
path = Path(filename)
_open = gzip.open if path.suffix == ".gz" else open
return pd.read_table(
_open(path, "rt"), header=None, names=["chrom", "start", "end", "iter", "tree"]
)


def subset_bed(data, region):
"""Subset a bed file.
Parameters
----------
bedfile : {py:obj}`str`
Path to the bed file.
region : {py:obj}`str`
Subset of the file, in format {chrom}:{start}-{end}
Returns
-------
{py:obj}`pandas.DataFrame`
"""
chrom, start, end = parse_region(region)
intervals = pd.IntervalIndex.from_arrays(data["start"], data["end"], closed="left")
interval = pd.Interval(start, end, closed="left")
return data[(data["chrom"] == chrom) & intervals.overlaps(interval)]
58 changes: 17 additions & 41 deletions argweaver/plot.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,29 @@
"""Plotting functions for ARGweavers."""
import re
import typing
from io import StringIO

import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

from argweaver.r import plotTreesFromBed
from Bio import Phylo

if typing.TYPE_CHECKING:
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
pass

__all__ = ["plot_trees"]
__all__ = ["plot_tree"]


def plot_trees(
bedfile: "Union[str, Path, PathLike]",
*,
i: "Union[int, Literal['max']]" = "max",
s: "Optional[str]" = None,
):
"""Plot trees from a bed file.
def plot_tree(newick: str, *, name="", **kwargs):
"""Plot a tree
Parameters
----------
bedfile : {py:obj}`str`, {py:obj}`pathlib.Path` or {py:obj}`os.PathLike`
Path to the bed file.
i : {py:obj}`int`, optional
The MCMC iteration to use. Default is -1, which means all intervals.
s : {py:obj}`str`, optional
Subset of the file, in format {chrom}:{start}-{end}
newick : {py:obj}`str`
A newick string containing a tree.
name : {py:obj}`str`, optional
Name of the tree.
**kwargs
Additional keyword arguments to pass to {py:obj}`Bio.Phylo.draw`.
"""
kwargs: "Dict[str, Any]" = {"iter": i}
if s is not None:
mo = re.match(r"(\w+):(\d+)-(\d+)", s)
if mo is None:
raise ValueError(f"Invalid subset string: {s}")
kwargs["chrom"] = mo.group(1)
kwargs["start"] = int(mo.group(2))
kwargs["end"] = int(mo.group(3))
rv = plotTreesFromBed(str(bedfile), **kwargs)
dfs = []
for p in rv[0]:
with (ro.default_converter + pandas2ri.converter).context():
dfs.append(ro.conversion.get_conversion().rpy2py(p))
with (ro.default_converter + pandas2ri.converter).context():
x = ro.conversion.get_conversion().rpy2py(rv[1])
return dfs, x
tree = Phylo.read(StringIO(newick), "newick")
if name:
tree.name = name
tree.ladderize()
Phylo.draw(tree, **kwargs)
53 changes: 49 additions & 4 deletions argweaver/r.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,57 @@
"""R interface for argweaver."""
import typing

import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr

from argweaver.utils import parse_region

if typing.TYPE_CHECKING:
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import pandas

__all__ = ["plotTreesFromBed"]

_argweaver = importr("argweaver")

plotTreesFromBed = _argweaver.plotTreesFromBed
"""Plot trees from a bed file.

See [](../../rapidocs/argweaver/plotTreesFromBed.md).
"""
def plotTreesFromBed(
bedfile: "Union[str, Path, PathLike]",
*,
i: "Union[int, Literal['max']]" = "max",
s: "Optional[str]" = None,
) -> "Tuple[List[pandas.DataFrame], pandas.DataFrame]":
"""Plot trees from a bed file.
See [](../../rapidocs/argweaver/plotTreesFromBed.md).
Parameters
----------
bedfile : {py:obj}`str`, {py:obj}`pathlib.Path` or {py:obj}`os.PathLike`
Path to the bed file.
i : {py:obj}`int`, optional
The MCMC iteration to use. Default is -1, which means all intervals.
s : {py:obj}`str`, optional
Subset of the file, in format {chrom}:{start}-{end}
"""
kwargs: "Dict[str, Any]" = {"iter": i}
if s is not None:
chrom, start, end = parse_region(s)
kwargs["chrom"] = chrom
kwargs["start"] = start
kwargs["end"] = end
rv = _argweaver.plotTreesFromBed(str(bedfile), **kwargs)
dfs = []
for p in rv[0]:
with (ro.default_converter + pandas2ri.converter).context():
dfs.append(ro.conversion.get_conversion().rpy2py(p))
with (ro.default_converter + pandas2ri.converter).context():
x = ro.conversion.get_conversion().rpy2py(rv[1])
return dfs, x
15 changes: 15 additions & 0 deletions argweaver/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import re
import typing

if typing.TYPE_CHECKING:
from typing import Tuple


def parse_region(region: str) -> "Tuple[str, int, int]":
mo = re.match(r"(\w+):(\d+)-(\d+)", region)
if mo is None:
raise ValueError(f"Invalid region string: {region}")
chrom, start, end = mo.group(1), int(mo.group(2)), int(mo.group(3))
if end < start:
raise ValueError(f"Invalid region string: {region}")
return chrom, start, end
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ dependencies = [
"pyarrow>=15.0.0",
"rpy2>=3.5.15",
"pandas>=2.0.3",
"matplotlib>=3.7.5",
"biopython>=1.83",
]

[project.scripts]
Expand All @@ -52,6 +54,7 @@ dev-dependencies = [
"esbonio-extensions>=0.2.2",
"myst-parser>=2.0.0",
"sphinx-autodoc2>=0.5.0",
"ipykernel>=6.29.2",
]
[tool.rye.scripts]
cov-report = { chain = ["htmlcov", "htmlcov-serve"] }
Expand All @@ -63,6 +66,12 @@ features = ["extension-module"]
module-name = "argweaver.s"
include = ["bin/*"]

[[tool.mypy.overrides]]
module = [
"rpy2.*",
]
ignore_missing_imports = true

[tool.ruff]
extend-select = ["I"]
target-version = "py38"
Expand Down
3 changes: 3 additions & 0 deletions tests/baseline_images/test_plot/test_plot_tree.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from argweaver.io import read_bed, subset_bed


def test_read_bed(bedfile):
df = read_bed(bedfile)
assert df.shape == (1828, 5)


def test_subset_bed(bedfile):
df = read_bed(bedfile)
subset = subset_bed(df, "chr:1000-2000")
assert subset.shape == (24, 5)
14 changes: 7 additions & 7 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import image_comparison

from argweaver.plot import plot_trees
from argweaver.plot import plot_tree


def test_plot_trees(bedfile):
with pytest.raises(ValueError):
plot_trees(bedfile, s="not valid")
rv = plot_trees(bedfile, s="chr:1000-2000")
assert len(rv[0]) == len(rv[1])
@image_comparison(baseline_images=["test_plot_tree"], extensions=["png"])
def test_plot_tree():
fig, ax = plt.subplots()
plot_tree("((A:1,B:1):1,C:2);", name="Test Plot", axes=ax, do_show=False)
10 changes: 10 additions & 0 deletions tests/test_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from argweaver.r import plotTreesFromBed


def test_plot_trees(bedfile):
with pytest.raises(ValueError):
plotTreesFromBed(bedfile, s="not valid")
rv = plotTreesFromBed(bedfile, s="chr:1000-2000")
assert len(rv[0]) == len(rv[1])
11 changes: 11 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from argweaver.utils import parse_region


def test_parse_region():
assert parse_region("chr:1000-2000") == ("chr", 1000, 2000)
with pytest.raises(ValueError):
parse_region("not valid")
with pytest.raises(ValueError):
parse_region("chr:2000-1000")

0 comments on commit e9a11d8

Please sign in to comment.