Skip to content

Latest commit

 

History

History
 
 

scatterElementsPlugin

scatterElements

Table Of Contents

Description

The scatterElements plugin implements the scatter operation described in (https://github.com/rusty1s/pytorch_scatter), in compliance with the ONNX specification for ScatterElements

Note: ScatterElements with reduce="none" is implemented in TRT core, not this plugin.

Structure

This plugin has the plugin creator class ScatterElementsPluginCreator and the plugin class ScatterElementsPlugin which extends IPluginV2DynamicExt.

The ScatterElements plugin consumes the following inputs:

  1. data - T: Tensor of rank r >= 1.
  2. indices - Tind: Tensor of int64 indices, of r >= 1 (same rank as input). All index values are expected to be within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.
  3. updates - T: Tensor of rank r >=1 (same rank and shape as indices)

The ScatterElements plugin produces the following output:

  1. output - T: Tensor, same shape as data.

Parameters

The ScatterElements plugin has the following parameters:

Type Parameter Description
int axis Which axis to scatter on. Default is 0. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).
char reduction Type of reduction to apply: add, mul, max, min. ‘add’: reduction using the addition operation. ‘mul’: reduction using the multiplication operation.‘max’: reduction using the maximum operation.‘min’: reduction using the minimum operation.

The following resources provide a deeper understanding of the scatterElements plugin:

License

For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.

Changelog

Oct 2023: This is the first release of this README.md file.

Known issues

  • Types T=BFLOAT16 and T=INT8 are currently not supported.
  • ONNX spec allows Tind=int32 : only INT64 is supported by this plugin