Skip to content

Commit

Permalink
Improve map transform method
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Dec 7, 2024
1 parent dd42058 commit e87d0e5
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 25 deletions.
1 change: 1 addition & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ else ()
endif ()

# Add each set of examples
add_subdirectory(edit)
add_subdirectory(io)
add_subdirectory(queries)
add_subdirectory(planning)
5 changes: 5 additions & 0 deletions examples/cpp/edit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Binaries
add_executable(transform_map transform_map.cc)
set_wavemap_target_properties(transform_map)
target_link_libraries(transform_map PUBLIC
wavemap::wavemap_core wavemap::wavemap_io)
36 changes: 36 additions & 0 deletions examples/cpp/edit/transform_map.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <memory>

#include <wavemap/core/common.h>
#include <wavemap/core/utils/edit/transform.h>
#include <wavemap/io/file_conversions.h>

int main(int, char**) {
// Settings
std::filesystem::path input_map_path =
"/home/victor/data/wavemaps/newer_college_cloister_10cm.wvmp";
std::filesystem::path output_map_path =
"/home/victor/data/wavemaps/newer_college_cloister_10cm_tranformed.wvmp";

// Create a smart pointer that will own the loaded map
wavemap::MapBase::Ptr map_base;

// Load the map
bool success = wavemap::io::fileToMap(input_map_path, map_base);
CHECK(success);

// Downcast it to a concrete map type
auto map = std::dynamic_pointer_cast<wavemap::HashedWaveletOctree>(map_base);
CHECK_NOTNULL(map);

// Define a transformation that flips the map upside down, for illustration
wavemap::Transformation3D T_AB;
T_AB.getRotation() = wavemap::Rotation3D{0.f, 1.f, 0.f, 0.f};

// Transform the map
map = wavemap::edit::transform(*map, T_AB,
std::make_shared<wavemap::ThreadPool>());

// Save the map
success &= wavemap::io::mapToFile(*map, output_map_path);
CHECK(success);
}
17 changes: 10 additions & 7 deletions library/cpp/include/wavemap/core/utils/edit/sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define WAVEMAP_CORE_UTILS_EDIT_SAMPLE_H_

#include <memory>
#include <utility>

#include "wavemap/core/common.h"
#include "wavemap/core/utils/thread_pool.h"
Expand All @@ -11,7 +12,7 @@ namespace detail {
template <typename MapT, typename SamplingFn>
void sampleLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
SamplingFn sampling_function,
SamplingFn&& sampling_function,
FloatingPoint min_cell_width) {
// Decompress child values
using Transform = typename MapT::Block::Transform;
Expand All @@ -24,7 +25,7 @@ void sampleLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex child_index = node_index.computeChildIndex(child_idx);
const Point3D t_W_child =
convert::nodeIndexToCenterPoint(child_index, min_cell_width);
child_values[child_idx] = std::invoke(sampling_function, t_W_child);
child_values[child_idx] = sampling_function(t_W_child);
}

// Compress
Expand All @@ -38,9 +39,9 @@ template <typename MapT, typename SamplingFn>
void sampleNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index,
FloatingPoint& node_value,
SamplingFn sampling_function,
SamplingFn&& sampling_function,
FloatingPoint min_cell_width,
IndexElement termination_height) {
IndexElement termination_height = 0) {
using NodeRefType = decltype(node);

// Decompress child values
Expand All @@ -56,11 +57,12 @@ void sampleNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
auto& child_value = child_values[child_idx];
if (child_index.height <= termination_height + 1) {
sampleLeavesBatch<MapT>(child_node, child_index, child_value,
sampling_function, min_cell_width);
std::forward<SamplingFn>(sampling_function),
min_cell_width);
} else {
sampleNodeRecursive<MapT>(child_node, child_index, child_value,
sampling_function, min_cell_width,
termination_height);
std::forward<SamplingFn>(sampling_function),
min_cell_width, termination_height);
}
}

