Skip to content

Commit

Permalink
Merge pull request #107 from Kymer0615/master
Browse files Browse the repository at this point in the history
Incoherent propagation focal stack for RGBD images
  • Loading branch information
kaanaksit authored Oct 1, 2024
2 parents 3eb6ec5 + 29b5dc5 commit c2294a2
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 0 deletions.
47 changes: 47 additions & 0 deletions odak/learn/perception/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,50 @@ def check_loss_inputs(loss_name, image, target):
loss_name + """ ERROR: Inputs should have either 1 or 3 channels
(1 channel for grayscale, 3 for RGB or YCrCb).
Ensure inputs have 1 or 3 channels and are in NCHW format.""")

def slice_rgbd_targets(target, depth, depth_plane_positions):
"""
Slices the target RGBD image and depth map into multiple layers based on depth plane positions.
Parameters
----------
target : torch.Tensor
The RGBD target tensor with shape (C, H, W).
depth : torch.Tensor
The depth map corresponding to the target image with shape (H, W).
depth_plane_positions : list or torch.Tensor
The positions of the depth planes used for slicing.
Returns
-------
targets : torch.Tensor
A tensor of shape (N, C, H, W) where N is the number of depth planes. Contains the sliced targets for each depth plane.
masks : torch.Tensor
A tensor of shape (N, C, H, W) containing binary masks for each depth plane.
"""
device = target.device
number_of_planes = len(depth_plane_positions) - 1
targets = torch.zeros(
number_of_planes,
target.shape[0],
target.shape[1],
target.shape[2],
requires_grad = False,
device = device
)
masks = torch.zeros_like(targets, dtype = torch.int).to(device)
mask_zeros = torch.zeros_like(depth, dtype = torch.int)
mask_ones = torch.ones_like(depth, dtype = torch.int)
for i in range(1, number_of_planes+1):
for ch in range(target.shape[0]):
pos = depth_plane_positions[i]
prev_pos = depth_plane_positions[i-1]
if i <= (number_of_planes - 1):
condition = torch.logical_and(prev_pos <= depth, depth < pos)
else:
condition = torch.logical_and(prev_pos <= depth, depth <= pos)
mask = torch.where(condition, mask_ones, mask_zeros)
new_target = target[ch] * mask
targets[i-1, ch] = new_target.squeeze(0)
masks[i-1, ch] = mask.detach().clone()
return targets, masks
1 change: 1 addition & 0 deletions odak/learn/wave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .optimizers import *
from .propagators import *
from .util import *
from .incoherent_propagation import *
80 changes: 80 additions & 0 deletions odak/learn/wave/incoherent_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from odak.learn.wave import calculate_amplitude, wavenumber, propagate_beam
from odak.learn.perception.color_conversion import rgb_to_linear_rgb, linear_rgb_to_rgb


def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], aperture = 1., alpha = 0.5):
"""
Generate incoherent focal stack using RGB-D images.
The occlusion mechanism is inspired from https://github.com/dongyeon93/holographic-parallax/blob/main/Incoherent_focal_stack.py
Parameters
----------
targets : torch.tensor
Slices of the targets based on the depth masks.
masks : torch.tensor
Masks based on the depthmaps.
distances : list
A list of propagation distances.
dx : float
Size of one single pixel in the field grid (in meters).
wavelengths : list
A list of wavelengths.
zero_padding : bool
Zero pad in Fourier domain.
aperture : torch.tensor
Fourier domain aperture (e.g., pinhole in a typical holographic display).
The default is one, but an aperture could be as large as input field [m x n].
alpha : float
Parameter to control how much the occlusion mask from the previous layer contributes to the current layer's occlusion when computing the focal stack.
"""


device = targets.device
number_of_planes, number_of_channels, nu, nv = targets.shape
focal_stack = torch.zeros_like(targets, dtype=torch.float32).to(device)
for ch, wavelength in enumerate(wavelengths):
for n in range(number_of_planes):
plane_sum = torch.zeros(nu, nv).to(device)
occlusion_masks = torch.zeros(number_of_planes, nu, nv).to(device)

