Skip to content

Commit

Permalink
Cleanup docstrings in lightly/utils subpackage (#1698)
Browse files Browse the repository at this point in the history
  • Loading branch information
HarshitVashisht11 authored Oct 18, 2024
1 parent e80dda7 commit c9b84fd
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 62 deletions.
60 changes: 46 additions & 14 deletions lightly/utils/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Bounding Box Utils """
"""Bounding Box Utils"""

from __future__ import annotations

Expand Down Expand Up @@ -31,17 +31,26 @@ class BoundingBox:
>>> # (x0, y0, x1, y1) = (10, 20, 30, 40)
>>> W, H = 100, 100 # get image shape
>>> bbox = BoundingBox(10 / W, 20 / H, 30 / W, 40 / H)
"""

def __init__(
self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True
):
"""
clip_values:
Set to true to clip the values into [0, 1] instead of raising an error if they lie outside.
"""
"""Initializes a BoundingBox object.
Args:
x0:
x0 coordinate relative to image width.
y0:
y0 coordinate relative to image height.
x1:
x1 coordinate relative to image width.
y1:
y1 coordinate relative to image height.
clip_values:
If True, clips the coordinates to [0, 1].
"""
if clip_values:

def clip_to_0_1(value: float) -> float:
Expand All @@ -60,14 +69,12 @@ def clip_to_0_1(value: float) -> float:

if x0 >= x1:
raise ValueError(
f"x0 must be smaller than x1 for bounding box "
f"[{x0}, {y0}, {x1}, {y1}]"
f"x0 must be smaller than x1 for bounding box [{x0}, {y0}, {x1}, {y1}]"
)

if y0 >= y1:
raise ValueError(
"y0 must be smaller than y1 for bounding box "
f"[{x0}, {y0}, {x1}, {y1}]"
f"y0 must be smaller than y1 for bounding box [{x0}, {y0}, {x1}, {y1}]"
)

self.x0 = x0
Expand All @@ -77,7 +84,20 @@ def clip_to_0_1(value: float) -> float:

@classmethod
def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
"""Helper to convert from bounding box format with width and height.
"""Creates a BoundingBox from x, y, width, and height.
Args:
x:
x coordinate of the top-left corner relative to image width.
y:
y coordinate of the top-left corner relative to image height.
w:
Width of the bounding box relative to image width.
h:
Height of the bounding box relative to image height.
Returns:
BoundingBox: A BoundingBox instance.
Examples:
>>> bbox = BoundingBox.from_x_y_w_h(0.1, 0.2, 0.2, 0.2)
Expand All @@ -89,11 +109,23 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
def from_yolo_label(
cls, x_center: float, y_center: float, w: float, h: float
) -> BoundingBox:
"""Helper to convert from yolo label format
x_center, y_center, w, h --> x0, y0, x1, y1
"""Creates a BoundingBox from YOLO label format.
Args:
x_center:
x coordinate of the center relative to image width.
y_center:
y coordinate of the center relative to image height.
w:
Width of the bounding box relative to image width.
h:
Height of the bounding box relative to image height.
Returns:
BoundingBox: A BoundingBox instance.
Examples:
>>> bbox = BoundingBox.from_yolo(0.5, 0.4, 0.2, 0.3)
>>> bbox = BoundingBox.from_yolo_label(0.5, 0.4, 0.2, 0.3)
"""
return cls(
Expand Down
33 changes: 20 additions & 13 deletions lightly/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"'pip install lightly[matplotlib]'."
)
except ImportError as ex:
# Matplotlib import can fail if an incompatible dateutil version is installed.
plt = ex


Expand All @@ -24,9 +23,9 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
"""Calculates the mean of the standard deviation of z along each dimension.
This measure was used by [0] to determine the level of collapse of the
learned representations. If the returned number is 0., the outputs z have
collapsed to a constant vector. "If the output z has a zero-mean isotropic
Gaussian distribution" [0], the returned number should be close to 1/sqrt(d)
learned representations. If the returned value is 0., the outputs z have
collapsed to a constant vector. If the output z has a zero-mean isotropic
Gaussian distribution [0], the returned value should be close to 1/sqrt(d),
where d is the dimensionality of the output.
[0]: https://arxiv.org/abs/2011.10566
Expand All @@ -38,9 +37,7 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
Returns:
The mean of the standard deviation of the l2 normalized tensor z along
each dimension.
"""

if len(z.shape) != 2:
raise ValueError(
f"Input tensor must have two dimensions but has {len(z.shape)}!"
Expand All @@ -53,8 +50,18 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor:
def apply_transform_without_normalize(
image: Image.Image,
transform,
):
"""Applies the transform to the image but skips ToTensor and Normalize."""
) -> Image.Image:
"""Applies the transform to the image but skips ToTensor and Normalize.
Args:
image:
The input PIL image.
transform:
The transformation to apply, excluding ToTensor and Normalize.
Returns:
The transformed image.
"""
skippable_transforms = (
torchvision.transforms.ToTensor,
torchvision.transforms.Normalize,
Expand All @@ -70,10 +77,10 @@ def apply_transform_without_normalize(
def generate_grid_of_augmented_images(
input_images: List[Image.Image],
collate_function: Union[BaseCollateFunction, MultiViewCollateFunction],
):
) -> List[List[Image.Image]]:
"""Returns a grid of augmented images. Images in a column belong together.
This function ignores the transforms ToTensor and Normalize for visualization purposes.
This function ignores the ToTensor and Normalize transforms for visualization purposes.
Args:
input_images:
Expand Down Expand Up @@ -116,9 +123,9 @@ def plot_augmented_images(
input_images: List[Image.Image],
collate_function: Union[BaseCollateFunction, MultiViewCollateFunction],
):
"""Returns a figure showing original images in the left column and augmented images to their right.
"""Plots original images and augmented images in a figure.
This function ignores the transforms ToTensor and Normalize for visualization purposes.
This function ignores the ToTensor and Normalize transforms for visualization purposes.
Args:
input_images:
Expand All @@ -134,7 +141,6 @@ def plot_augmented_images(
MultiViewCollateFunctions all the generated views are shown.
"""

_check_matplotlib_available()

if len(input_images) == 0:
Expand Down Expand Up @@ -166,5 +172,6 @@ def plot_augmented_images(


def _check_matplotlib_available() -> None:
"""Checks if matplotlib is available. Raises an error if not."""
if isinstance(plt, Exception):
raise plt
34 changes: 26 additions & 8 deletions lightly/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,42 @@

@functools.lru_cache(maxsize=1)
def torchvision_vit_available() -> bool:
"""Checks if Vision Transformer (ViT) models are available in torchvision.
This function checks if the `vision_transformer` module is available in torchvision,
which requires torchvision version >= 0.12. It also handles exceptions related to
CUDA version mismatches and installation issues.
Returns:
True if the Vision Transformer (ViT) models are available in torchvision,
otherwise False.
"""
try:
import torchvision.models.vision_transformer # Requires torchvision >=0.12
import torchvision.models.vision_transformer # Requires torchvision >=0.12.
except (
RuntimeError, # Different CUDA versions for torch and torchvision
OSError, # Different CUDA versions for torch and torchvision (old)
ImportError, # No installation or old version of torchvision
RuntimeError, # Different CUDA versions for torch and torchvision.
OSError, # Different CUDA versions for torch and torchvision (old).
ImportError, # No installation or old version of torchvision.
):
return False
else:
return True
return True


@functools.lru_cache(maxsize=1)
def timm_vit_available() -> bool:
"""Checks if Vision Transformer (ViT) models are available in the timm library.
This function checks if the `vision_transformer` module and `LayerType` from timm
are available, which requires timm version >= 0.3.3 and >= 0.9.9, respectively.
Returns:
True if the Vision Transformer (ViT) models are available in timm,
otherwise False.
"""
try:
import timm.models.vision_transformer # Requires timm >= 0.3.3
from timm.layers import LayerType # Requires timm >= 0.9.9
except ImportError:
return False
else:
return True
return True
13 changes: 8 additions & 5 deletions lightly/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
class GatherLayer(torch.autograd.Function):
"""Gather tensors from all processes, supporting backward propagation.
This code was taken and adapted from here:
Adapted from the Solo-Learn project:
https://github.com/vturrisi/solo-learn/blob/b69b4bd27472593919956d9ac58902a301537a4d/solo/utils/misc.py#L187
"""

@staticmethod
def forward(ctx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore
def forward(ctx: FunctionCtx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore
output = [torch.empty_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)

@staticmethod
def backward(ctx, *grads) -> torch.Tensor: # type: ignore
def backward(ctx: FunctionCtx, *grads: torch.Tensor) -> torch.Tensor: # type: ignore
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
grad_out = all_gradients[dist.get_rank()]
Expand All @@ -38,7 +38,7 @@ def world_size() -> int:


def gather(input: torch.Tensor) -> Tuple[torch.Tensor]:
"""Gathers this tensor from all processes. Supports backprop."""
"""Gathers a tensor from all processes and supports backpropagation."""
return GatherLayer.apply(input) # type: ignore[no-any-return]


Expand All @@ -62,6 +62,9 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor:
device:
Device on which the matrix should be created.
Returns:
A tensor with the appropriate diagonal filled for this rank.
"""
rows = torch.arange(n, device=device, dtype=torch.long)
cols = rows + rank() * n
Expand All @@ -74,7 +77,7 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor:


def rank_zero_only(fn: Callable[..., R]) -> Callable[..., Optional[R]]:
"""Decorator that only runs the function on the process with rank 0.
"""Decorator to ensure the function only runs on the process with rank 0.
Example:
>>> @rank_zero_only
Expand Down
23 changes: 15 additions & 8 deletions lightly/utils/embeddings_2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Transform embeddings to two-dimensional space for visualization. """
"""Transforms embeddings to two-dimensional space for visualization."""

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
Expand All @@ -21,13 +21,18 @@ class PCA(object):
Number of principal components to keep.
eps:
Epsilon for numerical stability.
mean:
Mean of the data.
w:
Eigenvectors of the covariance matrix.
"""

def __init__(self, n_components: int = 2, eps: float = 1e-10):
self.n_components = n_components
self.eps = eps
self.mean: Optional[NDArray[np.float32]] = None
self.w: Optional[NDArray[np.float32]] = None
self.eps = eps

def fit(self, X: NDArray[np.float32]) -> PCA:
"""Fits PCA to data in X.
Expand All @@ -37,7 +42,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA:
Datapoints stored in numpy array of size n x d.
Returns:
PCA object to transform datapoints.
PCA: The fitted PCA object to transform data points.
"""
X = X.astype(np.float32)
Expand All @@ -46,7 +51,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA:
X = X - self.mean + self.eps
cov = np.cov(X.T) / X.shape[0]
v, w = np.linalg.eig(cov)
idx = v.argsort()[::-1]
idx = v.argsort()[::-1] # Sort eigenvalues in descending order
v, w = v[idx], w[:, idx]
self.w = w
return self
Expand All @@ -62,10 +67,13 @@ def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]:
Numpy array of n x p datapoints where p <= d.
Raises:
ValueError: If PCA was not fitted before.
ValueError:
If PCA is not fitted before calling this method.
"""
if self.mean is None or self.w is None:
raise ValueError("PCA not fitted yet. Call fit() before transform().")

X = X.astype(np.float32)
X = X - self.mean + self.eps
transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components]
Expand All @@ -77,7 +85,7 @@ def fit_pca(
n_components: int = 2,
fraction: Optional[float] = None,
) -> PCA:
"""Fits PCA to randomly selected subset of embeddings.
"""Fits PCA to a randomly selected subset of embeddings.
For large datasets, it can be unfeasible to perform PCA on the whole data.
This method can fit a PCA on a fraction of the embeddings in order to save
Expand All @@ -101,8 +109,7 @@ def fit_pca(
"""
if fraction is not None:
if fraction < 0.0 or fraction > 1.0:
msg = f"fraction must be in [0, 1] but was {fraction}."
raise ValueError(msg)
raise ValueError(f"fraction must be in [0, 1] but was {fraction}.")

N = embeddings.shape[0]
n = N if fraction is None else min(N, int(N * fraction))
Expand Down
Loading

0 comments on commit c9b84fd

Please sign in to comment.