From d976d025b5e7c7d9bf453baa95dd948ed26ff2b5 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 20 Aug 2024 14:30:55 +0100 Subject: [PATCH] allowed backends to take pydantic models Also finished a `PvaTable` type. --- src/ophyd_async/core/__init__.py | 2 - src/ophyd_async/core/_device_save_loader.py | 14 +- src/ophyd_async/core/_signal_backend.py | 23 -- src/ophyd_async/core/_soft_signal_backend.py | 26 +- src/ophyd_async/epics/signal/__init__.py | 4 +- src/ophyd_async/epics/signal/_p4p.py | 36 ++- .../epics/signal/_p4p_table_abstraction.py | 70 ++++++ src/ophyd_async/fastcs/panda/__init__.py | 4 - src/ophyd_async/fastcs/panda/_table.py | 235 +++++++++--------- src/ophyd_async/plan_stubs/_fly.py | 41 ++- tests/core/test_device_save_loader.py | 20 +- tests/fastcs/panda/test_panda_connect.py | 157 ++++++------ tests/fastcs/panda/test_panda_utils.py | 26 +- tests/fastcs/panda/test_table.py | 222 +++++++++++------ tests/fastcs/panda/test_trigger.py | 34 ++- tests/test_data/test_yaml_save.yml | 3 +- 16 files changed, 507 insertions(+), 410 deletions(-) create mode 100644 src/ophyd_async/epics/signal/_p4p_table_abstraction.py diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 1b082f2835..d66cbc5ba3 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -63,7 +63,6 @@ ) from ._signal_backend import ( BackendConverterFactory, - ProtocolDatatypeAbstraction, RuntimeSubsetEnum, SignalBackend, SubsetEnum, @@ -124,7 +123,6 @@ "NameProvider", "PathInfo", "PathProvider", - "ProtocolDatatypeAbstraction", "ShapeProvider", "StaticFilenameProvider", "StaticPathProvider", diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 02ceeeb0af..a40c404d50 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -7,10 +7,10 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg +from pydantic import BaseModel from ._device import Device from ._signal import SignalRW -from ._signal_backend import ProtocolDatatypeAbstraction def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: @@ -19,14 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) -def protocol_datatype_abstraction_representer( - dumper: yaml.Dumper, protocol_datatype_abstraction: ProtocolDatatypeAbstraction +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel ) -> yaml.Node: """Uses the protocol datatype since it has to be serializable.""" - return dumper.represent_data( - protocol_datatype_abstraction.convert_to_protocol_datatype() - ) + return dumper.represent_data(model.model_dump(mode="python")) class OphydDumper(yaml.Dumper): @@ -146,8 +144,8 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) yaml.add_multi_representer( - ProtocolDatatypeAbstraction, - protocol_datatype_abstraction_representer, + BaseModel, + pydantic_model_abstraction_representer, Dumper=yaml.Dumper, ) diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 8427ee502c..178ff1edfa 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -13,29 +13,6 @@ from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T -class ProtocolDatatypeAbstraction(ABC, Generic[T]): - @abstractmethod - def __init__(self): - """The abstract datatype must be able to be intialized with no arguments.""" - - @abstractmethod - def convert_to_protocol_datatype(self) -> T: - """ - Convert the abstract datatype to a form which can be sent - over whichever protocol. - - This output will be used when the device is serialized. - """ - - @classmethod - @abstractmethod - def convert_from_protocol_datatype(cls, value: T) -> "ProtocolDatatypeAbstraction": - """ - Convert the datatype received from the protocol to a - higher level abstract datatype. - """ - - class BackendConverterFactory(ABC): """Convert between the signal backend and the signal type""" diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 619400c8fd..26b85272a7 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -8,11 +8,11 @@ import numpy as np from bluesky.protocols import DataKey, Dtype, Reading +from pydantic import BaseModel from typing_extensions import TypedDict from ._signal_backend import ( BackendConverterFactory, - ProtocolDatatypeAbstraction, RuntimeSubsetEnum, SignalBackend, ) @@ -127,8 +127,10 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) -class SoftProtocolDatatypeAbstractionConverter(SoftConverter): - def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]): +class SoftPydanticModelConverter(SoftConverter): + """Necessary for serializing soft signals.""" + + def __init__(self, datatype: Type[BaseModel]): self.datatype = datatype def reading(self, value: T, timestamp: float, severity: int) -> Reading: @@ -136,12 +138,16 @@ def reading(self, value: T, timestamp: float, severity: int) -> Reading: return super().reading(value, timestamp, severity) def value(self, value: Any) -> Any: - if not isinstance(value, self.datatype): - # For the case where we - value = self.datatype.convert_from_protocol_datatype(value) + if isinstance(value, dict): + value = self.datatype(**value) return value def write_value(self, value): + if isinstance(value, dict): + # If the device is being deserialized + return self.datatype(**value).model_dump(mode="python") + if isinstance(value, self.datatype): + return value.model_dump(mode="python") return value def make_initial_value(self, datatype: Type | None) -> Any: @@ -162,16 +168,16 @@ def make_converter(cls, datatype): is_enum = inspect.isclass(datatype) and ( issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) - is_convertable_abstract_datatype = inspect.isclass(datatype) and issubclass( - datatype, ProtocolDatatypeAbstraction + is_pydantic_model = inspect.isclass(datatype) and issubclass( + datatype, BaseModel ) if is_array or is_sequence: return SoftArrayConverter() if is_enum: return SoftEnumConverter(datatype) - if is_convertable_abstract_datatype: - return SoftProtocolDatatypeAbstractionConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) return SoftConverter() diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index 5da098b59e..f9bd58306f 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,5 +1,6 @@ from ._common import LimitPair, Limits, get_supported_values -from ._p4p import PvaSignalBackend, PvaTableAbstraction +from ._p4p import PvaSignalBackend +from ._p4p_table_abstraction import PvaTable from ._signal import ( epics_signal_r, epics_signal_rw, @@ -13,6 +14,7 @@ "LimitPair", "Limits", "PvaSignalBackend", + "PvaTable", "PvaTableAbstraction", "epics_signal_r", "epics_signal_rw", diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 58872071bc..59cbff5fa9 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,7 +3,6 @@ import inspect import logging import time -from abc import abstractmethod from dataclasses import dataclass from enum import Enum from math import isnan, nan @@ -13,12 +12,12 @@ from bluesky.protocols import DataKey, Dtype, Reading from p4p import Value from p4p.client.asyncio import Context, Subscription +from pydantic import BaseModel from ophyd_async.core import ( DEFAULT_TIMEOUT, BackendConverterFactory, NotConnected, - ProtocolDatatypeAbstraction, ReadingValueCallback, RuntimeSubsetEnum, SignalBackend, @@ -288,32 +287,25 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -class PvaTableAbstraction(ProtocolDatatypeAbstraction[Dict]): - @abstractmethod - def convert_to_protocol_datatype(self) -> Dict: - """Converts the object to a pva table (dictionary).""" - - @classmethod - @abstractmethod - def convert_from_protocol_datatype(cls, value: Dict) -> "PvaTableAbstraction": - """Converts from a pva table (dictionary) to a Python datatype.""" - - -class PvaTableAbtractionConverter(PvaConverter): - def __init__(self, datatype: PvaTableAbstraction): +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): self.datatype = datatype def reading(self, value: Value): ts = time.time() - value = self.datatype.convert_from_protocol_datatype(value.todict()) + value = self.value(value) return {"value": value, "timestamp": ts, "alarm_severity": 0} def value(self, value: Value): - return self.datatype.convert_from_protocol_datatype(value.todict()) + return self.datatype(**value.todict()) - def write_value(self, value): + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + """ + A user can put whichever form to the signal. + This is required for yaml deserialization. + """ if isinstance(value, self.datatype): - return value.convert_to_protocol_datatype() + return value.model_dump(mode="python") return value @@ -327,8 +319,8 @@ class PvaConverterFactory(BackendConverterFactory): np.ndarray, Enum, RuntimeSubsetEnum, + BaseModel, dict, - PvaTableAbstraction, ) @classmethod @@ -411,9 +403,9 @@ def make_converter( if ( datatype and inspect.isclass(datatype) - and issubclass(datatype, PvaTableAbstraction) + and issubclass(datatype, BaseModel) ): - return PvaTableAbtractionConverter(datatype) + return PvaPydanticModelConverter(datatype) return PvaDictConverter() else: raise TypeError(f"{pv}: Unsupported typeid {typeid}") diff --git a/src/ophyd_async/epics/signal/_p4p_table_abstraction.py b/src/ophyd_async/epics/signal/_p4p_table_abstraction.py new file mode 100644 index 0000000000..a6e5ecf566 --- /dev/null +++ b/src/ophyd_async/epics/signal/_p4p_table_abstraction.py @@ -0,0 +1,70 @@ +from typing import Dict + +import numpy as np +from pydantic import BaseModel, ConfigDict, model_validator +from pydantic_numpy.typing import NpNDArray + + +class PvaTable(BaseModel): + """An abstraction of a PVA Table of str to python array.""" + + model_config = ConfigDict(validate_assignment=True, strict=False) + + @classmethod + def row(cls, sub_cls, **kwargs) -> "PvaTable": + arrayified_kwargs = { + field_name: np.concatenate( + ( + (default_arr := field_value.default_factory()), + np.array([kwargs[field_name]], dtype=default_arr.dtype), + ) + ) + for field_name, field_value in sub_cls.model_fields.items() + } + return sub_cls(**arrayified_kwargs) + + def __add__(self, right: "PvaTable") -> "PvaTable": + """Concatinate the arrays in field values.""" + + assert isinstance(right, type(self)), ( + f"{right} is not a `PvaTable`, or is not the same " + f"type of `PvaTable` as {self}." + ) + + return type(self)( + **{ + field_name: np.concatenate( + (getattr(self, field_name), getattr(right, field_name)) + ) + for field_name in self.model_fields + } + ) + + @model_validator(mode="after") + def validate_arrays(self) -> "PvaTable": + first_length = len(next(iter(self))[1]) + assert all( + len(field_value) == first_length for _, field_value in self + ), "Rows should all be of equal size." + + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + + if not all( + np.issubdtype( + self.model_fields[field_name].default_factory().dtype, field_value.dtype + ) + for field_name, field_value in self + ): + raise ValueError( + f"Cannot construct a `{type(self).__name__}`, " + "some rows have incorrect types." + ) + + return self + + def convert_to_pva_datatype(self) -> Dict[str, NpNDArray]: + return self.model_dump(mode="python") + + @classmethod + def convert_from_pva_datatype(cls, pva_table: Dict[str, NpNDArray]): + return cls(**pva_table) diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 3724397053..0dbe7222b0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -15,9 +15,7 @@ DatasetTable, PandaHdf5DatasetType, SeqTable, - SeqTableRowType, SeqTrigger, - seq_table_row, ) from ._trigger import ( PcompInfo, @@ -44,9 +42,7 @@ "DatasetTable", "PandaHdf5DatasetType", "SeqTable", - "SeqTableRowType", "SeqTrigger", - "seq_table_row", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 6dba645cd1..b1ed6d5729 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,13 +1,13 @@ from enum import Enum -from typing import Dict, Sequence, Union +from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy as pnd -from pydantic import Field, RootModel, field_validator +from pydantic import Field +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation from typing_extensions import TypedDict -from ophyd_async.epics.signal import PvaTableAbstraction +from ophyd_async.epics.signal import PvaTable class PandaHdf5DatasetType(str, Enum): @@ -36,124 +36,119 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -SeqTableRowType = np.dtype( - [ - ("repeats", np.int32), - ("trigger", "U14"), # One of the SeqTrigger values - ("position", np.int32), - ("time1", np.int32), - ("outa1", np.bool_), - ("outb1", np.bool_), - ("outc1", np.bool_), - ("outd1", np.bool_), - ("oute1", np.bool_), - ("outf1", np.bool_), - ("time2", np.int32), - ("outa2", np.bool_), - ("outb2", np.bool_), - ("outc2", np.bool_), - ("outd2", np.bool_), - ("oute2", np.bool_), - ("outf2", np.bool_), - ] -) - - -def seq_table_row( - *, - repeats: int = 0, - trigger: str = "", - position: int = 0, - time1: int = 0, - outa1: bool = False, - outb1: bool = False, - outc1: bool = False, - outd1: bool = False, - oute1: bool = False, - outf1: bool = False, - time2: int = 0, - outa2: bool = False, - outb2: bool = False, - outc2: bool = False, - outd2: bool = False, - oute2: bool = False, - outf2: bool = False, -) -> pnd.NpNDArray: - return np.array( - ( - repeats, - trigger, - position, - time1, - outa1, - outb1, - outc1, - outd1, - oute1, - outf1, - time2, - outa2, - outb2, - outc2, - outd2, - oute2, - outf2, - ), - dtype=SeqTableRowType, +PydanticNp1DArrayInt32 = Annotated[ + np.ndarray[tuple[int], np.int32], + NpArrayPydanticAnnotation.factory( + data_type=np.int32, dimensions=1, strict_data_typing=False + ), +] +PydanticNp1DArrayBool = Annotated[ + np.ndarray[tuple[int], np.bool_], + NpArrayPydanticAnnotation.factory( + data_type=np.bool_, dimensions=1, strict_data_typing=False + ), +] + +PydanticNp1DArrayUnicodeString = Annotated[ + np.ndarray[tuple[int], np.unicode_], + NpArrayPydanticAnnotation.factory( + data_type=np.unicode_, dimensions=1, strict_data_typing=False + ), +] + + +class SeqTable(PvaTable): + repeats: PydanticNp1DArrayInt32 = Field( + default_factory=lambda: np.array([], np.int32) ) - - -class SeqTable(RootModel, PvaTableAbstraction): - root: pnd.NpNDArray = Field( - default_factory=lambda: np.array([], dtype=SeqTableRowType), + trigger: PydanticNp1DArrayUnicodeString = Field( + default_factory=lambda: np.array([], dtype=np.dtype(" Dict[str, npt.ArrayLike]: - """Convert root to the column-wise dict representation for backend put""" - - if len(self.root) == 0: - transposed = { # list with empty arrays, each with correct dtype - name: np.array([], dtype=dtype) for name, dtype in SeqTableRowType.descr - } - else: - transposed_list = list(zip(*list(self.root))) - transposed = { - name: np.array(col, dtype=dtype) - for col, (name, dtype) in zip(transposed_list, SeqTableRowType.descr) - } - return transposed @classmethod - def convert_from_protocol_datatype( - cls, pva_table: Dict[str, npt.ArrayLike] + def row( + cls, + *, + repeats: int = 0, + trigger: str = "", + position: int = 0, + time1: int = 0, + outa1: bool = False, + outb1: bool = False, + outc1: bool = False, + outd1: bool = False, + oute1: bool = False, + outf1: bool = False, + time2: int = 0, + outa2: bool = False, + outb2: bool = False, + outc2: bool = False, + outd2: bool = False, + oute2: bool = False, + outf2: bool = False, ) -> "SeqTable": - """Convert a pva table to a row-wise SeqTable.""" - - ordered_columns = [ - np.array(pva_table[name], dtype=dtype) - for name, dtype in SeqTableRowType.descr - ] - - transposed = list(zip(*ordered_columns)) - rows = np.array([tuple(row) for row in transposed], dtype=SeqTableRowType) - return cls(rows) - - @field_validator("root", mode="before") - @classmethod - def check_valid_rows(cls, rows: Union[Sequence, np.ndarray]): - assert isinstance( - rows, (np.ndarray, list) - ), "Rows must be a list or numpy array." - - if not (0 <= len(rows) < 4096): - raise ValueError(f"Length {len(rows)} not in range.") - - if not all(isinstance(row, (np.ndarray, np.void)) for row in rows): - raise ValueError("Cannot construct a SeqTable, some rows are not arrays.") - - if not all(row.dtype is SeqTableRowType for row in rows): - raise ValueError( - "Cannot construct a SeqTable, some rows have incorrect types." - ) - - return np.array(rows, dtype=SeqTableRowType) + return PvaTable.row( + cls, + repeats=repeats, + trigger=trigger, + position=position, + time1=time1, + outa1=outa1, + outb1=outb1, + outc1=outc1, + outd1=outd1, + oute1=oute1, + outf1=outf1, + time2=time2, + outa2=outa2, + outb2=outb2, + outc2=outc2, + outd2=outd2, + oute2=oute2, + outf2=outf2, + ) diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 04adf046a8..daa686b477 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -15,7 +15,6 @@ PcompInfo, SeqTable, SeqTableInfo, - seq_table_row, ) @@ -73,26 +72,26 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table = SeqTable( - [ - # Wait for pre-delay then open shutter - seq_table_row( - time1=in_micros(pre_delay), - time2=in_micros(shutter_time), - outa2=True, - ), - # Keeping shutter open, do N triggers - seq_table_row( - repeats=number_of_frames, - time1=in_micros(exposure), - outa1=True, - outb1=True, - time2=in_micros(deadtime), - outa2=True, - ), - # Add the shutter close - seq_table_row(time2=in_micros(shutter_time)), - ] + table = ( + # Wait for pre-delay then open shutter + SeqTable.row( + time1=in_micros(pre_delay), + time2=in_micros(shutter_time), + outa2=True, + ) + + + # Keeping shutter open, do N triggers + SeqTable.row( + repeats=number_of_frames, + time1=in_micros(exposure), + outa1=True, + outb1=True, + time2=in_micros(deadtime), + outa2=True, + ) + + + # Add the shutter close + SeqTable.row(time2=in_micros(shutter_time)) ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 16d18696ec..eb5d54977b 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -8,6 +8,7 @@ import pytest import yaml from bluesky.run_engine import RunEngine +from pydantic import BaseModel, Field from ophyd_async.core import ( Device, @@ -22,7 +23,6 @@ set_signal_values, walk_rw_signals, ) -from ophyd_async.core._signal_backend import ProtocolDatatypeAbstraction from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw @@ -55,16 +55,8 @@ class MyEnum(str, Enum): three = "three" -class SomeProtocolDatatypeAbstraction(ProtocolDatatypeAbstraction): - def __init__(self, value: int): - self.value = value - - def convert_to_protocol_datatype(self) -> int: - return self.value - 1 - - @classmethod - def convert_from_protocol_datatype(cls, value: int) -> "SomeProtocolDatatypeAbstraction": - return cls(value + 1) +class SomePvaPydanticModel(BaseModel): + some_field: int = Field(default=1) class DummyDeviceGroupAllTypes(Device): @@ -86,7 +78,9 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") - self.pv_protocol_device_abstraction = epics_signal_rw(SomeProtocolDatatypeAbstraction, "PV17") + self.pv_protocol_device_abstraction = epics_signal_rw( + SomePvaPydanticModel, "pva://PV17" + ) @pytest.fixture @@ -170,7 +164,7 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): ["one", "two", "three"], ) await device_all_types.pv_protocol_device_abstraction.set( - SomeProtocolDatatypeAbstraction(1) + SomePvaPydanticModel(some_field=1) ) # Create save plan from utility functions diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index a4bc23d41c..b7ce1b233d 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -20,7 +20,6 @@ SeqBlock, SeqTable, SeqTrigger, - seq_table_row, ) @@ -91,85 +90,83 @@ def test_panda_name_set(panda_t): async def test_panda_children_connected(mock_panda): # try to set and retrieve from simulated values... - table = table = SeqTable( - [ - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_GT, - position=3222, - time1=5, - outa1=True, - outb1=False, - outc1=False, - outd1=True, - oute1=True, - outf1=True, - time2=0, - outa2=True, - outb2=False, - outc2=False, - outd2=True, - oute2=True, - outf2=True, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_LT, - position=-565, - time1=0, - outa1=False, - outb1=False, - outc1=True, - outd1=True, - oute1=False, - outf1=False, - time2=10, - outa2=False, - outb2=False, - outc2=True, - outd2=True, - oute2=False, - outf2=False, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=False, - outb1=True, - outc1=True, - outd1=False, - oute1=True, - outf1=False, - time2=10, - outa2=False, - outb2=True, - outc2=True, - outd2=False, - oute2=True, - outf2=False, - ), - seq_table_row( - repeats=32, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=True, - outb1=True, - outc1=False, - outd1=True, - oute1=False, - outf1=False, - time2=11, - outa2=True, - outb2=True, - outc2=False, - outd2=True, - oute2=False, - outf2=False, - ), - ] + table = ( + SeqTable.row( + repeats=1, + trigger=SeqTrigger.POSA_GT, + position=3222, + time1=5, + outa1=True, + outb1=False, + outc1=False, + outd1=True, + oute1=True, + outf1=True, + time2=0, + outa2=True, + outb2=False, + outc2=False, + outd2=True, + oute2=True, + outf2=True, + ) + + SeqTable.row( + repeats=1, + trigger=SeqTrigger.POSA_LT, + position=-565, + time1=0, + outa1=False, + outb1=False, + outc1=True, + outd1=True, + oute1=False, + outf1=False, + time2=10, + outa2=False, + outb2=False, + outc2=True, + outd2=True, + oute2=False, + outf2=False, + ) + + SeqTable.row( + repeats=1, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=False, + outb1=True, + outc1=True, + outd1=False, + oute1=True, + outf1=False, + time2=10, + outa2=False, + outb2=True, + outc2=True, + outd2=False, + oute2=True, + outf2=False, + ) + + SeqTable.row( + repeats=32, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=True, + outb1=True, + outc1=False, + outd1=True, + oute1=False, + outf1=False, + time2=11, + outa2=True, + outb2=True, + outc2=False, + outd2=True, + oute2=False, + outf2=False, + ) ) await mock_panda.pulse[1].delay.set(20.0) await mock_panda.seq[1].table.set(table) diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index ff8a1559f8..48756523df 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,4 +1,3 @@ - import numpy as np from bluesky import RunEngine @@ -10,7 +9,6 @@ DataBlock, SeqTable, phase_sorter, - seq_table_row, ) @@ -36,27 +34,25 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_save_load_panda(tmp_path, RE: RunEngine): mock_panda1 = await get_mock_panda() - await mock_panda1.seq[1].table.set(SeqTable([seq_table_row(repeats=1)])) + await mock_panda1.seq[1].table.set(SeqTable.row(repeats=1)) RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) def check_equal_with_seq_tables(actual, expected): - assert set(actual.keys()) == set(expected.keys()) - for key, value1 in actual.items(): - value2 = expected[key] - if isinstance(value1, SeqTable): - assert np.array_equal(value1.root, value2.root) - else: - assert value1 == value2 + assert actual.model_fields_set == expected.model_fields_set + for field_name, field_value1 in actual: + field_value2 = getattr(expected, field_name) + assert np.array_equal(field_value1, field_value2) mock_panda2 = await get_mock_panda() - assert np.array_equal( - (await mock_panda2.seq[1].table.get_value()).root, SeqTable([]).root + check_equal_with_seq_tables( + (await mock_panda2.seq[1].table.get_value()), SeqTable() ) RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) - assert np.array_equal( - (await mock_panda2.seq[1].table.get_value()).root, - SeqTable([seq_table_row(repeats=1)]).root, + + check_equal_with_seq_tables( + await mock_panda2.seq[1].table.get_value(), + SeqTable.row(repeats=1), ) """ diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index 88bacabca4..1024283618 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,57 +1,86 @@ +from functools import reduce + import numpy as np import pytest from pydantic import ValidationError -from ophyd_async.fastcs.panda import SeqTable, SeqTableRowType, seq_table_row +from ophyd_async.fastcs.panda import SeqTable -@pytest.mark.parametrize( - # factory so that there aren't global errors if seq_table_row() fails - "rows_arg_factory", - [ - lambda: None, - list, - lambda: [seq_table_row(), seq_table_row()], - lambda: np.array([seq_table_row(), seq_table_row()]), - ], -) -def test_seq_table_initialization_allowed_args(rows_arg_factory): - rows_arg = rows_arg_factory() - seq_table = SeqTable() if rows_arg is None else SeqTable(rows_arg) - assert isinstance(seq_table.root, np.ndarray) - assert len(seq_table.root) == (0 if rows_arg is None else len(rows_arg)) +def test_seq_table_converts_lists(): + seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} + # Validation passes + seq_table = SeqTable(**seq_table_dict_with_lists) + assert isinstance(seq_table.trigger, np.ndarray) + assert seq_table.trigger.dtype == np.dtype("U32") def test_seq_table_validation_errors(): - with pytest.raises( - ValueError, match="Cannot construct a SeqTable, some rows are not arrays." - ): - SeqTable([seq_table_row().tolist()]) - with pytest.raises(ValidationError, match="Length 4098 not in range."): - SeqTable([seq_table_row() for _ in range(4098)]) + with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + SeqTable( + repeats=0, + trigger="", + position=0, + time1=0, + outa1=False, + outb1=False, + outc1=False, + outd1=False, + oute1=False, + outf1=False, + time2=0, + outa2=False, + outb2=False, + outc2=False, + outd2=False, + oute2=False, + outf2=False, + ) + + large_seq_table = SeqTable( + repeats=np.zeros(4095, dtype=np.int32), + trigger=np.array([""] * 4095, dtype="U32"), + position=np.zeros(4095, dtype=np.int32), + time1=np.zeros(4095, dtype=np.int32), + outa1=np.zeros(4095, dtype=np.bool_), + outb1=np.zeros(4095, dtype=np.bool_), + outc1=np.zeros(4095, dtype=np.bool_), + outd1=np.zeros(4095, dtype=np.bool_), + oute1=np.zeros(4095, dtype=np.bool_), + outf1=np.zeros(4095, dtype=np.bool_), + time2=np.zeros(4095, dtype=np.int32), + outa2=np.zeros(4095, dtype=np.bool_), + outb2=np.zeros(4095, dtype=np.bool_), + outc2=np.zeros(4095, dtype=np.bool_), + outd2=np.zeros(4095, dtype=np.bool_), + oute2=np.zeros(4095, dtype=np.bool_), + outf2=np.zeros(4095, dtype=np.bool_), + ) with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types.", + match=( + "1 validation error for SeqTable\n " + "Assertion failed, Length 4096 not in range." + ), ): - SeqTable([seq_table_row(), np.array([1, 2, 3]), seq_table_row()]) + large_seq_table + SeqTable.row() with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types.", + match="12 validation errors for SeqTable", ): - SeqTable( - [ - seq_table_row(), - np.array(range(len(seq_table_row().tolist()))), - seq_table_row(), - ] - ) + row_one = SeqTable.row() + wrong_types = { + field_name: field_value.astype(np.unicode_) + for field_name, field_value in row_one + } + SeqTable(**wrong_types) def test_seq_table_pva_conversion(): expected_pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), "trigger": np.array( - ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype="U14" + ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") ), "position": np.array([1, 2, 3, 4], dtype=np.int32), "time1": np.array([1, 0, 1, 0], dtype=np.int32), @@ -69,51 +98,102 @@ def test_seq_table_pva_conversion(): "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), } - expected_numpy_table = np.array( - [ - (1, "Immediate", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), - (2, "Immediate", 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0), - (3, "BITC=0", 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1), - (4, "Immediate", 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0), - ], - dtype=SeqTableRowType, - ) + expected_row_wise_dict = [ + { + "repeats": 1, + "trigger": "Immediate", + "position": 1, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 1, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 2, + "trigger": "Immediate", + "position": 2, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 2, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + { + "repeats": 3, + "trigger": "BITC=0", + "position": 3, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 3, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 4, + "trigger": "Immediate", + "position": 4, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 4, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + ] - # Can convert from PVA table - numpy_table_from_pva_dict = SeqTable.convert_from_protocol_datatype( - expected_pva_dict - ) - assert np.array_equal(numpy_table_from_pva_dict.root, expected_numpy_table) - assert ( - numpy_table_from_pva_dict.root.dtype - == expected_numpy_table.dtype - == SeqTableRowType - ) - - # Can convert to PVA table - pva_dict_from_numpy_table = SeqTable( - expected_numpy_table - ).convert_to_protocol_datatype() - for column1, column2 in zip( - pva_dict_from_numpy_table.values(), expected_pva_dict.values() + seq_table_from_pva_dict = SeqTable(**expected_pva_dict) + for (_, column1), column2 in zip( + seq_table_from_pva_dict, expected_pva_dict.values() ): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype - # Idempotency - applied_twice_to_numpy_table = SeqTable.convert_from_protocol_datatype( - SeqTable(expected_numpy_table).convert_to_protocol_datatype() - ) - assert np.array_equal(applied_twice_to_numpy_table.root, expected_numpy_table) - assert ( - applied_twice_to_numpy_table.root.dtype - == expected_numpy_table.dtype - == SeqTableRowType + seq_table_from_rows = reduce( + lambda x, y: x + y, + [SeqTable.row(**row_kwargs) for row_kwargs in expected_row_wise_dict], ) + for (_, column1), column2 in zip(seq_table_from_rows, expected_pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype - applied_twice_to_pva_dict = SeqTable( - SeqTable.convert_from_protocol_datatype(expected_pva_dict).root - ).convert_to_protocol_datatype() + # Idempotency + applied_twice_to_pva_dict = SeqTable(**expected_pva_dict).model_dump(mode="python") for column1, column2 in zip( applied_twice_to_pva_dict.values(), expected_pva_dict.values() ): diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index a12ee32200..d7a6e039fb 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -12,7 +12,6 @@ SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, - seq_table_row, ) @@ -38,13 +37,11 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_seq_table_trigger_logic(mock_panda): trigger_logic = StaticSeqTableTriggerLogic(mock_panda.seq[1]) - seq_table = SeqTable( - [ - seq_table_row(outa1=True, outa2=True), - seq_table_row(outa1=False, outa2=False), - seq_table_row(outa1=True, outa2=False), - seq_table_row(outa1=False, outa2=True), - ] + seq_table = ( + SeqTable.row(outa1=True, outa2=True) + + SeqTable.row(outa1=False, outa2=False) + + SeqTable.row(outa1=True, outa2=False) + + SeqTable.row(outa1=False, outa2=True) ) seq_table_info = SeqTableInfo(sequence_table=seq_table, repeats=1) @@ -81,7 +78,7 @@ async def set_active(value: bool): [ ( { - "sequence_table": SeqTable([seq_table_row(outc2=1)]), + "sequence_table_factory": lambda: SeqTable.row(outc2=1), "repeats": 0, "prescale_as_us": -1, }, @@ -90,13 +87,11 @@ async def set_active(value: bool): ), ( { - "sequence_table": SeqTable( - [ - seq_table_row(outc2=True), - seq_table_row(outc2=False), - seq_table_row(outc2=True), - seq_table_row(outc2=False), - ] + "sequence_table_factory": lambda: ( + SeqTable.row(outc2=True) + + SeqTable.row(outc2=False) + + SeqTable.row(outc2=True) + + SeqTable.row(outc2=False) ), "repeats": -1, }, @@ -105,15 +100,16 @@ async def set_active(value: bool): ), ( { - "sequence_table": 1, + "sequence_table_factory": lambda: 1, "repeats": 1, }, - "Assertion failed, Rows must be a list or numpy array. " - "[type=assertion_error, input_value=1, input_type=int]", + "Input should be a valid dictionary or instance of SeqTable " + "[type=model_type, input_value=1, input_type=int]", ), ], ) def test_malformed_seq_table_info(kwargs, error_msg): + kwargs["sequence_table"] = kwargs.pop("sequence_table_factory")() with pytest.raises(ValidationError) as exc: SeqTableInfo(**kwargs) assert error_msg in str(exc.value) diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index 349536ecdc..b6872436f9 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -19,5 +19,6 @@ pv_enum_str: two pv_float: 1.234 pv_int: 1 + pv_protocol_device_abstraction: + some_field: 1 pv_str: test_string - pv_protocol_device_abstraction: 0