Skip to content

MONAI_Preprocessors_and_Transforms_Design_Discussion

Ben Murray edited this page Apr 3, 2020 · 1 revision

Introduction

This page discusses the design of the preprocessing capabilities of MONAI. As of this revision of the page, this is entirely under discussion at this point in time and people are encouraged to contribute to / debate the current design thinking.

Table of Contents

Key functionality

  1. Preprocessing / augmentation
    1. Normalisation techniques
    2. Spatial augmentation
      1. Affine deformations
      2. Non-affine deformations
    3. Intensity augmentation
  2. Patches
    1. Patch sampling
    2. Grid sampling
  3. Data formats
    1. First-class nifti support
    2. First-class npz support
  4. Multiple workers

Preprocessor Design Challenges

The functionality for preprocessing is spread across multiple parts of pytorch and related libraries.

Pytorch

DataLoader is the pytorch-provided class that combines dataset and sampler functionality, and which provides iteration over the dataset. MONAI functionality should be compatible with DataLoader, although we are potentially free to subclass from it if there is additional functionality that we consider key which cannot be realised with the standard DataLoader class.

DataSet and its subclasses IterableDataset and MapDataSet are similarly core pytorch functionality that all preprocessors should be compatible with. It is quite normal to further subclass these classes for specific applications.

Torchvision

torchvision does a lot of the heavy lifting for data augmentation in the non-medical imaging world, and torch.utils.data handles sampling approaches:

  • Torchvision doesn't have a fancy underlying base class for image manipulation and only requires that __call__ is implemented
  • Sampler, the base class for all of the torch.utils.data sampler layers similarly descends directly from object and simply requires that iter and __len__ are implemented. Torchvision preprocessors are designed for use for generic images rather than medical data, and so we don't propose that the library should be part of the prerequisites for MONAI, but our layers should be compatible unless there is a show-stopping reason uncovered in the design process as to why they shouldn't be. It is likely that people with existing codebases who want to adopt MONAI will be doing so for medical-data-specific preprocessors. Making our preprocessors compatible with torchvision helpers such as Compose is therefore desirable.

Our design choices are therefore as follows:

  1. Implement preprocessing layers on the core of torchvision
  2. Implement torchvision like functionality
  3. Determine that there is a better way to handle preprocessing that is better suited to medical data

Design questions

What is 'vanilla' in terms of preprocessors?

A vanilla pytorch preprocessor pipeline looks like the following:

  • A subclass of torch.Dataset is created to contain the necessary state for the set of transforms to be performed
    • Standard subclasses of torch.Dataset, torch.IterableDataset and torch.MapDataset can be subclassed instead of the base Dataset class
  • This is passed to an instance of torch.Dataloader to be executed
  • torchvision.Compose is typically used to create a pipeline of preprocessor stages; this is typically what is passed to the transform argument of the Dataset derived subclass
    • For more complex pipelines, torchvision.transforms.functional can be used (TODO: example here)

Without Compose

Returning tuples

# 1. load returns tuple; normalise takes single arg; deform takes image and matrix
t = load(f)
i = normalise(t[0])
t = deform(i, t[1])

# 2. load returns tuple; normalise takes *args, deform takes *args
t = load(f)
t = normalise(*t)
t = deform(t)

# 3. load returns tuple; normalise takes single arg; deform takes image and matrix
i, m = load(f)
i = normalise(i)
i, m = deform(i, m)

Return dicts

# 4. load returns dict, normalise takes single arg; deform takes image and matrix and returns dict
d = load(f)
d['image'] = normalise(d['image'])
d = deform(d)

# 5. load returns dict, normalise takes **kwargs, deform takes **kwargs and returns dict#
d = load(f)
d = normalise(d)
d = deform(d)

Take dict as parameter and return dict

# 6. load returns dict, all methods take and return dicts
d = load(f)
d = normalise(d)
d = image(d)
d = matrix(d)

