diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index 9b31b36ed..cdaa67c18 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -75,6 +75,14 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTOperandConstraintArrayAttrGet( MlirContext ctx, uint32_t *OperandConstraints, size_t OperandConstraintsSize); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTTileSizeAttrGet(MlirContext ctx, + int64_t y, int64_t x); + +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( + MlirContext ctx, MlirAttribute *worker, size_t workerSize, + MlirAttribute *dram, size_t dramSize, MlirAttribute *eth, size_t ethSize, + MlirAttribute *eth_inactive, size_t eth_inactiveSize); + #ifdef __cplusplus } #endif diff --git a/include/ttmlir/Bindings/Python/TTMLIRModule.h b/include/ttmlir/Bindings/Python/TTMLIRModule.h index e0d089edd..5f2d4e134 100644 --- a/include/ttmlir/Bindings/Python/TTMLIRModule.h +++ b/include/ttmlir/Bindings/Python/TTMLIRModule.h @@ -21,9 +21,40 @@ #include "ttmlir/RegisterAll.h" #include "llvm/Support/CommandLine.h" +#include + namespace py = pybind11; namespace mlir::ttmlir::python { + +template +py::class_ tt_attribute_class(py::module &m, const char *class_name) { + py::class_ cls(m, class_name); + cls.def_static("maybe_downcast", + [](MlirAttribute attr) -> std::variant { + auto res = mlir::dyn_cast(unwrap(attr)); + if (res) { + return res; + } + return py::none(); + }); + return cls; +} + +template +py::class_ tt_type_class(py::module &m, const char *class_name) { + py::class_ cls(m, class_name); + cls.def_static("maybe_downcast", + [](MlirType type) -> std::variant { + auto res = mlir::dyn_cast(unwrap(type)); + if (res) { + return res; + } + return py::none(); + }); + return cls; +} + void populateTTModule(py::module &m); void populateTTIRModule(py::module &m); void populateTTKernelModule(py::module &m); diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index 8f2949852..b4bdc0e31 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -182,4 +182,34 @@ ttmlirTTOperandConstraintArrayAttrGet(MlirContext ctx, return wrap(ArrayAttr::get(unwrap(ctx), operandConstraintsArray)); } +MlirAttribute ttmlirTTTileSizeAttrGet(MlirContext ctx, int64_t y, int64_t x) { + return wrap(TileSizeAttr::get(unwrap(ctx), y, x)); +} + +MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( + MlirContext ctx, MlirAttribute *worker, size_t workerSize, + MlirAttribute *dram, size_t dramSize, MlirAttribute *eth, size_t ethSize, + MlirAttribute *eth_inactive, size_t eth_inactiveSize) { + std::vector workerVec, dramVec, ethVec, ethInactiveVec; + for (size_t i = 0; i < workerSize; i++) { + workerVec.push_back(mlir::cast(unwrap(worker[i]))); + } + + for (size_t i = 0; i < dramSize; i++) { + dramVec.push_back(mlir::cast(unwrap(dram[i]))); + } + + for (size_t i = 0; i < ethSize; i++) { + ethVec.push_back(mlir::cast(unwrap(eth[i]))); + } + + for (size_t i = 0; i < eth_inactiveSize; i++) { + ethInactiveVec.push_back( + mlir::cast(unwrap(eth_inactive[i]))); + } + + return wrap(ChipPhysicalCoresAttr::get(unwrap(ctx), workerVec, dramVec, + ethVec, ethInactiveVec)); +} + } // namespace mlir::tt diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 730946362..e43cb858d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -102,6 +102,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main set(TTMLIR_PYTHON_SOURCES MLIRPythonSources.Core + MLIRPythonSources.Dialects.affine MLIRPythonSources.Dialects.arith MLIRPythonSources.Dialects.func MLIRPythonSources.Dialects.tensor diff --git a/python/TTModule.cpp b/python/TTModule.cpp index 9cdc49f27..687626862 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -16,7 +16,7 @@ namespace mlir::ttmlir::python { void populateTTModule(py::module &m) { - py::class_(m, "LayoutAttr") + tt_attribute_class(m, "LayoutAttr") .def_static("get", [](MlirContext ctx, MlirType rankedTensorType, uint32_t memorySpaceValue, MlirAttribute grid, @@ -103,34 +103,48 @@ void populateTTModule(py::module &m) { return static_cast(la.getMemLayout()); }); - py::class_(m, "GridAttr") + tt_attribute_class(m, "GridAttr") .def_static("get", [](MlirContext ctx, std::vector shape) { return wrap(tt::GridAttr::get(unwrap(ctx), shape)); }) - .def_property_readonly("shape", [](tt::GridAttr const &ga) { - return std::vector(ga.getShape().begin(), ga.getShape().end()); - }); + .def_property_readonly( + "shape", [](tt::GridAttr const &ga) { return ga.getShape().vec(); }); - py::class_(m, "ChipCapabilityAttr") - .def_static("get", [](MlirContext ctx, uint32_t chipCapability) { - return wrap(tt::ChipCapabilityAttr::get( - unwrap(ctx), static_cast(chipCapability))); - }); + tt_attribute_class(m, "ChipCapabilityAttr") + .def_static( + "get", + [](MlirContext ctx, uint32_t chipCapability) { + return wrap(tt::ChipCapabilityAttr::get( + unwrap(ctx), static_cast(chipCapability))); + }) + .def_property_readonly("capability_as_int", + [](tt::ChipCapabilityAttr self) { + return static_cast(self.getValue()); + }); - py::class_(m, "ArchAttr") - .def_static("get", [](MlirContext ctx, uint32_t arch) { - return wrap( - tt::ArchAttr::get(unwrap(ctx), static_cast(arch))); + tt_attribute_class(m, "ArchAttr") + .def_static("get", + [](MlirContext ctx, uint32_t arch) { + return wrap(tt::ArchAttr::get(unwrap(ctx), + static_cast(arch))); + }) + .def_property_readonly("arch_as_int", [](tt::ArchAttr self) { + return static_cast(self.getValue()); }); - py::class_(m, "DataTypeAttr") - .def_static("get", [](MlirContext ctx, uint16_t *supportedDataTypes) { - return wrap(tt::DataTypeAttr::get( - unwrap(ctx), static_cast(*supportedDataTypes))); + tt_attribute_class(m, "DataTypeAttr") + .def_static( + "get", + [](MlirContext ctx, uint16_t *supportedDataTypes) { + return wrap(tt::DataTypeAttr::get( + unwrap(ctx), static_cast(*supportedDataTypes))); + }) + .def_property_readonly("data_type_as_int", [](tt::DataTypeAttr self) { + return static_cast(self.getValue()); }); - py::class_(m, "ChipDescAttr") + tt_attribute_class(m, "ChipDescAttr") .def_static( "get", [](MlirContext ctx, MlirAttribute arch, std::vector grid, @@ -152,105 +166,240 @@ void populateTTModule(py::module &m) { mlir::cast(unwrap(supportedDataTypes)), mlir::cast(unwrap(supportedTileSizes)), numCBs)); - }); + }) + .def_property_readonly("usable_l1_size", + &tt::ChipDescAttr::getUsableL1Size) + .def_property_readonly("usable_dram_channel_size", + &tt::ChipDescAttr::getUsableDramChannelSize) + .def_property_readonly("arch", &tt::ChipDescAttr::getArch) + .def_property_readonly( + "grid", [](tt::ChipDescAttr self) { return self.getGrid().vec(); }) + .def_property_readonly("l1_size", &tt::ChipDescAttr::getL1Size) + .def_property_readonly("num_dram_channels", + &tt::ChipDescAttr::getNumDramChannels) + .def_property_readonly("dram_channel_size", + &tt::ChipDescAttr::getDramChannelSize) + .def_property_readonly("noc_l1_address_align_bytes", + &tt::ChipDescAttr::getNocL1AddressAlignBytes) + .def_property_readonly("pcie_address_align_bytes", + &tt::ChipDescAttr::getPcieAddressAlignBytes) + .def_property_readonly("noc_dram_address_align_bytes", + &tt::ChipDescAttr::getNocDRAMAddressAlignBytes) + .def_property_readonly("l1_unreserved_base", + &tt::ChipDescAttr::getL1UnreservedBase) + .def_property_readonly("erisc_l1_unreserved_base", + &tt::ChipDescAttr::getEriscL1UnreservedBase) + .def_property_readonly("dram_unreserved_base", + &tt::ChipDescAttr::getDramUnreservedBase) + .def_property_readonly("dram_unreserved_end", + &tt::ChipDescAttr::getDramUnreservedEnd) + .def_property_readonly("chip_physical_cores", + &tt::ChipDescAttr::getChipPhysicalCores) + .def_property_readonly("supported_data_types", + [](tt::ChipDescAttr self) { + return self.getSupportedDataTypes().vec(); + }) + .def_property_readonly("supported_tile_sizes", + [](tt::ChipDescAttr self) { + return self.getSupportedTileSizes().vec(); + }) + .def_property_readonly("num_cbs", &tt::ChipDescAttr::getNumCBs); - py::class_(m, "ChipCoordAttr") - .def_static("get", [](MlirContext ctx, unsigned rack, unsigned shelf, - unsigned y, unsigned x) { - return wrap(tt::ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x)); - }); + tt_attribute_class(m, "TileSizeAttr") + .def_static("get", + [](MlirContext ctx, int64_t y, int64_t x) { + return wrap(tt::TileSizeAttr::get(unwrap(ctx), y, x)); + }) + .def_property_readonly("y", &tt::TileSizeAttr::getY) + .def_property_readonly("x", &tt::TileSizeAttr::getX); - py::class_(m, "ChipChannelAttr") - .def_static("get", [](MlirContext ctx, unsigned deviceId0, - std::vector ethernetCoreCoord0, - unsigned deviceId1, - std::vector ethernetCoreCoord1) { - return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0, - ethernetCoreCoord0, deviceId1, - ethernetCoreCoord1)); - }); + tt_attribute_class(m, "ChipPhysicalCoresAttr") + .def_static("get", + [](MlirContext ctx, std::vector worker, + std::vector dram, + std::vector eth, + std::vector eth_inactive) { + return wrap(tt::ChipPhysicalCoresAttr::get( + unwrap(ctx), worker, dram, eth, eth_inactive)); + }) + .def_property_readonly( + "worker", + [](tt::ChipPhysicalCoresAttr self) { return self.getWorker().vec(); }) + .def_property_readonly( + "dram", + [](tt::ChipPhysicalCoresAttr self) { return self.getDram().vec(); }) + .def_property_readonly( + "eth", + [](tt::ChipPhysicalCoresAttr self) { return self.getEth().vec(); }) + .def_property_readonly("eth_inactive", + [](tt::ChipPhysicalCoresAttr self) { + return self.getEthInactive().vec(); + }); + + tt_attribute_class(m, "ChipCoordAttr") + .def_static("get", + [](MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, + unsigned x) { + return wrap( + tt::ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x)); + }) + .def_property_readonly("rack", &tt::ChipCoordAttr::getRack) + .def_property_readonly("shelf", &tt::ChipCoordAttr::getShelf) + .def_property_readonly("y", &tt::ChipCoordAttr::getY) + .def_property_readonly("x", &tt::ChipCoordAttr::getX); + + tt_attribute_class(m, "ChipChannelAttr") + .def_static( + "get", + [](MlirContext ctx, unsigned deviceId0, + std::vector ethernetCoreCoord0, unsigned deviceId1, + std::vector ethernetCoreCoord1) { + return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0, + ethernetCoreCoord0, deviceId1, + ethernetCoreCoord1)); + }) + .def_property_readonly("device_id0", &tt::ChipChannelAttr::getDeviceId0) + .def_property_readonly("ethernet_core_coord0", + [](tt::ChipChannelAttr self) { + return self.getEthernetCoreCoord0().vec(); + }) + .def_property_readonly("device_id1", &tt::ChipChannelAttr::getDeviceId1) + .def_property_readonly("ethernet_core_coord1", + [](tt::ChipChannelAttr self) { + return self.getEthernetCoreCoord1().vec(); + }); - py::class_(m, "SystemDescAttr") + tt_attribute_class(m, "SystemDescAttr") .def_static("get_default", [](MlirContext ctx) { return wrap(tt::SystemDescAttr::getDefault(unwrap(ctx))); }) - .def_static("get", [](MlirContext ctx, - std::vector chipDescs, - std::vector chipDescIndices, - std::vector chipCapabilities, - std::vector chipCoords, - std::vector chipChannels) { - std::vector chipDescsUnwrapped; - for (auto chipDesc : chipDescs) { - chipDescsUnwrapped.push_back( - mlir::cast(unwrap(chipDesc))); - } - std::vector chipCapabilitiesUnwrapped; - for (auto chipCapability : chipCapabilities) { - chipCapabilitiesUnwrapped.push_back( - mlir::cast(unwrap(chipCapability))); - } - std::vector chipCoordsUnwrapped; - for (auto chipCoord : chipCoords) { - chipCoordsUnwrapped.push_back( - mlir::cast(unwrap(chipCoord))); - } - std::vector chipChannelsUnwrapped; - for (auto chipChannel : chipChannels) { - chipChannelsUnwrapped.push_back( - mlir::cast(unwrap(chipChannel))); - } - return wrap(tt::SystemDescAttr::get( - unwrap(ctx), chipDescsUnwrapped, chipDescIndices, - chipCapabilitiesUnwrapped, chipCoordsUnwrapped, - chipChannelsUnwrapped)); + .def_static( + "get", + [](MlirContext ctx, std::vector chipDescs, + std::vector chipDescIndices, + std::vector chipCapabilities, + std::vector chipCoords, + std::vector chipChannels) { + std::vector chipDescsUnwrapped; + for (auto chipDesc : chipDescs) { + chipDescsUnwrapped.push_back( + mlir::cast(unwrap(chipDesc))); + } + std::vector chipCapabilitiesUnwrapped; + for (auto chipCapability : chipCapabilities) { + chipCapabilitiesUnwrapped.push_back( + mlir::cast(unwrap(chipCapability))); + } + std::vector chipCoordsUnwrapped; + for (auto chipCoord : chipCoords) { + chipCoordsUnwrapped.push_back( + mlir::cast(unwrap(chipCoord))); + } + std::vector chipChannelsUnwrapped; + for (auto chipChannel : chipChannels) { + chipChannelsUnwrapped.push_back( + mlir::cast(unwrap(chipChannel))); + } + return wrap(tt::SystemDescAttr::get( + unwrap(ctx), chipDescsUnwrapped, chipDescIndices, + chipCapabilitiesUnwrapped, chipCoordsUnwrapped, + chipChannelsUnwrapped)); + }) + .def_property_readonly( + "chip_descs", + [](tt::SystemDescAttr self) { return self.getChipDescs().vec(); }) + .def_property_readonly("chip_desc_indices", + [](tt::SystemDescAttr self) { + return self.getChipDescIndices().vec(); + }) + .def_property_readonly("chip_capabilities", + [](tt::SystemDescAttr self) { + return self.getChipCapabilities().vec(); + }) + .def_property_readonly( + "chip_coords", + [](tt::SystemDescAttr self) { return self.getChipCoords().vec(); }) + .def_property_readonly("chip_channels", [](tt::SystemDescAttr self) { + return self.getChipChannels().vec(); }); - py::class_(m, "MemorySpaceAttr") - .def_static("get", [](MlirContext ctx, uint32_t memorySpace) { - return wrap(tt::MemorySpaceAttr::get( - unwrap(ctx), static_cast(memorySpace))); - }); + tt_attribute_class(m, "MemorySpaceAttr") + .def_static( + "get", + [](MlirContext ctx, uint32_t memorySpace) { + return wrap(tt::MemorySpaceAttr::get( + unwrap(ctx), static_cast(memorySpace))); + }) + .def_property_readonly("memory_space_as_int", + [](tt::MemorySpaceAttr self) { + return static_cast(self.getValue()); + }); - py::class_(m, "OOBValAttr") - .def_static("get", [](MlirContext ctx, uint32_t oobVal) { - return wrap( - tt::OOBValAttr::get(unwrap(ctx), static_cast(oobVal))); + tt_attribute_class(m, "OOBValAttr") + .def_static("get", + [](MlirContext ctx, uint32_t oobVal) { + return wrap(tt::OOBValAttr::get( + unwrap(ctx), static_cast(oobVal))); + }) + .def_property_readonly("oob_val_as_int", [](tt::OOBValAttr self) { + return static_cast(self.getValue()); }); - py::class_(m, "TensorMemoryLayoutAttr") - .def_static("get", [](MlirContext ctx, uint32_t memLayout) { - return wrap(tt::TensorMemoryLayoutAttr::get( - unwrap(ctx), static_cast(memLayout))); - }); + tt_attribute_class(m, "TensorMemoryLayoutAttr") + .def_static( + "get", + [](MlirContext ctx, uint32_t memLayout) { + return wrap(tt::TensorMemoryLayoutAttr::get( + unwrap(ctx), static_cast(memLayout))); + }) + .def_property_readonly("mem_layout_as_int", + [](tt::TensorMemoryLayoutAttr self) { + return static_cast(self.getValue()); + }); - py::class_(m, "IteratorTypeAttr") - .def_static("get", [](MlirContext ctx, uint32_t iteratorType) { - return wrap(tt::IteratorTypeAttr::get( - unwrap(ctx), static_cast(iteratorType))); - }); + tt_attribute_class(m, "IteratorTypeAttr") + .def_static( + "get", + [](MlirContext ctx, uint32_t iteratorType) { + return wrap(tt::IteratorTypeAttr::get( + unwrap(ctx), static_cast(iteratorType))); + }) + .def_property_readonly("iterator_type_as_int", + [](tt::IteratorTypeAttr self) { + return static_cast(self.getValue()); + }); - py::class_(m, "OperandConstraintAttr") + tt_attribute_class(m, "OperandConstraintAttr") .def_static("get", [](MlirContext ctx, uint32_t operandConstraint) { return wrap(tt::OperandConstraintAttr::get( unwrap(ctx), static_cast(operandConstraint))); }) - .def_static("get", [](MlirContext ctx, - std::vector attributesArray) { - return ::ttmlir::utils::wrapArrayOfMlirAttributesAsAttribute( - ctx, attributesArray); - }); + .def_static( + "get", + [](MlirContext ctx, std::vector attributesArray) { + return ::ttmlir::utils::wrapArrayOfMlirAttributesAsAttribute( + ctx, attributesArray); + }) + .def_property_readonly("operand_constraint_as_int", + [](tt::OperandConstraintAttr self) { + return static_cast(self.getValue()); + }); - py::class_(m, "DeviceType") - .def_static("get", [](MlirContext ctx, MlirAttribute deviceAttr) { - return wrap(tt::DeviceType::get( - unwrap(ctx), mlir::cast(unwrap(deviceAttr)))); + tt_type_class(m, "DeviceType") + .def_static( + "get", + [](MlirContext ctx, MlirAttribute deviceAttr) { + return wrap(tt::DeviceType::get( + unwrap(ctx), mlir::cast(unwrap(deviceAttr)))); + }) + .def_property_readonly("device_attr", [](tt::DeviceType const &self) { + return self.getDesc(); }); - py::class_(m, "DeviceAttr") + tt_attribute_class(m, "DeviceAttr") .def_static("from_system_desc", [](MlirContext ctx, MlirAttribute systemDesc, std::vector meshShape) { @@ -270,11 +419,21 @@ void populateTTModule(py::module &m) { unwrap(workerGridMapping)), unwrap(l1Map), unwrap(dramMap), meshShape, chipIds)); }) - .def("unwrap", [](MlirAttribute const &self) { - return mlir::cast(unwrap(self)); + .def("unwrap", + [](MlirAttribute const &self) { + return mlir::cast(unwrap(self)); + }) + .def_property_readonly("grid_attr", &tt::DeviceAttr::getWorkerGrid) + .def_property_readonly("l1_map", &tt::DeviceAttr::getL1Map) + .def_property_readonly("dram_map", &tt::DeviceAttr::getDramMap) + .def_property_readonly( + "mesh_shape", + [](tt::DeviceAttr const &self) { return self.getMeshShape().vec(); }) + .def_property_readonly("chip_ids", [](tt::DeviceAttr const &self) { + return self.getChipIds().vec(); }); - py::class_(m, "TileType") + tt_type_class(m, "TileType") .def_static("get", [](MlirContext ctx, std::int64_t height, std::int64_t width, uint32_t dataType) { diff --git a/python/TTNNModule.cpp b/python/TTNNModule.cpp index b8b408cb9..24bd05c8f 100644 --- a/python/TTNNModule.cpp +++ b/python/TTNNModule.cpp @@ -7,7 +7,7 @@ namespace mlir::ttmlir::python { void populateTTNNModule(py::module &m) { - py::class_(m, "CoreRangeAttr") + tt_attribute_class(m, "CoreRangeAttr") .def_static("get", [](MlirContext ctx, std::vector offset, std::vector size) { @@ -27,8 +27,15 @@ void populateTTNNModule(py::module &m) { offsetVec)); }, py::arg("ctx"), py::arg("grid"), - py::arg("offset") = std::vector{0, 0}); - py::class_(m, "LayoutAttr") + py::arg("offset") = std::vector{0, 0}) + .def_property_readonly( + "offset", + [](tt::ttnn::CoreRangeAttr self) { return self.getOffset().vec(); }) + .def_property_readonly("size", [](tt::ttnn::CoreRangeAttr self) { + return self.getSize().vec(); + }); + + tt_attribute_class(m, "LayoutAttr") .def_static("get", [](MlirContext ctx, uint32_t layout) { return wrap(tt::ttnn::LayoutAttr::get( @@ -37,7 +44,9 @@ void populateTTNNModule(py::module &m) { .def_property_readonly("value", [](tt::ttnn::LayoutAttr self) { return static_cast(self.getValue()); }); - py::class_(m, "TensorMemoryLayoutAttr") + + tt_attribute_class(m, + "TensorMemoryLayoutAttr") .def_static("get", [](MlirContext ctx, uint32_t tensorMemoryLayout) { return wrap(tt::ttnn::TensorMemoryLayoutAttr::get( @@ -48,7 +57,7 @@ void populateTTNNModule(py::module &m) { [](tt::ttnn::TensorMemoryLayoutAttr self) { return static_cast(self.getValue()); }); - py::class_(m, "BufferTypeAttr") + tt_attribute_class(m, "BufferTypeAttr") .def_static( "get", [](MlirContext ctx, uint32_t bufferType) { @@ -58,7 +67,8 @@ void populateTTNNModule(py::module &m) { .def_property_readonly("value", [](tt::ttnn::BufferTypeAttr self) { return static_cast(self.getValue()); }); - py::class_(m, "ShardSpecAttr") + + tt_attribute_class(m, "ShardSpecAttr") .def_static("get", [](MlirContext ctx, tt::ttnn::ShapeAttr shardShape) { return wrap( @@ -66,7 +76,8 @@ void populateTTNNModule(py::module &m) { }) .def_property_readonly("shard_shape", &tt::ttnn::ShardSpecAttr::getShardShape); - py::class_(m, "MemoryConfigAttr") + + tt_attribute_class(m, "MemoryConfigAttr") .def_static("get", [](MlirContext ctx, tt::ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr, @@ -97,7 +108,8 @@ void populateTTNNModule(py::module &m) { &tt::ttnn::MemoryConfigAttr::getBufferType) .def_property_readonly("shard_spec", &tt::ttnn::MemoryConfigAttr::getShardSpec); - py::class_(m, "ShapeAttr") + + tt_attribute_class(m, "ShapeAttr") .def_static("get", [](MlirContext ctx, std::vector shape) { return wrap(tt::ttnn::ShapeAttr::get(unwrap(ctx), shape)); @@ -106,7 +118,8 @@ void populateTTNNModule(py::module &m) { return std::vector(self.getShape().begin(), self.getShape().end()); }); - py::class_(m, "MeshShapeAttr") + + tt_attribute_class(m, "MeshShapeAttr") .def_static("get", [](MlirContext ctx, int64_t y, int64_t x) { return wrap(