Skip to content

Commit

Permalink
A single update.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Oct 15, 2024
1 parent 2067e16 commit 2cdc988
Show file tree
Hide file tree
Showing 7 changed files with 613 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unit tests
output/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
239 changes: 232 additions & 7 deletions odak/learn/wave/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def propagate_beam(
result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
elif propagation_type == 'Impulse Response Fresnel':
result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
elif propagation_type == 'Seperable Impulse Response Fresnel':
result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
elif propagation_type == 'Transfer Function Fresnel':
result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
elif propagation_type == 'custom':
Expand Down Expand Up @@ -124,7 +126,8 @@ def get_propagation_kernel(
-------
kernel : torch.tensor
Complex kernel for the given propagation type.
"""
"""
logging.warning('Requested propagation kernel size for %s method with %s m distance, %s m pixel pitch, %s m wavelength, %s x %s resolutions, x%s scale and %s samples.'.format(propagation_type, distance, dx, nu, nv, scale, samples))
if propagation_type == 'Bandlimited Angular Spectrum':
kernel = get_band_limited_angular_spectrum_kernel(
nu = nu,
Expand Down Expand Up @@ -172,6 +175,17 @@ def get_propagation_kernel(
distance = distance,
device = device
)
elif propagation_type == 'Seperable Impulse Response Fresnel':
kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel(
nu = nu,
nv = nv,
dx = dx,
wavelength = wavelength,
distance = distance,
device = device,
scale = scale,
aperture_samples = samples
)
else:
logging.warning('Propagation type not recognized')
assert True == False
Expand Down Expand Up @@ -380,7 +394,7 @@ def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9,
Returns
-------
H : float
H : torch.complex64
Complex kernel in Fourier domain.
"""
k = wavenumber(wavelength)
Expand All @@ -404,6 +418,202 @@ def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9,
return H


def get_seperable_impulse_response_fresnel_kernel(
nu,
nv,
dx = 3.74e-6,
wavelength = 515e-9,
distance = 0.,
scale = 1,
aperture_samples = [50, 50, 5, 5],
device = torch.device('cpu')
):
"""
Returns impulse response fresnel kernel in separable form.
Parameters
----------
nu : int
Resolution at X axis in pixels.
nv : int
Resolution at Y axis in pixels.
dx : float
Pixel pitch in meters.
wavelength : float
Wavelength in meters.
distance : float
Distance in meters.
device : torch.device
Device, for more see torch.device().
scale : int
Scale with respect to nu and nv (e.g., scale = 2 leads to 2 x nu and 2 x nv resolution for H).
aperture_samples : list
Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
Returns
-------
H : torch.complex64
Complex kernel in Fourier domain.
h : torch.complex64
Complex kernel in spatial domain.
h_x : torch.complex64
1D complex kernel in spatial domain along X axis.
h_y : torch.complex64
1D complex kernel in spatial domain along Y axis.
"""
k = wavenumber(wavelength)
distance = torch.as_tensor(distance, device = device)
length_x, length_y = (
torch.tensor(dx * nu, device = device),
torch.tensor(dx * nv, device = device)
)
x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0)
wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1)
pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1)
pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0)
wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1)
wys = (wys - pys).reshape(1, -1).unsqueeze(1)

X = x.unsqueeze(-1).unsqueeze(-1)
Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
r_x = (X + wxs) ** 2
r_y = (Y + wys) ** 2
r = r_x + r_y
h_x = torch.exp(1j * k / (2 * distance) * r)
h_x = torch.sum(h_x, axis = (1, 2))

if nu != nv:
X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
Y = y.unsqueeze(-1).unsqueeze(-1)
r_x = (X + wxs) ** 2
r_y = (Y + wys) ** 2
r = r_x + r_y
h_y = torch.exp(1j * k * r / (2 * distance))
h_y = torch.sum(h_y, axis = (1, 2))
else:
h_y = h_x.detach().clone()
h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0)
H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
return H, h, h_x, h_y


def get_point_wise_impulse_response_fresnel_kernel(
aperture_points,
aperture_field,
target_points,
resolution,
resolution_factor = 1,
wavelength = 515e-9,
distance = 0.,
randomization = False,
device = torch.device('cpu')
):
"""
This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`.
The point spread function is calculated over provided points, `target_points`.
The final result is reshaped to follow the provided `resolution`.
Parameters
----------
aperture_points : torch.tensor
Points representing an aperture in Euler space (XYZ) [m x 3].
aperture_field : torch.tensor
Complex field for each point provided by `aperture_points` [1 x m].
target_points : torch.tensor
Target points where the propagated field will be calculated [n x 1].
resolution : list
Final resolution that the propagated field will be reshaped [X x Y].
resolution_factor : int
Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
wavelength : float
Wavelength in meters.
randomization : bool
If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
distance : float
Distance in meters.
Returns
-------
h : float
Complex field in spatial domain.
"""
device = aperture_field.device
k = wavenumber(wavelength)
if randomization:
pp = [
aperture_points[:, 0].max() - aperture_points[:, 0].min(),
aperture_points[:, 1].max() - aperture_points[:, 1].min()
]
target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0]
target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1]
deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1)
deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1)
r = deltaX ** 2 + deltaY ** 2
h = torch.exp(1j * k / (2 * distance) * r) * aperture_field
h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor)
h = 1. / (1j * wavelength * distance) * h
return h


def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.
Parameters
----------
field : torch.complex
Complex field (MxN).
k : odak.wave.wavenumber
Wave number of a wave, see odak.wave.wavenumber for more.
distance : float
Propagation distance.
dx : float
Size of one single pixel in the field grid (in meters).
wavelength : float
Wavelength of the electric field.
zero_padding : bool
Zero pad in Fourier domain.
aperture : torch.tensor
Fourier domain aperture (e.g., pinhole in a typical holographic display).
The default is one, but an aperture could be as large as input field [m x n].
scale : int
Resolution factor to scale generated kernel.
samples : list
When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
Returns
-------
result : torch.complex
Final complex field (MxN).
"""
H = get_propagation_kernel(
nu = field.shape[-2],
nv = field.shape[-1],
dx = dx,
wavelength = wavelength,
distance = distance,
propagation_type = 'Seperable Impulse Response Fresnel',
device = field.device,
scale = scale,
samples = samples
)
if scale > 1:
field_amplitude = calculate_amplitude(field)
field_phase = calculate_phase(field)
field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
field_scale_phase = torch.zeros_like(field_scale_amplitude)
field_scale_amplitude[::scale, ::scale] = field_amplitude
field_scale_phase[::scale, ::scale] = field_phase
field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
else:
field_scale = field
result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
return result


def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation.
Expand Down Expand Up @@ -483,7 +693,7 @@ def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9,
Returns
-------
H : float
H : torch.complex64
Complex kernel in Fourier domain.
"""
distance = torch.tensor([distance]).to(device)
Expand Down Expand Up @@ -689,7 +899,14 @@ def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding
return result


def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
def get_band_limited_angular_spectrum_kernel(
nu,
nv,
dx = 8e-6,
wavelength = 515e-9,
distance = 0.,
device = torch.device('cpu')
):
"""
Helper function for odak.learn.wave.band_limited_angular_spectrum.
Expand All @@ -711,7 +928,7 @@ def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515
Returns
-------
H : float
H : torch.complex64
Complex kernel in Fourier domain.
"""
x = dx * float(nu)
Expand Down Expand Up @@ -741,7 +958,15 @@ def get_band_limited_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515
return H


def band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
def band_limited_angular_spectrum(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.
):
"""
A definition to calculate bandlimited angular spectrum based beam propagation. For more
`Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.
Expand Down Expand Up @@ -880,7 +1105,7 @@ def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propa
loss.backward(retain_graph = True)
optimizer.step()
t.set_description(description)
print(description)
logging.warning(description)
torch.no_grad()
hologram = generate_complex_field(1., phase)
reconstruction = propagate_beam(
Expand Down
3 changes: 2 additions & 1 deletion odak/visualize/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
except:
except Exception as e:
warning = 'odak.visualize.plotly requires certain packages: pip install plotly kaleido'
logging.warning(warning)
logging.warning(e)
import numpy as np
from ..wave import calculate_phase, calculate_amplitude, calculate_intensity

Expand Down
Binary file added test/data/rectangular_aperture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion test/test_learn_wave_compare_beam_propagations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ def test(device = torch.device('cpu'), output_directory = 'test_output'):
wavelength = 532e-9
pixel_pitch = 3.74e-6
distance = 1e-3
aperture_samples = [35, 35, 1, 1] # Replace it with this: [50, 50, 5, 5]
aperture_samples = [50, 50, 1, 1]
propagation_types = [
'Seperable Impulse Response Fresnel',
'Impulse Response Fresnel',
'Transfer Function Fresnel',
'Angular Spectrum',
Expand Down
Loading

0 comments on commit 2cdc988

Please sign in to comment.