diff --git a/docs/pages/tutorials/cpp.rst b/docs/pages/tutorials/cpp.rst index 43ab08f12..709f916dc 100644 --- a/docs/pages/tutorials/cpp.rst +++ b/docs/pages/tutorials/cpp.rst @@ -40,6 +40,21 @@ Code examples ************* In the following sections, you'll find sample code for common tasks. If you'd like to request examples for additional tasks or contribute new examples, please don't hesitate to `contact us `_. +Mapping +======= +The only requirements to build wavemap maps are that you have a set of + +1. depth measurements, +2. sensor pose (estimates) for each measurement. + +We usually use depth measurements from depth cameras or 3D LiDARs, but any source would work as long as a corresponding :ref:`projection ` and :ref:`measurement ` model is available. To help you get started quickly, we provide example configs for various sensor setups :gh_file:`here `. An overview of all the available settings is provided on the :doc:`parameters page <../parameters/index>`. + +Example pipeline +---------------- + +.. literalinclude:: ../../../examples/cpp/mapping/example_pipeline.cc + :language: cpp + Serializing maps ================ In this section, we'll demonstrate how to serialize and deserialize maps using wavemap's lightweight and efficient binary format. This format is consistent across wavemap's C++, Python, and ROS interfaces. For instance, you can create maps on a robot with ROS and later load them into a rendering engine plugin that only depends on wavemap's C++ library. diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 428126512..cac35a6e2 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -32,5 +32,6 @@ endif () # Add each set of examples add_subdirectory(io) +add_subdirectory(mapping) add_subdirectory(queries) add_subdirectory(planning) diff --git a/examples/cpp/mapping/.gitignore b/examples/cpp/mapping/.gitignore new file mode 100644 index 000000000..668b7e19a --- /dev/null +++ b/examples/cpp/mapping/.gitignore @@ -0,0 +1 @@ +example_map.wvmp diff --git a/examples/cpp/mapping/CMakeLists.txt b/examples/cpp/mapping/CMakeLists.txt new file mode 100644 index 000000000..4d65c6b40 --- /dev/null +++ b/examples/cpp/mapping/CMakeLists.txt @@ -0,0 +1,7 @@ +# Binaries +add_executable(example_pipeline example_pipeline.cc) +set_wavemap_target_properties(example_pipeline) +target_link_libraries(example_pipeline PUBLIC + wavemap::wavemap_core + wavemap::wavemap_io + wavemap::wavemap_pipeline) diff --git a/examples/cpp/mapping/example_config.yaml b/examples/cpp/mapping/example_config.yaml new file mode 100644 index 000000000..84b8a3700 --- /dev/null +++ b/examples/cpp/mapping/example_config.yaml @@ -0,0 +1,33 @@ +# NOTE: More examples can be found in the `interfaces/ros1/wavemap_ros/config` +# directory, and all available params are documented at: +# https://ethz-asl.github.io/wavemap/pages/parameters + +map: + type: hashed_chunked_wavelet_octree + min_cell_width: { meters: 0.02 } + +map_operations: + - type: threshold_map + once_every: { seconds: 2.0 } + - type: prune_map + once_every: { seconds: 10.0 } + +measurement_integrators: + your_camera: + projection_model: + type: pinhole_camera_projector + width: 640 + height: 480 + fx: 320.0 + fy: 320.0 + cx: 320.0 + cy: 240.0 + measurement_model: + type: continuous_ray + range_sigma: { meters: 0.01 } + scaling_free: 0.2 + scaling_occupied: 0.4 + integration_method: + type: hashed_chunked_wavelet_integrator + min_range: { meters: 0.1 } + max_range: { meters: 5.0 } diff --git a/examples/cpp/mapping/example_pipeline.cc b/examples/cpp/mapping/example_pipeline.cc new file mode 100644 index 000000000..5325786af --- /dev/null +++ b/examples/cpp/mapping/example_pipeline.cc @@ -0,0 +1,76 @@ +#include +#include + +#include +#include +#include +#include + +using namespace wavemap; +int main(int, char** argv) { + // Settings + const std::string config_name = "example_config.yaml"; + const std::string output_map_name = "example_map.wvmp"; + const std::filesystem::path current_dir = + std::filesystem::canonical(__FILE__).parent_path(); + const std::filesystem::path config_path = current_dir / config_name; + const std::filesystem::path output_map_path = current_dir / output_map_name; + + // Load the config + std::cout << "Loading config file: " << config_path << std::endl; + const auto params = io::yamlFileToParams(config_path); + CHECK(params.has_value()); + + // Create the map + const auto map_config = params->getChild("map"); + CHECK(map_config.has_value()); + MapBase::Ptr map = MapFactory::create(map_config.value()); + CHECK_NOTNULL(map); + + // Create measurement integration pipeline + Pipeline pipeline{map}; + + // Add map operations to pipeline + const auto map_operations = + params->getChildAs("map_operations"); + CHECK(map_operations); + for (const auto& operation_params : map_operations.value()) { + pipeline.addOperation(operation_params); + } + + // Add measurement integrators to pipeline + const auto measurement_integrators = + params->getChildAs("measurement_integrators"); + CHECK(measurement_integrators); + for (const auto& [integrator_name, integrator_params] : + measurement_integrators.value()) { + pipeline.addIntegrator(integrator_name, integrator_params); + } + + // Define a depth camera image for illustration + const int width = 640; + const int height = 480; + Image<> depth_image{width, height}; + depth_image.setToConstant(2.f); + + // Define a depth camera pose for illustration + Transformation3D T_W_C{}; + // Set the camera's [x, y, z] position + T_W_C.getPosition() = {0.f, 0.f, 0.f}; + // Set the camera's orientation + // For example as a quaternion's [w, x, y, z] coefficients + T_W_C.getRotation() = Rotation3D{0.5f, -0.5f, -0.5f, 0.5f}; + // NOTE: Alternatively, the rotation can also be loaded from a rotation + // matrix, or T_W_C can be initialized from a transformation matrix. + + // Integrate the measurement + pipeline.runPipeline({"your_camera"}, PosedImage<>{T_W_C, depth_image}); + + // Measure the map's size + const size_t map_size_KB = map->getMemoryUsage() / 1024; + std::cout << "Created map of size: " << map_size_KB << " KB" << std::endl; + + // Save the map to disk + std::cout << "Saving it to: " << output_map_path << std::endl; + io::mapToFile(*map, output_map_path); +} diff --git a/interfaces/ros1/wavemap_ros/CMakeLists.txt b/interfaces/ros1/wavemap_ros/CMakeLists.txt index e90bc7449..8987c6e1e 100644 --- a/interfaces/ros1/wavemap_ros/CMakeLists.txt +++ b/interfaces/ros1/wavemap_ros/CMakeLists.txt @@ -10,6 +10,9 @@ find_package(catkin REQUIRED COMPONENTS roscpp rosbag cv_bridge image_transport tf2_ros std_srvs sensor_msgs visualization_msgs) +# Optional dependencies +find_package(livox_ros_driver2 QUIET) + # Register catkin package catkin_package( INCLUDE_DIRS include @@ -19,14 +22,6 @@ catkin_package( roscpp rosbag cv_bridge image_transport tf2_ros std_srvs sensor_msgs visualization_msgs) - -# Optional dependencies -find_package(livox_ros_driver2 QUIET) -if (livox_ros_driver2_FOUND) - include_directories(${livox_ros_driver2_INCLUDE_DIRS}) - add_compile_definitions(LIVOX_AVAILABLE) -endif () - # Libraries add_library(${PROJECT_NAME} src/inputs/depth_image_topic_input.cc @@ -48,6 +43,13 @@ target_include_directories(${PROJECT_NAME} target_link_libraries(${PROJECT_NAME} PUBLIC ${catkin_LIBRARIES} ${OpenCV_LIBRARIES}) +# Optional Livox support +if (livox_ros_driver2_FOUND) + target_include_directories(${PROJECT_NAME} PUBLIC + ${livox_ros_driver2_INCLUDE_DIRS}) + target_compile_definitions(${PROJECT_NAME} PUBLIC LIVOX_AVAILABLE) +endif () + # Binaries add_executable(ros_server app/ros_server.cc) set_wavemap_target_properties(ros_server) diff --git a/interfaces/ros1/wavemap_ros/src/ros_server.cc b/interfaces/ros1/wavemap_ros/src/ros_server.cc index 3d94171a4..054d285c2 100644 --- a/interfaces/ros1/wavemap_ros/src/ros_server.cc +++ b/interfaces/ros1/wavemap_ros/src/ros_server.cc @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include #include @@ -33,10 +33,10 @@ bool RosServerConfig::isValid(bool verbose) const { // NOTE: If RosServerConfig::from(...) fails, accessing its value will throw // an exception and end the program. RosServer::RosServer(ros::NodeHandle nh, ros::NodeHandle nh_private) - : RosServer(nh, nh_private, - RosServerConfig::from( - param::convert::toParamValue(nh_private, "general")) - .value()) {} + : RosServer( + nh, nh_private, + RosServerConfig::from(convert::rosToParams(nh_private, "general")) + .value()) {} RosServer::RosServer(ros::NodeHandle nh, ros::NodeHandle nh_private, const RosServerConfig& config) @@ -47,8 +47,7 @@ RosServer::RosServer(ros::NodeHandle nh, ros::NodeHandle nh_private, config_.logging_level.applyToRosConsole(); // Setup data structure - const auto data_structure_params = - param::convert::toParamValue(nh_private, "map"); + const auto data_structure_params = convert::rosToParams(nh_private, "map"); occupancy_map_ = MapFactory::create(data_structure_params, MapType::kHashedBlocks); CHECK_NOTNULL(occupancy_map_); @@ -65,14 +64,14 @@ RosServer::RosServer(ros::NodeHandle nh, ros::NodeHandle nh_private, // Add map operations to pipeline const param::Array map_operation_param_array = - param::convert::toParamArray(nh_private, "map_operations"); + convert::rosToParamArray(nh_private, "map_operations"); for (const auto& operation_params : map_operation_param_array) { addOperation(operation_params, nh_private); } // Add measurement integrators to pipeline const param::Map measurement_integrator_param_map = - param::convert::toParamMap(nh_private, "measurement_integrators"); + convert::rosToParamMap(nh_private, "measurement_integrators"); for (const auto& [integrator_name, integrator_params] : measurement_integrator_param_map) { pipeline_->addIntegrator(integrator_name, integrator_params); @@ -80,7 +79,7 @@ RosServer::RosServer(ros::NodeHandle nh, ros::NodeHandle nh_private, // Setup measurement inputs const param::Array input_param_array = - param::convert::toParamArray(nh_private, "inputs"); + convert::rosToParamArray(nh_private, "inputs"); for (const auto& integrator_params : input_param_array) { addInput(integrator_params, nh, nh_private); } diff --git a/interfaces/ros1/wavemap_ros_conversions/include/wavemap_ros_conversions/config_conversions.h b/interfaces/ros1/wavemap_ros_conversions/include/wavemap_ros_conversions/config_conversions.h index 42204b47d..0031cb376 100644 --- a/interfaces/ros1/wavemap_ros_conversions/include/wavemap_ros_conversions/config_conversions.h +++ b/interfaces/ros1/wavemap_ros_conversions/include/wavemap_ros_conversions/config_conversions.h @@ -7,14 +7,14 @@ #include #include -namespace wavemap::param::convert { -param::Map toParamMap(const ros::NodeHandle& nh, const std::string& ns); -param::Array toParamArray(const ros::NodeHandle& nh, const std::string& ns); -param::Value toParamValue(const ros::NodeHandle& nh, const std::string& ns); +namespace wavemap::convert { +param::Map xmlRpcToParamMap(const XmlRpc::XmlRpcValue& xml_rpc_value); +param::Array xmlRpcToParamArray(const XmlRpc::XmlRpcValue& xml_rpc_value); +param::Value xmlRpcToParams(const XmlRpc::XmlRpcValue& xml_rpc_value); -param::Map toParamMap(const XmlRpc::XmlRpcValue& xml_rpc_value); -param::Array toParamArray(const XmlRpc::XmlRpcValue& xml_rpc_value); -param::Value toParamValue(const XmlRpc::XmlRpcValue& xml_rpc_value); -} // namespace wavemap::param::convert +param::Map rosToParamMap(const ros::NodeHandle& nh, const std::string& ns); +param::Array rosToParamArray(const ros::NodeHandle& nh, const std::string& ns); +param::Value rosToParams(const ros::NodeHandle& nh, const std::string& ns); +} // namespace wavemap::convert #endif // WAVEMAP_ROS_CONVERSIONS_CONFIG_CONVERSIONS_H_ diff --git a/interfaces/ros1/wavemap_ros_conversions/src/config_conversions.cc b/interfaces/ros1/wavemap_ros_conversions/src/config_conversions.cc index b321cb02e..256b74211 100644 --- a/interfaces/ros1/wavemap_ros_conversions/src/config_conversions.cc +++ b/interfaces/ros1/wavemap_ros_conversions/src/config_conversions.cc @@ -2,84 +2,51 @@ #include -namespace wavemap::param::convert { -param::Map toParamMap(const ros::NodeHandle& nh, const std::string& ns) { - XmlRpc::XmlRpcValue xml_rpc_value; - if (nh.getParam(ns, xml_rpc_value)) { - return toParamMap(xml_rpc_value); - } - - ROS_WARN_STREAM("Could not load ROS params under namespace " - << nh.resolveName(ns)); - return {}; -} - -param::Array toParamArray(const ros::NodeHandle& nh, const std::string& ns) { - XmlRpc::XmlRpcValue xml_rpc_value; - if (nh.getParam(ns, xml_rpc_value)) { - return toParamArray(xml_rpc_value); - } - - ROS_WARN_STREAM("Could not load ROS params under namespace " - << nh.resolveName(ns)); - return {}; -} - -param::Value toParamValue(const ros::NodeHandle& nh, const std::string& ns) { - XmlRpc::XmlRpcValue xml_rpc_value; - if (nh.getParam(ns, xml_rpc_value)) { - return toParamValue(xml_rpc_value); - } - - ROS_WARN_STREAM("Could not load ROS params under namespace " - << nh.resolveName(ns)); - return param::Value{param::Map{}}; // Return an empty map -} - -param::Map toParamMap( // NOLINT +namespace wavemap::convert { +param::Map xmlRpcToParamMap( // NOLINT const XmlRpc::XmlRpcValue& xml_rpc_value) { if (xml_rpc_value.getType() != XmlRpc::XmlRpcValue::TypeStruct) { - ROS_WARN("Expected param map."); + ROS_WARN("Expected ROS param map."); return {}; } param::Map param_map; for (const auto& kv : xml_rpc_value) { - param_map.emplace(kv.first, toParamValue(kv.second)); + param_map.emplace(kv.first, xmlRpcToParams(kv.second)); } return param_map; } -param::Array toParamArray( // NOLINT +param::Array xmlRpcToParamArray( // NOLINT const XmlRpc::XmlRpcValue& xml_rpc_value) { if (xml_rpc_value.getType() != XmlRpc::XmlRpcValue::TypeArray) { - ROS_WARN("Expected param array."); + ROS_WARN("Expected ROS param array."); return {}; } param::Array array; array.reserve(xml_rpc_value.size()); for (int idx = 0; idx < xml_rpc_value.size(); ++idx) { // NOLINT - array.template emplace_back(toParamValue(xml_rpc_value[idx])); + array.emplace_back(xmlRpcToParams(xml_rpc_value[idx])); } return array; } -param::Value toParamValue( // NOLINT +param::Value xmlRpcToParams( // NOLINT const XmlRpc::XmlRpcValue& xml_rpc_value) { switch (xml_rpc_value.getType()) { + case XmlRpc::XmlRpcValue::TypeStruct: + return param::Value{xmlRpcToParamMap(xml_rpc_value)}; + case XmlRpc::XmlRpcValue::TypeArray: + return param::Value{xmlRpcToParamArray(xml_rpc_value)}; case XmlRpc::XmlRpcValue::TypeBoolean: - return param::Value(static_cast(xml_rpc_value)); + return param::Value{static_cast(xml_rpc_value)}; case XmlRpc::XmlRpcValue::TypeInt: - return param::Value(static_cast(xml_rpc_value)); + return param::Value{static_cast(xml_rpc_value)}; case XmlRpc::XmlRpcValue::TypeDouble: - return param::Value(static_cast(xml_rpc_value)); + return param::Value{static_cast(xml_rpc_value)}; case XmlRpc::XmlRpcValue::TypeString: - return param::Value(static_cast(xml_rpc_value)); - case XmlRpc::XmlRpcValue::TypeArray: - return param::Value(toParamArray(xml_rpc_value)); - case XmlRpc::XmlRpcValue::TypeStruct: - return param::Value(toParamMap(xml_rpc_value)); + return param::Value{static_cast(xml_rpc_value)}; case XmlRpc::XmlRpcValue::TypeInvalid: ROS_ERROR("Encountered invalid type while parsing ROS params."); break; @@ -99,6 +66,39 @@ param::Value toParamValue( // NOLINT } // On error, return an empty array - return param::Value(param::Array{}); + return param::Value{param::Array{}}; +} + +param::Map rosToParamMap(const ros::NodeHandle& nh, const std::string& ns) { + XmlRpc::XmlRpcValue xml_rpc_value; + if (nh.getParam(ns, xml_rpc_value)) { + return xmlRpcToParamMap(xml_rpc_value); + } + + ROS_WARN_STREAM("Could not load ROS params under namespace " + << nh.resolveName(ns)); + return {}; +} + +param::Array rosToParamArray(const ros::NodeHandle& nh, const std::string& ns) { + XmlRpc::XmlRpcValue xml_rpc_value; + if (nh.getParam(ns, xml_rpc_value)) { + return xmlRpcToParamArray(xml_rpc_value); + } + + ROS_WARN_STREAM("Could not load ROS params under namespace " + << nh.resolveName(ns)); + return {}; +} + +param::Value rosToParams(const ros::NodeHandle& nh, const std::string& ns) { + XmlRpc::XmlRpcValue xml_rpc_value; + if (nh.getParam(ns, xml_rpc_value)) { + return xmlRpcToParams(xml_rpc_value); + } + + ROS_WARN_STREAM("Could not load ROS params under namespace " + << nh.resolveName(ns)); + return param::Value{param::Map{}}; // Return an empty map } -} // namespace wavemap::param::convert +} // namespace wavemap::convert diff --git a/interfaces/ros1/wavemap_rviz_plugin/src/wavemap_map_display.cc b/interfaces/ros1/wavemap_rviz_plugin/src/wavemap_map_display.cc index a9452c92e..bcf61337c 100644 --- a/interfaces/ros1/wavemap_rviz_plugin/src/wavemap_map_display.cc +++ b/interfaces/ros1/wavemap_rviz_plugin/src/wavemap_map_display.cc @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include "wavemap_rviz_plugin/utils/alert_dialog.h" diff --git a/library/cpp/cmake/find-wavemap-deps.cmake b/library/cpp/cmake/find-wavemap-deps.cmake index 695f0c98e..367f9068e 100644 --- a/library/cpp/cmake/find-wavemap-deps.cmake +++ b/library/cpp/cmake/find-wavemap-deps.cmake @@ -56,3 +56,4 @@ endif () # Optional dependencies find_package(tracy QUIET) +find_package(yaml-cpp QUIET) diff --git a/library/cpp/include/wavemap/io/config/file_conversions.h b/library/cpp/include/wavemap/io/config/file_conversions.h new file mode 100644 index 000000000..6c0a8f89e --- /dev/null +++ b/library/cpp/include/wavemap/io/config/file_conversions.h @@ -0,0 +1,13 @@ +#ifndef WAVEMAP_IO_CONFIG_FILE_CONVERSIONS_H_ +#define WAVEMAP_IO_CONFIG_FILE_CONVERSIONS_H_ + +#include + +#include + +namespace wavemap::io { +std::optional yamlFileToParams( + const std::filesystem::path& file_path); +} // namespace wavemap::io + +#endif // WAVEMAP_IO_CONFIG_FILE_CONVERSIONS_H_ diff --git a/library/cpp/include/wavemap/io/config/stream_conversions.h b/library/cpp/include/wavemap/io/config/stream_conversions.h new file mode 100644 index 000000000..eef5417db --- /dev/null +++ b/library/cpp/include/wavemap/io/config/stream_conversions.h @@ -0,0 +1,12 @@ +#ifndef WAVEMAP_IO_CONFIG_STREAM_CONVERSIONS_H_ +#define WAVEMAP_IO_CONFIG_STREAM_CONVERSIONS_H_ + +#include + +#include + +namespace wavemap::io { +std::optional yamlStreamToParams(std::istream& istream); +} + +#endif // WAVEMAP_IO_CONFIG_STREAM_CONVERSIONS_H_ diff --git a/library/cpp/include/wavemap/io/config/yaml_cpp_conversions.h b/library/cpp/include/wavemap/io/config/yaml_cpp_conversions.h new file mode 100644 index 000000000..3f26c50d7 --- /dev/null +++ b/library/cpp/include/wavemap/io/config/yaml_cpp_conversions.h @@ -0,0 +1,13 @@ +#ifndef WAVEMAP_IO_CONFIG_YAML_CPP_CONVERSIONS_H_ +#define WAVEMAP_IO_CONFIG_YAML_CPP_CONVERSIONS_H_ + +#include +#include + +namespace wavemap::convert { +param::Map yamlToParamMap(const YAML::Node& yaml_cpp_value); +param::Array yamlToParamArray(const YAML::Node& yaml_cpp_value); +param::Value yamlToParams(const YAML::Node& yaml_cpp_value); +} // namespace wavemap::convert + +#endif // WAVEMAP_IO_CONFIG_YAML_CPP_CONVERSIONS_H_ diff --git a/library/cpp/include/wavemap/io/file_conversions.h b/library/cpp/include/wavemap/io/map/file_conversions.h similarity index 62% rename from library/cpp/include/wavemap/io/file_conversions.h rename to library/cpp/include/wavemap/io/map/file_conversions.h index 6e0cb70f2..c34b5ab72 100644 --- a/library/cpp/include/wavemap/io/file_conversions.h +++ b/library/cpp/include/wavemap/io/map/file_conversions.h @@ -1,14 +1,13 @@ -#ifndef WAVEMAP_IO_FILE_CONVERSIONS_H_ -#define WAVEMAP_IO_FILE_CONVERSIONS_H_ +#ifndef WAVEMAP_IO_MAP_FILE_CONVERSIONS_H_ +#define WAVEMAP_IO_MAP_FILE_CONVERSIONS_H_ #include #include "wavemap/core/map/map_base.h" -#include "wavemap/io/stream_conversions.h" namespace wavemap::io { bool mapToFile(const MapBase& map, const std::filesystem::path& file_path); bool fileToMap(const std::filesystem::path& file_path, MapBase::Ptr& map); } // namespace wavemap::io -#endif // WAVEMAP_IO_FILE_CONVERSIONS_H_ +#endif // WAVEMAP_IO_MAP_FILE_CONVERSIONS_H_ diff --git a/library/cpp/include/wavemap/io/impl/streamable_types_impl.h b/library/cpp/include/wavemap/io/map/impl/streamable_types_impl.h similarity index 97% rename from library/cpp/include/wavemap/io/impl/streamable_types_impl.h rename to library/cpp/include/wavemap/io/map/impl/streamable_types_impl.h index 249437a04..6722a3c66 100644 --- a/library/cpp/include/wavemap/io/impl/streamable_types_impl.h +++ b/library/cpp/include/wavemap/io/map/impl/streamable_types_impl.h @@ -1,5 +1,5 @@ -#ifndef WAVEMAP_IO_IMPL_STREAMABLE_TYPES_IMPL_H_ -#define WAVEMAP_IO_IMPL_STREAMABLE_TYPES_IMPL_H_ +#ifndef WAVEMAP_IO_MAP_IMPL_STREAMABLE_TYPES_IMPL_H_ +#define WAVEMAP_IO_MAP_IMPL_STREAMABLE_TYPES_IMPL_H_ namespace wavemap::io::streamable { void Index3D::write(std::ostream& ostream) const { @@ -157,4 +157,4 @@ StorageFormat StorageFormat::peek(std::istream& istream) { } } // namespace wavemap::io::streamable -#endif // WAVEMAP_IO_IMPL_STREAMABLE_TYPES_IMPL_H_ +#endif // WAVEMAP_IO_MAP_IMPL_STREAMABLE_TYPES_IMPL_H_ diff --git a/library/cpp/include/wavemap/io/stream_conversions.h b/library/cpp/include/wavemap/io/map/stream_conversions.h similarity index 80% rename from library/cpp/include/wavemap/io/stream_conversions.h rename to library/cpp/include/wavemap/io/map/stream_conversions.h index f3341999b..3c02b226f 100644 --- a/library/cpp/include/wavemap/io/stream_conversions.h +++ b/library/cpp/include/wavemap/io/map/stream_conversions.h @@ -1,16 +1,14 @@ -#ifndef WAVEMAP_IO_STREAM_CONVERSIONS_H_ -#define WAVEMAP_IO_STREAM_CONVERSIONS_H_ +#ifndef WAVEMAP_IO_MAP_STREAM_CONVERSIONS_H_ +#define WAVEMAP_IO_MAP_STREAM_CONVERSIONS_H_ #include #include #include "wavemap/core/common.h" -#include "wavemap/core/map/cell_types/haar_coefficients.h" #include "wavemap/core/map/hashed_blocks.h" #include "wavemap/core/map/hashed_chunked_wavelet_octree.h" #include "wavemap/core/map/hashed_wavelet_octree.h" #include "wavemap/core/map/wavelet_octree.h" -#include "wavemap/io/streamable_types.h" namespace wavemap::io { bool mapToStream(const MapBase& map, std::ostream& ostream); @@ -28,4 +26,4 @@ bool streamToMap(std::istream& istream, HashedWaveletOctree::Ptr& map); bool mapToStream(const HashedChunkedWaveletOctree& map, std::ostream& ostream); } // namespace wavemap::io -#endif // WAVEMAP_IO_STREAM_CONVERSIONS_H_ +#endif // WAVEMAP_IO_MAP_STREAM_CONVERSIONS_H_ diff --git a/library/cpp/include/wavemap/io/streamable_types.h b/library/cpp/include/wavemap/io/map/streamable_types.h similarity index 93% rename from library/cpp/include/wavemap/io/streamable_types.h rename to library/cpp/include/wavemap/io/map/streamable_types.h index 3a82d09c5..f27eb54a7 100644 --- a/library/cpp/include/wavemap/io/streamable_types.h +++ b/library/cpp/include/wavemap/io/map/streamable_types.h @@ -1,5 +1,5 @@ -#ifndef WAVEMAP_IO_STREAMABLE_TYPES_H_ -#define WAVEMAP_IO_STREAMABLE_TYPES_H_ +#ifndef WAVEMAP_IO_MAP_STREAMABLE_TYPES_H_ +#define WAVEMAP_IO_MAP_STREAMABLE_TYPES_H_ #include #include @@ -99,6 +99,6 @@ struct StorageFormat : TypeSelector { }; } // namespace wavemap::io::streamable -#include "wavemap/io/impl/streamable_types_impl.h" +#include "wavemap/io/map/impl/streamable_types_impl.h" -#endif // WAVEMAP_IO_STREAMABLE_TYPES_H_ +#endif // WAVEMAP_IO_MAP_STREAMABLE_TYPES_H_ diff --git a/library/cpp/src/core/integrator/integrator_factory.cc b/library/cpp/src/core/integrator/integrator_factory.cc index 524a54dac..be9b21b96 100644 --- a/library/cpp/src/core/integrator/integrator_factory.cc +++ b/library/cpp/src/core/integrator/integrator_factory.cc @@ -60,10 +60,16 @@ std::unique_ptr IntegratorFactory::create( } // Create the projection model + const auto projector_params = params.getChild("projection_model"); + if (!projector_params) { + LOG(ERROR) << "No params with key \"projection_model\". " + "Projection model could not be created."; + return nullptr; + } std::shared_ptr projection_model = - ProjectorFactory::create(params); + ProjectorFactory::create(projector_params.value()); if (!projection_model) { - LOG(ERROR) << "Projection model could not be created."; + LOG(ERROR) << "Creating projection model failed."; return nullptr; } @@ -75,11 +81,18 @@ std::unique_ptr IntegratorFactory::create( std::make_shared>(projection_model->getDimensions()); // Create the measurement model + const auto measurement_params = params.getChild("measurement_model"); + if (!measurement_params) { + LOG(ERROR) << "No params with key \"measurement_model\". " + "Measurement model could not be created."; + return nullptr; + } std::shared_ptr measurement_model = - MeasurementModelFactory::create(params, projection_model, - posed_range_image, beam_offset_image); + MeasurementModelFactory::create(measurement_params.value(), + projection_model, posed_range_image, + beam_offset_image); if (!measurement_model) { - LOG(ERROR) << "Measurement model could not be created."; + LOG(ERROR) << "Creating measurement model failed."; return nullptr; } diff --git a/library/cpp/src/core/integrator/measurement_model/measurement_model_factory.cc b/library/cpp/src/core/integrator/measurement_model/measurement_model_factory.cc index ac5b32b42..a057d1f59 100644 --- a/library/cpp/src/core/integrator/measurement_model/measurement_model_factory.cc +++ b/library/cpp/src/core/integrator/measurement_model/measurement_model_factory.cc @@ -11,8 +11,7 @@ std::unique_ptr wavemap::MeasurementModelFactory::create( const param::Value& params, ProjectorBase::ConstPtr projection_model, Image<>::ConstPtr range_image, Image::ConstPtr beam_offset_image, std::optional default_measurement_model_type) { - if (const auto type = MeasurementModelType::from(params, "measurement_model"); - type) { + if (const auto type = MeasurementModelType::from(params); type) { return create(type.value(), std::move(projection_model), std::move(range_image), std::move(beam_offset_image), params); } @@ -36,9 +35,7 @@ std::unique_ptr wavemap::MeasurementModelFactory::create( Image::ConstPtr beam_offset_image, const param::Value& params) { switch (measurement_model_type) { case MeasurementModelType::kContinuousRay: { - if (const auto config = - ContinuousRayConfig::from(params, "measurement_model"); - config) { + if (const auto config = ContinuousRayConfig::from(params); config) { return std::make_unique(config.value(), std::move(projection_model), std::move(range_image)); @@ -49,9 +46,7 @@ std::unique_ptr wavemap::MeasurementModelFactory::create( } } case MeasurementModelType::kContinuousBeam: { - if (const auto config = - ContinuousBeamConfig::from(params, "measurement_model"); - config) { + if (const auto config = ContinuousBeamConfig::from(params); config) { return std::make_unique( config.value(), std::move(projection_model), std::move(range_image), std::move(beam_offset_image)); diff --git a/library/cpp/src/core/integrator/projection_model/projector_factory.cc b/library/cpp/src/core/integrator/projection_model/projector_factory.cc index 1a132e5ba..b41d8ba28 100644 --- a/library/cpp/src/core/integrator/projection_model/projector_factory.cc +++ b/library/cpp/src/core/integrator/projection_model/projector_factory.cc @@ -10,7 +10,7 @@ namespace wavemap { std::unique_ptr wavemap::ProjectorFactory::create( const param::Value& params, std::optional default_projector_type) { - if (const auto type = ProjectorType::from(params, "projection_model"); type) { + if (const auto type = ProjectorType::from(params); type) { return create(type.value(), params); } @@ -28,9 +28,7 @@ std::unique_ptr wavemap::ProjectorFactory::create( ProjectorType projector_type, const param::Value& params) { switch (projector_type) { case ProjectorType::kSphericalProjector: { - if (const auto config = - SphericalProjectorConfig::from(params, "projection_model"); - config) { + if (const auto config = SphericalProjectorConfig::from(params); config) { return std::make_unique(config.value()); } else { LOG(ERROR) << "Spherical projector config could not be loaded."; @@ -38,9 +36,7 @@ std::unique_ptr wavemap::ProjectorFactory::create( } } case ProjectorType::kOusterProjector: { - if (const auto config = - OusterProjectorConfig::from(params, "projection_model"); - config) { + if (const auto config = OusterProjectorConfig::from(params); config) { return std::make_unique(config.value()); } else { LOG(ERROR) << "Ouster projector config could not be loaded."; @@ -48,8 +44,7 @@ std::unique_ptr wavemap::ProjectorFactory::create( } } case ProjectorType::kPinholeCameraProjector: { - if (const auto config = - PinholeCameraProjectorConfig::from(params, "projection_model"); + if (const auto config = PinholeCameraProjectorConfig::from(params); config) { return std::make_unique(config.value()); } else { diff --git a/library/cpp/src/io/CMakeLists.txt b/library/cpp/src/io/CMakeLists.txt index 8a8fb375c..e903c7ba6 100644 --- a/library/cpp/src/io/CMakeLists.txt +++ b/library/cpp/src/io/CMakeLists.txt @@ -8,7 +8,19 @@ add_wavemap_include_directories(wavemap_io) target_link_libraries(wavemap_io PUBLIC Eigen3::Eigen glog wavemap_core) # Set sources -target_sources(wavemap_io PRIVATE file_conversions.cc stream_conversions.cc) +target_sources(wavemap_io PRIVATE + config/file_conversions.cc + config/stream_conversions.cc + map/file_conversions.cc + map/stream_conversions.cc) + +# Optional YAML support +if (yaml-cpp_FOUND) + target_compile_definitions(wavemap_io PUBLIC YAML_CPP_AVAILABLE) + target_link_libraries(wavemap_io PUBLIC ${YAML_CPP_LIBRARIES}) + target_sources(wavemap_io PRIVATE + config/yaml_cpp_conversions.cc) +endif () # Support installs if (GENERATE_WAVEMAP_INSTALL_RULES) diff --git a/library/cpp/src/io/config/file_conversions.cc b/library/cpp/src/io/config/file_conversions.cc new file mode 100644 index 000000000..b651af30d --- /dev/null +++ b/library/cpp/src/io/config/file_conversions.cc @@ -0,0 +1,32 @@ +#include "wavemap/io/config/file_conversions.h" + +#include + +#include "wavemap/io/config/stream_conversions.h" + +namespace wavemap::io { +std::optional yamlFileToParams( + const std::filesystem::path& file_path) { + if (file_path.empty()) { + LOG(WARNING) + << "Could not open file for reading. Specified file path is empty."; + return std::nullopt; + } + + // Open the file for reading + std::ifstream file_istream(file_path, + std::ifstream::in | std::ifstream::binary); + if (!file_istream.is_open()) { + LOG(WARNING) << "Could not open file " << file_path + << " for reading. Error: " << strerror(errno); + return std::nullopt; + } + + // Deserialize from bytestream + if (auto params = yamlStreamToParams(file_istream); params) { + return params; + } + LOG(WARNING) << "Failed to parse map from file " << file_path << "."; + return std::nullopt; +} +} // namespace wavemap::io diff --git a/library/cpp/src/io/config/stream_conversions.cc b/library/cpp/src/io/config/stream_conversions.cc new file mode 100644 index 000000000..54df5bec5 --- /dev/null +++ b/library/cpp/src/io/config/stream_conversions.cc @@ -0,0 +1,25 @@ +#include "wavemap/io/config/stream_conversions.h" + +#ifdef YAML_CPP_AVAILABLE +#include + +#include "wavemap/io/config/yaml_cpp_conversions.h" +#endif + +namespace wavemap::io { +std::optional yamlStreamToParams( + [[maybe_unused]] std::istream& istream) { +#ifdef YAML_CPP_AVAILABLE + try { + YAML::Node yaml = YAML::Load(istream); + return convert::yamlToParams(yaml); + } catch (YAML::ParserException&) { + LOG(WARNING) << "Failed to parse bytestream using yaml-cpp."; + return std::nullopt; + } +#endif + LOG(ERROR) << "No YAML parser is available. Install yaml-cpp or add an " + "interface to your preferred parser in wavemap/io/config."; + return std::nullopt; +} +} // namespace wavemap::io diff --git a/library/cpp/src/io/config/yaml_cpp_conversions.cc b/library/cpp/src/io/config/yaml_cpp_conversions.cc new file mode 100644 index 000000000..51ebe5355 --- /dev/null +++ b/library/cpp/src/io/config/yaml_cpp_conversions.cc @@ -0,0 +1,76 @@ +#include "wavemap/io/config/yaml_cpp_conversions.h" + +#include + +namespace wavemap::convert { +param::Map yamlToParamMap(const YAML::Node& yaml_cpp_value) { // NOLINT + if (!yaml_cpp_value.IsMap()) { + LOG(WARNING) << "Expected YAML param map."; + return {}; + } + + param::Map param_map; + for (const auto& kv : yaml_cpp_value) { + if (std::string key; YAML::convert::decode(kv.first, key)) { + param_map.emplace(key, yamlToParams(kv.second)); + } else { + LOG(WARNING) << "Ignoring YAML map entry. Key not convertible to string."; + } + } + return param_map; +} + +param::Array yamlToParamArray(const YAML::Node& yaml_cpp_value) { // NOLINT + if (!yaml_cpp_value.IsSequence()) { + LOG(WARNING) << "Expected YAML param sequence."; + return {}; + } + + param::Array array; + for (const auto& kv : yaml_cpp_value) { + array.emplace_back(yamlToParams(kv)); + } + return array; +} + +param::Value yamlToParams(const YAML::Node& yaml_cpp_value) { // NOLINT + if (yaml_cpp_value.IsDefined()) { + switch (yaml_cpp_value.Type()) { + case YAML::NodeType::Map: + return param::Value{yamlToParamMap(yaml_cpp_value)}; + case YAML::NodeType::Sequence: + return param::Value{yamlToParamArray(yaml_cpp_value)}; + case YAML::NodeType::Scalar: + if (bool value; YAML::convert::decode(yaml_cpp_value, value)) { + return param::Value{value}; + } + if (int value; YAML::convert::decode(yaml_cpp_value, value)) { + return param::Value{value}; + } + if (double value; + YAML::convert::decode(yaml_cpp_value, value)) { + return param::Value{value}; + } + if (std::string value; + YAML::convert::decode(yaml_cpp_value, value)) { + return param::Value{value}; + } + LOG(ERROR) << "Encountered unknown type while parsing YAML params."; + break; + case YAML::NodeType::Undefined: + LOG(ERROR) << "Encountered undefined type while parsing YAML params."; + break; + case YAML::NodeType::Null: + LOG(ERROR) << "Encountered null type while parsing YAML params."; + break; + default: + break; + } + } else { + LOG(ERROR) << "Encountered undefined node while parsing YAML params."; + } + + // On error, return an empty array + return param::Value{param::Array{}}; +} +} // namespace wavemap::convert diff --git a/library/cpp/src/io/file_conversions.cc b/library/cpp/src/io/map/file_conversions.cc similarity index 94% rename from library/cpp/src/io/file_conversions.cc rename to library/cpp/src/io/map/file_conversions.cc index 0c8fa9aba..16508a156 100644 --- a/library/cpp/src/io/file_conversions.cc +++ b/library/cpp/src/io/map/file_conversions.cc @@ -1,7 +1,9 @@ -#include "wavemap/io/file_conversions.h" +#include "wavemap/io/map/file_conversions.h" #include +#include "wavemap/io/map/stream_conversions.h" + namespace wavemap::io { bool mapToFile(const MapBase& map, const std::filesystem::path& file_path) { if (file_path.empty()) { diff --git a/library/cpp/src/io/stream_conversions.cc b/library/cpp/src/io/map/stream_conversions.cc similarity index 99% rename from library/cpp/src/io/stream_conversions.cc rename to library/cpp/src/io/map/stream_conversions.cc index 364260678..af2cbaa80 100644 --- a/library/cpp/src/io/stream_conversions.cc +++ b/library/cpp/src/io/map/stream_conversions.cc @@ -1,9 +1,11 @@ -#include "wavemap/io/stream_conversions.h" +#include "wavemap/io/map/stream_conversions.h" #include #include #include +#include "wavemap/io/map/streamable_types.h" + namespace wavemap::io { bool mapToStream(const MapBase& map, std::ostream& ostream) { // Call the appropriate mapToStream converter based on the map's derived type diff --git a/library/cpp/test/data/config_file.yaml b/library/cpp/test/data/config_file.yaml new file mode 100644 index 000000000..3f9d71236 --- /dev/null +++ b/library/cpp/test/data/config_file.yaml @@ -0,0 +1,52 @@ +general: + world_frame: "odom" + logging_level: debug + allow_reset_map_service: true + +map: + type: hashed_chunked_wavelet_octree + min_cell_width: { meters: 0.1 } + +map_operations: + - type: threshold_map + once_every: { seconds: 2.0 } + - type: prune_map + once_every: { seconds: 10.0 } + - type: publish_map + once_every: { seconds: 2.0 } + +measurement_integrators: + ouster_os0: + projection_model: + type: ouster_projector + lidar_origin_to_beam_origin: { millimeters: 27.67 } + lidar_origin_to_sensor_origin_z_offset: { millimeters: 36.18 } + elevation: + num_cells: 128 + min_angle: { degrees: -45.73 } + max_angle: { degrees: 46.27 } + azimuth: + num_cells: 1024 + min_angle: { degrees: -180.0 } + max_angle: { degrees: 180.0 } + measurement_model: + type: continuous_beam + angle_sigma: { degrees: 0.035 } + # NOTE: For best results, we recommend setting range_sigma to + # max(min_cell_width / 2, ouster_noise) where ouster_noise = 0.05 + range_sigma: { meters: 0.05 } + scaling_free: 0.2 + scaling_occupied: 0.4 + integration_method: + type: hashed_chunked_wavelet_integrator + min_range: { meters: 1.0 } + max_range: { meters: 15.0 } + +inputs: + - type: pointcloud_topic + topic_name: "/os_cloud_node/points" + topic_type: ouster + measurement_integrator_names: ouster_os0 + undistort_motion: true + topic_queue_length: 10 + max_wait_for_pose: { seconds: 1.0 } diff --git a/library/cpp/test/src/io/CMakeLists.txt b/library/cpp/test/src/io/CMakeLists.txt index 5cf212161..8a28eec9c 100644 --- a/library/cpp/test/src/io/CMakeLists.txt +++ b/library/cpp/test/src/io/CMakeLists.txt @@ -2,7 +2,11 @@ add_executable(test_wavemap_io) target_include_directories(test_wavemap_io PRIVATE ${PROJECT_SOURCE_DIR}/test/include) -target_sources(test_wavemap_io PRIVATE test_file_conversions.cc) +target_sources(test_wavemap_io PRIVATE + test_config_file_conversions.cc + test_map_file_conversions.cc) +target_compile_definitions(test_wavemap_io PRIVATE + DATADIR="${CMAKE_CURRENT_SOURCE_DIR}/../../data") set_wavemap_target_properties(test_wavemap_io) target_link_libraries(test_wavemap_io wavemap_core wavemap_io GTest::gtest_main) diff --git a/library/cpp/test/src/io/test_config_file_conversions.cc b/library/cpp/test/src/io/test_config_file_conversions.cc new file mode 100644 index 000000000..82a30eb4e --- /dev/null +++ b/library/cpp/test/src/io/test_config_file_conversions.cc @@ -0,0 +1,66 @@ +#include + +#include + +#include "wavemap/io/config/file_conversions.h" +#include "wavemap/test/fixture_base.h" + +namespace wavemap { +using ConfigFileConversionTest = FixtureBase; + +TEST_F(ConfigFileConversionTest, Reading) { + std::filesystem::path data_dir = DATADIR; + std::filesystem::path config_file_path = data_dir / "config_file.yaml"; + + const auto params = io::yamlFileToParams(config_file_path); + +#ifndef YAML_CPP_AVAILABLE + EXPECT_EQ(params, std::nullopt); + return; +#endif + + ASSERT_TRUE(params.has_value()); + EXPECT_TRUE(params->holds()); + EXPECT_TRUE(params->hasChild("general")); + EXPECT_TRUE(params->hasChild("map")); + EXPECT_TRUE(params->hasChild("map_operations")); + EXPECT_TRUE(params->hasChild("measurement_integrators")); + EXPECT_TRUE(params->hasChild("inputs")); + + // Test map loading + const auto map = params->getChildAs("map"); + ASSERT_TRUE(map.has_value()); + EXPECT_EQ(map->size(), 2); + + // Test array loading + const auto map_operations = + params->getChildAs("map_operations"); + ASSERT_TRUE(map_operations.has_value()) + << "Could not convert value for key \"map_operations\" to a " + "param::Array."; + EXPECT_EQ(map_operations->size(), 3); + + // Test loading primitive types + const auto inputs = params->getChildAs("inputs"); + ASSERT_TRUE(inputs.has_value()); + const auto& input = inputs->operator[](0); + // Booleans + const auto undistort_motion = input.getChildAs("undistort_motion"); + ASSERT_TRUE(undistort_motion.has_value()); + EXPECT_EQ(undistort_motion.value(), true); + // Integers + const auto topic_queue_length = input.getChildAs("topic_queue_length"); + ASSERT_TRUE(topic_queue_length.has_value()); + EXPECT_EQ(topic_queue_length.value(), 10); + // Floating points + const auto max_wait_for_pose = input.getChild("max_wait_for_pose"); + ASSERT_TRUE(max_wait_for_pose.has_value()); + const auto seconds = max_wait_for_pose->getChildAs("seconds"); + ASSERT_TRUE(seconds.has_value()); + EXPECT_EQ(seconds.value(), 1.f); + // Strings + const auto topic_name = input.getChildAs("topic_name"); + ASSERT_TRUE(topic_name.has_value()); + EXPECT_EQ(topic_name.value(), "/os_cloud_node/points"); +} +} // namespace wavemap diff --git a/library/cpp/test/src/io/test_file_conversions.cc b/library/cpp/test/src/io/test_map_file_conversions.cc similarity index 98% rename from library/cpp/test/src/io/test_file_conversions.cc rename to library/cpp/test/src/io/test_map_file_conversions.cc index 44cd077ff..6e7753c12 100644 --- a/library/cpp/test/src/io/test_file_conversions.cc +++ b/library/cpp/test/src/io/test_map_file_conversions.cc @@ -4,11 +4,12 @@ #include #include "wavemap/core/common.h" +#include "wavemap/core/map/hashed_blocks.h" #include "wavemap/core/map/hashed_chunked_wavelet_octree.h" #include "wavemap/core/map/hashed_wavelet_octree.h" #include "wavemap/core/map/map_base.h" #include "wavemap/core/map/wavelet_octree.h" -#include "wavemap/io/file_conversions.h" +#include "wavemap/io/map/file_conversions.h" #include "wavemap/test/config_generator.h" #include "wavemap/test/fixture_base.h" #include "wavemap/test/geometry_generator.h" diff --git a/library/python/include/pywavemap/param.h b/library/python/include/pywavemap/param.h index 27ff9f56f..e103519c6 100644 --- a/library/python/include/pywavemap/param.h +++ b/library/python/include/pywavemap/param.h @@ -8,9 +8,9 @@ namespace nb = nanobind; namespace wavemap { namespace convert { -param::Map toParamMap(const nb::handle& py_value); -param::Array toParamArray(const nb::handle& py_value); -param::Value toParamValue(const nb::handle& py_value); +param::Map pyToParamMap(const nb::handle& py_value); +param::Array pyToParamArray(const nb::handle& py_value); +param::Value pyToParams(const nb::handle& py_value); } // namespace convert void add_param_module(nb::module_& m_param); diff --git a/library/python/src/maps.cc b/library/python/src/maps.cc index 6aa29c3bd..b224d077e 100644 --- a/library/python/src/maps.cc +++ b/library/python/src/maps.cc @@ -11,7 +11,7 @@ #include #include #include -#include +#include using namespace nb::literals; // NOLINT diff --git a/library/python/src/param.cc b/library/python/src/param.cc index 734efeb6f..fe7c79afa 100644 --- a/library/python/src/param.cc +++ b/library/python/src/param.cc @@ -4,7 +4,7 @@ namespace wavemap { namespace convert { -param::Map toParamMap(const nb::handle& py_value) { // NOLINT +param::Map pyToParamMap(const nb::handle& py_value) { // NOLINT nb::dict py_dict; if (!nb::try_cast(py_value, py_dict)) { LOG(WARNING) << "Expected python dict, but got " @@ -14,9 +14,8 @@ param::Map toParamMap(const nb::handle& py_value) { // NOLINT param::Map param_map; for (const auto& [py_key, py_dict_value] : py_dict) { - nb::str py_key_str; - if (nb::try_cast(py_key, py_key_str)) { - param_map.emplace(py_key_str.c_str(), toParamValue(py_dict_value)); + if (nb::str py_key_str; nb::try_cast(py_key, py_key_str)) { + param_map.emplace(py_key_str.c_str(), pyToParams(py_dict_value)); } else { LOG(WARNING) << "Ignoring dict entry. Key not convertible to string for " "element with key " @@ -27,7 +26,7 @@ param::Map toParamMap(const nb::handle& py_value) { // NOLINT return param_map; } -param::Array toParamArray(const nb::handle& py_value) { // NOLINT +param::Array pyToParamArray(const nb::handle& py_value) { // NOLINT nb::list py_list; if (!nb::try_cast(py_value, py_list)) { LOG(WARNING) << "Expected python list, but got " @@ -37,13 +36,19 @@ param::Array toParamArray(const nb::handle& py_value) { // NOLINT param::Array array; array.reserve(nb::len(py_list)); - for (const auto& py_element : py_list) { // NOLINT - array.emplace_back(toParamValue(py_element)); + for (const auto& py_element : py_list) { + array.emplace_back(pyToParams(py_element)); } return array; } -param::Value toParamValue(const nb::handle& py_value) { // NOLINT +param::Value pyToParams(const nb::handle& py_value) { // NOLINT + if (nb::isinstance(py_value)) { + return param::Value(pyToParamMap(py_value)); + } + if (nb::isinstance(py_value)) { + return param::Value(pyToParamArray(py_value)); + } if (nb::bool_ py_bool; nb::try_cast(py_value, py_bool)) { return param::Value{static_cast(py_bool)}; } @@ -56,12 +61,6 @@ param::Value toParamValue(const nb::handle& py_value) { // NOLINT if (nb::str py_str; nb::try_cast(py_value, py_str)) { return param::Value{std::string{py_str.c_str()}}; } - if (nb::isinstance(py_value)) { - return param::Value(toParamArray(py_value)); - } - if (nb::isinstance(py_value)) { - return param::Value(toParamMap(py_value)); - } // On error, return an empty array LOG(ERROR) << "Encountered unsupported type while parsing python param " @@ -78,7 +77,7 @@ void add_param_module(nb::module_& m_param) { "can therefore hold the information needed to initialize an entire " "config, or even a hierarchy of nested configs.") .def("__init__", [](param::Value* t, nb::handle py_value) { - new (t) param::Value{convert::toParamValue(py_value)}; + new (t) param::Value{convert::pyToParams(py_value)}; }); nb::implicitly_convertible();