Skip to content

Commit

Permalink
Minor updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Aug 19, 2024
1 parent e7946e8 commit 79f9486
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 45 deletions.
2 changes: 1 addition & 1 deletion odak/learn/wave/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9,
for px in pxs:
for py in pys:
r = (X + px - wx) ** 2 + (Y + py - wy) ** 2
h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r)
h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r)
H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
return H

Expand Down
24 changes: 11 additions & 13 deletions odak/learn/wave/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self,
self.wavelengths = wavelengths
self.resolution = resolution
self.targets = targets
if propagator.propagation_type != 'Impulse Response Fresnel':
scale_factor = 1
self.scale_factor = scale_factor
self.propagator = propagator
self.learning_rate = learning_rate
Expand All @@ -50,7 +52,6 @@ def __init__(self,
self.double_phase = double_phase
self.channel_power_filename = channel_power_filename
self.method = method
self.upsample = torch.nn.Upsample(scale_factor = self.scale_factor, mode = 'nearest')
if self.method != 'conventional' and self.method != 'multi-color':
logging.warning('Unknown optimization method. Options are conventional or multi-color.')
import sys
Expand Down Expand Up @@ -109,12 +110,13 @@ def init_amplitude(self):
"""
Internal function to set the amplitude of the illumination source.
"""
self.amplitude = torch.ones(
self.resolution[0],
self.resolution[1],
requires_grad = False,
device = self.device
)
self.amplitude = torch.zeros(
self.resolution[0] * self.scale_factor,
self.resolution[1] * self.scale_factor,
requires_grad = False,
device = self.device
)
self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.


def init_phase(self):
Expand Down Expand Up @@ -321,12 +323,8 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
)
loss_variation_hologram += loss_phase
for channel_id in range(self.number_of_channels):
if self.scale_factor != 1:
phase_scaled = torch.zeros_like(self.amplitude)
phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
else:
phase_scaled = phase
phase_scaled = torch.zeros_like(self.amplitude)
phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
laser_power = laser_powers[frame_id][channel_id]
hologram = generate_complex_field(
laser_power * self.amplitude,
Expand Down
58 changes: 27 additions & 31 deletions odak/learn/wave/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ def __init__(
self.wavelengths = wavelengths
self.resolution = resolution
self.propagation_type = propagation_type
if self.propagation_type == 'Impulse Response Fresnel':
self.resolution_factor = resolution_factor
else:
self.resolution_factor = 1
if self.propagation_type != 'Impulse Response Fresnel':
resolution_factor = 1
self.resolution_factor = resolution_factor
self.number_of_frames = number_of_frames
self.number_of_depth_layers = number_of_depth_layers
self.number_of_channels = len(self.wavelengths)
Expand Down Expand Up @@ -251,8 +250,8 @@ def __call__(self, input_field, channel_id, depth_id):
if not self.generated_kernels[depth_id, channel_id]:
if self.propagator_type == 'forward':
H = get_propagation_kernel(
nu = input_field.shape[-2] * 2,
nv = input_field.shape[-1] * 2,
nu = self.resolution[0] * 2,
nv = self.resolution[1] * 2,
dx = self.pixel_pitch,
wavelength = self.wavelengths[channel_id],
distance = distance,
Expand All @@ -263,8 +262,8 @@ def __call__(self, input_field, channel_id, depth_id):
)
elif self.propagator_type == 'back and forth':
H_forward = get_propagation_kernel(
nu = input_field.shape[-2] * 2,
nv = input_field.shape[-1] * 2,
nu = self.resolution[0] * 2,
nv = self.resolution[1] * 2,
dx = self.pixel_pitch,
wavelength = self.wavelengths[channel_id],
distance = self.zero_mode_distance,
Expand All @@ -275,8 +274,8 @@ def __call__(self, input_field, channel_id, depth_id):
)
distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
H_back = get_propagation_kernel(
nu = input_field.shape[-2] * 2,
nv = input_field.shape[-1] * 2,
nu = self.resolution[0] * 2,
nv = self.resolution[1] * 2,
dx = self.pixel_pitch,
wavelength = self.wavelengths[channel_id],
distance = distance_back,
Expand All @@ -290,20 +289,7 @@ def __call__(self, input_field, channel_id, depth_id):
self.generated_kernels[depth_id, channel_id] = True
else:
H = self.kernels[depth_id, channel_id].detach().clone()
if self.resolution_factor > 1:
field_amplitude = calculate_amplitude(input_field)
field_phase = calculate_phase(input_field)
field_scale_amplitude = torch.zeros(
input_field.shape[-2] * self.resolution_factor,
input_field.shape[-1] * self.resolution_factor,
device = input_field.device
)
field_scale_phase = torch.zeros_like(field_scale_amplitude)
field_scale_amplitude[::self.resolution_factor, ::self.resolution_factor] = field_amplitude
field_scale_phase[::self.resolution_factor, ::self.resolution_factor] = field_phase
field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
else:
field_scale = input_field
field_scale = input_field
field_scale_padded = zero_pad(field_scale)
output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
output_field = crop_center(output_field_padded)
Expand Down Expand Up @@ -349,17 +335,27 @@ def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_com
device = self.device
)
if isinstance(amplitude, type(None)):
amplitude = torch.ones(
self.number_of_channels,
self.resolution[0],
self.resolution[1],
device = self.device
)
amplitude = torch.zeros(
self.number_of_channels,
self.resolution[0] * self.resolution_factor,
self.resolution[1] * self.resolution_factor,
device = self.device
)
amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
if self.resolution_factor != 1:
hologram_phases_scaled = torch.zeros_like(amplitude)
hologram_phases_scaled[
:,
::self.resolution_factor,
::self.resolution_factor
] = hologram_phases
else:
hologram_phases_scaled = hologram_phases
for frame_id in range(self.number_of_frames):
for depth_id in range(self.number_of_depth_layers):
for channel_id in range(self.number_of_channels):
laser_power = self.get_laser_powers()[frame_id][channel_id]
phase = hologram_phases[frame_id]
phase = hologram_phases_scaled[frame_id]
hologram = generate_complex_field(
laser_power * amplitude[channel_id],
phase * self.phase_scale[channel_id]
Expand Down

0 comments on commit 79f9486

Please sign in to comment.