Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use variant-specific registry domains #1362

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions include/mitsuba/core/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,55 @@ template <typename T, std::enable_if_t<is_constructible_v<T, Stream*>, int> = 0>
Class::UnserializeFunctor get_unserialize_functor() { return [](Stream* s) -> Object* { return new T(s); }; }
template <typename T, std::enable_if_t<!is_constructible_v<T, Stream*>, int> = 0>
Class::UnserializeFunctor get_unserialize_functor() { return {}; }

/// Adapted from: https://stackoverflow.com/a/65440575
template <size_t... Len>
constexpr auto constexpr_array_concatenation(const std::array<char, Len> &...arrays) {
constexpr size_t N = (... + Len) - sizeof...(Len);
std::array<char, N + 1> result = {};
result[N] = '\0';

char *dst = result.data();
for (const char *src : { arrays.data()... }) {
for (; *src != '\0'; src++, dst++) {
*dst = *src;
}
}
return result;
}

NAMESPACE_END(detail)

#define MI_REGISTRY_PUT(name, ptr) \
if constexpr (dr::is_jit_v<Float>) { \
static constexpr auto domain_name \
= detail::constexpr_array_concatenation( \
std::array<char, 10>{"mitsuba::"}, \
detail::get_variant_padded<Float, Spectrum>(), \
std::array<char, 3>{"__"}, \
std::array<char, sizeof(name) / sizeof(char)>{name} \
); \
jit_registry_put(dr::backend_v<Float>, domain_name.data(), ptr); \
}

#define MI_CALL_TEMPLATE_BEGIN(Name) \
DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Name)

#define MI_CALL_TEMPLATE_END(Name) \
private: \
static constexpr auto MIDomain_ \
= ::mitsuba::detail::constexpr_array_concatenation( \
std::array<char, 10>{"mitsuba::"}, \
::mitsuba::detail::get_variant_padded<Ts...>(), \
std::array<char, 3>{"__"}, \
std::array<char, sizeof(#Name) / sizeof(char)>{#Name} \
); \
public: \
static constexpr const char *domain_() { \
return MIDomain_.data(); \
} \
static_assert(is_detected_v<detail::has_domain_override, CallSupport_>); \
DRJIT_CALL_END(mitsuba::Name)


NAMESPACE_END(mitsuba)
4 changes: 2 additions & 2 deletions include/mitsuba/render/bsdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::BSDF)
MI_CALL_TEMPLATE_BEGIN(BSDF)
DRJIT_CALL_METHOD(sample)
DRJIT_CALL_METHOD(eval)
DRJIT_CALL_METHOD(eval_null_transmission)
Expand All @@ -670,7 +670,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::BSDF)
auto needs_differentials() const {
return has_flag(flags(), mitsuba::BSDFFlags::NeedsDifferentials);
}
DRJIT_CALL_END(mitsuba::BSDF)
MI_CALL_TEMPLATE_END(BSDF)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Emitter)
MI_CALL_TEMPLATE_BEGIN(Emitter)
DRJIT_CALL_METHOD(sample_ray)
DRJIT_CALL_METHOD(sample_direction)
DRJIT_CALL_METHOD(pdf_direction)
Expand All @@ -117,7 +117,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Emitter)
DRJIT_CALL_GETTER(shape)
DRJIT_CALL_GETTER(medium)
DRJIT_CALL_GETTER(sampling_weight)
DRJIT_CALL_END(mitsuba::Emitter)
MI_CALL_TEMPLATE_END(Emitter)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/medium.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for packets of Medium pointers
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Medium)
MI_CALL_TEMPLATE_BEGIN(Medium)
DRJIT_CALL_GETTER(phase_function)
DRJIT_CALL_GETTER(use_emitter_sampling)
DRJIT_CALL_GETTER(is_homogeneous)
Expand All @@ -131,7 +131,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Medium)
DRJIT_CALL_METHOD(sample_interaction)
DRJIT_CALL_METHOD(transmittance_eval_pdf)
DRJIT_CALL_METHOD(get_scattering_coefficients)
DRJIT_CALL_END(mitsuba::Medium)
MI_CALL_TEMPLATE_END(Medium)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/phase.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::PhaseFunction)
MI_CALL_TEMPLATE_BEGIN(PhaseFunction)
DRJIT_CALL_METHOD(sample)
DRJIT_CALL_METHOD(eval_pdf)
DRJIT_CALL_METHOD(projected_area)
DRJIT_CALL_METHOD(max_projected_area)
DRJIT_CALL_GETTER(flags)
DRJIT_CALL_GETTER(component_count)
DRJIT_CALL_END(mitsuba::PhaseFunction)
MI_CALL_TEMPLATE_END(PhaseFunction)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/sensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Sensor)
MI_CALL_TEMPLATE_BEGIN(Sensor)
DRJIT_CALL_METHOD(sample_ray)
DRJIT_CALL_METHOD(sample_ray_differential)
DRJIT_CALL_METHOD(sample_direction)
Expand All @@ -326,4 +326,4 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Sensor)
DRJIT_CALL_GETTER(flags)
DRJIT_CALL_GETTER(shape)
DRJIT_CALL_GETTER(medium)
DRJIT_CALL_END(mitsuba::Sensor)
MI_CALL_TEMPLATE_END(Sensor)
4 changes: 2 additions & 2 deletions include/mitsuba/render/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Shape)
MI_CALL_TEMPLATE_BEGIN(Shape)
DRJIT_CALL_METHOD(compute_surface_interaction)
DRJIT_CALL_METHOD(has_attribute)
DRJIT_CALL_METHOD(eval_attribute)
Expand Down Expand Up @@ -1112,7 +1112,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Shape)
auto is_mesh() const { return shape_type() == (uint32_t) mitsuba::ShapeType::Mesh; }
auto is_medium_transition() const { return interior_medium() != nullptr ||
exterior_medium() != nullptr; }
DRJIT_CALL_END(mitsuba::Shape)
MI_CALL_TEMPLATE_END(Shape)

