Skip to content

Commit

Permalink
Expose methods to sum shapes into map through Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Dec 16, 2024
1 parent 759669b commit 68bf7a9
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 34 deletions.
9 changes: 9 additions & 0 deletions examples/python/edit/sum_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
# Merge them together
wave.edit.sum(your_map, your_map_translated)

# Set a box in the map to free
box = wave.AABB(min=np.array([6.0, 6.0, -2.0]),
max=np.array([10.0, 10.0, 2.0]))
wave.edit.sum(your_map, box, -1.0)

# Set a sphere in the map to occupied
sphere = wave.Sphere(center=np.array([8.0, 8.0, 0.0]), radius=1.5)
wave.edit.sum(your_map, sphere, 2.0)

# Save the map
output_map_path = os.path.join(user_home, "your_map_merged.wvmp")
your_map.store(output_map_path)
44 changes: 22 additions & 22 deletions library/cpp/include/wavemap/core/utils/geometry/aabb.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,56 @@ struct AABB {
static constexpr int kDim = dim_v<PointT>;
static constexpr int kNumCorners = int_math::exp2(kDim);
using PointType = PointT;
using ScalarType = typename PointType::Scalar;
using ScalarType = typename PointT::Scalar;
using Corners = Eigen::Matrix<ScalarType, kDim, kNumCorners>;

static constexpr auto kInitialMin = std::numeric_limits<ScalarType>::max();
static constexpr auto kInitialMax = std::numeric_limits<ScalarType>::lowest();

PointType min = PointType::Constant(kInitialMin);
PointType max = PointType::Constant(kInitialMax);
PointT min = PointT::Constant(kInitialMin);
PointT max = PointT::Constant(kInitialMax);

AABB() = default;
AABB(const PointT& min, const PointT& max) : min(min), max(max) {}
AABB(PointT&& min, PointT&& max) : min(std::move(min)), max(std::move(max)) {}

void insert(const PointType& point) {
void insert(const PointT& point) {
min = min.cwiseMin(point);
max = max.cwiseMax(point);
}
bool contains(const PointType& point) const {
bool contains(const PointT& point) const {
return (min.array() <= point.array() && point.array() <= max.array()).all();
}

PointType closestPointTo(const PointType& point) const {
PointType closest_point = point.cwiseMax(min).cwiseMin(max);
PointT closestPointTo(const PointT& point) const {
PointT closest_point = point.cwiseMax(min).cwiseMin(max);
return closest_point;
}
PointType furthestPointFrom(const PointType& point) const {
const PointType aabb_center = (min + max) / static_cast<ScalarType>(2);
PointType furthest_point =
PointT furthestPointFrom(const PointT& point) const {
const PointT aabb_center = (min + max) / static_cast<ScalarType>(2);
PointT furthest_point =
(aabb_center.array() < point.array()).select(min, max);
return furthest_point;
}

PointType minOffsetTo(const PointType& point) const {
PointT minOffsetTo(const PointT& point) const {
return point - closestPointTo(point);
}
PointType maxOffsetTo(const PointType& point) const {
PointT maxOffsetTo(const PointT& point) const {
return point - furthestPointFrom(point);
}
// TODO(victorr): Check correctness with unit tests
PointType minOffsetTo(const AABB& other) const {
const PointType greatest_min = min.cwiseMax(other.min);
const PointType smallest_max = max.cwiseMin(other.max);
PointT minOffsetTo(const AABB& other) const {
const PointT greatest_min = min.cwiseMax(other.min);
const PointT smallest_max = max.cwiseMin(other.max);
return (greatest_min - smallest_max).cwiseMax(0);
}
// TODO(victorr): Check correctness with unit tests. Pay particular
// attention to whether the offset signs are correct.
PointType maxOffsetTo(const AABB& other) const {
const PointType diff_1 = min - other.max;
const PointType diff_2 = max - other.min;
PointType offset =
PointT maxOffsetTo(const AABB& other) const {
const PointT diff_1 = min - other.max;
const PointT diff_2 = max - other.min;
PointT offset =
(diff_2.array().abs() < diff_1.array().abs()).select(diff_1, diff_2);
return offset;
}
Expand All @@ -92,7 +92,7 @@ struct AABB {
ScalarType width() const {
return max[dim] - min[dim];
}
PointType widths() const { return max - min; }
PointT widths() const { return max - min; }

Corners corner_matrix() const {
Eigen::Matrix<ScalarType, kDim, kNumCorners> corners;
Expand All @@ -104,8 +104,8 @@ struct AABB {
return corners;
}

PointType corner_point(int corner_idx) const {
PointType corner;
PointT corner_point(int corner_idx) const {
PointT corner;
for (int dim_idx = 0; dim_idx < kDim; ++dim_idx) {
corner[dim_idx] = corner_coordinate(dim_idx, corner_idx);
}
Expand Down
14 changes: 9 additions & 5 deletions library/cpp/include/wavemap/core/utils/geometry/sphere.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ template <typename PointT>
struct Sphere {
static constexpr int kDim = dim_v<PointT>;
using PointType = PointT;
using ScalarType = typename PointType::Scalar;
using ScalarType = typename PointT::Scalar;

PointType center;
ScalarType radius;
PointT center = PointT::Constant(kNaN);
ScalarType radius = static_cast<ScalarType>(0);

Sphere() = default;
Sphere(const PointT& center, ScalarType radius)
Expand All @@ -25,10 +25,14 @@ struct Sphere {
: center(std::move(center)), radius(radius) {}

operator AABB<PointT>() const {
return AABB<PointT>(center.array() - radius, center.array() + radius);
if (std::isnan(center[0])) {
return {};
}
return {center.array() - radius, center.array() + radius};
}

bool contains(const PointType& point) const {
// TODO(victorr): Add tests, incl. behavior after default construction
bool contains(const PointT& point) const {
return (point - center).squaredNorm() <= radius * radius;
}

Expand Down
1 change: 1 addition & 0 deletions library/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ nanobind_add_module(_pywavemap_bindings STABLE_ABI
src/pywavemap.cc
src/convert.cc
src/edit.cc
src/geometry.cc
src/indices.cc
src/logging.cc
src/maps.cc
Expand Down
12 changes: 12 additions & 0 deletions library/python/include/pywavemap/geometry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef PYWAVEMAP_GEOMETRY_H_
#define PYWAVEMAP_GEOMETRY_H_

#include <nanobind/nanobind.h>

namespace nb = nanobind;

namespace wavemap {
void add_geometry_bindings(nb::module_& m);
}

#endif // PYWAVEMAP_GEOMETRY_H_
42 changes: 38 additions & 4 deletions library/python/src/edit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
#include <wavemap/core/utils/edit/crop.h>
#include <wavemap/core/utils/edit/multiply.h>
#include <wavemap/core/utils/edit/transform.h>
#include <wavemap/core/utils/geometry/aabb.h>
#include <wavemap/core/utils/geometry/sphere.h>

using namespace nb::literals; // NOLINT

namespace wavemap {
void add_edit_module(nb::module_& m_edit) {
// Map multiply methods
// Multiply a map with a scalar
// NOTE: Among others, this can be used to implement exponential forgetting,
// by multiplying the map with a scalar between 0 and 1.
m_edit.def(
Expand All @@ -30,7 +32,7 @@ void add_edit_module(nb::module_& m_edit) {
},
"map"_a, "multiplier"_a);

// Map sum methods
// Sum two maps together
m_edit.def(
"sum",
[](HashedWaveletOctree& map_A, const HashedWaveletOctree& map_B) {
Expand All @@ -45,7 +47,39 @@ void add_edit_module(nb::module_& m_edit) {
},
"map_A"_a, "map_B"_a);

// Map transformation methods
// Add a scalar value to all cells within an axis aligned bounding box
m_edit.def(
"sum",
[](HashedWaveletOctree& map, const AABB<Point3D>& aabb,
FloatingPoint update) {
edit::sum(map, aabb, update, std::make_shared<ThreadPool>());
},
"map"_a, "aabb"_a, "update"_a);
m_edit.def(
"sum",
[](HashedChunkedWaveletOctree& map, const AABB<Point3D>& aabb,
FloatingPoint update) {
edit::sum(map, aabb, update, std::make_shared<ThreadPool>());
},
"map"_a, "aabb"_a, "update"_a);

// Add a scalar value to all cells within a sphere
m_edit.def(
"sum",
[](HashedWaveletOctree& map, const Sphere<Point3D>& sphere,
FloatingPoint update) {
edit::sum(map, sphere, update, std::make_shared<ThreadPool>());
},
"map"_a, "sphere"_a, "update"_a);
m_edit.def(
"sum",
[](HashedChunkedWaveletOctree& map, const Sphere<Point3D>& sphere,
FloatingPoint update) {
edit::sum(map, sphere, update, std::make_shared<ThreadPool>());
},
"map"_a, "sphere"_a, "update"_a);

// Transform a map into a different coordinate frame
m_edit.def(
"transform",
[](HashedWaveletOctree& B_map, const Transformation3D& T_AB) {
Expand All @@ -59,7 +93,7 @@ void add_edit_module(nb::module_& m_edit) {
},
"map"_a, "transformation"_a);

// Map cropping methods
// Crop a map
m_edit.def(
"crop_to_sphere",
[](HashedWaveletOctree& map, const Point3D& t_W_center,
Expand Down
34 changes: 34 additions & 0 deletions library/python/src/geometry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "pywavemap/geometry.h"

#include <nanobind/eigen/dense.h>
#include <wavemap/core/common.h>
#include <wavemap/core/utils/geometry/aabb.h>
#include <wavemap/core/utils/geometry/sphere.h>

using namespace nb::literals; // NOLINT

namespace wavemap {
void add_geometry_bindings(nb::module_& m) {
// Axis-Aligned Bounding Box
nb::class_<AABB<Point3D>>(
m, "AABB", "A class representing an Axis-Aligned Bounding Box.")
.def(nb::init())
.def(nb::init<Point3D, Point3D>(), "min"_a, "max"_a)
.def_rw("min", &AABB<Point3D>::min)
.def_rw("max", &AABB<Point3D>::max)
.def("insert", &AABB<Point3D>::insert,
"Expand the AABB to tightly fit the new point "
"and its previous self.")
.def("contains", &AABB<Point3D>::contains,
"Test whether the AABB contains the given point.");

// Axis-Aligned Bounding Box
nb::class_<Sphere<Point3D>>(m, "Sphere", "A class representing a sphere.")
.def(nb::init())
.def(nb::init<Point3D, FloatingPoint>(), "center"_a, "radius"_a)
.def_rw("center", &Sphere<Point3D>::center)
.def_rw("radius", &Sphere<Point3D>::radius)
.def("contains", &Sphere<Point3D>::contains,
"Test whether the sphere contains the given point.");
}
} // namespace wavemap
2 changes: 1 addition & 1 deletion library/python/src/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace wavemap {
void add_index_bindings(nb::module_& m) {
nb::class_<OctreeIndex>(m, "OctreeIndex",
"A class representing indices of octree nodes.")
.def(nb::init<>())
.def(nb::init())
.def(nb::init<OctreeIndex::Element, OctreeIndex::Position>(), "height"_a,
"position"_a)
.def_rw("height", &OctreeIndex::height, "height"_a = 0,
Expand Down
7 changes: 5 additions & 2 deletions library/python/src/measurements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void add_measurement_bindings(nb::module_& m) {

// Pointclouds
nb::class_<Pointcloud<>>(m, "Pointcloud", "A class to store pointclouds.")
.def(nb::init<Pointcloud<>::Data>(), "point_matrix"_a);
.def(nb::init<Pointcloud<>::Data>(), "point_matrix"_a)
.def_prop_ro("size", &Pointcloud<>::size);
nb::class_<PosedPointcloud<>>(
m, "PosedPointcloud",
"A class to store pointclouds with an associated pose.")
Expand All @@ -36,7 +37,9 @@ void add_measurement_bindings(nb::module_& m) {

// Images
nb::class_<Image<>>(m, "Image", "A class to store depth images.")
.def(nb::init<Image<>::Data>(), "pixel_matrix"_a);
.def(nb::init<Image<>::Data>(), "pixel_matrix"_a)
.def_prop_ro("size", &Image<>::size)
.def_prop_ro("dimensions", &Image<>::getDimensions);
nb::class_<PosedImage<>>(
m, "PosedImage", "A class to store depth images with an associated pose.")
.def(nb::init<Transformation3D, Image<>>(), "pose"_a, "image"_a);
Expand Down
4 changes: 4 additions & 0 deletions library/python/src/pywavemap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "pywavemap/convert.h"
#include "pywavemap/edit.h"
#include "pywavemap/geometry.h"
#include "pywavemap/indices.h"
#include "pywavemap/logging.h"
#include "pywavemap/maps.h"
Expand Down Expand Up @@ -57,6 +58,9 @@ NB_MODULE(_pywavemap_bindings, m) {
// Bindings for measurement types
add_measurement_bindings(m);

// Bindings for geometric types
add_geometry_bindings(m);

// Bindings for map types
add_map_bindings(m);

Expand Down
1 change: 1 addition & 0 deletions library/python/src/pywavemap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._pywavemap_bindings import OctreeIndex
from ._pywavemap_bindings import (Rotation, Pose, Pointcloud, PosedPointcloud,
Image, PosedImage)
from ._pywavemap_bindings import AABB, Sphere
from ._pywavemap_bindings import (Map, HashedWaveletOctree,
HashedChunkedWaveletOctree,
InterpolationMode)
Expand Down

0 comments on commit 68bf7a9

Please sign in to comment.