Replies: 1 comment
-
I took another approach - to do the optimisation on the pytorch side and just wrap the mitsuba rendering function instead. This seems to work pretty well. I've also extended the test script to include position, rotation and colour now... from typing import List, Tuple
import drjit as dr
import matplotlib.pyplot as plt
import mitsuba as mi
import numpy as np
import torch
from kornia.geometry import axis_angle_to_rotation_matrix
from pytorch3d.utils import ico_sphere
from torch import nn
if USE_CUDA:
if 'cuda_ad_rgb' not in mi.variants():
raise RuntimeError('No CUDA variant found.')
mi.set_variant('cuda_ad_rgb')
device = torch.device('cuda')
else:
mi.set_variant('llvm_ad_rgb')
device = torch.device('cpu')
from mitsuba import ScalarTransform4f as T
SHAPE_NAME = 'sphere'
VERTEX_KEY = SHAPE_NAME + '.vertex_positions'
FACES_KEY = SHAPE_NAME + '.faces'
BSDF_KEY = SHAPE_NAME + '.bsdf'
COLOUR_KEY = BSDF_KEY + '.reflectance.value'
def to_numpy(t: torch.Tensor) -> np.ndarray:
return t.detach().cpu().numpy()
class Sphere(nn.Module):
def __init__(
self,
scale: float = 1.0,
origin: List[float] = [0, 0, 0],
rotvec: List[float] = [0, 0, 0],
colour: List[float] = [1, 1, 1]
):
"""
Create a sphere with the given scale, origin, rotation and colour.
"""
super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float32), requires_grad=True)
self.origin = nn.Parameter(torch.tensor(origin, dtype=torch.float32), requires_grad=True)
self.rotvec = nn.Parameter(torch.tensor(rotvec, dtype=torch.float32), requires_grad=True)
self.colour = nn.Parameter(torch.tensor(colour, dtype=torch.float32), requires_grad=True)
def build_mesh(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create a sphere mesh with the given origin and radius.
"""
# Build basic sphere
sphere = ico_sphere(level=1, device=self.origin.device)
vertices = sphere.verts_packed()
faces = sphere.faces_packed()
# Apply scaling
vertices = vertices * self.scale
# Apply translation
vertices = vertices + self.origin
# Apply rotation
R = axis_angle_to_rotation_matrix(self.rotvec[None, :]).squeeze(0)
vertices = vertices @ R.T
return vertices, faces
def build_mitsuba_mesh(shape: Sphere) -> mi.Mesh:
"""
Convert a Sphere object into a Mitsuba mesh.
"""
# Build the mesh in pytorch and convert the parameters to mitsuba format
vertices, faces = shape.build_mesh()
vertices = mi.TensorXf(vertices)
faces = mi.TensorXi64(faces)
# Set up the material properties
props = mi.Properties()
props[BSDF_KEY] = mi.load_dict({
'type': 'diffuse',
'reflectance': {
'type': 'rgb',
'value': shape.colour.tolist()
}
})
# Construct the mitusba mesh and set the vertex positions and faces
mesh = mi.Mesh(
SHAPE_NAME,
vertex_count=len(vertices),
face_count=len(faces),
has_vertex_normals=False,
has_vertex_texcoords=False,
props=props
)
mesh_params = mi.traverse(mesh)
mesh_params['vertex_positions'] = dr.ravel(vertices)
mesh_params['faces'] = dr.ravel(faces)
return mesh
def create_scene(shape: Sphere, spp=256, res=400) -> mi.Scene:
"""
Create a Mitsuba scene containing the given shape.
"""
scene = mi.load_dict({
'type': 'scene',
'integrator': {
'type': 'prb_projective',
'max_depth': 32,
'rr_depth': 4,
'sppi': 0
},
'sensor': {
'type': 'perspective',
'to_world': T.look_at(
origin=[0, 0, 10],
target=[0, 0, 0],
up=[0, 1, 0]
),
'sampler': {
'type': 'independent',
'sample_count': spp
},
'film': {
'type': 'hdrfilm',
'width': res,
'height': res,
'filter': {'type': 'gaussian'},
'sample_border': True,
},
},
'light': {
'type': 'rectangle',
'to_world': T.scale(50) @ T.look_at(
origin=[0, 0, 10],
target=[0, 0, 0],
up=[0, 1, 0]
),
'emitter': {
'type': 'area',
'radiance': {
'type': 'rgb',
'value': np.ones(3) * 50.0
}
},
},
SHAPE_NAME: build_mitsuba_mesh(shape)
})
return scene
def optimise_scene():
# Parameters
spp = 32
n_iterations = 2000
lr = 0.01
plot_freq = 5
# Set up target and initial scenes
sphere_target = Sphere(
scale=2,
origin=[1, 0.5, 0.1],
rotvec=[0.5, 0.5, 0],
colour=[0, 1, 0]
)
sphere_opt = Sphere(
scale=1,
origin=[-1, -0.2, 1],
rotvec=[0, 5, -2],
colour=[0, 0, 1]
)
sphere_target.to(device)
sphere_opt.to(device)
scene_target = create_scene(shape=sphere_target, spp=spp)
scene = create_scene(shape=sphere_opt, spp=spp)
params = mi.traverse(scene)
@dr.wrap_ad(source='torch', target='drjit')
def render_image(vertices, faces, colour, seed=1):
params[VERTEX_KEY] = dr.ravel(vertices)
params[FACES_KEY] = dr.ravel(faces)
params[COLOUR_KEY] = dr.unravel(mi.Color3f, colour)
params.update()
return mi.render(scene, params, seed=seed)
def plot(img_opt):
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for ax, img in zip(axes, [image_target, img_opt]):
img = to_numpy(img)
ax.imshow(img**(1.0 / 2.2))
ax.axis('off')
fig.suptitle(f'Iteration {i} Loss: {loss:.4E}')
fig.tight_layout()
plt.show()
plt.close(fig)
# Set up optimiser
opt = torch.optim.Adam(sphere_opt.parameters(), lr=lr)
# Optimise the scene
losses = []
s_diffs = []
o_diffs = []
r_diffs = []
c_diffs = []
for i in range(n_iterations):
if i > 0:
# Take a gradient step and update the parameters
loss.backward()
# r_err.backward() # Debug to check that the radius is being updated
opt.step()
opt.zero_grad()
# Rebuild the mesh using the new radius
vertices, faces = sphere_opt.build_mesh()
# Render new images - colour needs cloning as Mitsuba doesn't map nn.Parameters
img_i = render_image(vertices, faces, sphere_opt.colour.clone(), seed=i)
image_target = mi.render(scene_target, seed=i).torch()
# Calculate losses
loss = torch.mean((img_i - image_target)**2)
s_diff = torch.abs(sphere_opt.scale - sphere_target.scale)
o_diff = torch.norm(sphere_opt.origin - sphere_target.origin)
r_diff = torch.norm(sphere_opt.rotvec - sphere_target.rotvec)
c_diff = torch.norm(sphere_opt.colour - sphere_target.colour)
losses.append(loss.item())
s_diffs.append(s_diff.item())
o_diffs.append(o_diff.item())
r_diffs.append(r_diff.item())
c_diffs.append(c_diff.item())
logger.info(
f'Iteration {i} ' +
f'Loss: {loss.item():.3E} ' +
f's-error: {s_diff.item():.3E} (s={sphere_opt.scale.item():.2f}) ' +
f'o-error: {o_diff.item():.3E} (o=[' + ','.join([f'{v:.2f}' for v in sphere_opt.origin]) + ') ' +
f'r-error: {r_diff.item():.3E} (r=[' + ','.join([f'{v:.2f}' for v in sphere_opt.rotvec]) + ') ' +
f'c-error: {c_diff.item():.3E} (c=[' + ','.join([f'{v:.2f}' for v in sphere_opt.colour]) + ')'
)
# Plot
if i % plot_freq == 0:
plot(img_i)
# Plot the final scene comparison
plot(img_i)
# Plot the losses
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
ax = axes[0]
ax.plot(losses)
ax.set_title('Image Loss (L2)')
ax.set_xlabel('Iteration')
ax.set_ylabel('Error')
ax.set_yscale('log')
ax = axes[1]
ax.plot(s_diffs, label='Scale')
ax.plot(o_diffs, label='Origin')
ax.plot(r_diffs, label='Rotation')
ax.plot(c_diffs, label='Colour')
ax.set_title('Parameter Convergence')
ax.set_xlabel('Iteration')
ax.set_ylabel('Error')
ax.set_yscale('log')
ax.legend()
fig.tight_layout()
plt.show()
if __name__ == '__main__':
optimise_scene() Any feedback or suggestions welcomed! If anyone thinks this is doc-worthy I'm happy to turn it into a notebook. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a well defined shape that I want to fit the dimensions to using the inverse shape optimisation approach. The shape is parametrised by just a few parameters and from these I can build a mesh in pytorch. To get things going I've started with a toy example to just optimise the radius of an ico_sphere that is built using pytorch3d. If I keep the radius parameter outside of the pytorch function and then scale the returned (mi-type) vertices by it then this works nicely. However when I pass in this radius parameter I get an error at
dr.backward(loss)
.The full optimisation script I'm using:
The error I get with the failing approach is:
But the error it actually wants to throw is in
torch/autograd/__init__.py
line 88 due to a shape mismatch:Hopefully I'm not trying to do anything too wild and I'm just not going about this quite the right way. I also considered keeping my
radius
parameter as a torch variable but I couldn't figure out how to optimise it then.Any help or advice would be much appreciated!
EDIT
I know the radius parameter is not really a radius and I know I could just build the mesh once and then scale the vertices but I'm building up to something where I will need to rebuild the mesh every time.
Beta Was this translation helpful? Give feedback.
All reactions