Skip to content

Commit

Permalink
Add rudimentary support for "arbitrary" dimensions in MultiThreshold
Browse files Browse the repository at this point in the history
This allows node execution of MultiThreshold operators with arbitrary
number of dimensions, as long as the channel dimension is last. This is
necessary to run some verification steps of attention operators which,
at least for some intermediate steps, have 3 dimensional data layouts.

This does not change the behavior of execution on the already existing
2d and 4d data layouts.
  • Loading branch information
iksnagreb committed Dec 13, 2023
1 parent cadd6b2 commit 81a0744
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ def execute_node(self, context, graph):
pass
else:
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")

# Remember whether the shape has been modified to handle 1d or 3d data
# layouts
orig_shape = None
# If the input tensor has dimensions not covered by the NC or NCWH data
# layouts, the shape needs to be adapted such that it can be handled by
# multithreshold.
# TODO: Seems like a rather sketchy solution to support arbitrary data
# layouts. This does not even validate the assumption of channel last
# layout.
if v.ndim not in {2, 4}:
# Remember the original shape to be restored later
orig_shape = v.shape
# Assume last dimension to be the channel dimension C and reshape
# into NC layout which is supported by multithreshold
v = v.reshape((-1, v.shape[-1]))

# calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
Expand All @@ -145,6 +162,13 @@ def execute_node(self, context, graph):
pass
else:
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")

# If the shape has been modified to support arbitrary layouts, restore
# the original shape
# TODO: Part of the rather sketchy solution above.
if orig_shape is not None:
output = output.reshape(orig_shape)

context[node.output[0]] = output

def verify_node(self):
Expand Down

0 comments on commit 81a0744

Please sign in to comment.