Skip to content

Commit

Permalink
Some typing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Jan 15, 2025
1 parent dccd88f commit ab71446
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 32 deletions.
37 changes: 20 additions & 17 deletions src/sorcha/ephemeris/pixel_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
from collections import defaultdict

import healpy as hp
import numba
import numpy as np

from collections import defaultdict

from sorcha.ephemeris.simulation_geometry import *
from sorcha.ephemeris.simulation_constants import *
from sorcha.ephemeris.simulation_geometry import *


@numba.njit(fastmath=True)
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
nside=128,
nested=True,
n_sub_intervals=101,
use_integrate=False,
):
"""
Initialization function for the class. Computes the initial positions required for the ephemerides interpolation
Expand All @@ -87,6 +88,8 @@ def __init__(
Defines the ordering scheme for the healpix ordering. True (default) means a NESTED ordering
n_sub_intervals: int
Number of sub-intervals for the Lagrange interpolation (default: 101)
use_integrate: boolean
Whether to use the integrator to compute the ephemerides (default: False)
"""
self.nside = nside
self.picket_interval = picket_interval
Expand All @@ -96,7 +99,7 @@ def __init__(
self.sim_dict = sim_dict
self.ephem = ephem
self.observatory = observatory

self.use_integrate = use_integrate
# Set the three times and compute the observatory position
# at those times
# Using a quadratic isn't very general, but that can be
Expand All @@ -115,9 +118,9 @@ def __init__(

self.pixel_dict = defaultdict(list)

self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

self.compute_pixel_traversed()

Expand All @@ -138,7 +141,7 @@ def get_observatory_position(self, t):
r_obs = self.observatory.barycentricObservatory(et, self.obsCode) / AU_KM
return r_obs

def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01):
def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01, use_integrate=False):
"""
Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector
for a list of objects, at a given time
Expand All @@ -165,13 +168,13 @@ def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01):

# Get the topocentric unit vectors
rho, rho_mag, lt, r_ast, v_ast = integrate_light_time(
sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0
sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0, use_integrate=use_integrate
)
rho_hat = rho / rho_mag
rho_hat_dict[k] = rho_hat
return rho_hat_dict

def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01):
def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01, use_integrate=False):
"""
Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector
for *all* objects, at a given time
Expand All @@ -191,7 +194,7 @@ def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01):
"""

desigs = self.sim_dict.keys()
return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0)
return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0, use_integrate=use_integrate)

def get_interp_factors(self, tm, t0, tp, n_sub_intervals):
"""
Expand Down Expand Up @@ -313,7 +316,7 @@ def update_pickets(self, jd_tdb):

self.tm = self.t0 - self.picket_interval
self.r_obs_m = self.get_observatory_position(self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)

else:
# shift later
Expand All @@ -327,7 +330,7 @@ def update_pickets(self, jd_tdb):

self.tp = self.t0 + self.picket_interval
self.r_obs_p = self.get_observatory_position(self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

else:
# Need to compute three new sets
Expand All @@ -336,15 +339,15 @@ def update_pickets(self, jd_tdb):
# This is repeated code
self.t0 += n * self.picket_interval
self.r_obs_0 = self.get_observatory_position(self.t0)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate)

self.tp = self.t0 + self.picket_interval
self.r_obs_p = self.get_observatory_position(self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

self.tm = self.t0 - self.picket_interval
self.r_obs_m = self.get_observatory_position(self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)

self.compute_pixel_traversed()
else:
Expand Down
8 changes: 2 additions & 6 deletions src/sorcha/ephemeris/simulation_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs):
nside = 2**sconfigs.simulation.ar_healpix_order
n_sub_intervals = sconfigs.simulation.ar_n_sub_intervals

if sconfigs.expert.ar_use_integrate:
# set global variable to use integrate method instead of integrate_or_interpolate
global USE_INTEGRATE
USE_INTEGRATE = True

ephemeris_csv_filename = None
if args.output_ephemeris_file and args.outpath:
ephemeris_csv_filename = os.path.join(args.outpath, args.output_ephemeris_file)
Expand Down Expand Up @@ -181,6 +176,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs):
picket_interval,
nside,
n_sub_intervals=n_sub_intervals,
use_integrate=sconfigs.expert.ar_use_integrate,
)
for _, pointing in pointings_df.iterrows():
mjd_tai = float(pointing["observationMidpointMJD_TAI"])
Expand Down Expand Up @@ -212,7 +208,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs):
_,
ephem_geom_params.r_ast,
ephem_geom_params.v_ast,
) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01)
) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01, use_integrate=sconfigs.expert.ar_use_integrate)
ephem_geom_params.rho_hat = ephem_geom_params.rho / ephem_geom_params.rho_mag

ang_from_center = 180 / np.pi * np.arccos(np.dot(ephem_geom_params.rho_hat, visit_vector))
Expand Down
11 changes: 2 additions & 9 deletions src/sorcha/ephemeris/simulation_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def equatorial_to_ecliptic(v, rot_mat=EQ_TO_ECL_ROTATION_MATRIX):
return np.dot(v, rot_mat)


def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT):
def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT, use_integrate=False):
"""
Performs the light travel time correction between object and observatory iteratively for the object at a given reference time
Expand Down Expand Up @@ -78,16 +78,9 @@ def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_
Object velocity at t-lt
"""
lt = lt0
global USE_INTEGRATE
try:
USE_INTEGRATE
except NameError:
USE_INTEGRATE = False

print("USE_INTEGRATE: ", USE_INTEGRATE)

for i in range(iter):
if USE_INTEGRATE:
if use_integrate:
sim.integrate(t - lt)
else:
ex.integrate_or_interpolate(t - lt)
Expand Down

0 comments on commit ab71446

Please sign in to comment.