Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup docstrings in lightly/utils subpackage #1698

Merged
merged 10 commits into from
Oct 18, 2024
60 changes: 46 additions & 14 deletions lightly/utils/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Bounding Box Utils """
"""Bounding Box Utils"""

from __future__ import annotations

Expand Down Expand Up @@ -31,17 +31,26 @@ class BoundingBox:
>>> # (x0, y0, x1, y1) = (10, 20, 30, 40)
>>> W, H = 100, 100 # get image shape
>>> bbox = BoundingBox(10 / W, 20 / H, 30 / W, 40 / H)

"""

def __init__(
self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True
):
"""
clip_values:
Set to true to clip the values into [0, 1] instead of raising an error if they lie outside.
"""
"""Initializes a BoundingBox object.

Args:
x0:
x0 coordinate relative to image width.
y0:
y0 coordinate relative to image height.
x1:
x1 coordinate relative to image width.
y1:
y1 coordinate relative to image height.
clip_values:
If True, clips the coordinates to [0, 1].

"""
if clip_values:

def clip_to_0_1(value: float) -> float:
Expand All @@ -60,14 +69,12 @@ def clip_to_0_1(value: float) -> float:

if x0 >= x1:
raise ValueError(
f"x0 must be smaller than x1 for bounding box "
f"[{x0}, {y0}, {x1}, {y1}]"
f"x0 must be smaller than x1 for bounding box [{x0}, {y0}, {x1}, {y1}]"
)

if y0 >= y1:
raise ValueError(
"y0 must be smaller than y1 for bounding box "
f"[{x0}, {y0}, {x1}, {y1}]"
f"y0 must be smaller than y1 for bounding box [{x0}, {y0}, {x1}, {y1}]"
)

self.x0 = x0
Expand All @@ -77,7 +84,20 @@ def clip_to_0_1(value: float) -> float:

@classmethod
def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
"""Helper to convert from bounding box format with width and height.
"""Creates a BoundingBox from x, y, width, and height.

Args:
x:
x coordinate of the top-left corner relative to image width.
y:
y coordinate of the top-left corner relative to image height.
w:
Width of the bounding box.
h:
Height of the bounding box.

MalteEbner marked this conversation as resolved.
Show resolved Hide resolved
Returns:
BoundingBox: A BoundingBox instance.

Examples:
>>> bbox = BoundingBox.from_x_y_w_h(0.1, 0.2, 0.2, 0.2)
Expand All @@ -89,11 +109,23 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
def from_yolo_label(
cls, x_center: float, y_center: float, w: float, h: float
) -> BoundingBox:
"""Helper to convert from yolo label format
x_center, y_center, w, h --> x0, y0, x1, y1
"""Creates a BoundingBox from YOLO label format.

Args:
x_center:
x coordinate of the center relative to image width.
y_center:
y coordinate of the center relative to image height.
w:
Width of the bounding box.
h:
Height of the bounding box.

MalteEbner marked this conversation as resolved.
Show resolved Hide resolved
Returns:
BoundingBox: A BoundingBox instance.

Examples:
>>> bbox = BoundingBox.from_yolo(0.5, 0.4, 0.2, 0.3)
>>> bbox = BoundingBox.from_yolo_label(0.5, 0.4, 0.2, 0.3)

"""
return cls(
Expand Down
33 changes: 20 additions & 13 deletions lightly/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"'pip install lightly[matplotlib]'."
)
except ImportError as ex:
# Matplotlib import can fail if an incompatible dateutil version is installed.
plt = ex


Expand All @@ -24,9 +23,9 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
"""Calculates the mean of the standard deviation of z along each dimension.

This measure was used by [0] to determine the level of collapse of the
learned representations. If the returned number is 0., the outputs z have
collapsed to a constant vector. "If the output z has a zero-mean isotropic
Gaussian distribution" [0], the returned number should be close to 1/sqrt(d)
learned representations. If the returned value is 0., the outputs z have
collapsed to a constant vector. If the output z has a zero-mean isotropic
Gaussian distribution [0], the returned value should be close to 1/sqrt(d),
where d is the dimensionality of the output.

[0]: https://arxiv.org/abs/2011.10566
Expand All @@ -38,9 +37,7 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
Returns:
The mean of the standard deviation of the l2 normalized tensor z along
each dimension.

