Skip to content

Commit

Permalink
Refactor Inlet to respond to settings switch on-the-fly.
Browse files Browse the repository at this point in the history
  • Loading branch information
cboulay committed Sep 30, 2024
1 parent 6b807c3 commit f6de923
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 89 deletions.
159 changes: 86 additions & 73 deletions src/ezmsg/lsl/inlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class LSLInletUnit(ez.Unit):
SETTINGS = LSLInletSettings
STATE = LSLInletState

INPUT_SETTINGS = ez.InputStream(LSLInletSettings)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

# Share clock correction across all instances
Expand All @@ -137,9 +138,15 @@ def __init__(self, *args, **kwargs) -> None:
"""
kwargs = _sanitize_kwargs(kwargs)
super().__init__(*args, **kwargs)
self._msg_template: typing.Optional[AxisArray] = None
self._fetch_buffer: typing.Optional[npt.NDArray] = None

def _reset_resolver(self) -> None:
self.STATE.resolver = pylsl.ContinuousResolver(pred=None, forget_after=30.0)

def _reset_inlet(self) -> None:
self._fetch_buffer: npt.NDArray | None = None
self._msg_template: typing.Optional[AxisArray] = None
self._fetch_buffer: typing.Optional[npt.NDArray] = None
if self.STATE.inlet is not None:
self.STATE.inlet.close_stream()
del self.STATE.inlet
Expand All @@ -166,90 +173,96 @@ def _reset_inlet(self) -> None:
self.STATE.inlet = pylsl.StreamInlet(
info, max_chunklen=1, processing_flags=self.SETTINGS.processing_flags
)
else:
results: list[pylsl.StreamInfo] = self.STATE.resolver.results()
for strm_info in results:
b_match = True
b_match = b_match and ((not self.SETTINGS.info.name) or strm_info.name() == self.SETTINGS.info.name)
b_match = b_match and ((not self.SETTINGS.info.type) or strm_info.type() == self.SETTINGS.info.type)
if b_match:
self.STATE.inlet = pylsl.StreamInlet(
strm_info, max_chunklen=1, processing_flags=self.SETTINGS.processing_flags
)
break

def _reset_resolver(self) -> None:
# Build the predicate string. This uses XPATH syntax and can filter on anything in the stream info. e.g.,
# `"name='BioSemi'" or "type='EEG' and starts-with(name,'BioSemi') and count(info/desc/channel)=32"`
pred = ""
if self.SETTINGS.info.name:
pred += f"name='{self.SETTINGS.info.name}'"
if self.SETTINGS.info.type:
if len(pred):
pred += " and "
pred += f"type='{self.SETTINGS.info.type}'"
if not len(pred):
pred = None
self.STATE.resolver = pylsl.ContinuousResolver(pred=pred)
if self.STATE.inlet is not None:
self.STATE.inlet.open_stream()
inlet_info = self.STATE.inlet.info()
self.SETTINGS.info.nominal_srate = inlet_info.nominal_srate()
# If possible, create a destination buffer for faster pulls
fmt = inlet_info.channel_format()
n_ch = inlet_info.channel_count()
if fmt in fmt2npdtype:
dtype = fmt2npdtype[fmt]
n_buff = (
int(self.SETTINGS.local_buffer_dur * inlet_info.nominal_srate()) or 1000
)
self._fetch_buffer = np.zeros((n_buff, n_ch), dtype=dtype)
ch_labels = []
chans = inlet_info.desc().child("channels")
if not chans.empty():
ch = chans.first_child()
while not ch.empty():
ch_labels.append(ch.child_value("label"))
ch = ch.next_sibling()
while len(ch_labels) < n_ch:
ch_labels.append(str(len(ch_labels) + 1))
# Pre-allocate a message template.
fs = inlet_info.nominal_srate()
self._msg_template = AxisArray(
data=np.empty((0, n_ch)),
dims=["time", "ch"],
axes={
"time": AxisArray.Axis.TimeAxis(
fs=fs if fs else 1.0
), # HACK: Use 1.0 for irregular rate.
"ch": AxisArray.Axis.SpaceAxis(labels=ch_labels),
},
key=inlet_info.name(),
)

async def initialize(self) -> None:
self._reset_inlet()
self._reset_resolver()
self._reset_inlet()
# TODO: Let the clock_sync task do its job at the beginning.

def shutdown(self) -> None:
if self.STATE.resolver is not None:
del self.STATE.resolver
self.STATE.resolver = None
if self.STATE.inlet is not None:
self.STATE.inlet.close_stream()
del self.STATE.inlet
self.STATE.inlet = None
if self.STATE.resolver is not None:
del self.STATE.resolver
self.STATE.resolver = None

@ez.task
async def clock_sync_task(self) -> None:
while True:
force = self.clock_sync.count < 1000
await self.clock_sync.update(force=force, burst=1000 if force else 4)

@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: LSLInletSettings) -> None:
# The message may be full LSLInletSettings, a dict of settings, just the info, or dict of just info.
if isinstance(msg, dict):
# First make sure the info is in the right place.
msg = _sanitize_kwargs(msg)
# Next, convert to LSLInletSettings object.
msg = LSLInletSettings(**msg)
if msg != self.SETTINGS:
self.apply_settings(msg)
self._reset_resolver()
self._reset_inlet()

@ez.publisher(OUTPUT_SIGNAL)
async def lsl_pull(self) -> typing.AsyncGenerator:
while self.STATE.inlet is None:
results: list[pylsl.StreamInfo] = self.STATE.resolver.results()
if len(results):
self.STATE.inlet = pylsl.StreamInlet(
results[0], max_chunklen=1, processing_flags=pylsl.proc_ALL
)
else:
await asyncio.sleep(0.5)

self.STATE.inlet.open_stream()
inlet_info = self.STATE.inlet.info()
# If possible, create a destination buffer for faster pulls
fmt = inlet_info.channel_format()
n_ch = inlet_info.channel_count()
if fmt in fmt2npdtype:
dtype = fmt2npdtype[fmt]
n_buff = (
int(self.SETTINGS.local_buffer_dur * inlet_info.nominal_srate()) or 1000
)
self._fetch_buffer = np.zeros((n_buff, n_ch), dtype=dtype)
ch_labels = []
chans = inlet_info.desc().child("channels")
if not chans.empty():
ch = chans.first_child()
while not ch.empty():
ch_labels.append(ch.child_value("label"))
ch = ch.next_sibling()
while len(ch_labels) < n_ch:
ch_labels.append(str(len(ch_labels) + 1))
# Pre-allocate a message template.
fs = inlet_info.nominal_srate()
msg_template = AxisArray(
data=np.empty((0, n_ch)),
dims=["time", "ch"],
axes={
"time": AxisArray.Axis.TimeAxis(
fs=fs if fs else 1.0
), # HACK: Use 1.0 for irregular rate.
"ch": AxisArray.Axis.SpaceAxis(labels=ch_labels),
},
key=inlet_info.name(),
)

while self.clock_sync.count < 1000:
# Let the clock_sync task do its job at the beginning.
await asyncio.sleep(0.001)

while self.STATE.inlet is not None:
while True:
if self.STATE.inlet is None:
# Inlet not yet created, or recently destroyed because settings changed.
self._reset_inlet()
await asyncio.sleep(0.1)
continue

if self._fetch_buffer is not None:
samples, timestamps = self.STATE.inlet.pull_chunk(
max_samples=self._fetch_buffer.shape[0], dest_obj=self._fetch_buffer
Expand All @@ -270,16 +283,16 @@ async def lsl_pull(self) -> typing.AsyncGenerator:
t0 = time.time() - (timestamps[-1] - timestamps[0])
else:
t0 = self.clock_sync.convert_timestamp(timestamps[0])
if fs <= 0.0:
if self.SETTINGS.info.nominal_srate <= 0.0:
# Irregular rate streams need to be streamed sample-by-sample
for ts, samp in zip(timestamps, data):
out_msg = replace(
msg_template,
self._msg_template,
data=samp[None, ...],
axes={
**msg_template.axes,
**self._msg_template.axes,
"time": replace(
msg_template.axes["time"],
self._msg_template.axes["time"],
offset=t0 + (ts - timestamps[0]),
),
},
Expand All @@ -288,11 +301,11 @@ async def lsl_pull(self) -> typing.AsyncGenerator:
else:
# Regular-rate streams can go in a chunk
out_msg = replace(
msg_template,
self._msg_template,
data=data,
axes={
**msg_template.axes,
"time": replace(msg_template.axes["time"], offset=t0),
**self._msg_template.axes,
"time": replace(self._msg_template.axes["time"], offset=t0),
},
)
yield self.OUTPUT_SIGNAL, out_msg
Expand Down
67 changes: 51 additions & 16 deletions tests/test_inlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
These unit tests aren't really testable in a runner without a complicated setup with inlets and outlets.
This code exists mostly to use during development and debugging.
"""
import os
import asyncio
import json
import os
from pathlib import Path
import tempfile
import typing

