Skip to content

Commit

Permalink
Check if option set before returning option to python
Browse files Browse the repository at this point in the history
  • Loading branch information
Popov-Dmitriy-Ivanovich committed Dec 6, 2023
1 parent 1a14c0e commit 04952a6
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 54 deletions.
8 changes: 5 additions & 3 deletions src/core/algorithms/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ class Algorithm {
[[nodiscard]] std::unordered_set<std::string_view> GetPossibleOptions() const;
[[nodiscard]] std::string_view GetDescription(std::string_view option_name) const;

std::unordered_map<std::string_view, std::unique_ptr<config::OptValue> > GetOptValues() {
std::unordered_map<std::string_view, std::unique_ptr<config::OptValue> > opt_values;
std::unordered_map<std::string_view, std::unique_ptr<config::OptValue>> GetOptValues() {
std::unordered_map<std::string_view, std::unique_ptr<config::OptValue>> opt_values;
for (auto& i : possible_options_) {
opt_values[i.second->GetName()] = i.second->GetOptValue();
if (i.second->IsSet()) {
opt_values[i.second->GetName()] = i.second->GetOptValue();
}
}
return opt_values;
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/config/ioption.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "boost/any.hpp"

namespace config {
struct OptValue{
struct OptValue {
std::type_index type;
boost::any value;
};
Expand Down
3 changes: 2 additions & 1 deletion src/core/config/option.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class Option : public IOption {
}

std::unique_ptr<OptValue> GetOptValue() override {
return std::make_unique<OptValue>(OptValue{std::type_index(typeid(T)),boost::any(*value_ptr_)});
return std::make_unique<OptValue>(
OptValue{std::type_index(typeid(T)), boost::any(*value_ptr_)});
}

private:
Expand Down
75 changes: 33 additions & 42 deletions src/python_bindings/opt_to_str.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,48 @@ using ConvFunction = std::function<py::object(boost::any)>;

template <typename T>
std::pair<std::type_index, ConvFunction> normal_conv_pair{
std::type_index(typeid(T)),
[](boost::any value) { return py::cast(boost::any_cast<T>(value)); }
};
std::type_index(typeid(T)),
[](boost::any value) { return py::cast(boost::any_cast<T>(value)); }};

template <>
std::pair<std::type_index, ConvFunction> normal_conv_pair<bool>{
std::type_index(typeid(bool)),
[](boost::any value) { return py::cast(boost::any_cast<bool>(value)); }
};
std::type_index(typeid(bool)),
[](boost::any value) { return py::cast(boost::any_cast<bool>(value)); }};

template <>
std::pair<std::type_index, ConvFunction> normal_conv_pair<config::IndicesType>{
std::type_index(typeid(config::IndicesType)),
[](boost::any value) {
auto opt_value = (boost::any_cast<config::IndicesType>(value));
return py::cast(opt_value);
}
};
std::type_index(typeid(config::IndicesType)), [](boost::any value) {
auto opt_value = (boost::any_cast<config::IndicesType>(value));
return py::cast(opt_value);
}};
template <>
std::pair<std::type_index, ConvFunction> normal_conv_pair<algos::metric::MetricAlgo>{
std::type_index(typeid(algos::metric::MetricAlgo)),
[](boost::any value){
auto opt_value = boost::any_cast<algos::metric::MetricAlgo>(value);
return py::cast(opt_value._to_string());
}
};
std::type_index(typeid(algos::metric::MetricAlgo)), [](boost::any value) {
auto opt_value = boost::any_cast<algos::metric::MetricAlgo>(value);
return py::cast(opt_value._to_string());
}};
template <>
std::pair<std::type_index, ConvFunction> normal_conv_pair<algos::metric::Metric>{
std::type_index(typeid(algos::metric::Metric)),
[](boost::any value){
auto opt_value = boost::any_cast<algos::metric::Metric>(value);
return py::cast(opt_value._to_string());
}
};
std::type_index(typeid(algos::metric::Metric)), [](boost::any value) {
auto opt_value = boost::any_cast<algos::metric::Metric>(value);
return py::cast(opt_value._to_string());
}};
const std::unordered_map<std::type_index, ConvFunction> converters{
normal_conv_pair<int>,
normal_conv_pair<double>,
normal_conv_pair<long double>,
normal_conv_pair<unsigned int>,
normal_conv_pair<bool>,
normal_conv_pair<config::ThreadNumType>,
normal_conv_pair<config::MaxLhsType>,
normal_conv_pair<config::ErrorType>,
normal_conv_pair<config::IndicesType>,
normal_conv_pair<algos::metric::MetricAlgo>,
normal_conv_pair<algos::metric::Metric>
};
} // namespace
normal_conv_pair<int>,
normal_conv_pair<double>,
normal_conv_pair<long double>,
normal_conv_pair<unsigned int>,
normal_conv_pair<bool>,
normal_conv_pair<config::ThreadNumType>,
normal_conv_pair<config::MaxLhsType>,
normal_conv_pair<config::ErrorType>,
normal_conv_pair<config::IndicesType>,
normal_conv_pair<algos::metric::MetricAlgo>,
normal_conv_pair<algos::metric::Metric>};
} // namespace

namespace python_bindings{
py::object opt_to_str(std::type_index type, boost::any val) {
return converters.at(type)(val);
}
} // namespace python_bindings
namespace python_bindings {
py::object opt_to_str(std::type_index type, boost::any val) {
return converters.at(type)(val);
}
} // namespace python_bindings
10 changes: 5 additions & 5 deletions src/python_bindings/opt_to_str.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#include "config/error/type.h"
#include "config/indices/type.h"
#include "config/max_lhs/type.h"
#include "config/names.h"
#include "config/tabular_data/input_table_type.h"
#include "config/thread_number/type.h"
#include "config/names.h"
#include "model/types/types.h"

namespace python_bindings{
namespace py = pybind11;
py::object opt_to_str(std::type_index type, boost::any val);
} // namespace python_bindings
namespace python_bindings {
namespace py = pybind11;
py::object opt_to_str(std::type_index type, boost::any val);
} // namespace python_bindings
4 changes: 2 additions & 2 deletions src/python_bindings/py_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#include "config/error/type.h"
#include "config/indices/type.h"
#include "config/max_lhs/type.h"
#include "config/names.h"
#include "config/tabular_data/input_table_type.h"
#include "config/thread_number/type.h"
#include "config/names.h"
#include "model/types/types.h"
#include "opt_to_str.h"

Expand Down Expand Up @@ -61,7 +61,7 @@ class PyAlgorithmBase {
std::unordered_map<std::string, py::object> GetOpts() {
auto opt_ = algorithm_->GetOptValues();
std::unordered_map<std::string, py::object> res;
for (const auto& [name,value] : opt_) {
for (const auto& [name, value] : opt_) {
if (name == config::names::kTable || name == config::names::kInputFormat) {
continue;
}
Expand Down

0 comments on commit 04952a6

Please sign in to comment.