Skip to content

Commit

Permalink
Revert mypy changes (#1836)
Browse files Browse the repository at this point in the history
* Revert mypy changes

* Remove unused import
  • Loading branch information
adamjstewart committed Mar 2, 2024
1 parent 454ef65 commit 238d586
Show file tree
Hide file tree
Showing 28 changed files with 39 additions and 77 deletions.
3 changes: 1 addition & 2 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Optional

import kornia.augmentation as K
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
Expand Down Expand Up @@ -113,7 +112,7 @@ def __init__(
self.test_splits = test_splits
self.class_set = class_set
self.use_prior_labels = use_prior_labels
self.prior_smoothing_constant = torch.tensor(prior_smoothing_constant)
self.prior_smoothing_constant = prior_smoothing_constant

if self.use_prior_labels:
self.layers = [
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor

from ..datasets import SpaceNet1
Expand Down Expand Up @@ -88,6 +87,6 @@ def on_after_batch_transfer(
# We add 1 to the mask to map the current {background, building} labels to
# the values {1, 2}. This is necessary because we add 0 padding to the
# mask that we want to ignore in the loss function.
batch["mask"] += torch.tensor(1)
batch["mask"] += 1

return super().on_after_batch_transfer(batch, dataloader_idx)
3 changes: 1 addition & 2 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor:
dtype=np.uint8,
)

mask = torch.from_numpy(mask_data)
mask = mask.long()
mask = torch.from_numpy(mask_data).long()
return mask

def _check_integrity(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ def _load_image(self, index: int) -> Tensor:
)
images.append(array)
arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0)
tensor = torch.from_numpy(arrays)
tensor = tensor.float()
tensor = torch.from_numpy(arrays).float()
return tensor

def _load_target(self, index: int) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def _load_target(self, filename: str) -> Tensor:
with rasterio.open(os.path.join(self.root, "train_agbm", filename), "r") as src:
arr: "np.typing.NDArray[np.float_]" = src.read()

target = torch.from_numpy(arr)
target = target.float()
target = torch.from_numpy(arr).float()
return target

def _verify(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def plot(
else:
n_cols = 2

image, mask = sample["image"] / torch.tensor(3000), sample["mask"]
image, mask = sample["image"] / 3000, sample["mask"]

fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))

Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def _load_image(self, index: int) -> Tensor:
filename = os.path.join(self.root, self.images[index])
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img)
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
Expand All @@ -164,8 +163,7 @@ def _load_target(self, index: int) -> Tensor:
the target
"""
target = int(self.targets[index])
tensor = torch.tensor(target)
tensor = tensor.float()
tensor = torch.tensor(target).float()
return tensor

def _check_integrity(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def _load_image(self, directory: str) -> Tensor:
img = img.resize(size=(self.size, self.size), resample=resample)
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
tensor = tensor.permute((2, 0, 1))
tensor = tensor.float()
tensor = tensor.permute((2, 0, 1)).float()
return tensor

def _load_features(self, directory: str) -> dict[str, Any]:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def _load_image(self, path: str) -> Tensor:
filename = os.path.join(path)
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
Expand Down
4 changes: 1 addition & 3 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,7 @@ def plot(
assert len(hsi_indices) == 3

def normalize(x: Tensor) -> Tensor:
# https://github.com/pytorch/pytorch/issues/116327
out: Tensor = (x - x.min()) / (x.max() - x.min())
return out
return (x - x.min()) / (x.max() - x.min())

ncols = 3

Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def _load_image(self, path: str) -> Tensor:
"""
with rio.open(path) as img:
array = img.read().astype(np.int32)
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, path: str) -> Tensor:
Expand All @@ -146,8 +145,7 @@ def _load_target(self, path: str) -> Tensor:
with rio.open(path) as img:
array = img.read().astype(np.int32)
array = np.clip(array, 0, 1)
mask = torch.from_numpy(array[0])
mask = mask.long()
mask = torch.from_numpy(array[0]).long()
return mask

def __len__(self) -> int:
Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ def _load_image(self, id_: str) -> Tensor:
filename = os.path.join(self.root, "output", id_ + ".jpg")
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img)
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
Expand All @@ -388,8 +387,7 @@ def _load_target(self, id_: str) -> Tensor:
filename = os.path.join(self.root, "output", id_ + "_m.png")
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L"))
tensor = torch.from_numpy(array)
tensor = tensor.long()
tensor = torch.from_numpy(array).long()
return tensor