import numpy as np

Expand All @@ -21,6 +23,24 @@ def test_inlet_init_defaults():
assert True


class StreamSwitcher(ez.Unit):
STATE = ez.State
SETTINGS = ez.Settings
OUTPUT_SETTINGS = ez.OutputStream(LSLInletSettings)

@ez.publisher(OUTPUT_SETTINGS)
async def switch_stream(self) -> typing.AsyncGenerator:
switch_counter = 0

while True:
if switch_counter % 2 == 0:
yield self.OUTPUT_SETTINGS, LSLInletSettings(info=LSLInfo(type="ECoG"))
else:
yield self.OUTPUT_SETTINGS, LSLInletSettings(info=LSLInfo(type="Markers"))
switch_counter += 1
await asyncio.sleep(2)


class MessageReceiverSettings(ez.Settings):
num_msgs: int
output_fn: str
Expand All @@ -34,38 +54,53 @@ class AxarrReceiver(ez.Unit):
STATE = MessageReceiverState
SETTINGS = MessageReceiverSettings
INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SETTINGS = ez.OutputStream(LSLInletSettings)

@ez.subscriber(INPUT_SIGNAL)
async def on_message(self, msg: AxisArray) -> None:
self.STATE.num_received += 1
t_ax = msg.axes["time"]
tvec = np.arange(msg.data.shape[0]) * t_ax.gain + t_ax.offset
payload = {self.STATE.num_received: tvec.tolist()}
with open(self.SETTINGS.output_fn, "a") as output_file:
output_file.write(json.dumps(payload) + "\n")
try:
t_ax = msg.axes["time"]
tvec = np.arange(msg.data.shape[0]) * t_ax.gain + t_ax.offset
payload = {self.STATE.num_received: tvec.tolist()}
with open(self.SETTINGS.output_fn, "a") as output_file:
output_file.write(json.dumps(payload) + "\n")
except Exception as e:
print(f"Debug {e}")
if self.STATE.num_received == self.SETTINGS.num_msgs:
raise ez.NormalTermination


