Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNPA] Memory reduction of stickified constant by stickifying at file writing #2917

Open
wants to merge 73 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
a589aa2
Add and use ConstantOpInterface for KrnlGlobalOps.
imaihal Aug 5, 2024
76e9dcb
Add and use ConstantOpInterface in lowering of KrnlGlobalOp to LLVMIR.
imaihal Aug 7, 2024
84c13c0
Initial implementation for NNPA.
imaihal Aug 19, 2024
42cb5a1
Update to handle stickifiedConstantOp initialized with zero
imaihal Aug 21, 2024
5ac7b1f
Update to free memory correctly
imaihal Aug 22, 2024
e8b935d
Clean up
imaihal Aug 23, 2024
82ada4e
Clean up
imaihal Aug 23, 2024
5dff7cf
Merge branch 'main' into mem_reduction_stickified
imaihal Aug 26, 2024
9c5dd88
format
imaihal Aug 26, 2024
2f23ff8
Merge branch 'main' into mem_reduction_stickified
imaihal Aug 26, 2024
7a5fb6d
Fix for lstm and gru.
imaihal Aug 27, 2024
fd82e47
Merge branch 'main' into mem_reduction_stickified
imaihal Aug 27, 2024
748517f
Fix the case totalsize is less than or equal to totalThreshold.
imaihal Aug 28, 2024
4cb46dd
Merge branch 'main' into mem_reduction_stickified
imaihal Aug 29, 2024
434272a
Fix the case without setting --store-constants-to-file option.
imaihal Aug 29, 2024
18b9919
Fix lit tests.
imaihal Aug 29, 2024
b99a334
Fix getBuffersize() for CategoryMapperOp
imaihal Aug 30, 2024
16773e7
Update attributes in ZHigh/ZLowStickfiedConstantOp.
imaihal Aug 30, 2024
9e236c9
Use stickified attribute and zeroconst attribute
imaihal Sep 2, 2024
56bc50d
Attribute name change: zeroconst to allzero
imaihal Sep 3, 2024
4f08bef
Update lit tests.
imaihal Sep 3, 2024
67d0f20
clean up.
imaihal Sep 3, 2024
dbf4c82
Add an option.
imaihal Sep 10, 2024
208020b
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 11, 2024
ac37742
The option is true by default.
imaihal Sep 11, 2024
ad06734
Set the option false by default for testing.
imaihal Sep 12, 2024
5dcc2f7
Revert lit-tests for testing.
imaihal Sep 12, 2024
d18539e
Set false in store-constants-to-file for testing.
imaihal Sep 12, 2024
b5cfaf8
Revert "Set the option false by default for testing."
imaihal Sep 13, 2024
f3f3b68
Revert "Revert lit-tests for testing."
imaihal Sep 13, 2024
6c5569e
Revert "Set false in store-constants-to-file for testing."
imaihal Sep 13, 2024
c62db4b
Fix using OpInterfaceConversionPattern
imaihal Sep 13, 2024
1e4e156
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 17, 2024
53b99c1
Change order of conversion.
imaihal Sep 20, 2024
1d4ed1b
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 20, 2024
5ea61d9
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 24, 2024
695d072
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 25, 2024
fc07ddf
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 27, 2024
e9c6805
Add comments.
imaihal Sep 30, 2024
70abe9c
Merge branch 'main' into mem_reduction_stickified
imaihal Sep 30, 2024
691ec33
Fix layout attribute in zlowStickifiedConstantOp.
imaihal Oct 2, 2024
b06a7b9
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 2, 2024
9a002cc
Update lit tests.
imaihal Oct 2, 2024
232252b
Remove the option.
imaihal Oct 3, 2024
bc604bd
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 3, 2024
5233331
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 4, 2024
c98bdbd
Remove `allzero` attribute.
imaihal Oct 8, 2024
3e5de25
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 8, 2024
f0a187a
Fix lit tests for stickified constant for be.
imaihal Oct 8, 2024
bd065f5
Fix a buf for removing `allzero` attribute.
imaihal Oct 9, 2024
75ad266
Move getRawData() and mlirTypeToZDNNType() to OpHelper.
imaihal Oct 9, 2024
9b9fbc2
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 9, 2024
b19dd84
Fix lit test related to removing all_zero attribute.
imaihal Oct 10, 2024
d6b2100
Format.
imaihal Oct 10, 2024
92e831b
Remove duplicated code for getRawData().
imaihal Oct 10, 2024
1a4cb5d
Fix a bug when removing duplicated code.
imaihal Oct 10, 2024
7af8bdd
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 11, 2024
9af0a2f
Fix list test for ccfd
imaihal Oct 11, 2024
a27fb85
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 11, 2024
8cce1cb
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 15, 2024
ad19bc9
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 21, 2024
d94ef0e
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 24, 2024
5a0edcd
Merge branch 'main' into mem_reduction_stickified
imaihal Oct 29, 2024
89ea256
Change OpInterface name
imaihal Oct 29, 2024
313a785
Fix freeBuffer() for CPU.
imaihal Oct 29, 2024
9b8c662
Format
imaihal Oct 29, 2024
924ab61
Remove redundant code.
imaihal Oct 29, 2024
b5415c3
Remove #pragma
imaihal Oct 29, 2024
f0b92f0
Set stickified attr as mandatory attr.
imaihal Oct 30, 2024
5ee9e77
Update descriptions for the OpInterface.
imaihal Oct 30, 2024
3d167da
Keep original implementation
imaihal Oct 31, 2024
8192993
clean up
imaihal Nov 1, 2024
5147329
Merge branch 'main' into mem_reduction_stickified
imaihal Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 88 additions & 45 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,47 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, 0, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
stickifiedConstant.setValueAttr(valueAttr);
free(rawData);

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Set zero in value attribute as DenseResourceElementsAttribute.
// ZHighStickifiedConstantOp stickifiedConstant =
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
// /*stickified=*/rewriter.getBoolAttr(true),
// /*value=*/nullptr,
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
//
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
// assert(rawData && "failed to allocate memory for stickified data");
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);
// #else

