Skip to content

Commit

Permalink
ml-dsgp4: python file, tutorials, and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Sceki committed Nov 24, 2024
1 parent 1239a44 commit 696f3b5
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 11 deletions.
Binary file added doc/_static/dsgp4_backprop_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ $\partial$SGP4 API
:recursive:

dsgp4
dsgp4.mldsgp4.mldsgp4
dsgp4.plot.plot_orbit
dsgp4.plot.plot_tles
dsgp4.tle.compute_checksum
Expand Down
9 changes: 8 additions & 1 deletion doc/capabilities.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
"source": [
"# Capabilities\n",
"\n",
"dSGP4 is an open-source project that constitutes a differentiable version of SGP4\n"
"dSGP4 is an open-source project that constitutes a differentiable version of SGP4. It also offers hybrid ML-dSGP4 models to improve the accuracy of SGP4, when simulated or observed precise data is available.\n",
"\n",
"The core capabilities of dSGP4 can be summarized as follows:\n",
"\n",
"* Differentiable version of SGP4 (implemented in PyTorch)\n",
"* Hybrid SGP4 and machine learning propagation: input/output/parameters corrections of SGP4 from accurate simulated or observed data are learned\n",
"* Parallel TLE propagation\n",
"* Use of differentiable SGP4 on several spaceflight mechanics problems (state transition matrix computation, covariance transformation, and propagation, orbit determination, ML hybrid orbit propagation, etc.)"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# -- Project information -----------------------------------------------------

project = "dsgp4"
copyright = "2022, 2023, 2024, Giacomo Acciarini and Atılım Güneş Baydin and Dario Izzo"
copyright = "2022, 2023, 2024, 2025, Giacomo Acciarini and Atılım Güneş Baydin and Dario Izzo"
author = "Giacomo Acciarini, Atılım Güneş Baydin, Dario Izzo"

# The full version, including alpha/beta/rc tags
Expand Down
4 changes: 2 additions & 2 deletions doc/credits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"source": [
"# Credits\n",
"\n",
"$\\partial\\textrm{SGP4}$ was developed during a project sponsored by the University of Oxford, while Giacomo Acciarini was at the [OX4AILab](https://oxai4science.github.io/) collaborating with Dr. Atılım Güneş Baydin.\n",
"$\\partial\\textrm{SGP4}$ was developed during a project sponsored by the University of Oxford, while Giacomo Acciarini was at the [Oxford AI4Science Lab](https://oxai4science.github.io/) collaborating with Dr. Atılım Güneş Baydin.\n",
"\n",
"The main developers are: Giacomo Acciarini ( [email protected] ), Atılım Güneş Baydin ( [email protected] )."
"The main developers is: Giacomo Acciarini ( [email protected] )."
]
}
],
Expand Down
4 changes: 3 additions & 1 deletion doc/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
$\partial\textrm{SGP4}$ Documentation
================================

