-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
onnx.Unique #647
Comments
Hello @renxida Could you please assign this issue to me? Also, I have some questions to ask you
|
That sounds like a plan! feel free to reach me on discord (@xida_ren) if you have questions and just want to chat. also ping me on discord if you need e.g. a code review or a ci approval |
Thank you for your kind reply. ❤️ |
Hi @Peefy, are you still working on this op? |
Hello @vivekkhandelwal1. Yes, here's my code. I have almost completed the situation where onnx.Unique sorted attribute is true. But I encountered some other difficulties. For onnx.Unique then its sorted attribute is false, it seems that there is no non sorted normal deduplication function in torch. Therefore, I am thinking about how to use patterns.onOp(
"Unique", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
// Here we use torch.unique_consecutive and other operators to compose the onnx.Unique
// ```python
// def onnx_unique(x, sorted=True, dim=0):
// unique, inverse, counts = torch_unique(x, dim=dim,
// sorted=sorted, return_inverse=True, return_counts=True)
// _, ind_sorted = torch.sort(idx, stable=True)
// cum_sum = counts.cumsum(0)
// cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
// indicies = ind_sorted[cum_sum]
// return unique, indicies, inverse, counts
//
// def torch_unique(tensor):
// sorted_tensor, sorted_indices = torch.sort(tensor)
// return torch.unique_consecutive(sorted_tensor, return_inverse=True, return_counts=True)
// ```
// Note that the situation where sorted is false has not been handled yet.
//
// Reference: onnx.Unique: https://onnx.ai/onnx/operators/onnx__Unique.html
// Reference: torch.unique: https://pytorch.org/docs/stable/generated/torch.unique.html
// Reference: torch.unique_consecutive: https://pytorch.org/docs/stable/generated/torch.unique_consecutive.html
Torch::ValueTensorType outputType, indicesType, inverseIndicesType, countsType;
Value input;
// Note the axis can be negative.
// Accepted range is [-r, r-1] where r = rank(input).
int64_t axis;
// The default value of sorted attribute is 1
bool sorted;
if (binder.tensorOperand(input) ||
binder.s64BoolAttr(sorted, "sorted", true) ||
binder.s64IntegerAttr(axis, "axis", 0) ||
binder.tensorResultTypeAtIndex(outputType, 0) ||
binder.tensorResultTypeAtIndex(indicesType, 1) ||
binder.tensorResultTypeAtIndex(inverseIndicesType, 2) ||
binder.tensorResultTypeAtIndex(countsType, 3))
return failure();
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor");
if (!sorted)
return rewriter.notifyMatchFailure(binder.op, "Unimplemented: torch.unique is not yet supported for situations where sorted is false");
unsigned rank = *maybeRank;
axis = Torch::toPositiveDim(axis, rank);
auto loc = binder.getLoc();
auto torchUnique = [&](Value tensor, Value sorted, Value dim) -> std::tuple<Value, Value, Value> {
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto sortedResult = rewriter.create<Torch::AtenSortOp>(loc, tensor.getType(), indicesType, tensor, zero, cstFalse);
// Index 0 is the sorted tensor and index 1 is the indices
Value sortedTensor = sortedResult->getResult(0);
Value cstReturnInverse =
rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value cstReturnCounts =
rewriter.create<Torch::ConstantBoolOp>(loc, true);
auto uniqueConsecutiveResult = rewriter.create<Torch::AtenUniqueConsecutiveOp>(loc, outputType, indicesType, countsType, sortedTensor, cstReturnInverse, cstReturnCounts, dim);
Value uniqueValues = uniqueConsecutiveResult->getResult(0);
Value inverseIndices = uniqueConsecutiveResult->getResult(1);
Value unique_counts = uniqueConsecutiveResult->getResult(2);
return std::make_tuple(uniqueValues, inverseIndices, unique_counts);
};
auto onnxUnique = [&](Value unique, Value inverse, Value counts) -> ValueRange {
auto cstFalse =
rewriter.create<Torch::ConstantBoolOp>(loc, true);
auto zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto intMinus1 = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
auto sortedResult = rewriter.create<Torch::AtenSortOp>(loc, inverse.getType(), indicesType, inverse, zero, cstFalse);
auto ind_sorted = sortedResult->getResult(1);
// %int1 = torch.constant.int 1
// %size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// %none = torch.constant.none
// %1 = torch.aten.zeros %size, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],f32>
// %int0 = torch.constant.int 0
// %none_0 = torch.constant.none
// %2 = torch.aten.cumsum %counts, %int0, %none_0 : !torch.vtensor<[4],f32>, !torch.int, !torch.none -> !torch.vtensor<[4],f32>
// %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[1],f32>, !torch.vtensor<[4],f32>) -> !torch.list<vtensor>
// %int0_1 = torch.constant.int 0
// %4 = torch.aten.cat %3, %int0_1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[5],f32>
// %int0_2 = torch.constant.int 0
// %int0_3 = torch.constant.int 0
// %int-1 = torch.constant.int -1
// %int1_4 = torch.constant.int 1
// %5 = torch.aten.slice.Tensor %4, %int0_2, %int0_3, %int-1, %int1_4 : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4],f32>
auto none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto cumSumResult = rewriter.create<Torch::AtenCumsumOp>(loc, counts.getType(), counts, zero, none);
auto cumSumSliceResult = rewriter.create<Torch::AtenSliceTensorOp>(loc, indicesType, cumSumResult, zero, zero, intMinus1, one);
auto size = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
SmallVector<Value> sizeList;
sizeList.push_back(size);
auto sizeValue = rewriter.create<Torch::PrimListConstructOp>(
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
sizeList);
auto tensorZerosResult = rewriter.create<Torch::AtenZerosOp>(loc, counts.getType(), sizeValue, none, none, none, none);
SmallVector<Value> valueList;
valueList.push_back(tensorZerosResult);
valueList.push_back(cumSumSliceResult);
Type listElemType =
tensorZerosResult
.getType()
.cast<Torch::BaseTensorType>()
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, valueList);
auto catOpResult = rewriter.create<Torch::AtenCatOp>(loc, indicesType, tensorList, zero);
auto select = [&](Value v, Value k) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>();
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
loc,
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
ty.getOptionalDtype()),
v, zero, k);
Value item = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), sel);
return item;
};
auto indicies = select(ind_sorted, catOpResult);
return ValueRange({unique, indicies, inverse, counts});
};
rewriter.replaceOp(binder.op, std::apply(onnxUnique, torchUnique(
input,
rewriter.create<Torch::ConstantBoolOp>(loc, sorted),
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(axis)))));
return success();
}); |
Hi @Peefy, as of now you can add a limited support for the op, and then extend it later. Also, it will be better if you create a WIP pr for this to be reviewed. |
Hello @vivekkhandelwal1 Sorry, I may not have much time to complete this recently. Please un-assign me. |
Assigning this to @vinayakdsci |
Tracking Issue: #215
The text was updated successfully, but these errors were encountered: