The PyTorch-Neuron trace Python API provides a method to generate PyTorch models for execution on Inferentia, which can be serialized as TorchScript. It is analogous to torch.jit.trace function in PyTorch
import torch
import torch_neuron
torch.neuron.trace(model, example_inputs, compiler_args)
The torch.neuron.trace method sends Neuron-supported operations to the Neuron-Compiler for compilation and embeds compiled artifacts in a TorchScript graph.
Compilation can be done on any EC2 machine with sufficient memory and compute resources. c5.4xlarge or larger is recommended.
The compiled graph can be saved using the torch.jit.save function and restored using torch.jit.load function for inference on Inf1 instances. During inference, the previously compiled artifacts will be loaded into the Neuron Runtime for inference execution.
Options can be passed to Neuron compiler via the compile function. See Neuron Compiler CLI for more information about compiler options.
- model: A Python function or torch.nn.Module that will be run with example_inputs arguments and returns to
func
must be tensors or (possibly nested) tuples that contain tensors. When a module is passed to torch.neuron.trace, only the forward method is run and traced - example_inputs: A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. example_inputs may also be a single Tensor in which case it is automatically wrapped in a tuple.
- compiler_args: (Optional) List of strings representing neuron-cc compiler arguments. Note that these arguments apply to all subgraphs generated by whitelist partitioning. For example, use
compiler_args=['--num-neuroncores', '4']
to set number of NeuronCores per subgraph to 4. See Neuron Compiler CLI for more information about compiler options. - compiler_timeout (int, optional): Timeout in seconds for waiting neuron-cc to complete. Exceeding timeout will cause a
subprocess.TimeoutExpired
being raised - compiler_workdir (path-like, optional): Work directory used by neuron-cc. Useful for debugging and/or inspecting neuron-cc logs/IRs
- check_trace (
bool
, optional): Check if the same inputs run through traced code produce the same outputs. Default:True
. You might want to disable this if, for example, your network contains non-deterministic ops or if you are sure that the network is correct despite a checker failure - check_inputs (list of tuples, optional): A list of tuples of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in
example_inputs
. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the original example_inputs are used for checking - check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.
- If model is an nn.Module or is the forward method of an nn.Module:
- If model is in evaluation mode (has property
training==False
),trace
returns a :class:ScriptModule
object with a singleforward
method containing the traced code. - Otherwise
trace
returns input argumentfunc
as-is.
- If model is in evaluation mode (has property
- If
callable
is a standalone function,trace
returnstorch._C.Function
Model with compiled artifacts embedded.
Example (tracing a function):
import torch
import torch_neuron
def foo(x, y):
return 2 * x + y
# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.neuron.trace(foo, (torch.rand(3), torch.rand(3)))
# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment
Example (tracing an existing module)::
import torch
import torch_neuron
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
n = Net()
n.eval()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.neuron.trace(n.forward, example_forward_input)
# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.neuron.trace(n, example_forward_input)
The following is an example usage of the compilation Python API, with default compilation arguments, using a pretrained torch.nn.Module (in this case :
import torch
import torch_neuron
from torchvision import models
model = models.resnet50(pretrained=True)
model.eval()
model_neuron = torch.neuron.trace(model, example_inputs=[image])
model_neuron.save("resnet50_neuron.pt")