**dsgp4** is a differentiable SGP4 program written leveraging the [PyTorch](https://pytorch.org/) machine learning framework: this enables features like automatic differentiation and batch propagation (across different TLEs) that were not previously available in the original implementation.
**dsgp4** is a differentiable SGP4 program written leveraging the [PyTorch](https://pytorch.org/) machine learning framework: this enables features like automatic differentiation and batch propagation (across different TLEs) that were not previously available in the original implementation. Furthermore, it also offers a hybrid propagation scheme called ML-dSGP4 where dSGP4 and ML models can be combined to enhance SGP4 accuracy when higher-precision simulated (e.g. from a numerical integrator) or observed (e.g. from ephemerides) data is available.

For more details on the model and results, check out our publication: [Acciarini, Giacomo, Atılım Güneş Baydin, and Dario Izzo. "*Closing the Gap Between SGP4 and High-Precision Propagation via Differentiable Programming*" (2024) Vol. 226(1), pages: 694-701](https://doi.org/10.1016/j.actaastro.2024.10.063)


The authors are [Giacomo Acciarini](https://www.esa.int/gsp/ACT/team/giacomo_acciarini/), [Atılım Güneş Baydin](https://gbaydin.github.io/), [Dario Izzo](https://www.esa.int/gsp/ACT/team/dario_izzo/). The main developer is Giacomo Acciarini ([email protected]).
Expand Down
6 changes: 0 additions & 6 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@ Installation

.. _installation_deps:

Dependencies
------------

dSGP4 has the following Python dependencies:



Packages
--------
Expand Down
223 changes: 223 additions & 0 deletions doc/notebooks/mldsgp4.ipynb

Large diffs are not rendered by default.

Binary file added doc/notebooks/mldsgp4_example_model.pth
Binary file not shown.
1 change: 1 addition & 0 deletions doc/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ These tutorials include some basic examples on how to use dSGP4 for simple tasks
notebooks/tle_object.ipynb
notebooks/tle_propagation.ipynb
notebooks/sgp4_partial_derivatives.ipynb
notebooks/mldsgp4.ipynb


Advanced
Expand Down
1 change: 1 addition & 0 deletions dsgp4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .plot import plot_orbit, plot_tles
from . import tle
from .tle import TLE
from .mldsgp4 import mldsgp4
102 changes: 102 additions & 0 deletions dsgp4/mldsgp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import torch.nn as nn

from . import initialize_tle, propagate, propagate_batch
from torch.nn.parameter import Parameter

class mldsgp4(nn.Module):
def __init__(self,
normalization_R=6958.137,
normalization_V=7.947155867983262,
hidden_size=100,
input_correction=1e-2,
output_correction=0.8):
"""
This class implements the ML-dSGP4 model, where dSGP4 inputs and outputs are corrected via neural networks,
better match simulated or observed higher-precision data.
Parameters:
----------------
normalization_R (``float``): normalization constant for x,y,z coordinates.
normalization_V (``float``): normalization constant for vx,vy,vz coordinates.
hidden_size (``int``): number of neurons in the hidden layers.
input_correction (``float``): correction factor for the input layer.
output_correction (``float``): correction factor for the output layer.
"""
super().__init__()
self.fc1=nn.Linear(6, hidden_size)
self.fc2=nn.Linear(hidden_size,hidden_size)
self.fc3=nn.Linear(hidden_size, 6)
self.fc4=nn.Linear(6,hidden_size)
self.fc5=nn.Linear(hidden_size, hidden_size)
self.fc6=nn.Linear(hidden_size, 6)

self.tanh = nn.Tanh()
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
self.normalization_R=normalization_R
self.normalization_V=normalization_V
self.input_correction = Parameter(input_correction*torch.ones((6,)))
self.output_correction = Parameter(output_correction*torch.ones((6,)))

def forward(self, tles, tsinces):
"""
This method computes the forward pass of the ML-dSGP4 model.
It can take either a single or a list of `dsgp4.tle.TLE` objects,
and a torch.tensor of times since the TLE epoch in minutes.
It then returns the propagated state in the TEME coordinate system. The output
is normalized, to unnormalize and obtain km and km/s, you can use self.normalization_R constant for the position
and self.normalization_V constant for the velocity.
Parameters:
----------------
tles (``dsgp4.tle.TLE`` or ``list``): a TLE object or a list of TLE objects.
tsinces (``torch.tensor``): a torch.tensor of times since the TLE epoch in minutes.
Returns:
----------------
(``torch.tensor``): a tensor of len(tsince)x6 representing the corrected satellite position and velocity in normalized units (to unnormalize to km and km/s, use `self.normalization_R` for position, and `self.normalization_V` for velocity).
"""
is_batch=hasattr(tles, '__len__')
if is_batch:
#this is the batch case, so we proceed and initialize the batch:
_,tles=initialize_tle(tles,with_grad=True)
x0 = torch.stack((tles._ecco, tles._argpo, tles._inclo, tles._mo, tles._no_kozai, tles._nodeo), dim=1)
else:
#this handles the case in which a singlee TLE is passed
initialize_tle(tles,with_grad=True)
x0 = torch.stack((tles._ecco, tles._argpo, tles._inclo, tles._mo, tles._no_kozai, tles._nodeo), dim=0).reshape(-1,6)
x=self.leaky_relu(self.fc1(x0))
x=self.leaky_relu(self.fc2(x))
x=x0*(1+self.input_correction*self.tanh(self.fc3(x)))
#now we need to substitute them back into the tles:
tles._ecco=x[:,0]
tles._argpo=x[:,1]
tles._inclo=x[:,2]
tles._mo=x[:,3]
tles._no_kozai=x[:,4]
tles._nodeo=x[:,5]
if is_batch:
#we propagate the batch:
states_teme=propagate_batch(tles,tsinces)
else:
states_teme=propagate(tles,tsinces)
states_teme=states_teme.reshape(-1,6)
#we now extract the output parameters to correct:
x_out=torch.cat((states_teme[:,:3]/self.normalization_R, states_teme[:,3:]/self.normalization_V),dim=1)

x=self.leaky_relu(self.fc4(x_out))
x=self.leaky_relu(self.fc5(x))
x=x_out*(1+self.output_correction*self.tanh(self.fc6(x)))
return x

def load_model(self, path, device='cpu'):
"""
This method loads a model from a file.
Parameters:
----------------
path (``str``): path to the file where the model is stored.
device (``str``): device where the model will be loaded. Default is 'cpu'.
"""
self.load_state_dict(torch.load(path,map_location=torch.device(device)))
self.eval()

0 comments on commit 696f3b5

Please sign in to comment.