// Set zero in value attribute as SplatElementsAttr.
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN

res = stickifiedConstant.getResult();
} else {
Expand Down Expand Up @@ -686,7 +706,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
};

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to KrnlGlobal
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
imaihal marked this conversation as resolved.
Show resolved Hide resolved
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
Expand All @@ -699,7 +719,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp stickifiedConstOp =
ZHighStickifiedConstantOp zhighStickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -713,36 +733,59 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Get dense resource attribute.
auto blob = mlir::cast<DenseResourceElementsAttr>(
stickifiedConstOp.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();

// Validate the stickified tensor.
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
memRefSizeInBytes *= normalizedType.getNumElements();
assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");

// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
Copy link
Collaborator

@chentong319 chentong319 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the previous implementation with KrnlGlobalOp in comment or if false branch, if you do not want to create an option to control the choice. You can define an option '--disable-krnl-constant-to-file' with default value of 'false'.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Is this because we may reuse the previous implementation in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created directive NNPA_ZHIGH_STICKIFIEDCONST_GEN to keep the original implementation. Currently commented out, but I confirmed it works when enabling this code.

rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// Create ZLowStickifiedConstantOp.
StringAttr layout =
getZTensorLayoutAttr(rewriter, *op->result_type_begin());

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Lower to KrnlGlobalOp
// // Get dense resource attribute.
// auto blob = mlir::cast<DenseResourceElementsAttr>(
// zhighStickifiedConstOp.getValue().value())
// .getRawHandle()
// .getBlob();
// assert(blob && "Expecting dense resource with a valid blob");
// ArrayRef<char> data = blob->getData();
// // Validate the stickified tensor.
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
// memRefSizeInBytes *= normalizedType.getNumElements();
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
// "The stickified tensor's buffer size and MemRef's size
// mismatched");
// // Create a KrnlGlobalOp.
// KrnlGlobalOp constantOp =
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// /*shape=*/
// rewriter.getI64ArrayAttr(normalizedShape),
// /*name=*/
// rewriter.getStringAttr(
// "constant_stickify_" + std::to_string(constantID)),
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
// /*offset=*/nullptr,
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #else
ZLowStickifiedConstantOp constantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*value=*/stickifiedConstOp.getValueAttr(),
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantGlobal.getResult());
rewriter.replaceOp(op, constantOp.getResult());
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_onnx_mlir_library(OMZHighOps
OMONNXOps # Use ONNXShapeHelper
OMLayoutHelper
OMShapeHelperOpInterface
OMStickify
OMNNPACompilerOptions
MLIRIR

Expand Down
5 changes: 4 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,14 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let summary = "ZHigh Stickified Constant operation";
let description = [{
This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins OptionalAttr<AnyAttr>:$value,
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
51 changes: 50 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
Expand Down Expand Up @@ -482,5 +481,55 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
return IntegerAttr();
}

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(Type elementType) {
if (mlir::isa<FloatType>(elementType)) {
FloatType floatTy = mlir::cast<FloatType>(elementType);
if (floatTy.getWidth() == 16) {
return FP16;
} else if (floatTy.getWidth() == 32) {
return FP32;
} else
llvm_unreachable("Unsupported data type.");
} else
llvm_unreachable("Unsupported data type.");
}

/// Get stickified data from denseElementAttribute
ArrayRef<char> getStickifiedDataOfDenseElemAttr(
DenseElementsAttr denseAttr, StringAttr layout) {
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Type elementType = denseAttr.getType().getElementType();
int rank = shape.size();
// Read attributes's raw data.
std::vector<char> attrData;
getRawData(denseAttr, attrData);
// Call stickify.
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// pre-transformed desc.
zdnn_data_layouts zDNNLayout =
convertLayoutAttrToZDNNDataLayout(rank, layout);
// If zDNNLayout is NHWC, we stickify directly from NCHW.
if (zDNNLayout == ZDNN_NHWC)
zDNNLayout = ZDNN_NCHW;
zdnn_data_types zDNNType = onnx_mlir::zhigh::mlirTypeToZDNNType(elementType);
set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
// transformed desc.
zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
assert(status == ZDNN_OK);
// Stick data using the software stickify.
zdnn_ztensor ztensor;
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
status = allochelper_ztensor_alloc(&ztensor);
assert(status == ZDNN_OK);
status = stickify(&ztensor, attrData.data());
assert(status == ZDNN_OK);
int64_t sizeInBytes = ztensor.buffer_size;
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, ztensor.buffer, sizeInBytes);
allochelper_ztensor_free(&ztensor);
return llvm::ArrayRef(rawData, sizeInBytes);
}

} // namespace zhigh
} // namespace onnx_mlir
8 changes: 8 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp"

