Skip to content

Commit

Permalink
Enhancement of SSIM/MS-SSIM and BRISQUE, Refactoring of tests for SSI…
Browse files Browse the repository at this point in the history
…M/MS-SSIM and BRISQUE (#134)

* tests(ssim): fix the randomly failing test

Signed-off-by: Sergey Kastryulin <[email protected]>

* refactoring(ssim): changes to simplify the ssim/ms-ssim

* refactoring(test): changes ssim/ms-ssim to check values on real images

* refactoring(ssim): minor

* refactoring(tests): changes (ms)ssim tests to for better readability

* refactoring(tests): fix memory consumption

* refactoring(tests): docs

* refactoring(tests): fix BRISQUE tests on real images

* refactoring(tests): small fix for better utility

* refact(ssim/brisque): changes proposed by @snk4tr and @zakajd

* refact(tests): Hit 100% coverage for `TVLoss`, 'utls.py'.

* docs: groomed to the single format

* minor(all): changes after merge

* release_commit: v0.5.0

Co-authored-by: Sergey Kastryulin <[email protected]>
  • Loading branch information
denproc and snk4tr authored Jul 14, 2020
1 parent f4b6c68 commit eaeb38b
Show file tree
Hide file tree
Showing 11 changed files with 658 additions and 898 deletions.
2 changes: 1 addition & 1 deletion piq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.1"
__version__ = "0.5.0"

from .ssim import ssim, multi_scale_ssim, SSIMLoss, MultiScaleSSIMLoss
from .msid import MSID
Expand Down
231 changes: 105 additions & 126 deletions piq/brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,109 @@
from torch.utils.model_zoo import load_url
import torch.nn.functional as F
from piq.utils import _adjust_dimensions, _validate_input
from piq.functional import rgb2yiq
from piq.functional import rgb2yiq, gaussian_filter


def brisque(x: torch.Tensor,
kernel_size: int = 7, kernel_sigma: float = 7 / 6,
data_range: Union[int, float] = 1., reduction: str = 'mean',
interpolation: str = 'nearest') -> torch.Tensor:
r"""Interface of BRISQUE index.
Args:
x: Batch of images. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
kernel_sigma: Sigma of normal distribution.
data_range: Value range of input images (usually 1.0 or 255).
reduction: Reduction over samples in batch: "mean"|"sum"|"none".
interpolation: Interpolation to be used for scaling.
Returns:
Value of BRISQUE index.
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
_validate_input(input_tensors=x, allow_5d=False)
x = _adjust_dimensions(input_tensors=x)

assert data_range >= x.max(), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.'
x = x * 255. / data_range

if x.size(1) == 3:
x = rgb2yiq(x)[:, :1]
features = []
num_of_scales = 2
for _ in range(num_of_scales):
features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma))
x = F.interpolate(x, size=(x.size(2) // 2, x.size(3) // 2), mode=interpolation)

features = torch.cat(features, dim=-1)
scaled_features = _scale_features(features)
score = _score_svr(scaled_features)
if reduction == 'none':
return score

return {'mean': score.mean,
'sum': score.sum
}[reduction](dim=0)


class BRISQUELoss(_Loss):
r"""Creates a criterion that measures the BRISQUE score for input :math:`x`.
:math:`x` is tensor of 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
The sum operation still operates over all the elements, and divides by :math:`n`.
The division by :math:`n` can be avoided by setting ``reduction = 'sum'``.
Args:
kernel_size: By default, the mean and covariance of a pixel is obtained
by convolution with given filter_size.
kernel_sigma: Standard deviation for Gaussian kernel.
data_range: The difference between the maximum and minimum of the pixel value,
i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
The pixel value interval of both input and output should remain the same.
reduction: Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``.
interpolation: Interpolation to be used for scaling.
Shape:
- Input: Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
Examples::
>>> loss = BRISQUELoss()
>>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> target = torch.rand(3, 3, 256, 256)
>>> output = loss(prediction)
>>> output.backward()
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
def __init__(self, kernel_size: int = 7, kernel_sigma: float = 7 / 6,
data_range: Union[int, float] = 1., reduction: str = 'mean',
interpolation: str = 'nearest') -> None:
super().__init__()
self.reduction = reduction
self.kernel_size = kernel_size
self.kernel_sigma = kernel_sigma
self.data_range = data_range
self.interpolation = interpolation

def forward(self, prediction: torch.Tensor) -> torch.Tensor:
r"""Computation of BRISQUE score as a loss function.
Args:
prediction: Tensor of prediction of the network.
Returns:
Value of BRISQUE loss to be minimized.
"""
return brisque(prediction, reduction=self.reduction, kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma, data_range=self.data_range, interpolation=self.interpolation)


def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -65,24 +167,8 @@ def _aggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
return solution, left_sigma.squeeze(dim=-1), right_sigma.squeeze(dim=-1)


def _gaussian_kernel2d(kernel_size: int = 7, sigma: float = 7 / 6) -> torch.Tensor:
r"""Returns 2D Gaussian kernel N(0,`sigma`)
Args:
kernel_size: Size
sigma: Sigma
Returns:
gaussian_kernel: 2D kernel with shape (kernel_size x kernel_size)
"""
x = torch.arange(- (kernel_size // 2), kernel_size // 2 + 1).view(1, kernel_size)
y = torch.arange(- (kernel_size // 2), kernel_size // 2 + 1).view(kernel_size, 1)
kernel = torch.exp(-(x * x + y * y) / (2.0 * sigma ** 2))
kernel = kernel / torch.sum(kernel)
return kernel


def _natural_scene_statistics(luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7. / 6) -> torch.Tensor:
kernel = _gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma).view(1, 1, kernel_size, kernel_size).to(luma)
kernel = gaussian_filter(size=kernel_size, sigma=sigma).view(1, 1, kernel_size, kernel_size).to(luma)
C = 1
mu = F.conv2d(luma, kernel, padding=kernel_size // 2)
mu_sq = mu ** 2
Expand Down Expand Up @@ -132,9 +218,7 @@ def _scale_features(features: torch.Tensor) -> torch.Tensor:


def _rbf_kernel(features: torch.Tensor, sv: torch.Tensor, gamma: float = 0.05) -> torch.Tensor:
features.unsqueeze_(dim=-1)
sv.unsqueeze_(dim=0)
dist = (features - sv).pow(2).sum(dim=1)
dist = (features.unsqueeze(dim=-1) - sv.unsqueeze(dim=0)).pow(2).sum(dim=1)
return torch.exp(- dist * gamma)


Expand All @@ -151,108 +235,3 @@ def _score_svr(features: torch.Tensor) -> torch.Tensor:
kernel_features = _rbf_kernel(features=features, sv=sv, gamma=gamma)
score = kernel_features @ sv_coef
return score - rho


def brisque(x: torch.Tensor,
kernel_size: int = 7, kernel_sigma: float = 7 / 6,
data_range: Union[int, float] = 1., reduction: str = 'mean',
interpolation: str = 'nearest') -> torch.Tensor:
r"""Interface of SBRISQUE index.
Args:
x: Batch of images. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
kernel_sigma: Sigma of normal distribution.
data_range: Value range of input images (usually 1.0 or 255).
reduction: Reduction over samples in batch: "mean"|"sum"|"none".
interpolation: Interpolation to be used for scaling.
Returns:
Value of BRISQUE index.
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
_validate_input(input_tensors=x, allow_5d=False)
x = _adjust_dimensions(input_tensors=x)

assert data_range >= x.max(), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.'
x = x * 255. / data_range

if x.size(1) == 3:
x = rgb2yiq(x)[:, :1]
features = []
num_of_scales = 2
for _ in range(num_of_scales):
features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma))
x = F.interpolate(x, scale_factor=0.5, mode=interpolation)

features = torch.cat(features, dim=-1)
scaled_features = _scale_features(features)
score = _score_svr(scaled_features)
if reduction == 'none':
return score

return {'mean': score.mean,
'sum': score.sum
}[reduction](dim=0)


class BRISQUELoss(_Loss):
r"""Creates a criterion that measures the BRISQUE score for input :math:`x`.
:math:`x` is tensor of 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
The sum operation still operates over all the elements, and divides by :math:`n`.
The division by :math:`n` can be avoided by setting ``reduction = 'sum'``.
Args:
kernel_size: By default, the mean and covariance of a pixel is obtained
by convolution with given filter_size.
kernel_sigma: Standard deviation for Gaussian kernel.
data_range: The difference between the maximum and minimum of the pixel value,
i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
The pixel value interval of both input and output should remain the same.
reduction: Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``.
interpolation: Interpolation to be used for scaling.
Shape:
- Input: Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first.
Examples::
>>> loss = BRISQUELoss()
>>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> target = torch.rand(3, 3, 256, 256)
>>> output = loss(prediction)
>>> output.backward()
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
def __init__(self, kernel_size: int = 7, kernel_sigma: float = 7 / 6,
data_range: Union[int, float] = 1., reduction: str = 'mean',
interpolation: str = 'nearest') -> None:
super().__init__()
self.reduction = reduction
self.kernel_size = kernel_size
self.kernel_sigma = kernel_sigma
self.data_range = data_range
self.interpolation = interpolation

def forward(self, prediction: torch.Tensor) -> torch.Tensor:
r"""Computation of BRISQUE score as a loss function.
Args:
prediction: Tensor of prediction of the network.
Returns:
Value of BRISQUE loss to be minimized.
"""

return brisque(prediction, reduction=self.reduction, kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma, data_range=self.data_range, interpolation=self.interpolation)
4 changes: 2 additions & 2 deletions piq/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from piq.functional.base import ifftshift, get_meshgrid, similarity_map, gradient_map
from piq.functional.colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq
from piq.functional.filters import hann_filter, scharr_filter, prewitt_filter
from piq.functional.filters import hann_filter, scharr_filter, prewitt_filter, gaussian_filter
from piq.functional.layers import L2Pool2d


__all__ = [
'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map',
'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq',
'hann_filter', 'scharr_filter', 'prewitt_filter',
'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
'L2Pool2d',
]
19 changes: 8 additions & 11 deletions piq/functional/colour_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@


def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
r"""
Convert a batch of RGB images to a batch of LMN images
r"""Convert a batch of RGB images to a batch of LMN images
Args:
x: Batch of 4D (N x 3 x H x W) images in RGB colour space.
Expand All @@ -21,8 +20,7 @@ def rgb2lmn(x: torch.Tensor) -> torch.Tensor:


def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
r"""
Convert a batch of RGB images to a batch of XYZ images
r"""Convert a batch of RGB images to a batch of XYZ images
Args:
x: Batch of 4D (N x 3 x H x W) images in RGB colour space.
Expand All @@ -43,14 +41,14 @@ def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
return x_xyz


def xyz2lab(x: torch.Tensor, illuminant='D50', observer='2') -> torch.Tensor:
r"""
Convert a batch of XYZ images to a batch of LAB images
def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
r"""Convert a batch of XYZ images to a batch of LAB images
Args:
x: Batch of 4D (N x 3 x H x W) images in XYZ colour space.
illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant.
observer: {“2”, “10”}, optional. The aperture angle of the observer.
Returns:
Batch of 4D (N x 3 x H x W) images in LAB colour space.
"""
Expand Down Expand Up @@ -88,21 +86,20 @@ def xyz2lab(x: torch.Tensor, illuminant='D50', observer='2') -> torch.Tensor:


def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
r"""
Convert a batch of RGB images to a batch of LAB images
r"""Convert a batch of RGB images to a batch of LAB images
Args:
x: Batch of 4D (N x 3 x H x W) images in RGB colour space.
data_range: dynamic range of the input image.
Returns:
Batch of 4D (N x 3 x H x W) images in LAB colour space.
"""
return xyz2lab(rgb2xyz(x / float(data_range)))


def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
r"""
Convert a batch of RGB images to a batch of YIQ images
r"""Convert a batch of RGB images to a batch of YIQ images
Args:
x: Batch of 4D (N x 3 x H x W) images in RGB colour space.
Expand Down
18 changes: 18 additions & 0 deletions piq/functional/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ def hann_filter(kernel_size) -> torch.Tensor:
return kernel.view(1, kernel_size, kernel_size) / kernel.sum()


def gaussian_filter(size: int, sigma: float) -> torch.Tensor:
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
Args:
size: Size of the lernel
sigma: Std of the distribution
Returns:
gaussian_kernel: 2D kernel with shape (1 x kernel_size x kernel_size)
"""
coords = torch.arange(size).to(dtype=torch.float32)
coords -= (size - 1) / 2.

g = coords ** 2
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()

g /= g.sum()
return g.unsqueeze(0)


# Gradient operator kernels
def scharr_filter() -> torch.Tensor:
r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction
Expand Down
Loading

0 comments on commit eaeb38b

Please sign in to comment.