"""

if len(z.shape) != 2:
raise ValueError(
f"Input tensor must have two dimensions but has {len(z.shape)}!"
Expand All @@ -53,8 +50,18 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
def apply_transform_without_normalize(
image: Image.Image,
transform,
):
"""Applies the transform to the image but skips ToTensor and Normalize."""
) -> Image.Image:
"""Applies the transform to the image but skips ToTensor and Normalize.

Args:
image:
The input PIL image.
transform:
The transformation to apply, excluding ToTensor and Normalize.

Returns:
The transformed image.
"""
skippable_transforms = (
torchvision.transforms.ToTensor,
torchvision.transforms.Normalize,
Expand All @@ -70,10 +77,10 @@ def apply_transform_without_normalize(
def generate_grid_of_augmented_images(
input_images: List[Image.Image],
collate_function: Union[BaseCollateFunction, MultiViewCollateFunction],
):
) -> List[List[Image.Image]]:
"""Returns a grid of augmented images. Images in a column belong together.

This function ignores the transforms ToTensor and Normalize for visualization purposes.
This function ignores the ToTensor and Normalize transforms for visualization purposes.

Args:
input_images:
Expand Down Expand Up @@ -116,9 +123,9 @@ def plot_augmented_images(
input_images: List[Image.Image],
collate_function: Union[BaseCollateFunction, MultiViewCollateFunction],
):
"""Returns a figure showing original images in the left column and augmented images to their right.
"""Plots original images and augmented images in a figure.

This function ignores the transforms ToTensor and Normalize for visualization purposes.
This function ignores the ToTensor and Normalize transforms for visualization purposes.

