From f662aaa6f1e40bc491eeaca59439dbe05cf5340d Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Sun, 12 May 2024 19:25:00 -0700 Subject: [PATCH] poster scripting --- examples/models/anisotropic_thick_3d.py | 338 ++++++++++++++++++++++++ waveorder/optics.py | 135 ++++++++-- waveorder/util.py | 68 ++++- waveorder/waveorder_reconstructor.py | 11 +- 4 files changed, 523 insertions(+), 29 deletions(-) create mode 100644 examples/models/anisotropic_thick_3d.py diff --git a/examples/models/anisotropic_thick_3d.py b/examples/models/anisotropic_thick_3d.py new file mode 100644 index 0000000..f73cbe3 --- /dev/null +++ b/examples/models/anisotropic_thick_3d.py @@ -0,0 +1,338 @@ +import torch +import napari +import numpy as np +from waveorder import optics, util +from waveorder.models import phase_thick_3d + +# Parameters +# all lengths must use consistent units e.g. um +margin = 50 +simulation_arguments = { + "zyx_shape": (129, 256, 256), + "yx_pixel_size": 6.5 / 65, + "z_pixel_size": 0.1, + "index_of_refraction_media": 1.25, +} +# phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5} +transfer_function_arguments = { + "z_padding": 0, + "wavelength_illumination": 0.5, + "numerical_aperture_illumination": 0.75, # 75, + "numerical_aperture_detection": 1.0, +} +input_jones = torch.tensor([0.0 + 1.0j, 1.0 + 0j]) + +# # Create a phantom +# zyx_phase = phase_thick_3d.generate_test_phantom( +# **simulation_arguments, **phantom_arguments +# ) + +# Convert +zyx_shape = simulation_arguments["zyx_shape"] +yx_pixel_size = simulation_arguments["yx_pixel_size"] +z_pixel_size = simulation_arguments["z_pixel_size"] +index_of_refraction_media = simulation_arguments["index_of_refraction_media"] +z_padding = transfer_function_arguments["z_padding"] +wavelength_illumination = transfer_function_arguments[ + "wavelength_illumination" +] +numerical_aperture_illumination = transfer_function_arguments[ + "numerical_aperture_illumination" +] +numerical_aperture_detection = transfer_function_arguments[ + "numerical_aperture_detection" +] + +# Precalculations +z_total = zyx_shape[0] + 2 * z_padding +z_position_list = torch.fft.ifftshift( + (torch.arange(z_total) - z_total // 2) * z_pixel_size +) + +# Calculate frequencies +y_frequencies, x_frequencies = util.generate_frequencies( + zyx_shape[1:], yx_pixel_size +) +radial_frequencies = np.sqrt(x_frequencies**2 + y_frequencies**2) + +# 2D pupils +ill_pupil = optics.generate_pupil( + radial_frequencies, + numerical_aperture_illumination, + wavelength_illumination, +) +det_pupil = optics.generate_pupil( + radial_frequencies, + numerical_aperture_detection, + wavelength_illumination, +) +pupil = optics.generate_pupil( + radial_frequencies, + index_of_refraction_media, # largest possible NA + wavelength_illumination, +) + +# Defocus pupils +defocus_pupil = optics.generate_propagation_kernel( + radial_frequencies, + pupil, + wavelength_illumination / index_of_refraction_media, + z_position_list, +) + +greens_functions_z = optics.generate_greens_function_z( + radial_frequencies, + pupil, + wavelength_illumination / index_of_refraction_media, + z_position_list, +) + +# Calculate vector defocus pupils +S = optics.generate_vector_source_defocus_pupil( + x_frequencies, + y_frequencies, + z_position_list, + defocus_pupil, + input_jones, + ill_pupil, + wavelength_illumination / index_of_refraction_media, +) + +# Simplified scalar pupil +sP = optics.generate_propagation_kernel( + radial_frequencies, + det_pupil, + wavelength_illumination / index_of_refraction_media, + z_position_list, +) + +P = optics.generate_vector_detection_defocus_pupil( + x_frequencies, + y_frequencies, + z_position_list, + defocus_pupil, + det_pupil, + wavelength_illumination / index_of_refraction_media, +) + +G = optics.generate_defocus_greens_tensor( + x_frequencies, + y_frequencies, + greens_functions_z, + pupil, + lambda_in=wavelength_illumination / index_of_refraction_media, +) + +# window = torch.fft.ifftshift( +# torch.hann_window(z_position_list.shape[0], periodic=False) +# ) + +# ###### LATEST + +# # abs() and *(1j) are hacks to correct for tricky phase shifts +# P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64) +# G_3D = torch.abs(torch.fft.ifft(G, dim=-3)) * (-1j) +# S_3D = torch.fft.ifft(S, dim=-3) + +# # Normalize +# P_3D /= torch.amax(torch.abs(P_3D)) +# G_3D /= torch.amax(torch.abs(G_3D)) +# S_3D /= torch.amax(torch.abs(S_3D)) + +# # Main part +# PG_3D = torch.einsum("ijzyx,jpzyx->ipzyx", P_3D, G_3D) +# PS_3D = torch.einsum("jlzyx,lzyx,kzyx->jlzyx", P_3D, S_3D, torch.conj(S_3D)) + +# # PG_3D /= torch.amax(torch.abs(PG_3D)) +# # PS_3D /= torch.amax(torch.abs(PS_3D)) + +# pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1)) +# ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1)) + +# H1 = torch.fft.ifftn( +# torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)), +# dim=(-3, -2, -1), +# ) + +# H2 = torch.fft.ifftn( +# torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)), +# dim=(-3, -2, -1), +# ) + +# MAY 12 Simplified +P_3D = torch.abs(torch.fft.ifft(sP, dim=-3)).type(torch.complex64) +G_3D = torch.abs(torch.fft.ifft(G, dim=-3)) * (-1j) +S_3D = torch.fft.ifft(S, dim=-3) + +# Normalize +P_3D /= torch.amax(torch.abs(P_3D)) +G_3D /= torch.amax(torch.abs(G_3D)) +S_3D /= torch.amax(torch.abs(S_3D)) + +# Main part +PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D) +PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D)) + +PG_3D /= torch.amax(torch.abs(PG_3D)) +PS_3D /= torch.amax(torch.abs(PS_3D)) + +pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1)) +ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1)) + +H1 = torch.fft.ifftn( + torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)), + dim=(-3, -2, -1), +) + +H2 = torch.fft.ifftn( + torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)), + dim=(-3, -2, -1), +) + +H_re = H1[1:, 1:] + H2[1:, 1:] +# H_im = 1j * (H1 - H2) + +s = util.pauli() +Y = util.gellmann() + +H_re_stokes = torch.einsum("sik,ikpjzyx,lpj->slzyx", s, H_re, Y) + +print("H_re_stokes: (RE, IM, ABS)") +torch.set_printoptions(precision=1) +print(torch.log10(torch.sum(torch.real(H_re_stokes) ** 2, dim=(-3, -2, -1)))) +print(torch.log10(torch.sum(torch.imag(H_re_stokes) ** 2, dim=(-3, -2, -1)))) +print(torch.log10(torch.sum(torch.abs(H_re_stokes) ** 2, dim=(-3, -2, -1)))) + +# Display transfer function +v = napari.Viewer() + + +def view_transfer_function( + transfer_function, +): + shift_dims = (-3, -2, -1) + lim = 1e-3 + zyx_scale = np.array( + [ + zyx_shape[0] * z_pixel_size, + zyx_shape[1] * yx_pixel_size, + zyx_shape[2] * yx_pixel_size, + ] + ) + + v.add_image( + torch.fft.ifftshift(torch.real(transfer_function), dim=shift_dims) + .cpu() + .numpy(), + colormap="bwr", + contrast_limits=(-lim, lim), + scale=1 / zyx_scale, + ) + if transfer_function.dtype == torch.complex64: + v.add_image( + torch.fft.ifftshift(torch.imag(transfer_function), dim=shift_dims) + .cpu() + .numpy(), + colormap="bwr", + contrast_limits=(-lim, lim), + scale=1 / zyx_scale, + ) + + # v.dims.order = (2, 1, 0) + + +# view_transfer_function(H_re_stokes) +# view_transfer_function(G_3D) +# view_transfer_function(H_re) +# view_transfer_function(P_3D) + +# PLOT transfer function +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors + + +def plot_data(data, y_slices, filename): + fig, axs = plt.subplots(4, 9, figsize=(20, 10)) # Adjust the size as needed + + for i in range(data.shape[0]): # Stokes parameter + for j in range(data.shape[1]): # Object parameter + for k, y in enumerate(y_slices): # Y slices + z = data[i, j, :, y, :] + hue = np.angle(z) / (2 * np.pi) + 0.5 # Normalize and shift to make red at 0 + sat = np.abs(z) / np.amax(np.abs(z)) + hsv = np.stack((hue, sat, np.ones_like(sat)), axis=-1) + rgb = mcolors.hsv_to_rgb(hsv) + + ax = axs[i, j] + ax.imshow(rgb, aspect='auto') + ax.set_title('') # Remove titles + ax.set_xticks([]) # Remove x-axis ticks + ax.set_yticks([]) # Remove y-axis ticks + ax.spines['top'].set_visible(False) # Hide top spine + ax.spines['right'].set_visible(False) # Hide right spine + ax.spines['bottom'].set_visible(False) # Hide bottom spine + ax.spines['left'].set_visible(False) # Hide left spine + ax.set_xlabel('') # Remove x-axis labels + + plt.tight_layout() + plt.savefig(filename, format='pdf') + +# Adjust y_slices according to your index base (check if your array index starts at 0) +y_center = 128 # Assuming the middle index for Y dimension +y_slices = [y_center - 10, y_center, y_center + 10] +plot_data(torch.fft.ifftshift(H_re_stokes, dim=(-3, -2, -1)).numpy(), y_slices, './output.pdf') + +# Simulate +yx_star, yx_theta, _ = util.generate_star_target( + yx_shape=zyx_shape[1:], + blur_px=1, + margin=margin, +) +c00 = yx_star +c2_2 = -torch.sin(2 * yx_theta) * yx_star +c22 = torch.cos(2 * yx_theta) * yx_star + +# Put in in a center slices of a 3D object +center_slice_object = torch.stack((c00, c2_2, c22), dim=0) +object = torch.zeros((3,) + zyx_shape) +object[:, zyx_shape[0] // 2, ...] = center_slice_object + +# Simulate +object_spectrum = torch.fft.fftn(object, dim=(-3, -2, -1)) +data_spectrum = torch.einsum( + "slzyx,lzyx->szyx", H_re_stokes[:, (0, 4, 8), ...], object_spectrum +) +data = torch.fft.ifftn(data_spectrum, dim=(-3, -2, -1)) + +v.add_image(object.numpy()) +v.add_image(torch.real(data).numpy()) +v.add_image(torch.imag(data).numpy()) + +import pdb + +pdb.set_trace() + + +zyx_data = phase_thick_3d.apply_transfer_function( + zyx_phase, + real_potential_transfer_function, + transfer_function_arguments["z_padding"], + brightness=1e3, +) + +# Reconstruct +zyx_recon = phase_thick_3d.apply_inverse_transfer_function( + zyx_data, + real_potential_transfer_function, + imag_potential_transfer_function, + transfer_function_arguments["z_padding"], +) + +# Display +viewer.add_image(zyx_phase.numpy(), name="Phantom", scale=zyx_scale) +viewer.add_image(zyx_data.numpy(), name="Data", scale=zyx_scale) +viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale) +input("Showing object, data, and recon. Press to quit...") + +# %% diff --git a/waveorder/optics.py b/waveorder/optics.py index 4151b9d..8df21b3 100644 --- a/waveorder/optics.py +++ b/waveorder/optics.py @@ -133,7 +133,7 @@ def generate_pupil(frr, NA, lamb_in): numerical aperture of the pupil function (normalized by the refractive index of the immersion media) lamb_in : float - wavelength of the light (inside the immersion media) + wavelength of the light in free space in units of length (inverse of frr's units) Returns @@ -225,6 +225,101 @@ def gen_sector_Pupil(fxx, fyy, NA, lamb_in, sector_angle, rotation_angle): return Pupil_sector +def rotation_matrix(nu_z, nu_y, nu_x, wavelength): + nu_perp_squared = nu_x**2 + nu_y**2 + nu_zz = wavelength * nu_z - 1 + + R_xx = (wavelength * nu_x**2 * nu_z + nu_y**2) / nu_perp_squared + R_yy = (wavelength * nu_y**2 * nu_z + nu_x**2) / nu_perp_squared + R_xy = nu_x * nu_y * nu_zz / nu_perp_squared + + row0 = torch.stack((-wavelength * nu_y, -wavelength * nu_x), dim=0) + row1 = torch.stack((R_yy, R_xy), dim=0) + row2 = torch.stack((R_xy, R_xx), dim=0) + + out = torch.stack((row0, row1, row2), dim=0) + + # KLUDGE to avoid fix nans + out[..., 0, 0] = torch.tensor([[0, 0], [1, 0], [0, 1]])[..., None] + + return out + + +def generate_vector_source_defocus_pupil( + x_frequencies, + y_frequencies, + z_position_list, + defocus_pupil, + input_jones, + ill_pupil, + wavelength, +): + ill_pupil_3d = torch.einsum( + "zyx,yx->zyx", torch.fft.fft(defocus_pupil, dim=0), ill_pupil + ).abs() # make this real + + # Calculate zyx_frequency grid (inelegant) + z_frequencies = torch.fft.ifft(z_position_list) + freq_shape = z_frequencies.shape + x_frequencies.shape + z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape) + y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape) + x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape) + + # Calculate rotation matrix + rotations = rotation_matrix( + z_broadcast, y_broadcast, x_broadcast, wavelength + ).type(torch.complex64) + + # Main calculation in the frequency domain + source_pupil = ( + torch.einsum( + "ijzyx,j,zyx->izyx", rotations, input_jones, ill_pupil_3d + ) # .abs() + # ** 2 + ) # abs here is critical...incoherent pupil + + # Convert back to defocus pupil + source_defocus_pupil = torch.fft.ifft(source_pupil, dim=-3) + + return source_defocus_pupil + + +def generate_vector_detection_defocus_pupil( + x_frequencies, + y_frequencies, + z_position_list, + det_defocus_pupil, + det_pupil, + wavelength, +): + # TODO: refactor redundancy with illumination pupil + det_pupil_3d = torch.einsum( + "zyx,yx->zyx", torch.fft.ifft(det_defocus_pupil, dim=0), det_pupil + ) + + # Calculate zyx_frequency grid (inelegant) + z_frequencies = torch.fft.ifft(z_position_list) + freq_shape = z_frequencies.shape + x_frequencies.shape + z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape) + y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape) + x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape) + + # Calculate rotation matrix + rotations = rotation_matrix( + z_broadcast, y_broadcast, x_broadcast, wavelength + ).type(torch.complex64) + + # Main calculation in the frequency domain + vector_detection_pupil = torch.einsum( + "jizyx,zyx->ijzyx", rotations, det_pupil_3d + ) + + # Convert back to defocus pupil + detection_defocus_pupil = torch.fft.fft(vector_detection_pupil, dim=-3) + + return detection_defocus_pupil + + def Source_subsample(Source_cont, NAx_coord, NAy_coord, subsampled_NA=0.1): """ @@ -310,9 +405,10 @@ def generate_propagation_kernel( """ - oblique_factor = ( - (1 - wavelength**2 * radial_frequencies**2) * pupil_support - ) ** (1 / 2) / wavelength + oblique_factor = ((1 - wavelength**2 * radial_frequencies**2)) ** ( + 1 / 2 + ) / wavelength + oblique_factor = torch.nan_to_num(oblique_factor, nan=0.0) propagation_kernel = pupil_support[None, :, :] * torch.exp( 1j @@ -367,7 +463,7 @@ def generate_greens_function_z( 1j * 2 * np.pi - * torch.tensor(z_position_list)[:, None, None] + * torch.abs(torch.tensor(z_position_list)[:, None, None]) * oblique_factor[None, :, :] ) / (oblique_factor[None, :, :] + 1e-15) @@ -376,23 +472,25 @@ def generate_greens_function_z( return greens_function_z -def gen_dyadic_Greens_tensor_z(fxx, fyy, G_fun_z, Pupil_support, lambda_in): +def generate_defocus_greens_tensor( + fxx, fyy, G_fun_z, Pupil_support, lambda_in +): """ generate forward dyadic Green's function in u_x, u_y, z space Parameters ---------- - fxx : numpy.ndarray + fxx : tensor.Tensor x component of 2D spatial frequency array with the size of (Ny, Nx) - fyy : numpy.ndarray + fyy : tensor.Tensor y component of 2D spatial frequency array with the size of (Ny, Nx) - G_fun_z : numpy.ndarray - forward Green's function in u_x, u_y, z space with size of (Ny, Nx, Nz) + G_fun_z : tensor.Tensor + forward Green's function in u_x, u_y, z space with size of (Nz, Ny, Nx) - Pupil_support : numpy.ndarray + Pupil_support : tensor.Tensor the array that defines the support of the pupil function with the size of (Ny, Nx) lambda_in : float @@ -400,22 +498,21 @@ def gen_dyadic_Greens_tensor_z(fxx, fyy, G_fun_z, Pupil_support, lambda_in): Returns ------- - G_tensor_z : numpy.ndarray - forward dyadic Green's function in u_x, u_y, z space with the size of (3, 3, Ny, Nx, Nz) + G_tensor_z : tensor.Tensor + forward dyadic Green's function in u_x, u_y, z space with the size of (3, 3, Nz, Ny, Nx) """ - N, M = fxx.shape fr = (fxx**2 + fyy**2) ** (1 / 2) oblique_factor = ((1 - lambda_in**2 * fr**2) * Pupil_support) ** ( 1 / 2 ) / lambda_in - diff_filter = np.zeros((3,) + G_fun_z.shape, complex) - diff_filter[0] = (1j * 2 * np.pi * fxx * Pupil_support)[..., np.newaxis] - diff_filter[1] = (1j * 2 * np.pi * fyy * Pupil_support)[..., np.newaxis] - diff_filter[2] = (1j * 2 * np.pi * oblique_factor)[..., np.newaxis] + diff_filter = torch.zeros((3,) + G_fun_z.shape, dtype=torch.complex64) + diff_filter[0] = (1j * 2 * np.pi * oblique_factor)[None, ...] + diff_filter[1] = (1j * 2 * np.pi * fyy * Pupil_support)[None, ...] + diff_filter[2] = (1j * 2 * np.pi * fxx * Pupil_support)[None, ...] - G_tensor_z = np.zeros((3, 3) + G_fun_z.shape, complex) + G_tensor_z = torch.zeros((3, 3) + G_fun_z.shape, dtype=torch.complex64) for i in range(3): for j in range(3): diff --git a/waveorder/util.py b/waveorder/util.py index 613614e..c797863 100644 --- a/waveorder/util.py +++ b/waveorder/util.py @@ -331,12 +331,15 @@ def gen_coordinate(img_dim, ps): return (xx, yy, fxx, fyy) -def generate_radial_frequencies(img_dim, ps): +def generate_frequencies(img_dim, ps): fy = torch.fft.fftfreq(img_dim[0], ps) fx = torch.fft.fftfreq(img_dim[1], ps) - fyy, fxx = torch.meshgrid(fy, fx, indexing="ij") + return fyy, fxx + +def generate_radial_frequencies(img_dim, ps): + fyy, fxx = generate_frequencies(img_dim, ps) return torch.sqrt(fyy**2 + fxx**2) @@ -2239,3 +2242,64 @@ def orientation_3D_continuity_map( retardance_pr_avg /= np.max(retardance_pr_avg) return retardance_pr_avg + + +def pauli(): + # yx order + # trace-orthogonal normalization + # torch.einsum("kij,lji->kl", pauli(), pauli()) == torch.eye(4) + + # intensity, x-y, +45-(-45), LCP-RCP + # yx + # yx + a = 2**-0.5 + sigma = torch.tensor( + [ + [[a, 0], [0, a]], + [[-a, 0], [0, a]], + [[0, a], [a, 0]], + [[0, 1j * a], [-1j * a, 0]], + ] + ) + return sigma +# torch.allclose( +# torch.abs(torch.einsum("kij,lji->kl", s, s) - torch.eye(4)), +# torch.zeros((4, 4)), +# atol=1e-5, +# ) + + +def gellmann(): + # zyx order + # trace-orthogonal normalization + # torch.einsum("kij,lji->kl", gellmann(), gellmann()) == torch.eye(9) + # + # lexicographical order of the Gell-Mann matrices + # 00, 1-1, 10, 11, 2-2, 2-1, 20, 21, 22 + # + # zyx + # zyx + a = 3**-0.5 + b = -1j * 2**-0.5 + c = 2**-0.5 + d = -(6**-0.5) + e = 2 * (6**-0.5) + return torch.tensor( + [ + [[a, 0, 0], [0, a, 0], [0, 0, a]], + [[0, 0, -b], [0, 0, 0], [b, 0, 0]], + [[0, 0, 0], [0, 0, -b], [0, b, 0]], + [[0, -b, 0], [b, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, c], [0, c, 0]], # + [[0, c, 0], [c, 0, 0], [0, 0, 0]], + [[e, 0, 0], [0, d, 0], [0, 0, d]], + [[0, 0, c], [0, 0, 0], [c, 0, 0]], + [[0, 0, 0], [0, -c, 0], [0, 0, c]], # + ] + ) + + # torch.allclose( +# torch.abs(torch.einsum("kij,lji->kl", Y, Y) - torch.eye(9)), +# torch.zeros((9, 9)), +# atol=1e-5, +# ) diff --git a/waveorder/waveorder_reconstructor.py b/waveorder/waveorder_reconstructor.py index df62aaa..5faed08 100644 --- a/waveorder/waveorder_reconstructor.py +++ b/waveorder/waveorder_reconstructor.py @@ -160,7 +160,6 @@ def instrument_matrix_calibration(I_cali_norm, I_meas): class waveorder_microscopy: - """ waveorder_microscopy contains reconstruction algorithms for label-free @@ -732,9 +731,7 @@ def inclination_recon_setup(self, inc_recon): wave_vec_norm_x = self.lambda_illu * self.fxx wave_vec_norm_y = self.lambda_illu * self.fyy wave_vec_norm_z = ( - np.maximum( - 0, 1 - wave_vec_norm_x**2 - wave_vec_norm_y**2 - ) + np.maximum(0, 1 - wave_vec_norm_x**2 - wave_vec_norm_y**2) ) ** (0.5) incident_theta = np.arctan2( @@ -1005,7 +1002,7 @@ def gen_2D_vec_WOTF(self, inc_option=False): .numpy() .transpose((1, 2, 0)) ) - G_tensor_z = gen_dyadic_Greens_tensor_z( + G_tensor_z = generate_defocus_greens_tensor( self.fxx, self.fyy, G_fun_z, self.Pupil_support, self.lambda_illu ) @@ -4017,9 +4014,7 @@ def Fluor_anisotropy_recon(self, S1_stack, S2_stack): S1_stack = cp.array(S1_stack) S2_stack = cp.array(S2_stack) - anisotropy = cp.asnumpy( - 0.5 * cp.sqrt(S1_stack**2 + S2_stack**2) - ) + anisotropy = cp.asnumpy(0.5 * cp.sqrt(S1_stack**2 + S2_stack**2)) orientation = cp.asnumpy( (0.5 * cp.arctan2(S2_stack, S1_stack)) % np.pi )