Skip to content

Commit

Permalink
Use dataclass for TransientView
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Jan 8, 2025
1 parent e9394cd commit 87bdd05
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
21 changes: 5 additions & 16 deletions ampel/view/TransientView.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Last Modified Date: 17.06.2021
# Last Modified By: valery brinnel <[email protected]>

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from ampel.view.LightCurve import LightCurve
Expand All @@ -23,24 +24,12 @@
from ampel.view.T2DocView import T2DocView


@dataclass(frozen=True)
class TransientView(SnapView):

__slots__ = "lightcurve",

lightcurve: "None | Sequence[LightCurve]"

def __init__(
self,
id: "StockId",
stock: "None | StockDocument" = None,
origin: "None | OneOrMany[int]" = None,
t0: "None | Sequence[DataPoint]" = None,
t1: "None | Sequence[T1Document]" = None,
t2: "None | Sequence[T2DocView]" = None,
logs: "None | Sequence[LogDocument]" = None,
extra: "None | dict[str, Any]" = None
):
super().__init__(id, stock=stock, origin=origin, t0=t0, t1=t1, t2=t2, logs=logs, extra=extra)
lightcurve: "None | Sequence[LightCurve]" = field(init=False)

def __post_init__(self):

if self.t0 and self.t1:
lightcurve: None | Sequence[LightCurve] = tuple(
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ include = [

[tool.poetry.dependencies]
python = "^3.10"
ampel-interface = {version = ">=0.10.4,<0.11"}
ampel-interface = {version = ">=0.10.4.post0,<0.11"}

astropy = ">=5"

[tool.poetry.dev-dependencies]
Expand Down
24 changes: 20 additions & 4 deletions tests/test_TransientView.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from functools import reduce

import pytest
Expand All @@ -7,7 +8,7 @@

@pytest.fixture
def view():
return TransientView("stockystock")
return TransientView(id="stockystock")


def gather_slots(typ):
Expand All @@ -17,9 +18,24 @@ def gather_slots(typ):


def test_reduce(view: TransientView):
cls, args = view.__reduce__()
tview = cls(*args)
tview = pickle.loads(pickle.dumps(view))
slots = gather_slots(type(tview))
assert slots
assert len(slots) > 1
for attr in gather_slots(type(tview)):
assert getattr(tview, attr) == getattr(view, attr)

def test_build_lightcurve():
view = TransientView(
id=1,
t0=[
{"id": 1, "body": {"jd": 2450000.0, "mag": 12.0}},
{"id": 2, "body": {"jd": 2450001.0, "mag": 13.0}},
],
t1=[
{"stock": 1, "link": 1, "dps": [1, 2]},
],
)
assert view.lightcurve
assert len(view.lightcurve) == 1
assert view.lightcurve[0].get_values("jd") == [2450000.0, 2450001.0]
assert view.lightcurve[0].get_values("mag") == [12.0, 13.0]

0 comments on commit 87bdd05

Please sign in to comment.