namespace onnx_mlir {
namespace zhigh {
Expand Down Expand Up @@ -88,6 +89,13 @@ bool hasNNPAUse(mlir::Value v);
/// Get saturation settings.
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(mlir::Type elementType);

/// Get stickified data from denseElementAttribute
mlir::ArrayRef<char> getStickifiedDataOfDenseElemAttr(
mlir::DenseElementsAttr denseAttr, mlir::StringAttr layout);

} // namespace zhigh
} // namespace onnx_mlir
#endif
5 changes: 5 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ add_onnx_mlir_library(OMZLowOps
DEPENDS
OMZLowIncGen
OMONNXZLowCombineIncGen
OMKrnlGlobalOpInterface

LINK_LIBS PUBLIC
MLIRIR
OMMlirDialects
OMZHighOps

ACCEL_INCLUDE_DIRS PRIVATE
${NNPA_INCLUDE_PATH}
)
17 changes: 17 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def ZMemRef : MemRefOf<[DLF16]>;
//===----------------------------------------------------------------------===//

include "mlir/Interfaces/SideEffectInterfaces.td"
include "src/Interface/KrnlGlobalOpInterface.td"

def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down Expand Up @@ -547,4 +548,20 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> {
];
}

def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
DeclareOpInterfaceMethods<KrnlGlobalOpInterface>]> {
let summary = "ZLow Stickified Constant operation.";
let description = [{

}];
let arguments = (ins AnyAttr:$shape,
StrAttr:$name,
BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<I64Attr>:$offset,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs ZMemRef:$output);
imaihal marked this conversation as resolved.
Show resolved Hide resolved
}

#endif // ZLOW_OPS
Loading
Loading