diff --git a/odak/learn/tools/matrix.py b/odak/learn/tools/matrix.py index 40339660..9064d8e6 100644 --- a/odak/learn/tools/matrix.py +++ b/odak/learn/tools/matrix.py @@ -261,3 +261,27 @@ def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 's blurred_field.shape[-1] ) return blurred_field + + +def correlation_2d(first_tensor, second_tensor): + """ + Definition to calculate the correlation between two tensors. + + Parameters + ---------- + first_tensor : torch.tensor + First tensor. + second_tensor : torch.tensor + Second tensor. + + Returns + ---------- + correlation : torch.tensor + Correlation between the two tensors. + """ + fft_first_tensor = (torch.fft.fft2(first_tensor)) + fft_second_tensor = (torch.fft.fft2(second_tensor)) + conjugate_second_tensor = torch.conj(fft_second_tensor) + result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor)) + return result + diff --git a/odak/learn/wave/classical.py b/odak/learn/wave/classical.py index ab2bf631..31e0a74c 100644 --- a/odak/learn/wave/classical.py +++ b/odak/learn/wave/classical.py @@ -4,7 +4,7 @@ from .util import set_amplitude, generate_complex_field, calculate_amplitude, calculate_phase from .lens import quadratic_phase_function from .util import wavenumber -from ..tools import zero_pad, crop_center, generate_2d_gaussian, circular_binary_mask +from ..tools import zero_pad, crop_center, generate_2d_gaussian, circular_binary_mask, correlation_2d from tqdm import tqdm @@ -73,6 +73,8 @@ def propagate_beam( result = custom(field, kernel, zero_padding[1], aperture = aperture) elif propagation_type == 'Fraunhofer': result = fraunhofer(field, k, distance, dx, wavelength) + elif propagation_type == 'Incoherent Angular Spectrum': + result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture) else: logging.warning('Propagation type not recognized') assert True == False @@ -530,6 +532,41 @@ def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance return H +def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')): + """ + Helper function for odak.learn.wave.angular_spectrum. + + Parameters + ---------- + nu : int + Resolution at X axis in pixels. + nv : int + Resolution at Y axis in pixels. + dx : float + Pixel pitch in meters. + wavelength : float + Wavelength in meters. + distance : float + Distance in meters. + device : torch.device + Device, for more see torch.device(). + + + Returns + ------- + H : float + Complex kernel in Fourier domain. + """ + distance = torch.tensor([distance]).to(device) + fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device) + fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device) + FY, FX = torch.meshgrid(fx, fy, indexing='ij') + H = torch.exp(1j * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2)))) + H_ptime = correlation_2d(H, H) + H = H_ptime.to(device) + return H + + def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.): """ A definition to calculate convolution with Angular Spectrum method for beam propagation. @@ -572,6 +609,47 @@ def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, a return result +def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.): + """ + A definition to calculate incoherent beam propagation with Angular Spectrum method. + + Parameters + ---------- + field : torch.complex + Complex field [m x n]. + k : odak.wave.wavenumber + Wave number of a wave, see odak.wave.wavenumber for more. + distance : float + Propagation distance. + dx : float + Size of one single pixel in the field grid (in meters). + wavelength : float + Wavelength of the electric field. + zero_padding : bool + Zero pad in Fourier domain. + aperture : torch.tensor + Fourier domain aperture (e.g., pinhole in a typical holographic display). + The default is one, but an aperture could be as large as input field [m x n]. + + + Returns + ------- + result : torch.complex + Final complex field [m x n]. + """ + H = get_propagation_kernel( + nu = field.shape[-2], + nv = field.shape[-1], + dx = dx, + wavelength = wavelength, + distance = distance, + propagation_type = 'Incoherent Angular Spectrum', + device = field.device + ) + result = custom(field, H, zero_padding = zero_padding, aperture = aperture) + return result + + def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')): """ Helper function for odak.learn.wave.band_limited_angular_spectrum.