How to optimize variables in custom mitsuba plugins ? #1315
-
I followed a tutorial on creating custom Python plugins and I want to optimize the tint value using auto-differentiation. I created a gradient-enabled Float and passed it to the scene parameters. However, after rendering, the resulting image did not have gradient enabled. I know that changing tint to a bitmap plugin might help, but sometimes I need to optimize custom values. Therefore, I prefer to keep tint as a Float, Color3f, ScalarColor3f or Tensor rather than a texture plugin. Here is the reproducer. The first part is copied from the tutorial. import drjit as dr
import mitsuba as mi
mi.set_variant('cuda_ad_rgb')
import mitsuba as mi
import drjit as dr
class MyBSDF(mi.BSDF):
def __init__(self, props):
mi.BSDF.__init__(self, props)
# Read 'eta' and 'tint' properties from `props`
self.eta = 1.33
if props.has_property('eta'):
self.eta = props['eta']
self.tint = props['tint']
# Set the BSDF flags
reflection_flags = mi.BSDFFlags.DeltaReflection | mi.BSDFFlags.FrontSide | mi.BSDFFlags.BackSide
transmission_flags = mi.BSDFFlags.DeltaTransmission | mi.BSDFFlags.FrontSide | mi.BSDFFlags.BackSide
self.m_components = [reflection_flags, transmission_flags]
self.m_flags = reflection_flags | transmission_flags
def sample(self, ctx, si, sample1, sample2, active):
# Compute Fresnel terms
cos_theta_i = mi.Frame3f.cos_theta(si.wi)
r_i, cos_theta_t, eta_it, eta_ti = mi.fresnel(cos_theta_i, self.eta)
t_i = dr.maximum(1.0 - r_i, 0.0)
# Pick between reflection and transmission
selected_r = (sample1 <= r_i) & active
# Fill up the BSDFSample struct
bs = mi.BSDFSample3f()
bs.pdf = dr.select(selected_r, r_i, t_i)
bs.sampled_component = dr.select(selected_r, mi.UInt32(0), mi.UInt32(1))
bs.sampled_type = dr.select(selected_r, mi.UInt32(+mi.BSDFFlags.DeltaReflection),
mi.UInt32(+mi.BSDFFlags.DeltaTransmission))
bs.wo = dr.select(selected_r,
mi.reflect(si.wi),
mi.refract(si.wi, cos_theta_t, eta_ti))
bs.eta = dr.select(selected_r, 1.0, eta_it)
# For reflection, tint based on the incident angle (more tint at grazing angle)
value_r = dr.lerp(mi.Color3f(self.tint), mi.Color3f(1.0), dr.clamp(cos_theta_i, 0.0, 1.0))
# For transmission, radiance must be scaled to account for the solid angle compression
value_t = mi.Color3f(1.0) * dr.sqr(eta_ti)
value = dr.select(selected_r, value_r, value_t)
return (bs, value)
def eval(self, ctx, si, wo, active):
return 0.0
def pdf(self, ctx, si, wo, active):
return 0.0
def eval_pdf(self, ctx, si, wo, active):
return 0.0, 0.0
def traverse(self, callback):
callback.put_parameter('tint', self.tint, mi.ParamFlags.Differentiable)
def parameters_changed(self, keys):
print("🏝️ there is nothing to do here 🏝️")
def to_string(self):
return ('MyBSDF[\n'
' eta=%s,\n'
' tint=%s,\n'
']' % (self.eta, self.tint))
mi.register_bsdf("mybsdf", lambda props: MyBSDF(props))
my_bsdf = mi.load_dict({
'type' : 'mybsdf',
'tint' : [0.2, 0.9, 0.2],
'eta' : 1.33
})
print(my_bsdf)
scene = mi.load_dict({
'type': 'scene',
'integrator': {
'type': 'path'
},
'light': {
'type': 'constant',
'radiance': 0.99,
},
'sphere' : {
'type': 'sphere',
'bsdf': my_bsdf
},
'sensor': {
'type': 'perspective',
'to_world': mi.ScalarTransform4f.look_at(origin=[0, -5, 5],
target=[0, 0, 0],
up=[0, 0, 1]),
}
})
# ------------------------------------------------------------
params = mi.traverse(scene)
print(params)
tint = mi.Float(0.9, 0.2, 0.2)
dr.set_grad_enabled(tint, True)
print('tint grad: ', tint.grad_enabled_())
# True
key = 'sphere.bsdf.tint'
params[key] = tint
params.update()
image = mi.render(scene=scene, params=params, spp=64)
print('image grad: ', image.grad_enabled_())
# False :(
dr.backward(dr.sum(image))
# Raise error: the argument does not depend on the input variable(s) being differentiated. mitsuba version : 3.5.2 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi @XiaoXinyyx I think what's missing here is that you need to call: When it will be differentiated, the |
Beta Was this translation helpful? Give feedback.
Oh, in the constructor we should have:
and in the initialization of the tint:
Otherwise internally, the system will keep thinking that
self.tint
is adrjit.scalar.Array3f64
(you can printtype(self.tint)
to check that it is adrjit.*.ad
type).I'll change this in the tutorial, so that it is less confusing even though we don't need it to be differentiated.