diff --git a/include/bbp/sonata/report_reader.h b/include/bbp/sonata/report_reader.h index 0fa7c255..6f0ddea0 100644 --- a/include/bbp/sonata/report_reader.h +++ b/include/bbp/sonata/report_reader.h @@ -98,13 +98,14 @@ class SONATA_API SpikeReader mutable std::map populations_; }; + using Range = std::pair; + using Ranges = std::vector; + using NodePointers = std::map; + template class SONATA_API ReportReader { public: - using Range = std::pair; - using Ranges = std::vector; - class Population { public: @@ -147,8 +148,7 @@ class SONATA_API ReportReader * \param fn lambda applied to all ranges for all node ids */ typename DataFrame::DataType getNodeIdElementIdMapping( - const nonstd::optional& node_ids = nonstd::nullopt, - std::function fn = nullptr) const; + const nonstd::optional& node_ids = nonstd::nullopt) const; /** * \param node_ids limit the report to the given selection. @@ -167,7 +167,7 @@ class SONATA_API ReportReader std::pair getIndex(const nonstd::optional& tstart, const nonstd::optional& tstop) const; - std::map nodes_pointers_; + NodePointers nodes_pointers_; H5::Group pop_group_; std::vector nodes_ids_; double tstart_, tstop_, tstep_; @@ -175,8 +175,10 @@ class SONATA_API ReportReader std::string time_units_; std::string data_units_; bool nodes_ids_sorted_ = false; - Selection::Values node_ids_from_selection( + std::pair node_pointers_from_selection( const nonstd::optional& node_ids = nonstd::nullopt) const; + typename DataFrame::DataType ids_from_node_pointers( + const std::pair& result) const; friend ReportReader; }; diff --git a/python/bindings.cpp b/python/bindings.cpp index 40514f9b..1bb30c3e 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -368,7 +368,7 @@ void bindReportReader(py::module& m, const std::string& prefix) { "get_node_id_element_id_mapping", [](const typename ReportType::Population& population, const nonstd::optional& selection) { - return population.getNodeIdElementIdMapping(selection, nullptr); + return population.getNodeIdElementIdMapping(selection); }, DOC_REPORTREADER_POP(getNodeIdElementIdMapping), "selection"_a = nonstd::nullopt) diff --git a/src/report_reader.cpp b/src/report_reader.cpp index 3723c7b2..033d7589 100644 --- a/src/report_reader.cpp +++ b/src/report_reader.cpp @@ -18,6 +18,7 @@ namespace { using bbp::sonata::CompartmentID; using bbp::sonata::ElementID; using bbp::sonata::NodeID; +using bbp::sonata::NodePointers; using bbp::sonata::Selection; using bbp::sonata::Spike; using bbp::sonata::Spikes; @@ -284,24 +285,70 @@ std::vector ReportReader::Population::getNodeIds() const { } template -Selection::Values ReportReader::Population::node_ids_from_selection( +std::pair ReportReader::Population::node_pointers_from_selection( const nonstd::optional& selection) const { - Selection::Values node_ids; - - if (!selection) { // Take all nodes in this case - node_ids.reserve(nodes_pointers_.size()); - std::transform(nodes_pointers_.begin(), - nodes_pointers_.end(), - std::back_inserter(node_ids), - [](const std::pair& node_pointer) { - return node_pointer.first; - }); - } else if (selection->empty()) { - return {}; - } else { - node_ids = selection->flatten(); + NodePointers node_pointers; + Range range = { std::numeric_limits::max(), + std::numeric_limits::min() }; + + const auto& update_range = [&range](const Range range_) { + range.first = std::min(range.first, range_.first); + range.second = std::max(range.second, range_.second); + }; + + // Take all nodes if no selection is provided + if (!selection) { + node_pointers = nodes_pointers_; + + for (const auto& node_pointer : node_pointers) { + update_range(node_pointer.second); + } + } + else if (!selection->empty()) { + const auto& node_ids = selection->flatten(); + + for (const auto& node_id : node_ids) { + const auto it = nodes_pointers_.find(node_id); + if (it != nodes_pointers_.end()) { + node_pointers.emplace(*it); + update_range(it->second); + } + } } - return node_ids; + + return { node_pointers, range }; +} + +template +typename DataFrame::DataType ReportReader::Population::ids_from_node_pointers( + const std::pair& result) const { + typename DataFrame::DataType ids{}; + + const auto& node_pointers = result.first; + const auto& min = result.second.first; + const auto& max = result.second.second; + + // typename DataFrame::DataType ids{}; + // ids.reserve(nids); + + if (!node_pointers.empty()) + { + std::vector element_ids; + auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids"); + dataset_elem_ids.select({min}, {max - min}).read(element_ids); + + for (const auto& node_pointer : node_pointers) { + const auto& node_id = node_pointer.first; + const auto& range = node_pointer.second; + + for (auto i = (range.first - min); i < (range.second - min); i++) + { + ids.emplace_back(make_key(node_id, element_ids[i])); + } + } + } + + return ids; } template @@ -341,30 +388,9 @@ std::pair ReportReader::Population::getIndex( template typename DataFrame::DataType ReportReader::Population::getNodeIdElementIdMapping( - const nonstd::optional& selection, std::function fn) const { - typename DataFrame::DataType ids{}; - - Selection::Values node_ids = node_ids_from_selection(selection); - - auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids"); - for (const auto& node_id : node_ids) { - const auto it = nodes_pointers_.find(node_id); - if (it == nodes_pointers_.end()) { - continue; - } - - std::vector element_ids(it->second.second - it->second.first); - dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first}) - .read(element_ids.data()); - for (const auto& elem : element_ids) { - ids.push_back(make_key(node_id, elem)); - } - - if (fn) { - fn(it->second); - } - } - return ids; + const nonstd::optional& selection) const { + const auto& result = node_pointers_from_selection(selection); + return ids_from_node_pointers(result); } template @@ -390,14 +416,12 @@ DataFrame ReportReader::Population::get(const nonstd::optional& // min and max offsets of the node_ids requested are calculated // to reduce the amount of IO that is brought to memory - Ranges positions; - uint64_t min = std::numeric_limits::max(); - uint64_t max = std::numeric_limits::min(); - data_frame.ids = getNodeIdElementIdMapping(selection, [&](const Range& range) { - min = std::min(range.first, min); - max = std::max(range.second, max); - positions.emplace_back(range.first, range.second); - }); + const auto& result = node_pointers_from_selection(selection); + const auto& node_pointers = result.first; + const auto& min = result.second.first; + const auto& max = result.second.second; + + data_frame.ids = ids_from_node_pointers(result); if (data_frame.ids.empty()) { // At the end no data available (wrong node_ids?) return DataFrame{{}, {}, {}}; } @@ -423,7 +447,8 @@ DataFrame ReportReader::Population::get(const nonstd::optional& off_t offset = 0; off_t data_offset = (timer_index - index_start) / stride; auto data_ptr = &data_frame.data[data_offset * n_ids]; - for (const auto& position : positions) { + for (const auto& node_pointer : node_pointers) { + const auto& position = node_pointer.second; uint64_t elements_per_gid = position.second - position.first; uint64_t gid_start = position.first - min;