From 63e50ca6b6782186e3e8ff7d842d8bbac6c9bd08 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 13 Aug 2024 15:14:06 +0100 Subject: [PATCH 01/11] WIP: converted to row-wise numpy `SeqTable` decided to include this in the PR for #310 --- src/ophyd_async/fastcs/panda/__init__.py | 18 +- src/ophyd_async/fastcs/panda/_block.py | 7 +- src/ophyd_async/fastcs/panda/_table.py | 226 +++++++++++------------ src/ophyd_async/fastcs/panda/_trigger.py | 10 +- src/ophyd_async/plan_stubs/_fly.py | 13 +- tests/fastcs/panda/test_panda_connect.py | 59 +++--- tests/fastcs/panda/test_table.py | 3 + tests/fastcs/panda/test_trigger.py | 31 +++- 8 files changed, 193 insertions(+), 174 deletions(-) diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 9d1c1d429f..a46baed3a0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -14,11 +14,12 @@ from ._table import ( DatasetTable, PandaHdf5DatasetType, - SeqTable, - SeqTableRow, + SeqTablePvaTable, + SeqTableRowType, SeqTrigger, - seq_table_from_arrays, - seq_table_from_rows, + convert_seq_table_to_columnwise_pva_table, + create_seq_table, + seq_table_row, ) from ._trigger import ( PcompInfo, @@ -44,11 +45,12 @@ "PandaPcapController", "DatasetTable", "PandaHdf5DatasetType", - "SeqTable", - "SeqTableRow", + "create_seq_table", + "convert_seq_table_to_columnwise_pva_table", + "SeqTablePvaTable", + "SeqTableRowType", "SeqTrigger", - "seq_table_from_arrays", - "seq_table_from_rows", + "seq_table_row", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_block.py b/src/ophyd_async/fastcs/panda/_block.py index 9deff70015..37a3bc35f3 100644 --- a/src/ophyd_async/fastcs/panda/_block.py +++ b/src/ophyd_async/fastcs/panda/_block.py @@ -1,10 +1,13 @@ from __future__ import annotations from enum import Enum +from typing import Dict + +from pydantic_numpy import NpNDArray from ophyd_async.core import Device, DeviceVector, SignalR, SignalRW, SubsetEnum -from ._table import DatasetTable, SeqTable +from ._table import DatasetTable class DataBlock(Device): @@ -52,7 +55,7 @@ class TimeUnits(str, Enum): class SeqBlock(Device): - table: SignalRW[SeqTable] + table: SignalRW[Dict[str, NpNDArray]] active: SignalRW[bool] repeats: SignalRW[int] prescale: SignalRW[float] diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ec2c1a5b8b..9a62b42f42 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,11 +1,10 @@ -from dataclasses import dataclass from enum import Enum -from typing import Optional, Sequence, Type, TypeVar +from typing import NotRequired, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy.typing as pnd -from typing_extensions import NotRequired, TypedDict +import pydantic_numpy as pnd +from typing_extensions import TypedDict class PandaHdf5DatasetType(str, Enum): @@ -34,28 +33,99 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -@dataclass -class SeqTableRow: - repeats: int = 1 - trigger: SeqTrigger = SeqTrigger.IMMEDIATE - position: int = 0 - time1: int = 0 - outa1: bool = False - outb1: bool = False - outc1: bool = False - outd1: bool = False - oute1: bool = False - outf1: bool = False - time2: int = 0 - outa2: bool = False - outb2: bool = False - outc2: bool = False - outd2: bool = False - oute2: bool = False - outf2: bool = False - - -class SeqTable(TypedDict): +SeqTableRowType = np.dtype( + [ + ("repeats", np.int32), + ("trigger", "U14"), # One of the SeqTrigger values + ("position", np.int32), + ("time1", np.int32), + ("outa1", np.bool_), + ("outb1", np.bool_), + ("outc1", np.bool_), + ("outd1", np.bool_), + ("oute1", np.bool_), + ("outf1", np.bool_), + ("time2", np.int32), + ("outa2", np.bool_), + ("outb2", np.bool_), + ("outc2", np.bool_), + ("outd2", np.bool_), + ("oute2", np.bool_), + ("outf2", np.bool_), + ] +) + + +def seq_table_row( + *, + repeats: int = 0, + trigger: str = "", + position: int = 0, + time1: int = 0, + outa1: bool = False, + outb1: bool = False, + outc1: bool = False, + outd1: bool = False, + oute1: bool = False, + outf1: bool = False, + time2: int = 0, + outa2: bool = False, + outb2: bool = False, + outc2: bool = False, + outd2: bool = False, + oute2: bool = False, + outf2: bool = False, +) -> pnd.NpNDArray: + return np.array( + ( + repeats, + trigger, + position, + time1, + outa1, + outb1, + outc1, + outd1, + oute1, + outf1, + time2, + outa2, + outb2, + outc2, + outd2, + oute2, + outf2, + ), + dtype=SeqTableRowType, + ) + + +_SEQ_TABLE_ROW_SHAPE = seq_table_row().shape +_SEQ_TABLE_COLUMN_NAMES = [x[0] for x in SeqTableRowType.names] + + +def create_seq_table(*rows: pnd.NpNDArray) -> pnd.NpNDArray: + if not (0 < len(rows) < 4096): + raise ValueError(f"Length {len(rows)} not in range.") + + if not all(isinstance(row, np.ndarray) for row in rows): + for row in rows: + if not isinstance(row, np.void): + raise ValueError( + f"Cannot construct a SeqTable, some rows {row} are not arrays {type(row)}." + ) + raise ValueError("Cannot construct a SeqTable, some rows are not arrays.") + if not all(row.shape == _SEQ_TABLE_ROW_SHAPE for row in rows): + raise ValueError( + "Cannot construct a SeqTable, some rows have incorrect shapes." + ) + if not all(row.dtype is SeqTableRowType for row in rows): + raise ValueError("Cannot construct a SeqTable, some rows have incorrect types.") + + return np.array(rows) + + +class SeqTablePvaTable(TypedDict): repeats: NotRequired[pnd.Np1DArrayUint16] trigger: NotRequired[Sequence[SeqTrigger]] position: NotRequired[pnd.Np1DArrayInt32] @@ -75,96 +145,14 @@ class SeqTable(TypedDict): outf2: NotRequired[pnd.Np1DArrayBool] -def seq_table_from_rows(*rows: SeqTableRow): - """ - Constructs a sequence table from a series of rows. - """ - return seq_table_from_arrays( - repeats=np.array([row.repeats for row in rows], dtype=np.uint16), - trigger=[row.trigger for row in rows], - position=np.array([row.position for row in rows], dtype=np.int32), - time1=np.array([row.time1 for row in rows], dtype=np.uint32), - outa1=np.array([row.outa1 for row in rows], dtype=np.bool_), - outb1=np.array([row.outb1 for row in rows], dtype=np.bool_), - outc1=np.array([row.outc1 for row in rows], dtype=np.bool_), - outd1=np.array([row.outd1 for row in rows], dtype=np.bool_), - oute1=np.array([row.oute1 for row in rows], dtype=np.bool_), - outf1=np.array([row.outf1 for row in rows], dtype=np.bool_), - time2=np.array([row.time2 for row in rows], dtype=np.uint32), - outa2=np.array([row.outa2 for row in rows], dtype=np.bool_), - outb2=np.array([row.outb2 for row in rows], dtype=np.bool_), - outc2=np.array([row.outc2 for row in rows], dtype=np.bool_), - outd2=np.array([row.outd2 for row in rows], dtype=np.bool_), - oute2=np.array([row.oute2 for row in rows], dtype=np.bool_), - outf2=np.array([row.outf2 for row in rows], dtype=np.bool_), - ) - - -T = TypeVar("T", bound=np.generic) - - -def seq_table_from_arrays( - *, - repeats: Optional[npt.NDArray[np.uint16]] = None, - trigger: Optional[Sequence[SeqTrigger]] = None, - position: Optional[npt.NDArray[np.int32]] = None, - time1: Optional[npt.NDArray[np.uint32]] = None, - outa1: Optional[npt.NDArray[np.bool_]] = None, - outb1: Optional[npt.NDArray[np.bool_]] = None, - outc1: Optional[npt.NDArray[np.bool_]] = None, - outd1: Optional[npt.NDArray[np.bool_]] = None, - oute1: Optional[npt.NDArray[np.bool_]] = None, - outf1: Optional[npt.NDArray[np.bool_]] = None, - time2: npt.NDArray[np.uint32], - outa2: Optional[npt.NDArray[np.bool_]] = None, - outb2: Optional[npt.NDArray[np.bool_]] = None, - outc2: Optional[npt.NDArray[np.bool_]] = None, - outd2: Optional[npt.NDArray[np.bool_]] = None, - oute2: Optional[npt.NDArray[np.bool_]] = None, - outf2: Optional[npt.NDArray[np.bool_]] = None, -) -> SeqTable: - """ - Constructs a sequence table from a series of columns as arrays. - time2 is the only required argument and must not be None. - All other provided arguments must be of equal length to time2. - If any other argument is not given, or else given as None or empty, - an array of length len(time2) filled with the following is defaulted: - repeats: 1 - trigger: SeqTrigger.IMMEDIATE - all others: 0/False as appropriate - """ - assert time2 is not None, "time2 must be provided" - length = len(time2) - assert 0 < length < 4096, f"Length {length} not in range" - - def or_default( - value: Optional[npt.NDArray[T]], dtype: Type[T], default_value: int = 0 - ) -> npt.NDArray[T]: - if value is None or len(value) == 0: - return np.full(length, default_value, dtype=dtype) - return value - - table = SeqTable( - repeats=or_default(repeats, np.uint16, 1), - trigger=trigger or [SeqTrigger.IMMEDIATE] * length, - position=or_default(position, np.int32), - time1=or_default(time1, np.uint32), - outa1=or_default(outa1, np.bool_), - outb1=or_default(outb1, np.bool_), - outc1=or_default(outc1, np.bool_), - outd1=or_default(outd1, np.bool_), - oute1=or_default(oute1, np.bool_), - outf1=or_default(outf1, np.bool_), - time2=time2, - outa2=or_default(outa2, np.bool_), - outb2=or_default(outb2, np.bool_), - outc2=or_default(outc2, np.bool_), - outd2=or_default(outd2, np.bool_), - oute2=or_default(oute2, np.bool_), - outf2=or_default(outf2, np.bool_), - ) - for k, v in table.items(): - size = len(v) # type: ignore - if size != length: - raise ValueError(f"{k}: has length {size} not {length}") - return table +def convert_seq_table_to_columnwise_pva_table( + seq_table: pnd.NpNDArray, +) -> SeqTablePvaTable: + if seq_table.dtype != SeqTableRowType: + raise ValueError( + f"Cannot convert a SeqTable to a columnwise dictionary, " + f"input is not a SeqTable {seq_table.dtype}." + ) + print(seq_table) + transposed = seq_table.transpose(axis=1) + return dict(zip(_SEQ_TABLE_COLUMN_NAMES, transposed)) diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index c79988a381..977e5781f3 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -2,15 +2,16 @@ from typing import Optional from pydantic import BaseModel, Field +from pydantic_numpy import NpNDArray from ophyd_async.core import TriggerLogic, wait_for_value from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits -from ._table import SeqTable +from ._table import convert_seq_table_to_columnwise_pva_table class SeqTableInfo(BaseModel): - sequence_table: SeqTable = Field(strict=True) + sequence_table: NpNDArray = Field(strict=True) repeats: int = Field(ge=0) prescale_as_us: float = Field(default=1, ge=0) # microseconds @@ -24,10 +25,13 @@ async def prepare(self, value: SeqTableInfo): self.seq.prescale_units.set(TimeUnits.us), self.seq.enable.set("ZERO"), ) + seq_table_pva_table = convert_seq_table_to_columnwise_pva_table( + value.sequence_table + ) await asyncio.gather( self.seq.prescale.set(value.prescale_as_us), self.seq.repeats.set(value.repeats), - self.seq.table.set(value.sequence_table), + self.seq.table.set(seq_table_pva_table), ) async def kickoff(self) -> None: diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 087ec62dd1..95394a6863 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -13,10 +13,9 @@ from ophyd_async.fastcs.panda import ( PcompDirectionOptions, PcompInfo, - SeqTable, SeqTableInfo, - SeqTableRow, - seq_table_from_rows, + create_seq_table, + seq_table_row, ) @@ -74,15 +73,15 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table: SeqTable = seq_table_from_rows( + table = create_seq_table( # Wait for pre-delay then open shutter - SeqTableRow( + seq_table_row( time1=in_micros(pre_delay), time2=in_micros(shutter_time), outa2=True, ), # Keeping shutter open, do N triggers - SeqTableRow( + seq_table_row( repeats=number_of_frames, time1=in_micros(exposure), outa1=True, @@ -91,7 +90,7 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( outa2=True, ), # Add the shutter close - SeqTableRow(time2=in_micros(shutter_time)), + seq_table_row(time2=in_micros(shutter_time)), ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 2685f3c66c..4204e649f6 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -19,8 +19,11 @@ PcapBlock, PulseBlock, SeqBlock, - SeqTable, + SeqTablePvaTable, + convert_seq_table_to_columnwise_pva_table, SeqTrigger, + create_seq_table, + seq_table_row, ) @@ -91,32 +94,36 @@ def test_panda_name_set(panda_t): async def test_panda_children_connected(mock_panda): # try to set and retrieve from simulated values... - table = SeqTable( - repeats=np.array([1, 1, 1, 32]).astype(np.uint16), - trigger=( - SeqTrigger.POSA_GT, - SeqTrigger.POSA_LT, - SeqTrigger.IMMEDIATE, - SeqTrigger.IMMEDIATE, - ), - position=np.array([3222, -565, 0, 0], dtype=np.int32), - time1=np.array([5, 0, 10, 10]).astype(np.uint32), # TODO: change below syntax. - outa1=np.array([1, 0, 0, 1]).astype(np.bool_), - outb1=np.array([0, 0, 1, 1]).astype(np.bool_), - outc1=np.array([0, 1, 1, 0]).astype(np.bool_), - outd1=np.array([1, 1, 0, 1]).astype(np.bool_), - oute1=np.array([1, 0, 1, 0]).astype(np.bool_), - outf1=np.array([1, 0, 0, 0]).astype(np.bool_), - time2=np.array([0, 10, 10, 11]).astype(np.uint32), - outa2=np.array([1, 0, 0, 1]).astype(np.bool_), - outb2=np.array([0, 0, 1, 1]).astype(np.bool_), - outc2=np.array([0, 1, 1, 0]).astype(np.bool_), - outd2=np.array([1, 1, 0, 1]).astype(np.bool_), - oute2=np.array([1, 0, 1, 0]).astype(np.bool_), - outf2=np.array([1, 0, 0, 0]).astype(np.bool_), + table = create_seq_table( + seq_table_row( + repeats=np.array([1, 1, 1, 32]).astype(np.uint16), + trigger=( + SeqTrigger.POSA_GT, + SeqTrigger.POSA_LT, + SeqTrigger.IMMEDIATE, + SeqTrigger.IMMEDIATE, + ), + position=np.array([3222, -565, 0, 0], dtype=np.int32), + time1=np.array([5, 0, 10, 10]).astype( + np.uint32 + ), # TODO: change below syntax. + outa1=np.array([1, 0, 0, 1]).astype(np.bool_), + outb1=np.array([0, 0, 1, 1]).astype(np.bool_), + outc1=np.array([0, 1, 1, 0]).astype(np.bool_), + outd1=np.array([1, 1, 0, 1]).astype(np.bool_), + oute1=np.array([1, 0, 1, 0]).astype(np.bool_), + outf1=np.array([1, 0, 0, 0]).astype(np.bool_), + time2=np.array([0, 10, 10, 11]).astype(np.uint32), + outa2=np.array([1, 0, 0, 1]).astype(np.bool_), + outb2=np.array([0, 0, 1, 1]).astype(np.bool_), + outc2=np.array([0, 1, 1, 0]).astype(np.bool_), + outd2=np.array([1, 1, 0, 1]).astype(np.bool_), + oute2=np.array([1, 0, 1, 0]).astype(np.bool_), + outf2=np.array([1, 0, 0, 0]).astype(np.bool_), + ) ) await mock_panda.pulse[1].delay.set(20.0) - await mock_panda.seq[1].table.set(table) + await mock_panda.seq[1].table.set(convert_seq_table_to_columnwise_pva_table(table)) readback_pulse = await mock_panda.pulse[1].delay.get_value() readback_seq = await mock_panda.seq[1].table.get_value() @@ -164,7 +171,7 @@ async def test_panda_gets_types_from_common_class(panda_pva, panda_t): assert panda.pcap.active._backend.datatype is bool # works with custom datatypes - assert panda.seq[1].table._backend.datatype is SeqTable + assert panda.seq[1].table._backend.datatype is SeqTablePvaTable # others are given the None datatype assert panda.pcap.newsignal._backend.datatype is None diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index ad92683bbd..9b795f1292 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,3 +1,4 @@ +""" import numpy as np import pytest @@ -29,3 +30,5 @@ def test_from_arrays_too_long(): time2 = np.zeros(4097) with pytest.raises(AssertionError, match="Length 4097 not in range"): seq_table_from_arrays(time2=time2) + +""" diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 1a76614afa..005daccaf9 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -1,6 +1,5 @@ import asyncio -import numpy as np import pytest from pydantic import ValidationError @@ -9,10 +8,11 @@ from ophyd_async.fastcs.panda import ( CommonPandaBlocks, PcompInfo, - SeqTable, SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, + create_seq_table, + seq_table_row, ) @@ -38,8 +38,11 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_seq_table_trigger_logic(mock_panda): trigger_logic = StaticSeqTableTriggerLogic(mock_panda.seq[1]) - seq_table = SeqTable( - outa1=np.array([1, 2, 3, 4, 5]), outa2=np.array([1, 2, 3, 4, 5]) + seq_table = create_seq_table( + seq_table_row(outa1=True, outa2=True), + seq_table_row(outa1=False, outa2=False), + seq_table_row(outa1=True, outa2=False), + seq_table_row(outa1=False, outa2=True), ) seq_table_info = SeqTableInfo(sequence_table=seq_table, repeats=1) @@ -75,13 +78,22 @@ async def set_active(value: bool): ["kwargs", "error_msg"], [ ( - {"sequence_table": {}, "repeats": 0, "prescale_as_us": -1}, + { + "sequence_table": create_seq_table(seq_table_row(outc2=1)), + "repeats": 0, + "prescale_as_us": -1, + }, "Input should be greater than or equal to 0 " "[type=greater_than_equal, input_value=-1, input_type=int]", ), ( { - "sequence_table": SeqTable(outc2=np.array([1, 0, 1, 0], dtype=bool)), + "sequence_table": create_seq_table( + seq_table_row(outc2=True), + seq_table_row(outc2=False), + seq_table_row(outc2=True), + seq_table_row(outc2=False), + ), "repeats": -1, }, "Input should be greater than or equal to 0 " @@ -89,11 +101,12 @@ async def set_active(value: bool): ), ( { - "sequence_table": SeqTable(outc2=1), + "sequence_table": 1, "repeats": 1, }, - "Input should be an instance of ndarray " - "[type=is_instance_of, input_value=1, input_type=int]", + "Value error, Cannot construct a SeqTable, " + "input is not an unpacked tuple of `SeqTableRowType`. " + "[type=value_error, input_value=1, input_type=int]", ), ], ) From 28b5e2a613d08230f13dc31a9d24e84176f0963a Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Wed, 14 Aug 2024 15:03:49 +0100 Subject: [PATCH 02/11] created the `PvaTableAbstraction` and used it in the `SeqTable` Need to add json serialization. --- src/ophyd_async/core/__init__.py | 10 +- src/ophyd_async/core/_signal_backend.py | 49 ++++- src/ophyd_async/core/_soft_signal_backend.py | 73 +++++-- src/ophyd_async/epics/signal/__init__.py | 3 +- src/ophyd_async/epics/signal/_aioca.py | 143 ++++++++------ src/ophyd_async/epics/signal/_p4p.py | 198 ++++++++++++------- src/ophyd_async/fastcs/panda/__init__.py | 8 +- src/ophyd_async/fastcs/panda/_block.py | 7 +- src/ophyd_async/fastcs/panda/_table.py | 120 +++++------ src/ophyd_async/fastcs/panda/_trigger.py | 10 +- src/ophyd_async/plan_stubs/_fly.py | 6 +- tests/core/test_subset_enum.py | 8 +- tests/fastcs/panda/test_panda_connect.py | 113 ++++++++--- tests/fastcs/panda/test_panda_utils.py | 32 +-- tests/fastcs/panda/test_table.py | 141 ++++++++++--- tests/fastcs/panda/test_trigger.py | 17 +- 16 files changed, 641 insertions(+), 297 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 3f88752fdd..1b082f2835 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -61,7 +61,13 @@ soft_signal_rw, wait_for_value, ) -from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum +from ._signal_backend import ( + BackendConverterFactory, + ProtocolDatatypeAbstraction, + RuntimeSubsetEnum, + SignalBackend, + SubsetEnum, +) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus from ._utils import ( @@ -103,6 +109,7 @@ "MockSignalBackend", "callback_on_mock_put", "get_mock_put", + "BackendConverterFactory", "mock_puts_blocked", "reset_mock_put_calls", "set_mock_put_proceeds", @@ -117,6 +124,7 @@ "NameProvider", "PathInfo", "PathProvider", + "ProtocolDatatypeAbstraction", "ShapeProvider", "StaticFilenameProvider", "StaticPathProvider", diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 41e9fbcbd3..c66cc0b50a 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,10 +1,55 @@ -from abc import abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, + Literal, + Optional, + Tuple, + Type, +) from ._protocol import DataKey, Reading from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T +class ProtocolDatatypeAbstraction(ABC, Generic[T]): + @abstractmethod + def __init__(self): + """The abstract datatype must be able to be intialized with no arguments.""" + + @abstractmethod + def convert_to_protocol_datatype(self) -> T: + """ + Convert the abstract datatype to a form which can be sent + over whichever protocol. + """ + + @classmethod + @abstractmethod + def convert_from_protocol_datatype(cls, value: T) -> "ProtocolDatatypeAbstraction": + """ + Convert the datatype received from the protocol to a + higher level abstract datatype. + """ + + +class BackendConverterFactory(ABC): + """Convert between the signal backend and the signal type""" + + _ALLOWED_TYPES: ClassVar[Tuple[Type]] + + @classmethod + @abstractmethod + def datatype_allowed(cls, datatype: Type) -> bool: + """Check if the datatype is allowed.""" + + @classmethod + @abstractmethod + def make_converter(self, datatype: Type): + """Updates the object with callables `to_signal` and `from_signal`.""" + + class SignalBackend(Generic[T]): """A read/write/monitor backend for a Signals""" diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 62bafd5bb1..3e6b821f09 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -4,13 +4,18 @@ import time from collections import abc from enum import Enum -from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin +from typing import Any, Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin import numpy as np from bluesky.protocols import DataKey, Dtype, Reading from typing_extensions import TypedDict -from ._signal_backend import RuntimeSubsetEnum, SignalBackend +from ._signal_backend import ( + BackendConverterFactory, + ProtocolDatatypeAbstraction, + RuntimeSubsetEnum, + SignalBackend, +) from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype primitive_dtypes: Dict[type, Dtype] = { @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftEnumConverter(SoftConverter): choices: Tuple[str, ...] - def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]): + def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]): if issubclass(datatype, Enum): self.choices = tuple(v.value for v in datatype) else: @@ -122,19 +127,55 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) -def make_converter(datatype): - is_array = get_dtype(datatype) is not None - is_sequence = get_origin(datatype) == abc.Sequence - is_enum = inspect.isclass(datatype) and ( - issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) - ) +class SoftProtocolDatatypeAbstractionConverter(SoftConverter): + """ + No conversion is necessary for ProtocolDatatypeAbstraction datatypes in soft + signals. + """ + def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]): + self.datatype = datatype + + def reading(self, value: T, timestamp: float, severity: int) -> Reading: + return super().reading(value, timestamp, severity) + + def value(self, value: Any) -> Any: + return value + + def write_value(self, value): + return value + + def make_initial_value(self, datatype: Type | None) -> Any: + return super().make_initial_value(datatype) - if is_array or is_sequence: - return SoftArrayConverter() - if is_enum: - return SoftEnumConverter(datatype) - return SoftConverter() +class SoftSignalConverterFactory(BackendConverterFactory): + _ALLOWED_TYPES = (object,) # Any type is allowed + + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + return True # Any value allowed in a soft signal + + + @classmethod + def make_converter(cls, datatype): + is_array = get_dtype(datatype) is not None + is_sequence = get_origin(datatype) == abc.Sequence + is_enum = inspect.isclass(datatype) and ( + issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) + ) + is_convertable_abstract_datatype = inspect.isclass(datatype) and issubclass( + datatype, + ProtocolDatatypeAbstraction + ) + + if is_array or is_sequence: + return SoftArrayConverter() + if is_enum: + return SoftEnumConverter(datatype) + if is_convertable_abstract_datatype: + return SoftProtocolDatatypeAbstractionConverter(datatype) + + return SoftConverter() class SoftSignalBackend(SignalBackend[T]): @@ -154,7 +195,9 @@ def __init__( self.datatype = datatype self._initial_value = initial_value self._metadata = metadata or {} - self.converter: SoftConverter = make_converter(datatype) + self.converter: SoftConverter = SoftSignalConverterFactory.make_converter( + datatype + ) if self._initial_value is None: self._initial_value = self.converter.make_initial_value(self.datatype) else: diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index 8d7628bf01..5da098b59e 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,5 +1,5 @@ from ._common import LimitPair, Limits, get_supported_values -from ._p4p import PvaSignalBackend +from ._p4p import PvaSignalBackend, PvaTableAbstraction from ._signal import ( epics_signal_r, epics_signal_rw, @@ -13,6 +13,7 @@ "LimitPair", "Limits", "PvaSignalBackend", + "PvaTableAbstraction", "epics_signal_r", "epics_signal_rw", "epics_signal_rw_rbv", diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 78052d448d..ce57acba80 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,9 +1,10 @@ +import inspect import logging import sys from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin import numpy as np from aioca import ( @@ -22,8 +23,10 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, + BackendConverterFactory, NotConnected, ReadingValueCallback, + RuntimeSubsetEnum, SignalBackend, T, get_dtype, @@ -183,63 +186,89 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -def make_converter( - datatype: Optional[Type], values: Dict[str, AugmentedValue] -) -> CaConverter: - pv = list(values)[0] - pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") - is_array = bool([v for v in values.values() if v.element_count > 1]) - if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: - # Override waveform of chars to be treated as string - return CaLongStrConverter() - elif is_array and pv_dbr == dbr.DBR_STRING: - # Waveform of strings, check we wanted this - if datatype: - datatype_dtype = get_dtype(datatype) - if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return CaArrayConverter(pv_dbr, None) - elif is_array: - pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - return CaArrayConverter(pv_dbr, None) - elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, CA can do int - pv_choices_len = get_unique( - {k: len(v.enums) for k, v in values.items()}, "number of choices" - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) - elif pv_dbr == dbr.DBR_ENUM: - # This is an Enum - pv_choices = get_unique( - {k: tuple(v.enums) for k, v in values.items()}, "choices" +class CaConverterFactory(BackendConverterFactory): + _ALLOWED_TYPES = ( + bool, + int, + float, + str, + Sequence, + Enum, + RuntimeSubsetEnum, + np.ndarray + ) + + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + stripped_origin = get_origin(datatype) or datatype + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_TYPES ) - supported_values = get_supported_values(pv, datatype, pv_choices) - return CaEnumConverter(dbr.DBR_STRING, None, supported_values) - else: - value = list(values.values())[0] - # Done the dbr check, so enough to check one of the values - if datatype and not isinstance(value, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = ( - isinstance(value, float) - and get_unique({k: v.precision for k, v in values.items()}, "precision") - == 0 + + @classmethod + def make_converter( + cls, datatype: Optional[Type], values: Dict[str, AugmentedValue] + ) -> CaConverter: + if datatype is not None and not cls.datatype_allowed(datatype): + raise TypeError(f"Given datatype {datatype} unsupported in CA.") + + pv = list(values)[0] + pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + is_array = bool([v for v in values.values() if v.element_count > 1]) + if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: + # Override waveform of chars to be treated as string + return CaLongStrConverter() + elif is_array and pv_dbr == dbr.DBR_STRING: + # Waveform of strings, check we wanted this + if datatype: + datatype_dtype = get_dtype(datatype) + if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): + raise TypeError(f"{pv} has type [str] not {datatype.__name__}") + return CaArrayConverter(pv_dbr, None) + elif is_array: + pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") + # This is an array + if datatype: + # Check we wanted an array of this type + dtype = get_dtype(datatype) + if not dtype: + raise TypeError( + f"{pv} has type [{pv_dtype}] not {datatype.__name__}" + ) + if dtype != pv_dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") + return CaArrayConverter(pv_dbr, None) + elif pv_dbr == dbr.DBR_ENUM and datatype is bool: + # Database can't do bools, so are often representated as enums, + # CA can do int + pv_choices_len = get_unique( + {k: len(v.enums) for k, v in values.items()}, "number of choices" ) - if not (datatype is int and is_prec_zero_float): - raise TypeError( - f"{pv} has type {type(value).__name__.replace('ca_', '')} " - + f"not {datatype.__name__}" + if pv_choices_len != 2: + raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") + return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) + elif pv_dbr == dbr.DBR_ENUM: + # This is an Enum + pv_choices = get_unique( + {k: tuple(v.enums) for k, v in values.items()}, "choices" + ) + supported_values = get_supported_values(pv, datatype, pv_choices) + return CaEnumConverter(dbr.DBR_STRING, None, supported_values) + else: + value = list(values.values())[0] + # Done the dbr check, so enough to check one of the values + if datatype and not isinstance(value, datatype): + # Allow int signals to represent float records when prec is 0 + is_prec_zero_float = ( + isinstance(value, float) + and get_unique({k: v.precision for k, v in values.items()}, "precision") + == 0 ) + if not (datatype is int and is_prec_zero_float): + raise TypeError( + f"{pv} has type {type(value).__name__.replace('ca_', '')} " + + f"not {datatype.__name__}" + ) return CaConverter(pv_dbr, None) @@ -287,7 +316,9 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): else: # The same, so only need to connect one await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = make_converter(self.datatype, self.initial_values) + self.converter = CaConverterFactory.make_converter( + self.datatype, self.initial_values + ) async def put(self, value: Optional[T], wait=True, timeout=None): if value is None: diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 28ec8fe6ab..3f95740522 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -6,14 +6,17 @@ from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +import numpy as np from bluesky.protocols import DataKey, Dtype, Reading from p4p import Value from p4p.client.asyncio import Context, Subscription +from pydantic import BaseModel from ophyd_async.core import ( DEFAULT_TIMEOUT, + BackendConverterFactory, NotConnected, ReadingValueCallback, RuntimeSubsetEnum, @@ -64,7 +67,7 @@ def _data_key_from_value( *, shape: Optional[list[int]] = None, choices: Optional[list[str]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[str] = None, ) -> DataKey: """ Args: @@ -284,75 +287,131 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConverter: - pv = list(values)[0] - typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") - typ = get_unique( - {k: type(v.get("value")) for k, v in values.items()}, "value types" +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): + self.datatype = datatype + + def reading(self, value: Value): + ts = time.time() + value = self.value(value) + return {"value": value, "timestamp": ts, "alarm_severity": 0} + + def value(self, value: Value): + return self.datatype(**value.todict()) + + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + """ + A user can put whichever form to the signal. + This is required for yaml deserialization. + """ + if isinstance(value, self.datatype): + return value.model_dump(mode="python") + return value + + +class PvaConverterFactory(BackendConverterFactory): + _ALLOWED_TYPES = ( + bool, + int, + float, + str, + Sequence, + np.ndarray, + Enum, + RuntimeSubsetEnum, + BaseModel, + dict, ) - if "NTScalarArray" in typeid and typ is list: - # Waveform of strings, check we wanted this - if datatype and datatype != Sequence[str]: - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return PvaArrayConverter() - elif "NTScalarArray" in typeid or "NTNDArray" in typeid: - pv_dtype = get_unique( - {k: v["value"].dtype for k, v in values.items()}, "dtypes" - ) - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - if "NTNDArray" in typeid: - return PvaNDArrayConverter() - else: - return PvaArrayConverter() - elif "NTEnum" in typeid and datatype is bool: - # Wanted a bool, but database represents as an enum - pv_choices_len = get_unique( - {k: len(v["value"]["choices"]) for k, v in values.items()}, - "number of choices", + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_TYPES ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return PvaEmumBoolConverter() - elif "NTEnum" in typeid: - # This is an Enum - pv_choices = get_unique( - {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" + + @classmethod + def make_converter( + cls, datatype: Optional[Type], values: Dict[str, Any] + ) -> PvaConverter: + pv = list(values)[0] + typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") + typ = get_unique( + {k: type(v.get("value")) for k, v in values.items()}, "value types" ) - return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) - elif "NTScalar" in typeid: - if ( - typ is str - and inspect.isclass(datatype) - and issubclass(datatype, RuntimeSubsetEnum) - ): - return PvaEnumConverter( - get_supported_values(pv, datatype, datatype.choices) + if "NTScalarArray" in typeid and typ is list: + # Waveform of strings, check we wanted this + if datatype and datatype != Sequence[str]: + raise TypeError(f"{pv} has type [str] not {datatype.__name__}") + return PvaArrayConverter() + elif "NTScalarArray" in typeid or "NTNDArray" in typeid: + pv_dtype = get_unique( + {k: v["value"].dtype for k, v in values.items()}, "dtypes" ) - elif datatype and not issubclass(typ, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = typ is float and ( - get_unique( - {k: v["display"]["precision"] for k, v in values.items()}, - "precision", - ) - == 0 + # This is an array + if datatype: + # Check we wanted an array of this type + dtype = get_dtype(datatype) + if not dtype: + raise TypeError( + f"{pv} has type [{pv_dtype}] not {datatype.__name__}" + ) + if dtype != pv_dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") + if "NTNDArray" in typeid: + return PvaNDArrayConverter() + else: + return PvaArrayConverter() + elif "NTEnum" in typeid and datatype is bool: + # Wanted a bool, but database represents as an enum + pv_choices_len = get_unique( + {k: len(v["value"]["choices"]) for k, v in values.items()}, + "number of choices", ) - if not (datatype is int and is_prec_zero_float): - raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") - return PvaConverter() - elif "NTTable" in typeid: - return PvaTableConverter() - elif "structure" in typeid: - return PvaDictConverter() - else: - raise TypeError(f"{pv}: Unsupported typeid {typeid}") + if pv_choices_len != 2: + raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") + return PvaEmumBoolConverter() + elif "NTEnum" in typeid: + # This is an Enum + pv_choices = get_unique( + {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" + ) + return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) + elif "NTScalar" in typeid: + if ( + typ is str + and inspect.isclass(datatype) + and issubclass(datatype, RuntimeSubsetEnum) + ): + return PvaEnumConverter( + get_supported_values(pv, datatype, datatype.choices) + ) + elif datatype and not issubclass(typ, datatype): + # Allow int signals to represent float records when prec is 0 + is_prec_zero_float = typ is float and ( + get_unique( + {k: v["display"]["precision"] for k, v in values.items()}, + "precision", + ) + == 0 + ) + if not (datatype is int and is_prec_zero_float): + raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") + return PvaConverter() + elif "NTTable" in typeid: + return PvaTableConverter() + elif "structure" in typeid: + if ( + datatype + and inspect.isclass(datatype) + and issubclass(datatype, BaseModel) + ): + return PvaPydanticModelConverter(datatype) + return PvaDictConverter() + else: + raise TypeError(f"{pv}: Unsupported typeid {typeid}") class PvaSignalBackend(SignalBackend[T]): @@ -360,6 +419,9 @@ class PvaSignalBackend(SignalBackend[T]): def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not PvaConverterFactory.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") + self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} @@ -402,7 +464,9 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): else: # The same, so only need to connect one await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = make_converter(self.datatype, self.initial_values) + self.converter = PvaConverterFactory.make_converter( + self.datatype, self.initial_values + ) async def put(self, value: Optional[T], wait=True, timeout=None): if value is None: diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index a46baed3a0..3724397053 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -14,11 +14,9 @@ from ._table import ( DatasetTable, PandaHdf5DatasetType, - SeqTablePvaTable, + SeqTable, SeqTableRowType, SeqTrigger, - convert_seq_table_to_columnwise_pva_table, - create_seq_table, seq_table_row, ) from ._trigger import ( @@ -45,9 +43,7 @@ "PandaPcapController", "DatasetTable", "PandaHdf5DatasetType", - "create_seq_table", - "convert_seq_table_to_columnwise_pva_table", - "SeqTablePvaTable", + "SeqTable", "SeqTableRowType", "SeqTrigger", "seq_table_row", diff --git a/src/ophyd_async/fastcs/panda/_block.py b/src/ophyd_async/fastcs/panda/_block.py index 37a3bc35f3..9deff70015 100644 --- a/src/ophyd_async/fastcs/panda/_block.py +++ b/src/ophyd_async/fastcs/panda/_block.py @@ -1,13 +1,10 @@ from __future__ import annotations from enum import Enum -from typing import Dict - -from pydantic_numpy import NpNDArray from ophyd_async.core import Device, DeviceVector, SignalR, SignalRW, SubsetEnum -from ._table import DatasetTable +from ._table import DatasetTable, SeqTable class DataBlock(Device): @@ -55,7 +52,7 @@ class TimeUnits(str, Enum): class SeqBlock(Device): - table: SignalRW[Dict[str, NpNDArray]] + table: SignalRW[SeqTable] active: SignalRW[bool] repeats: SignalRW[int] prescale: SignalRW[float] diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 9a62b42f42..12d201df22 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,11 +1,14 @@ from enum import Enum -from typing import NotRequired, Sequence +from typing import Dict, Sequence, Union import numpy as np import numpy.typing as npt import pydantic_numpy as pnd +from pydantic import Field, RootModel, field_validator from typing_extensions import TypedDict +from ophyd_async.epics.signal import PvaTableAbstraction + class PandaHdf5DatasetType(str, Enum): FLOAT_64 = "float64" @@ -52,6 +55,7 @@ class SeqTrigger(str, Enum): ("outd2", np.bool_), ("oute2", np.bool_), ("outf2", np.bool_), + ] ) @@ -100,59 +104,61 @@ def seq_table_row( ) -_SEQ_TABLE_ROW_SHAPE = seq_table_row().shape -_SEQ_TABLE_COLUMN_NAMES = [x[0] for x in SeqTableRowType.names] - - -def create_seq_table(*rows: pnd.NpNDArray) -> pnd.NpNDArray: - if not (0 < len(rows) < 4096): - raise ValueError(f"Length {len(rows)} not in range.") - - if not all(isinstance(row, np.ndarray) for row in rows): - for row in rows: - if not isinstance(row, np.void): - raise ValueError( - f"Cannot construct a SeqTable, some rows {row} are not arrays {type(row)}." - ) - raise ValueError("Cannot construct a SeqTable, some rows are not arrays.") - if not all(row.shape == _SEQ_TABLE_ROW_SHAPE for row in rows): - raise ValueError( - "Cannot construct a SeqTable, some rows have incorrect shapes." - ) - if not all(row.dtype is SeqTableRowType for row in rows): - raise ValueError("Cannot construct a SeqTable, some rows have incorrect types.") - - return np.array(rows) - - -class SeqTablePvaTable(TypedDict): - repeats: NotRequired[pnd.Np1DArrayUint16] - trigger: NotRequired[Sequence[SeqTrigger]] - position: NotRequired[pnd.Np1DArrayInt32] - time1: NotRequired[pnd.Np1DArrayUint32] - outa1: NotRequired[pnd.Np1DArrayBool] - outb1: NotRequired[pnd.Np1DArrayBool] - outc1: NotRequired[pnd.Np1DArrayBool] - outd1: NotRequired[pnd.Np1DArrayBool] - oute1: NotRequired[pnd.Np1DArrayBool] - outf1: NotRequired[pnd.Np1DArrayBool] - time2: NotRequired[pnd.Np1DArrayUint32] - outa2: NotRequired[pnd.Np1DArrayBool] - outb2: NotRequired[pnd.Np1DArrayBool] - outc2: NotRequired[pnd.Np1DArrayBool] - outd2: NotRequired[pnd.Np1DArrayBool] - oute2: NotRequired[pnd.Np1DArrayBool] - outf2: NotRequired[pnd.Np1DArrayBool] - - -def convert_seq_table_to_columnwise_pva_table( - seq_table: pnd.NpNDArray, -) -> SeqTablePvaTable: - if seq_table.dtype != SeqTableRowType: - raise ValueError( - f"Cannot convert a SeqTable to a columnwise dictionary, " - f"input is not a SeqTable {seq_table.dtype}." - ) - print(seq_table) - transposed = seq_table.transpose(axis=1) - return dict(zip(_SEQ_TABLE_COLUMN_NAMES, transposed)) + + +class SeqTable(RootModel, PvaTableAbstraction): + root: pnd.NpNDArray = Field( + default_factory=lambda: np.array([], dtype=SeqTableRowType), + ) + + def convert_to_protocol_datatype(self) -> Dict[str, npt.ArrayLike]: + """Convert root to the column-wise dict representation for backend put""" + + if len(self.root) == 0: + transposed = { # list with empty arrays, each with correct dtype + name: np.array([], dtype=dtype) for name, dtype in SeqTableRowType.descr + } + else: + transposed_list = list(zip(*list(self.root))) + transposed = { + name: np.array(col, dtype=dtype) + for col, (name, dtype) in zip(transposed_list, SeqTableRowType.descr) + } + return transposed + + @classmethod + def convert_from_protocol_datatype( + cls, pva_table: Dict[str, npt.ArrayLike] + ) -> "SeqTable": + """Convert a pva table to a row-wise SeqTable.""" + + ordered_columns = [ + np.array(pva_table[name], dtype=dtype) + for name, dtype in SeqTableRowType.descr + ] + + transposed = list(zip(*ordered_columns)) + rows = np.array([tuple(row) for row in transposed], dtype=SeqTableRowType) + return cls(rows) + + @field_validator("root", mode="before") + @classmethod + def check_valid_rows(cls, rows: Union[Sequence, np.ndarray]): + assert isinstance( + rows, (np.ndarray, list) + ), "Rows must be a list or numpy array." + + if not (0 <= len(rows) < 4096): + raise ValueError(f"Length {len(rows)} not in range.") + + if not all(isinstance(row, (np.ndarray, np.void)) for row in rows): + raise ValueError( + "Cannot construct a SeqTable, some rows are not arrays." + ) + + if not all(row.dtype is SeqTableRowType for row in rows): + raise ValueError( + "Cannot construct a SeqTable, some rows have incorrect types." + ) + + return np.array(rows, dtype=SeqTableRowType) diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index 977e5781f3..c79988a381 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -2,16 +2,15 @@ from typing import Optional from pydantic import BaseModel, Field -from pydantic_numpy import NpNDArray from ophyd_async.core import TriggerLogic, wait_for_value from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits -from ._table import convert_seq_table_to_columnwise_pva_table +from ._table import SeqTable class SeqTableInfo(BaseModel): - sequence_table: NpNDArray = Field(strict=True) + sequence_table: SeqTable = Field(strict=True) repeats: int = Field(ge=0) prescale_as_us: float = Field(default=1, ge=0) # microseconds @@ -25,13 +24,10 @@ async def prepare(self, value: SeqTableInfo): self.seq.prescale_units.set(TimeUnits.us), self.seq.enable.set("ZERO"), ) - seq_table_pva_table = convert_seq_table_to_columnwise_pva_table( - value.sequence_table - ) await asyncio.gather( self.seq.prescale.set(value.prescale_as_us), self.seq.repeats.set(value.repeats), - self.seq.table.set(seq_table_pva_table), + self.seq.table.set(value.sequence_table), ) async def kickoff(self) -> None: diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 95394a6863..5ea184834e 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -13,8 +13,8 @@ from ophyd_async.fastcs.panda import ( PcompDirectionOptions, PcompInfo, + SeqTable, SeqTableInfo, - create_seq_table, seq_table_row, ) @@ -73,7 +73,8 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table = create_seq_table( + table = SeqTable( + [ # Wait for pre-delay then open shutter seq_table_row( time1=in_micros(pre_delay), @@ -91,6 +92,7 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( ), # Add the shutter close seq_table_row(time2=in_micros(shutter_time)), + ] ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 41af248aac..2fad24435f 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -7,8 +7,8 @@ from ophyd_async.epics.signal import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import make_converter as aioca_make_converter -from ophyd_async.epics.signal._p4p import make_converter as p4p_make_converter +from ophyd_async.epics.signal._aioca import CaConverterFactory +from ophyd_async.epics.signal._p4p import PvaConverterFactory async def test_runtime_enum_behaviour(): @@ -52,7 +52,7 @@ def __init__(self): epics_value = EpicsValue() rt_enum = SubsetEnum["A", "B"] - converter = aioca_make_converter( + converter = CaConverterFactory.make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert converter.choices == {"A": "A", "B": "B", "C": "C"} @@ -68,7 +68,7 @@ async def test_pva_runtime_enum_converter(): }, ) rt_enum = SubsetEnum["A", "B"] - converter = p4p_make_converter( + converter = PvaConverterFactory.make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert {"A", "B"}.issubset(set(converter.choices)) diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 4204e649f6..1827b7e5bb 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -3,7 +3,6 @@ import copy from typing import Dict -import numpy as np import pytest from ophyd_async.core import ( @@ -19,10 +18,8 @@ PcapBlock, PulseBlock, SeqBlock, - SeqTablePvaTable, - convert_seq_table_to_columnwise_pva_table, + SeqTable, SeqTrigger, - create_seq_table, seq_table_row, ) @@ -94,36 +91,88 @@ def test_panda_name_set(panda_t): async def test_panda_children_connected(mock_panda): # try to set and retrieve from simulated values... - table = create_seq_table( + table = table = SeqTable( + [ seq_table_row( - repeats=np.array([1, 1, 1, 32]).astype(np.uint16), - trigger=( - SeqTrigger.POSA_GT, - SeqTrigger.POSA_LT, - SeqTrigger.IMMEDIATE, - SeqTrigger.IMMEDIATE, - ), - position=np.array([3222, -565, 0, 0], dtype=np.int32), - time1=np.array([5, 0, 10, 10]).astype( - np.uint32 - ), # TODO: change below syntax. - outa1=np.array([1, 0, 0, 1]).astype(np.bool_), - outb1=np.array([0, 0, 1, 1]).astype(np.bool_), - outc1=np.array([0, 1, 1, 0]).astype(np.bool_), - outd1=np.array([1, 1, 0, 1]).astype(np.bool_), - oute1=np.array([1, 0, 1, 0]).astype(np.bool_), - outf1=np.array([1, 0, 0, 0]).astype(np.bool_), - time2=np.array([0, 10, 10, 11]).astype(np.uint32), - outa2=np.array([1, 0, 0, 1]).astype(np.bool_), - outb2=np.array([0, 0, 1, 1]).astype(np.bool_), - outc2=np.array([0, 1, 1, 0]).astype(np.bool_), - outd2=np.array([1, 1, 0, 1]).astype(np.bool_), - oute2=np.array([1, 0, 1, 0]).astype(np.bool_), - outf2=np.array([1, 0, 0, 0]).astype(np.bool_), - ) + repeats=1, + trigger=SeqTrigger.POSA_GT, + position=3222, + time1=5, + outa1=True, + outb1=False, + outc1=False, + outd1=True, + oute1=True, + outf1=True, + time2=0, + outa2=True, + outb2=False, + outc2=False, + outd2=True, + oute2=True, + outf2=True, + ), + seq_table_row( + repeats=1, + trigger=SeqTrigger.POSA_LT, + position=-565, + time1=0, + outa1=False, + outb1=False, + outc1=True, + outd1=True, + oute1=False, + outf1=False, + time2=10, + outa2=False, + outb2=False, + outc2=True, + outd2=True, + oute2=False, + outf2=False, + ), + seq_table_row( + repeats=1, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=False, + outb1=True, + outc1=True, + outd1=False, + oute1=True, + outf1=False, + time2=10, + outa2=False, + outb2=True, + outc2=True, + outd2=False, + oute2=True, + outf2=False, + ), + seq_table_row( + repeats=32, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=True, + outb1=True, + outc1=False, + outd1=True, + oute1=False, + outf1=False, + time2=11, + outa2=True, + outb2=True, + outc2=False, + outd2=True, + oute2=False, + outf2=False, + ), + ] ) await mock_panda.pulse[1].delay.set(20.0) - await mock_panda.seq[1].table.set(convert_seq_table_to_columnwise_pva_table(table)) + await mock_panda.seq[1].table.set(table) readback_pulse = await mock_panda.pulse[1].delay.get_value() readback_seq = await mock_panda.seq[1].table.get_value() @@ -171,7 +220,7 @@ async def test_panda_gets_types_from_common_class(panda_pva, panda_t): assert panda.pcap.active._backend.datatype is bool # works with custom datatypes - assert panda.seq[1].table._backend.datatype is SeqTablePvaTable + assert panda.seq[1].table._backend.datatype is SeqTable # others are given the None datatype assert panda.pcap.newsignal._backend.datatype is None diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index f5d4e02600..609cbe7aa9 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import numpy as np import pytest from bluesky import RunEngine @@ -10,6 +11,7 @@ CommonPandaBlocks, DataBlock, PcompDirectionOptions, + SeqTable, TimeUnits, phase_sorter, ) @@ -41,13 +43,22 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_save_panda(mock_save_to_yaml, mock_panda, RE: RunEngine): RE(save_device(mock_panda, "path", sorter=phase_sorter)) mock_save_to_yaml.assert_called_once() - assert mock_save_to_yaml.call_args[0] == ( - [ - { - "phase_1_signal_units": 0, - "seq.1.prescale_units": TimeUnits("min"), - "seq.2.prescale_units": TimeUnits("min"), - }, + + def check_equal_with_seq_tables(actual, expected): + assert set(actual.keys()) == set(expected.keys()) + for key, value1 in actual.items(): + value2 = expected[key] + if isinstance(value1, SeqTable): + assert np.array_equal(value1.root, value2.root) + else: + assert value1 == value2 + + assert mock_save_to_yaml.call_args[0][0][0] == { + "phase_1_signal_units": 0, + "seq.1.prescale_units": TimeUnits("min"), + "seq.2.prescale_units": TimeUnits("min"), + } + check_equal_with_seq_tables(mock_save_to_yaml.call_args[0][0][1], { "data.capture": False, "data.create_directory": 0, @@ -73,16 +84,15 @@ async def test_save_panda(mock_save_to_yaml, mock_panda, RE: RunEngine): "pulse.2.delay": 0.0, "pulse.2.width": 0.0, "seq.1.active": False, - "seq.1.table": {}, + "seq.1.table": SeqTable([]), "seq.1.repeats": 0, "seq.1.prescale": 0.0, "seq.1.enable": "ZERO", - "seq.2.table": {}, + "seq.2.table": SeqTable([]), "seq.2.active": False, "seq.2.repeats": 0, "seq.2.prescale": 0.0, "seq.2.enable": "ZERO", }, - ], - "path", ) + assert mock_save_to_yaml.call_args[0][1] == "path" diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index 9b795f1292..67e11c6354 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,34 +1,127 @@ -""" import numpy as np import pytest +from pydantic import ValidationError -from ophyd_async.fastcs.panda import seq_table_from_arrays +from ophyd_async.fastcs.panda import SeqTable, SeqTableRowType, seq_table_row -def test_from_arrays_inconsistent_lengths(): - length = 4 - time2 = np.zeros(length) - time1 = np.zeros(length + 1) - with pytest.raises(ValueError, match="time1: has length 5 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) - time1 = np.zeros(length - 1) - with pytest.raises(ValueError, match="time1: has length 3 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) +@pytest.mark.parametrize( + # factory so that there aren't global errors if seq_table_row() fails + "rows_arg_factory", + [ + lambda: None, + list, + lambda: [seq_table_row(), seq_table_row()], + lambda: np.array([seq_table_row(), seq_table_row()]) + ] +) +def test_seq_table_initialization_allowed_args(rows_arg_factory): + rows_arg = rows_arg_factory() + seq_table = SeqTable() if rows_arg is None else SeqTable(rows_arg) + assert isinstance(seq_table.root, np.ndarray) + assert len(seq_table.root) == (0 if rows_arg is None else len(rows_arg)) -def test_from_arrays_no_time(): - with pytest.raises(AssertionError, match="time2 must be provided"): - seq_table_from_arrays(time2=None) # type: ignore - with pytest.raises(TypeError, match="required keyword-only argument: 'time2'"): - seq_table_from_arrays() # type: ignore - time2 = np.zeros(0) - with pytest.raises(AssertionError, match="Length 0 not in range"): - seq_table_from_arrays(time2=time2) +def test_seq_table_validation_errors(): + + with pytest.raises( + ValueError, + match="Cannot construct a SeqTable, some rows are not arrays." + ): + SeqTable([seq_table_row().tolist()]) + with pytest.raises(ValidationError, match="Length 4098 not in range."): + SeqTable([seq_table_row() for _ in range(4098)]) + with pytest.raises( + ValidationError, + match="Cannot construct a SeqTable, some rows have incorrect types." + ): + SeqTable([seq_table_row(), np.array([1,2,3]), seq_table_row()]) + with pytest.raises( + ValidationError, + match="Cannot construct a SeqTable, some rows have incorrect types." + ): + SeqTable( + [ + seq_table_row(), + np.array(range(len(seq_table_row().tolist()))), + seq_table_row() + ] + ) + + +def test_seq_table_pva_conversion(): + expected_pva_dict = { + "repeats": np.array([1, 2, 3, 4], dtype=np.int32), + "trigger": np.array( + [ + "Immediate","Immediate","BITC=0", "Immediate" + ], dtype="U14" + ), + "position": np.array([1, 2, 3, 4], dtype=np.int32), + "time1": np.array([1, 0, 1, 0], dtype=np.int32), + "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd1": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf1": np.array([1, 0, 1, 0], dtype=np.bool_), + "time2": np.array([1, 2, 3, 4], dtype=np.int32), + "outa2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd2": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), + } + expected_numpy_table = np.array([ + (1, "Immediate", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (2, "Immediate", 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0), + (3, "BITC=0", 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1), + (4, "Immediate", 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0), + ], dtype = SeqTableRowType) + + # Can convert from PVA table + numpy_table_from_pva_dict = SeqTable.convert_from_protocol_datatype( + expected_pva_dict + ) + assert np.array_equal(numpy_table_from_pva_dict.root, expected_numpy_table) + assert ( + numpy_table_from_pva_dict.root.dtype == + expected_numpy_table.dtype == + SeqTableRowType + ) + + # Can convert to PVA table + pva_dict_from_numpy_table = SeqTable( + expected_numpy_table + ).convert_to_protocol_datatype() + for column1, column2 in zip( + pva_dict_from_numpy_table.values(), + expected_pva_dict.values() + ): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype + + # Idempotency + applied_twice_to_numpy_table = SeqTable.convert_from_protocol_datatype( + SeqTable(expected_numpy_table).convert_to_protocol_datatype() + ) + assert np.array_equal(applied_twice_to_numpy_table.root, expected_numpy_table) + assert ( + applied_twice_to_numpy_table.root.dtype == + expected_numpy_table.dtype == + SeqTableRowType + ) + + applied_twice_to_pva_dict = SeqTable( + SeqTable.convert_from_protocol_datatype(expected_pva_dict).root + ).convert_to_protocol_datatype() + for column1, column2 in zip( + applied_twice_to_pva_dict.values(), + expected_pva_dict.values() + ): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype -def test_from_arrays_too_long(): - time2 = np.zeros(4097) - with pytest.raises(AssertionError, match="Length 4097 not in range"): - seq_table_from_arrays(time2=time2) -""" diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 005daccaf9..5450b23634 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -8,10 +8,10 @@ from ophyd_async.fastcs.panda import ( CommonPandaBlocks, PcompInfo, + SeqTable, SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, - create_seq_table, seq_table_row, ) @@ -38,11 +38,13 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_seq_table_trigger_logic(mock_panda): trigger_logic = StaticSeqTableTriggerLogic(mock_panda.seq[1]) - seq_table = create_seq_table( + seq_table = SeqTable( + [ seq_table_row(outa1=True, outa2=True), seq_table_row(outa1=False, outa2=False), seq_table_row(outa1=True, outa2=False), seq_table_row(outa1=False, outa2=True), + ] ) seq_table_info = SeqTableInfo(sequence_table=seq_table, repeats=1) @@ -79,7 +81,7 @@ async def set_active(value: bool): [ ( { - "sequence_table": create_seq_table(seq_table_row(outc2=1)), + "sequence_table": SeqTable([seq_table_row(outc2=1)]), "repeats": 0, "prescale_as_us": -1, }, @@ -88,11 +90,13 @@ async def set_active(value: bool): ), ( { - "sequence_table": create_seq_table( + "sequence_table": SeqTable( + [ seq_table_row(outc2=True), seq_table_row(outc2=False), seq_table_row(outc2=True), seq_table_row(outc2=False), + ] ), "repeats": -1, }, @@ -104,9 +108,8 @@ async def set_active(value: bool): "sequence_table": 1, "repeats": 1, }, - "Value error, Cannot construct a SeqTable, " - "input is not an unpacked tuple of `SeqTableRowType`. " - "[type=value_error, input_value=1, input_type=int]", + "Assertion failed, Rows must be a list or numpy array. " + "[type=assertion_error, input_value=1, input_type=int]" ), ], ) From 7d54c852c52c151a5776f92b4fc5cb4c2c616cb0 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Fri, 16 Aug 2024 13:33:30 +0100 Subject: [PATCH 03/11] changed to prevent signal initialisation with unknown datatype --- src/ophyd_async/core/_device_save_loader.py | 16 ++ src/ophyd_async/core/_signal_backend.py | 2 + src/ophyd_async/core/_soft_signal_backend.py | 11 +- src/ophyd_async/core/_utils.py | 2 +- src/ophyd_async/epics/signal/_aioca.py | 14 +- src/ophyd_async/epics/signal/_p4p.py | 1 + src/ophyd_async/fastcs/panda/_table.py | 7 +- src/ophyd_async/plan_stubs/_fly.py | 34 ++-- tests/core/test_signal.py | 23 +++ tests/fastcs/panda/test_panda_connect.py | 156 +++++++++---------- tests/fastcs/panda/test_panda_utils.py | 34 ++-- tests/fastcs/panda/test_table.py | 58 ++++--- tests/fastcs/panda/test_trigger.py | 18 +-- 13 files changed, 210 insertions(+), 166 deletions(-) diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 5b81228264..02ceeeb0af 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -10,6 +10,7 @@ from ._device import Device from ._signal import SignalRW +from ._signal_backend import ProtocolDatatypeAbstraction def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: @@ -18,6 +19,16 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) +def protocol_datatype_abstraction_representer( + dumper: yaml.Dumper, protocol_datatype_abstraction: ProtocolDatatypeAbstraction +) -> yaml.Node: + """Uses the protocol datatype since it has to be serializable.""" + + return dumper.represent_data( + protocol_datatype_abstraction.convert_to_protocol_datatype() + ) + + class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): @@ -134,6 +145,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: """ yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) + yaml.add_multi_representer( + ProtocolDatatypeAbstraction, + protocol_datatype_abstraction_representer, + Dumper=yaml.Dumper, + ) with open(save_path, "w") as file: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index c66cc0b50a..8427ee502c 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -23,6 +23,8 @@ def convert_to_protocol_datatype(self) -> T: """ Convert the abstract datatype to a form which can be sent over whichever protocol. + + This output will be used when the device is serialized. """ @classmethod diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 3e6b821f09..5d64d270cc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -132,13 +132,17 @@ class SoftProtocolDatatypeAbstractionConverter(SoftConverter): No conversion is necessary for ProtocolDatatypeAbstraction datatypes in soft signals. """ + def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]): self.datatype = datatype def reading(self, value: T, timestamp: float, severity: int) -> Reading: + value = self.value(value) return super().reading(value, timestamp, severity) def value(self, value: Any) -> Any: + if not issubclass(type(value), ProtocolDatatypeAbstraction): + value = self.datatype.convert_from_protocol_datatype(value) return value def write_value(self, value): @@ -153,8 +157,7 @@ class SoftSignalConverterFactory(BackendConverterFactory): @classmethod def datatype_allowed(cls, datatype: Type) -> bool: - return True # Any value allowed in a soft signal - + return True # Any value allowed in a soft signal @classmethod def make_converter(cls, datatype): @@ -164,8 +167,7 @@ def make_converter(cls, datatype): issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) is_convertable_abstract_datatype = inspect.isclass(datatype) and issubclass( - datatype, - ProtocolDatatypeAbstraction + datatype, ProtocolDatatypeAbstraction ) if is_array or is_sequence: @@ -195,6 +197,7 @@ def __init__( self.datatype = datatype self._initial_value = initial_value self._metadata = metadata or {} + self.converter_factory = SoftSignalConverterFactory self.converter: SoftConverter = SoftSignalConverterFactory.make_converter( datatype ) diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index f5098ce717..d081ed008f 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -145,7 +145,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]: def get_unique(values: Dict[str, T], types: str) -> T: - """If all values are the same, return that value, otherwise return TypeError + """If all values are the same, return that value, otherwise raise TypeError >>> get_unique({"a": 1, "b": 1}, "integers") 1 diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index ce57acba80..aa15d5a570 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -195,23 +195,23 @@ class CaConverterFactory(BackendConverterFactory): Sequence, Enum, RuntimeSubsetEnum, - np.ndarray + np.ndarray, ) @classmethod - def datatype_allowed(cls, datatype: Type) -> bool: + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_TYPES + stripped_origin, cls._ALLOWED_TYPES ) @classmethod def make_converter( cls, datatype: Optional[Type], values: Dict[str, AugmentedValue] ) -> CaConverter: - if datatype is not None and not cls.datatype_allowed(datatype): - raise TypeError(f"Given datatype {datatype} unsupported in CA.") - pv = list(values)[0] pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") is_array = bool([v for v in values.values() if v.element_count > 1]) @@ -287,6 +287,8 @@ def _use_pyepics_context_if_imported(): class CaSignalBackend(SignalBackend[T]): def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not CaConverterFactory.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, AugmentedValue] = {} diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 3f95740522..ac220122e5 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -425,6 +425,7 @@ def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} + self.converter_factory = PvaConverterFactory() self.converter: PvaConverter = DisconnectedPvaConverter() self.subscription: Optional[Subscription] = None diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 12d201df22..6dba645cd1 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -55,7 +55,6 @@ class SeqTrigger(str, Enum): ("outd2", np.bool_), ("oute2", np.bool_), ("outf2", np.bool_), - ] ) @@ -104,8 +103,6 @@ def seq_table_row( ) - - class SeqTable(RootModel, PvaTableAbstraction): root: pnd.NpNDArray = Field( default_factory=lambda: np.array([], dtype=SeqTableRowType), @@ -152,9 +149,7 @@ def check_valid_rows(cls, rows: Union[Sequence, np.ndarray]): raise ValueError(f"Length {len(rows)} not in range.") if not all(isinstance(row, (np.ndarray, np.void)) for row in rows): - raise ValueError( - "Cannot construct a SeqTable, some rows are not arrays." - ) + raise ValueError("Cannot construct a SeqTable, some rows are not arrays.") if not all(row.dtype is SeqTableRowType for row in rows): raise ValueError( diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 5ea184834e..04adf046a8 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -75,23 +75,23 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( table = SeqTable( [ - # Wait for pre-delay then open shutter - seq_table_row( - time1=in_micros(pre_delay), - time2=in_micros(shutter_time), - outa2=True, - ), - # Keeping shutter open, do N triggers - seq_table_row( - repeats=number_of_frames, - time1=in_micros(exposure), - outa1=True, - outb1=True, - time2=in_micros(deadtime), - outa2=True, - ), - # Add the shutter close - seq_table_row(time2=in_micros(shutter_time)), + # Wait for pre-delay then open shutter + seq_table_row( + time1=in_micros(pre_delay), + time2=in_micros(shutter_time), + outa2=True, + ), + # Keeping shutter open, do N triggers + seq_table_row( + repeats=number_of_frames, + time1=in_micros(exposure), + outa1=True, + outb1=True, + time2=in_micros(deadtime), + outa2=True, + ), + # Add the shutter close + seq_table_row(time2=in_micros(shutter_time)), ] ) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3b4c4934f4..3adafdd9f8 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -403,3 +403,26 @@ async def test_subscription_logs(caplog): assert "Making subscription" in caplog.text mock_signal_rw.clear_sub(cbs.append) assert "Closing subscription on source" in caplog.text + + +async def test_signal_unknown_datatype(): + class SomeClass: + def __init__(self): + self.some_attribute = "some_attribute" + + def some_function(self): + pass + + # with pytest.raises(ValueError, match="Unknown datatype 'SomeClass'"): + err_str = ( + "Given datatype .SomeClass'>" + " unsupported in %s." + ) + with pytest.raises(TypeError, match=err_str % ("PVA",)): + epics_signal_rw(SomeClass, "pva://mock_signal", name="mock_signal") + with pytest.raises(TypeError, match=err_str % ("CA",)): + epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") + + # Any dtype allowed in soft signal + soft_signal_rw(SomeClass, SomeClass(), "soft_signal") diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 1827b7e5bb..a4bc23d41c 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -92,84 +92,84 @@ def test_panda_name_set(panda_t): async def test_panda_children_connected(mock_panda): # try to set and retrieve from simulated values... table = table = SeqTable( - [ - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_GT, - position=3222, - time1=5, - outa1=True, - outb1=False, - outc1=False, - outd1=True, - oute1=True, - outf1=True, - time2=0, - outa2=True, - outb2=False, - outc2=False, - outd2=True, - oute2=True, - outf2=True, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_LT, - position=-565, - time1=0, - outa1=False, - outb1=False, - outc1=True, - outd1=True, - oute1=False, - outf1=False, - time2=10, - outa2=False, - outb2=False, - outc2=True, - outd2=True, - oute2=False, - outf2=False, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=False, - outb1=True, - outc1=True, - outd1=False, - oute1=True, - outf1=False, - time2=10, - outa2=False, - outb2=True, - outc2=True, - outd2=False, - oute2=True, - outf2=False, - ), - seq_table_row( - repeats=32, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=True, - outb1=True, - outc1=False, - outd1=True, - oute1=False, - outf1=False, - time2=11, - outa2=True, - outb2=True, - outc2=False, - outd2=True, - oute2=False, - outf2=False, - ), - ] + [ + seq_table_row( + repeats=1, + trigger=SeqTrigger.POSA_GT, + position=3222, + time1=5, + outa1=True, + outb1=False, + outc1=False, + outd1=True, + oute1=True, + outf1=True, + time2=0, + outa2=True, + outb2=False, + outc2=False, + outd2=True, + oute2=True, + outf2=True, + ), + seq_table_row( + repeats=1, + trigger=SeqTrigger.POSA_LT, + position=-565, + time1=0, + outa1=False, + outb1=False, + outc1=True, + outd1=True, + oute1=False, + outf1=False, + time2=10, + outa2=False, + outb2=False, + outc2=True, + outd2=True, + oute2=False, + outf2=False, + ), + seq_table_row( + repeats=1, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=False, + outb1=True, + outc1=True, + outd1=False, + oute1=True, + outf1=False, + time2=10, + outa2=False, + outb2=True, + outc2=True, + outd2=False, + oute2=True, + outf2=False, + ), + seq_table_row( + repeats=32, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=True, + outb1=True, + outc1=False, + outd1=True, + oute1=False, + outf1=False, + time2=11, + outa2=True, + outb2=True, + outc2=False, + outd2=True, + oute2=False, + outf2=False, + ), + ] ) await mock_panda.pulse[1].delay.set(20.0) await mock_panda.seq[1].table.set(table) diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index 609cbe7aa9..a27c3e8d56 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,24 +1,20 @@ -from unittest.mock import patch import numpy as np -import pytest from bluesky import RunEngine -from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, save_device +from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, load_device, save_device from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, - PcompDirectionOptions, SeqTable, - TimeUnits, phase_sorter, + seq_table_row, ) -@pytest.fixture -async def mock_panda(): +async def get_mock_panda(): class Panda(CommonPandaBlocks): data: DataBlock @@ -35,14 +31,14 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async with DeviceCollector(mock=True): mock_panda = Panda("PANDA") mock_panda.phase_1_signal_units = epics_signal_rw(int, "") - assert mock_panda.name == "mock_panda" - yield mock_panda + return mock_panda -@patch("ophyd_async.core._device_save_loader.save_to_yaml") -async def test_save_panda(mock_save_to_yaml, mock_panda, RE: RunEngine): - RE(save_device(mock_panda, "path", sorter=phase_sorter)) - mock_save_to_yaml.assert_called_once() +async def test_save_load_panda(tmp_path, RE: RunEngine): + mock_panda1 = await get_mock_panda() + await mock_panda1.seq[1].table.set(SeqTable([seq_table_row(repeats=1)])) + + RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) def check_equal_with_seq_tables(actual, expected): assert set(actual.keys()) == set(expected.keys()) @@ -53,6 +49,17 @@ def check_equal_with_seq_tables(actual, expected): else: assert value1 == value2 + mock_panda2 = await get_mock_panda() + assert np.array_equal( + (await mock_panda2.seq[1].table.get_value()).root, SeqTable([]).root + ) + RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) + assert np.array_equal( + (await mock_panda2.seq[1].table.get_value()).root, + SeqTable([seq_table_row(repeats=1)]).root, + ) + + """ assert mock_save_to_yaml.call_args[0][0][0] == { "phase_1_signal_units": 0, "seq.1.prescale_units": TimeUnits("min"), @@ -96,3 +103,4 @@ def check_equal_with_seq_tables(actual, expected): }, ) assert mock_save_to_yaml.call_args[0][1] == "path" + """ diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index 67e11c6354..88bacabca4 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -12,39 +12,37 @@ lambda: None, list, lambda: [seq_table_row(), seq_table_row()], - lambda: np.array([seq_table_row(), seq_table_row()]) - ] + lambda: np.array([seq_table_row(), seq_table_row()]), + ], ) def test_seq_table_initialization_allowed_args(rows_arg_factory): - rows_arg = rows_arg_factory() seq_table = SeqTable() if rows_arg is None else SeqTable(rows_arg) assert isinstance(seq_table.root, np.ndarray) assert len(seq_table.root) == (0 if rows_arg is None else len(rows_arg)) -def test_seq_table_validation_errors(): +def test_seq_table_validation_errors(): with pytest.raises( - ValueError, - match="Cannot construct a SeqTable, some rows are not arrays." + ValueError, match="Cannot construct a SeqTable, some rows are not arrays." ): SeqTable([seq_table_row().tolist()]) with pytest.raises(ValidationError, match="Length 4098 not in range."): SeqTable([seq_table_row() for _ in range(4098)]) with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types." + match="Cannot construct a SeqTable, some rows have incorrect types.", ): - SeqTable([seq_table_row(), np.array([1,2,3]), seq_table_row()]) + SeqTable([seq_table_row(), np.array([1, 2, 3]), seq_table_row()]) with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types." + match="Cannot construct a SeqTable, some rows have incorrect types.", ): SeqTable( [ seq_table_row(), np.array(range(len(seq_table_row().tolist()))), - seq_table_row() + seq_table_row(), ] ) @@ -53,9 +51,7 @@ def test_seq_table_pva_conversion(): expected_pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), "trigger": np.array( - [ - "Immediate","Immediate","BITC=0", "Immediate" - ], dtype="U14" + ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype="U14" ), "position": np.array([1, 2, 3, 4], dtype=np.int32), "time1": np.array([1, 0, 1, 0], dtype=np.int32), @@ -73,12 +69,15 @@ def test_seq_table_pva_conversion(): "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), } - expected_numpy_table = np.array([ - (1, "Immediate", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), - (2, "Immediate", 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0), - (3, "BITC=0", 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1), - (4, "Immediate", 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0), - ], dtype = SeqTableRowType) + expected_numpy_table = np.array( + [ + (1, "Immediate", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (2, "Immediate", 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0), + (3, "BITC=0", 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1), + (4, "Immediate", 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0), + ], + dtype=SeqTableRowType, + ) # Can convert from PVA table numpy_table_from_pva_dict = SeqTable.convert_from_protocol_datatype( @@ -86,9 +85,9 @@ def test_seq_table_pva_conversion(): ) assert np.array_equal(numpy_table_from_pva_dict.root, expected_numpy_table) assert ( - numpy_table_from_pva_dict.root.dtype == - expected_numpy_table.dtype == - SeqTableRowType + numpy_table_from_pva_dict.root.dtype + == expected_numpy_table.dtype + == SeqTableRowType ) # Can convert to PVA table @@ -96,8 +95,7 @@ def test_seq_table_pva_conversion(): expected_numpy_table ).convert_to_protocol_datatype() for column1, column2 in zip( - pva_dict_from_numpy_table.values(), - expected_pva_dict.values() + pva_dict_from_numpy_table.values(), expected_pva_dict.values() ): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype @@ -108,20 +106,16 @@ def test_seq_table_pva_conversion(): ) assert np.array_equal(applied_twice_to_numpy_table.root, expected_numpy_table) assert ( - applied_twice_to_numpy_table.root.dtype == - expected_numpy_table.dtype == - SeqTableRowType + applied_twice_to_numpy_table.root.dtype + == expected_numpy_table.dtype + == SeqTableRowType ) applied_twice_to_pva_dict = SeqTable( SeqTable.convert_from_protocol_datatype(expected_pva_dict).root ).convert_to_protocol_datatype() for column1, column2 in zip( - applied_twice_to_pva_dict.values(), - expected_pva_dict.values() + applied_twice_to_pva_dict.values(), expected_pva_dict.values() ): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype - - - diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 5450b23634..a12ee32200 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -40,10 +40,10 @@ async def test_seq_table_trigger_logic(mock_panda): trigger_logic = StaticSeqTableTriggerLogic(mock_panda.seq[1]) seq_table = SeqTable( [ - seq_table_row(outa1=True, outa2=True), - seq_table_row(outa1=False, outa2=False), - seq_table_row(outa1=True, outa2=False), - seq_table_row(outa1=False, outa2=True), + seq_table_row(outa1=True, outa2=True), + seq_table_row(outa1=False, outa2=False), + seq_table_row(outa1=True, outa2=False), + seq_table_row(outa1=False, outa2=True), ] ) seq_table_info = SeqTableInfo(sequence_table=seq_table, repeats=1) @@ -92,10 +92,10 @@ async def set_active(value: bool): { "sequence_table": SeqTable( [ - seq_table_row(outc2=True), - seq_table_row(outc2=False), - seq_table_row(outc2=True), - seq_table_row(outc2=False), + seq_table_row(outc2=True), + seq_table_row(outc2=False), + seq_table_row(outc2=True), + seq_table_row(outc2=False), ] ), "repeats": -1, @@ -109,7 +109,7 @@ async def set_active(value: bool): "repeats": 1, }, "Assertion failed, Rows must be a list or numpy array. " - "[type=assertion_error, input_value=1, input_type=int]" + "[type=assertion_error, input_value=1, input_type=int]", ), ], ) From 8d129255eb95cdc4b397cbe478b06684ce68d734 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Fri, 16 Aug 2024 14:44:35 +0100 Subject: [PATCH 04/11] WIP: deciding on certain features in backend conversion --- src/ophyd_async/core/_soft_signal_backend.py | 9 ++------- src/ophyd_async/epics/signal/_p4p.py | 1 - tests/core/test_device_save_loader.py | 17 +++++++++++++++++ tests/test_data/test_yaml_save.yml | 1 + 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 5d64d270cc..619400c8fd 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -128,11 +128,6 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftProtocolDatatypeAbstractionConverter(SoftConverter): - """ - No conversion is necessary for ProtocolDatatypeAbstraction datatypes in soft - signals. - """ - def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]): self.datatype = datatype @@ -141,7 +136,8 @@ def reading(self, value: T, timestamp: float, severity: int) -> Reading: return super().reading(value, timestamp, severity) def value(self, value: Any) -> Any: - if not issubclass(type(value), ProtocolDatatypeAbstraction): + if not isinstance(value, self.datatype): + # For the case where we value = self.datatype.convert_from_protocol_datatype(value) return value @@ -197,7 +193,6 @@ def __init__( self.datatype = datatype self._initial_value = initial_value self._metadata = metadata or {} - self.converter_factory = SoftSignalConverterFactory self.converter: SoftConverter = SoftSignalConverterFactory.make_converter( datatype ) diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index ac220122e5..3f95740522 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -425,7 +425,6 @@ def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} - self.converter_factory = PvaConverterFactory() self.converter: PvaConverter = DisconnectedPvaConverter() self.subscription: Optional[Subscription] = None diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index aa60be9802..16d18696ec 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -22,6 +22,7 @@ set_signal_values, walk_rw_signals, ) +from ophyd_async.core._signal_backend import ProtocolDatatypeAbstraction from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw @@ -54,6 +55,18 @@ class MyEnum(str, Enum): three = "three" +class SomeProtocolDatatypeAbstraction(ProtocolDatatypeAbstraction): + def __init__(self, value: int): + self.value = value + + def convert_to_protocol_datatype(self) -> int: + return self.value - 1 + + @classmethod + def convert_from_protocol_datatype(cls, value: int) -> "SomeProtocolDatatypeAbstraction": + return cls(value + 1) + + class DummyDeviceGroupAllTypes(Device): def __init__(self, name: str): self.pv_int: SignalRW = epics_signal_rw(int, "PV1") @@ -73,6 +86,7 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") + self.pv_protocol_device_abstraction = epics_signal_rw(SomeProtocolDatatypeAbstraction, "PV17") @pytest.fixture @@ -155,6 +169,9 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): await device_all_types.pv_array_str.set( ["one", "two", "three"], ) + await device_all_types.pv_protocol_device_abstraction.set( + SomeProtocolDatatypeAbstraction(1) + ) # Create save plan from utility functions def save_my_device(): diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index fc3e1ebd95..349536ecdc 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -20,3 +20,4 @@ pv_float: 1.234 pv_int: 1 pv_str: test_string + pv_protocol_device_abstraction: 0 From e1ed77b91f1d9bdc1f365d53d6dda5aa72c66ea3 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 20 Aug 2024 14:30:55 +0100 Subject: [PATCH 05/11] allowed backends to take pydantic models Also finished a `PvaTable` type. --- src/ophyd_async/core/__init__.py | 2 - src/ophyd_async/core/_device_save_loader.py | 14 +- src/ophyd_async/core/_signal_backend.py | 23 -- src/ophyd_async/core/_soft_signal_backend.py | 26 +- src/ophyd_async/epics/signal/__init__.py | 4 +- .../epics/signal/_p4p_table_abstraction.py | 70 ++++++ src/ophyd_async/fastcs/panda/__init__.py | 4 - src/ophyd_async/fastcs/panda/_table.py | 235 +++++++++--------- src/ophyd_async/plan_stubs/_fly.py | 41 ++- tests/core/test_device_save_loader.py | 20 +- tests/fastcs/panda/test_panda_connect.py | 157 ++++++------ tests/fastcs/panda/test_panda_utils.py | 26 +- tests/fastcs/panda/test_table.py | 222 +++++++++++------ tests/fastcs/panda/test_trigger.py | 34 ++- tests/test_data/test_yaml_save.yml | 3 +- 15 files changed, 493 insertions(+), 388 deletions(-) create mode 100644 src/ophyd_async/epics/signal/_p4p_table_abstraction.py diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 1b082f2835..d66cbc5ba3 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -63,7 +63,6 @@ ) from ._signal_backend import ( BackendConverterFactory, - ProtocolDatatypeAbstraction, RuntimeSubsetEnum, SignalBackend, SubsetEnum, @@ -124,7 +123,6 @@ "NameProvider", "PathInfo", "PathProvider", - "ProtocolDatatypeAbstraction", "ShapeProvider", "StaticFilenameProvider", "StaticPathProvider", diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 02ceeeb0af..a40c404d50 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -7,10 +7,10 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg +from pydantic import BaseModel from ._device import Device from ._signal import SignalRW -from ._signal_backend import ProtocolDatatypeAbstraction def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: @@ -19,14 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) -def protocol_datatype_abstraction_representer( - dumper: yaml.Dumper, protocol_datatype_abstraction: ProtocolDatatypeAbstraction +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel ) -> yaml.Node: """Uses the protocol datatype since it has to be serializable.""" - return dumper.represent_data( - protocol_datatype_abstraction.convert_to_protocol_datatype() - ) + return dumper.represent_data(model.model_dump(mode="python")) class OphydDumper(yaml.Dumper): @@ -146,8 +144,8 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) yaml.add_multi_representer( - ProtocolDatatypeAbstraction, - protocol_datatype_abstraction_representer, + BaseModel, + pydantic_model_abstraction_representer, Dumper=yaml.Dumper, ) diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 8427ee502c..178ff1edfa 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -13,29 +13,6 @@ from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T -class ProtocolDatatypeAbstraction(ABC, Generic[T]): - @abstractmethod - def __init__(self): - """The abstract datatype must be able to be intialized with no arguments.""" - - @abstractmethod - def convert_to_protocol_datatype(self) -> T: - """ - Convert the abstract datatype to a form which can be sent - over whichever protocol. - - This output will be used when the device is serialized. - """ - - @classmethod - @abstractmethod - def convert_from_protocol_datatype(cls, value: T) -> "ProtocolDatatypeAbstraction": - """ - Convert the datatype received from the protocol to a - higher level abstract datatype. - """ - - class BackendConverterFactory(ABC): """Convert between the signal backend and the signal type""" diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 619400c8fd..26b85272a7 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -8,11 +8,11 @@ import numpy as np from bluesky.protocols import DataKey, Dtype, Reading +from pydantic import BaseModel from typing_extensions import TypedDict from ._signal_backend import ( BackendConverterFactory, - ProtocolDatatypeAbstraction, RuntimeSubsetEnum, SignalBackend, ) @@ -127,8 +127,10 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) -class SoftProtocolDatatypeAbstractionConverter(SoftConverter): - def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]): +class SoftPydanticModelConverter(SoftConverter): + """Necessary for serializing soft signals.""" + + def __init__(self, datatype: Type[BaseModel]): self.datatype = datatype def reading(self, value: T, timestamp: float, severity: int) -> Reading: @@ -136,12 +138,16 @@ def reading(self, value: T, timestamp: float, severity: int) -> Reading: return super().reading(value, timestamp, severity) def value(self, value: Any) -> Any: - if not isinstance(value, self.datatype): - # For the case where we - value = self.datatype.convert_from_protocol_datatype(value) + if isinstance(value, dict): + value = self.datatype(**value) return value def write_value(self, value): + if isinstance(value, dict): + # If the device is being deserialized + return self.datatype(**value).model_dump(mode="python") + if isinstance(value, self.datatype): + return value.model_dump(mode="python") return value def make_initial_value(self, datatype: Type | None) -> Any: @@ -162,16 +168,16 @@ def make_converter(cls, datatype): is_enum = inspect.isclass(datatype) and ( issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) - is_convertable_abstract_datatype = inspect.isclass(datatype) and issubclass( - datatype, ProtocolDatatypeAbstraction + is_pydantic_model = inspect.isclass(datatype) and issubclass( + datatype, BaseModel ) if is_array or is_sequence: return SoftArrayConverter() if is_enum: return SoftEnumConverter(datatype) - if is_convertable_abstract_datatype: - return SoftProtocolDatatypeAbstractionConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) return SoftConverter() diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index 5da098b59e..f9bd58306f 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,5 +1,6 @@ from ._common import LimitPair, Limits, get_supported_values -from ._p4p import PvaSignalBackend, PvaTableAbstraction +from ._p4p import PvaSignalBackend +from ._p4p_table_abstraction import PvaTable from ._signal import ( epics_signal_r, epics_signal_rw, @@ -13,6 +14,7 @@ "LimitPair", "Limits", "PvaSignalBackend", + "PvaTable", "PvaTableAbstraction", "epics_signal_r", "epics_signal_rw", diff --git a/src/ophyd_async/epics/signal/_p4p_table_abstraction.py b/src/ophyd_async/epics/signal/_p4p_table_abstraction.py new file mode 100644 index 0000000000..a6e5ecf566 --- /dev/null +++ b/src/ophyd_async/epics/signal/_p4p_table_abstraction.py @@ -0,0 +1,70 @@ +from typing import Dict + +import numpy as np +from pydantic import BaseModel, ConfigDict, model_validator +from pydantic_numpy.typing import NpNDArray + + +class PvaTable(BaseModel): + """An abstraction of a PVA Table of str to python array.""" + + model_config = ConfigDict(validate_assignment=True, strict=False) + + @classmethod + def row(cls, sub_cls, **kwargs) -> "PvaTable": + arrayified_kwargs = { + field_name: np.concatenate( + ( + (default_arr := field_value.default_factory()), + np.array([kwargs[field_name]], dtype=default_arr.dtype), + ) + ) + for field_name, field_value in sub_cls.model_fields.items() + } + return sub_cls(**arrayified_kwargs) + + def __add__(self, right: "PvaTable") -> "PvaTable": + """Concatinate the arrays in field values.""" + + assert isinstance(right, type(self)), ( + f"{right} is not a `PvaTable`, or is not the same " + f"type of `PvaTable` as {self}." + ) + + return type(self)( + **{ + field_name: np.concatenate( + (getattr(self, field_name), getattr(right, field_name)) + ) + for field_name in self.model_fields + } + ) + + @model_validator(mode="after") + def validate_arrays(self) -> "PvaTable": + first_length = len(next(iter(self))[1]) + assert all( + len(field_value) == first_length for _, field_value in self + ), "Rows should all be of equal size." + + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + + if not all( + np.issubdtype( + self.model_fields[field_name].default_factory().dtype, field_value.dtype + ) + for field_name, field_value in self + ): + raise ValueError( + f"Cannot construct a `{type(self).__name__}`, " + "some rows have incorrect types." + ) + + return self + + def convert_to_pva_datatype(self) -> Dict[str, NpNDArray]: + return self.model_dump(mode="python") + + @classmethod + def convert_from_pva_datatype(cls, pva_table: Dict[str, NpNDArray]): + return cls(**pva_table) diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 3724397053..0dbe7222b0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -15,9 +15,7 @@ DatasetTable, PandaHdf5DatasetType, SeqTable, - SeqTableRowType, SeqTrigger, - seq_table_row, ) from ._trigger import ( PcompInfo, @@ -44,9 +42,7 @@ "DatasetTable", "PandaHdf5DatasetType", "SeqTable", - "SeqTableRowType", "SeqTrigger", - "seq_table_row", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 6dba645cd1..b1ed6d5729 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,13 +1,13 @@ from enum import Enum -from typing import Dict, Sequence, Union +from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy as pnd -from pydantic import Field, RootModel, field_validator +from pydantic import Field +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation from typing_extensions import TypedDict -from ophyd_async.epics.signal import PvaTableAbstraction +from ophyd_async.epics.signal import PvaTable class PandaHdf5DatasetType(str, Enum): @@ -36,124 +36,119 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -SeqTableRowType = np.dtype( - [ - ("repeats", np.int32), - ("trigger", "U14"), # One of the SeqTrigger values - ("position", np.int32), - ("time1", np.int32), - ("outa1", np.bool_), - ("outb1", np.bool_), - ("outc1", np.bool_), - ("outd1", np.bool_), - ("oute1", np.bool_), - ("outf1", np.bool_), - ("time2", np.int32), - ("outa2", np.bool_), - ("outb2", np.bool_), - ("outc2", np.bool_), - ("outd2", np.bool_), - ("oute2", np.bool_), - ("outf2", np.bool_), - ] -) - - -def seq_table_row( - *, - repeats: int = 0, - trigger: str = "", - position: int = 0, - time1: int = 0, - outa1: bool = False, - outb1: bool = False, - outc1: bool = False, - outd1: bool = False, - oute1: bool = False, - outf1: bool = False, - time2: int = 0, - outa2: bool = False, - outb2: bool = False, - outc2: bool = False, - outd2: bool = False, - oute2: bool = False, - outf2: bool = False, -) -> pnd.NpNDArray: - return np.array( - ( - repeats, - trigger, - position, - time1, - outa1, - outb1, - outc1, - outd1, - oute1, - outf1, - time2, - outa2, - outb2, - outc2, - outd2, - oute2, - outf2, - ), - dtype=SeqTableRowType, +PydanticNp1DArrayInt32 = Annotated[ + np.ndarray[tuple[int], np.int32], + NpArrayPydanticAnnotation.factory( + data_type=np.int32, dimensions=1, strict_data_typing=False + ), +] +PydanticNp1DArrayBool = Annotated[ + np.ndarray[tuple[int], np.bool_], + NpArrayPydanticAnnotation.factory( + data_type=np.bool_, dimensions=1, strict_data_typing=False + ), +] + +PydanticNp1DArrayUnicodeString = Annotated[ + np.ndarray[tuple[int], np.unicode_], + NpArrayPydanticAnnotation.factory( + data_type=np.unicode_, dimensions=1, strict_data_typing=False + ), +] + + +class SeqTable(PvaTable): + repeats: PydanticNp1DArrayInt32 = Field( + default_factory=lambda: np.array([], np.int32) ) - - -class SeqTable(RootModel, PvaTableAbstraction): - root: pnd.NpNDArray = Field( - default_factory=lambda: np.array([], dtype=SeqTableRowType), + trigger: PydanticNp1DArrayUnicodeString = Field( + default_factory=lambda: np.array([], dtype=np.dtype(" Dict[str, npt.ArrayLike]: - """Convert root to the column-wise dict representation for backend put""" - - if len(self.root) == 0: - transposed = { # list with empty arrays, each with correct dtype - name: np.array([], dtype=dtype) for name, dtype in SeqTableRowType.descr - } - else: - transposed_list = list(zip(*list(self.root))) - transposed = { - name: np.array(col, dtype=dtype) - for col, (name, dtype) in zip(transposed_list, SeqTableRowType.descr) - } - return transposed @classmethod - def convert_from_protocol_datatype( - cls, pva_table: Dict[str, npt.ArrayLike] + def row( + cls, + *, + repeats: int = 0, + trigger: str = "", + position: int = 0, + time1: int = 0, + outa1: bool = False, + outb1: bool = False, + outc1: bool = False, + outd1: bool = False, + oute1: bool = False, + outf1: bool = False, + time2: int = 0, + outa2: bool = False, + outb2: bool = False, + outc2: bool = False, + outd2: bool = False, + oute2: bool = False, + outf2: bool = False, ) -> "SeqTable": - """Convert a pva table to a row-wise SeqTable.""" - - ordered_columns = [ - np.array(pva_table[name], dtype=dtype) - for name, dtype in SeqTableRowType.descr - ] - - transposed = list(zip(*ordered_columns)) - rows = np.array([tuple(row) for row in transposed], dtype=SeqTableRowType) - return cls(rows) - - @field_validator("root", mode="before") - @classmethod - def check_valid_rows(cls, rows: Union[Sequence, np.ndarray]): - assert isinstance( - rows, (np.ndarray, list) - ), "Rows must be a list or numpy array." - - if not (0 <= len(rows) < 4096): - raise ValueError(f"Length {len(rows)} not in range.") - - if not all(isinstance(row, (np.ndarray, np.void)) for row in rows): - raise ValueError("Cannot construct a SeqTable, some rows are not arrays.") - - if not all(row.dtype is SeqTableRowType for row in rows): - raise ValueError( - "Cannot construct a SeqTable, some rows have incorrect types." - ) - - return np.array(rows, dtype=SeqTableRowType) + return PvaTable.row( + cls, + repeats=repeats, + trigger=trigger, + position=position, + time1=time1, + outa1=outa1, + outb1=outb1, + outc1=outc1, + outd1=outd1, + oute1=oute1, + outf1=outf1, + time2=time2, + outa2=outa2, + outb2=outb2, + outc2=outc2, + outd2=outd2, + oute2=oute2, + outf2=outf2, + ) diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 04adf046a8..daa686b477 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -15,7 +15,6 @@ PcompInfo, SeqTable, SeqTableInfo, - seq_table_row, ) @@ -73,26 +72,26 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table = SeqTable( - [ - # Wait for pre-delay then open shutter - seq_table_row( - time1=in_micros(pre_delay), - time2=in_micros(shutter_time), - outa2=True, - ), - # Keeping shutter open, do N triggers - seq_table_row( - repeats=number_of_frames, - time1=in_micros(exposure), - outa1=True, - outb1=True, - time2=in_micros(deadtime), - outa2=True, - ), - # Add the shutter close - seq_table_row(time2=in_micros(shutter_time)), - ] + table = ( + # Wait for pre-delay then open shutter + SeqTable.row( + time1=in_micros(pre_delay), + time2=in_micros(shutter_time), + outa2=True, + ) + + + # Keeping shutter open, do N triggers + SeqTable.row( + repeats=number_of_frames, + time1=in_micros(exposure), + outa1=True, + outb1=True, + time2=in_micros(deadtime), + outa2=True, + ) + + + # Add the shutter close + SeqTable.row(time2=in_micros(shutter_time)) ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 16d18696ec..eb5d54977b 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -8,6 +8,7 @@ import pytest import yaml from bluesky.run_engine import RunEngine +from pydantic import BaseModel, Field from ophyd_async.core import ( Device, @@ -22,7 +23,6 @@ set_signal_values, walk_rw_signals, ) -from ophyd_async.core._signal_backend import ProtocolDatatypeAbstraction from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw @@ -55,16 +55,8 @@ class MyEnum(str, Enum): three = "three" -class SomeProtocolDatatypeAbstraction(ProtocolDatatypeAbstraction): - def __init__(self, value: int): - self.value = value - - def convert_to_protocol_datatype(self) -> int: - return self.value - 1 - - @classmethod - def convert_from_protocol_datatype(cls, value: int) -> "SomeProtocolDatatypeAbstraction": - return cls(value + 1) +class SomePvaPydanticModel(BaseModel): + some_field: int = Field(default=1) class DummyDeviceGroupAllTypes(Device): @@ -86,7 +78,9 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") - self.pv_protocol_device_abstraction = epics_signal_rw(SomeProtocolDatatypeAbstraction, "PV17") + self.pv_protocol_device_abstraction = epics_signal_rw( + SomePvaPydanticModel, "pva://PV17" + ) @pytest.fixture @@ -170,7 +164,7 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): ["one", "two", "three"], ) await device_all_types.pv_protocol_device_abstraction.set( - SomeProtocolDatatypeAbstraction(1) + SomePvaPydanticModel(some_field=1) ) # Create save plan from utility functions diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index a4bc23d41c..b7ce1b233d 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -20,7 +20,6 @@ SeqBlock, SeqTable, SeqTrigger, - seq_table_row, ) @@ -91,85 +90,83 @@ def test_panda_name_set(panda_t): async def test_panda_children_connected(mock_panda): # try to set and retrieve from simulated values... - table = table = SeqTable( - [ - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_GT, - position=3222, - time1=5, - outa1=True, - outb1=False, - outc1=False, - outd1=True, - oute1=True, - outf1=True, - time2=0, - outa2=True, - outb2=False, - outc2=False, - outd2=True, - oute2=True, - outf2=True, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.POSA_LT, - position=-565, - time1=0, - outa1=False, - outb1=False, - outc1=True, - outd1=True, - oute1=False, - outf1=False, - time2=10, - outa2=False, - outb2=False, - outc2=True, - outd2=True, - oute2=False, - outf2=False, - ), - seq_table_row( - repeats=1, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=False, - outb1=True, - outc1=True, - outd1=False, - oute1=True, - outf1=False, - time2=10, - outa2=False, - outb2=True, - outc2=True, - outd2=False, - oute2=True, - outf2=False, - ), - seq_table_row( - repeats=32, - trigger=SeqTrigger.IMMEDIATE, - position=0, - time1=10, - outa1=True, - outb1=True, - outc1=False, - outd1=True, - oute1=False, - outf1=False, - time2=11, - outa2=True, - outb2=True, - outc2=False, - outd2=True, - oute2=False, - outf2=False, - ), - ] + table = ( + SeqTable.row( + repeats=1, + trigger=SeqTrigger.POSA_GT, + position=3222, + time1=5, + outa1=True, + outb1=False, + outc1=False, + outd1=True, + oute1=True, + outf1=True, + time2=0, + outa2=True, + outb2=False, + outc2=False, + outd2=True, + oute2=True, + outf2=True, + ) + + SeqTable.row( + repeats=1, + trigger=SeqTrigger.POSA_LT, + position=-565, + time1=0, + outa1=False, + outb1=False, + outc1=True, + outd1=True, + oute1=False, + outf1=False, + time2=10, + outa2=False, + outb2=False, + outc2=True, + outd2=True, + oute2=False, + outf2=False, + ) + + SeqTable.row( + repeats=1, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=False, + outb1=True, + outc1=True, + outd1=False, + oute1=True, + outf1=False, + time2=10, + outa2=False, + outb2=True, + outc2=True, + outd2=False, + oute2=True, + outf2=False, + ) + + SeqTable.row( + repeats=32, + trigger=SeqTrigger.IMMEDIATE, + position=0, + time1=10, + outa1=True, + outb1=True, + outc1=False, + outd1=True, + oute1=False, + outf1=False, + time2=11, + outa2=True, + outb2=True, + outc2=False, + outd2=True, + oute2=False, + outf2=False, + ) ) await mock_panda.pulse[1].delay.set(20.0) await mock_panda.seq[1].table.set(table) diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index a27c3e8d56..2e45f18bd4 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,4 +1,3 @@ - import numpy as np from bluesky import RunEngine @@ -10,7 +9,6 @@ DataBlock, SeqTable, phase_sorter, - seq_table_row, ) @@ -36,27 +34,25 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_save_load_panda(tmp_path, RE: RunEngine): mock_panda1 = await get_mock_panda() - await mock_panda1.seq[1].table.set(SeqTable([seq_table_row(repeats=1)])) + await mock_panda1.seq[1].table.set(SeqTable.row(repeats=1)) RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) def check_equal_with_seq_tables(actual, expected): - assert set(actual.keys()) == set(expected.keys()) - for key, value1 in actual.items(): - value2 = expected[key] - if isinstance(value1, SeqTable): - assert np.array_equal(value1.root, value2.root) - else: - assert value1 == value2 + assert actual.model_fields_set == expected.model_fields_set + for field_name, field_value1 in actual: + field_value2 = getattr(expected, field_name) + assert np.array_equal(field_value1, field_value2) mock_panda2 = await get_mock_panda() - assert np.array_equal( - (await mock_panda2.seq[1].table.get_value()).root, SeqTable([]).root + check_equal_with_seq_tables( + (await mock_panda2.seq[1].table.get_value()), SeqTable() ) RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) - assert np.array_equal( - (await mock_panda2.seq[1].table.get_value()).root, - SeqTable([seq_table_row(repeats=1)]).root, + + check_equal_with_seq_tables( + await mock_panda2.seq[1].table.get_value(), + SeqTable.row(repeats=1), ) """ diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index 88bacabca4..1024283618 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,57 +1,86 @@ +from functools import reduce + import numpy as np import pytest from pydantic import ValidationError -from ophyd_async.fastcs.panda import SeqTable, SeqTableRowType, seq_table_row +from ophyd_async.fastcs.panda import SeqTable -@pytest.mark.parametrize( - # factory so that there aren't global errors if seq_table_row() fails - "rows_arg_factory", - [ - lambda: None, - list, - lambda: [seq_table_row(), seq_table_row()], - lambda: np.array([seq_table_row(), seq_table_row()]), - ], -) -def test_seq_table_initialization_allowed_args(rows_arg_factory): - rows_arg = rows_arg_factory() - seq_table = SeqTable() if rows_arg is None else SeqTable(rows_arg) - assert isinstance(seq_table.root, np.ndarray) - assert len(seq_table.root) == (0 if rows_arg is None else len(rows_arg)) +def test_seq_table_converts_lists(): + seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} + # Validation passes + seq_table = SeqTable(**seq_table_dict_with_lists) + assert isinstance(seq_table.trigger, np.ndarray) + assert seq_table.trigger.dtype == np.dtype("U32") def test_seq_table_validation_errors(): - with pytest.raises( - ValueError, match="Cannot construct a SeqTable, some rows are not arrays." - ): - SeqTable([seq_table_row().tolist()]) - with pytest.raises(ValidationError, match="Length 4098 not in range."): - SeqTable([seq_table_row() for _ in range(4098)]) + with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + SeqTable( + repeats=0, + trigger="", + position=0, + time1=0, + outa1=False, + outb1=False, + outc1=False, + outd1=False, + oute1=False, + outf1=False, + time2=0, + outa2=False, + outb2=False, + outc2=False, + outd2=False, + oute2=False, + outf2=False, + ) + + large_seq_table = SeqTable( + repeats=np.zeros(4095, dtype=np.int32), + trigger=np.array([""] * 4095, dtype="U32"), + position=np.zeros(4095, dtype=np.int32), + time1=np.zeros(4095, dtype=np.int32), + outa1=np.zeros(4095, dtype=np.bool_), + outb1=np.zeros(4095, dtype=np.bool_), + outc1=np.zeros(4095, dtype=np.bool_), + outd1=np.zeros(4095, dtype=np.bool_), + oute1=np.zeros(4095, dtype=np.bool_), + outf1=np.zeros(4095, dtype=np.bool_), + time2=np.zeros(4095, dtype=np.int32), + outa2=np.zeros(4095, dtype=np.bool_), + outb2=np.zeros(4095, dtype=np.bool_), + outc2=np.zeros(4095, dtype=np.bool_), + outd2=np.zeros(4095, dtype=np.bool_), + oute2=np.zeros(4095, dtype=np.bool_), + outf2=np.zeros(4095, dtype=np.bool_), + ) with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types.", + match=( + "1 validation error for SeqTable\n " + "Assertion failed, Length 4096 not in range." + ), ): - SeqTable([seq_table_row(), np.array([1, 2, 3]), seq_table_row()]) + large_seq_table + SeqTable.row() with pytest.raises( ValidationError, - match="Cannot construct a SeqTable, some rows have incorrect types.", + match="12 validation errors for SeqTable", ): - SeqTable( - [ - seq_table_row(), - np.array(range(len(seq_table_row().tolist()))), - seq_table_row(), - ] - ) + row_one = SeqTable.row() + wrong_types = { + field_name: field_value.astype(np.unicode_) + for field_name, field_value in row_one + } + SeqTable(**wrong_types) def test_seq_table_pva_conversion(): expected_pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), "trigger": np.array( - ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype="U14" + ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") ), "position": np.array([1, 2, 3, 4], dtype=np.int32), "time1": np.array([1, 0, 1, 0], dtype=np.int32), @@ -69,51 +98,102 @@ def test_seq_table_pva_conversion(): "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), } - expected_numpy_table = np.array( - [ - (1, "Immediate", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), - (2, "Immediate", 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0), - (3, "BITC=0", 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1), - (4, "Immediate", 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0), - ], - dtype=SeqTableRowType, - ) + expected_row_wise_dict = [ + { + "repeats": 1, + "trigger": "Immediate", + "position": 1, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 1, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 2, + "trigger": "Immediate", + "position": 2, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 2, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + { + "repeats": 3, + "trigger": "BITC=0", + "position": 3, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 3, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 4, + "trigger": "Immediate", + "position": 4, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 4, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + ] - # Can convert from PVA table - numpy_table_from_pva_dict = SeqTable.convert_from_protocol_datatype( - expected_pva_dict - ) - assert np.array_equal(numpy_table_from_pva_dict.root, expected_numpy_table) - assert ( - numpy_table_from_pva_dict.root.dtype - == expected_numpy_table.dtype - == SeqTableRowType - ) - - # Can convert to PVA table - pva_dict_from_numpy_table = SeqTable( - expected_numpy_table - ).convert_to_protocol_datatype() - for column1, column2 in zip( - pva_dict_from_numpy_table.values(), expected_pva_dict.values() + seq_table_from_pva_dict = SeqTable(**expected_pva_dict) + for (_, column1), column2 in zip( + seq_table_from_pva_dict, expected_pva_dict.values() ): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype - # Idempotency - applied_twice_to_numpy_table = SeqTable.convert_from_protocol_datatype( - SeqTable(expected_numpy_table).convert_to_protocol_datatype() - ) - assert np.array_equal(applied_twice_to_numpy_table.root, expected_numpy_table) - assert ( - applied_twice_to_numpy_table.root.dtype - == expected_numpy_table.dtype - == SeqTableRowType + seq_table_from_rows = reduce( + lambda x, y: x + y, + [SeqTable.row(**row_kwargs) for row_kwargs in expected_row_wise_dict], ) + for (_, column1), column2 in zip(seq_table_from_rows, expected_pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype - applied_twice_to_pva_dict = SeqTable( - SeqTable.convert_from_protocol_datatype(expected_pva_dict).root - ).convert_to_protocol_datatype() + # Idempotency + applied_twice_to_pva_dict = SeqTable(**expected_pva_dict).model_dump(mode="python") for column1, column2 in zip( applied_twice_to_pva_dict.values(), expected_pva_dict.values() ): diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index a12ee32200..d7a6e039fb 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -12,7 +12,6 @@ SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, - seq_table_row, ) @@ -38,13 +37,11 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async def test_seq_table_trigger_logic(mock_panda): trigger_logic = StaticSeqTableTriggerLogic(mock_panda.seq[1]) - seq_table = SeqTable( - [ - seq_table_row(outa1=True, outa2=True), - seq_table_row(outa1=False, outa2=False), - seq_table_row(outa1=True, outa2=False), - seq_table_row(outa1=False, outa2=True), - ] + seq_table = ( + SeqTable.row(outa1=True, outa2=True) + + SeqTable.row(outa1=False, outa2=False) + + SeqTable.row(outa1=True, outa2=False) + + SeqTable.row(outa1=False, outa2=True) ) seq_table_info = SeqTableInfo(sequence_table=seq_table, repeats=1) @@ -81,7 +78,7 @@ async def set_active(value: bool): [ ( { - "sequence_table": SeqTable([seq_table_row(outc2=1)]), + "sequence_table_factory": lambda: SeqTable.row(outc2=1), "repeats": 0, "prescale_as_us": -1, }, @@ -90,13 +87,11 @@ async def set_active(value: bool): ), ( { - "sequence_table": SeqTable( - [ - seq_table_row(outc2=True), - seq_table_row(outc2=False), - seq_table_row(outc2=True), - seq_table_row(outc2=False), - ] + "sequence_table_factory": lambda: ( + SeqTable.row(outc2=True) + + SeqTable.row(outc2=False) + + SeqTable.row(outc2=True) + + SeqTable.row(outc2=False) ), "repeats": -1, }, @@ -105,15 +100,16 @@ async def set_active(value: bool): ), ( { - "sequence_table": 1, + "sequence_table_factory": lambda: 1, "repeats": 1, }, - "Assertion failed, Rows must be a list or numpy array. " - "[type=assertion_error, input_value=1, input_type=int]", + "Input should be a valid dictionary or instance of SeqTable " + "[type=model_type, input_value=1, input_type=int]", ), ], ) def test_malformed_seq_table_info(kwargs, error_msg): + kwargs["sequence_table"] = kwargs.pop("sequence_table_factory")() with pytest.raises(ValidationError) as exc: SeqTableInfo(**kwargs) assert error_msg in str(exc.value) diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index 349536ecdc..b6872436f9 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -19,5 +19,6 @@ pv_enum_str: two pv_float: 1.234 pv_int: 1 + pv_protocol_device_abstraction: + some_field: 1 pv_str: test_string - pv_protocol_device_abstraction: 0 From c4bece5b0480656f6c046579ac46c1e2260c7f2b Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Wed, 21 Aug 2024 14:14:30 +0100 Subject: [PATCH 06/11] cleaned up after rebase --- src/ophyd_async/core/_device_save_loader.py | 2 - src/ophyd_async/core/_soft_signal_backend.py | 5 - src/ophyd_async/epics/signal/__init__.py | 3 +- src/ophyd_async/epics/signal/_aioca.py | 4 +- src/ophyd_async/epics/signal/_p4p.py | 45 +++---- ...ble_abstraction.py => _p4p_table_model.py} | 14 +- tests/core/test_signal.py | 6 +- tests/fastcs/panda/test_panda_utils.py | 123 ++++++++++++------ tests/fastcs/panda/test_table.py | 20 ++- 9 files changed, 119 insertions(+), 103 deletions(-) rename src/ophyd_async/epics/signal/{_p4p_table_abstraction.py => _p4p_table_model.py} (81%) diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index a40c404d50..d847caff69 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -22,8 +22,6 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No def pydantic_model_abstraction_representer( dumper: yaml.Dumper, model: BaseModel ) -> yaml.Node: - """Uses the protocol datatype since it has to be serializable.""" - return dumper.represent_data(model.model_dump(mode="python")) diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 26b85272a7..d35204795b 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -128,8 +128,6 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftPydanticModelConverter(SoftConverter): - """Necessary for serializing soft signals.""" - def __init__(self, datatype: Type[BaseModel]): self.datatype = datatype @@ -143,9 +141,6 @@ def value(self, value: Any) -> Any: return value def write_value(self, value): - if isinstance(value, dict): - # If the device is being deserialized - return self.datatype(**value).model_dump(mode="python") if isinstance(value, self.datatype): return value.model_dump(mode="python") return value diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index f9bd58306f..a249d41cf0 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,6 +1,6 @@ from ._common import LimitPair, Limits, get_supported_values from ._p4p import PvaSignalBackend -from ._p4p_table_abstraction import PvaTable +from ._p4p_table_model import PvaTable from ._signal import ( epics_signal_r, epics_signal_rw, @@ -15,7 +15,6 @@ "Limits", "PvaSignalBackend", "PvaTable", - "PvaTableAbstraction", "epics_signal_r", "epics_signal_rw", "epics_signal_rw_rbv", diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index aa15d5a570..4a34ff4367 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -261,7 +261,9 @@ def make_converter( # Allow int signals to represent float records when prec is 0 is_prec_zero_float = ( isinstance(value, float) - and get_unique({k: v.precision for k, v in values.items()}, "precision") + and get_unique( + {k: v.precision for k, v in values.items()}, "precision" + ) == 0 ) if not (datatype is int and is_prec_zero_float): diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 3f95740522..cb1502cea0 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -67,7 +67,7 @@ def _data_key_from_value( *, shape: Optional[list[int]] = None, choices: Optional[list[str]] = None, - dtype: Optional[str] = None, + dtype: Optional[Dtype] = None, ) -> DataKey: """ Args: @@ -256,6 +256,19 @@ def get_datakey(self, source: str, value) -> DataKey: return _data_key_from_value(source, value, dtype="object") +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): + self.datatype = datatype + + def value(self, value: Value): + return self.datatype(**value.todict()) + + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + if isinstance(value, self.datatype): + return value.model_dump(mode="python") + return value + + class PvaDictConverter(PvaConverter): def reading(self, value): ts = time.time() @@ -287,28 +300,6 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -class PvaPydanticModelConverter(PvaConverter): - def __init__(self, datatype: BaseModel): - self.datatype = datatype - - def reading(self, value: Value): - ts = time.time() - value = self.value(value) - return {"value": value, "timestamp": ts, "alarm_severity": 0} - - def value(self, value: Value): - return self.datatype(**value.todict()) - - def write_value(self, value: Union[BaseModel, Dict[str, Any]]): - """ - A user can put whichever form to the signal. - This is required for yaml deserialization. - """ - if isinstance(value, self.datatype): - return value.model_dump(mode="python") - return value - - class PvaConverterFactory(BackendConverterFactory): _ALLOWED_TYPES = ( bool, @@ -398,17 +389,19 @@ def make_converter( == 0 ) if not (datatype is int and is_prec_zero_float): - raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") + raise TypeError( + f"{pv} has type {typ.__name__} not {datatype.__name__}" + ) return PvaConverter() elif "NTTable" in typeid: - return PvaTableConverter() - elif "structure" in typeid: if ( datatype and inspect.isclass(datatype) and issubclass(datatype, BaseModel) ): return PvaPydanticModelConverter(datatype) + return PvaTableConverter() + elif "structure" in typeid: return PvaDictConverter() else: raise TypeError(f"{pv}: Unsupported typeid {typeid}") diff --git a/src/ophyd_async/epics/signal/_p4p_table_abstraction.py b/src/ophyd_async/epics/signal/_p4p_table_model.py similarity index 81% rename from src/ophyd_async/epics/signal/_p4p_table_abstraction.py rename to src/ophyd_async/epics/signal/_p4p_table_model.py index a6e5ecf566..49d115903b 100644 --- a/src/ophyd_async/epics/signal/_p4p_table_abstraction.py +++ b/src/ophyd_async/epics/signal/_p4p_table_model.py @@ -1,12 +1,9 @@ -from typing import Dict - import numpy as np from pydantic import BaseModel, ConfigDict, model_validator -from pydantic_numpy.typing import NpNDArray class PvaTable(BaseModel): - """An abstraction of a PVA Table of str to python array.""" + """An abstraction of a PVA Table of str to numpy array.""" model_config = ConfigDict(validate_assignment=True, strict=False) @@ -24,7 +21,7 @@ def row(cls, sub_cls, **kwargs) -> "PvaTable": return sub_cls(**arrayified_kwargs) def __add__(self, right: "PvaTable") -> "PvaTable": - """Concatinate the arrays in field values.""" + """Concatenate the arrays in field values.""" assert isinstance(right, type(self)), ( f"{right} is not a `PvaTable`, or is not the same " @@ -61,10 +58,3 @@ def validate_arrays(self) -> "PvaTable": ) return self - - def convert_to_pva_datatype(self) -> Dict[str, NpNDArray]: - return self.model_dump(mode="python") - - @classmethod - def convert_from_pva_datatype(cls, pva_table: Dict[str, NpNDArray]): - return cls(**pva_table) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3adafdd9f8..ab5c02cffe 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -413,7 +413,6 @@ def __init__(self): def some_function(self): pass - # with pytest.raises(ValueError, match="Unknown datatype 'SomeClass'"): err_str = ( "Given datatype .SomeClass'>" @@ -425,4 +424,7 @@ def some_function(self): epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") # Any dtype allowed in soft signal - soft_signal_rw(SomeClass, SomeClass(), "soft_signal") + signal = soft_signal_rw(SomeClass, SomeClass(), "soft_signal") + assert isinstance((await signal.get_value()), SomeClass) + await signal.set(1) + assert (await signal.get_value()) == 1 diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index 2e45f18bd4..f79ac0442c 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,4 +1,5 @@ import numpy as np +import yaml from bluesky import RunEngine from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, load_device, save_device @@ -7,7 +8,9 @@ from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, + PcompDirectionOptions, SeqTable, + TimeUnits, phase_sorter, ) @@ -55,48 +58,86 @@ def check_equal_with_seq_tables(actual, expected): SeqTable.row(repeats=1), ) - """ - assert mock_save_to_yaml.call_args[0][0][0] == { + # Load the YAML content as a string + with open(str(tmp_path / "panda.yaml"), "r") as file: + yaml_content = file.read() + + # Parse the YAML content + parsed_yaml = yaml.safe_load(yaml_content) + + assert parsed_yaml[0] == { "phase_1_signal_units": 0, "seq.1.prescale_units": TimeUnits("min"), "seq.2.prescale_units": TimeUnits("min"), } - check_equal_with_seq_tables(mock_save_to_yaml.call_args[0][0][1], - { - "data.capture": False, - "data.create_directory": 0, - "data.flush_period": 0.0, - "data.hdf_directory": "", - "data.hdf_file_name": "", - "data.num_capture": 0, - "pcap.arm": False, - "pcomp.1.dir": PcompDirectionOptions.positive, - "pcomp.1.enable": "ZERO", - "pcomp.1.pulses": 0, - "pcomp.1.start": 0, - "pcomp.1.step": 0, - "pcomp.1.width": 0, - "pcomp.2.dir": PcompDirectionOptions.positive, - "pcomp.2.enable": "ZERO", - "pcomp.2.pulses": 0, - "pcomp.2.start": 0, - "pcomp.2.step": 0, - "pcomp.2.width": 0, - "pulse.1.delay": 0.0, - "pulse.1.width": 0.0, - "pulse.2.delay": 0.0, - "pulse.2.width": 0.0, - "seq.1.active": False, - "seq.1.table": SeqTable([]), - "seq.1.repeats": 0, - "seq.1.prescale": 0.0, - "seq.1.enable": "ZERO", - "seq.2.table": SeqTable([]), - "seq.2.active": False, - "seq.2.repeats": 0, - "seq.2.prescale": 0.0, - "seq.2.enable": "ZERO", - }, - ) - assert mock_save_to_yaml.call_args[0][1] == "path" - """ + assert parsed_yaml[1] == { + "data.capture": False, + "data.create_directory": 0, + "data.flush_period": 0.0, + "data.hdf_directory": "", + "data.hdf_file_name": "", + "data.num_capture": 0, + "pcap.arm": False, + "pcomp.1.dir": PcompDirectionOptions.positive, + "pcomp.1.enable": "ZERO", + "pcomp.1.pulses": 0, + "pcomp.1.start": 0, + "pcomp.1.step": 0, + "pcomp.1.width": 0, + "pcomp.2.dir": PcompDirectionOptions.positive, + "pcomp.2.enable": "ZERO", + "pcomp.2.pulses": 0, + "pcomp.2.start": 0, + "pcomp.2.step": 0, + "pcomp.2.width": 0, + "pulse.1.delay": 0.0, + "pulse.1.width": 0.0, + "pulse.2.delay": 0.0, + "pulse.2.width": 0.0, + "seq.1.active": False, + "seq.1.table": { + "outa1": [False], + "outa2": [False], + "outb1": [False], + "outb2": [False], + "outc1": [False], + "outc2": [False], + "outd1": [False], + "outd2": [False], + "oute1": [False], + "oute2": [False], + "outf1": [False], + "outf2": [False], + "position": [0], + "repeats": [1], + "time1": [0], + "time2": [0], + "trigger": [""], + }, + "seq.1.repeats": 0, + "seq.1.prescale": 0.0, + "seq.1.enable": "ZERO", + "seq.2.table": { + "outa1": [], + "outa2": [], + "outb1": [], + "outb2": [], + "outc1": [], + "outc2": [], + "outd1": [], + "outd2": [], + "oute1": [], + "oute2": [], + "outf1": [], + "outf2": [], + "position": [], + "repeats": [], + "time1": [], + "time2": [], + "trigger": [], + }, + "seq.2.active": False, + "seq.2.repeats": 0, + "seq.2.prescale": 0.0, + "seq.2.enable": "ZERO", + } diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index 1024283618..ba5ad3e2f9 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -77,7 +77,7 @@ def test_seq_table_validation_errors(): def test_seq_table_pva_conversion(): - expected_pva_dict = { + pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), "trigger": np.array( ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") @@ -98,7 +98,7 @@ def test_seq_table_pva_conversion(): "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), } - expected_row_wise_dict = [ + row_wise_dicts = [ { "repeats": 1, "trigger": "Immediate", @@ -177,25 +177,21 @@ def test_seq_table_pva_conversion(): }, ] - seq_table_from_pva_dict = SeqTable(**expected_pva_dict) - for (_, column1), column2 in zip( - seq_table_from_pva_dict, expected_pva_dict.values() - ): + seq_table_from_pva_dict = SeqTable(**pva_dict) + for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype seq_table_from_rows = reduce( lambda x, y: x + y, - [SeqTable.row(**row_kwargs) for row_kwargs in expected_row_wise_dict], + [SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts], ) - for (_, column1), column2 in zip(seq_table_from_rows, expected_pva_dict.values()): + for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype # Idempotency - applied_twice_to_pva_dict = SeqTable(**expected_pva_dict).model_dump(mode="python") - for column1, column2 in zip( - applied_twice_to_pva_dict.values(), expected_pva_dict.values() - ): + applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") + for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype From 3af98831a2e31cc7bd243813a7fc4ccd1da5e0d7 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 27 Aug 2024 13:35:26 +0100 Subject: [PATCH 07/11] made suggested changes Still need to add datatype to the args of signal --- src/ophyd_async/core/__init__.py | 2 - src/ophyd_async/core/_signal_backend.py | 26 +-- src/ophyd_async/core/_soft_signal_backend.py | 49 +++--- src/ophyd_async/epics/signal/_aioca.py | 165 +++++++++-------- src/ophyd_async/epics/signal/_p4p.py | 175 +++++++++---------- src/ophyd_async/fastcs/panda/_table.py | 71 +++----- tests/core/test_device_save_loader.py | 13 +- tests/core/test_subset_enum.py | 8 +- tests/fastcs/panda/test_panda_connect.py | 102 +++-------- tests/test_data/test_yaml_save.yml | 4 +- 10 files changed, 249 insertions(+), 366 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index d66cbc5ba3..a81f804a6b 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -62,7 +62,6 @@ wait_for_value, ) from ._signal_backend import ( - BackendConverterFactory, RuntimeSubsetEnum, SignalBackend, SubsetEnum, @@ -108,7 +107,6 @@ "MockSignalBackend", "callback_on_mock_put", "get_mock_put", - "BackendConverterFactory", "mock_puts_blocked", "reset_mock_put_calls", "set_mock_put_proceeds", diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 178ff1edfa..45bdad8f65 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import ( TYPE_CHECKING, ClassVar, @@ -13,28 +13,20 @@ from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T -class BackendConverterFactory(ABC): - """Convert between the signal backend and the signal type""" - - _ALLOWED_TYPES: ClassVar[Tuple[Type]] - - @classmethod - @abstractmethod - def datatype_allowed(cls, datatype: Type) -> bool: - """Check if the datatype is allowed.""" - - @classmethod - @abstractmethod - def make_converter(self, datatype: Type): - """Updates the object with callables `to_signal` and `from_signal`.""" - - class SignalBackend(Generic[T]): """A read/write/monitor backend for a Signals""" #: Datatype of the signal value datatype: Optional[Type[T]] = None + _ALLOWED_DATATYPES: ClassVar[Tuple[Type]] + + @classmethod + @abstractmethod + def datatype_allowed(cls, dtype: type): + """Check if a given datatype is acceptable for this signal backend.""" + pass + #: Like ca://PV_PREFIX:SIGNAL @abstractmethod def source(self, name: str) -> str: diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index d35204795b..1b30c0ae73 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -12,7 +12,6 @@ from typing_extensions import TypedDict from ._signal_backend import ( - BackendConverterFactory, RuntimeSubsetEnum, SignalBackend, ) @@ -149,32 +148,22 @@ def make_initial_value(self, datatype: Type | None) -> Any: return super().make_initial_value(datatype) -class SoftSignalConverterFactory(BackendConverterFactory): - _ALLOWED_TYPES = (object,) # Any type is allowed +def make_converter(datatype): + is_array = get_dtype(datatype) is not None + is_sequence = get_origin(datatype) == abc.Sequence + is_enum = inspect.isclass(datatype) and ( + issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) + ) + is_pydantic_model = inspect.isclass(datatype) and issubclass(datatype, BaseModel) - @classmethod - def datatype_allowed(cls, datatype: Type) -> bool: - return True # Any value allowed in a soft signal - - @classmethod - def make_converter(cls, datatype): - is_array = get_dtype(datatype) is not None - is_sequence = get_origin(datatype) == abc.Sequence - is_enum = inspect.isclass(datatype) and ( - issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) - ) - is_pydantic_model = inspect.isclass(datatype) and issubclass( - datatype, BaseModel - ) + if is_array or is_sequence: + return SoftArrayConverter() + if is_enum: + return SoftEnumConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) - if is_array or is_sequence: - return SoftArrayConverter() - if is_enum: - return SoftEnumConverter(datatype) - if is_pydantic_model: - return SoftPydanticModelConverter(datatype) - - return SoftConverter() + return SoftConverter() class SoftSignalBackend(SignalBackend[T]): @@ -185,6 +174,12 @@ class SoftSignalBackend(SignalBackend[T]): _timestamp: float _severity: int + _ALLOWED_DATATYPES = (object,) # Any type is allowed + + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + return True # Any value allowed in a soft signal + def __init__( self, datatype: Optional[Type[T]], @@ -194,9 +189,7 @@ def __init__( self.datatype = datatype self._initial_value = initial_value self._metadata = metadata or {} - self.converter: SoftConverter = SoftSignalConverterFactory.make_converter( - datatype - ) + self.converter: SoftConverter = make_converter(datatype) if self._initial_value is None: self._initial_value = self.converter.make_initial_value(self.datatype) else: diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 4a34ff4367..ef8a5693e2 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -23,7 +23,6 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, - BackendConverterFactory, NotConnected, ReadingValueCallback, RuntimeSubsetEnum, @@ -186,8 +185,81 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -class CaConverterFactory(BackendConverterFactory): - _ALLOWED_TYPES = ( +def make_converter( + datatype: Optional[Type], values: Dict[str, AugmentedValue] +) -> CaConverter: + pv = list(values)[0] + pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + is_array = bool([v for v in values.values() if v.element_count > 1]) + if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: + # Override waveform of chars to be treated as string + return CaLongStrConverter() + elif is_array and pv_dbr == dbr.DBR_STRING: + # Waveform of strings, check we wanted this + if datatype: + datatype_dtype = get_dtype(datatype) + if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): + raise TypeError(f"{pv} has type [str] not {datatype.__name__}") + return CaArrayConverter(pv_dbr, None) + elif is_array: + pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") + # This is an array + if datatype: + # Check we wanted an array of this type + dtype = get_dtype(datatype) + if not dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") + if dtype != pv_dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") + return CaArrayConverter(pv_dbr, None) + elif pv_dbr == dbr.DBR_ENUM and datatype is bool: + # Database can't do bools, so are often representated as enums, + # CA can do int + pv_choices_len = get_unique( + {k: len(v.enums) for k, v in values.items()}, "number of choices" + ) + if pv_choices_len != 2: + raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") + return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) + elif pv_dbr == dbr.DBR_ENUM: + # This is an Enum + pv_choices = get_unique( + {k: tuple(v.enums) for k, v in values.items()}, "choices" + ) + supported_values = get_supported_values(pv, datatype, pv_choices) + return CaEnumConverter(dbr.DBR_STRING, None, supported_values) + else: + value = list(values.values())[0] + # Done the dbr check, so enough to check one of the values + if datatype and not isinstance(value, datatype): + # Allow int signals to represent float records when prec is 0 + is_prec_zero_float = ( + isinstance(value, float) + and get_unique({k: v.precision for k, v in values.items()}, "precision") + == 0 + ) + if not (datatype is int and is_prec_zero_float): + raise TypeError( + f"{pv} has type {type(value).__name__.replace('ca_', '')} " + + f"not {datatype.__name__}" + ) + return CaConverter(pv_dbr, None) + + +_tried_pyepics = False + + +def _use_pyepics_context_if_imported(): + global _tried_pyepics + if not _tried_pyepics: + ca = sys.modules.get("epics.ca", None) + if ca: + ca.use_initial_context() + _tried_pyepics = True + + +class CaSignalBackend(SignalBackend[T]): + _ALLOWED_DATATYPES = ( bool, int, float, @@ -205,91 +277,12 @@ def datatype_allowed(cls, datatype: Optional[Type]) -> bool: return True return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_TYPES + stripped_origin, cls._ALLOWED_DATATYPES ) - @classmethod - def make_converter( - cls, datatype: Optional[Type], values: Dict[str, AugmentedValue] - ) -> CaConverter: - pv = list(values)[0] - pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") - is_array = bool([v for v in values.values() if v.element_count > 1]) - if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: - # Override waveform of chars to be treated as string - return CaLongStrConverter() - elif is_array and pv_dbr == dbr.DBR_STRING: - # Waveform of strings, check we wanted this - if datatype: - datatype_dtype = get_dtype(datatype) - if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return CaArrayConverter(pv_dbr, None) - elif is_array: - pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError( - f"{pv} has type [{pv_dtype}] not {datatype.__name__}" - ) - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - return CaArrayConverter(pv_dbr, None) - elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, - # CA can do int - pv_choices_len = get_unique( - {k: len(v.enums) for k, v in values.items()}, "number of choices" - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) - elif pv_dbr == dbr.DBR_ENUM: - # This is an Enum - pv_choices = get_unique( - {k: tuple(v.enums) for k, v in values.items()}, "choices" - ) - supported_values = get_supported_values(pv, datatype, pv_choices) - return CaEnumConverter(dbr.DBR_STRING, None, supported_values) - else: - value = list(values.values())[0] - # Done the dbr check, so enough to check one of the values - if datatype and not isinstance(value, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = ( - isinstance(value, float) - and get_unique( - {k: v.precision for k, v in values.items()}, "precision" - ) - == 0 - ) - if not (datatype is int and is_prec_zero_float): - raise TypeError( - f"{pv} has type {type(value).__name__.replace('ca_', '')} " - + f"not {datatype.__name__}" - ) - return CaConverter(pv_dbr, None) - - -_tried_pyepics = False - - -def _use_pyepics_context_if_imported(): - global _tried_pyepics - if not _tried_pyepics: - ca = sys.modules.get("epics.ca", None) - if ca: - ca.use_initial_context() - _tried_pyepics = True - - -class CaSignalBackend(SignalBackend[T]): def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype - if not CaConverterFactory.datatype_allowed(self.datatype): + if not CaSignalBackend.datatype_allowed(self.datatype): raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv @@ -320,9 +313,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): else: # The same, so only need to connect one await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = CaConverterFactory.make_converter( - self.datatype, self.initial_values - ) + self.converter = make_converter(self.datatype, self.initial_values) async def put(self, value: Optional[T], wait=True, timeout=None): if value is None: diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index cb1502cea0..4e35f4d5b9 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -16,7 +16,6 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, - BackendConverterFactory, NotConnected, ReadingValueCallback, RuntimeSubsetEnum, @@ -300,8 +299,83 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -class PvaConverterFactory(BackendConverterFactory): - _ALLOWED_TYPES = ( +def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConverter: + pv = list(values)[0] + typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") + typ = get_unique( + {k: type(v.get("value")) for k, v in values.items()}, "value types" + ) + if "NTScalarArray" in typeid and typ is list: + # Waveform of strings, check we wanted this + if datatype and datatype != Sequence[str]: + raise TypeError(f"{pv} has type [str] not {datatype.__name__}") + return PvaArrayConverter() + elif "NTScalarArray" in typeid or "NTNDArray" in typeid: + pv_dtype = get_unique( + {k: v["value"].dtype for k, v in values.items()}, "dtypes" + ) + # This is an array + if datatype: + # Check we wanted an array of this type + dtype = get_dtype(datatype) + if not dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") + if dtype != pv_dtype: + raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") + if "NTNDArray" in typeid: + return PvaNDArrayConverter() + else: + return PvaArrayConverter() + elif "NTEnum" in typeid and datatype is bool: + # Wanted a bool, but database represents as an enum + pv_choices_len = get_unique( + {k: len(v["value"]["choices"]) for k, v in values.items()}, + "number of choices", + ) + if pv_choices_len != 2: + raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") + return PvaEmumBoolConverter() + elif "NTEnum" in typeid: + # This is an Enum + pv_choices = get_unique( + {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" + ) + return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) + elif "NTScalar" in typeid: + if ( + typ is str + and inspect.isclass(datatype) + and issubclass(datatype, RuntimeSubsetEnum) + ): + return PvaEnumConverter( + get_supported_values(pv, datatype, datatype.choices) + ) + elif datatype and not issubclass(typ, datatype): + # Allow int signals to represent float records when prec is 0 + is_prec_zero_float = typ is float and ( + get_unique( + {k: v["display"]["precision"] for k, v in values.items()}, + "precision", + ) + == 0 + ) + if not (datatype is int and is_prec_zero_float): + raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") + return PvaConverter() + elif "NTTable" in typeid: + if datatype and inspect.isclass(datatype) and issubclass(datatype, BaseModel): + return PvaPydanticModelConverter(datatype) + return PvaTableConverter() + elif "structure" in typeid: + return PvaDictConverter() + else: + raise TypeError(f"{pv}: Unsupported typeid {typeid}") + + +class PvaSignalBackend(SignalBackend[T]): + _ctxt: Optional[Context] = None + + _ALLOWED_DATATYPES = ( bool, int, float, @@ -320,99 +394,12 @@ def datatype_allowed(cls, datatype: Optional[Type]) -> bool: if datatype is None: return True return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_TYPES + stripped_origin, cls._ALLOWED_DATATYPES ) - @classmethod - def make_converter( - cls, datatype: Optional[Type], values: Dict[str, Any] - ) -> PvaConverter: - pv = list(values)[0] - typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") - typ = get_unique( - {k: type(v.get("value")) for k, v in values.items()}, "value types" - ) - if "NTScalarArray" in typeid and typ is list: - # Waveform of strings, check we wanted this - if datatype and datatype != Sequence[str]: - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return PvaArrayConverter() - elif "NTScalarArray" in typeid or "NTNDArray" in typeid: - pv_dtype = get_unique( - {k: v["value"].dtype for k, v in values.items()}, "dtypes" - ) - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError( - f"{pv} has type [{pv_dtype}] not {datatype.__name__}" - ) - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - if "NTNDArray" in typeid: - return PvaNDArrayConverter() - else: - return PvaArrayConverter() - elif "NTEnum" in typeid and datatype is bool: - # Wanted a bool, but database represents as an enum - pv_choices_len = get_unique( - {k: len(v["value"]["choices"]) for k, v in values.items()}, - "number of choices", - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return PvaEmumBoolConverter() - elif "NTEnum" in typeid: - # This is an Enum - pv_choices = get_unique( - {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" - ) - return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) - elif "NTScalar" in typeid: - if ( - typ is str - and inspect.isclass(datatype) - and issubclass(datatype, RuntimeSubsetEnum) - ): - return PvaEnumConverter( - get_supported_values(pv, datatype, datatype.choices) - ) - elif datatype and not issubclass(typ, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = typ is float and ( - get_unique( - {k: v["display"]["precision"] for k, v in values.items()}, - "precision", - ) - == 0 - ) - if not (datatype is int and is_prec_zero_float): - raise TypeError( - f"{pv} has type {typ.__name__} not {datatype.__name__}" - ) - return PvaConverter() - elif "NTTable" in typeid: - if ( - datatype - and inspect.isclass(datatype) - and issubclass(datatype, BaseModel) - ): - return PvaPydanticModelConverter(datatype) - return PvaTableConverter() - elif "structure" in typeid: - return PvaDictConverter() - else: - raise TypeError(f"{pv}: Unsupported typeid {typeid}") - - -class PvaSignalBackend(SignalBackend[T]): - _ctxt: Optional[Context] = None - def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype - if not PvaConverterFactory.datatype_allowed(self.datatype): + if not PvaSignalBackend.datatype_allowed(self.datatype): raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") self.read_pv = read_pv @@ -457,9 +444,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): else: # The same, so only need to connect one await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = PvaConverterFactory.make_converter( - self.datatype, self.initial_values - ) + self.converter = make_converter(self.datatype, self.initial_values) async def put(self, value: Optional[T], wait=True, timeout=None): if value is None: diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index b1ed6d5729..69767c744c 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -41,12 +41,14 @@ class SeqTrigger(str, Enum): NpArrayPydanticAnnotation.factory( data_type=np.int32, dimensions=1, strict_data_typing=False ), + Field(default_factory=lambda: np.array([], np.int32)), ] PydanticNp1DArrayBool = Annotated[ np.ndarray[tuple[int], np.bool_], NpArrayPydanticAnnotation.factory( data_type=np.bool_, dimensions=1, strict_data_typing=False ), + Field(default_factory=lambda: np.array([], dtype=np.bool_)), ] PydanticNp1DArrayUnicodeString = Annotated[ @@ -54,61 +56,28 @@ class SeqTrigger(str, Enum): NpArrayPydanticAnnotation.factory( data_type=np.unicode_, dimensions=1, strict_data_typing=False ), + Field(default_factory=lambda: np.array([], dtype=np.dtype(" Date: Mon, 9 Sep 2024 09:56:00 +0100 Subject: [PATCH 08/11] fixed issubclass of ABC check in python 3.10 --- src/ophyd_async/core/_soft_signal_backend.py | 7 ++++++- tests/core/test_soft_signal_backend.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 1b30c0ae73..e45b71b0d9 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -2,6 +2,7 @@ import inspect import time +from abc import ABCMeta from collections import abc from enum import Enum from typing import Any, Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin @@ -154,7 +155,11 @@ def make_converter(datatype): is_enum = inspect.isclass(datatype) and ( issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) - is_pydantic_model = inspect.isclass(datatype) and issubclass(datatype, BaseModel) + is_pydantic_model = ( + inspect.isclass(datatype) + and isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ) if is_array or is_sequence: return SoftArrayConverter() diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 5e55507626..16bf23567e 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -94,7 +94,7 @@ async def test_soft_signal_backend_get_put_monitor( descriptor: Callable[[Any], dict], dtype_numpy: str, ): - backend = SoftSignalBackend(datatype) + backend = SoftSignalBackend(datatype=datatype) await backend.connect() q = MonitorQueue(backend) From 49a047b84ce3b76f267fddfe4ec0de502f577326 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Wed, 11 Sep 2024 11:57:27 +0100 Subject: [PATCH 09/11] made suggested changes --- src/ophyd_async/core/__init__.py | 2 + src/ophyd_async/core/_signal_backend.py | 3 - src/ophyd_async/core/_soft_signal_backend.py | 22 ++----- .../_p4p_table_model.py => core/_table.py} | 16 ++--- src/ophyd_async/epics/signal/__init__.py | 2 - src/ophyd_async/epics/signal/_p4p.py | 10 ++- src/ophyd_async/fastcs/panda/_table.py | 66 +++++++++++-------- tests/fastcs/panda/test_panda_utils.py | 7 +- tests/fastcs/panda/test_table.py | 31 ++++++++- tests/test_data/test_yaml_save.yml | 2 +- 10 files changed, 95 insertions(+), 66 deletions(-) rename src/ophyd_async/{epics/signal/_p4p_table_model.py => core/_table.py} (78%) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index bb07372c96..1928c7aba4 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -68,6 +68,7 @@ ) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status +from ._table import Table from ._utils import ( DEFAULT_TIMEOUT, CalculatableTimeout, @@ -156,6 +157,7 @@ "CalculateTimeout", "NotConnected", "ReadingValueCallback", + "Table", "T", "WatcherUpdate", "get_dtype", diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 45bdad8f65..594863ef2a 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -19,13 +19,10 @@ class SignalBackend(Generic[T]): #: Datatype of the signal value datatype: Optional[Type[T]] = None - _ALLOWED_DATATYPES: ClassVar[Tuple[Type]] - @classmethod @abstractmethod def datatype_allowed(cls, dtype: type): """Check if a given datatype is acceptable for this signal backend.""" - pass #: Like ca://PV_PREFIX:SIGNAL @abstractmethod diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index e45b71b0d9..1e895e60cc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -5,7 +5,7 @@ from abc import ABCMeta from collections import abc from enum import Enum -from typing import Any, Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin +from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin import numpy as np from bluesky.protocols import DataKey, Dtype, Reading @@ -131,23 +131,11 @@ class SoftPydanticModelConverter(SoftConverter): def __init__(self, datatype: Type[BaseModel]): self.datatype = datatype - def reading(self, value: T, timestamp: float, severity: int) -> Reading: - value = self.value(value) - return super().reading(value, timestamp, severity) - - def value(self, value: Any) -> Any: - if isinstance(value, dict): - value = self.datatype(**value) - return value - def write_value(self, value): - if isinstance(value, self.datatype): - return value.model_dump(mode="python") + if isinstance(value, dict): + return self.datatype(**value) return value - def make_initial_value(self, datatype: Type | None) -> Any: - return super().make_initial_value(datatype) - def make_converter(datatype): is_array = get_dtype(datatype) is not None @@ -155,8 +143,10 @@ def make_converter(datatype): is_enum = inspect.isclass(datatype) and ( issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) + is_pydantic_model = ( inspect.isclass(datatype) + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ and isinstance(datatype, ABCMeta) and issubclass(datatype, BaseModel) ) @@ -179,8 +169,6 @@ class SoftSignalBackend(SignalBackend[T]): _timestamp: float _severity: int - _ALLOWED_DATATYPES = (object,) # Any type is allowed - @classmethod def datatype_allowed(cls, datatype: Type) -> bool: return True # Any value allowed in a soft signal diff --git a/src/ophyd_async/epics/signal/_p4p_table_model.py b/src/ophyd_async/core/_table.py similarity index 78% rename from src/ophyd_async/epics/signal/_p4p_table_model.py rename to src/ophyd_async/core/_table.py index 49d115903b..bdb619a3b9 100644 --- a/src/ophyd_async/epics/signal/_p4p_table_model.py +++ b/src/ophyd_async/core/_table.py @@ -2,13 +2,13 @@ from pydantic import BaseModel, ConfigDict, model_validator -class PvaTable(BaseModel): - """An abstraction of a PVA Table of str to numpy array.""" +class Table(BaseModel): + """An abstraction of a Table of str to numpy array.""" model_config = ConfigDict(validate_assignment=True, strict=False) @classmethod - def row(cls, sub_cls, **kwargs) -> "PvaTable": + def row(cls, sub_cls, **kwargs) -> "Table": arrayified_kwargs = { field_name: np.concatenate( ( @@ -20,12 +20,12 @@ def row(cls, sub_cls, **kwargs) -> "PvaTable": } return sub_cls(**arrayified_kwargs) - def __add__(self, right: "PvaTable") -> "PvaTable": + def __add__(self, right: "Table") -> "Table": """Concatenate the arrays in field values.""" assert isinstance(right, type(self)), ( - f"{right} is not a `PvaTable`, or is not the same " - f"type of `PvaTable` as {self}." + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." ) return type(self)( @@ -38,14 +38,12 @@ def __add__(self, right: "PvaTable") -> "PvaTable": ) @model_validator(mode="after") - def validate_arrays(self) -> "PvaTable": + def validate_arrays(self) -> "Table": first_length = len(next(iter(self))[1]) assert all( len(field_value) == first_length for _, field_value in self ), "Rows should all be of equal size." - assert 0 <= first_length < 4096, f"Length {first_length} not in range." - if not all( np.issubdtype( self.model_fields[field_name].default_factory().dtype, field_value.dtype diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index a249d41cf0..8d7628bf01 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,6 +1,5 @@ from ._common import LimitPair, Limits, get_supported_values from ._p4p import PvaSignalBackend -from ._p4p_table_model import PvaTable from ._signal import ( epics_signal_r, epics_signal_rw, @@ -14,7 +13,6 @@ "LimitPair", "Limits", "PvaSignalBackend", - "PvaTable", "epics_signal_r", "epics_signal_rw", "epics_signal_rw_rbv", diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 4e35f4d5b9..c7d0b5240d 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,6 +3,7 @@ import inspect import logging import time +from abc import ABCMeta from dataclasses import dataclass from enum import Enum from math import isnan, nan @@ -363,7 +364,14 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") return PvaConverter() elif "NTTable" in typeid: - if datatype and inspect.isclass(datatype) and issubclass(datatype, BaseModel): + if ( + datatype + and inspect.isclass(datatype) + and + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ): return PvaPydanticModelConverter(datatype) return PvaTableConverter() elif "structure" in typeid: diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 69767c744c..d2a31c53e4 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,13 +1,14 @@ +import inspect from enum import Enum from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -from pydantic import Field +from pydantic import Field, field_validator, model_validator from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation from typing_extensions import TypedDict -from ophyd_async.epics.signal import PvaTable +from ophyd_async.core import Table class PandaHdf5DatasetType(str, Enum): @@ -50,8 +51,7 @@ class SeqTrigger(str, Enum): ), Field(default_factory=lambda: np.array([], dtype=np.bool_)), ] - -PydanticNp1DArrayUnicodeString = Annotated[ +TriggerStr = Annotated[ np.ndarray[tuple[int], np.unicode_], NpArrayPydanticAnnotation.factory( data_type=np.unicode_, dimensions=1, strict_data_typing=False @@ -60,9 +60,9 @@ class SeqTrigger(str, Enum): ] -class SeqTable(PvaTable): +class SeqTable(Table): repeats: PydanticNp1DArrayInt32 - trigger: PydanticNp1DArrayUnicodeString + trigger: TriggerStr position: PydanticNp1DArrayInt32 time1: PydanticNp1DArrayInt32 outa1: PydanticNp1DArrayBool @@ -83,8 +83,8 @@ class SeqTable(PvaTable): def row( cls, *, - repeats: int = 0, - trigger: str = "", + repeats: int = 1, + trigger: str = SeqTrigger.IMMEDIATE, position: int = 0, time1: int = 0, outa1: bool = False, @@ -101,23 +101,33 @@ def row( oute2: bool = False, outf2: bool = False, ) -> "SeqTable": - return PvaTable.row( - cls, - repeats=repeats, - trigger=trigger, - position=position, - time1=time1, - outa1=outa1, - outb1=outb1, - outc1=outc1, - outd1=outd1, - oute1=oute1, - outf1=outf1, - time2=time2, - outa2=outa2, - outb2=outb2, - outc2=outc2, - outd2=outd2, - oute2=oute2, - outf2=outf2, - ) + sig = inspect.signature(cls.row) + kwargs = {k: v for k, v in locals().items() if k in sig.parameters} + if isinstance(kwargs["trigger"], SeqTrigger): + kwargs["trigger"] = kwargs["trigger"].value + return Table.row(cls, **kwargs) + + @field_validator("trigger", mode="before") + @classmethod + def trigger_to_np_array(cls, trigger_column): + """ + The user can provide a list of SeqTrigger enum elements instead of a numpy str. + """ + if isinstance(trigger_column, Sequence) and all( + isinstance(trigger, SeqTrigger) for trigger in trigger_column + ): + trigger_column = np.array( + [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": + """ + Used to check max_length. Unfortunately trying the `max_length` arg in + the pydantic field doesn't work + """ + + first_length = len(next(iter(self))[1]) + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + return self diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index f79ac0442c..d8b9a01269 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -8,7 +8,6 @@ from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, - PcompDirectionOptions, SeqTable, TimeUnits, phase_sorter, @@ -78,13 +77,13 @@ def check_equal_with_seq_tables(actual, expected): "data.hdf_file_name": "", "data.num_capture": 0, "pcap.arm": False, - "pcomp.1.dir": PcompDirectionOptions.positive, + "pcomp.1.dir": "Positive", "pcomp.1.enable": "ZERO", "pcomp.1.pulses": 0, "pcomp.1.start": 0, "pcomp.1.step": 0, "pcomp.1.width": 0, - "pcomp.2.dir": PcompDirectionOptions.positive, + "pcomp.2.dir": "Positive", "pcomp.2.enable": "ZERO", "pcomp.2.pulses": 0, "pcomp.2.start": 0, @@ -112,7 +111,7 @@ def check_equal_with_seq_tables(actual, expected): "repeats": [1], "time1": [0], "time2": [0], - "trigger": [""], + "trigger": ["Immediate"], }, "seq.1.repeats": 0, "seq.1.prescale": 0.0, diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index ba5ad3e2f9..b8b54e76e0 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -5,6 +5,7 @@ from pydantic import ValidationError from ophyd_async.fastcs.panda import SeqTable +from ophyd_async.fastcs.panda._table import SeqTrigger def test_seq_table_converts_lists(): @@ -16,7 +17,7 @@ def test_seq_table_converts_lists(): def test_seq_table_validation_errors(): - with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + with pytest.raises(ValidationError, match="80 validation errors for SeqTable"): SeqTable( repeats=0, trigger="", @@ -195,3 +196,31 @@ def test_seq_table_pva_conversion(): for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): assert np.array_equal(column1, column2) assert column1.dtype == column2.dtype + + +def test_seq_table_takes_trigger_enum_row(): + for trigger in (SeqTrigger.BITA_0, "BITA=0"): + table = SeqTable.row(trigger=trigger) + assert table.trigger[0] == "BITA=0" + assert np.issubdtype(table.trigger.dtype, np.dtype(" Date: Wed, 11 Sep 2024 13:45:57 +0100 Subject: [PATCH 10/11] made trigger validator more reliable and added test --- src/ophyd_async/fastcs/panda/_table.py | 16 +++++++++++ tests/fastcs/panda/test_trigger.py | 37 ++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index d2a31c53e4..ee6df7522f 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -103,8 +103,12 @@ def row( ) -> "SeqTable": sig = inspect.signature(cls.row) kwargs = {k: v for k, v in locals().items() if k in sig.parameters} + if isinstance(kwargs["trigger"], SeqTrigger): kwargs["trigger"] = kwargs["trigger"].value + elif isinstance(kwargs["trigger"], str): + SeqTrigger(kwargs["trigger"]) + return Table.row(cls, **kwargs) @field_validator("trigger", mode="before") @@ -119,6 +123,18 @@ def trigger_to_np_array(cls, trigger_column): trigger_column = np.array( [trigger.value for trigger in trigger_column], dtype=np.dtype(" Date: Wed, 11 Sep 2024 14:11:04 +0100 Subject: [PATCH 11/11] fixed bug --- tests/fastcs/panda/test_table.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index b8b54e76e0..c5f5abb846 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -17,10 +17,10 @@ def test_seq_table_converts_lists(): def test_seq_table_validation_errors(): - with pytest.raises(ValidationError, match="80 validation errors for SeqTable"): + with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): SeqTable( repeats=0, - trigger="", + trigger="Immediate", position=0, time1=0, outa1=False, @@ -40,7 +40,7 @@ def test_seq_table_validation_errors(): large_seq_table = SeqTable( repeats=np.zeros(4095, dtype=np.int32), - trigger=np.array([""] * 4095, dtype="U32"), + trigger=np.array(["Immediate"] * 4095, dtype="U32"), position=np.zeros(4095, dtype=np.int32), time1=np.zeros(4095, dtype=np.int32), outa1=np.zeros(4095, dtype=np.bool_),