//! @}
// -----------------------------------------------------------------------
14 changes: 14 additions & 0 deletions resources/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def w(s):
f.write(' helper various macros to instantiate multiple variants of Mitsuba. */\n\n')

f.write('#pragma once\n\n')
f.write('#include <array>\n')
f.write('#include <mitsuba/core/fwd.h>\n')

enable_jit = False
Expand Down Expand Up @@ -89,6 +90,7 @@ def w(s):

f.write('NAMESPACE_BEGIN(mitsuba)\n')
f.write('NAMESPACE_BEGIN(detail)\n')

f.write('/// Convert a <Float, Spectrum> type pair into one of the strings in MI_VARIANT\n')
f.write('template <typename Float_, typename Spectrum_> constexpr const char *get_variant() {\n')
for index, (name, float_, spectrum) in enumerate(enabled):
Expand All @@ -98,6 +100,18 @@ def w(s):
f.write(' else\n')
f.write(' return "";\n')
f.write('}\n')

max_len = max(len(e[0]) for e in enabled)
f.write('/// Convert a <Float, Spectrum> type pair into a fixed-length string\n')
f.write('template <typename Float_, typename Spectrum_> constexpr std::array<char, %d> get_variant_padded() {\n' % (max_len+1))
for index, (name, float_, spectrum) in enumerate(enabled):
f.write(' %sif constexpr (std::is_same_v<Float_, %s> &&\n' % ('else ' if index > 0 else '', float_))
f.write(' %s std::is_same_v<Spectrum_, %s>)\n' % (' ' if index > 0 else '', spectrum))
f.write(' return std::array<char, %d>{"%s"};\n' % (max_len+1, name.ljust(max_len, '_')))
f.write(' else\n')
f.write(' return std::array<char, %d>{"%s"};\n' % (max_len+1, '_' * max_len))
f.write('}\n')

