Skip to content

Commit

Permalink
Use a IndexedObservers quivr table in from_codes
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 18, 2024
1 parent 14734b8 commit 9e7ea9d
Showing 1 changed file with 17 additions and 72 deletions.
89 changes: 17 additions & 72 deletions src/adam_core/observers/observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,90 +84,35 @@ def from_codes(
if not isinstance(codes, pa.Array):
codes = pa.array(codes, type=pa.large_string())

# Create a table with the codes and times and add
# and index column to track the original order
table = pa.Table.from_pydict(
{
"index": pa.array(range(len(codes)), type=pa.uint64()),
"code": codes,
"times.days": times.days,
"times.nanos": times.nanos,
}
)

# Expected observers schema with the addition of a
# column that tracks the original index
observers_schema = pa.schema(
[
pa.field("code", pa.large_string(), nullable=False),
pa.field(
"coordinates",
pa.struct(
[
pa.field("x", pa.float64()),
pa.field("y", pa.float64()),
pa.field("z", pa.float64()),
pa.field("vx", pa.float64()),
pa.field("vy", pa.float64()),
pa.field("vz", pa.float64()),
pa.field(
"time",
pa.struct(
[
pa.field("days", pa.int64()),
pa.field("nanos", pa.int64()),
]
),
),
pa.field(
"covariance",
pa.struct(
[pa.field("values", pa.large_list(pa.float64()))]
),
),
pa.field(
"origin",
pa.struct([pa.field("code", pa.large_string())]),
),
]
),
),
pa.field("index", pa.uint64()),
],
metadata={
"coordinates.time.scale": times.scale,
"coordinates.frame": "ecliptic",
},
)
class IndexedObservers(qv.Table):
index = qv.UInt64Column()
observers = Observers.as_column()

# Create an empty table with the expected schema
observers_table = observers_schema.empty_table()
indexed_observers = IndexedObservers.empty()

# Loop through each unique code and calculate the observer's
# state for each time (these can be non-unique as cls.from_code
# will handle this)
for code in table["code"].unique():
for code in pc.unique(codes):

times_code = table.filter(pc.equal(table["code"], code))
indices = pc.indices_nonzero(pc.equal(codes, code))
times_code = times.take(indices)

observers = cls.from_code(
observers_i = cls.from_code(
code.as_py(),
Timestamp.from_kwargs(
days=times_code["times.days"],
nanos=times_code["times.nanos"],
scale=times.scale,
),
times_code,
)

observers_table_i = observers.table.append_column(
"index", times_code["index"]
indexed_observers_i = IndexedObservers.from_kwargs(
index=indices,
observers=observers_i,
)
observers_table = pa.concat_tables(
[observers_table, observers_table_i]
).combine_chunks()

observers_table = observers_table.sort_by(("index")).drop_columns(["index"])
return cls.from_pyarrow(observers_table)
indexed_observers = qv.concatenate([indexed_observers, indexed_observers_i])
if indexed_observers.fragmented():
indexed_observers = qv.defragment(indexed_observers)

return indexed_observers.sort_by("index").observers

@classmethod
def from_code(cls, code: Union[str, OriginCodes], times: Timestamp) -> Self:
Expand Down

0 comments on commit 9e7ea9d

Please sign in to comment.