-
Notifications
You must be signed in to change notification settings - Fork 0
/
nerad.py
456 lines (364 loc) · 15.6 KB
/
nerad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
import drjit as dr
import mitsuba as mi
if __name__ == "__main__":
mi.set_variant("cuda_ad_rgb")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from mitsuba.python.ad.integrators.common import mis_weight
from tqdm import tqdm
from tinycudann import Encoding as NGPEncoding
class NRField(nn.Module):
def __init__(self, scene: mi.Scene, width=256, n_hidden=8) -> None:
"""Initialize an instance of NRField.
Args:
bb_min (mi.ScalarBoundingBox3f): minimum point of the bounding box
bb_max (mi.ScalarBoundingBox3f): maximum point of the bounding box
"""
super().__init__()
self.bbox = scene.bbox()
enc_config = {
"otype": "HashGrid",
# "type": "Hash",
"base_resolution": 16,
"n_levels": 8,
"n_features_per_level": 4,
"log2_hashmap_size": 22,
}
self.pos_enc = NGPEncoding(3, enc_config)
in_size = 3 * 4 + self.pos_enc.n_output_dims
hidden_layers = []
for _ in range(n_hidden):
hidden_layers.append(nn.Linear(width, width))
hidden_layers.append(nn.ReLU(inplace=True))
self.network = nn.Sequential(
nn.Linear(in_size, width),
nn.ReLU(inplace=True),
*hidden_layers,
nn.Linear(width, 3),
).to("cuda")
def forward(self, si: mi.SurfaceInteraction3f):
"""Forward pass for NRField.
Args:
si (mitsuba.SurfaceInteraction3f): surface interaction
bsdf (mitsuba.BSDF): bidirectional scattering distribution function
Returns:
torch.Tensor
"""
with dr.suspend_grad():
x = ((si.p - self.bbox.min) / (self.bbox.max - self.bbox.min)).torch()
wi = si.to_world(si.wi).torch()
n = si.sh_frame.n.torch()
f_d = si.bsdf().eval_diffuse_reflectance(si).torch()
z_x = self.pos_enc(x)
inp = torch.concat([x, wi, n, f_d, z_x], dim=1)
out = self.network(inp)
out = torch.abs(out)
return out.to(torch.float32)
class NeradIntegrator(mi.SamplingIntegrator):
def __init__(
self,
model,
batch_size=2**14,
M=32,
lr=5e-4,
seed=42,
sample_mode="lhs",
l_sampler: mi.Sampler = mi.load_dict(
{"type": "independent", "sample_count": 1}
),
r_sampler: mi.Sampler = mi.load_dict(
{"type": "independent", "sample_count": 1}
),
) -> None:
super().__init__(mi.Properties())
self.model = model
self.l_sampler = l_sampler
self.r_sampler = r_sampler
self.M = M
self.batch_size = batch_size
self.lr = lr
self.seed = seed
self.sample_mode = sample_mode
self.losses = []
def sample_si(
self,
scene: mi.Scene,
shape_sampler: mi.DiscreteDistribution,
sample1,
sample2,
sample3,
active=True,
) -> mi.SurfaceInteraction3f:
"""Sample a batch of surface interactions with bsdfs.
Args:
scene (mitsuba.Scene): the underlying scene
shape_sampler (mi.DiscreteDistribution): a source of random numbers for shape sampling
sample1 (drjit.llvm.ad.Float): determines mesh surfaces
sample2 (mitsuba.Point2f): determines positions on the meshes
sample3 (mitsuba.Point2f): determines directions at the positions
active (bool, optional): mask to specify active lanes. Defaults to True.
Returns:
tuple of (mitsuba.SurfaceInteraction3f, mitsuba.BSDF)
"""
shape_index = shape_sampler.sample(sample1, active)
shape: mi.Shape = dr.gather(mi.ShapePtr, scene.shapes_dr(), shape_index, active)
ps = shape.sample_position(0.5, sample2, active)
si = mi.SurfaceInteraction3f(ps, dr.zeros(mi.Color0f))
si.shape = shape
bsdf = shape.bsdf()
active_two_sided = mi.has_flag(bsdf.flags(), mi.BSDFFlags.BackSide)
si.wi = dr.select(
active_two_sided,
mi.warp.square_to_uniform_sphere(sample3),
mi.warp.square_to_uniform_hemisphere(sample3),
)
si.shape = shape
return si
def first_non_specular_or_null_si(
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, sampler: mi.Sampler
):
"""Find the first non-specular or null surface interaction.
Args:
scene (mi.Scene): Scene object.
si (mi.SurfaceInteraction3f): Surface interaction.
sampler (mi.Sampler): Sampler object.
Returns:
tuple: A tuple containing four values:
- si (mi.SurfaceInteraction3f): First non-specular or null surface interaction.
- β (mi.Spectrum): The product of the weights of all previous BSDFs.
- null_face (bool): A boolean mask indicating whether the surface is a null face or not.
"""
max_iterations = 6
# Instead of `bsdf.flags()`, based on `bsdf_sample.sampled_type`.
with dr.suspend_grad():
bsdf_ctx = mi.BSDFContext()
depth = mi.UInt32(0)
β = mi.Spectrum(1)
null_face = mi.Bool(True)
active = mi.Bool(True)
loop = mi.Loop(
name="first_non_specular_or_null_si",
state=lambda: (sampler, depth, β, active, null_face, si),
)
loop.set_max_iterations(max_iterations)
while loop(active):
# loop invariant: si is located at non-null and Delta surface
# if si is located at null or Smooth surface, end loop
bsdf: mi.BSDF = si.bsdf()
bsdf_sample, bsdf_weight = bsdf.sample(
bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active
)
null_face &= ~mi.has_flag(
bsdf_sample.sampled_type, mi.BSDFFlags.BackSide
) & (si.wi.z < 0)
active &= si.is_valid() & ~null_face & (depth < max_iterations)
active &= ~mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Smooth)
ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
si[active] = scene.ray_intersect(
ray,
ray_flags=mi.RayFlags.All,
coherent=False,
active=active,
)
β[active] *= bsdf_weight
depth += 1
# return si at the first non-specular bounce or null face
return si, β, null_face
def render_lhs(
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, mode: str = "drjit"
) -> mi.Color3f | torch.Tensor:
"""
Renders the left-hand side of the rendering equation by calculating the emitter's radiance and
the neural network output at the given surface interaction (si) position and direction (bsdf).
Args:
scene (mi.Scene): A Mitsuba scene object.
model (torch.nn.Module): A neural network model that takes si and bsdf as input and returns
a predicted radiance value.
si (mi.SurfaceInteraction3f): A Mitsuba surface interaction object.
bsdf (mi.BSDF): A Mitsuba BSDF object.
Returns:
tuple: A tuple containing four values:
- L (mi.Spectrum): The total outgoing radiance value.
- Le (mi.Spectrum): The emitter's radiance.
- out (torch.Tensor): The neural network's predicted radiance value.
- mask (torch.Tensor): A boolean tensor indicating which surface interactions are valid.
"""
with dr.suspend_grad():
Le = si.emitter(scene).eval(si)
# discard the null bsdf backside
null_face = ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.BackSide) & (
si.wi.z < 0
)
mask = si.is_valid() & ~null_face
out = self.model(si)
if mode == "drjit":
return Le + dr.select(mask, mi.Spectrum(out), 0)
elif mode == "torch":
# In torch mode returns only the network output
# i.e. N(x, ω_o)
return out * mask.torch().reshape(-1, 1)
def render_rhs(
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, sampler, mode="drjit"
) -> mi.Color3d | torch.Tensor:
with dr.suspend_grad():
bsdf_ctx = mi.BSDFContext()
depth = mi.UInt32(0)
L = mi.Spectrum(0)
β = mi.Spectrum(1)
η = mi.Float(1)
prev_si = dr.zeros(mi.SurfaceInteraction3f)
prev_bsdf_pdf = mi.Float(1.0)
prev_bsdf_delta = mi.Bool(True)
# Do not use Le1 only render Transmittance
# L = E + T = E(x, ω_o) + T(E(x', -ω_i) + N(x', -ω_i))
# L(x, ω_o) = E(x, ω_o) + ∫ L(x'(x, ω_o), ω_i) f(x, ω_o, ω_i) dω_i^⊥ = E(x, ω_o) + N(x', ω_o)
Le1 = β * si.emitter(scene).eval(si)
bsdf = si.bsdf()
# emitter sampling
active_next = si.is_valid()
active_em = active_next & mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth)
ds, em_weight = scene.sample_emitter_direction(
si, sampler.next_2d(), True, active_em
)
active_em &= dr.neq(ds.pdf, 0.0)
wo = si.to_local(ds.d)
bsdf_value_em, bsdf_pdf_em = bsdf.eval_pdf(bsdf_ctx, si, wo, active_em)
mis_em = dr.select(ds.delta, 1, mis_weight(ds.pdf, bsdf_pdf_em))
Lr_dir = β * mis_em * bsdf_value_em * em_weight
# bsdf sampling
bsdf_sample, bsdf_weight = bsdf.sample(
bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active_next
)
# update
# L = L + Le + Lr_dir
ray = si.spawn_ray(si.to_world(bsdf_sample.wo))
η *= bsdf_sample.eta
β *= bsdf_weight
prev_si = dr.detach(si, True)
prev_bsdf_pdf = bsdf_sample.pdf
prev_bsdf_delta = mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Delta)
si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True)
bsdf = si.bsdf(ray)
ds = mi.DirectionSample3f(scene, si=si, ref=prev_si)
mis_bsdf = mis_weight(
prev_bsdf_pdf,
scene.pdf_emitter_direction(prev_si, ds, ~prev_bsdf_delta),
)
si, β2, null_face = self.first_non_specular_or_null_si(scene, si, sampler)
β *= β2
# T(E)
L = β * mis_bsdf * si.emitter(scene).eval(si) + Lr_dir
out = self.model(si)
active_nr = (
si.is_valid()
& ~null_face
& dr.eq(si.emitter(scene).eval(si), mi.Spectrum(0))
)
w_nr = β * mis_bsdf
if mode == "drjit":
# E(x) + T(E(x')) + T(N(x'))
return Le1 + L + dr.select(active_nr, w_nr * mi.Spectrum(out), 0)
elif mode == "torch":
# T(E) + T(N)
return L.torch() + out * dr.select(active_nr, w_nr, 0).torch()
def sample(
self,
scene: mi.Scene,
sampler: mi.Sampler,
ray: mi.Ray3f,
medium: mi.Medium = None,
active: bool = True,
) -> tuple[mi.Color3f, mi.Bool, list[mi.Color3f]]:
self.model.eval()
with torch.no_grad():
w, h = list(scene.sensors()[0].film().size())
L = mi.Spectrum(0)
ray = mi.Ray3f(dr.detach(ray))
si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True)
# update si and bsdf with the first non-specular ones
si, β, _ = self.first_non_specular_or_null_si(scene, si, sampler)
if self.sample_mode == "rhs":
L = self.render_rhs(scene, si, sampler, mode="drjit")
elif self.sample_mode == "lhs":
L = self.render_lhs(scene, si, mode="drjit")
self.model.train()
torch.cuda.empty_cache()
return β * L, si.is_valid(), []
def train(self, scene: mi.Scene, steps: int):
m_area = []
for shape in scene.shapes():
if not shape.is_emitter() and mi.has_flag(
shape.bsdf().flags(), mi.BSDFFlags.Smooth
):
m_area.append(shape.surface_area())
else:
m_area.append([0.0])
m_area = np.array(m_area)[:, 0]
if len(m_area):
m_area /= m_area.sum()
else:
raise Warning("No smooth shape. No need of neural network training.")
shape_sampler = mi.DiscreteDistribution(m_area)
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
tqdm_iterator = tqdm(range(steps))
self.model.train()
for step in tqdm_iterator:
optimizer.zero_grad()
# detach the computation graph of samplers to avoid lengthy graph of dr.jit
r_sampler = self.r_sampler.clone()
l_sampler = self.l_sampler.clone()
r_sampler.seed(step, self.batch_size * self.M)
l_sampler.seed(step, self.batch_size)
si_lhs = self.sample_si(
scene,
shape_sampler,
l_sampler.next_1d(),
l_sampler.next_2d(),
l_sampler.next_2d(),
)
# copy `si_lhs` M times for RHS evaluation
indices = dr.arange(mi.UInt, 0, self.batch_size)
indices = dr.repeat(indices, self.M)
si_rhs = dr.gather(type(si_lhs), si_lhs, indices)
# bsdf_rhs = dr.gather(type(bsdf_lhs), bsdf_lhs, indices)
# LHS and RHS evaluation
lhs = self.render_lhs(scene, si_lhs, mode="torch")
rhs = self.render_rhs(scene, si_rhs, r_sampler, mode="torch")
rhs = rhs.reshape(self.batch_size, self.M, 3).mean(dim=1)
norm = 1
# in our experiment, normalization makes rendering biased (dimmer)
# norm = (lhs + rhs).detach() / 2 + 1e-2
loss = torch.nn.MSELoss()(lhs / norm, rhs / norm)
loss.backward()
optimizer.step()
tqdm_iterator.set_description("Loss %.04f" % (loss.item()))
self.losses.append(loss.item())
self.model.eval()
# optimizer = torch.optim.Adam(field.parameters(), lr=lr)
# train_losses = []
# tqdm_iterator = tqdm(range(total_steps))
if __name__ == "__main__":
torch.manual_seed(0)
scene = mi.load_file("./data/scenes/cornell-box-specular/scene.xml")
# scene = mi.load_file("./data/scenes/caustics/scene.xml")
# scene = mi.load_file("./data/scenes/veach-door/scene.xml")
field = NRField(scene, n_hidden=3, width=256)
integrator = NeradIntegrator(field)
integrator.train(scene, 1000)
image = mi.render(scene, spp=1, integrator=integrator)
losses = integrator.losses
ref_image = mi.render(scene, spp=16)
pt_image = mi.render(scene, spp=16, integrator=mi.load_dict({"type": "ptracer"}))
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
fig.patch.set_visible(False) # Hide the figure's background
ax[0][0].axis("off") # Remove the axes from the image
ax[0][0].imshow(mi.util.convert_to_bitmap(image))
ax[1][0].axis("off")
ax[1][0].imshow(mi.util.convert_to_bitmap(ref_image))
ax[0][1].axis("off")
ax[0][1].imshow(mi.util.convert_to_bitmap(pt_image))
ax[1][1].plot(losses, color="red")
fig.tight_layout() # Remove any extra white spaces around the image
plt.show()