Expand All @@ -84,6 +86,7 @@ void sample(MapT& map, SamplingFn sampling_function,
const Index3D& block_index, auto& block) {
// Indicate that the block has changed
block.setLastUpdatedStamp();
block.setNeedsPruning();

// Get pointers to the root value and node, which contain the wavelet
// scale and detail coefficients, respectively
Expand Down
78 changes: 60 additions & 18 deletions library/cpp/include/wavemap/core/utils/edit/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@
#include "wavemap/core/utils/thread_pool.h"

namespace wavemap::edit {
template <typename MapT = HashedWaveletOctree>
MapT transform(const MapT& B_map, const Transformation3D& T_AB,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr) {
template <typename MapT>
std::unique_ptr<MapT> transform(
const MapT& B_map, const Transformation3D& T_AB,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr) {
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
const IndexElement tree_height = B_map.getTreeHeight();
const FloatingPoint min_cell_width = B_map.getMinCellWidth();
const FloatingPoint min_cell_width_inv = 1.f / min_cell_width;
const FloatingPoint block_width =
convert::heightToCellWidth(min_cell_width, tree_height);
const FloatingPoint block_width_inv = 1.f / block_width;

// Allocate blocks in the result map
MapT A_map{B_map.getConfig()};
auto A_map = std::make_unique<MapT>(B_map.getConfig());
B_map.forEachBlock([&A_map, &T_AB, tree_height, min_cell_width,
min_cell_width_inv](const Index3D& block_index,
const auto& /*block*/) {
block_width_inv](const Index3D& block_index,
const auto& /*block*/) {
AABB<Point3D> A_aabb{};
const auto B_block_aabb =
convert::nodeIndexToAABB<3>({tree_height, block_index}, min_cell_width);
Expand All @@ -36,23 +40,61 @@ MapT transform(const MapT& B_map, const Transformation3D& T_AB,
A_aabb.includePoint(A_corner);
}
const Index3D A_block_index_min =
convert::pointToFloorIndex(A_aabb.min, min_cell_width_inv);
convert::pointToFloorIndex(A_aabb.min, block_width_inv);
const Index3D A_block_index_max =
convert::pointToCeilIndex(A_aabb.max, min_cell_width_inv);
convert::pointToCeilIndex(A_aabb.max, block_width_inv);
for (const auto& A_block_index :
Grid<3>(A_block_index_min, A_block_index_max)) {
A_map.getOrAllocateBlock(A_block_index);
Grid(A_block_index_min, A_block_index_max)) {
A_map->getOrAllocateBlock(A_block_index);
}
});

// Populate map A by interpolating map B
sample(
A_map,
[&B_map, &T_AB](const Point3D& A_point) {
const auto B_point = T_AB * A_point;
return interpolate::trilinear(B_map, B_point);
},
std::move(thread_pool));
const Transformation3D T_BA = T_AB.inverse();
QueryAccelerator B_accelerator{B_map};
A_map->forEachBlock(
[&B_accelerator, &T_BA, &thread_pool, tree_height, min_cell_width](
const Index3D& block_index, auto& block) {
// Indicate that the block has changed
block.setLastUpdatedStamp();
block.setNeedsPruning();

// Get pointers to the root value and node, which contain the wavelet
// scale and detail coefficients, respectively
FloatingPoint* root_value_ptr = &block.getRootScale();
NodePtrType root_node_ptr = &block.getRootNode();
const OctreeIndex root_node_index{tree_height, block_index};

// Recursively crop all nodes
if (thread_pool) {
thread_pool->add_task([B_accelerator, &T_BA, root_node_ptr,
root_node_index, root_value_ptr,
block_ptr = &block, min_cell_width]() mutable {
detail::sampleNodeRecursive<MapT>(
*root_node_ptr, root_node_index, *root_value_ptr,
[&B_accelerator, &T_BA](const Point3D& A_point) {
const auto B_point = T_BA * A_point;
return interpolate::trilinear(B_accelerator, B_point);
},
min_cell_width);
block_ptr->prune();
});
} else {
detail::sampleNodeRecursive<MapT>(
*root_node_ptr, root_node_index, *root_value_ptr,
[&B_accelerator, &T_BA](const Point3D& A_point) {
const auto B_point = T_BA * A_point;
return interpolate::trilinear(B_accelerator, B_point);
},
min_cell_width);
block.prune();
}
});

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}

return A_map;
}
Expand Down

0 comments on commit e87d0e5

Please sign in to comment.