Skip to content

Commit

Permalink
RasterDataset: remove plot method (#476)
Browse files Browse the repository at this point in the history
* RasterDataset: remove plot method

* Remove RasterDataset plot tests

* Remove plotting tests for landsat/naip
  • Loading branch information
adamjstewart authored Mar 30, 2022
1 parent f20f02a commit 9f96cdd
Show file tree
Hide file tree
Showing 14 changed files with 9 additions and 82 deletions.
6 changes: 0 additions & 6 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,6 @@ def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
RasterDataset(str(tmp_path))

def test_plot_with_cmap(self, custom_dtype_ds: RasterDataset) -> None:
custom_dtype_ds.cmap = {i: (0, 0, 0, 255) for i in range(256)}
custom_dtype_ds.is_image = False
x = custom_dtype_ds[custom_dtype_ds.bounds]
custom_dtype_ds.plot(x["mask"])


class TestVectorDataset:
@pytest.fixture(scope="class")
Expand Down
7 changes: 0 additions & 7 deletions tests/datasets/test_landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand All @@ -17,7 +16,6 @@
class TestLandsat8:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> Landsat8:
monkeypatch.setattr(plt, "show", lambda *args: None)
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
transforms = nn.Identity() # type: ignore[no-untyped-call]
Expand All @@ -40,11 +38,6 @@ def test_or(self, dataset: Landsat8) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: Landsat8) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x["image"])

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "):
Landsat8(str(tmp_path))
Expand Down
7 changes: 0 additions & 7 deletions tests/datasets/test_naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand All @@ -17,7 +16,6 @@
class TestNAIP:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> NAIP:
monkeypatch.setattr(plt, "show", lambda *args: None)
root = os.path.join("tests", "data", "naip")
transforms = nn.Identity() # type: ignore[no-untyped-call]
return NAIP(root, transforms=transforms)
Expand All @@ -36,11 +34,6 @@ def test_or(self, dataset: NAIP) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: NAIP) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x["image"])

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No NAIP data was found in "):
NAIP(str(tmp_path))
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _download(self) -> None:
item["properties"]["tile_id"] + ".tif",
)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _verify(self) -> None:
"have manually downloaded dataset tiles as suggested in the documentation."
)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _extract(self) -> None:
for zipfile in glob.iglob(pathname):
extract_archive(zipfile)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.zipfile))

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _extract(self) -> None:
pathname = os.path.join(self.root, self.zipfile)
extract_archive(pathname)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/esri2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.zipfile))

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _verify(self) -> None:
"have manually downloaded the dataset as suggested in the documentation."
)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
52 changes: 0 additions & 52 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import fiona
import fiona.transform
import matplotlib.pyplot as plt
import numpy as np
import pyproj
import rasterio
Expand Down Expand Up @@ -303,10 +302,6 @@ class RasterDataset(GeoDataset):
#: Names of RGB bands in the dataset, used for plotting
rgb_bands: List[str] = []

#: If True, stretch the image from the 2nd percentile to the 98th percentile,
#: used for plotting
stretch = False

#: Color map for the dataset, used for plotting
cmap: Dict[int, Tuple[int, int, int, int]] = {}

Expand Down Expand Up @@ -503,53 +498,6 @@ def _load_warp_file(self, filepath: str) -> DatasetReader:
else:
return src

def plot(self, data: Tensor) -> None:
"""Plot a data sample.
Args:
data: the data to plot
Raises:
AssertionError: if ``is_image`` is True and ``data`` has a different number
of channels than expected
"""
array = data.squeeze().numpy()

if self.is_image:
bands = getattr(self, "bands", self.all_bands)
assert array.shape[0] == len(bands)

# Only plot RGB bands
if bands and self.rgb_bands:
indices: "np.typing.NDArray[np.int_]" = np.array(
[bands.index(band) for band in self.rgb_bands]
)
array = array[indices]

# Convert from CxHxW to HxWxC
array = np.rollaxis(array, 0, 3)

if self.cmap:
# Convert from class labels to RGBA values
cmap: "np.typing.NDArray[np.int_]" = np.array(
[self.cmap[i] for i in range(len(self.cmap))]
)
array = cmap[array]

if self.stretch:
# Stretch to the range of 2nd to 98th percentile
per02 = np.percentile(array, 2)
per98 = np.percentile(array, 98)
array = (array - per02) / (per98 - per02)
array = np.clip(array, 0, 1)

# Plot the data
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()


class VectorDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as vector files."""
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _verify(self) -> None:
"have manually downloaded the dataset as suggested in the documentation."
)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down
1 change: 0 additions & 1 deletion torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class Landsat(RasterDataset, abc.ABC):
rgb_bands: List[str] = []

separate_files = True
stretch = True

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(

super().__init__(root, crs, res, transforms, cache)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
Expand Down

0 comments on commit 9f96cdd

Please sign in to comment.