Without 'Compose', the cleanest syntax is achieved through implementing all transforms as vanilla function calls, but this is not typical.

Use of Compose

As mentioned, transform libraries tend to have Compose functionality that places strong design constraints on the transforms that those libraries have. The different approaches taken mean that it is hard to take a single approach that also somehow avoids any cognitive burden on the transform writer

torchvision transforms are all implemented so that __call__ takes a single argument and returns a single argument. torchvision.transforms.Compose is implemented with this assumption in mind.

def __call__(self, img):
    for t in self.transforms:
        img = t(img)
    return img

Other libraries do not restrict themselves to this convention, and indeed, it isn't necessarily desirable to do so.

Within the medical deep learning space MIC-DKFZ/batchgenerators takes an approach where each transform takes a dictionary and returns a dictionary. It therefore has to provide its own Compose function, that passes **kwargs to the __call__ methods of the transforms being composed.

def __call__(self, **data_dict):
    for t in self.transforms:
        data_dict = t(**data_dict)
    return data_dict

When layers are written to take dictionaries explicitly, we are back to being able to use torchvision's Compose functionality again provided that all transforms take and receive dictionaries.

In fact, without adaptor wrappers for transforms, there isn't a way to universally mix and patch transforms with different call philosophies.

Canonical preprocessor formats

Given a dictionary approach, a canonical form for preprocessor functions that pass through other arguments unaltered will look something like the following:

def deform(image, tx, **okwargs):
    image = #work on image
    tx = #work on tx

    args = dict(okwargs)
    args.update({'image': a, 'tx': b})
    return args

The dictionary approach also makes it harder to work with non-dictionary based transforms that don't pass unused arguments through. For these, an adaptor of some kind might be required:

def adaptor(fn, in_names, out_names):
  to = {}
  for n in in_names:
    to[n[1]] = args[n[0]]

  # call a function that is not setup to handle *args/**kwargs
  res = fn(**to)

  # copy the args so we aren't destructively changing collections
  args = dict(args)

  # handle the different return types - this might need more workers
  if isinstance(res, (list, tuple)):
    for k,v in zip(rtn_names, res):
      args[k] = v
  elif isinstance(res, dict):
    modres = {rtn_names[k] if k in rtn_names else k:v for k,v in res.items()}
    args = args.copy()
    args.update(mod_res)
  else:
    args[rtn_names] = res

  return args

A more radical alternative to Compose

Looking at the different design trade-offs when writing transforms, it becomes apparent is that Compose forces most of the complexities onto transforms. Consider a situation where some transforms are applied to image data only and some are applied to both image data and segmentation:

  Compose([
    MRISegDataset(...), # t is a dictionary that contains 'image' and 'seg' keywords
    Normalise(...), # applied to image only
    Rotate(...), # applied to image and seg
    Flip(...) # applied to image and seg
  ])

The complexity here arises because we force a linear dataflow onto a non-linear set of function calls. We do this because we want those calls to be invoked by the DataLoader, potentially on multiple threads and in a way that takes advantage of IO completion.

Instead of using Compose, we could create a class that performs the same functionality and is invoked as a function call at data loading time.

class Preprocessor:
  def __init__(self):
    self.norm = Normalise(...)
    self.rotn = Rotate(...)
    self.flip = Flip(...)

  def __call__(self, inputs):
    image = self.norm(inputs['image'])
    inputs = np.stack(image, inputs['seg'])
    inputs = self.rotn(inputs)
    inputs = self.flip(inputs)

d = Dataset(..., transform=Preprocessor())

Consider this a sketch at the moment; there is a lot of momentum in the community for use of some form of Compose method or other, but it should be considered an anti-pattern, for the following reasons:

  1. Most transform pipelines for supervised applications contain layers that don't act on all 'channels'
  2. Compose doesn't make this distinction and forces a linear call chain on the pipeline
  3. The transform functionality must then be polluted with knowledge of how they are being called, which shouldn't be necessary

