Skip to content

Commit

Permalink
Adds preparebatch
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham committed Jan 30, 2024
1 parent 33dce3a commit 023b6bb
Showing 1 changed file with 78 additions and 1 deletion.
79 changes: 78 additions & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union, cast

import torch
import torch.nn as nn

from monai.config import IgniteInfo
from monai.transforms import apply_transform
Expand All @@ -36,6 +37,8 @@
"PrepareBatch",
"PrepareBatchDefault",
"PrepareBatchExtraInput",
"DiffusionPrepareBatch",
"VPredictionPrepareBatch",
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
Expand Down Expand Up @@ -238,6 +241,80 @@ def _get_data(key: str) -> torch.Tensor:
return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_


class DiffusionPrepareBatch(PrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
"""

def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images: torch.Tensor) -> torch.Tensor:
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
"""Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
"""Return the target for the loss function, this is the `noise` value by default."""
return noise

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
infer_kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
infer_kwargs["conditioning"] = batchdata[self.condition_name].to(
device, non_blocking=non_blocking, **kwargs
)

# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
return images, target, (), infer_kwargs


class VPredictionPrepareBatch(DiffusionPrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
from this compute the velocity using the provided scheduler. This value is used as the target in place of the
noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
being used in conjunction with this class expects a "noise" parameter to be provided.
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
"""

def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
self.scheduler = scheduler

def get_target(self, images, noise, timesteps):
return self.scheduler.get_velocity(images, noise, timesteps)


def default_make_latent(
num_latents: int,
latent_size: int,
Expand Down

0 comments on commit 023b6bb

Please sign in to comment.