Skip to content

Commit

Permalink
Expose leaky wrappers, improve conversion to SymPy, disable conversio…
Browse files Browse the repository at this point in the history
…n from SymPy.
  • Loading branch information
bluescarni committed Nov 5, 2023
1 parent a1af893 commit f031014
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 45 deletions.
2 changes: 0 additions & 2 deletions heyoka/_sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def mul_wrapper(*args):
retval[_spy.Function("heyoka_kepE")] = core.kepE
retval[_spy.Function("heyoka_kepF")] = core.kepF
retval[_spy.Function("heyoka_kepDE")] = core.kepDE
retval[_spy.Function("heyoka_relu")] = core.relu
retval[_spy.Function("heyoka_relup")] = core.relup
retval[_spy.Function("heyoka_time")] = lambda: core.time

return retval
Expand Down
15 changes: 15 additions & 0 deletions heyoka/_test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,18 @@ def test_fix_unfix(self):
self.assertEqual(fix_nn(expression(1.1)), expression(1.1))
self.assertEqual(unfix(fix(x + y)), x + y)
self.assertEqual(unfix(arg=[fix(x + fix(y)), fix(x - fix(y))]), [x + y, x - y])

def test_relu_wrappers(self):
from . import make_vars, leaky_relu, leaky_relup, relu, relup

x, y = make_vars("x", "y")

self.assertEqual(leaky_relu(0.)(x), relu(x))
self.assertEqual(leaky_relup(0.)(x), relup(x))
self.assertEqual(leaky_relu(0.1)(x+y), relu(x+y,0.1))
self.assertEqual(leaky_relup(0.1)(x+y), relup(x+y,0.1))

self.assertEqual(leaky_relu(0.)(x*y), relu(x*y))
self.assertEqual(leaky_relup(0.)(x*y), relup(x*y))
self.assertEqual(leaky_relu(0.1)(x*y+y), relu(x*y+y,0.1))
self.assertEqual(leaky_relup(0.1)(x*y+y), relup(x*y+y,0.1))
20 changes: 6 additions & 14 deletions heyoka/_test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,27 +287,19 @@ def test_func_conversion(self):
to_sympy(core.kepF(hx, hy, hz)), spy.Function("heyoka_kepF")(x, y, z)
)

self.assertEqual(
core.relu(hx), from_sympy(spy.Function("heyoka_relu")(x))
)
self.assertEqual(
to_sympy(core.relu(hx)), spy.Function("heyoka_relu")(x)
)

self.assertEqual(
core.relup(hx), from_sympy(spy.Function("heyoka_relup")(x))
)
self.assertEqual(
to_sympy(core.relup(hx)), spy.Function("heyoka_relup")(x)
)

self.assertEqual(
core.kepDE(hx, hy, hz), from_sympy(spy.Function("heyoka_kepDE")(x, y, z))
)
self.assertEqual(
to_sympy(core.kepDE(hx, hy, hz)), spy.Function("heyoka_kepDE")(x, y, z)
)

# relu/relup.
self.assertEqual(to_sympy(core.relu(hx)), spy.Piecewise((x, x > 0), (0., True)))
self.assertEqual(to_sympy(core.relup(hx)), spy.Piecewise((1., x > 0), (0., True)))
self.assertEqual(to_sympy(core.relu(hx, 0.1)), spy.Piecewise((x, x > 0), (x*0.1, True)))
self.assertEqual(to_sympy(core.relup(hx, 0.1)), spy.Piecewise((1., x > 0), (0.1, True)))

self.assertEqual(-1.0 * hx, from_sympy(-x))
self.assertEqual(to_sympy(-hx), -x)

Expand Down
47 changes: 28 additions & 19 deletions heyoka/expose_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,25 +278,34 @@ void expose_expression(py::module_ &m)
m.def("prod", &hey::prod, "terms"_a);

