diff --git a/include/mitsuba/core/ray.h b/include/mitsuba/core/ray.h index 463c5202e..280f82f99 100644 --- a/include/mitsuba/core/ray.h +++ b/include/mitsuba/core/ray.h @@ -17,15 +17,16 @@ NAMESPACE_BEGIN(mitsuba) template struct Ray { static constexpr size_t Size = dr::size_v; - using Point = Point_; + using Point = Point_; // TODO: need this because dr::value_t> isn't a masked array using Float = std::conditional_t, dr::detail::MaskedArray>, dr::value_t>; - using Vector = mitsuba::Vector; - using Spectrum = Spectrum_; - using Wavelength = wavelength_t; + using ScalarFloat = dr::scalar_t; + using Vector = mitsuba::Vector; + using Spectrum = Spectrum_; + using Wavelength = wavelength_t; /// Ray origin Point o; @@ -34,7 +35,7 @@ template struct Ray { /// Maximum position on the ray segment Float maxt = dr::Largest; /// Time value associated with this ray - Float time = 0.f; + Float time = (ScalarFloat) 0.f; /// Wavelength associated with the ray Wavelength wavelengths; @@ -44,7 +45,7 @@ template struct Ray { : o(o), d(d), time(time), wavelengths(wavelengths) { } /// Construct a new ray (o, d) with time - Ray(const Point &o, const Vector &d, const Float &time=0.f) + Ray(const Point &o, const Vector &d, const Float &time = (ScalarFloat) 0.f) : o(o), d(d), time(time) { } /// Construct a new ray (o, d) with bounds @@ -82,7 +83,7 @@ template struct RayDifferential : Ray { using Base = Ray; - MI_USING_TYPES(Float, Point, Vector, Wavelength) + MI_USING_TYPES(Float, ScalarFloat, Point, Vector, Wavelength) MI_USING_MEMBERS(o, d, maxt, time, wavelengths) Point o_x, o_y; @@ -94,7 +95,7 @@ struct RayDifferential : Ray { : Base(ray), o_x(0), o_y(0), d_x(0), d_y(0), has_differentials(false) {} /// Construct a new ray (o, d) at time 'time' - RayDifferential(const Point &o_, const Vector &d_, Float time_ = 0.f, + RayDifferential(const Point &o_, const Vector &d_, Float time_ = (ScalarFloat) 0.f, const Wavelength &wavelengths_ = Wavelength()) : o_x(0), o_y(0), d_x(0), d_y(0), has_differentials(false) { o = o_; diff --git a/src/core/python/ray_v.cpp b/src/core/python/ray_v.cpp index f39190043..16a5d10c5 100644 --- a/src/core/python/ray_v.cpp +++ b/src/core/python/ray_v.cpp @@ -5,21 +5,26 @@ template void bind_ray(nb::module_ &m, const char *name) { MI_PY_IMPORT_TYPES() - using Vector = typename Ray::Vector; - using Point = typename Ray::Point; + // Re-import this specific `Ray`'s types in cased of mixed precision + // between Float and Spectrum. + using RayFloat = typename Ray::Float; + using RayScalarFloat = typename Ray::ScalarFloat; + using Vector = typename Ray::Vector; + using Point = typename Ray::Point; + using Wavelength = typename Ray::Wavelength; MI_PY_CHECK_ALIAS(Ray, name) { auto ray = nb::class_(m, name, D(Ray)) .def(nb::init<>(), "Create an uninitialized ray") .def(nb::init(), "Copy constructor", "other"_a) - .def(nb::init(), + .def(nb::init(), D(Ray, Ray, 2), - "o"_a, "d"_a, "time"_a=0.0, "wavelengths"_a=Wavelength()) - .def(nb::init(), + "o"_a, "d"_a, "time"_a=(RayScalarFloat) 0.0, "wavelengths"_a=Wavelength()) + .def(nb::init(), D(Ray, Ray, 3), - "o"_a, "d"_a, "maxt"_a, "time"_a, "wavelengths"_a) - .def(nb::init(), - D(Ray, Ray, 4), "other"_a, "maxt"_a) + "o"_a, "d"_a, "maxt"_a, "time"_a, "wavelengths"_a) + .def(nb::init(), + D(Ray, Ray, 4), "other"_a, "maxt"_a) .def("__call__", &Ray::operator(), D(Ray, operator, call), "t"_a) .def_field(Ray, o, D(Ray, o)) .def_field(Ray, d, D(Ray, d)) @@ -45,7 +50,7 @@ MI_PY_EXPORT(Ray) { .def(nb::init(), "ray"_a) .def(nb::init(), "Initialize without differentials.", - "o"_a, "d"_a, "time"_a=0.0, "wavelengths"_a=Wavelength()) + "o"_a, "d"_a, "time"_a=(ScalarFloat) 0.0, "wavelengths"_a=Wavelength()) .def("scale_differential", &RayDifferential3f::scale_differential, "amount"_a, D(RayDifferential, scale_differential)) .def_field(RayDifferential3f, o_x, D(RayDifferential, o_x))