Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add data fixture #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions pytest_plt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,50 @@ def __exit__(self, type, value, traceback):
self.plt.close('all')


class Analytics(Recorder):
DOC_KEY = 'documentation'

def __init__(self, dirname, module_name, function_name):
super(Analytics, self).__init__(dirname, module_name, function_name)

self.data = {}
self.doc = {}

@staticmethod
def load(path, module, function_name):
modparts = module.split('.')
modparts = modparts[1:]
modparts.remove('tests')

return np.load(os.path.join(path, "%s.%s.npz" % (
'.'.join(modparts), function_name)))

def __enter__(self):
return self

def add_data(self, name, data, doc=""):
if name == self.DOC_KEY:
raise ValueError("The name '{}' is reserved.".format(self.DOC_KEY))

if self.record:
self.data[name] = data
if doc != "":
self.doc[name] = doc

def save_data(self):
if len(self.data) == 0:
return

npz_data = dict(self.data)
if len(self.doc) > 0:
npz_data.update({self.DOC_KEY: self.doc})
np.savez(self.get_filepath(ext='npz'), **npz_data)

def __exit__(self, type, value, traceback):
if self.record:
self.save_data()


@pytest.fixture
def plt(request):
"""A pyplot-compatible plotting interface.
Expand All @@ -175,3 +219,24 @@ def plt(request):
plotter = Plotter(dirname, request.node.nodeid)
request.addfinalizer(lambda: plotter.__exit__(None, None, None))
return plotter.__enter__()


@pytest.fixture
def data(request):
"""An object to store data for analysis.

Please use this if you're concerned that accuracy or speed may regress.

This will keep saved data organized in a simulator-specific folder,
with an automatically generated name. Raw data (for later processing)
can be saved with ``analytics.add_raw_data``; these will be saved in
separate compressed ``.npz`` files. Summary data can be saved with
``analytics.add_summary_data``; these will be saved
in a single ``.csv`` file.
"""
dirname = request.config.getvalue("data")
if not is_string(dirname):
dirname = "data"
analytics = Analytics(dirname, request.node.nodeid)
request.addfinalizer(lambda: analytics.__exit__(None, None, None))
return analytics.__enter__()
48 changes: 48 additions & 0 deletions pytest_plt/tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import errno
import os

import pytest

from pytest_plt import Analytics


def test_analytics_empty():
analytics = Analytics('nengo.simulator.analytics',
'nengo.utils.tests.test_testing',
'test_analytics_empty')
with analytics:
pass
path = analytics.get_filepath(ext='npz')
assert not os.path.exists(path)


def test_analytics_record():
analytics = Analytics('nengo.simulator.analytics',
'nengo.utils.tests.test_testing',
'test_analytics_record')
with analytics:
analytics.add_data('test', 1, "Test analytics implementation")
assert analytics.data['test'] == 1
assert analytics.doc['test'] == "Test analytics implementation"
with pytest.raises(ValueError):
analytics.add_data('documentation', '')
path = analytics.get_filepath(ext='npz')
assert os.path.exists(path)
os.remove(path)
# This will remove the analytics directory, only if it's empty
try:
os.rmdir(analytics.dirname)
except OSError as ex:
assert ex.errno == errno.ENOTEMPTY


def test_analytics_norecord():
analytics = Analytics(None,
'nengo.utils.tests.test_testing',
'test_analytics_norecord')
with analytics:
analytics.add_data('test', 1, "Test analytics implementation")
assert 'test' not in analytics.data
assert 'test' not in analytics.doc
with pytest.raises(ValueError):
analytics.get_filepath(ext='npz')