From 2ae695d4196cde03b0662ff26563cade7f0d4d15 Mon Sep 17 00:00:00 2001 From: Kymer0615 Date: Tue, 15 Oct 2024 14:41:40 +0100 Subject: [PATCH] fix incoherent_focal_stack_rgbd with aperture -> apertures for each of the color channel; fix the corresponding test file with correct targets and masks order --- odak/learn/wave/incoherent_propagation.py | 12 +++--- ..._learn_wave_incoherent_focal_stack_rgbd.py | 38 ++++++++++++++----- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/odak/learn/wave/incoherent_propagation.py b/odak/learn/wave/incoherent_propagation.py index 389c52a2..7f93fa43 100644 --- a/odak/learn/wave/incoherent_propagation.py +++ b/odak/learn/wave/incoherent_propagation.py @@ -3,9 +3,10 @@ from odak.learn.perception.color_conversion import rgb_to_linear_rgb, linear_rgb_to_rgb -def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], aperture = 1., alpha = 0.5): +def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero_padding = [True, False, True], apertures = [1., 1., 1.], alpha = 0.5): """ Generate incoherent focal stack using RGB-D images. + Please note that the targets and masks should follow the order from the furthest to the closest. The occlusion mechanism is inspired from https://github.com/dongyeon93/holographic-parallax/blob/main/Incoherent_focal_stack.py Parameters @@ -22,9 +23,8 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero A list of wavelengths. zero_padding : bool Zero pad in Fourier domain. - aperture : torch.tensor - Fourier domain aperture (e.g., pinhole in a typical holographic display). - The default is one, but an aperture could be as large as input field [m x n]. + apertures : torch.tensor + Fourier domain apertures (e.g., pinhole in a typical holographic display) for each color channel. alpha : float Parameter to control how much the occlusion mask from the previous layer contributes to the current layer's occlusion when computing the focal stack. """ @@ -49,7 +49,7 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero wavelength = wavelength, propagation_type = 'Incoherent Angular Spectrum', zero_padding = zero_padding, - aperture = aperture + aperture = apertures[ch] ) propagated_mask = calculate_amplitude(propagated_mask) propagated_mask = torch.mean(propagated_mask, dim = 0) @@ -64,7 +64,7 @@ def incoherent_focal_stack_rgbd(targets, masks, distances, dx, wavelengths, zero wavelength = wavelength, propagation_type = 'Incoherent Angular Spectrum', zero_padding = zero_padding, - aperture = aperture + aperture = apertures[ch] ) propagated_target = calculate_amplitude(propagated_target) if k == 0: diff --git a/test/test_learn_wave_incoherent_focal_stack_rgbd.py b/test/test_learn_wave_incoherent_focal_stack_rgbd.py index dd853487..1fb0eeab 100644 --- a/test/test_learn_wave_incoherent_focal_stack_rgbd.py +++ b/test/test_learn_wave_incoherent_focal_stack_rgbd.py @@ -1,6 +1,5 @@ import torch import sys - from os.path import join from odak.tools import check_directory from odak.learn.tools import load_image, save_image @@ -27,29 +26,48 @@ def test(device = torch.device('cpu'), output_directory = 'test_output'): depth = torch.mean(depth, dim = 0).to(device) # Ensure the depthmap has the shape of [w x h] depth_plane_positions = torch.linspace(0, 1, steps=5).to(device) - wavelengths = torch.tensor([639e-9, 515e-9, 473e-9], dtype = torch.float32).to(device) - pixel_pitch = 3.74e-6 + wavelengths = torch.tensor([638e-9, 520e-9, 450e-9], dtype = torch.float32).to(device) + pixel_pitch = 8.2e-06 + apertures = [1., 1., 1.] targets, masks = slice_rgbd_targets(target, depth, depth_plane_positions) - distances = depth_plane_positions * 0.0005 # Multiply with some multiplier to control the blurriness + distances = depth_plane_positions * 0.001 # Multiply with some multiplier to control the blurriness + + for idx, mask in enumerate(masks): + save_image( + join(output_directory, f"mask_{idx}.png"), + mask, + cmin = 0, + cmax = 1 + ) + + for idx, target in enumerate(targets): + save_image( + join(output_directory, f"target_{idx}.png"), + target, + cmin = 0, + cmax = 1 + ) + # Please note that the targets and masks are reversed here to maintain the order from the furthest to the closest, ensuring the overlay rendering is accurate. focal_stack = incoherent_focal_stack_rgbd( - targets = targets, - masks = masks, + targets = targets.flip(dims=(0,)), + masks = masks.flip(dims=(0,)), distances = distances, dx = pixel_pitch, wavelengths = wavelengths, + apertures = apertures ) + focal_stack = focal_stack.flip(dims=(0,)) for idx, focal_image in enumerate(focal_stack): + focal_image = linear_rgb_to_rgb(focal_image) min_value = focal_image.min() max_value = focal_image.max() - focal_image = (focal_image - min_value) / (max_value - min_value) - focal_image = linear_rgb_to_rgb(focal_image) save_image( join(output_directory, f"focal_target_{idx}.png"), focal_image, - cmin = 0, - cmax = 1 + cmin = min_value, + cmax = max_value ) assert True == True