Args:
input_images:
Expand All @@ -134,7 +141,6 @@ def plot_augmented_images(
MultiViewCollateFunctions all the generated views are shown.

"""

_check_matplotlib_available()

if len(input_images) == 0:
Expand Down Expand Up @@ -166,5 +172,6 @@ def plot_augmented_images(


def _check_matplotlib_available() -> None:
"""Checks if matplotlib is available. Raises an error if not."""
if isinstance(plt, Exception):
raise plt
34 changes: 26 additions & 8 deletions lightly/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,42 @@

@functools.lru_cache(maxsize=1)
def torchvision_vit_available() -> bool:
"""Checks if Vision Transformer (ViT) models are available in torchvision.

This function checks if the `vision_transformer` module is available in torchvision,
which requires torchvision version >= 0.12. It also handles exceptions related to
CUDA version mismatches and installation issues.

Returns:
True if the Vision Transformer (ViT) models are available in torchvision,
otherwise False.
"""
try:
import torchvision.models.vision_transformer # Requires torchvision >=0.12
import torchvision.models.vision_transformer # Requires torchvision >=0.12.
except (
RuntimeError, # Different CUDA versions for torch and torchvision
OSError, # Different CUDA versions for torch and torchvision (old)
ImportError, # No installation or old version of torchvision
RuntimeError, # Different CUDA versions for torch and torchvision.
OSError, # Different CUDA versions for torch and torchvision (old).
ImportError, # No installation or old version of torchvision.
):
return False
else:
return True
return True


@functools.lru_cache(maxsize=1)
def timm_vit_available() -> bool:
"""Checks if Vision Transformer (ViT) models are available in the timm library.

This function checks if the `vision_transformer` module and `LayerType` from timm
are available, which requires timm version >= 0.3.3 and >= 0.9.9, respectively.

Returns:
True if the Vision Transformer (ViT) models are available in timm,
otherwise False.

"""
try:
import timm.models.vision_transformer # Requires timm >= 0.3.3
from timm.layers import LayerType # Requires timm >= 0.9.9
except ImportError:
return False
else:
return True
return True
13 changes: 8 additions & 5 deletions lightly/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
class GatherLayer(torch.autograd.Function):
"""Gather tensors from all processes, supporting backward propagation.

This code was taken and adapted from here:
Adapted from the Solo-Learn project:
https://github.com/vturrisi/solo-learn/blob/b69b4bd27472593919956d9ac58902a301537a4d/solo/utils/misc.py#L187

"""

@staticmethod
def forward(ctx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore
def forward(ctx: FunctionCtx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore
output = [torch.empty_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)

@staticmethod
def backward(ctx, *grads) -> torch.Tensor: # type: ignore
def backward(ctx: FunctionCtx, *grads: torch.Tensor) -> torch.Tensor: # type: ignore
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
grad_out = all_gradients[dist.get_rank()]
Expand All @@ -38,7 +38,7 @@ def world_size() -> int:


def gather(input: torch.Tensor) -> Tuple[torch.Tensor]:
"""Gathers this tensor from all processes. Supports backprop."""
"""Gathers a tensor from all processes and supports backpropagation."""
return GatherLayer.apply(input) # type: ignore[no-any-return]


Expand All @@ -62,6 +62,9 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor:
device:
Device on which the matrix should be created.

Returns:
A tensor with the appropriate diagonal filled for this rank.

"""
rows = torch.arange(n, device=device, dtype=torch.long)
cols = rows + rank() * n
Expand All @@ -74,7 +77,7 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor:


def rank_zero_only(fn: Callable[..., R]) -> Callable[..., Optional[R]]:
"""Decorator that only runs the function on the process with rank 0.
"""Decorator to ensure the function only runs on the process with rank 0.

Example:
>>> @rank_zero_only
Expand Down
23 changes: 15 additions & 8 deletions lightly/utils/embeddings_2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Transform embeddings to two-dimensional space for visualization. """
"""Transforms embeddings to two-dimensional space for visualization."""

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
Expand All @@ -21,13 +21,18 @@
Number of principal components to keep.
eps:
Epsilon for numerical stability.
mean:
Mean of the data.
w:
Eigenvectors of the covariance matrix.

"""

def __init__(self, n_components: int = 2, eps: float = 1e-10):
self.n_components = n_components
self.eps = eps

Check warning on line 33 in lightly/utils/embeddings_2d.py

View check run for this annotation

Codecov / codecov/patch

lightly/utils/embeddings_2d.py#L33

Added line #L33 was not covered by tests
self.mean: Optional[NDArray[np.float32]] = None
self.w: Optional[NDArray[np.float32]] = None
self.eps = eps

def fit(self, X: NDArray[np.float32]) -> PCA:
"""Fits PCA to data in X.
Expand All @@ -37,7 +42,7 @@
Datapoints stored in numpy array of size n x d.

Returns:
PCA object to transform datapoints.
PCA: The fitted PCA object to transform data points.

"""
X = X.astype(np.float32)
Expand All @@ -46,7 +51,7 @@
X = X - self.mean + self.eps
cov = np.cov(X.T) / X.shape[0]
v, w = np.linalg.eig(cov)
idx = v.argsort()[::-1]
idx = v.argsort()[::-1] # Sort eigenvalues in descending order

Check warning on line 54 in lightly/utils/embeddings_2d.py

View check run for this annotation

Codecov / codecov/patch

lightly/utils/embeddings_2d.py#L54

Added line #L54 was not covered by tests
v, w = v[idx], w[:, idx]
self.w = w
return self
Expand All @@ -62,10 +67,13 @@
Numpy array of n x p datapoints where p <= d.

Raises:
ValueError: If PCA was not fitted before.
ValueError:
If PCA is not fitted before calling this method.

"""
if self.mean is None or self.w is None:
raise ValueError("PCA not fitted yet. Call fit() before transform().")

X = X.astype(np.float32)
X = X - self.mean + self.eps
transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components]
Expand All @@ -77,7 +85,7 @@
n_components: int = 2,
fraction: Optional[float] = None,
) -> PCA:
"""Fits PCA to randomly selected subset of embeddings.
"""Fits PCA to a randomly selected subset of embeddings.

For large datasets, it can be unfeasible to perform PCA on the whole data.
This method can fit a PCA on a fraction of the embeddings in order to save
Expand All @@ -101,8 +109,7 @@
"""
if fraction is not None:
if fraction < 0.0 or fraction > 1.0:
msg = f"fraction must be in [0, 1] but was {fraction}."
raise ValueError(msg)
raise ValueError(f"fraction must be in [0, 1] but was {fraction}.")

Check warning on line 112 in lightly/utils/embeddings_2d.py

View check run for this annotation

Codecov / codecov/patch

lightly/utils/embeddings_2d.py#L112

Added line #L112 was not covered by tests

N = embeddings.shape[0]
n = N if fraction is None else min(N, int(N * fraction))
Expand Down
Loading
Loading