Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow compute to return a generator instead of chunks #751

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions strax/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from copy import copy, deepcopy
import strax
import types

export, __all__ = strax.exporter()

Expand Down Expand Up @@ -484,10 +485,14 @@ class IterDone(Exception):
inputs_merged = {
kind: strax.Chunk.merge([inputs[d] for d in deps_of_kind])
for kind, deps_of_kind in self.dependencies_by_kind().items()}

# Submit the computation
# print(f"{self} calling with {inputs_merged}")
if self.parallel and executor is not None:
if inspect.isgeneratorfunction(self.compute):
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" uses an iterator as compute method. '
'This is not supported in multi-threading/processing.')
new_future = executor.submit(
self.do_compute,
chunk_i=chunk_i,
Expand All @@ -496,7 +501,11 @@ class IterDone(Exception):
pending_futures = [f for f in pending_futures if not f.done()]
yield new_future
else:
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)
chunk = self.do_compute(chunk_i=chunk_i, **inputs_merged)
if isinstance(chunk, types.GeneratorType):
yield from chunk
else:
yield chunk

except IterDone:
# Check all sources are exhausted.
Expand Down Expand Up @@ -605,9 +614,11 @@ def do_compute(self, chunk_i=None, **kwargs):
if self.compute_takes_start_end:
kwargs['start'] = start
kwargs['end'] = end

result = self.compute(**kwargs)

return self._fix_output(result, start, end)
if isinstance(result, types.GeneratorType):
return result
return self._fix_output(result, start, end)

def _fix_output(self, result, start, end, _dtype=None):
if self.multi_output and _dtype is None:
Expand Down
70 changes: 70 additions & 0 deletions strax/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,76 @@ def compute(self, peaks):
return dict(peak_classification=p,
lone_hits=lh)



# Plugins with time structure within chunks,
# used to test down chunking within plugin compute.
@strax.takes_config(
strax.Option('n_chunks', type=int, default=10, track=False),
strax.Option('recs_per_chunk', type=int, default=10, track=False),
)
class RecordsWithTimeStructure(strax.Plugin):
provides = 'records'
parallel = 'process'
depends_on = tuple()
dtype = strax.record_dtype()

rechunk_on_save = False

def source_finished(self):
return True

def is_ready(self, chunk_i):
return chunk_i < self.config['n_chunks']

def setup(self):
self.last_end = 0

def compute(self, chunk_i):

r = np.zeros(self.config['recs_per_chunk'], self.dtype)
r['time'] = self.last_end + np.arange(self.config['recs_per_chunk']) + 5
r['length'] = r['dt'] = 1
r['channel'] = np.arange(len(r))

end = self.last_end + self.config['recs_per_chunk'] + 10
chunk = self.chunk(start=self.last_end, end=end, data=r)
self.last_end = end

return chunk


class DownSampleRecords(strax.Plugin):
"""PLugin to test the downsampling of Chunks during compute. Needed
for simulations.
"""

provides = 'records_down_chunked'
depends_on = 'records'
dtype = strax.record_dtype()
rechunk_on_save = False
parallel='process'

def compute(self, records, start, end):
offset = 0
last_start = start

count=0
for count, r in enumerate(records):
if count == 5:
res = records[offset:count]
chunk_end = np.max(strax.endtime(res))
offset = count
chunk = self.chunk(start=last_start, end=chunk_end, data=res)
last_start = chunk_end
yield chunk

res = records[offset:count+1]
chunk = self.chunk(start=last_start, end=end, data=res)
yield chunk



# Used in test_core.py
run_id = '0'

Expand Down
34 changes: 31 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import strax
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, RecordsWithTimeStructure, DownSampleRecords, run_id
import tempfile
import numpy as np
from hypothesis import given, settings
Expand Down Expand Up @@ -215,6 +215,33 @@ def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)

def test_down_chunking(self):
st = self.get_context(False)
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, 'records')
st.make(run_id, 'records_down_chunked')

chunks_records = st.get_meta(run_id, 'records')['chunks']
chunks_records_down_chunked = st.get_meta(run_id, 'records_down_chunked')['chunks']

_chunks_are_downsampled = len(chunks_records)*2 == len(chunks_records_down_chunked)
assert _chunks_are_downsampled

_chunks_are_continues = np.all([chunks_records_down_chunked[i]['end'] == chunks_records_down_chunked[i+1]['start'] for i in range(len(chunks_records_down_chunked)-1)])
assert _chunks_are_continues

def test_down_chunking_multi_processing(self):
st = self.get_context(False, allow_multiprocess=True)
st.set_context_config({'use_per_run_defaults': False})
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, 'records', max_workers=1)
with self.assertRaises(NotImplementedError):
st.make(run_id, 'records_down_chunked', max_workers=2)

def test_get_plugins_with_cache(self):
st = self.get_context(False)
st.register(Records)
Expand Down Expand Up @@ -283,11 +310,12 @@ def test_deregister(self):
st.deregister_plugins_with_missing_dependencies()
assert st._plugin_class_registry.pop('peaks', None) is None

def get_context(self, use_defaults):
def get_context(self, use_defaults, **kwargs):
"""Get simple context where we have one mock run in the only storage frontend"""
assert isinstance(use_defaults, bool)
st = strax.Context(storage=self.get_mock_sf(),
check_available=('records',)
check_available=('records',),
**kwargs
)
st.set_context_config({'use_per_run_defaults': use_defaults})
return st
Expand Down