Skip to content

Commit

Permalink
Merge pull request #104 from KorayKavakli/master
Browse files Browse the repository at this point in the history
adding incoherent angular spectrum method
  • Loading branch information
kaanaksit authored Sep 28, 2024
2 parents ef1253d + 76e32c6 commit 12d36a1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
24 changes: 24 additions & 0 deletions odak/learn/tools/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,27 @@ def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 's
blurred_field.shape[-1]
)
return blurred_field


def correlation_2d(first_tensor, second_tensor):
"""
Definition to calculate the correlation between two tensors.
Parameters
----------
first_tensor : torch.tensor
First tensor.
second_tensor : torch.tensor
Second tensor.
Returns
----------
correlation : torch.tensor
Correlation between the two tensors.
"""
fft_first_tensor = (torch.fft.fft2(first_tensor))
fft_second_tensor = (torch.fft.fft2(second_tensor))
conjugate_second_tensor = torch.conj(fft_second_tensor)
result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor))
return result

80 changes: 79 additions & 1 deletion odak/learn/wave/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .util import set_amplitude, generate_complex_field, calculate_amplitude, calculate_phase
from .lens import quadratic_phase_function
from .util import wavenumber
from ..tools import zero_pad, crop_center, generate_2d_gaussian, circular_binary_mask
from ..tools import zero_pad, crop_center, generate_2d_gaussian, circular_binary_mask, correlation_2d
from tqdm import tqdm


Expand Down Expand Up @@ -73,6 +73,8 @@ def propagate_beam(
result = custom(field, kernel, zero_padding[1], aperture = aperture)
elif propagation_type == 'Fraunhofer':
result = fraunhofer(field, k, distance, dx, wavelength)
elif propagation_type == 'Incoherent Angular Spectrum':
result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
else:
logging.warning('Propagation type not recognized')
assert True == False
Expand Down Expand Up @@ -530,6 +532,41 @@ def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance
return H


def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
"""
Helper function for odak.learn.wave.angular_spectrum.
Parameters
----------
nu : int
Resolution at X axis in pixels.
nv : int
Resolution at Y axis in pixels.
dx : float
Pixel pitch in meters.
wavelength : float
Wavelength in meters.
distance : float
Distance in meters.
device : torch.device
Device, for more see torch.device().
Returns
-------
H : float
Complex kernel in Fourier domain.
"""
distance = torch.tensor([distance]).to(device)
fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
FY, FX = torch.meshgrid(fx, fy, indexing='ij')
H = torch.exp(1j * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
H_ptime = correlation_2d(H, H)
H = H_ptime.to(device)
return H


def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
"""
A definition to calculate convolution with Angular Spectrum method for beam propagation.
Expand Down Expand Up @@ -572,6 +609,47 @@ def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, a
return result


def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
"""
A definition to calculate incoherent beam propagation with Angular Spectrum method.
Parameters
----------
field : torch.complex
Complex field [m x n].
k : odak.wave.wavenumber
Wave number of a wave, see odak.wave.wavenumber for more.
distance : float
Propagation distance.
dx : float
Size of one single pixel in the field grid (in meters).
wavelength : float
Wavelength of the electric field.
zero_padding : bool
Zero pad in Fourier domain.
aperture : torch.tensor
Fourier domain aperture (e.g., pinhole in a typical holographic display).
The default is one, but an aperture could be as large as input field [m x n].
Returns
-------
result : torch.complex
Final complex field [m x n].
"""
H = get_propagation_kernel(
nu = field.shape[-2],
nv = field.shape[-1],
dx = dx,
wavelength = wavelength,
distance = distance,
propagation_type = 'Incoherent Angular Spectrum',
device = field.device
)
result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
return result


def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
"""
Helper function for odak.learn.wave.band_limited_angular_spectrum.
Expand Down

0 comments on commit 12d36a1

Please sign in to comment.