Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves IO performance for the report readers #104

Merged
merged 6 commits into from
Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions include/bbp/sonata/report_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include <highfive/H5File.hpp>

#include <bbp/sonata/population.h>
#include <bbp/sonata/optional.hpp>
#include <bbp/sonata/population.h>

namespace H5 = HighFive;

Expand Down Expand Up @@ -97,6 +97,9 @@ template <typename KeyType>
class SONATA_API ReportReader
{
public:
using Range = std::pair<uint64_t, uint64_t>;
using Ranges = std::vector<Range>;

class Population
{
public:
Expand All @@ -123,19 +126,22 @@ class SONATA_API ReportReader

/**
* \param node_ids limit the report to the given selection.
* \param tstart return spikes occurring on or after tstart. tstart=nonstd::nullopt
* indicates no limit. \param tstop return spikes occurring on or before tstop.
* tstop=nonstd::nullopt indicates no limit.
* \param tstart return voltages occurring on or after tstart. tstart=nonstd::nullopt
* indicates no limit. \param tstop return voltages occurring on or before tstop.
sergiorg-hpc marked this conversation as resolved.
Show resolved Hide resolved
* tstop=nonstd::nullopt indicates no limit. \param tstride indicates every how many
* timesteps we read data. tstride=nonstd::nullopt indicates that all timesteps are read.
*/
DataFrame<KeyType> get(const nonstd::optional<Selection>& node_ids = nonstd::nullopt,
const nonstd::optional<double>& tstart = nonstd::nullopt,
const nonstd::optional<double>& tstop = nonstd::nullopt) const;
const nonstd::optional<double>& tstop = nonstd::nullopt,
const nonstd::optional<size_t>& tstride = nonstd::nullopt) const;

private:
Population(const H5::File& file, const std::string& populationName);
std::pair<size_t, size_t> getIndex(const nonstd::optional<double>& tstart, const nonstd::optional<double>& tstop) const;
std::pair<size_t, size_t> getIndex(const nonstd::optional<double>& tstart,
const nonstd::optional<double>& tstop) const;

std::vector<std::pair<NodeID, std::pair<uint64_t, uint64_t>>> nodes_pointers_;
std::map<NodeID, Range> nodes_pointers_;
H5::Group pop_group_;
std::vector<NodeID> nodes_ids_;
double tstart_, tstop_, tstep_;
Expand Down
6 changes: 4 additions & 2 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,12 @@ void bindReportReader(py::module& m, const std::string& prefix) {
"A population inside a ReportReader")
.def("get",
&ReportType::Population::get,
"Return reports with all those node_ids between 'tstart' and 'tstop'",
"Return reports with all those node_ids between 'tstart' and 'tstop' with a stride "
"tstride",
"node_ids"_a = nonstd::nullopt,
"tstart"_a = nonstd::nullopt,
"tstop"_a = nonstd::nullopt)
"tstop"_a = nonstd::nullopt,
"tstride"_a = nonstd::nullopt)
.def("get_node_ids",
&ReportType::Population::getNodeIds,
"Return the list of nodes ids for this population")
Expand Down
12 changes: 8 additions & 4 deletions python/generated/docstrings.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,16 @@ R"doc(Parameter ``node_ids``:
limit the report to the given selection.

Parameter ``tstart``:
return spikes occurring on or after tstart. tstart=nonstd::nullopt
indicates no limit.
return voltages occurring on or after tstart.
tstart=nonstd::nullopt indicates no limit.

Parameter ``tstop``:
return spikes occurring on or before tstop. tstop=nonstd::nullopt
indicates no limit.)doc";
return voltages occurring on or before tstop.
tstop=nonstd::nullopt indicates no limit.

Parameter ``tstride``:
indicates every how many timesteps we read data.
tstride=nonstd::nullopt indicates that all timesteps are read.)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getDataUnits = R"doc(Return the unit of data.)doc";

