Skip to content

Commit

Permalink
made trigger validator more reliable and added test
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Sep 11, 2024
1 parent 49a047b commit 695e1db
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
16 changes: 16 additions & 0 deletions src/ophyd_async/fastcs/panda/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ def row(
) -> "SeqTable":
sig = inspect.signature(cls.row)
kwargs = {k: v for k, v in locals().items() if k in sig.parameters}

if isinstance(kwargs["trigger"], SeqTrigger):
kwargs["trigger"] = kwargs["trigger"].value
elif isinstance(kwargs["trigger"], str):
SeqTrigger(kwargs["trigger"])

return Table.row(cls, **kwargs)

@field_validator("trigger", mode="before")
Expand All @@ -119,6 +123,18 @@ def trigger_to_np_array(cls, trigger_column):
trigger_column = np.array(
[trigger.value for trigger in trigger_column], dtype=np.dtype("<U32")
)
elif isinstance(trigger_column, Sequence) or isinstance(
trigger_column, np.ndarray
):
for trigger in trigger_column:
SeqTrigger(
trigger
) # To check all the given strings are actually `SeqTrigger`s
else:
raise ValueError(
"Expected a numpy array or a sequence of `SeqTrigger`, got "
f"{type(trigger_column)}."
)
return trigger_column

@model_validator(mode="after")
Expand Down
37 changes: 35 additions & 2 deletions tests/fastcs/panda/test_trigger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

import numpy as np
import pytest
from pydantic import ValidationError

Expand Down Expand Up @@ -109,7 +110,39 @@ async def set_active(value: bool):
],
)
def test_malformed_seq_table_info(kwargs, error_msg):
kwargs["sequence_table"] = kwargs.pop("sequence_table_factory")()
with pytest.raises(ValidationError) as exc:
SeqTableInfo(**kwargs)
SeqTableInfo(sequence_table=kwargs.pop("sequence_table_factory")(), **kwargs)
assert error_msg in str(exc.value)


def test_malformed_trigger_in_seq_table():
def full_seq_table(trigger):
SeqTable(
repeats=np.array([1], dtype=np.int32),
trigger=trigger,
position=np.array([1], dtype=np.int32),
time1=np.array([1], dtype=np.int32),
outa1=np.array([1], dtype=np.bool_),
outb1=np.array([1], dtype=np.bool_),
outc1=np.array([1], dtype=np.bool_),
outd1=np.array([1], dtype=np.bool_),
oute1=np.array([1], dtype=np.bool_),
outf1=np.array([1], dtype=np.bool_),
time2=np.array([1], dtype=np.int32),
outa2=np.array([1], dtype=np.bool_),
outb2=np.array([1], dtype=np.bool_),
outc2=np.array([1], dtype=np.bool_),
outd2=np.array([1], dtype=np.bool_),
oute2=np.array([1], dtype=np.bool_),
outf2=np.array([1], dtype=np.bool_),
)

with pytest.raises(ValidationError) as exc:
full_seq_table(np.array(["A"], dtype="U32"))
assert "Value error, 'A' is not a valid SeqTrigger" in str(exc)
with pytest.raises(ValidationError) as exc:
full_seq_table(["A"])
assert "Value error, 'A' is not a valid SeqTrigger" in str(exc)
with pytest.raises(ValidationError) as exc:
full_seq_table({"A"})
assert "Expected a numpy array or a sequence of `SeqTrigger`, got" in str(exc)

0 comments on commit 695e1db

Please sign in to comment.