Augmenting Compose functionality as another approach

It may be possible to resolve this problem by adding a small, additional amount of functionality to Compose. compose non-linear pipeline out of linear Compose-able elements.

Randomized Transforms

Randomized transforms are transforms whose behaviour is driven by a random number generator. Such transforms present a design challenge because, when having multiple inputs that live in the same spatial frame of reference, the same randomness must be applied to multiple inputs. Any rotation, flips, scaling, deformations, etc. must be applied across all related inputs.

Transforms using the implicit np.random instance

Many randomized transforms use np.random's internal random generator instance. This instance is therefore shared across multiple transforms, and creates many issues that are not easily resolved, especially in a threaded scenario.

The following solution is recommended by torchvision (here) and currently in MONAI as of this commit.

seed = np.random.randint(2147483647)

if self.transform is not None:
    np.random.seed(seed)
    img = self.transform(img)
    random_sync_test = np.random.randint(2147483647)

if self.seg_transform is not None:
    np.random.seed(seed)  # ensure randomized transforms roll the same values for segmentations as images
    seg = self.seg_transform(seg)
    seg_seed = np.random.randint(2147483647)
    assert(random_sync_test == seg_seed)

In a single threaded scenario, this can can be used to ensure that the same values are used for both image and seg transform paths, but only if the following conditions hold:

  1. The transform pipelines are identical
  2. The data loader is run in a single threaded mode

These are restrictive scenarios that mean we cannot consider this approach to be a suitable solution.

Solution components

Our potential solutions rely on a number of elements that can be brought together in different ways

Make all native MONAI transforms use RandomState instances

'RandomState' instances can be created and used instead of relying on the implicit random instance. We can construct transforms with either a random seed or with a RandomState instance:

  • If constructed with a seed, a RandomState instances is created local to that transform with that seg_seed
  • If constructed with a RandomState instance, that instance is used by the transform.

Doing so allows us to have pipelines containing different preprocessing stages for images and segmentations, provided that the corresponding transform instances are initialised with the same seed.

imtrans=transforms.Compose([
    Rescale(),
    AddChannel(),
    AddGaussianNoise(sigma=.5, seed=1234),
    UniformRandomPatch((64, 64, 64), seed=5678),
    ToTensor()
 ])    

segtrans=transforms.Compose([
    AddChannel(),
    UniformRandomPatch((64, 64, 64), seed=5678),
    ToTensor()
])

Make MONAI transforms accept multiple inputs

Monai transforms can be written to accept multiple inputs and have the same random values apply to all of them. This approach can be made to work with vanilla Compose through the adaptor function.

class ATransform:
  def __init__(self, ...):
    # initialise things

  def __call__(self, *data):
    r = self.rng.nextInt()
    result = []
    for d in data:
      result.append(apply_r_to_d(r, d))
    return result

Share random sources between Transforms

This approach requires a more sophisticated random source, wrapping an underlying random source and being updated explicitly as part of a transform step.

Composable affine transforms and deformations

Graphics engines generally work on the idea that a series of affine transforms can be represented as a homogeneous matrix representing the accumulation of operations applied in a specific order. This is typically known as a transform stack. Operations can be pushed to and popped off the stack, resulting in the compound matrix.

The cost of applying individual transforms (flips, rotations, etc.) as discrete operations results in unnecessary computation; the transform could be built instead and applied in its final form to the un-affine-transformed image.

Deformations are different; they cannot be applied via a single matrix. Would such a design extend to deformations?

Data sources

The main data format used in medical imaging is NIFTI. Key support should be provided for nifti file format loading and, in particluar, all of the complications that arise around q-form and s-form fields, and nifti's output through VTK. NiftyNet's reader handles some of these issues, but as far as the author of this page is aware, no reader exists that handles all of these perfectly.

Clone this wiki locally