def _verify_data(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ def _load_image(self, path: str) -> Tensor:
filename = os.path.join(path)
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/mapinwild.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def _load_raster(self, filename: int, source: str) -> Tensor:
array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0)
if array.dtype == np.uint16:
array = array.astype(np.int32)
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _verify(self, url: str, md5: Optional[str] = None) -> None:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def _load_image(self, path: str) -> Tensor:
"""
with rasterio.open(path) as f:
array = f.read()
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, path: str) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor:
with Image.open(path) as img:
images.append(np.array(img))
array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_)
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, path: str) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ def _load_semantic_targets(self, index: int) -> Tensor:
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
# even though the mask file is 3 bands, we just select the first band
array = np.load(self.files[index]["semantic"])[0].astype(np.uint8)
tensor = torch.from_numpy(array)
tensor = tensor.long()
tensor = torch.from_numpy(array).long()
return tensor

def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def _load_image(self, index: int) -> Tensor:
path = self.files[index]["image"]
with rasterio.open(path) as f:
array = f.read()
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, index: int) -> Tensor:
Expand Down
4 changes: 1 addition & 3 deletions torchgeo/datasets/seasonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,7 @@ def _load_target(self, index: int) -> Tensor:
path = self.files.iloc[index][0]
with rasterio.open(f"{path}_labels.tif") as f:
array = f.read() - 1
tensor = torch.from_numpy(array)
tensor = tensor.squeeze()
tensor = tensor.long()
tensor = torch.from_numpy(array).squeeze().long()
return tensor

def _verify(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def _load_image(self, index: int) -> Tensor:
else:
arr = rearrange(arr, "h w c -> c h w")

tensor = torch.from_numpy(arr)
tensor = tensor.to(torch.float32)
tensor = torch.from_numpy(arr).to(torch.float32)
return tensor

def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def _load_mask(
dtype=np.uint8,
)

mask = torch.from_numpy(mask_data)
mask = mask.long()
mask = torch.from_numpy(mask_data).long()

return mask

Expand Down Expand Up @@ -733,8 +732,7 @@ def _load_mask(
dtype=np.uint8,
)

mask = torch.from_numpy(mask_data)
mask = mask.long()
mask = torch.from_numpy(mask_data).long()
return mask

def plot(
Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/ssl4eo_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ def _load_image(self, path: str) -> Tensor:
image
"""
with rasterio.open(path) as src:
image = torch.from_numpy(src.read())
image = image.float()
image = torch.from_numpy(src.read()).float()
return image

def _load_mask(self, path: str) -> Tensor:
Expand All @@ -328,8 +327,7 @@ def _load_mask(self, path: str) -> Tensor:
mask
"""
with rasterio.open(path) as src:
mask = torch.from_numpy(src.read())
mask = mask.long()
mask = torch.from_numpy(src.read()).long()
mask = self.ordinal_map[mask]
return mask

Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def _load_image(self, path: str) -> Tensor:
"""
with rasterio.open(path) as f:
array: "np.typing.NDArray[np.int_]" = f.read()
tensor = torch.from_numpy(array)
tensor = tensor.float()
tensor = torch.from_numpy(array).float()
return tensor

def _verify(self) -> None:
Expand Down
16 changes: 5 additions & 11 deletions torchgeo/losses/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module


Expand All @@ -29,16 +28,12 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
qr loss
"""
q = probs
# https://github.com/pytorch/pytorch/issues/116327
q_bar: Tensor = q.mean(dim=(0, 2, 3))
log_q_bar = torch.log(q_bar)
qbar_log_S: Tensor = q_bar * log_q_bar
qbar_log_S = qbar_log_S.sum()
q_bar = q.mean(dim=(0, 2, 3))
qbar_log_S = (q_bar * torch.log(q_bar)).sum()

q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target))
q_log_p = q_log_p.mean()
q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target)).mean()

loss: Tensor = qbar_log_S - q_log_p
loss = qbar_log_S - q_log_p
return loss


Expand Down Expand Up @@ -67,7 +62,6 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q)
r = F.normalize(z * target, p=1, dim=1)

loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q))
loss = loss.mean()
loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)).mean()

return loss
2 changes: 1 addition & 1 deletion torchgeo/models/rcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
),
)
self.register_buffer(
"biases", torch.zeros(num_patches, requires_grad=False) + torch.tensor(bias)
"biases", torch.zeros(num_patches, requires_grad=False) + bias
)

if mode == "empirical":
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
if torch.sum(self.areas) == 0:
self.areas += torch.tensor(1)
self.areas += 1

def __iter__(self) -> Iterator[list[BoundingBox]]:
"""Return the indices of a dataset.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
if torch.sum(self.areas) == 0:
self.areas += torch.tensor(1)
self.areas += 1

def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Expand Down
9 changes: 3 additions & 6 deletions torchgeo/transforms/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ def apply_transform(
Returns:
The augmented input.
"""
weights = flags["weights"]
weights = weights[..., :, None, None]
weights = weights.to(input.device)
out: Tensor = input * weights
weights = flags["weights"][..., :, None, None].to(input.device)
out = input * weights
out = out.sum(dim=-3)
out = out.unsqueeze(-3)
out = out.expand(input.shape)
out = out.unsqueeze(-3).expand(input.shape)
return out

0 comments on commit 238d586

Please sign in to comment.