// NOTE: need explicit casts for sqrt and exp due to the presence of overloads for number.
m.def("sqrt", static_cast<hey::expression (*)(hey::expression)>(&hey::sqrt));
m.def("exp", static_cast<hey::expression (*)(hey::expression)>(&hey::exp));
m.def("log", &hey::log);
m.def("sin", &hey::sin);
m.def("cos", &hey::cos);
m.def("tan", &hey::tan);
m.def("asin", &hey::asin);
m.def("acos", &hey::acos);
m.def("atan", &hey::atan);
m.def("sinh", &hey::sinh);
m.def("cosh", &hey::cosh);
m.def("tanh", &hey::tanh);
m.def("asinh", &hey::asinh);
m.def("acosh", &hey::acosh);
m.def("atanh", &hey::atanh);
m.def("sigmoid", &hey::sigmoid);
m.def("erf", &hey::erf);
m.def("relu", &hey::relu);
m.def("relup", &hey::relup);
m.def("sqrt", static_cast<hey::expression (*)(hey::expression)>(&hey::sqrt), "arg"_a);
m.def("exp", static_cast<hey::expression (*)(hey::expression)>(&hey::exp), "arg"_a);
m.def("log", &hey::log, "arg"_a);
m.def("sin", &hey::sin, "arg"_a);
m.def("cos", &hey::cos, "arg"_a);
m.def("tan", &hey::tan, "arg"_a);
m.def("asin", &hey::asin, "arg"_a);
m.def("acos", &hey::acos, "arg"_a);
m.def("atan", &hey::atan, "arg"_a);
m.def("sinh", &hey::sinh, "arg"_a);
m.def("cosh", &hey::cosh, "arg"_a);
m.def("tanh", &hey::tanh, "arg"_a);
m.def("asinh", &hey::asinh, "arg"_a);
m.def("acosh", &hey::acosh, "arg"_a);
m.def("atanh", &hey::atanh, "arg"_a);
m.def("sigmoid", &hey::sigmoid, "arg"_a);
m.def("erf", &hey::erf, "arg"_a);
m.def("relu", &hey::relu, "arg"_a, "slope"_a = 0.);
m.def("relup", &hey::relup, "arg"_a, "slope"_a = 0.);

// Leaky relu wrappers.
py::class_<hey::leaky_relu> lr_class(m, "leaky_relu", py::dynamic_attr{});
lr_class.def(py::init([](double slope) { return hey::leaky_relu(slope); }), "slope"_a);
lr_class.def("__call__", &hey::leaky_relu::operator(), "arg"_a);

py::class_<hey::leaky_relup> lrp_class(m, "leaky_relup", py::dynamic_attr{});
lrp_class.def(py::init([](double slope) { return hey::leaky_relup(slope); }), "slope"_a);
lrp_class.def("__call__", &hey::leaky_relup::operator(), "arg"_a);

// NOTE: when exposing multivariate functions, we want to be able to pass
// in numerical arguments for convenience. Thus, we expose such functions taking
Expand Down
58 changes: 48 additions & 10 deletions heyoka/setup_sympy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,54 @@ void setup_sympy(py::module &m)
auto sympy_kepDE = py::object(detail::spy->attr("Function")("heyoka_kepDE"));
detail::fmap[typeid(hy::detail::kepDE_impl)] = sympy_kepDE;

// relu, relup.
// NOTE: these will remain unevaluated functions.
// NOTE: it might be possible to implement relu and relup as
// piecewise functions:
// https://stats.stackexchange.com/questions/540101/how-to-use-sympy-library-to-create-a-leakey-relu-function-for-implementation
auto sympy_relu = py::object(detail::spy->attr("Function")("heyoka_relu"));
detail::fmap[typeid(hy::detail::relu_impl)] = sympy_relu;

auto sympy_relup = py::object(detail::spy->attr("Function")("heyoka_relup"));
detail::fmap[typeid(hy::detail::relup_impl)] = sympy_relup;
// relu, relup and leaky variants.
// NOTE: these are implemented as piecewise functions:
// https://medium.com/@mathcube7/piecewise-functions-in-pythons-sympy-83f857948d3
detail::fmap[typeid(hy::detail::relu_impl)]
= [](std::unordered_map<const void *, py::object> &func_map, const hy::func &f) -> py::object {
assert(f.args().size() == 1u);

// Convert the argument to SymPy.
auto s_arg = detail::to_sympy_impl(func_map, f.args()[0]);

// Fetch the slope value.
const auto slope = f.extract<hy::detail::relu_impl>()->get_slope();

// Create the condition arg > 0.
auto cond = s_arg.attr("__gt__")(0);

// Fetch the piecewise function.
auto pw = detail::spy->attr("Piecewise");

if (slope == 0) {
return pw(py::make_tuple(s_arg, cond), py::make_tuple(0., true));
} else {
return pw(py::make_tuple(s_arg, cond), py::make_tuple(py::cast(slope) * s_arg, true));
}
};

detail::fmap[typeid(hy::detail::relup_impl)]
= [](std::unordered_map<const void *, py::object> &func_map, const hy::func &f) -> py::object {
assert(f.args().size() == 1u);

// Convert the argument to SymPy.
auto s_arg = detail::to_sympy_impl(func_map, f.args()[0]);

// Fetch the slope value.
const auto slope = f.extract<hy::detail::relup_impl>()->get_slope();

// Create the condition arg > 0.
auto cond = s_arg.attr("__gt__")(0);

// Fetch the piecewise function.
auto pw = detail::spy->attr("Piecewise");

if (slope == 0) {
return pw(py::make_tuple(1., cond), py::make_tuple(0., true));
} else {
return pw(py::make_tuple(1., cond), py::make_tuple(slope, true));
}
};

// sigmoid.
detail::fmap[typeid(hy::detail::sigmoid_impl)]
Expand Down

0 comments on commit f031014

Please sign in to comment.