Expand Down
15 changes: 12 additions & 3 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,21 @@ def test_get_reports_from_population(self):
self.assertEqual(self.test_obj['All'].times, (0., 1., 0.1))
self.assertEqual(self.test_obj['All'].time_units, 'ms')
self.assertEqual(self.test_obj['All'].data_units, 'mV')
self.assertTrue(self.test_obj['All'].sorted)
self.assertFalse(self.test_obj['All'].sorted)
self.assertEqual(len(self.test_obj['All'].get().ids), 20) # Number of nodes
self.assertEqual(len(self.test_obj['All'].get().times), 10) # number of times
self.assertEqual(len(self.test_obj['All'].get().data), 10) # should be the same

sel = self.test_obj['All'].get(node_ids=[13, 14], tstart=0.8, tstop=1.0)
self.assertEqual(len(sel.times), 2) # Number of timestamp (0.8 and 0.9)
self.assertEqual(list(sel.ids), [13, 14])
np.testing.assert_allclose(sel.data, [[13.8, 14.8], [13.9, 14.9]])

sel_all = self.test_obj['All'].get()
jorblancoa marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(sel_all.ids, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

sel_empty = self.test_obj['All'].get(node_ids=[])
np.testing.assert_allclose(sel_empty.data, np.empty(shape=(0, 0)))

class TestElementReportPopulation(unittest.TestCase):
def setUp(self):
Expand All @@ -308,8 +316,8 @@ def test_get_reports_from_population(self):
self.assertEqual(self.test_obj['All'].time_units, 'ms')
self.assertEqual(self.test_obj['All'].data_units, 'mV')
self.assertTrue(self.test_obj['All'].sorted)
self.assertEqual(len(self.test_obj['All'].get().data), 20) # Number of times in this range
self.assertEqual(len(self.test_obj['All'].get().times), 20) # Should be the same
self.assertEqual(len(self.test_obj['All'].get(tstride=2).data), 10) # Number of times in this range
self.assertEqual(len(self.test_obj['All'].get(tstride=2).times), 10) # Should be the same
self.assertEqual(len(self.test_obj['All'].get().ids), 100)
sel = self.test_obj['All'].get(node_ids=[13, 14], tstart=0.8, tstop=1.2)
keys = list(sel.ids)
Expand All @@ -327,6 +335,7 @@ def test_get_reports_from_population(self):
# check following calls succeed (no memory destroyed)
np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[1, 2], tstart=3., tstop=3.).data[0], [150.0, 150.1, 150.2, 150.3, 150.4, 150.5, 150.6, 150.7, 150.8, 150.9])
np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstart=0.2, tstop=0.4).data[0], [11.0, 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9], 1e-6, 0)
np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstride=4).data[2], [81.0, 81.1, 81.2, 81.3, 81.4, 81.5, 81.6, 81.7, 81.8, 81.9], 1e-6, 0)

