Skip to content

Commit

Permalink
Set stickified attr as mandatory attr.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal committed Oct 30, 2024
1 parent b5415c3 commit f0b92f0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));
res = stickifiedConstant.getResult();
} else {
Expand Down Expand Up @@ -708,10 +708,10 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());

// Increment constant ID:
Expand Down
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,8 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins OptionalAttr<AnyAttr>:$value,
OptionalAttr<BoolAttr>:$stickified,
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,10 @@ def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
}];
let arguments = (ins AnyAttr:$shape,
StrAttr:$name,
BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<I64Attr>:$offset,
OptionalAttr<BoolAttr>:$stickified,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs ZMemRef:$output);
}
Expand Down
8 changes: 4 additions & 4 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,10 @@ ArrayRef<char> ZLowStickifiedConstantOp::getBuffer() {
MLIRContext *context = getOperation()->getContext();
PatternRewriter rewriter(context);
ArrayRef<char> ret;
if (getValueAttr() && getStickifiedAttr()) {
if (getValueAttr()) {
StringAttr layout = getLayoutAttr();
auto dataAttr = getValue().value();
if (!getStickified().value()) {
if (!getStickified()) {
// The case which the data in value attribute is still not stickified.
DenseElementsAttr denseAttr = mlir::cast<DenseElementsAttr>(dataAttr);
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Expand Down Expand Up @@ -445,8 +445,8 @@ void ZLowStickifiedConstantOp::updateValueAttr() {
PatternRewriter rewriter(context);
// Set buffer when the value attribute is still not stickified or is splat
// with dense element attribute.
if (getValueAttr() && getStickifiedAttr()) {
bool isStickified = getStickified().value();
if (getValueAttr()) {
bool isStickified = getStickified();
bool isSplat = false;
if (auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(getValue().value()))
isSplat = denseAttr.isSplat();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter,
// Create a ZHighStickifiedConstantOp.
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, outputType,
/*value=*/nullptr,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
Expand Down Expand Up @@ -74,8 +74,8 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter,

ZHighStickifiedConstantOp constantOp =
rewriter.create<ZHighStickifiedConstantOp>(loc, replacingValue.getType(),
/*value=*/dataAttr,
/*stickified=*/rewriter.getBoolAttr(false),
/*value=*/dataAttr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

return constantOp;
Expand Down

0 comments on commit f0b92f0

Please sign in to comment.