diff --git a/.gitignore b/.gitignore index 17247de2..b6e53082 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Unit tests +output/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/odak/learn/wave/classical.py b/odak/learn/wave/classical.py index b453ed0d..5f1a2a23 100644 --- a/odak/learn/wave/classical.py +++ b/odak/learn/wave/classical.py @@ -67,6 +67,8 @@ def propagate_beam( result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture) elif propagation_type == 'Impulse Response Fresnel': result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples) + elif propagation_type == 'Seperable Impulse Response Fresnel': + result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples) elif propagation_type == 'Transfer Function Fresnel': result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture) elif propagation_type == 'custom': @@ -124,7 +126,8 @@ def get_propagation_kernel( ------- kernel : torch.tensor Complex kernel for the given propagation type. - """ + """ + logging.warning('Requested propagation kernel size for %s method with %s m distance, %s m pixel pitch, %s m wavelength, %s x %s resolutions, x%s scale and %s samples.'.format(propagation_type, distance, dx, nu, nv, scale, samples)) if propagation_type == 'Bandlimited Angular Spectrum': kernel = get_band_limited_angular_spectrum_kernel( nu = nu, @@ -172,6 +175,17 @@ def get_propagation_kernel( distance = distance, device = device ) + elif propagation_type == 'Seperable Impulse Response Fresnel': + kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel( + nu = nu, + nv = nv, + dx = dx, + wavelength = wavelength, + distance = distance, + device = device, + scale = scale, + aperture_samples = samples + ) else: logging.warning('Propagation type not recognized') assert True == False @@ -380,7 +394,7 @@ def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, Returns ------- - H : float + H : torch.complex64 Complex kernel in Fourier domain. """ k = wavenumber(wavelength) @@ -404,6 +418,202 @@ def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, return H +def get_seperable_impulse_response_fresnel_kernel( + nu, + nv, + dx = 3.74e-6, + wavelength = 515e-9, + distance = 0., + scale = 1, + aperture_samples = [50, 50, 5, 5], + device = torch.device('cpu') + ): + """ + Returns impulse response fresnel kernel in separable form. + + Parameters + ---------- + nu : int + Resolution at X axis in pixels. + nv : int + Resolution at Y axis in pixels. + dx : float + Pixel pitch in meters. + wavelength : float + Wavelength in meters. + distance : float + Distance in meters. + device : torch.device + Device, for more see torch.device(). + scale : int + Scale with respect to nu and nv (e.g., scale = 2 leads to 2 x nu and 2 x nv resolution for H). + aperture_samples : list + Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels. + + Returns + ------- + H : torch.complex64 + Complex kernel in Fourier domain. + h : torch.complex64 + Complex kernel in spatial domain. + h_x : torch.complex64 + 1D complex kernel in spatial domain along X axis. + h_y : torch.complex64 + 1D complex kernel in spatial domain along Y axis. + """ + k = wavenumber(wavelength) + distance = torch.as_tensor(distance, device = device) + length_x, length_y = ( + torch.tensor(dx * nu, device = device), + torch.tensor(dx * nv, device = device) + ) + x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device) + y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device) + wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0) + wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1) + pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1) + pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0) + wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1) + wys = (wys - pys).reshape(1, -1).unsqueeze(1) + + X = x.unsqueeze(-1).unsqueeze(-1) + Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1) + r_x = (X + wxs) ** 2 + r_y = (Y + wys) ** 2 + r = r_x + r_y + h_x = torch.exp(1j * k / (2 * distance) * r) + h_x = torch.sum(h_x, axis = (1, 2)) + + if nu != nv: + X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1) + Y = y.unsqueeze(-1).unsqueeze(-1) + r_x = (X + wxs) ** 2 + r_y = (Y + wys) ** 2 + r = r_x + r_y + h_y = torch.exp(1j * k * r / (2 * distance)) + h_y = torch.sum(h_y, axis = (1, 2)) + else: + h_y = h_x.detach().clone() + h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0) + H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3] + return H, h, h_x, h_y + + +def get_point_wise_impulse_response_fresnel_kernel( + aperture_points, + aperture_field, + target_points, + resolution, + resolution_factor = 1, + wavelength = 515e-9, + distance = 0., + randomization = False, + device = torch.device('cpu') + ): + """ + This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`. + The point spread function is calculated over provided points, `target_points`. + The final result is reshaped to follow the provided `resolution`. + + Parameters + ---------- + aperture_points : torch.tensor + Points representing an aperture in Euler space (XYZ) [m x 3]. + aperture_field : torch.tensor + Complex field for each point provided by `aperture_points` [1 x m]. + target_points : torch.tensor + Target points where the propagated field will be calculated [n x 1]. + resolution : list + Final resolution that the propagated field will be reshaped [X x Y]. + resolution_factor : int + Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field. + wavelength : float + Wavelength in meters. + randomization : bool + If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur. + distance : float + Distance in meters. + + Returns + ------- + h : float + Complex field in spatial domain. + """ + device = aperture_field.device + k = wavenumber(wavelength) + if randomization: + pp = [ + aperture_points[:, 0].max() - aperture_points[:, 0].min(), + aperture_points[:, 1].max() - aperture_points[:, 1].min() + ] + target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0] + target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1] + deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1) + deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1) + r = deltaX ** 2 + deltaY ** 2 + h = torch.exp(1j * k / (2 * distance) * r) * aperture_field + h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor) + h = 1. / (1j * wavelength * distance) * h + return h + + +def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]): + """ + A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property. + + Parameters + ---------- + field : torch.complex + Complex field (MxN). + k : odak.wave.wavenumber + Wave number of a wave, see odak.wave.wavenumber for more. + distance : float + Propagation distance. + dx : float + Size of one single pixel in the field grid (in meters). + wavelength : float + Wavelength of the electric field. + zero_padding : bool + Zero pad in Fourier domain. + aperture : torch.tensor + Fourier domain aperture (e.g., pinhole in a typical holographic display). + The default is one, but an aperture could be as large as input field [m x n]. + scale : int + Resolution factor to scale generated kernel. + samples : list + When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel. + + Returns + ------- + result : torch.complex + Final complex field (MxN). + + """ + H = get_propagation_kernel( + nu = field.shape[-2], + nv = field.shape[-1], + dx = dx, + wavelength = wavelength, + distance = distance, + propagation_type = 'Seperable Impulse Response Fresnel', + device = field.device, + scale = scale, + samples = samples + ) + if scale > 1: + field_amplitude = calculate_amplitude(field) + field_phase = calculate_phase(field) + field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device) + field_scale_phase = torch.zeros_like(field_scale_amplitude) + field_scale_amplitude[::scale, ::scale] = field_amplitude + field_scale_phase[::scale, ::scale] = field_phase + field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase) + else: + field_scale = field + result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture) + return result + + def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]): """ A definition to calculate convolution based Fresnel approximation for beam propagation. @@ -483,7 +693,7 @@ def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, Returns ------- - H : float + H : torch.complex64 Complex kernel in Fourier domain. """ distance = torch.tensor([distance]).to(device) @@ -689,7 +899,14 @@ def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding return result -def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')): +def get_band_limited_angular_spectrum_kernel( + nu, + nv, + dx = 8e-6, + wavelength = 515e-9, + distance = 0., + device = torch.device('cpu') + ): """ Helper function for odak.learn.wave.band_limited_angular_spectrum. @@ -711,7 +928,7 @@ def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515 Returns ------- - H : float + H : torch.complex64 Complex kernel in Fourier domain. """ x = dx * float(nu) @@ -741,7 +958,15 @@ def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515 return H -def band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.): +def band_limited_angular_spectrum( + field, + k, + distance, + dx, + wavelength, + zero_padding = False, + aperture = 1. + ): """ A definition to calculate bandlimited angular spectrum based beam propagation. For more `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`. @@ -880,7 +1105,7 @@ def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propa loss.backward(retain_graph = True) optimizer.step() t.set_description(description) - print(description) + logging.warning(description) torch.no_grad() hologram = generate_complex_field(1., phase) reconstruction = propagate_beam( diff --git a/odak/visualize/plotly.py b/odak/visualize/plotly.py index e4b95390..cd364115 100644 --- a/odak/visualize/plotly.py +++ b/odak/visualize/plotly.py @@ -5,9 +5,10 @@ import plotly import plotly.graph_objects as go from plotly.subplots import make_subplots -except: +except Exception as e: warning = 'odak.visualize.plotly requires certain packages: pip install plotly kaleido' logging.warning(warning) + logging.warning(e) import numpy as np from ..wave import calculate_phase, calculate_amplitude, calculate_intensity diff --git a/test/data/rectangular_aperture.png b/test/data/rectangular_aperture.png new file mode 100644 index 00000000..e21dc3e9 Binary files /dev/null and b/test/data/rectangular_aperture.png differ diff --git a/test/test_learn_wave_compare_beam_propagations.py b/test/test_learn_wave_compare_beam_propagations.py index 34570c4b..1b1d664f 100644 --- a/test/test_learn_wave_compare_beam_propagations.py +++ b/test/test_learn_wave_compare_beam_propagations.py @@ -10,8 +10,9 @@ def test(device = torch.device('cpu'), output_directory = 'test_output'): wavelength = 532e-9 pixel_pitch = 3.74e-6 distance = 1e-3 - aperture_samples = [35, 35, 1, 1] # Replace it with this: [50, 50, 5, 5] + aperture_samples = [50, 50, 1, 1] propagation_types = [ + 'Seperable Impulse Response Fresnel', 'Impulse Response Fresnel', 'Transfer Function Fresnel', 'Angular Spectrum', diff --git a/test/test_learn_wave_get_point_wise_impulse_response_fresnel_kernel.py b/test/test_learn_wave_get_point_wise_impulse_response_fresnel_kernel.py new file mode 100644 index 00000000..a78a81f3 --- /dev/null +++ b/test/test_learn_wave_get_point_wise_impulse_response_fresnel_kernel.py @@ -0,0 +1,166 @@ +import odak +import torch +import sys + +from tqdm import tqdm + + +def get_target_plane_points( + resolution = [50, 50], + resolution_factor = 1, + z = 1e-3, + pixel_pitch = 3.74e-6, + device = torch.device('cpu') + ): + wx = resolution[0] * pixel_pitch + wy = resolution[1] * pixel_pitch + x = torch.linspace(-wx / 2., wx / 2., resolution[0] * resolution_factor, device = device) + y = torch.linspace(-wy / 2., wy / 2., resolution[1] * resolution_factor, device = device) + X, Y = torch.meshgrid(x, y, indexing = 'ij') + X = X.reshape(-1, 1) + Y = Y.reshape(-1, 1) + Z = torch.ones_like(X) * z + target_plane_points = torch.cat((X, Y, Z), axis = 1) + return target_plane_points + + +def get_aperture_points( + aperture_pattern, + z = 0., + aperture_phase = None, + dimensions = [3.74e-6, 3.74e-6], + device = torch.device('cpu') + ): + x = torch.linspace(- dimensions[0] / 2., dimensions[0] / 2., aperture_pattern.shape[-2], device = device) + y = torch.linspace(- dimensions[1] / 2., dimensions[1] / 2., aperture_pattern.shape[-1], device = device) + X, Y = torch.meshgrid(x, y, indexing = 'ij') + X = X[aperture_pattern > 0.] + Y = Y[aperture_pattern > 0.] + X = X.reshape(-1, 1) + Y = Y.reshape(-1, 1) + Z = torch.ones_like(X) * z + aperture_points = torch.cat((X, Y, Z), axis = 1) + if isinstance(aperture_phase, type(None)): + aperture_phase = torch.zeros_like(aperture_pattern) + aperture_field = odak.learn.wave.generate_complex_field(aperture_pattern, aperture_phase) + aperture_field = aperture_field[aperture_pattern > 0.] + aperture_field = aperture_field.reshape(1, -1) + return aperture_points, aperture_field + + +def main( + wavelength = 515e-9, + distance = 0.5e-3, + pixel_pitch = 3.74e-6, + resolution = [512, 512], + resolution_factor = 1, + randomization = False, + nx = 1, ny = 1, + aperture_pattern_filename = './test/data/rectangular_aperture.png', + device = torch.device('cpu'), + output_directory = 'output' + ): + target_points = get_target_plane_points( + resolution = resolution, + resolution_factor = resolution_factor, + z = distance, + pixel_pitch = pixel_pitch, + device = device + ) + aperture_pattern = odak.learn.tools.load_image(aperture_pattern_filename, normalizeby = 255., torch_style = True).to(device)[0] + aperture_points, aperture_field = get_aperture_points( + aperture_pattern, + dimensions = [pixel_pitch, pixel_pitch], + device = device + ) + h = torch.zeros( + resolution[0] * resolution_factor, + resolution[1] * resolution_factor, + dtype = torch.complex64, + device = device + ) + for i in range(nx): + for j in range(ny): + shift = torch.tensor( + [[ + pixel_pitch / nx * i - pixel_pitch / 2., + pixel_pitch / ny * j - pixel_pitch / 2., + 0., + ]], + device = device + ) + h += odak.learn.wave.get_point_wise_impulse_response_fresnel_kernel( + aperture_points = aperture_points, + aperture_field = aperture_field, + target_points = target_points + shift, + resolution = resolution, + resolution_factor = resolution_factor, + wavelength = wavelength, + distance = distance, + randomization = randomization, + device = device + ) + h = h / nx / ny + save_psfs( + h, + directory = '{}/aperture_psfs/'.format(output_directory) + ) + assert True == True + + +def save_psfs(kernel, directory, wavelength_id = 0, distance_id = 0, pixel_pitch_id = 0): + odak.tools.check_directory(directory) + kernel_amplitude = odak.learn.wave.calculate_amplitude(kernel) + kernel_intensity = kernel_amplitude ** 2 + kernel_amplitude = kernel_amplitude / kernel_amplitude.max() + kernel_phase = odak.learn.wave.calculate_phase(kernel) % (2 * torch.pi) + kernel_weighted = kernel_amplitude * kernel_phase + odak.learn.tools.save_image( + '{}/intensity_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_intensity, + cmin = 0., + cmax = kernel_intensity.max() + ) + odak.learn.tools.save_image( + '{}/amplitude_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_amplitude, + cmin = 0., + cmax = kernel_amplitude.max() + ) + odak.learn.tools.save_image( + '{}/phase_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_phase, + cmin = 0., + cmax = 2 * torch.pi + ) + odak.learn.tools.save_image( + '{}/weighted_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_weighted, + cmin = 0., + cmax = kernel_weighted.max() + ) + return True + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/test/test_learn_wave_get_seperable_impulse_response_fresnel.py b/test/test_learn_wave_get_seperable_impulse_response_fresnel.py new file mode 100644 index 00000000..87cd2de9 --- /dev/null +++ b/test/test_learn_wave_get_seperable_impulse_response_fresnel.py @@ -0,0 +1,208 @@ +import odak +import torch +import sys + +from tqdm import tqdm + + +def get_1D_kernels( + resolution, + pixel_pitches, + wavelengths, + distances, + scale = 1, + aperture_samples = [50, 50, 5, 5], + device = torch.device('cpu') + ): + kernels_x = torch.zeros( + len(wavelengths), + len(distances), + len(pixel_pitches), + resolution[0] * scale, + dtype = torch.complex64, + device = device + ) + kernels_y = torch.zeros( + len(wavelengths), + len(distances), + len(pixel_pitches), + resolution[1] * scale, + dtype = torch.complex64, + device = device + ) + kernels = torch.zeros( + len(wavelengths), + len(distances), + len(pixel_pitches), + resolution[0] * scale, + resolution[1] * scale, + dtype = torch.complex64, + device = device + ) + for wavelength_id, wavelength in enumerate(wavelengths): + for dx_id, dx in enumerate(pixel_pitches): + for distance_id in tqdm(range(len(distances))): + distance = distances[distance_id] + _, kernel, kernel_x, kernel_y = odak.learn.wave.get_seperable_impulse_response_fresnel_kernel( + nu = resolution[0], + nv = resolution[1], + dx = dx, + wavelength = wavelength, + distance = distance, + scale = scale, + aperture_samples = aperture_samples, + device = device + ) + kernels[ + wavelength_id, + distance_id, + dx_id + ] = kernel.detach().clone() + kernels_x[ + wavelength_id, + distance_id , + dx_id + ] = kernel_x.detach().clone() + kernels_y[ + wavelength_id, + distance_id , + dx_id + ] = kernel_y.detach().clone() + return kernels, kernels_x, kernels_y + + +def main( + wavelengths = [515e-9], + distance_range = [1e-4, 10.0e-3], + distance_no = 1, + pixel_pitches = [3.74e-6], + resolution = [512, 512], + resolution_factor = 1, + propagation_type = 'Impulse Response Fresnel', + samples = [50, 50, 1, 1], + device = torch.device('cpu'), + output_directory = 'output' + ): + odak.tools.check_directory(output_directory) + distances = torch.linspace(distance_range[0], distance_range[1], distance_no) + light_kernels, light_kernels_x, light_kernels_y = get_1D_kernels( + resolution = resolution, + pixel_pitches = pixel_pitches, + wavelengths = wavelengths, + distances = distances, + scale = resolution_factor, + aperture_samples = samples, + device = device + ) + save( + light_kernels_x, + directory = '{}/vectorized/'.format(output_directory) + ) + save_psfs( + light_kernels, + directory = '{}/vectorized/psfs/'.format(output_directory) + ) + assert True == True + + +def save_psfs(kernels, directory): + odak.tools.check_directory(directory) + for wavelength_id in range(kernels.shape[0]): + for distance_id in range(kernels.shape[1]): + for pixel_pitch_id in range(kernels.shape[2]): + kernel = kernels[wavelength_id, distance_id, pixel_pitch_id] + kernel_amplitude = odak.learn.wave.calculate_amplitude(kernel) + kernel_amplitude = kernel_amplitude / kernel_amplitude.max() + kernel_intensity = kernel_amplitude ** 2 + kernel_phase = odak.learn.wave.calculate_phase(kernel) % (2 * torch.pi) + kernel_weighted = kernel_amplitude * kernel_phase + odak.learn.tools.save_image( + '{}/intensity_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_intensity, + cmin = 0., + cmax = kernel_intensity.max() + ) + odak.learn.tools.save_image( + '{}/amplitude_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_amplitude, + cmin = 0., + cmax = kernel_amplitude.max() + ) + odak.learn.tools.save_image( + '{}/phase_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_phase, + cmin = 0., + cmax = 2 * torch.pi + ) + odak.learn.tools.save_image( + '{}/weighted_w{:03d}_d{:03d}_p{:03d}.png'.format( + directory, + wavelength_id, + distance_id, + pixel_pitch_id, + ), + kernel_weighted, + cmin = 0., + cmax = kernel_weighted.max() + ) + return True + +def save(kernels, directory): + odak.tools.check_directory(directory) + for wavelength_id in range(kernels.shape[0]): + for pixel_pitch_id in range(kernels.shape[2]): + kernel = kernels[wavelength_id, :, pixel_pitch_id] + kernel_amplitude = odak.learn.wave.calculate_amplitude(kernel) + kernel_amplitude = kernel_amplitude / kernel_amplitude.max() + kernel_phase = odak.learn.wave.calculate_phase(kernel) % (2 * torch.pi) + kernel_weighted = kernel_amplitude * kernel_phase + odak.learn.tools.save_image( + '{}/amplitude_w{:03}_p{:03d}.png'.format( + directory, + wavelength_id, + pixel_pitch_id, + ), + kernel_amplitude, + cmin = 0., + cmax = kernel_amplitude.max() + ) + odak.learn.tools.save_image( + '{}/phase_w{:03}_p{:03d}.png'.format( + directory, + wavelength_id, + pixel_pitch_id, + ), + kernel_phase, + cmin = 0., + cmax = 2 * torch.pi + ) + odak.learn.tools.save_image( + '{}/weighted_w{:03}_p{:03d}.png'.format( + directory, + wavelength_id, + pixel_pitch_id, + ), + kernel_weighted, + cmin = 0., + cmax = kernel_weighted.max() + ) + return True + + +if __name__ == '__main__': + sys.exit(main())