diff --git a/.github/inverse_rendering.gif b/.github/inverse_rendering.gif
new file mode 100644
index 000000000..717283f44
Binary files /dev/null and b/.github/inverse_rendering.gif differ
diff --git a/README.md b/README.md
index 9bf32fe37..be128184d 100644
--- a/README.md
+++ b/README.md
@@ -62,6 +62,10 @@ Get started with PyTorch3D by trying one of the tutorial notebooks.
|:------------------------------------------------------------:|:--------------------------------------------------:|
| [Fit Textured Volume in Implicitron](https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/implicitron_volumes.ipynb)| [Implicitron Config System](https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/implicitron_config_system.ipynb)|
+| |
+|:------------------------------------------------------------:|
+| [Inverse Rendering](https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/inverse_rendering.ipynb) |
+
diff --git a/docs/tutorials/inverse_rendering.ipynb b/docs/tutorials/inverse_rendering.ipynb
new file mode 100644
index 000000000..a5c712934
--- /dev/null
+++ b/docs/tutorials/inverse_rendering.ipynb
@@ -0,0 +1,2040 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_Ip8kp4TfBLZ"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kuXHJv44fBLe"
+ },
+ "source": [
+ "# Fit a mesh and cameras simultaneously\n",
+ "\n",
+ "This tutorial shows how to:\n",
+ "- Load a mesh and textures from an `.obj` file.\n",
+ "- Create a synthetic dataset by rendering a textured mesh from multiple viewpoints\n",
+ "- Fit a mesh to the observed synthetic images using differential silhouette rendering\n",
+ "- Fit a mesh and its textures using differential textured rendering"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Bnj3THhzfBLf"
+ },
+ "source": [
+ "## 0. Install and Import modules"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "okLalbR_g7NS"
+ },
+ "source": [
+ "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "musUWTglgxSB",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "b927f4c1-2932-481c-bd98-092142e18f63"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting iopath\n",
+ " Downloading iopath-0.1.10.tar.gz (42 kB)\n",
+ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/42.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.2/42.2 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from iopath) (4.66.5)\n",
+ "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iopath) (4.12.2)\n",
+ "Collecting portalocker (from iopath)\n",
+ " Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)\n",
+ "Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)\n",
+ "Building wheels for collected packages: iopath\n",
+ " Building wheel for iopath (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for iopath: filename=iopath-0.1.10-py3-none-any.whl size=31528 sha256=37694219b6590dc6469222f52c523d5814a23898e0c80654e1af2e397b2611a6\n",
+ " Stored in directory: /root/.cache/pip/wheels/9a/a3/b6/ac0fcd1b4ed5cfeb3db92e6a0e476cfd48ed0df92b91080c1d\n",
+ "Successfully built iopath\n",
+ "Installing collected packages: portalocker, iopath\n",
+ "Successfully installed iopath-0.1.10 portalocker-2.10.1\n",
+ "Trying to install wheel for PyTorch3D\n",
+ "Looking in links: https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/download.html\n",
+ "Collecting pytorch3d\n",
+ " Downloading https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl (20.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.5/20.5 MB\u001b[0m \u001b[31m221.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: iopath in /usr/local/lib/python3.10/dist-packages (from pytorch3d) (0.1.10)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from iopath->pytorch3d) (4.66.5)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from iopath->pytorch3d) (4.12.2)\n",
+ "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from iopath->pytorch3d) (2.10.1)\n",
+ "Installing collected packages: pytorch3d\n",
+ "Successfully installed pytorch3d-0.7.8\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "import torch\n",
+ "import subprocess\n",
+ "need_pytorch3d=False\n",
+ "try:\n",
+ " import pytorch3d\n",
+ "except ModuleNotFoundError:\n",
+ " need_pytorch3d=True\n",
+ "if need_pytorch3d:\n",
+ " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
+ " version_str=\"\".join([\n",
+ " f\"py3{sys.version_info.minor}_cu\",\n",
+ " torch.version.cuda.replace(\".\",\"\"),\n",
+ " f\"_pyt{pyt_version_str}\"\n",
+ " ])\n",
+ " !pip install iopath\n",
+ " if sys.platform.startswith(\"linux\"):\n",
+ " print(\"Trying to install wheel for PyTorch3D\")\n",
+ " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
+ " pip_list = !pip freeze\n",
+ " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
+ " if need_pytorch3d:\n",
+ " print(f\"failed to find/install wheel for {version_str}\")\n",
+ "if need_pytorch3d:\n",
+ " print(\"Installing PyTorch3D from source\")\n",
+ " !pip install ninja\n",
+ " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "nX99zdoffBLg"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from pytorch3d.utils import ico_sphere\n",
+ "import numpy as np\n",
+ "from tqdm.notebook import tqdm\n",
+ "\n",
+ "# Util function for loading meshes\n",
+ "from pytorch3d.io import load_objs_as_meshes, save_obj\n",
+ "\n",
+ "from pytorch3d.loss import (\n",
+ " chamfer_distance,\n",
+ " mesh_edge_loss,\n",
+ " mesh_laplacian_smoothing,\n",
+ " mesh_normal_consistency,\n",
+ ")\n",
+ "\n",
+ "# Data structures and functions for rendering\n",
+ "from pytorch3d.structures import Meshes\n",
+ "from pytorch3d.renderer import (\n",
+ " look_at_view_transform,\n",
+ " FoVPerspectiveCameras,\n",
+ " PointLights,\n",
+ " DirectionalLights,\n",
+ " Materials,\n",
+ " RasterizationSettings,\n",
+ " MeshRenderer,\n",
+ " MeshRasterizer,\n",
+ " SoftPhongShader,\n",
+ " SoftSilhouetteShader,\n",
+ " SoftPhongShader,\n",
+ " TexturesVertex\n",
+ ")\n",
+ "\n",
+ "# add path for demo utils functions\n",
+ "import sys\n",
+ "import os\n",
+ "sys.path.append(os.path.abspath(''))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Lxmehq6Zhrzv"
+ },
+ "source": [
+ "If using **Google Colab**, fetch the utils file for plotting image grids:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HZozr3Pmho-5",
+ "outputId": "e15d3f0b-2543-47b3-ec64-bca5cfc8f4a3"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2024-10-20 01:05:05-- https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py\n",
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 1608 (1.6K) [text/plain]\n",
+ "Saving to: ‘plot_image_grid.py’\n",
+ "\n",
+ "plot_image_grid.py 100%[===================>] 1.57K --.-KB/s in 0s \n",
+ "\n",
+ "2024-10-20 01:05:05 (43.9 MB/s) - ‘plot_image_grid.py’ saved [1608/1608]\n",
+ "\n",
+ "--2024-10-20 01:05:05-- https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/camera_visualization.py\n",
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 2037 (2.0K) [text/plain]\n",
+ "Saving to: ‘camera_visualization.py’\n",
+ "\n",
+ "camera_visualizatio 100%[===================>] 1.99K --.-KB/s in 0s \n",
+ "\n",
+ "2024-10-20 01:05:05 (39.8 MB/s) - ‘camera_visualization.py’ saved [2037/2037]\n",
+ "\n",
+ "Collecting kaleido\n",
+ " Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)\n",
+ "Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.9/79.9 MB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: kaleido\n",
+ "Successfully installed kaleido-0.2.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py\n",
+ "from plot_image_grid import image_grid\n",
+ "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/camera_visualization.py\n",
+ "from camera_visualization import plot_camera_scene\n",
+ "# install this package for output plotly\n",
+ "import locale\n",
+ "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
+ "!pip install -U kaleido"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "g4B62MzYiJUM"
+ },
+ "source": [
+ "OR if running **locally** uncomment and run the following cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "paJ4Im8ahl7O"
+ },
+ "outputs": [],
+ "source": [
+ "# from utils.plot_image_grid import image_grid"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": true,
+ "id": "5jGq772XfBLk"
+ },
+ "source": [
+ "### 1. Load a mesh and texture file\n",
+ "\n",
+ "Load an `.obj` file and its associated `.mtl` file and create a **Textures** and **Meshes** object.\n",
+ "\n",
+ "**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes.\n",
+ "\n",
+ "**TexturesVertex** is an auxiliary datastructure for storing vertex rgb texture information about meshes.\n",
+ "\n",
+ "**Meshes** has several class methods which are used throughout the rendering pipeline."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a8eU4zo5jd_H"
+ },
+ "source": [
+ "If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path `data/cow_mesh`:\n",
+ "If running locally, the data is already available at the correct path."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tTm0cVuOjb1W",
+ "outputId": "c3c3b8f7-9d79-4a07-cf04-74e96bf71142"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2024-10-20 01:06:57-- https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
+ "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 108.157.254.121, 108.157.254.15, 108.157.254.102, ...\n",
+ "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|108.157.254.121|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 330659 (323K) [text/plain]\n",
+ "Saving to: ‘data/cow_mesh/cow.obj’\n",
+ "\n",
+ "\rcow.obj 0%[ ] 0 --.-KB/s \rcow.obj 100%[===================>] 322.91K --.-KB/s in 0.01s \n",
+ "\n",
+ "2024-10-20 01:06:57 (28.7 MB/s) - ‘data/cow_mesh/cow.obj’ saved [330659/330659]\n",
+ "\n",
+ "--2024-10-20 01:06:57-- https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
+ "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 108.157.254.121, 108.157.254.15, 108.157.254.102, ...\n",
+ "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|108.157.254.121|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 155 [text/plain]\n",
+ "Saving to: ‘data/cow_mesh/cow.mtl’\n",
+ "\n",
+ "cow.mtl 100%[===================>] 155 --.-KB/s in 0s \n",
+ "\n",
+ "2024-10-20 01:06:57 (352 KB/s) - ‘data/cow_mesh/cow.mtl’ saved [155/155]\n",
+ "\n",
+ "--2024-10-20 01:06:57-- https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png\n",
+ "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 108.157.254.121, 108.157.254.15, 108.157.254.102, ...\n",
+ "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|108.157.254.121|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 78699 (77K) [image/png]\n",
+ "Saving to: ‘data/cow_mesh/cow_texture.png’\n",
+ "\n",
+ "cow_texture.png 100%[===================>] 76.85K --.-KB/s in 0.004s \n",
+ "\n",
+ "2024-10-20 01:06:57 (19.5 MB/s) - ‘data/cow_mesh/cow_texture.png’ saved [78699/78699]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "!mkdir -p data/cow_mesh\n",
+ "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
+ "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
+ "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "gi5Kd0GafBLl"
+ },
+ "outputs": [],
+ "source": [
+ "# Setup\n",
+ "if torch.cuda.is_available():\n",
+ " device = torch.device(\"cuda:0\")\n",
+ " torch.cuda.set_device(device)\n",
+ "else:\n",
+ " device = torch.device(\"cpu\")\n",
+ "\n",
+ "# Set paths\n",
+ "DATA_DIR = \"./data\"\n",
+ "obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
+ "\n",
+ "# Load obj file\n",
+ "mesh = load_objs_as_meshes([obj_filename], device=device)\n",
+ "\n",
+ "# We scale normalize and center the target mesh to fit in a sphere of radius 1\n",
+ "# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh\n",
+ "# to its original center and scale. Note that normalizing the target mesh,\n",
+ "# speeds up the optimization but is not necessary!\n",
+ "verts = mesh.verts_packed()\n",
+ "N = verts.shape[0]\n",
+ "center = verts.mean(0)\n",
+ "scale = max((verts - center).abs().max(0)[0])\n",
+ "mesh.offset_verts_(-center)\n",
+ "mesh.scale_verts_((1.0 / float(scale)));"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "17c4xmtyfBMH"
+ },
+ "source": [
+ "## 2. Dataset Creation\n",
+ "\n",
+ "We sample different camera positions that encode multiple viewpoints of the cow. We create a renderer with a shader that performs texture map interpolation. We render a synthetic dataset of images of the textured cow mesh from multiple viewpoints.\n",
+ "For simplicity, all cameras point towards the mesh\n",
+ "\n",
+ "```\n",
+ "# This is formatted as code\n",
+ "```\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "CDQKebNNfBMI"
+ },
+ "outputs": [],
+ "source": [
+ "# Number of different viewpoints\n",
+ "num_views = 8\n",
+ "\n",
+ "# Image size\n",
+ "img_size = 128\n",
+ "\n",
+ "# Construct viewing angles from 8-vertex platonic solid, which is a cube\n",
+ "# Vertices of a cube in Cartesian coordinates\n",
+ "vertices = torch.tensor([\n",
+ " [ 1, 1, 1],\n",
+ " [ 1, 1, -1],\n",
+ " [ 1, -1, 1],\n",
+ " [ 1, -1, -1],\n",
+ " [-1, 1, 1],\n",
+ " [-1, 1, -1],\n",
+ " [-1, -1, 1],\n",
+ " [-1, -1, -1]\n",
+ "], dtype=torch.float32)\n",
+ "\n",
+ "# Calculate the radial distance (r) for normalization\n",
+ "r = torch.norm(vertices, dim=1)\n",
+ "\n",
+ "# Calculate elevation (theta)\n",
+ "elev = torch.atan2(vertices[:, 2], torch.norm(vertices[:, :2], dim=1))\n",
+ "\n",
+ "# Calculate azimuth (phi)\n",
+ "azim = torch.atan2(vertices[:, 1], vertices[:, 0])\n",
+ "\n",
+ "# Point light\n",
+ "lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])\n",
+ "\n",
+ "# Initialize cameras\n",
+ "R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim, degrees=False)\n",
+ "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n",
+ "view_idx = 6\n",
+ "camera = FoVPerspectiveCameras(device=device, R=R[None, view_idx, ...], T=T[None, view_idx, ...])\n",
+ "\n",
+ "\n",
+ "# Rasterization and shading settings for standard renderer\n",
+ "# Here we set the output image to be of size\n",
+ "# img_size x img_size. As we are rendering images for visualization purposes only we will set faces_per_pixel=1\n",
+ "# and blur_radius=0.0. We also set bin_size and max_faces_per_bin to None which ensure that\n",
+ "# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for\n",
+ "# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of\n",
+ "# the difference between naive and coarse-to-fine rasterization.\n",
+ "raster_settings = RasterizationSettings(\n",
+ " image_size=img_size,\n",
+ " blur_radius=0.0,\n",
+ " faces_per_pixel=1,\n",
+ ")\n",
+ "\n",
+ "# Rasterization and shading settings for silhouette renderer\n",
+ "raster_settings_silhouette = RasterizationSettings(\n",
+ " image_size=img_size,\n",
+ " blur_radius=0.0,\n",
+ " faces_per_pixel=1,\n",
+ ")\n",
+ "\n",
+ "# Create standard renderer for creating ground truth rgb images\n",
+ "renderer_std = MeshRenderer(\n",
+ " rasterizer=MeshRasterizer(\n",
+ " cameras=cameras,\n",
+ " raster_settings=raster_settings\n",
+ " ),\n",
+ " shader=SoftPhongShader(\n",
+ " device=device,\n",
+ " cameras=cameras,\n",
+ " lights=lights\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# Create Silhouette renderer for creating ground truth silhouette images\n",
+ "renderer_silhouette = MeshRenderer(\n",
+ " rasterizer=MeshRasterizer(\n",
+ " cameras=camera,\n",
+ " raster_settings=raster_settings_silhouette\n",
+ " ),\n",
+ " shader=SoftSilhouetteShader()\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# Create target images\n",
+ "\n",
+ "meshes = mesh.extend(num_views)\n",
+ "gt_rgb = renderer_std(meshes, cameras=cameras, lights=lights)\n",
+ "# Render silhouette images. The 3rd channel of the rendering output is\n",
+ "# the alpha/silhouette channel\n",
+ "gt_silhouette = renderer_silhouette(meshes, cameras=cameras, lights=lights)\n",
+ "gt_camera_pts = cameras.get_camera_center()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TppB4PVmR1Rc"
+ },
+ "source": [
+ "Visualize the dataset:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 520
+ },
+ "id": "HHE0CnbVR1Rd",
+ "collapsed": true,
+ "outputId": "051283eb-989a-4b29-b9e0-45d51ee6fde5"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "