From 27e05b891e3e3c2a5aa50613bf161eb2936a0bfb Mon Sep 17 00:00:00 2001 From: Trevor Bekolay Date: Fri, 30 Nov 2018 14:43:13 -0500 Subject: [PATCH] Add data fixture --- pytest_plt/__init__.py | 65 +++++++++++++++++++++++++++++++++++ pytest_plt/tests/test_data.py | 48 ++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 pytest_plt/tests/test_data.py diff --git a/pytest_plt/__init__.py b/pytest_plt/__init__.py index c6dee00..bc4d319 100644 --- a/pytest_plt/__init__.py +++ b/pytest_plt/__init__.py @@ -156,6 +156,50 @@ def __exit__(self, type, value, traceback): self.plt.close('all') +class Analytics(Recorder): + DOC_KEY = 'documentation' + + def __init__(self, dirname, module_name, function_name): + super(Analytics, self).__init__(dirname, module_name, function_name) + + self.data = {} + self.doc = {} + + @staticmethod + def load(path, module, function_name): + modparts = module.split('.') + modparts = modparts[1:] + modparts.remove('tests') + + return np.load(os.path.join(path, "%s.%s.npz" % ( + '.'.join(modparts), function_name))) + + def __enter__(self): + return self + + def add_data(self, name, data, doc=""): + if name == self.DOC_KEY: + raise ValueError("The name '{}' is reserved.".format(self.DOC_KEY)) + + if self.record: + self.data[name] = data + if doc != "": + self.doc[name] = doc + + def save_data(self): + if len(self.data) == 0: + return + + npz_data = dict(self.data) + if len(self.doc) > 0: + npz_data.update({self.DOC_KEY: self.doc}) + np.savez(self.get_filepath(ext='npz'), **npz_data) + + def __exit__(self, type, value, traceback): + if self.record: + self.save_data() + + @pytest.fixture def plt(request): """A pyplot-compatible plotting interface. @@ -175,3 +219,24 @@ def plt(request): plotter = Plotter(dirname, request.node.nodeid) request.addfinalizer(lambda: plotter.__exit__(None, None, None)) return plotter.__enter__() + + +@pytest.fixture +def data(request): + """An object to store data for analysis. + + Please use this if you're concerned that accuracy or speed may regress. + + This will keep saved data organized in a simulator-specific folder, + with an automatically generated name. Raw data (for later processing) + can be saved with ``analytics.add_raw_data``; these will be saved in + separate compressed ``.npz`` files. Summary data can be saved with + ``analytics.add_summary_data``; these will be saved + in a single ``.csv`` file. + """ + dirname = request.config.getvalue("data") + if not is_string(dirname): + dirname = "data" + analytics = Analytics(dirname, request.node.nodeid) + request.addfinalizer(lambda: analytics.__exit__(None, None, None)) + return analytics.__enter__() diff --git a/pytest_plt/tests/test_data.py b/pytest_plt/tests/test_data.py new file mode 100644 index 0000000..f71f5c7 --- /dev/null +++ b/pytest_plt/tests/test_data.py @@ -0,0 +1,48 @@ +import errno +import os + +import pytest + +from pytest_plt import Analytics + + +def test_analytics_empty(): + analytics = Analytics('nengo.simulator.analytics', + 'nengo.utils.tests.test_testing', + 'test_analytics_empty') + with analytics: + pass + path = analytics.get_filepath(ext='npz') + assert not os.path.exists(path) + + +def test_analytics_record(): + analytics = Analytics('nengo.simulator.analytics', + 'nengo.utils.tests.test_testing', + 'test_analytics_record') + with analytics: + analytics.add_data('test', 1, "Test analytics implementation") + assert analytics.data['test'] == 1 + assert analytics.doc['test'] == "Test analytics implementation" + with pytest.raises(ValueError): + analytics.add_data('documentation', '') + path = analytics.get_filepath(ext='npz') + assert os.path.exists(path) + os.remove(path) + # This will remove the analytics directory, only if it's empty + try: + os.rmdir(analytics.dirname) + except OSError as ex: + assert ex.errno == errno.ENOTEMPTY + + +def test_analytics_norecord(): + analytics = Analytics(None, + 'nengo.utils.tests.test_testing', + 'test_analytics_norecord') + with analytics: + analytics.add_data('test', 1, "Test analytics implementation") + assert 'test' not in analytics.data + assert 'test' not in analytics.doc + with pytest.raises(ValueError): + analytics.get_filepath(ext='npz')