for k in range(number_of_planes):
distance = distances[n] - distances[k]
mask_k = masks[k]
propagated_mask = propagate_beam(
field = mask_k,
k = wavenumber(wavelength),
distance = distance,
dx = dx,
wavelength = wavelength,
propagation_type = 'Incoherent Angular Spectrum',
zero_padding = zero_padding,
aperture = aperture
)
propagated_mask = calculate_amplitude(propagated_mask)
propagated_mask = torch.mean(propagated_mask, dim = 0)
occlusion_mask = 1.0 - propagated_mask / (propagated_mask.max() if propagated_mask.max() else 1e-12)
occlusion_masks[k, :, :] = torch.nan_to_num(occlusion_mask, 1.0)
target = targets[k, ch]
propagated_target = propagate_beam(
field = target,
k = wavenumber(wavelength),
distance = distance,
dx = dx,
wavelength = wavelength,
propagation_type = 'Incoherent Angular Spectrum',
zero_padding = zero_padding,
aperture = aperture
)
propagated_target = calculate_amplitude(propagated_target)
if k == 0:
plane_sum = (1. * occlusion_mask) * plane_sum + propagated_target
elif k == (number_of_planes - 1):
prev_occlusion_mask = occlusion_masks[k-1]
plane_sum = (alpha * occlusion_mask + (1.0 - alpha) * prev_occlusion_mask) * plane_sum + alpha * propagated_target
else:
prev_occlusion_mask = occlusion_masks[k-1]
plane_sum = (alpha * occlusion_mask + (1.0 - alpha) * prev_occlusion_mask) * plane_sum + propagated_target

focal_stack[n, ch, :, :] = plane_sum
return focal_stack / focal_stack.max()
Binary file added test/data/sample_depthmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/sample_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions test/test_learn_perception_util_slice_rgbd_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import sys

from os.path import join
from odak.learn.perception.util import slice_rgbd_targets
from odak.learn.tools import load_image, save_image
from odak.tools import check_directory


def test(device = torch.device('cpu'), output_directory = 'test_output'):
check_directory(output_directory)

target = load_image(
"test/data/sample_rgb.png",
normalizeby = 256,
torch_style = True
).to(device)

depth = load_image(
"test/data/sample_depthmap.png",
normalizeby = 256,
torch_style = True
).to(device)
depth = torch.mean(depth, dim = 0) # Ensure the depthmap has the shape of [w x h]

depth_plane_positions = torch.linspace(0, 1, steps=5).to(device)
targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions)
depth_slices_sum = torch.zeros_like(target)
for idx, target in enumerate(targets):
depth_slices_sum += masks[idx]
save_image(
join(output_directory, f"target_{idx}.png"),
target,
cmin = target.min(),
cmax = target.max()
)
save_image(
join(output_directory, f"depth_{idx}.png"),
masks[idx],
cmin = masks[idx].min(),
cmax = masks[idx].max()
)
print(depth_slices_sum.mean().item())
assert depth_slices_sum.mean().item() == 1. # The mean of the depth slices sum shoud be 1


if __name__ == '__main__':
sys.exit(test())
58 changes: 58 additions & 0 deletions test/test_learn_wave_incoherent_focal_stack_rgbd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import sys

from os.path import join
from odak.tools import check_directory
from odak.learn.tools import load_image, save_image
from odak.learn.perception.util import slice_rgbd_targets
from odak.learn.wave import incoherent_focal_stack_rgbd
from odak.learn.perception import rgb_to_linear_rgb, linear_rgb_to_rgb


def test(device = torch.device('cpu'), output_directory = 'test_output'):
check_directory(output_directory)

target = load_image(
"test/data/sample_rgb.png",
normalizeby = 256,
torch_style = True
).to(device)
target = rgb_to_linear_rgb(target).squeeze(0)

depth = load_image(
"test/data/sample_depthmap.png",
normalizeby = 256,
torch_style = True
).to(device)
depth = torch.mean(depth, dim = 0).to(device) # Ensure the depthmap has the shape of [w x h]
depth_plane_positions = torch.linspace(0, 1, steps=5).to(device)

wavelengths = torch.tensor([639e-9, 515e-9, 473e-9], dtype = torch.float32).to(device)
pixel_pitch = 3.74e-6

targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions)
distances = depth_plane_positions * 0.0005 # Multiply with some multiplier to control the blurriness

focal_stack = incoherent_focal_stack_rgbd(
targets = targets,
masks = masks,
distances = distances,
dx = pixel_pitch,
wavelengths = wavelengths,
)
for idx, focal_image in enumerate(focal_stack):
min_value = focal_image.min()
max_value = focal_image.max()
focal_image = (focal_image - min_value) / (max_value - min_value)
focal_image = linear_rgb_to_rgb(focal_image)
save_image(
join(output_directory, f"focal_target_{idx}.png"),
focal_image,
cmin = 0,
cmax = 1
)
assert True == True


if __name__ == '__main__':
sys.exit(test())

0 comments on commit c2294a2

Please sign in to comment.