diff --git a/src/emitters/volumelight.cpp b/src/emitters/volumelight.cpp index bca5a0925..2a62339ec 100644 --- a/src/emitters/volumelight.cpp +++ b/src/emitters/volumelight.cpp @@ -74,7 +74,6 @@ class VolumeLight final : public Emitter { m_flags = +EmitterFlags::Medium; dr::set_attr(this, "flags", m_flags); - dr::set_attr(this, "radiance", m_radiance); } void traverse(TraversalCallback *callback) override { diff --git a/src/python/python/ad/integrators/prbvolpath.py b/src/python/python/ad/integrators/prbvolpath.py index b15323447..425e981bf 100644 --- a/src/python/python/ad/integrators/prbvolpath.py +++ b/src/python/python/ad/integrators/prbvolpath.py @@ -268,11 +268,14 @@ def sample(self, null_scatter = (sampler.next_1d(active_medium) < (emission_prob + null_prob)) act_null_scatter = null_scatter & active_medium act_medium_scatter = ~null_scatter & active_medium - L[act_null_scatter] += throughput * radiance * dr.detach(weight * emission_weight) + contrib = throughput * radiance * dr.detach(emission_weight * weight) + L[act_null_scatter] += dr.detach(contrib if is_primal else -contrib) weight[act_null_scatter] *= mei.sigma_n * dr.detach(null_weight) else: scatter_weight = mi.Float(1.0) act_medium_scatter = active_medium + contrib = dr.zeros(mi.UnpolarizedSpectrum) + depth[act_medium_scatter] += 1 last_scatter_event[act_medium_scatter] = dr.detach(mei) @@ -291,6 +294,8 @@ def sample(self, if not is_primal and dr.grad_enabled(weight): Lo = dr.detach(dr.select(active_medium | escaped_medium, L / dr.maximum(1e-8, weight), 0.0)) dr.backward(δL * weight * Lo) + if not is_primal and dr.grad_enabled(contrib): + dr.backward(δL * contrib) phase_ctx = mi.PhaseFunctionContext(sampler) phase = mei.medium.phase_function()