f.write('NAMESPACE_END(detail)\n')
f.write('NAMESPACE_END(mitsuba)\n')

Expand Down
3 changes: 1 addition & 2 deletions src/render/bsdf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ NAMESPACE_BEGIN(mitsuba)

MI_VARIANT BSDF<Float, Spectrum>::BSDF(const Properties &props)
: m_flags(+BSDFFlags::Empty), m_id(props.id()) {
if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::BSDF", this);
MI_REGISTRY_PUT("BSDF", this);
}

MI_VARIANT BSDF<Float, Spectrum>::~BSDF() {
Expand Down
5 changes: 2 additions & 3 deletions src/render/emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ MI_VARIANT Emitter<Float, Spectrum>::Emitter(const Properties &props)
: Base(props) {
m_sampling_weight = props.get<ScalarFloat>("sampling_weight", 1.0f);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Emitter", this);
MI_REGISTRY_PUT("Emitter", this);
}

MI_VARIANT Emitter<Float, Spectrum>::~Emitter() {
MI_VARIANT Emitter<Float, Spectrum>::~Emitter() {
if constexpr (dr::is_jit_v<Float>)
jit_registry_remove(this);
}
Expand Down
10 changes: 4 additions & 6 deletions src/render/medium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

NAMESPACE_BEGIN(mitsuba)

MI_VARIANT Medium<Float, Spectrum>::Medium() :
m_is_homogeneous(false),
MI_VARIANT Medium<Float, Spectrum>::Medium() :
m_is_homogeneous(false),
m_has_spectral_extinction(true) {

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Medium", this);
MI_REGISTRY_PUT("Medium", this);
}

MI_VARIANT Medium<Float, Spectrum>::Medium(const Properties &props) : m_id(props.id()) {
Expand All @@ -33,8 +32,7 @@ MI_VARIANT Medium<Float, Spectrum>::Medium(const Properties &props) : m_id(props

m_sample_emitters = props.get<bool>("sample_emitters", true);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Medium", this);
MI_REGISTRY_PUT("Medium", this);
}

MI_VARIANT Medium<Float, Spectrum>::~Medium() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/phase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ NAMESPACE_BEGIN(mitsuba)
MI_VARIANT
PhaseFunction<Float, Spectrum>::PhaseFunction(const Properties &props)
: m_flags(+PhaseFunctionFlags::Empty), m_id(props.id()) {
if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::PhaseFunction", this);
MI_REGISTRY_PUT("PhaseFunction", this);
}

MI_VARIANT PhaseFunction<Float, Spectrum>::~PhaseFunction() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/sensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ MI_VARIANT Sensor<Float, Spectrum>::Sensor(const Properties &props) : Base(props
}
}

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Sensor", this);
MI_REGISTRY_PUT("Sensor", this);
}

MI_VARIANT Sensor<Float, Spectrum>::~Sensor() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ MI_VARIANT Shape<Float, Spectrum>::Shape(const Properties &props) : m_id(props.i

m_silhouette_sampling_weight = props.get<ScalarFloat>("silhouette_sampling_weight", 1.0f);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Shape", this);
MI_REGISTRY_PUT("Shape", this);
}

MI_VARIANT Shape<Float, Spectrum>::~Shape() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/shapegroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ MI_VARIANT ShapeGroup<Float, Spectrum>::ShapeGroup(const Properties &props) {
}
#endif

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::ShapeGroup", this);
MI_REGISTRY_PUT("ShapeGroup", this);
}

MI_VARIANT ShapeGroup<Float, Spectrum>::~ShapeGroup() {
Expand Down
3 changes: 1 addition & 2 deletions src/shapes/merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class MergeShape final : public Shape<Float, Spectrum> {
Log(Info, "Collapsed %zu into %zu meshes. (took %s, %zu objects ignored)",
visited, tbl.size(), util::time_string((float) timer.value()), ignored);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Shape", this);
MI_REGISTRY_PUT("Shape", this);
}

std::vector<ref<Object>> expand() const override {
Expand Down