if __name__ == '__main__':
unittest.main()
99 changes: 48 additions & 51 deletions src/report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ ReportReader<T>::Population::Population(const H5::File& file, const std::string&
mapping_group.getDataSet("index_pointers").read(index_pointers);

for (size_t i = 0; i < nodes_ids_.size(); ++i) {
nodes_pointers_.emplace_back(nodes_ids_[i],
std::make_pair(index_pointers[i], index_pointers[i + 1]));
nodes_pointers_.emplace(nodes_ids_[i],
std::make_pair(index_pointers[i], index_pointers[i + 1]));
}

{ // Get times
Expand Down Expand Up @@ -315,18 +315,21 @@ std::pair<size_t, size_t> ReportReader<T>::Population::getIndex(
template <typename T>
DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>& selection,
const nonstd::optional<double>& tstart,
const nonstd::optional<double>& tstop) const {
const nonstd::optional<double>& tstop,
const nonstd::optional<size_t>& tstride) const {
DataFrame<T> data_frame;

size_t index_start = 0;
size_t index_stop = 0;
std::tie(index_start, index_stop) = getIndex(tstart, tstop);

const size_t stride = tstride.value_or(1);
if (stride == 0) {
throw SonataError("tstride should be > 0");
}
if (index_start > index_stop) {
throw SonataError("tstart should be <= to tstop");
}

for (size_t i = index_start; i <= index_stop; ++i) {
for (size_t i = index_start; i <= index_stop; i += stride) {
data_frame.times.push_back(times_index_[i].second);
}

Expand All @@ -337,10 +340,11 @@ DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>&
Selection::Values node_ids;

if (!selection) { // Take all nodes in this case
node_ids.reserve(nodes_pointers_.size());
std::transform(nodes_pointers_.begin(),
nodes_pointers_.end(),
std::back_inserter(node_ids),
[](const std::pair<NodeID, std::pair<uint64_t, uint64_t>>& node_pointer) {
[](const std::pair<NodeID, Range>& node_pointer) {
return node_pointer.first;
});
} else if (selection->empty()) {
Expand All @@ -349,22 +353,24 @@ DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>&
node_ids = selection->flatten();
sergiorg-hpc marked this conversation as resolved.
Show resolved Hide resolved
sergiorg-hpc marked this conversation as resolved.
Show resolved Hide resolved
}

Ranges positions;
// min and max offsets of the node_ids requested are calculated
// to reduce the amount of IO that is brought to memory
uint64_t min = std::numeric_limits<uint64_t>::max();
uint64_t max = std::numeric_limits<uint64_t>::min();
auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids");
for (const auto& node_id : node_ids) {
const auto it = std::find_if(
nodes_pointers_.begin(),
nodes_pointers_.end(),
[&node_id](const std::pair<NodeID, std::pair<NodeID, uint64_t>>& node_pointer) {
return node_pointer.first == node_id;
});
const auto it = nodes_pointers_.find(node_id);
mgeplf marked this conversation as resolved.
Show resolved Hide resolved
if (it == nodes_pointers_.end()) {
continue;
}
min = std::min(it->second.first, min);
max = std::max(it->second.second, max);
positions.emplace_back(it->second.first, it->second.second);

std::vector<ElementID> element_ids;
pop_group_.getGroup("mapping")
.getDataSet("element_ids")
.select({it->second.first}, {it->second.second - it->second.first})
.read(element_ids);
std::vector<ElementID> element_ids(it->second.second - it->second.first);
dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first})
.read(element_ids.data());
for (const auto& elem : element_ids) {
data_frame.ids.push_back(make_key<T>(node_id, elem));
}
Expand All @@ -374,43 +380,34 @@ DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>&
}

// Fill .data member

auto n_time_entries = index_stop - index_start + 1;
auto n_ids = data_frame.ids.size();
size_t n_time_entries = ((index_stop - index_start) / stride) + 1;
size_t n_ids = data_frame.ids.size();
data_frame.data.resize(n_time_entries * n_ids);

// FIXME: It will be good to do it for ranges but if node_ids are not sorted it is not easy
// TODO: specialized this function for sorted node_ids?
int ids_index = 0;
for (const auto& node_id : node_ids) {
const auto it = std::find_if(
nodes_pointers_.begin(),
nodes_pointers_.end(),
[&node_id](const std::pair<NodeID, std::pair<uint64_t, uint64_t>>& node_pointer) {
return node_pointer.first == node_id;
});
if (it == nodes_pointers_.end()) {
continue;
}

// elems are by timestamp and by Nodes_id
std::vector<std::vector<float>> data;
pop_group_.getDataSet("data")
.select({index_start, it->second.first},
{index_stop - index_start + 1, it->second.second - it->second.first})
.read(data);

int timer_index = 0;

for (const std::vector<float>& datum : data) {
std::copy(datum.data(),
datum.data() + datum.size(),
&data_frame.data[timer_index * n_ids + ids_index]);
++timer_index;
std::vector<float> buffer(max - min);
auto dataset = pop_group_.getDataSet("data");
sergiorg-hpc marked this conversation as resolved.
Show resolved Hide resolved
mgeplf marked this conversation as resolved.
Show resolved Hide resolved
for (size_t timer_index = index_start; timer_index <= index_stop; timer_index += stride) {
// Note: The code assumes that the file is chunked by rows and not by columns
// (i.e., if the chunking changes in the future, the reading method must also be adapted)
dataset.select({timer_index, min}, {1, max - min}).read(buffer.data());
sergiorg-hpc marked this conversation as resolved.
Show resolved Hide resolved

off_t offset = 0;
off_t data_offset = (timer_index - index_start) / stride;
auto data_ptr = &data_frame.data[data_offset * n_ids];
for (const auto& position : positions) {
uint64_t elements_per_gid = position.second - position.first;
uint64_t gid_start = position.first - min;

// Soma report
if (elements_per_gid == 1) {
data_ptr[offset] = buffer[gid_start];
} else { // Elements report
uint64_t gid_end = position.second - min;
std::copy(&buffer[gid_start], &buffer[gid_end], &data_ptr[offset]);
}
offset += elements_per_gid;
}
ids_index += data[0].size();
}

return data_frame;
}

Expand Down
3 changes: 1 addition & 2 deletions tests/data/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def write_edges(filepath):

def write_soma_report(filepath):
population_names = ['All', 'soma1', 'soma2']
node_ids = np.arange(1, 21)
node_ids = np.concatenate((np.arange(10, 21), np.arange(1, 10)), axis=None)
index_pointers = np.arange(0, 21)
element_ids = np.zeros(20)
times = (0.0, 1.0, 0.1)
Expand All @@ -148,7 +148,6 @@ def write_soma_report(filepath):
gmapping = h5f.create_group('/report/' + population_names[0] + '/mapping')

dnodes = gmapping.create_dataset('node_ids', data=node_ids, dtype=np.uint64)
dnodes.attrs.create('sorted', data=True, dtype=np.uint8)
gmapping.create_dataset('index_pointers', data=index_pointers, dtype=np.uint64)
gmapping.create_dataset('element_ids', data=element_ids, dtype=np.uint32)
dtimes = gmapping.create_dataset('time', data=times, dtype=np.double)
Expand Down
Binary file modified tests/data/somas.h5
Binary file not shown.
16 changes: 13 additions & 3 deletions tests/test_report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,22 @@ TEST_CASE("SomaReportReader", "[base]") {

REQUIRE(pop.getDataUnits() == "mV");

REQUIRE(pop.getSorted());
REQUIRE(pop.getSorted() == false);

REQUIRE(pop.getNodeIds() == std::vector<NodeID>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
REQUIRE(pop.getNodeIds() == std::vector<NodeID>{10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 1, 2, 3, 4, 5, 6, 7, 8, 9});

auto data = pop.get(Selection({{3, 5}}), 0.2, 0.5);
REQUIRE(data.ids == DataFrame<NodeID>::DataType{{3, 4}});
testTimes(data.times, 0.2, 0.1, 4);
REQUIRE(data.data == std::vector<float>{3.2f, 4.2f, 3.3f, 4.3f, 3.4f, 4.4f, 3.5f, 4.5f});

auto data_all = pop.get();
REQUIRE(data_all.ids == DataFrame<NodeID>::DataType{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20}});

auto data_empty = pop.get(Selection({}));
REQUIRE(data_empty.data == std::vector<float>{});
}

TEST_CASE("ElementReportReader limits", "[base]") {
Expand All @@ -106,6 +113,9 @@ TEST_CASE("ElementReportReader limits", "[base]") {

// Negatives times
REQUIRE_THROWS(pop.get(Selection({{1, 2}}), -1., -2.));

// Stride = 0
REQUIRE_THROWS(pop.get(Selection({{1, 2}}), 0.1, 0.2, 0));
}

TEST_CASE("ElementReportReader", "[base]") {
Expand Down