Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 2, 2024
1 parent 6f232ca commit a018184
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 59 deletions.
10 changes: 5 additions & 5 deletions tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
def test_getitem(self, dataset: NASAMarineDebris) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["bbox_xyxy"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["bbox_xyxy"].shape[-1] == 4
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['bbox_xyxy'], torch.Tensor)
assert x['image'].shape[0] == 3
assert x['bbox_xyxy'].shape[-1] == 4

def test_len(self, dataset: NASAMarineDebris) -> None:
assert len(dataset) == 4
Expand Down Expand Up @@ -99,6 +99,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction_boxes"] = x["bbox_xyxy"].clone()
x['prediction_boxes'] = x['bbox_xyxy'].clone()
dataset.plot(x)
plt.close()
24 changes: 12 additions & 12 deletions tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_getitem(self, dataset: VHR10) -> None:
for i in range(2):
x = dataset[i]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
if dataset.split == "positive":
assert isinstance(x["class"], torch.Tensor)
assert isinstance(x["bbox_xyxy"], torch.Tensor)
if "mask" in x:
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x['image'], torch.Tensor)
if dataset.split == 'positive':
assert isinstance(x['class'], torch.Tensor)
assert isinstance(x['bbox_xyxy'], torch.Tensor)
if 'mask' in x:
assert isinstance(x['mask'], torch.Tensor)

def test_len(self, dataset: VHR10) -> None:
if dataset.split == 'positive':
Expand Down Expand Up @@ -91,10 +91,10 @@ def test_plot(self, dataset: VHR10) -> None:
scores = [0.7, 0.3, 0.7]
for i in range(3):
x = dataset[i]
x["prediction_labels"] = x["class"]
x["prediction_boxes"] = x["bbox_xyxy"]
x["prediction_scores"] = torch.Tensor([scores[i]])
if "masks" in x:
x["prediction_masks"] = x["mask"]
dataset.plot(x, show_feats="masks")
x['prediction_labels'] = x['class']
x['prediction_boxes'] = x['bbox_xyxy']
x['prediction_scores'] = torch.Tensor([scores[i]])
if 'masks' in x:
x['prediction_masks'] = x['mask']
dataset.plot(x, show_feats='masks')
plt.close()
1 change: 0 additions & 1 deletion torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
GridGeoSampler,
RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException


Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ..datasets import NASAMarineDebris
from .geo import NonGeoDataModule
from .utils import collate_fn_detection, dataset_split
from .utils import collate_fn_detection


class NASAMarineDebrisDataModule(NonGeoDataModule):
Expand Down
23 changes: 10 additions & 13 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

import math
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any

import numpy as np
import torch
from torch import Generator, Tensor
from torch.utils.data import Subset, TensorDataset, random_split

from ..datasets import NonGeoDataset
from torch import Tensor


# Based on lightning_lite.utilities.exceptions
Expand All @@ -32,23 +29,23 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
.. versionadded:: 0.6
"""
output: dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output['image'] = torch.stack([sample['image'] for sample in batch])
# Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"}
bbox_key = "boxes"
bbox_key = 'boxes'
for key in batch[0].keys():
if key in {"bbox", "bbox_xyxy", "bbox_xywh"}:
if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}:
bbox_key = key

output[bbox_key] = [sample[bbox_key].float() for sample in batch]
if "class" in batch[0].keys():
output["class"] = [sample["class"] for sample in batch]
if 'class' in batch[0].keys():
output['class'] = [sample['class'] for sample in batch]
else:
output["class"] = [
output['class'] = [
torch.tensor([1] * len(sample[bbox_key])) for sample in batch
]

if "mask" in batch[0]:
output["mask"] = [sample["mask"] for sample in batch]
if 'mask' in batch[0]:
output['mask'] = [sample['mask'] for sample in batch]
return output


Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def setup(self, stage: str) -> None:
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.kwargs["transforms"] = K.AugmentationSequential(
self.kwargs['transforms'] = K.AugmentationSequential(
K.Resize(self.patch_size), data_keys=None, keepdim=True
)
self.kwargs["transforms"].keepdim = True
self.kwargs['transforms'].keepdim = True
self.dataset = VHR10(**self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
Expand Down
20 changes: 10 additions & 10 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
data and labels at that index
"""
image = self._load_image(self.files[index]["image"])
boxes = self._load_target(self.files[index]["target"])
sample = {"image": image, "bbox_xyxy": boxes}
image = self._load_image(self.files[index]['image'])
boxes = self._load_target(self.files[index]['target'])
sample = {'image': image, 'bbox_xyxy': boxes}

