Skip to content

PyTorch models

esgomezm edited this page Jun 20, 2023 · 5 revisions

PyTorch models in deepImageJ

The easiest way to get a deepImageJ-compatible PyTorch model is by using the bioimage.io core library from the BioImage Model Zoo. Please, follow the example notebook to see how to do it.

More technically, deepImageJ can load Pytorch models by making use of a third-party library called Java Deep Learning Library (JDLL). To load PyTorch models in JDLL, the models need to be saved in TorchScript format so the library can make use of the Python C++ API. The latter does not add complexity to coding in Python as it only implies adding 2 extra lines of code:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

The compatibility of deepImageJ with PyTorch versions is defined by the JDLL.

For compatibility with Windows OS, JDLL requires the installation of Visual Studio 2019 redistributable.