Skip to content

Commit

Permalink
put profiling in its own special group
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Dec 4, 2024
1 parent b3ba5a3 commit 9ae26d6
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pip-build-lint-test-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install pdm
- name: Install Testing Dependencies
run: |
pdm install -G dev -G test
pdm install -G test
- name: Lint
run: pdm run lint
- name: Test with coverage
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ lint = { composite = [
fix = "ruff ./src/adam_core --fix"
typecheck = "mypy --strict ./src/adam_core"

test = "pytest --benchmark-disable {args}"
test = "pytest --benchmark-disable -m 'not profile' {args}"
doctest = "pytest --doctest-plus --doctest-only"
benchmark = "pytest --benchmark-only"
coverage = "pytest --cov=adam_core --cov-report=xml"
coverage = "pytest --cov=adam_core -m 'not profile' --cov-report=xml"


19 changes: 7 additions & 12 deletions src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from .aberrations import _add_light_time, add_stellar_aberration
from .propagation import pad_to_fixed_size, process_in_chunks
from .propagation import process_in_chunks


@jit
Expand Down Expand Up @@ -178,7 +178,9 @@ def generate_ephemeris_2body(
Topocentric ephemerides for each propagated orbit as observed by the given observers.
"""
num_entries = len(observers)
assert len(propagated_orbits) == num_entries, "Orbits and observers must be paired and orbits must be propagated to observer times."
assert (
len(propagated_orbits) == num_entries
), "Orbits and observers must be paired and orbits must be propagated to observer times."

# Transform both the orbits and observers to the barycenter if they are not already.
propagated_orbits_barycentric = propagated_orbits.set_column(
Expand All @@ -200,16 +202,9 @@ def generate_ephemeris_2body(
),
)

# Stack the observer coordinates and codes for each orbit in the propagated orbits
# num_orbits = len(propagated_orbits_barycentric.orbit_id.unique())
# observer_coordinates = np.tile(
# observers_barycentric.coordinates.values, (num_orbits, 1)
# )
observer_coordinates = observers_barycentric.coordinates.values
observer_codes = observers_barycentric.code.to_numpy(zero_copy_only=False)
# observer_codes = np.tile(observers.code.to_numpy(zero_copy_only=False), num_orbits)
mu = observers_barycentric.coordinates.origin.mu()
# mu = np.tile(mu, num_orbits)
times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False)

# Define chunk size
Expand All @@ -218,8 +213,8 @@ def generate_ephemeris_2body(
# Process in chunks
ephemeris_chunks = []
light_time_chunks = []
for (orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk) in zip(

for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip(
process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size),
process_in_chunks(times, chunk_size),
process_in_chunks(observer_coordinates, chunk_size),
Expand All @@ -241,7 +236,7 @@ def generate_ephemeris_2body(
# Concatenate chunks and remove padding
ephemeris_spherical = jnp.concatenate(ephemeris_chunks, axis=0)[:num_entries]
light_time = jnp.concatenate(light_time_chunks, axis=0)[:num_entries]

ephemeris_spherical = np.array(ephemeris_spherical)
light_time = np.array(light_time)

Expand Down
6 changes: 3 additions & 3 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def process_in_chunks(array, chunk_size):
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i:i + chunk_size]
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk
Expand Down Expand Up @@ -145,7 +145,7 @@ def propagate_2body(
process_in_chunks(orbits_array_, chunk_size),
process_in_chunks(t0_, chunk_size),
process_in_chunks(t1_, chunk_size),
process_in_chunks(mu, chunk_size)
process_in_chunks(mu, chunk_size),
):
orbits_propagated_chunk = _propagate_2body_vmap(
orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol
Expand All @@ -156,7 +156,7 @@ def propagate_2body(
orbits_propagated = jnp.concatenate(orbits_propagated_chunks, axis=0)

# Remove padding
orbits_propagated = orbits_propagated[:n_orbits * n_times]
orbits_propagated = orbits_propagated[: n_orbits * n_times]

if not orbits.coordinates.covariance.is_all_nan():
cartesian_covariances = orbits.coordinates.covariance.to_matrix()
Expand Down
12 changes: 7 additions & 5 deletions src/adam_core/dynamics/tests/test_ephemeris.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cProfile
import itertools

import jax
import numpy as np
Expand Down Expand Up @@ -140,9 +139,12 @@ def test_generate_ephemeris_2body(object_id, propagated_orbits, ephemeris):
assert pc.all(pc.is_null(ephemeris_orbit_2body.aberrated_coordinates.vy)).as_py()
assert pc.all(pc.is_null(ephemeris_orbit_2body.aberrated_coordinates.vz)).as_py()


@pytest.mark.profile
def test_profile_generate_ephemeris_2body_matrix(propagated_orbits, tmp_path):
"""Profile the generate_ephemeris_2body function with different combinations of orbits,
observers and times. Results are saved to a stats file that can be visualized with snakeviz."""
"""Profile the generate_ephemeris_2body function with different combinations of orbits,
observers and times. Results are saved to a stats file that can be visualized with snakeviz.
"""
# Clear the jax cache
jax.clear_caches()
# Create profiler
Expand Down Expand Up @@ -176,8 +178,8 @@ def to_profile():
profiler.enable()
to_profile()
profiler.disable()

# Save and print results
stats_file = tmp_path / "ephemeris_profile.prof"
profiler.dump_stats(stats_file)
print(f"Run 'snakeviz {stats_file}' to view the profile results.")
print(f"Run 'snakeviz {stats_file}' to view the profile results.")
12 changes: 5 additions & 7 deletions src/adam_core/dynamics/tests/test_propagation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import cProfile
import itertools
import pstats
from pstats import SortKey

import jax
import numpy as np
Expand Down Expand Up @@ -449,11 +447,10 @@ def test_benchmark_propagate_2body(benchmark, orbital_elements):


@pytest.mark.benchmark(group="propagate_2body")
def test_benchmark_propagate_2body_matrix(
benchmark, propagated_orbits
):
def test_benchmark_propagate_2body_matrix(benchmark, propagated_orbits):
# Clear the jax cache
jax.clear_caches()

def benchmark_function():
n_orbits = [1, 5, 20]
n_times = [1, 10, 100]
Expand All @@ -468,12 +465,13 @@ def benchmark_function():
benchmark(benchmark_function)


@pytest.mark.profile
def test_profile_propagate_2body_matrix(propagated_orbits, tmp_path):
"""Profile the propagate_2body function with different combinations of orbits and times.
Results are saved to a stats file that can be visualized with snakeviz."""
# Clear the jax cache
jax.clear_caches()

# Create profiler
profiler = cProfile.Profile(subcalls=True, builtins=True)
profiler.bias = 0
Expand All @@ -488,7 +486,7 @@ def test_profile_propagate_2body_matrix(propagated_orbits, tmp_path):
)
propagate_2body(propagated_orbits[:n_orbits_i], times=times)
profiler.disable()

# Save and print results
stats_file = tmp_path / "precovery_profile.prof"
profiler.dump_stats(stats_file)
Expand Down
5 changes: 1 addition & 4 deletions src/adam_core/orbit_determination/tests/test_iod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import importlib.util
import os
import sys

import numpy as np
import pyarrow.compute as pc
import pytest
Expand Down Expand Up @@ -60,6 +56,7 @@ def real_data():

return orbit, observations


def test_iod(real_data):
orbit, observations = real_data
# Call the iod function
Expand Down
4 changes: 0 additions & 4 deletions src/adam_core/orbit_determination/tests/test_od.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import importlib.util
import os
import sys

import numpy as np
import pyarrow.compute as pc
import pytest
Expand Down
4 changes: 3 additions & 1 deletion src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def _add_light_time(

# Calculate the new epoch and propagate the initial orbit to that epoch
# Should be sufficient to use 2body propagation for this
orbit_i = propagate_2body(orbit, Timestamp.from_mjd([t0 - lt], scale="tdb"))
orbit_i = propagate_2body(
orbit, Timestamp.from_mjd([t0 - lt], scale="tdb")
)

# Update the previous light travel time to this iteration's light travel time
lt_prev = lt
Expand Down

0 comments on commit 9ae26d6

Please sign in to comment.