From 234f063756b1ab9f502cc95f99e347d58a8a6f13 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Tue, 10 Dec 2024 13:13:00 -0500 Subject: [PATCH 1/3] Ensure some time scales and coordinate origins are correct --- src/adam_core/dynamics/propagation.py | 5 +++-- src/adam_core/propagator/propagator.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/adam_core/dynamics/propagation.py b/src/adam_core/dynamics/propagation.py index 75fdf4f..77791b4 100644 --- a/src/adam_core/dynamics/propagation.py +++ b/src/adam_core/dynamics/propagation.py @@ -159,8 +159,9 @@ def propagate_2body( else: cartesian_covariances = None - origin_code = np.empty(n_orbits * n_times, dtype="object") - origin_code.fill("SUN") + origin_code = np.repeat( + orbits.coordinates.origin.code.to_numpy(zero_copy_only=False), n_times + ) # Convert from the jax array to a numpy array orbits_propagated = np.asarray(orbits_propagated) diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index 49095ab..ce52097 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Type, Union +from typing import List, Literal, Optional, Tuple, Type, Union import numpy as np import numpy.typing as npt @@ -103,7 +103,7 @@ def _add_light_time( observers, lt_tol: float = 1e-12, max_iter: int = 10, - ): + ) -> Tuple[Orbits, np.ndarray]: orbits_aberrated = Orbits.empty() lts = np.zeros(len(orbits)) for i, (orbit, observer) in enumerate(zip(orbits, observers)): @@ -574,6 +574,12 @@ def propagate_orbits( if propagated_variants is not None: propagated = propagated_variants.collapse(propagated) + # Preserve the time scale of the requested times + propagated = propagated.set_column( + "coordinates.time", + propagated.coordinates.time.rescale(times.scale), + ) + # Return the results with the original origin and frame # Preserve the original output origin for the input orbits # by orbit id From 6cf20743eda280522f33fa64d1d5f1fec9c897fc Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Tue, 10 Dec 2024 13:23:49 -0500 Subject: [PATCH 2/3] fetch times from store if ray objectref --- src/adam_core/propagator/propagator.py | 5 +++-- src/adam_core/propagator/tests/test_propagator.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index ce52097..fe5230b 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -433,8 +433,8 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp def propagate_orbits( self, - orbits: OrbitType, - times: TimestampType, + orbits: Union[OrbitType, ObjectRef], + times: Union[TimestampType, ObjectRef], covariance: bool = False, covariance_method: Literal[ "auto", "sigma-point", "monte-carlo" @@ -495,6 +495,7 @@ def propagate_orbits( times_ref = ray.put(times) else: times_ref = times + times = ray.get(times_ref) if not isinstance(orbits, ObjectRef): orbits_ref = ray.put(orbits) diff --git a/src/adam_core/propagator/tests/test_propagator.py b/src/adam_core/propagator/tests/test_propagator.py index c3192e8..7ff1a99 100644 --- a/src/adam_core/propagator/tests/test_propagator.py +++ b/src/adam_core/propagator/tests/test_propagator.py @@ -90,7 +90,6 @@ def test_propagator_single_worker(): pass -@pytest.mark.skipif(RAY_INSTALLED is False, reason="Ray is not installed.") def test_propagator_multiple_workers_ray(): orbits = make_real_orbits(10) times = Timestamp.from_iso8601(["2020-01-01T00:00:00", "2020-01-01T00:00:01"]) From 83d0e6609e7d8f5650d24f3a55f0d010f6ea296e Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Tue, 10 Dec 2024 13:28:22 -0500 Subject: [PATCH 3/3] remove unused import --- pyproject.toml | 2 +- src/adam_core/propagator/tests/test_propagator.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7fdac6f..983c6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ lint = { composite = [ "black --check ./src/adam_core", "isort --check-only ./src/adam_core", ] } -fix = "ruff ./src/adam_core --fix" +fix = "ruff check ./src/adam_core --fix" typecheck = "mypy --strict ./src/adam_core" test = "pytest --benchmark-skip -m 'not profile' {args}" diff --git a/src/adam_core/propagator/tests/test_propagator.py b/src/adam_core/propagator/tests/test_propagator.py index 7ff1a99..a273f21 100644 --- a/src/adam_core/propagator/tests/test_propagator.py +++ b/src/adam_core/propagator/tests/test_propagator.py @@ -1,6 +1,5 @@ import numpy as np import pyarrow as pa -import pytest import quivr as qv from ...coordinates.cartesian import CartesianCoordinates