# Filter invalid boxes
w_check = (sample["bbox_xyxy"][:, 2] - sample["bbox_xyxy"][:, 0]) > 0
h_check = (sample["bbox_xyxy"][:, 3] - sample["bbox_xyxy"][:, 1]) > 0
w_check = (sample['bbox_xyxy'][:, 2] - sample['bbox_xyxy'][:, 0]) > 0
h_check = (sample['bbox_xyxy'][:, 3] - sample['bbox_xyxy'][:, 1]) > 0
indices = w_check & h_check
sample["bbox_xyxy"] = sample["bbox_xyxy"][indices]
sample['bbox_xyxy'] = sample['bbox_xyxy'][indices]

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down Expand Up @@ -234,11 +234,11 @@ def plot(
"""
ncols = 1

sample["image"] = sample["image"].byte()
image = sample["image"]
if "bbox_xyxy" in sample and len(sample["bbox_xyxy"]):
sample['image'] = sample['image'].byte()
image = sample['image']
if 'bbox_xyxy' in sample and len(sample['bbox_xyxy']):
image = draw_bounding_boxes(
image=sample["image"], boxes=sample["bbox_xyxy"]
image=sample['image'], boxes=sample['bbox_xyxy']
)
image = image.permute((1, 2, 0)).numpy()

Expand Down
18 changes: 9 additions & 9 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ def __getitem__(self, index: int) -> dict[str, Any]:

if sample['label']['annotations']:
sample = self.coco_convert(sample)
sample["class"] = sample["label"]["labels"]
sample["bbox_xyxy"] = sample["label"]["boxes"]
sample["mask"] = sample["label"]["masks"].float()
del sample["label"]
sample['class'] = sample['label']['labels']
sample['bbox_xyxy'] = sample['label']['boxes']
sample['mask'] = sample['label']['masks'].float()
del sample['label']

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down Expand Up @@ -401,11 +401,11 @@ def plot(
if show_feats != 'boxes':
skimage = lazy_import('skimage')

image = sample["image"].permute(1, 2, 0).numpy()
boxes = sample["bbox_xyxy"].cpu().numpy()
labels = sample["class"].cpu().numpy()
if "mask" in sample:
masks = [mask.squeeze().cpu().numpy() for mask in sample["mask"]]
boxes = sample['bbox_xyxy'].cpu().numpy()
labels = sample['class'].cpu().numpy()

if 'mask' in sample:
masks = [mask.squeeze().cpu().numpy() for mask in sample['mask']]

n_gt = len(boxes)

Expand Down
12 changes: 6 additions & 6 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ def training_step(
batch_size = x.shape[0]
# Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"}
for key in batch.keys():
if key in {"bbox", "bbox_xyxy", "bbox_xywh"}:
if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}:
bbox_key = key
y = [
{"boxes": batch[bbox_key][i], "labels": batch["class"][i]}
{'boxes': batch[bbox_key][i], 'labels': batch['class'][i]}
for i in range(batch_size)
]
loss_dict = self(x, y)
Expand All @@ -264,10 +264,10 @@ def validation_step(
batch_size = x.shape[0]
# Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"}
for key in batch.keys():
if key in {"bbox", "bbox_xyxy", "bbox_xywh"}:
if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}:
bbox_key = key
y = [
{"boxes": batch[bbox_key][i], "labels": batch["class"][i]}
{'boxes': batch[bbox_key][i], 'labels': batch['class'][i]}
for i in range(batch_size)
]
y_hat = self(x)
Expand Down Expand Up @@ -322,10 +322,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_size = x.shape[0]
# Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"}
for key in batch.keys():
if key in {"bbox", "bbox_xyxy", "bbox_xywh"}:
if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}:
bbox_key = key
y = [
{"boxes": batch[bbox_key][i], "labels": batch["class"][i]}
{'boxes': batch[bbox_key][i], 'labels': batch['class'][i]}
for i in range(batch_size)
]
y_hat = self(x)
Expand Down

0 comments on commit a018184

Please sign in to comment.