Skip to content

Commit

Permalink
Merge branch 'master' into edit_ASM
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit authored Nov 11, 2024
2 parents ed4cbf9 + 8de0360 commit 8f700f1
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 24 deletions.
227 changes: 204 additions & 23 deletions odak/learn/wave/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,84 @@ def propagate_beam(
if zero_padding[0]:
field = zero_pad(field)
if propagation_type == 'Angular Spectrum':
result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
result = angular_spectrum(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture
)
elif propagation_type == 'Bandlimited Angular Spectrum':
result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
result = band_limited_angular_spectrum(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture
)
elif propagation_type == 'Impulse Response Fresnel':
result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
result = impulse_response_fresnel(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture,
scale = scale,
samples = samples
)
elif propagation_type == 'Seperable Impulse Response Fresnel':
result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
result = seperable_impulse_response_fresnel(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture,
scale = scale,
samples = samples
)
elif propagation_type == 'Transfer Function Fresnel':
result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
result = transfer_function_fresnel(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture
)
elif propagation_type == 'custom':
result = custom(field, kernel, zero_padding[1], aperture = aperture)
result = custom(
field = field,
kernel = kernel,
zero_padding = zero_padding[1],
aperture = aperture
)
elif propagation_type == 'Fraunhofer':
result = fraunhofer(field, k, distance, dx, wavelength)
result = fraunhofer(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength
)
elif propagation_type == 'Incoherent Angular Spectrum':
result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
result = incoherent_angular_spectrum(
field = field,
k = k,
distance = distance,
dx = dx,
wavelength = wavelength,
zero_padding = zero_padding[1],
aperture = aperture
)
else:
logging.warning('Propagation type not recognized')
assert True == False
Expand Down Expand Up @@ -301,7 +364,13 @@ def get_light_kernels(
return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters


def fraunhofer(field, k, distance, dx, wavelength):
def fraunhofer(
field,
k,
distance,
dx,
wavelength
):
"""
A definition to calculate light transport usin Fraunhofer approximation.
Expand Down Expand Up @@ -334,7 +403,12 @@ def fraunhofer(field, k, distance, dx, wavelength):
return result


def custom(field, kernel, zero_padding = False, aperture = 1.):
def custom(
field,
kernel,
zero_padding = False,
aperture = 1.
):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation.
Expand Down Expand Up @@ -369,7 +443,16 @@ def custom(field, kernel, zero_padding = False, aperture = 1.):
return result


def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu'), scale = 1, aperture_samples = [20, 20, 5, 5]):
def get_impulse_response_fresnel_kernel(
nu,
nv,
dx = 8e-6,
wavelength = 515e-9,
distance = 0.,
device = torch.device('cpu'),
scale = 1,
aperture_samples = [20, 20, 5, 5]
):
"""
Helper function for odak.learn.wave.impulse_response_fresnel.
Expand Down Expand Up @@ -557,7 +640,17 @@ def get_point_wise_impulse_response_fresnel_kernel(
return h


def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
def seperable_impulse_response_fresnel(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.,
scale = 1,
samples = [20, 20, 5, 5]
):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.
Expand Down Expand Up @@ -614,7 +707,17 @@ def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_
return result


def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
def impulse_response_fresnel(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.,
scale = 1,
samples = [20, 20, 5, 5]
):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation.
Expand Down Expand Up @@ -671,7 +774,14 @@ def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding =
return result


def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
def get_transfer_function_fresnel_kernel(
nu,
nv,
dx = 8e-6,
wavelength = 515e-9,
distance = 0.,
device = torch.device('cpu')
):
"""
Helper function for odak.learn.wave.transfer_function_fresnel.
Expand Down Expand Up @@ -705,7 +815,15 @@ def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9,
return H


def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
def transfer_function_fresnel(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.
):
"""
A definition to calculate convolution based Fresnel approximation for beam propagation.
Expand Down Expand Up @@ -747,7 +865,14 @@ def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding =
return result


def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
def get_angular_spectrum_kernel(
nu,
nv,
dx = 8e-6,
wavelength = 515e-9,
distance = 0.,
device = torch.device('cpu')
):
"""
Helper function for odak.learn.wave.angular_spectrum.
Expand Down Expand Up @@ -781,7 +906,14 @@ def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance
return H


def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
def get_incoherent_angular_spectrum_kernel(
nu,
nv,
dx = 8e-6,
wavelength = 515e-9,
distance = 0.,
device = torch.device('cpu')
):
"""
Helper function for odak.learn.wave.angular_spectrum.
Expand Down Expand Up @@ -816,7 +948,15 @@ def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-
return H


def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
def angular_spectrum(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.
):
"""
A definition to calculate convolution with Angular Spectrum method for beam propagation.
Expand Down Expand Up @@ -858,7 +998,15 @@ def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, a
return result


def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
def incoherent_angular_spectrum(
field,
k,
distance,
dx,
wavelength,
zero_padding = False,
aperture = 1.
):
"""
A definition to calculate incoherent beam propagation with Angular Spectrum method.
Expand Down Expand Up @@ -1009,7 +1157,15 @@ def band_limited_angular_spectrum(
return result


def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):
def gerchberg_saxton(
field,
n_iterations,
distance,
dx,
wavelength,
slm_range = 6.28,
propagation_type = 'Transfer Function Fresnel'
):
"""
Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.
Expand Down Expand Up @@ -1048,7 +1204,16 @@ def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.
return hologram, reconstruction


def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
def stochastic_gradient_descent(
target,
wavelength,
distance,
pixel_pitch,
propagation_type = 'Bandlimited Angular Spectrum',
n_iteration = 100,
loss_function = None,
learning_rate = 0.1
):
"""
Definition to generate phase and reconstruction from target image via stochastic gradient descent.
Expand Down Expand Up @@ -1120,7 +1285,14 @@ def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propa
return hologram, reconstruction


def point_wise(target, wavelength, distance, dx, device, lens_size=401):
def point_wise(
target,
wavelength,
distance,
dx,
device,
lens_size=401
):
"""
Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.
Expand Down Expand Up @@ -1165,7 +1337,16 @@ def point_wise(target, wavelength, distance, dx, device, lens_size=401):
return hologram


def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):
def shift_w_double_phase(
phase,
depth_shift,
pixel_pitch,
wavelength,
propagation_type = 'Transfer Function Fresnel',
kernel_length = 4,
sigma = 0.5,
amplitude = None
):
"""
Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.
Expand Down
2 changes: 1 addition & 1 deletion test/test_wave_propagate_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test(output_directory = 'test_output'):
pixeltom = 3.74e-6
distance = 5e-3
resolution = [250, 250]
propagation_types = ['Transfer Function Fresnel', 'Impulse Response Fresnel', 'Bandlimited Angular Spectrum', 'Angular Spectrum']
propagation_types = ['Transfer Function Fresnel', 'Impulse Response Fresnel', 'Bandlimited Angular Spectrum', 'Angular Spectrum', 'Rayleigh-Sommerfeld']

k = wavenumber(wavelength)
sample_field = np.zeros((resolution[0], resolution[1]), dtype=np.complex64)
Expand Down

0 comments on commit 8f700f1

Please sign in to comment.