Skip to content

Commit

Permalink
Merge branch 'Kymer0615-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Oct 15, 2024
2 parents 2cdc988 + 77e5348 commit 272e8bb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
12 changes: 6 additions & 6 deletions odak/learn/wave/incoherent_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from odak.learn.perception.color_conversion import rgb_to_linear_rgb, linear_rgb_to_rgb


def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], aperture = 1., alpha = 0.5):
def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], apertures = [1., 1., 1.], alpha = 0.5):
"""
Generate incoherent focal stack using RGB-D images.
Please note that the targets and masks should follow the order from the furthest to the closest.
The occlusion mechanism is inspired from https://github.com/dongyeon93/holographic-parallax/blob/main/Incoherent_focal_stack.py
Parameters
Expand All @@ -22,9 +23,8 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero
A list of wavelengths.
zero_padding : bool
Zero pad in Fourier domain.
aperture : torch.tensor
Fourier domain aperture (e.g., pinhole in a typical holographic display).
The default is one, but an aperture could be as large as input field [m x n].
apertures : torch.tensor
Fourier domain apertures (e.g., pinhole in a typical holographic display) for each color channel.
alpha : float
Parameter to control how much the occlusion mask from the previous layer contributes to the current layer's occlusion when computing the focal stack.
"""
Expand All @@ -49,7 +49,7 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero
wavelength = wavelength,
propagation_type = 'Incoherent Angular Spectrum',
zero_padding = zero_padding,
aperture = aperture
aperture = apertures[ch]
)
propagated_mask = calculate_amplitude(propagated_mask)
propagated_mask = torch.mean(propagated_mask, dim = 0)
Expand All @@ -64,7 +64,7 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero
wavelength = wavelength,
propagation_type = 'Incoherent Angular Spectrum',
zero_padding = zero_padding,
aperture = aperture
aperture = apertures[ch]
)
propagated_target = calculate_amplitude(propagated_target)
if k == 0:
Expand Down
38 changes: 28 additions & 10 deletions test/test_learn_wave_incoherent_focal_stack_rgbd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import sys

from os.path import join
from odak.tools import check_directory
from odak.learn.tools import load_image, save_image
Expand All @@ -27,29 +26,48 @@ def test(device = torch.device('cpu'), output_directory = 'test_output'):
depth = torch.mean(depth, dim = 0).to(device) # Ensure the depthmap has the shape of [w x h]
depth_plane_positions = torch.linspace(0, 1, steps=5).to(device)

wavelengths = torch.tensor([639e-9, 515e-9, 473e-9], dtype = torch.float32).to(device)
pixel_pitch = 3.74e-6
wavelengths = torch.tensor([638e-9, 520e-9, 450e-9], dtype = torch.float32).to(device)
pixel_pitch = 8.2e-06
apertures = [1., 1., 1.]

targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions)
distances = depth_plane_positions * 0.0005 # Multiply with some multiplier to control the blurriness
distances = depth_plane_positions * 0.001 # Multiply with some multiplier to control the blurriness

for idx, mask in enumerate(masks):
save_image(
join(output_directory, f"mask_{idx}.png"),
mask,
cmin = 0,
cmax = 1
)

for idx, target in enumerate(targets):
save_image(
join(output_directory, f"target_{idx}.png"),
target,
cmin = 0,
cmax = 1
)

# Please note that the targets and masks are reversed here to maintain the order from the furthest to the closest, ensuring the overlay rendering is accurate.
focal_stack = incoherent_focal_stack_rgbd(
targets = targets,
masks = masks,
targets = targets.flip(dims=(0,)),
masks = masks.flip(dims=(0,)),
distances = distances,
dx = pixel_pitch,
wavelengths = wavelengths,
apertures = apertures
)
focal_stack = focal_stack.flip(dims=(0,))
for idx, focal_image in enumerate(focal_stack):
focal_image = linear_rgb_to_rgb(focal_image)
min_value = focal_image.min()
max_value = focal_image.max()
focal_image = (focal_image - min_value) / (max_value - min_value)
focal_image = linear_rgb_to_rgb(focal_image)
save_image(
join(output_directory, f"focal_target_{idx}.png"),
focal_image,
cmin = 0,
cmax = 1
cmin = min_value,
cmax = max_value
)
assert True == True

Expand Down

0 comments on commit 272e8bb

Please sign in to comment.