diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 428126512..d6b95f089 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -31,6 +31,7 @@ else () endif () # Add each set of examples +add_subdirectory(edit) add_subdirectory(io) add_subdirectory(queries) add_subdirectory(planning) diff --git a/examples/cpp/edit/CMakeLists.txt b/examples/cpp/edit/CMakeLists.txt new file mode 100644 index 000000000..3e4369e8a --- /dev/null +++ b/examples/cpp/edit/CMakeLists.txt @@ -0,0 +1,5 @@ +# Binaries +add_executable(transform_map transform_map.cc) +set_wavemap_target_properties(transform_map) +target_link_libraries(transform_map PUBLIC + wavemap::wavemap_core wavemap::wavemap_io) diff --git a/examples/cpp/edit/transform_map.cc b/examples/cpp/edit/transform_map.cc new file mode 100644 index 000000000..68f730acb --- /dev/null +++ b/examples/cpp/edit/transform_map.cc @@ -0,0 +1,36 @@ +#include + +#include +#include +#include + +int main(int, char**) { + // Settings + std::filesystem::path input_map_path = + "/home/victor/data/wavemaps/newer_college_cloister_10cm.wvmp"; + std::filesystem::path output_map_path = + "/home/victor/data/wavemaps/newer_college_cloister_10cm_tranformed.wvmp"; + + // Create a smart pointer that will own the loaded map + wavemap::MapBase::Ptr map_base; + + // Load the map + bool success = wavemap::io::fileToMap(input_map_path, map_base); + CHECK(success); + + // Downcast it to a concrete map type + auto map = std::dynamic_pointer_cast(map_base); + CHECK_NOTNULL(map); + + // Define a transformation that flips the map upside down, for illustration + wavemap::Transformation3D T_AB; + T_AB.getRotation() = wavemap::Rotation3D{0.f, 1.f, 0.f, 0.f}; + + // Transform the map + map = wavemap::edit::transform(*map, T_AB, + std::make_shared()); + + // Save the map + success &= wavemap::io::mapToFile(*map, output_map_path); + CHECK(success); +} diff --git a/library/cpp/include/wavemap/core/utils/edit/sample.h b/library/cpp/include/wavemap/core/utils/edit/sample.h index 7a7c6a646..0de5b9dc5 100644 --- a/library/cpp/include/wavemap/core/utils/edit/sample.h +++ b/library/cpp/include/wavemap/core/utils/edit/sample.h @@ -2,6 +2,7 @@ #define WAVEMAP_CORE_UTILS_EDIT_SAMPLE_H_ #include +#include #include "wavemap/core/common.h" #include "wavemap/core/utils/thread_pool.h" @@ -11,7 +12,7 @@ namespace detail { template void sampleLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node, const OctreeIndex& node_index, FloatingPoint& node_value, - SamplingFn sampling_function, + SamplingFn&& sampling_function, FloatingPoint min_cell_width) { // Decompress child values using Transform = typename MapT::Block::Transform; @@ -24,7 +25,7 @@ void sampleLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node, const OctreeIndex child_index = node_index.computeChildIndex(child_idx); const Point3D t_W_child = convert::nodeIndexToCenterPoint(child_index, min_cell_width); - child_values[child_idx] = std::invoke(sampling_function, t_W_child); + child_values[child_idx] = sampling_function(t_W_child); } // Compress @@ -38,9 +39,9 @@ template void sampleNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node, const OctreeIndex& node_index, FloatingPoint& node_value, - SamplingFn sampling_function, + SamplingFn&& sampling_function, FloatingPoint min_cell_width, - IndexElement termination_height) { + IndexElement termination_height = 0) { using NodeRefType = decltype(node); // Decompress child values @@ -56,11 +57,12 @@ void sampleNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node, auto& child_value = child_values[child_idx]; if (child_index.height <= termination_height + 1) { sampleLeavesBatch(child_node, child_index, child_value, - sampling_function, min_cell_width); + std::forward(sampling_function), + min_cell_width); } else { sampleNodeRecursive(child_node, child_index, child_value, - sampling_function, min_cell_width, - termination_height); + std::forward(sampling_function), + min_cell_width, termination_height); } } @@ -84,6 +86,7 @@ void sample(MapT& map, SamplingFn sampling_function, const Index3D& block_index, auto& block) { // Indicate that the block has changed block.setLastUpdatedStamp(); + block.setNeedsPruning(); // Get pointers to the root value and node, which contain the wavelet // scale and detail coefficients, respectively diff --git a/library/cpp/include/wavemap/core/utils/edit/transform.h b/library/cpp/include/wavemap/core/utils/edit/transform.h index cd3325a0d..32dc8ed2c 100644 --- a/library/cpp/include/wavemap/core/utils/edit/transform.h +++ b/library/cpp/include/wavemap/core/utils/edit/transform.h @@ -15,18 +15,22 @@ #include "wavemap/core/utils/thread_pool.h" namespace wavemap::edit { -template -MapT transform(const MapT& B_map, const Transformation3D& T_AB, - const std::shared_ptr& thread_pool = nullptr) { +template +std::unique_ptr transform( + const MapT& B_map, const Transformation3D& T_AB, + const std::shared_ptr& thread_pool = nullptr) { + using NodePtrType = typename MapT::Block::OctreeType::NodePtrType; const IndexElement tree_height = B_map.getTreeHeight(); const FloatingPoint min_cell_width = B_map.getMinCellWidth(); - const FloatingPoint min_cell_width_inv = 1.f / min_cell_width; + const FloatingPoint block_width = + convert::heightToCellWidth(min_cell_width, tree_height); + const FloatingPoint block_width_inv = 1.f / block_width; // Allocate blocks in the result map - MapT A_map{B_map.getConfig()}; + auto A_map = std::make_unique(B_map.getConfig()); B_map.forEachBlock([&A_map, &T_AB, tree_height, min_cell_width, - min_cell_width_inv](const Index3D& block_index, - const auto& /*block*/) { + block_width_inv](const Index3D& block_index, + const auto& /*block*/) { AABB A_aabb{}; const auto B_block_aabb = convert::nodeIndexToAABB<3>({tree_height, block_index}, min_cell_width); @@ -36,23 +40,61 @@ MapT transform(const MapT& B_map, const Transformation3D& T_AB, A_aabb.includePoint(A_corner); } const Index3D A_block_index_min = - convert::pointToFloorIndex(A_aabb.min, min_cell_width_inv); + convert::pointToFloorIndex(A_aabb.min, block_width_inv); const Index3D A_block_index_max = - convert::pointToCeilIndex(A_aabb.max, min_cell_width_inv); + convert::pointToCeilIndex(A_aabb.max, block_width_inv); for (const auto& A_block_index : - Grid<3>(A_block_index_min, A_block_index_max)) { - A_map.getOrAllocateBlock(A_block_index); + Grid(A_block_index_min, A_block_index_max)) { + A_map->getOrAllocateBlock(A_block_index); } }); // Populate map A by interpolating map B - sample( - A_map, - [&B_map, &T_AB](const Point3D& A_point) { - const auto B_point = T_AB * A_point; - return interpolate::trilinear(B_map, B_point); - }, - std::move(thread_pool)); + const Transformation3D T_BA = T_AB.inverse(); + QueryAccelerator B_accelerator{B_map}; + A_map->forEachBlock( + [&B_accelerator, &T_BA, &thread_pool, tree_height, min_cell_width]( + const Index3D& block_index, auto& block) { + // Indicate that the block has changed + block.setLastUpdatedStamp(); + block.setNeedsPruning(); + + // Get pointers to the root value and node, which contain the wavelet + // scale and detail coefficients, respectively + FloatingPoint* root_value_ptr = &block.getRootScale(); + NodePtrType root_node_ptr = &block.getRootNode(); + const OctreeIndex root_node_index{tree_height, block_index}; + + // Recursively crop all nodes + if (thread_pool) { + thread_pool->add_task([B_accelerator, &T_BA, root_node_ptr, + root_node_index, root_value_ptr, + block_ptr = &block, min_cell_width]() mutable { + detail::sampleNodeRecursive( + *root_node_ptr, root_node_index, *root_value_ptr, + [&B_accelerator, &T_BA](const Point3D& A_point) { + const auto B_point = T_BA * A_point; + return interpolate::trilinear(B_accelerator, B_point); + }, + min_cell_width); + block_ptr->prune(); + }); + } else { + detail::sampleNodeRecursive( + *root_node_ptr, root_node_index, *root_value_ptr, + [&B_accelerator, &T_BA](const Point3D& A_point) { + const auto B_point = T_BA * A_point; + return interpolate::trilinear(B_accelerator, B_point); + }, + min_cell_width); + block.prune(); + } + }); + + // Wait for all parallel jobs to finish + if (thread_pool) { + thread_pool->wait_all(); + } return A_map; }