def test_inlet_init_with_settings():
test_name = os.environ.get("PYTEST_CURRENT_TEST")
if test_name is None:
test_name = "test_inlet:test_inlet_init_with_settings na"
test_name = test_name.split(":")[-1].split(" ")[0]
file_path = Path(tempfile.gettempdir())
file_path = file_path / Path(f"{test_name}.json")

comps = {
"SRC": LSLInletUnit(info=LSLInfo(name="BrainVision RDA", type="EEG")),
"SINK": AxarrReceiver(num_msgs=10_000, output_fn=file_path),
"SINK": AxarrReceiver(num_msgs=500, output_fn=file_path),
"FLIPFLOP": StreamSwitcher(),
}
conns = ((comps["SRC"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL),)
conns = (
(comps["FLIPFLOP"].OUTPUT_SETTINGS, comps["SRC"].INPUT_SETTINGS),
(comps["SRC"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL),
)
ez.run(components=comps, connections=conns)

tvecs = []
with open(file_path, "r") as file:
for ix, line in enumerate(file.readlines()):
tvecs.append(json.loads(line)[str(ix + 1)])
os.remove(str(file_path))
tvec = np.hstack(tvecs)
# tvecs = []
# with open(file_path, "r") as file:
# for ix, line in enumerate(file.readlines()):
# tmp = json.loads(line)
# tvecs.append(tmp[str(ix + 1)])
# os.remove(str(file_path))
# tvec = np.hstack(tvecs)
#
# # counts, bins = np.histogram(np.diff(tvec), 20)
# assert np.max(np.diff(tvec)) < 0.003


# counts, bins = np.histogram(np.diff(tvec), 20)
assert np.max(np.diff(tvec)) < 0.003
if __name__ == "__main__":
test_inlet_init_with_settings()

0 comments on commit f6de923

Please sign in to comment.