diff --git a/Histogram_refactoring.ipynb b/Histogram_refactoring.ipynb new file mode 100644 index 0000000..a1a725d --- /dev/null +++ b/Histogram_refactoring.ipynb @@ -0,0 +1,1070 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Refactoring HistoGAN: Validation and Benchmark\n", + "\n", + "The refactored histogram blocks extracted common functionality and -data into a separate base class.\n", + "Here we test, whether the results match the original implementation and assess potential performance benefits.\n", + "\n", + "## Reference Implementation\n", + "\n", + "Below are the reference implemenations of all three histogram blocks: RGB-uv, rg-chroma, and Lab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "EPS = 1e-6\n", + "\n", + "class OriginalRGBuvHistBlock(nn.Module):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=True,\n", + " hist_boundary=None, green_only=False, device='cuda'):\n", + " super().__init__()\n", + " self.h = h\n", + " self.insz = insz\n", + " self.device = device\n", + " self.resizing = resizing\n", + " self.method = method\n", + " self.intensity_scale = intensity_scale\n", + " self.green_only = green_only\n", + " if hist_boundary is None:\n", + " hist_boundary = [-3, 3]\n", + " hist_boundary.sort()\n", + " self.hist_boundary = hist_boundary\n", + " if self.method == 'thresholding':\n", + " self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h\n", + " else:\n", + " self.sigma = sigma\n", + "\n", + " def forward(self, x):\n", + " x = torch.clamp(x, 0, 1)\n", + " if x.shape[2] > self.insz or x.shape[3] > self.insz:\n", + " if self.resizing == 'interpolation':\n", + " x_sampled = F.interpolate(x, size=(self.insz, self.insz),\n", + " mode='bilinear', align_corners=False)\n", + " elif self.resizing == 'sampling':\n", + " inds_1 = torch.LongTensor(\n", + " np.linspace(0, x.shape[2], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " inds_2 = torch.LongTensor(\n", + " np.linspace(0, x.shape[3], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " x_sampled = x.index_select(2, inds_1)\n", + " x_sampled = x_sampled.index_select(3, inds_2)\n", + " else:\n", + " raise Exception(\n", + " f'Wrong resizing method. It should be: interpolation or sampling. '\n", + " f'But the given value is {self.resizing}.')\n", + " else:\n", + " x_sampled = x\n", + "\n", + " L = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2,\n", + " self.h, self.h)).to(device=self.device)\n", + " for l in range(L):\n", + " I = torch.t(torch.reshape(X[l], (3, -1)))\n", + " II = torch.pow(I, 2)\n", + " if self.intensity_scale:\n", + " Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),\n", + " dim=1)\n", + " else:\n", + " Iy = 1\n", + " if not self.green_only:\n", + " Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] +\n", + " EPS), dim=1)\n", + " Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] +\n", + " EPS), dim=1)\n", + " diff_u0 = abs(\n", + " Iu0 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " diff_v0 = abs(\n", + " Iv0 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " if self.method == 'thresholding':\n", + " diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2\n", + " diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2\n", + " elif self.method == 'RBF':\n", + " diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u0 = torch.exp(-diff_u0) # Radial basis function\n", + " diff_v0 = torch.exp(-diff_v0)\n", + " elif self.method == 'inverse-quadratic':\n", + " diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic\n", + " diff_v0 = 1 / (1 + diff_v0)\n", + " else:\n", + " raise Exception(\n", + " f'Wrong kernel method. It should be either thresholding, RBF,'\n", + " f' inverse-quadratic. But the given value is {self.method}.')\n", + " diff_u0 = diff_u0.type(torch.float32)\n", + " diff_v0 = diff_v0.type(torch.float32)\n", + " a = torch.t(Iy * diff_u0)\n", + " hists[l, 0, :, :] = torch.mm(a, diff_v0)\n", + "\n", + " Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),\n", + " dim=1)\n", + " Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),\n", + " dim=1)\n", + " diff_u1 = abs(\n", + " Iu1 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " diff_v1 = abs(\n", + " Iv1 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + "\n", + " if self.method == 'thresholding':\n", + " diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2\n", + " diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2\n", + " elif self.method == 'RBF':\n", + " diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u1 = torch.exp(-diff_u1) # Gaussian\n", + " diff_v1 = torch.exp(-diff_v1)\n", + " elif self.method == 'inverse-quadratic':\n", + " diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic\n", + " diff_v1 = 1 / (1 + diff_v1)\n", + "\n", + " diff_u1 = diff_u1.type(torch.float32)\n", + " diff_v1 = diff_v1.type(torch.float32)\n", + " a = torch.t(Iy * diff_u1)\n", + " if not self.green_only:\n", + " hists[l, 1, :, :] = torch.mm(a, diff_v1)\n", + " else:\n", + " hists[l, 0, :, :] = torch.mm(a, diff_v1)\n", + "\n", + " if not self.green_only:\n", + " Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] +\n", + " EPS), dim=1)\n", + " Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] +\n", + " EPS), dim=1)\n", + " diff_u2 = abs(\n", + " Iu2 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " diff_v2 = abs(\n", + " Iv2 - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " if self.method == 'thresholding':\n", + " diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2\n", + " diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2\n", + " elif self.method == 'RBF':\n", + " diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u2 = torch.exp(-diff_u2) # Gaussian\n", + " diff_v2 = torch.exp(-diff_v2)\n", + " elif self.method == 'inverse-quadratic':\n", + " diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic\n", + " diff_v2 = 1 / (1 + diff_v2)\n", + " diff_u2 = diff_u2.type(torch.float32)\n", + " diff_v2 = diff_v2.type(torch.float32)\n", + " a = torch.t(Iy * diff_u2)\n", + " hists[l, 2, :, :] = torch.mm(a, diff_v2)\n", + "\n", + " # normalization\n", + " hists_normalized = hists / (\n", + " ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)\n", + "\n", + " return hists_normalized\n", + "\n", + "class OriginalrgChromaHistBlock(nn.Module):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=False,\n", + " hist_boundary=None, device='cuda'):\n", + " super().__init__()\n", + " self.h = h\n", + " self.insz = insz\n", + " self.device = device\n", + " self.resizing = resizing\n", + " self.method = method\n", + " self.intensity_scale = intensity_scale\n", + " if hist_boundary is None:\n", + " hist_boundary = [0, 1]\n", + " hist_boundary.sort()\n", + " self.hist_boundary = hist_boundary\n", + " if self.method == 'thresholding':\n", + " self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h\n", + " else:\n", + " self.sigma = sigma\n", + "\n", + " def forward(self, x):\n", + " x = torch.clamp(x, 0, 1)\n", + " if x.shape[2] > self.insz or x.shape[3] > self.insz:\n", + " if self.resizing == 'interpolation':\n", + " x_sampled = F.interpolate(x, size=(self.insz, self.insz),\n", + " mode='bilinear', align_corners=False)\n", + " elif self.resizing == 'sampling':\n", + " inds_1 = torch.LongTensor(\n", + " np.linspace(0, x.shape[2], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " inds_2 = torch.LongTensor(\n", + " np.linspace(0, x.shape[3], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " x_sampled = x.index_select(2, inds_1)\n", + " x_sampled = x_sampled.index_select(3, inds_2)\n", + " else:\n", + " raise Exception(\n", + " f'Wrong resizing method. It should be: interpolation or sampling. '\n", + " f'But the given value is {self.resizing}.')\n", + " else:\n", + " x_sampled = x\n", + "\n", + " L = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to(\n", + " device=self.device)\n", + " for l in range(L):\n", + " I = torch.t(torch.reshape(X[l], (3, -1)))\n", + " II = torch.pow(I, 2)\n", + " if self.intensity_scale:\n", + " Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),\n", + " dim=1)\n", + " else:\n", + " Iy = 1\n", + "\n", + " Ir = torch.unsqueeze(I[:, 0] / (torch.sum(I, dim=-1) + EPS), dim=1)\n", + " Ig = torch.unsqueeze(I[:, 1] / (torch.sum(I, dim=-1) + EPS), dim=1)\n", + "\n", + " diff_r = abs(Ir - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " diff_g = abs(Ig - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + "\n", + " if self.method == 'thresholding':\n", + " diff_r = torch.reshape(diff_r, (-1, self.h)) <= self.eps / 2\n", + " diff_g = torch.reshape(diff_g, (-1, self.h)) <= self.eps / 2\n", + " elif self.method == 'RBF':\n", + " diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_r = torch.exp(-diff_r) # Gaussian\n", + " diff_g = torch.exp(-diff_g)\n", + " elif self.method == 'inverse-quadratic':\n", + " diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_r = 1 / (1 + diff_r) # Inverse quadratic\n", + " diff_g = 1 / (1 + diff_g)\n", + "\n", + " diff_r = diff_r.type(torch.float32)\n", + " diff_g = diff_g.type(torch.float32)\n", + " a = torch.t(Iy * diff_r)\n", + "\n", + " hists[l, 0, :, :] = torch.mm(a, diff_g)\n", + "\n", + " # normalization\n", + " hists_normalized = hists / (\n", + " ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)\n", + "\n", + " return hists_normalized\n", + "\n", + "class OriginalLabHistBlock(nn.Module):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=False,\n", + " hist_boundary=None, device='cuda'):\n", + " super().__init__()\n", + " self.h = h\n", + " self.insz = insz\n", + " self.device = device\n", + " self.resizing = resizing\n", + " self.method = method\n", + " self.intensity_scale = intensity_scale\n", + " if hist_boundary is None:\n", + " hist_boundary = [0, 1]\n", + " hist_boundary.sort()\n", + " self.hist_boundary = hist_boundary\n", + " if self.method == 'thresholding':\n", + " self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h\n", + " else:\n", + " self.sigma = sigma\n", + "\n", + " def forward(self, x):\n", + " x = torch.clamp(x, 0, 1)\n", + " if x.shape[2] > self.insz or x.shape[3] > self.insz:\n", + " if self.resizing == 'interpolation':\n", + " x_sampled = F.interpolate(x, size=(self.insz, self.insz),\n", + " mode='bilinear', align_corners=False)\n", + " elif self.resizing == 'sampling':\n", + " inds_1 = torch.LongTensor(\n", + " np.linspace(0, x.shape[2], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " inds_2 = torch.LongTensor(\n", + " np.linspace(0, x.shape[3], self.h, endpoint=False)).to(\n", + " device=self.device)\n", + " x_sampled = x.index_select(2, inds_1)\n", + " x_sampled = x_sampled.index_select(3, inds_2)\n", + " else:\n", + " raise Exception(\n", + " f'Wrong resizing method. It should be: interpolation or sampling. '\n", + " f'But the given value is {self.resizing}.')\n", + " else:\n", + " x_sampled = x\n", + "\n", + " L = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to(\n", + " device=self.device)\n", + " for l in range(L):\n", + " I = torch.t(torch.reshape(X[l], (3, -1)))\n", + " if self.intensity_scale:\n", + " Il = torch.unsqueeze(I[:, 0], dim=1)\n", + " else:\n", + " Il = 1\n", + "\n", + " Ia = torch.unsqueeze(I[:, 1], dim=1)\n", + " Ib = torch.unsqueeze(I[:, 2], dim=1)\n", + "\n", + " diff_a = abs(Ia - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + " diff_b = abs(Ib - torch.unsqueeze(torch.tensor(np.linspace(\n", + " self.hist_boundary[0], self.hist_boundary[1], num=self.h)),\n", + " dim=0).to(self.device))\n", + "\n", + " if self.method == 'thresholding':\n", + " diff_a = torch.reshape(diff_a, (-1, self.h)) <= self.eps / 2\n", + " diff_b = torch.reshape(diff_b, (-1, self.h)) <= self.eps / 2\n", + " elif self.method == 'RBF':\n", + " diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_a = torch.exp(-diff_a) # Gaussian\n", + " diff_b = torch.exp(-diff_b)\n", + " elif self.method == 'inverse-quadratic':\n", + " diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)),\n", + " 2) / self.sigma ** 2\n", + " diff_a = 1 / (1 + diff_a) # Inverse quadratic\n", + " diff_b = 1 / (1 + diff_b)\n", + "\n", + " diff_a = diff_a.type(torch.float32)\n", + " diff_b = diff_b.type(torch.float32)\n", + " a = torch.t(Il * diff_a)\n", + "\n", + " hists[l, 0, :, :] = torch.mm(a, diff_b)\n", + "\n", + " # normalization\n", + " hists_normalized = hists / (\n", + " ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)\n", + "\n", + " return hists_normalized" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Refactored Implementation\n", + "\n", + "Next, let's define the refactored versions.\n", + "Since pixel sampling, histogram value scaling, and kernel methods are the same across all histogram blocks, it makes sense to extract them into separate functions.\n", + "We can also do some micro-optimizations here, like replacing divisions with multiplications where applicable." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Kernel Functions\n", + "\n", + "We start by defining the kernel functions for intensity scaling, resizing, and pixel counting.\n", + "\n", + "Minor changes: first, squaring the input for intensity scaling has been moved into the intensity scaling function itself.\n", + "This means that when `intensity_scale` is set to `False`, the histogram block's `forward()` method no longer includes this\n", + "calculation, which saves on memory and computation time (only calculate what we need and when we actually need it).\n", + "\n", + "Another change can be found in the sampling method. Here, we no longer use `LongTensor`, which creates 64-bit indexes and\n", + "is slow on many consumer as well as professional GPU devices. It's also unnecessary, since the index values only exceed\n", + "the 32-bit range once we get to image sizes beyond 45k x 45k pixels, i.e. √(2^31) by √(2^31) images. Given that a 3-channel\n", + "image of this size would require ~6GiB or VRAM, it's reasonable to assume that we can limit ourselves to sub-2 gigapixel\n", + "images for the time being." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Union\n", + "\n", + "Device = Union[str, torch.device]\n", + "\n", + "\n", + "def no_scaling(_: torch.Tensor) -> int:\n", + " return 1\n", + "\n", + "\n", + "def intensity_scaling(X: torch.Tensor) -> torch.Tensor:\n", + " XX = X ** 2\n", + " return (XX[:, 0] + XX[:, 1] + XX[:, 2] + EPS).sqrt().unsqueeze(dim=1)\n", + "\n", + "\n", + "def resizing_interpolate(max_size: int, X: torch.Tensor) -> torch.Tensor:\n", + " H, W = X.shape[2:]\n", + " if H > max_size or W > max_size:\n", + " return F.interpolate(\n", + " X, size=(max_size, max_size), mode='bilinear', align_corners=False\n", + " )\n", + " return X\n", + "\n", + "\n", + "def resizing_sample(\n", + " h: int, max_size: int, device: Device, X: torch.Tensor\n", + ") -> torch.Tensor:\n", + " H, W = X.shape[2:]\n", + " if H > max_size or W > max_size:\n", + " index_H = torch.linspace(0, H - H/h, h, dtype=torch.int32).to(device)\n", + " index_W = torch.linspace(0, W - W/h, h, dtype=torch.int32).to(device)\n", + " sampled = X.index_select(dim=2, index=index_H)\n", + " return sampled.index_select(dim=3, index=index_W)\n", + " return X\n", + "\n", + "\n", + "def thresholding_kernel(h: int, eps: float, X: torch.Tensor) -> torch.Tensor:\n", + " return (X.reshape(-1, h) <= eps).float()\n", + "\n", + "\n", + "def rbf_kernel(h: int, inv_sigma_sq: float, X: torch.Tensor) -> torch.Tensor:\n", + " Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq\n", + " return (-Y).exp()\n", + "\n", + "\n", + "def inverse_quadratic_kernel(\n", + " h: int, inv_sigma_sq: float, X: torch.Tensor\n", + ") -> torch.Tensor:\n", + " Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq\n", + " return 1. / (1. + Y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### HistBlock Base Class\n", + "\n", + "Next we define a base class for all histogram blocks. The base class ctor selects the kernel functions depending on the provided parameter and precalculates tensors that only depend on ctor arguments. This includes the delta-values used for calculating differences.\n", + "\n", + "We can compute these once and upload them onto the device in a suitable data format. Factory functions are used to map function names to actual kernel functions.\n", + "Partial function application helps setting kernel function parameters that don't depend on the input tensor and thus can be precalculated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import Callable, List, Sequence, Union\n", + "\n", + "_KernelMethod = Callable[[torch.Tensor], torch.Tensor]\n", + "_Device = Union[str, torch.device]\n", + "\n", + "def _get_resizing(\n", + " mode: str, h: int, max_size: int, device: Device\n", + ") -> _KernelMethod:\n", + " if mode == 'interpolation':\n", + " return partial(resizing_interpolate, max_size)\n", + " elif mode == 'sampling':\n", + " return partial(resizing_sample, h, max_size, device)\n", + " else:\n", + " raise ValueError(\n", + " f'Unknown resizing method: \"{mode}\". Supported methods are '\n", + " '\"interpolation\" or \"sampling\"'\n", + " )\n", + "\n", + "\n", + "def _get_scaling(intensity_scale: bool):\n", + " return intensity_scaling if intensity_scale else no_scaling \n", + "\n", + "\n", + "def _get_kernel(\n", + " method: str, h: int, sigma: float, boundary: Sequence[int]\n", + ") -> _KernelMethod: \n", + " if method == 'thresholding':\n", + " eps = (boundary[1] - boundary[0]) / (2 * h)\n", + " return partial(thresholding_kernel, h, eps)\n", + " elif method == 'RBF':\n", + " inv_sigma_sq = 1 / sigma ** 2\n", + " return partial(rbf_kernel, h, inv_sigma_sq)\n", + " elif method == 'inverse-quadratic':\n", + " inv_sigma_sq = 1 / sigma ** 2\n", + " return partial(inverse_quadratic_kernel, h, inv_sigma_sq)\n", + " else:\n", + " raise ValueError(\n", + " f'Unknown kernel method: \"{method}\". Supported methods are '\n", + " '\"thresholding\", \"RBF\", or \"inverse-quadratic\".'\n", + " )\n", + "\n", + "class HistBlock(nn.Module):\n", + " def __init__(\n", + " self, h: int, insz: int, resizing: str, method: str, sigma: float,\n", + " intensity_scale: str, hist_boundary: List[int], device: _Device\n", + " ) -> None:\n", + " super().__init__()\n", + " hist_boundary.sort()\n", + " start, end = hist_boundary[:2]\n", + " self.h = h\n", + " self.device = torch.device(device)\n", + " self.resize = _get_resizing(resizing, h, insz, self.device)\n", + " self.kernel = _get_kernel(method, h, sigma, hist_boundary)\n", + " self.scaling = _get_scaling(intensity_scale)\n", + " self.delta = torch.linspace(\n", + " start, end, steps=h, device=self.device, dtype=torch.float32\n", + " ).unsqueeze(dim=0)\n", + "\n", + " def forward(self, _: torch.Tensor) -> torch.Tensor:\n", + " raise NotImplementedError()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Refactored Histogram Blocks\n", + "\n", + "With all the pieces in place, we can modify the original histogram blocks to make use of our common components.\n", + "\n", + "The RGB-uv block can be simplified by observing that the difference calculations for the channels only differ in\n", + "tensor indexing. We can extract the calculation into a function and pass these indexes as arguments to the difference\n", + "calculation.\n", + "\n", + "We can also speed up historgram normalization by only summing elements once. This will cause a little more numeric\n", + "instability due to loss of significance with unsorted tensor values. Performance is a bit better, though, and the\n", + "differences should be minimal, but we get back to later in the validation part.\n", + "\n", + "Just from personal prefence and for sake of consistency, member functions are used on tensors where appropriate\n", + "(type hints would also help a lot, but I didn't want to change too much)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RGBuvHistBlock(HistBlock):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=True,\n", + " hist_boundary=None, green_only=False, device='cuda'):\n", + " super().__init__(\n", + " h, insz, resizing, method, sigma, intensity_scale,\n", + " hist_boundary or [-3, 3], device\n", + " )\n", + " self.green_only = green_only\n", + "\n", + " def forward(self, x):\n", + " x_sampled = self.resize(x.clamp(0, 1))\n", + "\n", + " N = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " C = 1 + int(not self.green_only) * 2\n", + " hists = torch.zeros(N, C, self.h, self.h, device=self.device)\n", + " for n in range(N):\n", + " Ix = X[n].reshape(3, -1).t()\n", + " Iy = self.scaling(Ix)\n", + " if not self.green_only:\n", + " Du, Dv = self._diff_uv(Ix, i=0, j=1, k=2)\n", + " a = (Iy * Du).t()\n", + " hists[n, 0, :, :] = torch.mm(a, Dv)\n", + "\n", + " Du, Dv = self._diff_uv(Ix, i=1, j=0, k=2)\n", + " a = (Iy * Du).t()\n", + " hists[n, int(not self.green_only), :, :] = torch.mm(a, Dv)\n", + "\n", + " if not self.green_only:\n", + " Du, Dv = self._diff_uv(Ix, i=2, j=0, k=1)\n", + " a = (Iy * Du).t()\n", + " hists[n, 2, :, :] = torch.mm(a, Dv)\n", + "\n", + " # normalization\n", + " norm = hists.view(-1, C * self.h * self.h).sum(dim=1).view(-1, 1, 1, 1)\n", + " hists_normalized = hists / (norm + EPS)\n", + "\n", + " return hists_normalized\n", + "\n", + " def _diff_uv(self, X: torch.Tensor, i: int, j: int, k: int):\n", + " U = ((X[:, i] + EPS).log() - (X[:, j] + EPS).log()).unsqueeze(dim=1)\n", + " V = ((X[:, i] + EPS).log() - (X[:, k] + EPS).log()).unsqueeze(dim=1)\n", + " Du = (U - self.delta).abs()\n", + " Dv = (V - self.delta).abs()\n", + " Du = self.kernel(Du)\n", + " Dv = self.kernel(Dv)\n", + " return Du, Dv\n", + "\n", + "class rgChromaHistBlock(HistBlock):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=False,\n", + " hist_boundary=None, device='cuda'):\n", + " super().__init__(\n", + " h, insz, resizing, method, sigma, intensity_scale, \n", + " hist_boundary or [0, 1], device\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x_sampled = self.resize(x.clamp(0, 1))\n", + "\n", + " N = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " hists = torch.zeros(N, 1, self.h, self.h, device=self.device)\n", + " for n in range(N):\n", + " Ix = X[n].reshape(3, -1).t()\n", + " Inorm = Ix.sum(dim=-1) + EPS\n", + " Ir = (Ix[:, 0] / Inorm).unsqueeze(dim=1)\n", + " Ig = (Ix[:, 1] / Inorm).unsqueeze(dim=1)\n", + "\n", + " diff_r = (Ir - self.delta).abs()\n", + " diff_g = (Ig - self.delta).abs()\n", + " diff_r = self.kernel(diff_r)\n", + " diff_g = self.kernel(diff_g)\n", + " Iy = self.scaling(Ix)\n", + " a = torch.t(Iy * diff_r)\n", + "\n", + " hists[n, 0, :, :] = torch.mm(a, diff_g)\n", + "\n", + " # normalization\n", + " norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS\n", + " hists_normalized = hists / norm\n", + "\n", + " return hists_normalized\n", + "\n", + "class LabHistBlock(HistBlock):\n", + " def __init__(self, h=64, insz=150, resizing='interpolation',\n", + " method='inverse-quadratic', sigma=0.02, intensity_scale=False,\n", + " hist_boundary=None, device='cuda'):\n", + " super().__init__(\n", + " h, insz, resizing, method, sigma, intensity_scale, \n", + " hist_boundary or [0, 1], device\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x_sampled = self.resize(x.clamp(0, 1))\n", + "\n", + " N = x_sampled.shape[0] # size of mini-batch\n", + " if x_sampled.shape[1] > 3:\n", + " x_sampled = x_sampled[:, :3, :, :]\n", + " X = torch.unbind(x_sampled, dim=0)\n", + " hists = torch.zeros(N, 1, self.h, self.h, device=self.device)\n", + " for n in range(N):\n", + " Ix = X[n].reshape(3, -1).t()\n", + "\n", + " Ia = Ix[:, 1].unsqueeze(dim=1)\n", + " Ib = Ix[:, 2].unsqueeze(dim=1)\n", + "\n", + " diff_a = (Ia - self.delta).abs()\n", + " diff_b = (Ib - self.delta).abd()\n", + "\n", + " diff_a = self.kernel(diff_a)\n", + " diff_b = self.kernel(diff_b)\n", + " Iy = self.scaling(Ix)\n", + " a = torch.t(Iy * diff_a)\n", + "\n", + " hists[n, 0, :, :] = torch.mm(a, diff_b)\n", + "\n", + " # normalization\n", + " norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS\n", + " hists_normalized = hists / norm\n", + "\n", + " return hists_normalized" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Validation\n", + "\n", + "In order to validate our work, let's run A-B-tests for each possible parameter combination.\n", + "We can keep most values at their default, but resizing, sampling, and intensity scaling options should be thoroughly tested.\n", + "We can define a `dict` that holds all test cases - histogram block classes and the tested parameters.\n", + "\n", + "Next we run both reference and refactored models with a batch of randomly generated images and compare the results.\n", + "For the comparison, we use the *arctangent absolute percentage error* (AAPE) as proposed in\n", + "\n", + " Sungil Kim, Heeyoung Kim,\n", + " \"A new metric of absolute percentage error for intermittent demand forecasts\",\n", + " International Journal of Forecasting,\n", + " Volume 32, Issue 3,\n", + " 2016,\n", + " Pages 669-679,\n", + " https://doi.org/10.1016/j.ijforecast.2015.12.003\n", + "\n", + "with the AAPE rescaled from its original [0, ½π] range to [0, 100] to obtain more readable percentages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DEVICE = 'cuda' # device to run the tests on (e.g. 'cuda' or 'cpu')\n", + "BATCHES = 8 # how many samples per mini-batch\n", + "SAMPLE_SIZE = 256 # sample image size in pixels\n", + "ERR_THRESHOLD = 0.05 # validation error threshold in percent\n", + "RANDOM_SEED = 4793 # for reproducibility we seed the rng, use torch.random.seed() instead to explore\n", + " # the selected seed produces a more colourful output ;)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass, field\n", + "from itertools import product, repeat\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "RESIZING_VALS = ['interpolation', 'sampling']\n", + "BOOL_VALS = [False, True]\n", + "METHOD_VALS = ['thresholding', 'RBF', 'inverse-quadratic']\n", + "TESTS = {\n", + " 'RGB-uv': {\n", + " 'A': OriginalRGBuvHistBlock, 'B': RGBuvHistBlock,\n", + " 'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS, 'green_only': BOOL_VALS}\n", + " },\n", + " 'rg-chroma': {\n", + " 'A': OriginalrgChromaHistBlock, 'B': rgChromaHistBlock,\n", + " 'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS}\n", + " },\n", + " 'Lab': {\n", + " 'A': OriginalrgChromaHistBlock, 'B': rgChromaHistBlock,\n", + " 'params': {'resizing': RESIZING_VALS, 'method': METHOD_VALS, 'intensity_scale': BOOL_VALS}\n", + " }\n", + "}\n", + "\n", + "def _to_dict(names, values):\n", + " return {key: val for key, val in zip(names, values)}\n", + "\n", + "def _param_info(p):\n", + " if isinstance(p[1], bool):\n", + " return f\"{'+' if p[1] else '-'}{p[0][:3].upper()}\"\n", + " return p[1][:4].upper()\n", + "\n", + "@dataclass\n", + "class Result:\n", + " test: str\n", + " params: str\n", + " min_err: float\n", + " max_err: float\n", + " avg_err: float\n", + " median_err: float\n", + " outcome: str = field(init=False)\n", + "\n", + " def __post_init__(self):\n", + " valid = self.avg_err < ERR_THRESHOLD\n", + " self.outcome = 'PASS' if valid else 'FAILED'\n", + "\n", + "def _to_row(result: Result):\n", + " return [\n", + " result.test, result.params, f'{result.min_err:.6f}', f'{result.max_err:.6f}',\n", + " f'{result.avg_err:.6f}', f'{result.median_err:.6f}', result.outcome\n", + " ]\n", + "\n", + "def _aape(X, Y):\n", + " eps = torch.full_like(input=X, fill_value=1.1921e-7).float()\n", + " phi = ((X - Y).abs() / torch.maximum(X.abs(), eps)).arctan()\n", + " return phi * (2 / torch.pi) * 100\n", + "\n", + "torch.random.manual_seed(RANDOM_SEED)\n", + "samples = torch.randint(low=0, high=256, size=(BATCHES, 3, SAMPLE_SIZE, SAMPLE_SIZE))\n", + "samples = (samples / 255.).float().to(DEVICE)\n", + "\n", + "def _validate(name, A_, B_, model_args):\n", + " A = A_(**model_args).eval().to(DEVICE)\n", + " B = B_(**model_args).eval().to(DEVICE)\n", + " with torch.no_grad():\n", + " Ay = A(samples)\n", + " By = B(samples)\n", + " err = _aape(Ay, By)\n", + " params = ','.join(map(_param_info, model_args.items()))\n", + " return Result(name, params, err.min().item(), err.max().item(), err.mean().item(), err.median().item())\n", + "\n", + "validation_results: List[Result] = []\n", + "for name, args in TESTS.items():\n", + " A, B, params = args['A'], args['B'], args['params']\n", + " test_params = map(_to_dict, repeat(params), product(*params.values()))\n", + " items = map(_validate, repeat(name), repeat(A), repeat(B), test_params)\n", + " validation_results.extend(items)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize the validation results next - a table will do for now.\n", + "\n", + "The tested parameter combinations are listed by their first four letters in caps. Boolean flags are indicated by -*FLAG*\n", + "if the parameter is `False` and +*FLAG* if the parameter is set to `True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if any(filter(lambda i: i.outcome == 'FAILED', validation_results)):\n", + " print(f'Got some failures; RANDOM_SEED to reproduce: {RANDOM_SEED}')\n", + "\n", + "# plot the results\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "col_labels = ['HIST BLOCK', 'PARAMS', 'MIN ERR %', 'MAX ERR %', 'AVG ERR %', 'MEDIAN ERR %', 'RESULT']\n", + "ax.axis('tight')\n", + "ax.axis('off')\n", + "tbl = ax.table(\n", + " cellText=list(map(_to_row, validation_results)), colLabels=col_labels, loc='center',\n", + " colColours=['slategrey']*7\n", + ")\n", + "# format table\n", + "_ = list(map(lambda col: tbl[(0, col)].set_text_props(fontweight='bold'), range(7)))\n", + "for row, item in enumerate(validation_results):\n", + " tbl[(row+1, 0)].set_facecolor('lightsteelblue')\n", + " if item.min_err > ERR_THRESHOLD:\n", + " tbl[(row+1, 2)].set_facecolor('darkorange')\n", + " if item.max_err > ERR_THRESHOLD:\n", + " tbl[(row+1, 3)].set_facecolor('darkorange')\n", + " if item.avg_err > ERR_THRESHOLD:\n", + " tbl[(row+1, 4)].set_facecolor('darkorange')\n", + " if item.median_err > ERR_THRESHOLD:\n", + " tbl[(row+1, 5)].set_facecolor('darkorange')\n", + " tbl[(row+1, 6)].set_facecolor('g' if item.outcome=='PASS' else 'r')\n", + " tbl[(row+1, 6)].set_text_props(fontweight='bold')\n", + "tbl.auto_set_font_size(False)\n", + "tbl.set_fontsize(10)\n", + "tbl.scale(2, 2)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Validation Result Discussion\n", + "\n", + "The chosen criterion for passing validation is a somewhat arbitrary but low average error threshold\n", + "(across all samples in the mini-batch).\n", + "The reported maximum error might exceed this threshold (I've seen values as high as 10%), but we\n", + "need to keep two things in mind here:\n", + "\n", + "First, the maximum error refers to a single bucket value in the histogram. A single outlier in a\n", + "64x64 histogram shouldn't account for a validation failure.\n", + "\n", + "Secondly, I noticed that the deviations occur in conjunction with interpolation only. Match the\n", + "sample dimensions (`SAMPLE_SIZE`) with the maximum histogram input size (`insz`) and the errors\n", + "go away. Since the resize function is identical to the original version in every way, I'm a bit at a\n", + "loss as to why that is.\n", + "\n", + "## Benchmark\n", + "\n", + "With that out of the way, let's get some performance numbers to see whether we actually improved things.\n", + "We assess the performance differences by running each histogram block a given number of times on a\n", + "mini-batch of random samples. Inference time is measured and results are plotted to a diagram." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ITERATIONS = 100 # number of benchmark passes per model\n", + "BATCHES = 16 # mini-batch size\n", + "SAMPLE_SIZE = 256 # sample image size for benchmarking\n", + "DEVICE = 'cuda' # computation device to run the benchmark on (e.g. 'cuda' or 'cpu')\n", + "RANDOM_SEED = 0 # random seed for producing sample data (again, for reproducibility)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "from time import perf_counter\n", + "from typing import Dict\n", + "import numpy as np\n", + "\n", + "\n", + "BASELINE = [(key, val['A']) for key, val in TESTS.items()]\n", + "REFACTORED = [(key, val['B']) for key, val in TESTS.items()]\n", + "\n", + "torch.random.manual_seed(RANDOM_SEED)\n", + "\n", + "\n", + "def _gen_minibatch():\n", + " while True:\n", + " X = torch.randint(low=0, high=256, size=(BATCHES, 3, SAMPLE_SIZE, SAMPLE_SIZE))\n", + " yield (X / 255.).float().to(DEVICE)\n", + "\n", + "\n", + "def _benchmark(model, sample):\n", + " start = perf_counter()\n", + " _ = model(sample)\n", + " return (perf_counter() - start) * 1_000\n", + "\n", + "\n", + "baseline_results: Dict[str, float] = { }\n", + "for name, Model in BASELINE:\n", + " with torch.no_grad():\n", + " models = repeat(Model().eval().to(DEVICE), times=ITERATIONS)\n", + " runs = tqdm(models, total=ITERATIONS, desc=f'Benchmarking baseline {name}')\n", + " baseline_results[name] = list(map(_benchmark, runs, _gen_minibatch()))\n", + "\n", + "refactored_results: Dict[str, float] = { }\n", + "for name, Model in REFACTORED:\n", + " with torch.no_grad():\n", + " models = repeat(Model().eval().to(DEVICE), times=ITERATIONS)\n", + " runs = tqdm(models, total=ITERATIONS, desc=f'Benchmarking refactored {name}')\n", + " refactored_results[name] = list(map(_benchmark, runs, _gen_minibatch()))\n", + "\n", + "a = np.array(list(baseline_results.values()))\n", + "a_mins = a.min(axis=1)\n", + "a_maxes = a.max(axis=1)\n", + "a_means = a.mean(axis=1)\n", + "a_std = a.std(axis=1)\n", + "\n", + "b = np.array(list(refactored_results.values()))\n", + "b_mins = b.min(axis=1)\n", + "b_maxes = b.max(axis=1)\n", + "b_means = b.mean(axis=1)\n", + "b_std = b.std(axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With our data collected, let's print the mean relative execution time differences and plot some charts. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import chain\n", + "\n", + "N = len(baseline_results)\n", + "\n", + "for name, speedup in zip(TESTS, a_means / b_means):\n", + " print(f'Refactored {name}: {speedup:.1f}x faster on average')\n", + "\n", + "fig, ax = plt.subplots(3, 1, figsize=(12, 18))\n", + "\n", + "tick_labels = [(f'{n} (baseline)', f'{n} (refactored)') for n in TESTS]\n", + "tick_labels = list(chain(*tick_labels))\n", + "ax[0].errorbar(np.arange(N) * 2, a_means, a_std, fmt='_k', lw=3, ms=11, capsize=3)\n", + "ax[0].errorbar(np.arange(N) * 2, a_means, [a_means - a_mins, a_maxes - a_means], fmt='.k', ecolor='grey', lw=1, capsize=3)\n", + "ax[0].errorbar(np.arange(N) * 2 + 1, b_means, b_std, fmt='_b', lw=3, ms=11, capsize=3)\n", + "ax[0].errorbar(np.arange(N) * 2 + 1, b_means, [b_means - b_mins, b_maxes - b_means], fmt='.k', ecolor='lightsteelblue', lw=1, capsize=3)\n", + "ax[0].set_xticks(np.arange(2*N), minor=False)\n", + "ax[0].set_xtick_labels(tick_labels)\n", + "ax[0].set_title(f'Benchmark results for {ITERATIONS} iterations and mini-batch size of {BATCHES}')\n", + "ax[0].set_ylabel('Iteration time in ms')\n", + "\n", + "labels = [name for name in TESTS]\n", + "width = 0.35\n", + "ax[1].bar(labels, a_means, width, yerr=a_std, label='Baseline', capsize=3)\n", + "ax[1].bar(labels, b_means, width, yerr=b_std, label='Refactored', capsize=3)\n", + "ax[1].set_ylabel('Iteration in ms')\n", + "ax[1].set_title('Execution time difference')\n", + "ax[1].legend()\n", + "\n", + "\n", + "a_mean_its = 1_000 / a_means\n", + "b_mean_its = 1_000 / b_means\n", + "a_it_std = (1_000 / a).std(axis=1)\n", + "b_it_std = (1_000 / b).std(axis=1)\n", + "width = 0.35\n", + "ax[2].bar(labels, a_mean_its, width, yerr=a_it_std, label='Baseline', capsize=3)\n", + "ax[2].bar(labels, b_mean_its, width, yerr=b_it_std, label='Refactored', bottom=a_mean_its, capsize=3)\n", + "ax[2].set_ylabel('Iteration per second')\n", + "ax[2].set_title('Performance difference')\n", + "ax[2].legend()\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "57164f1c0f1b3bb0c5f993d1ba49aa53e928357509b375fbcec01d49bae2dae4" + }, + "kernelspec": { + "display_name": "Python 3.9.12 ('pytorch')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/histogram_classes/HistBlock.py b/histogram_classes/HistBlock.py new file mode 100644 index 0000000..915d19c --- /dev/null +++ b/histogram_classes/HistBlock.py @@ -0,0 +1,191 @@ +"""Base class for color histogram blocks. + +##### Copyright 2021 Mahmoud Afifi. + + If you find this code useful, please cite our paper: + + Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN: + Controlling Colors of GAN-Generated and Real Images via Color Histograms." + In CVPR, 2021. + + @inproceedings{afifi2021histogan, + title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via + Color Histograms}, + author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.}, + booktitle={CVPR}, + year={2021} +} +#### + +Portions Copyright (c) 2022 Patrick Levin. +""" +from functools import partial +from typing import Callable, List, Sequence, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_Device = Union[str, torch.device] + + +EPS = 1e-6 + + +class HistBlock(nn.Module): + """Histogram block for calculating colour histograms for 3-channel + tensors. + + Args: + h (int): histogram dimension size (number of bins). + + insz (int): Maximum size of the input image; if it is larger than + this size, the image will be resized to (`insz`, `insz`). + + resizing (str): resizing method if applicable. Options are: + 'interpolation' and 'sampling'. + + method (str): the method used to count the number of pixels for + each bin in the histogram feature. Options are: + 'thresholding', 'RBF' (radial basis function), and + 'inverse-quadratic' + + sigma (float): if the method value is 'RBF' or 'inverse-quadratic', + then this is the sigma parameter of the kernel function. + + intensity_scale (Scaling): intensity scale method to obtain scaling + values (I_y in Equation 2). + + hist_boundary (list[int]): A list of histogram boundary values. + The list must have two entries; additional values are ignored + if present. + + device (str|device): computation device (name or instance) + + Methods: + forward: accepts input image and returns its histogram feature. + Note that unless the method is `Method.THRESHOLDING`, this is a + differentiable function and can be easily integrated with + the loss function. As mentioned in the paper, + `Method.INVERSE_QUADTRATIC` was found more stable than + `Method.RADIAL_BASIS_FUNCTION`. + """ + def __init__( + self, + h: int, + insz: int, + resizing: str, + method: str, + sigma: float, + intensity_scale: str, + hist_boundary: List[int], + device: _Device + ) -> None: + super().__init__() + hist_boundary.sort() + start, end = hist_boundary[:2] + self.h = h + self.device = torch.device(device) + self.resize = _get_resizing(resizing, h, insz, self.device) + self.kernel = _get_kernel(method, h, sigma, hist_boundary) + self.scaling = _get_scaling(intensity_scale) + self.delta = torch.linspace( + start, end, steps=h, device=self.device, dtype=torch.float32 + ).unsqueeze(dim=0) + + def forward(self, _: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + +_KernelMethod = Callable[[torch.Tensor], torch.Tensor] + + +# ---------------------------- Factory Functions ----------------------------- + +def _get_resizing( + mode: str, h: int, max_size: int, device: _Device +) -> _KernelMethod: + if mode == 'interpolation': + return partial(_resizing_interpolate, max_size) + elif mode == 'sampling': + return partial(_resizing_sample, h, max_size, device) + else: + raise ValueError( + f'Unknown resizing method: "{mode}". Supported methods are ' + '"interpolation" or "sampling"' + ) + + +def _get_scaling(intensity_scale: bool): + return _intensity_scaling if intensity_scale else _no_scaling + + +def _get_kernel( + method: str, h: int, sigma: float, boundary: Sequence[int] +) -> _KernelMethod: + if method == 'thresholding': + eps = (boundary[1] - boundary[0]) / (2 * h) + return partial(_thresholding_kernel, h, eps) + elif method == 'RBF': + inv_sigma_sq = 1 / sigma ** 2 + return partial(_rbf_kernel, h, inv_sigma_sq) + elif method == 'inverse-quadratic': + inv_sigma_sq = 1 / sigma ** 2 + return partial(_inverse_quadratic_kernel, h, inv_sigma_sq) + else: + raise ValueError( + f'Unknown kernel method: "{method}". Supported methods are ' + '"thresholding", "RBF", or "inverse-quadratic".' + ) + + +# ----------------------------- Resizing Kernels ----------------------------- + +def _resizing_interpolate(max_size: int, X: torch.Tensor) -> torch.Tensor: + H, W = X.shape[2:] + if H > max_size or W > max_size: + return F.interpolate( + X, size=(max_size, max_size), mode='bilinear', align_corners=False + ) + return X + + +def _resizing_sample( + h: int, max_size: int, device: _Device, X: torch.Tensor +) -> torch.Tensor: + H, W = X.shape[2:] + if H > max_size or W > max_size: + index_H = torch.linspace(0, H - H/h, h, dtype=torch.int32).to(device) + index_W = torch.linspace(0, W - W/h, h, dtype=torch.int32).to(device) + sampled = X.index_select(dim=2, index=index_H) + return sampled.index_select(dim=3, index=index_W) + return X + + +# ---------------------------- Scaling Functions ----------------------------- + +def _no_scaling(_: torch.Tensor) -> int: + return 1 + + +def _intensity_scaling(X: torch.Tensor) -> torch.Tensor: + XX = X ** 2 + return (XX[:, 0] + XX[:, 1] + XX[:, 2] + EPS).sqrt().unsqueeze(dim=1) + + +# ----------------------------- Kernel Functions ----------------------------- + +def _thresholding_kernel(h: int, eps: float, X: torch.Tensor) -> torch.Tensor: + return (X.reshape(-1, h) <= eps).float() + + +def _rbf_kernel(h: int, inv_sigma_sq: float, X: torch.Tensor) -> torch.Tensor: + Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq + return (-Y).exp() + + +def _inverse_quadratic_kernel( + h: int, inv_sigma_sq: float, X: torch.Tensor +) -> torch.Tensor: + Y = (X.reshape(-1, h) ** 2) * inv_sigma_sq + return 1. / (1. + Y) diff --git a/histogram_classes/LabHistBlock.py b/histogram_classes/LabHistBlock.py index 6d0bff2..d63d75e 100644 --- a/histogram_classes/LabHistBlock.py +++ b/histogram_classes/LabHistBlock.py @@ -15,18 +15,13 @@ year={2021} } #### +Portions Copyright (c) 2022 Patrick Levin. """ - +from histogram_classes.HistBlock import EPS, HistBlock import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -EPS = 1e-6 - -class LabHistBlock(nn.Module): +class LabHistBlock(HistBlock): def __init__(self, h=64, insz=150, resizing='interpolation', method='inverse-quadratic', sigma=0.02, intensity_scale=False, hist_boundary=None, device='cuda'): @@ -54,93 +49,37 @@ def __init__(self, h=64, insz=150, resizing='interpolation', paper, the 'inverse-quadratic' was found more stable than 'RBF' in our training. """ - super(LabHistBlock, self).__init__() - self.h = h - self.insz = insz - self.device = device - self.resizing = resizing - self.method = method - self.intensity_scale = intensity_scale - if hist_boundary is None: - hist_boundary = [0, 1] - hist_boundary.sort() - self.hist_boundary = hist_boundary - if self.method == 'thresholding': - self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h - else: - self.sigma = sigma + super().__init__( + h, insz, resizing, method, sigma, intensity_scale, + hist_boundary or [0, 1], device + ) def forward(self, x): - x = torch.clamp(x, 0, 1) - if x.shape[2] > self.insz or x.shape[3] > self.insz: - if self.resizing == 'interpolation': - x_sampled = F.interpolate(x, size=(self.insz, self.insz), - mode='bilinear', align_corners=False) - elif self.resizing == 'sampling': - inds_1 = torch.LongTensor( - np.linspace(0, x.shape[2], self.h, endpoint=False)).to( - device=self.device) - inds_2 = torch.LongTensor( - np.linspace(0, x.shape[3], self.h, endpoint=False)).to( - device=self.device) - x_sampled = x.index_select(2, inds_1) - x_sampled = x_sampled.index_select(3, inds_2) - else: - raise Exception( - f'Wrong resizing method. It should be: interpolation or sampling. ' - f'But the given value is {self.resizing}.') - else: - x_sampled = x + x_sampled = self.resize(x.clamp(0, 1)) - L = x_sampled.shape[0] # size of mini-batch + N = x_sampled.shape[0] # size of mini-batch if x_sampled.shape[1] > 3: x_sampled = x_sampled[:, :3, :, :] X = torch.unbind(x_sampled, dim=0) - hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to( - device=self.device) - for l in range(L): - I = torch.t(torch.reshape(X[l], (3, -1))) - if self.intensity_scale: - Il = torch.unsqueeze(I[:, 0], dim=1) - else: - Il = 1 - - Ia = torch.unsqueeze(I[:, 1], dim=1) - Ib = torch.unsqueeze(I[:, 2], dim=1) + hists = torch.zeros(N, 1, self.h, self.h, device=self.device) + for n in range(N): + Ix = X[n].reshape(3, -1).t() - diff_a = abs(Ia - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - diff_b = abs(Ib - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) + Ia = Ix[:, 1].unsqueeze(dim=1) + Ib = Ix[:, 2].unsqueeze(dim=1) - if self.method == 'thresholding': - diff_a = torch.reshape(diff_a, (-1, self.h)) <= self.eps / 2 - diff_b = torch.reshape(diff_b, (-1, self.h)) <= self.eps / 2 - elif self.method == 'RBF': - diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)), - 2) / self.sigma ** 2 - diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)), - 2) / self.sigma ** 2 - diff_a = torch.exp(-diff_a) # Gaussian - diff_b = torch.exp(-diff_b) - elif self.method == 'inverse-quadratic': - diff_a = torch.pow(torch.reshape(diff_a, (-1, self.h)), - 2) / self.sigma ** 2 - diff_b = torch.pow(torch.reshape(diff_b, (-1, self.h)), - 2) / self.sigma ** 2 - diff_a = 1 / (1 + diff_a) # Inverse quadratic - diff_b = 1 / (1 + diff_b) + diff_a = (Ia - self.delta).abs() + diff_b = (Ib - self.delta).abd() - diff_a = diff_a.type(torch.float32) - diff_b = diff_b.type(torch.float32) - a = torch.t(Il * diff_a) + diff_a = self.kernel(diff_a) + diff_b = self.kernel(diff_b) + Iy = self.scaling(Ix) + a = torch.t(Iy * diff_a) - hists[l, 0, :, :] = torch.mm(a, diff_b) + hists[n, 0, :, :] = torch.mm(a, diff_b) # normalization - hists_normalized = hists / ( - ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) + norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS + hists_normalized = hists / norm return hists_normalized \ No newline at end of file diff --git a/histogram_classes/RGBuvHistBlock.py b/histogram_classes/RGBuvHistBlock.py index 76ec9a2..491377c 100644 --- a/histogram_classes/RGBuvHistBlock.py +++ b/histogram_classes/RGBuvHistBlock.py @@ -15,17 +15,15 @@ year={2021} } #### + +Portions Copyright (c) 2022 Patrick Levin. """ +from histogram_classes.HistBlock import EPS, HistBlock import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -EPS = 1e-6 -class RGBuvHistBlock(nn.Module): +class RGBuvHistBlock(HistBlock): def __init__(self, h=64, insz=150, resizing='interpolation', method='inverse-quadratic', sigma=0.02, intensity_scale=True, hist_boundary=None, green_only=False, device='cuda'): @@ -55,174 +53,49 @@ def __init__(self, h=64, insz=150, resizing='interpolation', paper, the 'inverse-quadratic' was found more stable than 'RBF' in our training. """ - super(RGBuvHistBlock, self).__init__() - self.h = h - self.insz = insz - self.device = device - self.resizing = resizing - self.method = method - self.intensity_scale = intensity_scale + super().__init__( + h, insz, resizing, method, sigma, intensity_scale, + hist_boundary or [-3, 3], device + ) self.green_only = green_only - if hist_boundary is None: - hist_boundary = [-3, 3] - hist_boundary.sort() - self.hist_boundary = hist_boundary - if self.method == 'thresholding': - self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h - else: - self.sigma = sigma def forward(self, x): - x = torch.clamp(x, 0, 1) - if x.shape[2] > self.insz or x.shape[3] > self.insz: - if self.resizing == 'interpolation': - x_sampled = F.interpolate(x, size=(self.insz, self.insz), - mode='bilinear', align_corners=False) - elif self.resizing == 'sampling': - inds_1 = torch.LongTensor( - np.linspace(0, x.shape[2], self.h, endpoint=False)).to( - device=self.device) - inds_2 = torch.LongTensor( - np.linspace(0, x.shape[3], self.h, endpoint=False)).to( - device=self.device) - x_sampled = x.index_select(2, inds_1) - x_sampled = x_sampled.index_select(3, inds_2) - else: - raise Exception( - f'Wrong resizing method. It should be: interpolation or sampling. ' - f'But the given value is {self.resizing}.') - else: - x_sampled = x + x_sampled = self.resize(x.clamp(0, 1)) - L = x_sampled.shape[0] # size of mini-batch + N = x_sampled.shape[0] # size of mini-batch if x_sampled.shape[1] > 3: x_sampled = x_sampled[:, :3, :, :] X = torch.unbind(x_sampled, dim=0) - hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2, - self.h, self.h)).to(device=self.device) - for l in range(L): - I = torch.t(torch.reshape(X[l], (3, -1))) - II = torch.pow(I, 2) - if self.intensity_scale: - Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS), - dim=1) - else: - Iy = 1 + C = 1 + int(not self.green_only) * 2 + hists = torch.zeros(N, C, self.h, self.h, device=self.device) + for n in range(N): + Ix = X[n].reshape(3, -1).t() + Iy = self.scaling(Ix) if not self.green_only: - Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + - EPS), dim=1) - Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + - EPS), dim=1) - diff_u0 = abs( - Iu0 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - diff_v0 = abs( - Iv0 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - if self.method == 'thresholding': - diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2 - diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2 - elif self.method == 'RBF': - diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u0 = torch.exp(-diff_u0) # Radial basis function - diff_v0 = torch.exp(-diff_v0) - elif self.method == 'inverse-quadratic': - diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic - diff_v0 = 1 / (1 + diff_v0) - else: - raise Exception( - f'Wrong kernel method. It should be either thresholding, RBF,' - f' inverse-quadratic. But the given value is {self.method}.') - diff_u0 = diff_u0.type(torch.float32) - diff_v0 = diff_v0.type(torch.float32) - a = torch.t(Iy * diff_u0) - hists[l, 0, :, :] = torch.mm(a, diff_v0) - - Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS), - dim=1) - Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS), - dim=1) - diff_u1 = abs( - Iu1 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - diff_v1 = abs( - Iv1 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) + Du, Dv = self._diff_uv(Ix, i=0, j=1, k=2) + a = (Iy * Du).t() + hists[n, 0, :, :] = torch.mm(a, Dv) - if self.method == 'thresholding': - diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2 - diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2 - elif self.method == 'RBF': - diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u1 = torch.exp(-diff_u1) # Gaussian - diff_v1 = torch.exp(-diff_v1) - elif self.method == 'inverse-quadratic': - diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic - diff_v1 = 1 / (1 + diff_v1) + Du, Dv = self._diff_uv(Ix, i=1, j=0, k=2) + a = (Iy * Du).t() + hists[n, int(not self.green_only), :, :] = torch.mm(a, Dv) - diff_u1 = diff_u1.type(torch.float32) - diff_v1 = diff_v1.type(torch.float32) - a = torch.t(Iy * diff_u1) if not self.green_only: - hists[l, 1, :, :] = torch.mm(a, diff_v1) - else: - hists[l, 0, :, :] = torch.mm(a, diff_v1) - - if not self.green_only: - Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + - EPS), dim=1) - Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + - EPS), dim=1) - diff_u2 = abs( - Iu2 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - diff_v2 = abs( - Iv2 - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - if self.method == 'thresholding': - diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2 - diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2 - elif self.method == 'RBF': - diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u2 = torch.exp(-diff_u2) # Gaussian - diff_v2 = torch.exp(-diff_v2) - elif self.method == 'inverse-quadratic': - diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), - 2) / self.sigma ** 2 - diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), - 2) / self.sigma ** 2 - diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic - diff_v2 = 1 / (1 + diff_v2) - diff_u2 = diff_u2.type(torch.float32) - diff_v2 = diff_v2.type(torch.float32) - a = torch.t(Iy * diff_u2) - hists[l, 2, :, :] = torch.mm(a, diff_v2) + Du, Dv = self._diff_uv(Ix, i=2, j=0, k=1) + a = (Iy * Du).t() + hists[n, 2, :, :] = torch.mm(a, Dv) # normalization - hists_normalized = hists / ( - ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) - - return hists_normalized \ No newline at end of file + norm = hists.view(-1, C * self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + hists_normalized = hists / (norm + EPS) + + return hists_normalized + + def _diff_uv(self, X: torch.Tensor, i: int, j: int, k: int): + U = ((X[:, i] + EPS).log() - (X[:, j] + EPS).log()).unsqueeze(dim=1) + V = ((X[:, i] + EPS).log() - (X[:, k] + EPS).log()).unsqueeze(dim=1) + Du = (U - self.delta).abs() + Dv = (V - self.delta).abs() + Du = self.kernel(Du) + Dv = self.kernel(Dv) + return Du, Dv diff --git a/histogram_classes/rgChromaHistBlock.py b/histogram_classes/rgChromaHistBlock.py index 5ac8a07..4c7ddab 100644 --- a/histogram_classes/rgChromaHistBlock.py +++ b/histogram_classes/rgChromaHistBlock.py @@ -15,17 +15,13 @@ year={2021} } #### +Portions Copyright (c) 2022 Patrick Levin. """ - +from histogram_classes.HistBlock import EPS, HistBlock import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -EPS = 1e-6 -class rgChromaHistBlock(nn.Module): +class rgChromaHistBlock(HistBlock): def __init__(self, h=64, insz=150, resizing='interpolation', method='inverse-quadratic', sigma=0.02, intensity_scale=False, hist_boundary=None, device='cuda'): @@ -53,95 +49,36 @@ def __init__(self, h=64, insz=150, resizing='interpolation', paper, the 'inverse-quadratic' was found more stable than 'RBF' in our training. """ - super(rgChromaHistBlock, self).__init__() - self.h = h - self.insz = insz - self.device = device - self.resizing = resizing - self.method = method - self.intensity_scale = intensity_scale - if hist_boundary is None: - hist_boundary = [0, 1] - hist_boundary.sort() - self.hist_boundary = hist_boundary - if self.method == 'thresholding': - self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h - else: - self.sigma = sigma + super().__init__( + h, insz, resizing, method, sigma, intensity_scale, + hist_boundary or [0, 1], device + ) def forward(self, x): - x = torch.clamp(x, 0, 1) - if x.shape[2] > self.insz or x.shape[3] > self.insz: - if self.resizing == 'interpolation': - x_sampled = F.interpolate(x, size=(self.insz, self.insz), - mode='bilinear', align_corners=False) - elif self.resizing == 'sampling': - inds_1 = torch.LongTensor( - np.linspace(0, x.shape[2], self.h, endpoint=False)).to( - device=self.device) - inds_2 = torch.LongTensor( - np.linspace(0, x.shape[3], self.h, endpoint=False)).to( - device=self.device) - x_sampled = x.index_select(2, inds_1) - x_sampled = x_sampled.index_select(3, inds_2) - else: - raise Exception( - f'Wrong resizing method. It should be: interpolation or sampling. ' - f'But the given value is {self.resizing}.') - else: - x_sampled = x + x_sampled = self.resize(x.clamp(0, 1)) - L = x_sampled.shape[0] # size of mini-batch + N = x_sampled.shape[0] # size of mini-batch if x_sampled.shape[1] > 3: x_sampled = x_sampled[:, :3, :, :] X = torch.unbind(x_sampled, dim=0) - hists = torch.zeros((x_sampled.shape[0], 1, self.h, self.h)).to( - device=self.device) - for l in range(L): - I = torch.t(torch.reshape(X[l], (3, -1))) - II = torch.pow(I, 2) - if self.intensity_scale: - Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS), - dim=1) - else: - Iy = 1 - - Ir = torch.unsqueeze(I[:, 0] / (torch.sum(I, dim=-1) + EPS), dim=1) - Ig = torch.unsqueeze(I[:, 1] / (torch.sum(I, dim=-1) + EPS), dim=1) - - diff_r = abs(Ir - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - diff_g = abs(Ig - torch.unsqueeze(torch.tensor(np.linspace( - self.hist_boundary[0], self.hist_boundary[1], num=self.h)), - dim=0).to(self.device)) - - if self.method == 'thresholding': - diff_r = torch.reshape(diff_r, (-1, self.h)) <= self.eps / 2 - diff_g = torch.reshape(diff_g, (-1, self.h)) <= self.eps / 2 - elif self.method == 'RBF': - diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)), - 2) / self.sigma ** 2 - diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)), - 2) / self.sigma ** 2 - diff_r = torch.exp(-diff_r) # Gaussian - diff_g = torch.exp(-diff_g) - elif self.method == 'inverse-quadratic': - diff_r = torch.pow(torch.reshape(diff_r, (-1, self.h)), - 2) / self.sigma ** 2 - diff_g = torch.pow(torch.reshape(diff_g, (-1, self.h)), - 2) / self.sigma ** 2 - diff_r = 1 / (1 + diff_r) # Inverse quadratic - diff_g = 1 / (1 + diff_g) - - diff_r = diff_r.type(torch.float32) - diff_g = diff_g.type(torch.float32) + hists = torch.zeros(N, 1, self.h, self.h, device=self.device) + for n in range(N): + Ix = X[n].reshape(3, -1).t() + Inorm = Ix.sum(dim=-1) + EPS + Ir = (Ix[:, 0] / Inorm).unsqueeze(dim=1) + Ig = (Ix[:, 1] / Inorm).unsqueeze(dim=1) + + diff_r = (Ir - self.delta).abs() + diff_g = (Ig - self.delta).abs() + diff_r = self.kernel(diff_r) + diff_g = self.kernel(diff_g) + Iy = self.scaling(Ix) a = torch.t(Iy * diff_r) - hists[l, 0, :, :] = torch.mm(a, diff_g) + hists[n, 0, :, :] = torch.mm(a, diff_g) # normalization - hists_normalized = hists / ( - ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) + norm = hists.view(-1, self.h * self.h).sum(dim=1).view(-1, 1, 1, 1) + EPS + hists_normalized = hists / norm return hists_normalized \ No newline at end of file