-
-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #107 from Kymer0615/master
Incoherent propagation focal stack for RGBD images
- Loading branch information
Showing
7 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ | |
from .optimizers import * | ||
from .propagators import * | ||
from .util import * | ||
from .incoherent_propagation import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from odak.learn.wave import calculate_amplitude, wavenumber, propagate_beam | ||
from odak.learn.perception.color_conversion import rgb_to_linear_rgb, linear_rgb_to_rgb | ||
|
||
|
||
def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], aperture = 1., alpha = 0.5): | ||
""" | ||
Generate incoherent focal stack using RGB-D images. | ||
The occlusion mechanism is inspired from https://github.com/dongyeon93/holographic-parallax/blob/main/Incoherent_focal_stack.py | ||
Parameters | ||
---------- | ||
targets : torch.tensor | ||
Slices of the targets based on the depth masks. | ||
masks : torch.tensor | ||
Masks based on the depthmaps. | ||
distances : list | ||
A list of propagation distances. | ||
dx : float | ||
Size of one single pixel in the field grid (in meters). | ||
wavelengths : list | ||
A list of wavelengths. | ||
zero_padding : bool | ||
Zero pad in Fourier domain. | ||
aperture : torch.tensor | ||
Fourier domain aperture (e.g., pinhole in a typical holographic display). | ||
The default is one, but an aperture could be as large as input field [m x n]. | ||
alpha : float | ||
Parameter to control how much the occlusion mask from the previous layer contributes to the current layer's occlusion when computing the focal stack. | ||
""" | ||
|
||
|
||
device = targets.device | ||
number_of_planes, number_of_channels, nu, nv = targets.shape | ||
focal_stack = torch.zeros_like(targets, dtype=torch.float32).to(device) | ||
for ch, wavelength in enumerate(wavelengths): | ||
for n in range(number_of_planes): | ||
plane_sum = torch.zeros(nu, nv).to(device) | ||
occlusion_masks = torch.zeros(number_of_planes, nu, nv).to(device) | ||
|
||
for k in range(number_of_planes): | ||
distance = distances[n] - distances[k] | ||
mask_k = masks[k] | ||
propagated_mask = propagate_beam( | ||
field = mask_k, | ||
k = wavenumber(wavelength), | ||
distance = distance, | ||
dx = dx, | ||
wavelength = wavelength, | ||
propagation_type = 'Incoherent Angular Spectrum', | ||
zero_padding = zero_padding, | ||
aperture = aperture | ||
) | ||
propagated_mask = calculate_amplitude(propagated_mask) | ||
propagated_mask = torch.mean(propagated_mask, dim = 0) | ||
occlusion_mask = 1.0 - propagated_mask / (propagated_mask.max() if propagated_mask.max() else 1e-12) | ||
occlusion_masks[k, :, :] = torch.nan_to_num(occlusion_mask, 1.0) | ||
target = targets[k, ch] | ||
propagated_target = propagate_beam( | ||
field = target, | ||
k = wavenumber(wavelength), | ||
distance = distance, | ||
dx = dx, | ||
wavelength = wavelength, | ||
propagation_type = 'Incoherent Angular Spectrum', | ||
zero_padding = zero_padding, | ||
aperture = aperture | ||
) | ||
propagated_target = calculate_amplitude(propagated_target) | ||
if k == 0: | ||
plane_sum = (1. * occlusion_mask) * plane_sum + propagated_target | ||
elif k == (number_of_planes - 1): | ||
prev_occlusion_mask = occlusion_masks[k-1] | ||
plane_sum = (alpha * occlusion_mask + (1.0 - alpha) * prev_occlusion_mask) * plane_sum + alpha * propagated_target | ||
else: | ||
prev_occlusion_mask = occlusion_masks[k-1] | ||
plane_sum = (alpha * occlusion_mask + (1.0 - alpha) * prev_occlusion_mask) * plane_sum + propagated_target | ||
|
||
focal_stack[n, ch, :, :] = plane_sum | ||
return focal_stack / focal_stack.max() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import sys | ||
|
||
from os.path import join | ||
from odak.learn.perception.util import slice_rgbd_targets | ||
from odak.learn.tools import load_image, save_image | ||
from odak.tools import check_directory | ||
|
||
|
||
def test(device = torch.device('cpu'), output_directory = 'test_output'): | ||
check_directory(output_directory) | ||
|
||
target = load_image( | ||
"test/data/sample_rgb.png", | ||
normalizeby = 256, | ||
torch_style = True | ||
).to(device) | ||
|
||
depth = load_image( | ||
"test/data/sample_depthmap.png", | ||
normalizeby = 256, | ||
torch_style = True | ||
).to(device) | ||
depth = torch.mean(depth, dim = 0) # Ensure the depthmap has the shape of [w x h] | ||
|
||
depth_plane_positions = torch.linspace(0, 1, steps=5).to(device) | ||
targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions) | ||
depth_slices_sum = torch.zeros_like(target) | ||
for idx, target in enumerate(targets): | ||
depth_slices_sum += masks[idx] | ||
save_image( | ||
join(output_directory, f"target_{idx}.png"), | ||
target, | ||
cmin = target.min(), | ||
cmax = target.max() | ||
) | ||
save_image( | ||
join(output_directory, f"depth_{idx}.png"), | ||
masks[idx], | ||
cmin = masks[idx].min(), | ||
cmax = masks[idx].max() | ||
) | ||
print(depth_slices_sum.mean().item()) | ||
assert depth_slices_sum.mean().item() == 1. # The mean of the depth slices sum shoud be 1 | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(test()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
import sys | ||
|
||
from os.path import join | ||
from odak.tools import check_directory | ||
from odak.learn.tools import load_image, save_image | ||
from odak.learn.perception.util import slice_rgbd_targets | ||
from odak.learn.wave import incoherent_focal_stack_rgbd | ||
from odak.learn.perception import rgb_to_linear_rgb, linear_rgb_to_rgb | ||
|
||
|
||
def test(device = torch.device('cpu'), output_directory = 'test_output'): | ||
check_directory(output_directory) | ||
|
||
target = load_image( | ||
"test/data/sample_rgb.png", | ||
normalizeby = 256, | ||
torch_style = True | ||
).to(device) | ||
target = rgb_to_linear_rgb(target).squeeze(0) | ||
|
||
depth = load_image( | ||
"test/data/sample_depthmap.png", | ||
normalizeby = 256, | ||
torch_style = True | ||
).to(device) | ||
depth = torch.mean(depth, dim = 0).to(device) # Ensure the depthmap has the shape of [w x h] | ||
depth_plane_positions = torch.linspace(0, 1, steps=5).to(device) | ||
|
||
wavelengths = torch.tensor([639e-9, 515e-9, 473e-9], dtype = torch.float32).to(device) | ||
pixel_pitch = 3.74e-6 | ||
|
||
targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions) | ||
distances = depth_plane_positions * 0.0005 # Multiply with some multiplier to control the blurriness | ||
|
||
focal_stack = incoherent_focal_stack_rgbd( | ||
targets = targets, | ||
masks = masks, | ||
distances = distances, | ||
dx = pixel_pitch, | ||
wavelengths = wavelengths, | ||
) | ||
for idx, focal_image in enumerate(focal_stack): | ||
min_value = focal_image.min() | ||
max_value = focal_image.max() | ||
focal_image = (focal_image - min_value) / (max_value - min_value) | ||
focal_image = linear_rgb_to_rgb(focal_image) | ||
save_image( | ||
join(output_directory, f"focal_target_{idx}.png"), | ||
focal_image, | ||
cmin = 0, | ||
cmax = 1 | ||
) | ||
assert True == True | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(test()) |