From 950c1664bd8c02f17fd20d9e4361a86b96966a57 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:51:04 -0400 Subject: [PATCH 1/6] Modified camera data module --- configs/dreamfusion-sd-eff.yaml | 115 +++++++ threestudio/data/uncond_eff.py | 543 ++++++++++++++++++++++++++++++++ threestudio/utils/ops.py | 21 ++ 3 files changed, 679 insertions(+) create mode 100644 configs/dreamfusion-sd-eff.yaml create mode 100644 threestudio/data/uncond_eff.py diff --git a/configs/dreamfusion-sd-eff.yaml b/configs/dreamfusion-sd-eff.yaml new file mode 100644 index 00000000..06e7d6a1 --- /dev/null +++ b/configs/dreamfusion-sd-eff.yaml @@ -0,0 +1,115 @@ +name: "dreamfusion-sd" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "eff-random-camera-datamodule" +data: + batch_size: 1 + width: 128 + height: 128 + sample_width: 64 + sample_height: 64 + camera_distance_range: [1.5, 2.0] + fovy_range: [40, 70] + elevation_range: [-10, 45] + light_sample_strategy: "dreamfusion" + eval_camera_distance: 2.0 + eval_fovy_deg: 70. + +system_type: "dreamfusion-system" +system: + geometry_type: "implicit-volume" + geometry: + radius: 2.0 + normal_type: "analytic" + + # the density initialization proposed in the DreamFusion paper + # does not work very well + # density_bias: "blob_dreamfusion" + # density_activation: exp + # density_blob_scale: 5. + # density_blob_std: 0.2 + + # use Magic3D density initialization instead + density_bias: "blob_magic3d" + density_activation: softplus + density_blob_scale: 10. + density_blob_std: 0.5 + + # coarse to fine hash grid encoding + # to ensure smooth analytic normals + pos_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + per_level_scale: 1.447269237440378 # max resolution 4096 + start_level: 8 # resolution ~200 + start_step: 2000 + update_steps: 500 + + material_type: "diffuse-with-point-light-material" + material: + ambient_only_steps: 2001 + albedo_activation: sigmoid + + background_type: "neural-environment-map-background" + background: + color_activation: sigmoid + + renderer_type: "nerf-volume-renderer" + renderer: + radius: ${system.geometry.radius} + num_samples_per_ray: 512 + + prompt_processor_type: "stable-diffusion-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + prompt: ??? + + guidance_type: "stable-diffusion-guidance" + guidance: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + guidance_scale: 100. + weighting_strategy: sds + min_step_percent: 0.02 + max_step_percent: 0.98 + + loggers: + wandb: + enable: false + project: "threestudio" + name: None + + loss: + lambda_sds: 1. + lambda_orient: [0, 10., 1000., 5000] + lambda_sparsity: 1. + lambda_opaque: 0. + + optimizer: + name: Adam + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.01 + background: + lr: 0.001 + +trainer: + max_steps: 10000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 16-mixed + +checkpoint: + save_last: true # save at each validation time + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} diff --git a/threestudio/data/uncond_eff.py b/threestudio/data/uncond_eff.py new file mode 100644 index 00000000..18af81c9 --- /dev/null +++ b/threestudio/data/uncond_eff.py @@ -0,0 +1,543 @@ +import bisect +import math +import random +from dataclasses import dataclass, field + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + +import threestudio +from threestudio import register +from threestudio.utils.base import Updateable +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_device +from threestudio.utils.ops import ( + get_full_projection_matrix, + get_mvp_matrix, + get_projection_matrix, + get_ray_directions, + get_rays, + mask_ray_directions +) +from threestudio.utils.typing import * + + +@dataclass +class EffRandomCameraDataModuleConfig: + # height, width, and batch_size should be Union[int, List[int]] + # but OmegaConf does not support Union of containers + height: Any = 128 + width: Any = 128 + sample_height: Any = 64 + sample_width: Any = 64 + batch_size: Any = 1 + resolution_milestones: List[int] = field(default_factory=lambda: []) + eval_height: int = 512 + eval_width: int = 512 + eval_batch_size: int = 1 + n_val_views: int = 1 + n_test_views: int = 120 + elevation_range: Tuple[float, float] = (-10, 90) + azimuth_range: Tuple[float, float] = (-180, 180) + camera_distance_range: Tuple[float, float] = (1, 1.5) + fovy_range: Tuple[float, float] = ( + 40, + 70, + ) # in degrees, in vertical direction (along height) + camera_perturb: float = 0.1 + center_perturb: float = 0.2 + up_perturb: float = 0.02 + light_position_perturb: float = 1.0 + light_distance_range: Tuple[float, float] = (0.8, 1.5) + eval_elevation_deg: float = 15.0 + eval_camera_distance: float = 1.5 + eval_fovy_deg: float = 70.0 + light_sample_strategy: str = "dreamfusion" + batch_uniform_azimuth: bool = True + progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy + + rays_d_normalize: bool = True + + +class EffRandomCameraIterableDataset(IterableDataset, Updateable): + def __init__(self, cfg: Any) -> None: + super().__init__() + self.cfg: EffRandomCameraDataModuleConfig = cfg + self.heights: List[int] = ( + [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height + ) + self.widths: List[int] = ( + [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width + ) + self.sample_heights: List[int] = ( + [self.cfg.sample_height] if isinstance(self.cfg.sample_height, int) else self.cfg.sample_height + ) + self.sample_widths: List[int] = ( + [self.cfg.sample_width] if isinstance(self.cfg.sample_width, int) else self.cfg.sample_width + ) + self.batch_sizes: List[int] = ( + [self.cfg.batch_size] + if isinstance(self.cfg.batch_size, int) + else self.cfg.batch_size + ) + assert len(self.heights) == len(self.widths) == len(self.batch_sizes) == len(self.sample_heights) == len(self.sample_widths) + self.resolution_milestones: List[int] + if ( + len(self.heights) == 1 + and len(self.widths) == 1 + and len(self.batch_sizes) == 1 + and len(self.sample_heights) == 1 + and len(self.sample_widths) == 1 + ): + if len(self.cfg.resolution_milestones) > 0: + threestudio.warn( + "Ignoring resolution_milestones since height and width are not changing" + ) + self.resolution_milestones = [-1] + else: + assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 + self.resolution_milestones = [-1] + self.cfg.resolution_milestones + + self.directions_unit_focals = [ + get_ray_directions(H=height, W=width, focal=1.0) + for (height, width) in zip(self.heights, self.widths) + ] + dirs_and_masks = [ + (mask_ray_directions(dir,H,W,s_H,s_W)) for (dir,H,W,s_H,s_W) + in zip(self.directions_unit_focals, self.heights, + self.widths, self.sample_heights, self.sample_widths) + ] + self.directions_unit_focals = [dir for (dir,mask) in dirs_and_masks] + self.efficiency_masks = [mask for (dir,mask)in dirs_and_masks] + self.height: int = self.heights[0] + self.width: int = self.widths[0] + self.sample_height: int = self.sample_heights[0] + self.sample_width: int = self.sample_widths[0] + self.batch_size: int = self.batch_sizes[0] + self.directions_unit_focal = self.directions_unit_focals[0] + self.efficiency_mask = self.efficiency_masks[0] + self.elevation_range = self.cfg.elevation_range + self.azimuth_range = self.cfg.azimuth_range + self.camera_distance_range = self.cfg.camera_distance_range + self.fovy_range = self.cfg.fovy_range + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 + self.height = self.heights[size_ind] + self.width = self.widths[size_ind] + self.sample_height = self.sample_heights[size_ind] + self.sample_width = self.sample_widths[size_ind] + self.batch_size = self.batch_sizes[size_ind] + self.directions_unit_focal = self.directions_unit_focals[size_ind] + self.efficiency_mask = self.efficiency_masks[size_ind] + threestudio.debug( + f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}" + ) + # progressive view + self.progressive_view(global_step) + + def __iter__(self): + while True: + yield {} + + def progressive_view(self, global_step): + r = min(1.0, global_step / (self.cfg.progressive_until + 1)) + self.elevation_range = [ + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0], + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1], + ] + self.azimuth_range = [ + (1 - r) * 0.0 + r * self.cfg.azimuth_range[0], + (1 - r) * 0.0 + r * self.cfg.azimuth_range[1], + ] + # self.camera_distance_range = [ + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[0], + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[1], + # ] + # self.fovy_range = [ + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0], + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1], + # ] + + def collate(self, batch) -> Dict[str, Any]: + # sample elevation angles + elevation_deg: Float[Tensor, "B"] + elevation: Float[Tensor, "B"] + if random.random() < 0.5: + # sample elevation angles uniformly with a probability 0.5 (biased towards poles) + elevation_deg = ( + torch.rand(self.batch_size) + * (self.elevation_range[1] - self.elevation_range[0]) + + self.elevation_range[0] + ) + elevation = elevation_deg * math.pi / 180 + else: + # otherwise sample uniformly on sphere + elevation_range_percent = [ + self.elevation_range[0] / 180.0 * math.pi, + self.elevation_range[1] / 180.0 * math.pi, + ] + # inverse transform sampling + elevation = torch.asin( + ( + torch.rand(self.batch_size) + * ( + math.sin(elevation_range_percent[1]) + - math.sin(elevation_range_percent[0]) + ) + + math.sin(elevation_range_percent[0]) + ) + ) + elevation_deg = elevation / math.pi * 180.0 + + # sample azimuth angles from a uniform distribution bounded by azimuth_range + azimuth_deg: Float[Tensor, "B"] + if self.cfg.batch_uniform_azimuth: + # ensures sampled azimuth angles in a batch cover the whole range + azimuth_deg = ( + torch.rand(self.batch_size) + torch.arange(self.batch_size) + ) / self.batch_size * ( + self.azimuth_range[1] - self.azimuth_range[0] + ) + self.azimuth_range[ + 0 + ] + else: + # simple random sampling + azimuth_deg = ( + torch.rand(self.batch_size) + * (self.azimuth_range[1] - self.azimuth_range[0]) + + self.azimuth_range[0] + ) + azimuth = azimuth_deg * math.pi / 180 + + # sample distances from a uniform distribution bounded by distance_range + camera_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.camera_distance_range[1] - self.camera_distance_range[0]) + + self.camera_distance_range[0] + ) + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.batch_size, 1) + + # sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb] + camera_perturb: Float[Tensor, "B 3"] = ( + torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb + - self.cfg.camera_perturb + ) + camera_positions = camera_positions + camera_perturb + # sample center perturbations from a normal distribution with mean 0 and std center_perturb + center_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.center_perturb + ) + center = center + center_perturb + # sample up perturbations from a normal distribution with mean 0 and std up_perturb + up_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.up_perturb + ) + up = up + up_perturb + + # sample fovs from a uniform distribution bounded by fov_range + fovy_deg: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0]) + + self.fovy_range[0] + ) + fovy = fovy_deg * math.pi / 180 + + # sample light distance from a uniform distribution bounded by light_distance_range + light_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0]) + + self.cfg.light_distance_range[0] + ) + + if self.cfg.light_sample_strategy == "dreamfusion": + # sample light direction from a normal distribution with mean camera_position and std light_position_perturb + light_direction: Float[Tensor, "B 3"] = F.normalize( + camera_positions + + torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb, + dim=-1, + ) + # get light position by scaling light direction by light distance + light_positions: Float[Tensor, "B 3"] = ( + light_direction * light_distances[:, None] + ) + elif self.cfg.light_sample_strategy == "magic3d": + # sample light direction within restricted angle range (pi/3) + local_z = F.normalize(camera_positions, dim=-1) + local_x = F.normalize( + torch.stack( + [local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])], + dim=-1, + ), + dim=-1, + ) + local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1) + rot = torch.stack([local_x, local_y, local_z], dim=-1) + light_azimuth = ( + torch.rand(self.batch_size) * math.pi * 2 - math.pi + ) # [-pi, pi] + light_elevation = ( + torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6 + ) # [pi/6, pi/2] + light_positions_local = torch.stack( + [ + light_distances + * torch.cos(light_elevation) + * torch.cos(light_azimuth), + light_distances + * torch.cos(light_elevation) + * torch.sin(light_azimuth), + light_distances * torch.sin(light_elevation), + ], + dim=-1, + ) + light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0] + else: + raise ValueError( + f"Unknown light sample strategy: {self.cfg.light_sample_strategy}" + ) + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy) + directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[ + None, :, :, : + ].repeat(self.batch_size, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + + # Importance note: the returned rays_d MUST be normalized! + rays_o, rays_d = get_rays( + directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize + ) + + self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( + fovy, self.width / self.height, 0.01, 100.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) + self.fovy = fovy + + return { + "rays_o": rays_o, + "rays_d": rays_d, + "efficiency_mask":self.efficiency_mask, + "mvp_mtx": mvp_mtx, + "camera_positions": camera_positions, + "c2w": c2w, + "light_positions": light_positions, + "elevation": elevation_deg, + "azimuth": azimuth_deg, + "camera_distances": camera_distances, + "height": self.height, + "width": self.width, + "fovy": self.fovy, + "proj_mtx": self.proj_mtx, + } + +### No changes here as this class is used in Validation/test +class RandomCameraDataset(Dataset): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.cfg: EffRandomCameraDataModuleConfig = cfg + self.split = split + + if split == "val": + self.n_views = self.cfg.n_val_views + else: + self.n_views = self.cfg.n_test_views + + azimuth_deg: Float[Tensor, "B"] + if self.split == "val": + # make sure the first and last view are not the same + azimuth_deg = torch.linspace(0, 360.0, self.n_views + 1)[: self.n_views] + else: + azimuth_deg = torch.linspace(0, 360.0, self.n_views) + elevation_deg: Float[Tensor, "B"] = torch.full_like( + azimuth_deg, self.cfg.eval_elevation_deg + ) + camera_distances: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_camera_distance + ) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.cfg.eval_batch_size, 1) + + fovy_deg: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_fovy_deg + ) + fovy = fovy_deg * math.pi / 180 + light_positions: Float[Tensor, "B 3"] = camera_positions + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = ( + 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) + ) + directions_unit_focal = get_ray_directions( + H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0 + ) + directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ + None, :, :, : + ].repeat(self.n_views, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + + rays_o, rays_d = get_rays( + directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize + ) + self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( + fovy, self.cfg.eval_width / self.cfg.eval_height, 0.01, 100.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) + + self.rays_o, self.rays_d = rays_o, rays_d + self.mvp_mtx = mvp_mtx + self.c2w = c2w + self.camera_positions = camera_positions + self.light_positions = light_positions + self.elevation, self.azimuth = elevation, azimuth + self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg + self.camera_distances = camera_distances + self.fovy = fovy + + def __len__(self): + return self.n_views + + def __getitem__(self, index): + return { + "index": index, + "rays_o": self.rays_o[index], + "rays_d": self.rays_d[index], + "mvp_mtx": self.mvp_mtx[index], + "c2w": self.c2w[index], + "camera_positions": self.camera_positions[index], + "light_positions": self.light_positions[index], + "elevation": self.elevation_deg[index], + "azimuth": self.azimuth_deg[index], + "camera_distances": self.camera_distances[index], + "height": self.cfg.eval_height, + "width": self.cfg.eval_width, + "fovy": self.fovy[index], + "proj_mtx": self.proj_mtx[index], + } + + def collate(self, batch): + batch = torch.utils.data.default_collate(batch) + batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) + return batch + + +@register("random-camera-datamodule") +class RandomCameraDataModule(pl.LightningDataModule): + cfg: EffRandomCameraDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(EffRandomCameraDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = EffRandomCameraIterableDataset(self.cfg) + if stage in [None, "fit", "validate"]: + self.val_dataset = RandomCameraDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = RandomCameraDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: + return DataLoader( + dataset, + # very important to disable multi-processing if you want to change self attributes at runtime! + # (for example setting self.width and self.height in update_step) + num_workers=0, # type: ignore + batch_size=batch_size, + collate_fn=collate_fn, + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader( + self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate + ) + # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) + + def test_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index 81d5b599..b03750c8 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -216,6 +216,27 @@ def get_ray_directions( return directions +def mask_ray_directions( + directions:Float[Tensor, "H W 3"], + H: int, + W:int, + s_H:int, + s_W:int + ) -> Float[Tensor, "H W 3"]: + """ + Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image. + pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels. + then apply the mask to ray_directions vector. + """ + mask = torch.zeros(H,W, device= directions.device) + p = (s_H*s_W)/(H*W) + mask += p + mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = 1 - p + ### mask contains prob of individual pixel, drawing using Bernoulli dist + mask = torch.bernoulli(mask) + directions = directions[mask] + + return directions,mask def get_rays( directions: Float[Tensor, "... 3"], From 7f2f2831caee8754910de79593b2b106ef8f4e69 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:59:58 -0400 Subject: [PATCH 2/6] modified masking logic and shape adjust in SD input --- configs/dreamfusion-sd-eff.yaml | 2 +- threestudio/data/__init__.py | 2 +- threestudio/data/uncond_eff.py | 150 ++++------------------------- threestudio/systems/dreamfusion.py | 71 ++++++++++++++ threestudio/utils/ops.py | 51 +++++++--- 5 files changed, 130 insertions(+), 146 deletions(-) diff --git a/configs/dreamfusion-sd-eff.yaml b/configs/dreamfusion-sd-eff.yaml index 06e7d6a1..88a23aa6 100644 --- a/configs/dreamfusion-sd-eff.yaml +++ b/configs/dreamfusion-sd-eff.yaml @@ -17,7 +17,7 @@ data: eval_camera_distance: 2.0 eval_fovy_deg: 70. -system_type: "dreamfusion-system" +system_type: "efficient-dreamfusion-system" system: geometry_type: "implicit-volume" geometry: diff --git a/threestudio/data/__init__.py b/threestudio/data/__init__.py index ce2e5cc7..70aaeeb1 100644 --- a/threestudio/data/__init__.py +++ b/threestudio/data/__init__.py @@ -1 +1 @@ -from . import co3d, image, multiview, uncond +from . import co3d, image, multiview, uncond, uncond_eff diff --git a/threestudio/data/uncond_eff.py b/threestudio/data/uncond_eff.py index 18af81c9..274e5e96 100644 --- a/threestudio/data/uncond_eff.py +++ b/threestudio/data/uncond_eff.py @@ -23,7 +23,7 @@ mask_ray_directions ) from threestudio.utils.typing import * - +from threestudio.data.uncond import RandomCameraDataset @dataclass class EffRandomCameraDataModuleConfig: @@ -105,13 +105,18 @@ def __init__(self, cfg: Any) -> None: get_ray_directions(H=height, W=width, focal=1.0) for (height, width) in zip(self.heights, self.widths) ] - dirs_and_masks = [ - (mask_ray_directions(dir,H,W,s_H,s_W)) for (dir,H,W,s_H,s_W) - in zip(self.directions_unit_focals, self.heights, - self.widths, self.sample_heights, self.sample_widths) - ] - self.directions_unit_focals = [dir for (dir,mask) in dirs_and_masks] - self.efficiency_masks = [mask for (dir,mask)in dirs_and_masks] + + self.efficiency_masks = [ + (mask_ray_directions(H,W,s_H,s_W)) for (H,W,s_H,s_W) + in zip( self.heights, self.widths, + self.sample_heights, self.sample_widths)] + self.directions_unit_focals = [ + ( + self.directions_unit_focals[i].reshape(-1,3)[self.efficiency_masks[i]] + ).reshape(self.sample_heights[i],self.sample_widths[i],3) + for i in range(len(self.heights)) + ] + self.height: int = self.heights[0] self.width: int = self.widths[0] self.sample_height: int = self.sample_heights[0] @@ -341,6 +346,7 @@ def collate(self, batch) -> Dict[str, Any]: ) # Importance note: the returned rays_d MUST be normalized! + ### Efficiency masking added here rays_o, rays_d = get_rays( directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize ) @@ -364,136 +370,16 @@ def collate(self, batch) -> Dict[str, Any]: "camera_distances": camera_distances, "height": self.height, "width": self.width, + "sample_height": self.sample_height, + "sample_width": self.sample_width, "fovy": self.fovy, "proj_mtx": self.proj_mtx, } -### No changes here as this class is used in Validation/test -class RandomCameraDataset(Dataset): - def __init__(self, cfg: Any, split: str) -> None: - super().__init__() - self.cfg: EffRandomCameraDataModuleConfig = cfg - self.split = split - - if split == "val": - self.n_views = self.cfg.n_val_views - else: - self.n_views = self.cfg.n_test_views - - azimuth_deg: Float[Tensor, "B"] - if self.split == "val": - # make sure the first and last view are not the same - azimuth_deg = torch.linspace(0, 360.0, self.n_views + 1)[: self.n_views] - else: - azimuth_deg = torch.linspace(0, 360.0, self.n_views) - elevation_deg: Float[Tensor, "B"] = torch.full_like( - azimuth_deg, self.cfg.eval_elevation_deg - ) - camera_distances: Float[Tensor, "B"] = torch.full_like( - elevation_deg, self.cfg.eval_camera_distance - ) - - elevation = elevation_deg * math.pi / 180 - azimuth = azimuth_deg * math.pi / 180 - - # convert spherical coordinates to cartesian coordinates - # right hand coordinate system, x back, y right, z up - # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) - camera_positions: Float[Tensor, "B 3"] = torch.stack( - [ - camera_distances * torch.cos(elevation) * torch.cos(azimuth), - camera_distances * torch.cos(elevation) * torch.sin(azimuth), - camera_distances * torch.sin(elevation), - ], - dim=-1, - ) - - # default scene center at origin - center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) - # default camera up direction as +z - up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ - None, : - ].repeat(self.cfg.eval_batch_size, 1) - - fovy_deg: Float[Tensor, "B"] = torch.full_like( - elevation_deg, self.cfg.eval_fovy_deg - ) - fovy = fovy_deg * math.pi / 180 - light_positions: Float[Tensor, "B 3"] = camera_positions - - lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) - right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) - up = F.normalize(torch.cross(right, lookat), dim=-1) - c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( - [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], - dim=-1, - ) - c2w: Float[Tensor, "B 4 4"] = torch.cat( - [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 - ) - c2w[:, 3, 3] = 1.0 - - # get directions by dividing directions_unit_focal by focal length - focal_length: Float[Tensor, "B"] = ( - 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) - ) - directions_unit_focal = get_ray_directions( - H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0 - ) - directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ - None, :, :, : - ].repeat(self.n_views, 1, 1, 1) - directions[:, :, :, :2] = ( - directions[:, :, :, :2] / focal_length[:, None, None, None] - ) - - rays_o, rays_d = get_rays( - directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize - ) - self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( - fovy, self.cfg.eval_width / self.cfg.eval_height, 0.01, 100.0 - ) # FIXME: hard-coded near and far - mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) - - self.rays_o, self.rays_d = rays_o, rays_d - self.mvp_mtx = mvp_mtx - self.c2w = c2w - self.camera_positions = camera_positions - self.light_positions = light_positions - self.elevation, self.azimuth = elevation, azimuth - self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg - self.camera_distances = camera_distances - self.fovy = fovy - - def __len__(self): - return self.n_views - - def __getitem__(self, index): - return { - "index": index, - "rays_o": self.rays_o[index], - "rays_d": self.rays_d[index], - "mvp_mtx": self.mvp_mtx[index], - "c2w": self.c2w[index], - "camera_positions": self.camera_positions[index], - "light_positions": self.light_positions[index], - "elevation": self.elevation_deg[index], - "azimuth": self.azimuth_deg[index], - "camera_distances": self.camera_distances[index], - "height": self.cfg.eval_height, - "width": self.cfg.eval_width, - "fovy": self.fovy[index], - "proj_mtx": self.proj_mtx[index], - } - - def collate(self, batch): - batch = torch.utils.data.default_collate(batch) - batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) - return batch -@register("random-camera-datamodule") -class RandomCameraDataModule(pl.LightningDataModule): +@register("eff-random-camera-datamodule") +class EffRandomCameraDataModule(pl.LightningDataModule): cfg: EffRandomCameraDataModuleConfig def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: diff --git a/threestudio/systems/dreamfusion.py b/threestudio/systems/dreamfusion.py index 4e594b6e..205a6597 100644 --- a/threestudio/systems/dreamfusion.py +++ b/threestudio/systems/dreamfusion.py @@ -160,3 +160,74 @@ def on_test_epoch_end(self): name="test", step=self.true_global_step, ) + + +@threestudio.register("efficient-dreamfusion-system") +class EffDreamFusion(DreamFusion): + @dataclass + class Config(DreamFusion.Config): + pass + + cfg: Config + + def configure(self): + # create geometry, material, background, renderer + super().configure() + + def training_step(self, batch, batch_idx): + out = self(batch) + ### using mask to create image at original resolution during training + (B,s_H,s_W,C) = out["comp_rgb"].shape + mask = batch["efficiency_mask"] + comp_rgb = torch.zeros(B,batch["sample_height"],batch["sample_width"],C,device=mask.device) + comp_rgb[mask] = out["comp_rgb"] + out.update( + { + "comp_rgb": comp_rgb, + } + ) + + prompt_utils = self.prompt_processor() + guidance_out = self.guidance( + out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False + ) + + loss = 0.0 + + for name, value in guidance_out.items(): + if not (type(value) is torch.Tensor and value.numel() > 1): + self.log(f"train/{name}", value) + if name.startswith("loss_"): + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + + if self.C(self.cfg.loss.lambda_orient) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for orientation loss, no normal is found in the output." + ) + loss_orient = ( + out["weights"].detach() + * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 + ).sum() / (out["opacity"] > 0).sum() + self.log("train/loss_orient", loss_orient) + loss += loss_orient * self.C(self.cfg.loss.lambda_orient) + + loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() + self.log("train/loss_sparsity", loss_sparsity) + loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) + + opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) + loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) + self.log("train/loss_opaque", loss_opaque) + loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) + + # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/ + if "z_variance" in out and "lambda_z_variance" in self.cfg.loss: + loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() + self.log("train/loss_z_variance", loss_z_variance) + loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} \ No newline at end of file diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index b03750c8..17f2b8b0 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -217,26 +217,53 @@ def get_ray_directions( return directions def mask_ray_directions( - directions:Float[Tensor, "H W 3"], H: int, W:int, s_H:int, s_W:int - ) -> Float[Tensor, "H W 3"]: + ) -> Float[Tensor, "s_H s_W"]: """ Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image. - pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels. - then apply the mask to ray_directions vector. + pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio). + the masking is deferred to before calling get_rays(). """ - mask = torch.zeros(H,W, device= directions.device) + indices_all = torch.meshgrid( + torch.arange(W, dtype=torch.float32) , + torch.arange(H, dtype=torch.float32) , + indexing="xy", + ) + # indices_inner = torch.meshgrid( + # torch.arange((W-s_W)//2 , W - math.ceil((W-s_W)/2), dtype=torch.float32) , + # torch.arange((H-s_H)//2,H - math.ceil((H-s_H)/2), dtype=torch.float32) , + # indexing="xy", + # ) + mask = torch.zeros(H,W, dtype=torch.bool) + mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True + + in_ind_1d = (indices_all[0]+H*indices_all[1])[mask] + out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)] p = (s_H*s_W)/(H*W) - mask += p - mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = 1 - p - ### mask contains prob of individual pixel, drawing using Bernoulli dist - mask = torch.bernoulli(mask) - directions = directions[mask] - - return directions,mask + select_ind = in_ind_1d[ + torch.multinomial( + torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)] + select_ind = torch.concatenate( + [select_ind, out_ind_1d[torch.multinomial( + torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)] + ], + dim=0).to(dtype=torch.int).reshape(s_H,s_W) + + ### first attempt at sampling, this produces variable number of rays, + ### so 4D tensor directions cant be sampled + # mask = torch.zeros(H,W, device= directions.device) + # p = (s_H*s_W)/(H*W) + # mask += p + # mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = 1 - p + # ### mask contains prob of individual pixel, drawing using Bernoulli dist + # mask = torch.bernoulli(mask).to(dtype=torch.bool) + ### postponing masking before get_rays is called + #directions = directions[mask] + + return select_ind def get_rays( directions: Float[Tensor, "... 3"], From ec02948a5da10e5ed78cda433b023497a8819eb3 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:38:54 -0400 Subject: [PATCH 3/6] (Working)new sampling maskand SD loss edits --- threestudio/data/uncond_eff.py | 4 ++-- threestudio/systems/dreamfusion.py | 6 +++--- threestudio/utils/ops.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/threestudio/data/uncond_eff.py b/threestudio/data/uncond_eff.py index 274e5e96..a69f1bb5 100644 --- a/threestudio/data/uncond_eff.py +++ b/threestudio/data/uncond_eff.py @@ -112,8 +112,8 @@ def __init__(self, cfg: Any) -> None: self.sample_heights, self.sample_widths)] self.directions_unit_focals = [ ( - self.directions_unit_focals[i].reshape(-1,3)[self.efficiency_masks[i]] - ).reshape(self.sample_heights[i],self.sample_widths[i],3) + self.directions_unit_focals[i].view(-1,3)[self.efficiency_masks[i]] + ).view(self.sample_heights[i],self.sample_widths[i],3) for i in range(len(self.heights)) ] diff --git a/threestudio/systems/dreamfusion.py b/threestudio/systems/dreamfusion.py index 205a6597..3bb68d6d 100644 --- a/threestudio/systems/dreamfusion.py +++ b/threestudio/systems/dreamfusion.py @@ -179,11 +179,11 @@ def training_step(self, batch, batch_idx): ### using mask to create image at original resolution during training (B,s_H,s_W,C) = out["comp_rgb"].shape mask = batch["efficiency_mask"] - comp_rgb = torch.zeros(B,batch["sample_height"],batch["sample_width"],C,device=mask.device) - comp_rgb[mask] = out["comp_rgb"] + comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) + comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) out.update( { - "comp_rgb": comp_rgb, + "comp_rgb": comp_rgb.view(B,batch["height"],batch["width"],C), } ) diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index 17f2b8b0..3f3bfe19 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -250,7 +250,7 @@ def mask_ray_directions( [select_ind, out_ind_1d[torch.multinomial( torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)] ], - dim=0).to(dtype=torch.int).reshape(s_H,s_W) + dim=0).to(dtype=torch.int).view(s_H,s_W) ### first attempt at sampling, this produces variable number of rays, ### so 4D tensor directions cant be sampled From bfe69c85be35cc5bc6463e442e0e960512a4f277 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Thu, 8 Aug 2024 12:07:26 -0400 Subject: [PATCH 4/6] Changing the subsampling a bit, not better results --- threestudio/data/uncond_eff.py | 2 +- threestudio/utils/ops.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/threestudio/data/uncond_eff.py b/threestudio/data/uncond_eff.py index a69f1bb5..6b322caf 100644 --- a/threestudio/data/uncond_eff.py +++ b/threestudio/data/uncond_eff.py @@ -402,7 +402,7 @@ def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: dataset, # very important to disable multi-processing if you want to change self attributes at runtime! # (for example setting self.width and self.height in update_step) - num_workers=0, # type: ignore + num_workers=5, # type: ignore batch_size=batch_size, collate_fn=collate_fn, ) diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index 3f3bfe19..801cddd2 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -242,7 +242,10 @@ def mask_ray_directions( in_ind_1d = (indices_all[0]+H*indices_all[1])[mask] out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)] - p = (s_H*s_W)/(H*W) + ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already + ### leads to more samples inside anyways + + p = 0.5#(s_H*s_W)/(H*W) select_ind = in_ind_1d[ torch.multinomial( torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)] From 96cb5b8ecb4c35db4f7709b76e44dccacbea1461 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Sun, 22 Sep 2024 20:23:21 -0400 Subject: [PATCH 5/6] new exp with upsampling before SDS --- .gitignore | 2 ++ threestudio/systems/dreamfusion.py | 33 ++++++++++++++++-- threestudio/utils/ops.py | 54 +++++++++++++++++------------- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index b774bf79..e3a1470a 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,8 @@ coverage.xml .pytest_cache/ cover/ +# Slurm logs +slurm* # Translations *.mo *.pot diff --git a/threestudio/systems/dreamfusion.py b/threestudio/systems/dreamfusion.py index 3bb68d6d..135c6415 100644 --- a/threestudio/systems/dreamfusion.py +++ b/threestudio/systems/dreamfusion.py @@ -174,16 +174,43 @@ def configure(self): # create geometry, material, background, renderer super().configure() + def unmask(self,ind,subsampled_tensor,H,W): + """ + ind: B,s_H,s_W + subsampled_tensor: B,C,s_H,s_W + """ + + # Create a grid of coordinates for the original image size + offset = [ind[0,0]%H,ind[0,0]//H] + indices_all = torch.meshgrid( + torch.arange(W, dtype=torch.float32,device=self.device) , + torch.arange(H, dtype=torch.float32,device=self.device) , + indexing="xy" + ) + + grid = torch.stack( + [(indices_all[0] - offset[0])*4/(3*W), + (indices_all[1] - offset[1])*4/(H*3)], + dim=-1) + grid = grid*2 - 1 + grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1) + # Use grid_sample to upsample the subsampled tensor (B,C,H,W) + upsampled_tensor = torch.nn.functional.grid_sample(subsampled_tensor, grid, mode='bilinear', align_corners=True) + + return upsampled_tensor.permute(0,2,3,1) + def training_step(self, batch, batch_idx): out = self(batch) ### using mask to create image at original resolution during training (B,s_H,s_W,C) = out["comp_rgb"].shape + comp_rgb = out["comp_rgb"].permute(0,3,1,2) mask = batch["efficiency_mask"] - comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) - comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) + comp_rgb = self.unmask(mask,comp_rgb,batch["height"],batch["width"]) + # comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) + # comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) out.update( { - "comp_rgb": comp_rgb.view(B,batch["height"],batch["width"],C), + "comp_rgb": comp_rgb, } ) diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py index 801cddd2..f99d4459 100644 --- a/threestudio/utils/ops.py +++ b/threestudio/utils/ops.py @@ -227,33 +227,41 @@ def mask_ray_directions( pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio). the masking is deferred to before calling get_rays(). """ - indices_all = torch.meshgrid( - torch.arange(W, dtype=torch.float32) , - torch.arange(H, dtype=torch.float32) , - indexing="xy", - ) - # indices_inner = torch.meshgrid( - # torch.arange((W-s_W)//2 , W - math.ceil((W-s_W)/2), dtype=torch.float32) , - # torch.arange((H-s_H)//2,H - math.ceil((H-s_H)/2), dtype=torch.float32) , + # indices_all = torch.meshgrid( + # torch.arange(W, dtype=torch.float32) , + # torch.arange(H, dtype=torch.float32) , # indexing="xy", # ) - mask = torch.zeros(H,W, dtype=torch.bool) - mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True + + indices_inner = torch.meshgrid( + torch.linspace(0,0.75*W,s_W, dtype=torch.int8) , + torch.linspace(0,0.75*H,s_H, dtype=torch.int8) , + indexing="xy", + ) + offset = [torch.randint(0,W//8 +1,(1,)), + torch.randint(0,H//8 +1,(1,))] + + select_ind = indices_inner[0]+offset[0] + H*(indices_inner[1] + offset[1]) - in_ind_1d = (indices_all[0]+H*indices_all[1])[mask] - out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)] - ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already - ### leads to more samples inside anyways + + ### removing the random sampling approach, we sample in uniform grid + # mask = torch.zeros(H,W, dtype=torch.bool) + # mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True + + # in_ind_1d = (indices_all[0]+H*indices_all[1])[mask] + # out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)] + # ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already + # ### leads to more samples inside anyways - p = 0.5#(s_H*s_W)/(H*W) - select_ind = in_ind_1d[ - torch.multinomial( - torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)] - select_ind = torch.concatenate( - [select_ind, out_ind_1d[torch.multinomial( - torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)] - ], - dim=0).to(dtype=torch.int).view(s_H,s_W) + # p = 0.5#(s_H*s_W)/(H*W) + # select_ind = in_ind_1d[ + # torch.multinomial( + # torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)] + # select_ind = torch.concatenate( + # [select_ind, out_ind_1d[torch.multinomial( + # torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)] + # ], + # dim=0).to(dtype=torch.int).view(s_H,s_W) ### first attempt at sampling, this produces variable number of rays, ### so 4D tensor directions cant be sampled From ee55c665a6cf974ec1e60199181bd4d202fd0a69 Mon Sep 17 00:00:00 2001 From: jadevaibhav <25821637+jadevaibhav@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:48:52 -0400 Subject: [PATCH 6/6] refactoring --- threestudio/systems/dreamfusion.py | 96 ------------------------- threestudio/systems/eff_dreamfusion.py | 98 ++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 96 deletions(-) create mode 100644 threestudio/systems/eff_dreamfusion.py diff --git a/threestudio/systems/dreamfusion.py b/threestudio/systems/dreamfusion.py index 135c6415..700dc614 100644 --- a/threestudio/systems/dreamfusion.py +++ b/threestudio/systems/dreamfusion.py @@ -162,99 +162,3 @@ def on_test_epoch_end(self): ) -@threestudio.register("efficient-dreamfusion-system") -class EffDreamFusion(DreamFusion): - @dataclass - class Config(DreamFusion.Config): - pass - - cfg: Config - - def configure(self): - # create geometry, material, background, renderer - super().configure() - - def unmask(self,ind,subsampled_tensor,H,W): - """ - ind: B,s_H,s_W - subsampled_tensor: B,C,s_H,s_W - """ - - # Create a grid of coordinates for the original image size - offset = [ind[0,0]%H,ind[0,0]//H] - indices_all = torch.meshgrid( - torch.arange(W, dtype=torch.float32,device=self.device) , - torch.arange(H, dtype=torch.float32,device=self.device) , - indexing="xy" - ) - - grid = torch.stack( - [(indices_all[0] - offset[0])*4/(3*W), - (indices_all[1] - offset[1])*4/(H*3)], - dim=-1) - grid = grid*2 - 1 - grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1) - # Use grid_sample to upsample the subsampled tensor (B,C,H,W) - upsampled_tensor = torch.nn.functional.grid_sample(subsampled_tensor, grid, mode='bilinear', align_corners=True) - - return upsampled_tensor.permute(0,2,3,1) - - def training_step(self, batch, batch_idx): - out = self(batch) - ### using mask to create image at original resolution during training - (B,s_H,s_W,C) = out["comp_rgb"].shape - comp_rgb = out["comp_rgb"].permute(0,3,1,2) - mask = batch["efficiency_mask"] - comp_rgb = self.unmask(mask,comp_rgb,batch["height"],batch["width"]) - # comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) - # comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) - out.update( - { - "comp_rgb": comp_rgb, - } - ) - - prompt_utils = self.prompt_processor() - guidance_out = self.guidance( - out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False - ) - - loss = 0.0 - - for name, value in guidance_out.items(): - if not (type(value) is torch.Tensor and value.numel() > 1): - self.log(f"train/{name}", value) - if name.startswith("loss_"): - loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) - - if self.C(self.cfg.loss.lambda_orient) > 0: - if "normal" not in out: - raise ValueError( - "Normal is required for orientation loss, no normal is found in the output." - ) - loss_orient = ( - out["weights"].detach() - * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 - ).sum() / (out["opacity"] > 0).sum() - self.log("train/loss_orient", loss_orient) - loss += loss_orient * self.C(self.cfg.loss.lambda_orient) - - loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() - self.log("train/loss_sparsity", loss_sparsity) - loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) - - opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) - loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) - self.log("train/loss_opaque", loss_opaque) - loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) - - # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/ - if "z_variance" in out and "lambda_z_variance" in self.cfg.loss: - loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() - self.log("train/loss_z_variance", loss_z_variance) - loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) - - for name, value in self.cfg.loss.items(): - self.log(f"train_params/{name}", self.C(value)) - - return {"loss": loss} \ No newline at end of file diff --git a/threestudio/systems/eff_dreamfusion.py b/threestudio/systems/eff_dreamfusion.py new file mode 100644 index 00000000..669d49e8 --- /dev/null +++ b/threestudio/systems/eff_dreamfusion.py @@ -0,0 +1,98 @@ +from .dreamfusion import * + +@threestudio.register("efficient-dreamfusion-system") +class EffDreamFusion(DreamFusion): + @dataclass + class Config(DreamFusion.Config): + pass + + cfg: Config + + def configure(self): + # create geometry, material, background, renderer + super().configure() + + def unmask(self,ind,subsampled_tensor,H,W): + """ + ind: B,s_H,s_W + subsampled_tensor: B,C,s_H,s_W + """ + + # Create a grid of coordinates for the original image size + offset = [ind[0,0]%H,ind[0,0]//H] + indices_all = torch.meshgrid( + torch.arange(W, dtype=torch.float32,device=self.device) , + torch.arange(H, dtype=torch.float32,device=self.device) , + indexing="xy" + ) + + grid = torch.stack( + [(indices_all[0] - offset[0])*4/(3*W), + (indices_all[1] - offset[1])*4/(H*3)], + dim=-1) + grid = grid*2 - 1 + grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1) + # Use grid_sample to upsample the subsampled tensor (B,C,H,W) + upsampled_tensor = torch.nn.functional.grid_sample(subsampled_tensor, grid, mode='bilinear', align_corners=True) + + return upsampled_tensor.permute(0,2,3,1) + + def training_step(self, batch, batch_idx): + out = self(batch) + ### using mask to create image at original resolution during training + (B,s_H,s_W,C) = out["comp_rgb"].shape + comp_rgb = out["comp_rgb"].permute(0,3,1,2) + mask = batch["efficiency_mask"] + comp_rgb = self.unmask(mask,comp_rgb,batch["height"],batch["width"]) + # comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C) + # comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C) + out.update( + { + "comp_rgb": comp_rgb, + } + ) + + prompt_utils = self.prompt_processor() + guidance_out = self.guidance( + out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False + ) + + loss = 0.0 + + for name, value in guidance_out.items(): + if not (type(value) is torch.Tensor and value.numel() > 1): + self.log(f"train/{name}", value) + if name.startswith("loss_"): + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + + if self.C(self.cfg.loss.lambda_orient) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for orientation loss, no normal is found in the output." + ) + loss_orient = ( + out["weights"].detach() + * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 + ).sum() / (out["opacity"] > 0).sum() + self.log("train/loss_orient", loss_orient) + loss += loss_orient * self.C(self.cfg.loss.lambda_orient) + + loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() + self.log("train/loss_sparsity", loss_sparsity) + loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) + + opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) + loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) + self.log("train/loss_opaque", loss_opaque) + loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) + + # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/ + if "z_variance" in out and "lambda_z_variance" in self.cfg.loss: + loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() + self.log("train/loss_z_variance", loss_z_variance) + loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance) + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} \ No newline at end of file