diff --git a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py new file mode 100644 index 00000000000..3dcff16cfd4 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger + +import ttnn +import ttnn.operations +from models.utility_functions import comp_allclose_and_pcc, skip_for_grayskull +from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( + get_compute_kernel_options, + compute_kernel_options, + compute_kernel_ids, +) + + +def create_tt_tensor(tensor: torch.Tensor, device): + return ttnn.from_torch(tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + +def get_tensors( + input_shape, other_shape, output_shape, require_input_grad, require_other_grad, is_1d, device, use_randint=True +): + npu_dtype = ttnn.bfloat16 + cpu_dtype = torch.bfloat16 + npu_layout = ttnn.TILE_LAYOUT + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + + # create tensors for forward + if use_randint: + input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype) + other = torch.randint(-2, 3, other_shape, dtype=cpu_dtype) + output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype) + else: + input = torch.rand(input_shape, dtype=cpu_dtype) + other = torch.rand(other_shape, dtype=cpu_dtype) + output = torch.rand(output_shape, dtype=cpu_dtype) + + tt_input = create_tt_tensor(input, device) + tt_other = create_tt_tensor(other, device) + tt_output = create_tt_tensor(output, device) + + torch_input = input.reshape(-1) if is_1d else input + torch_other = other.reshape(-1) if is_1d else other + + # tensors for backward + output_grad = tt_output_grad = torch_output_grad = tt_input_grad = tt_other_grad = None + if require_input_grad or require_other_grad: + output_grad = torch.randint(-2, 3, output_shape, dtype=cpu_dtype) + tt_output_grad = ttnn.Tensor(output_grad, npu_dtype).pad_to_tile(float(-1)).to(npu_layout).to(device) + torch_output_grad = output_grad[0][0][0][0] if is_1d else output_grad + + if require_input_grad: + input_grad = torch.full(input_shape, float("nan"), dtype=cpu_dtype) + tt_input_grad = ttnn.Tensor(input_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + + if require_other_grad: + other_grad = torch.full(other_shape, float("nan"), dtype=cpu_dtype) + tt_other_grad = ( + ttnn.Tensor( + other_grad, + npu_dtype, + ) + .pad_to_tile(float("nan")) + .to(npu_layout) + .to(device) + ) + + return ( + tt_input, + tt_other, + tt_output, + tt_output_grad, + tt_input_grad, + tt_other_grad, + torch_input, + torch_other, + torch_output_grad, + ) + + +def moreh_matmul(params, has_output, compute_kernel_config, device): + torch.manual_seed(3072) + input_shape, other_shape, output_shape, transpose_input, transpose_other = params + tt_input, tt_other, tt_output, _, _, _, torch_input, torch_other, _ = get_tensors( + input_shape, other_shape, output_shape, False, False, False, device + ) + if not has_output: + tt_output = None + + torch_input = torch_input.transpose(-1, -2) if transpose_input else torch_input + torch_other = torch_other.transpose(-1, -2) if transpose_other else torch_other + + # tt matmul + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + tt_output = ttnn.operations.moreh.matmul( + tt_input, + tt_other, + transpose_input=transpose_input, + transpose_other=transpose_other, + output=tt_output, + compute_kernel_config=compute_kernel_config, + ) + tt_output_cpu = tt_output.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + + # torch matmul + torch_out = torch.matmul(torch_input, torch_other) + + # test for equivalance + rtol = atol = 0.1 + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + + return passing + + +@pytest.mark.parametrize( + "params", + ( + # input, other, output shape, transpose input, other + ([32, 32], [32, 32], [32, 32], False, False), # single-core + ([1024, 128], [128, 1024], [1024, 1024], False, False), # multi-core + ([128, 1024], [128, 1024], [1024, 1024], True, False), # input transpose + ([1024, 128], [1024, 128], [1024, 1024], False, True), # other transpose + ([128, 1024], [1024, 128], [1024, 1024], True, True), # input, other transpose + ([1020, 128], [128, 1024], [1020, 1024], False, False), # input mask + ([1024, 128], [128, 1020], [1024, 1020], False, False), # other mask + ([1020, 310], [310, 1020], [1020, 1020], False, False), # input, other mask + ([128, 1020], [128, 1024], [1020, 1024], True, False), # input mask, transpose + ([1024, 128], [1020, 128], [1024, 1020], False, True), # other mask, transpose + ([310, 1020], [1020, 310], [1020, 1020], True, True), # input, other mask, transpose + ([3, 1, 2, 1, 4, 1, 319, 95], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], False, False), # batched matmul + ([2, 319, 95], [2, 1, 3, 4, 1, 95, 470], [2, 1, 3, 4, 2, 319, 470], False, False), # batched matmul + ([3, 1, 2, 1, 4, 1, 95, 319], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], True, False), # batched matmul + ([2, 319, 95], [2, 1, 3, 4, 1, 470, 95], [2, 1, 3, 4, 2, 319, 470], False, True), # batched matmul + ( + [2, 3, 1, 2, 3, 2, 64, 64], + [2, 1, 4, 2, 1, 2, 64, 64], + [2, 3, 4, 2, 3, 2, 64, 64], + False, + False, + ), # batched matmul + ), +) +@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) +def test_moreh_matmul(params, compute_kernel_options, device): + compute_kernel_config = get_compute_kernel_options(compute_kernel_options) + passing = moreh_matmul(params, True, compute_kernel_config, device) + assert passing + + +@pytest.mark.parametrize( + "params", + ( + # input, other, output shape, transpose input, other + ([32, 32], [32, 32], [32, 32], False, False), # single-core + ([3, 1, 2, 1, 4, 1, 95, 319], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], True, False), # batched matmul + ([2, 319, 95], [2, 1, 3, 4, 1, 470, 95], [2, 1, 3, 4, 2, 319, 470], False, True), # batched matmul + ( + [2, 3, 1, 2, 3, 2, 64, 64], + [2, 1, 4, 2, 1, 2, 64, 64], + [2, 3, 4, 2, 3, 2, 64, 64], + False, + False, + ), # batched matmul + ), +) +def test_moreh_matmul_wo_output(params, device): + passing = moreh_matmul(params, False, None, device) + assert passing + + +@pytest.mark.parametrize( + "params", + ( + # input, weight, bias(1d or scalar), output + ([32, 32], [32, 32], [32, 32], False, False), # single-core + ( + [2, 3, 1, 2, 3, 2, 64, 64], + [2, 1, 4, 2, 1, 2, 64, 64], + [2, 3, 4, 2, 3, 2, 64, 64], + False, + False, + ), # batched matmul + ), +) +def test_moreh_matmul_enable_cache(params, device, use_program_cache): + torch.manual_seed(3072) + for i in range(4): + # change input's transpose option + if i % 2 == 1: + param_list = list(params) + param_list[3] = False if param_list[3] else True + params = tuple(param_list) + passing = moreh_matmul(params, False, None, device) + assert passing + assert device.num_program_cache_entries() == 2 + + +@skip_for_grayskull("Doesn't seem to work properly on Grayskull devices. Wormhole_b0 devices work fine.") +@pytest.mark.parametrize( + "params", + ( + # input, other, output shape, transpose input, other + ([32, 3200], [3200, 32], [32, 32], False, False), + ([3100, 31], [3100, 31], [31, 31], True, False), + ), +) +def test_moreh_matmul_fp32_dest_acc(params, device): + torch.manual_seed(3072) + input_shape, other_shape, output_shape, transpose_input, transpose_other = params + tt_input, tt_other, tt_output_fp32, _, _, _, torch_input, torch_other, _ = get_tensors( + input_shape, other_shape, output_shape, False, False, False, device, use_randint=False + ) + + compute_kernel_config_fp32_dest_acc = get_compute_kernel_options(True) + compute_kernel_config_bf16_dest_acc = get_compute_kernel_options(False) + + torch_input = torch_input.transpose(-1, -2) if transpose_input else torch_input + torch_other = torch_other.transpose(-1, -2) if transpose_other else torch_other + + # tt matmul + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + tt_output_fp32 = ttnn.operations.moreh.matmul( + tt_input, + tt_other, + transpose_input=transpose_input, + transpose_other=transpose_other, + output=tt_output_fp32, + compute_kernel_config=compute_kernel_config_fp32_dest_acc, + ) + + tt_output_fp16 = ttnn.operations.moreh.matmul( + tt_input, + tt_other, + transpose_input=transpose_input, + transpose_other=transpose_other, + compute_kernel_config=compute_kernel_config_bf16_dest_acc, + ) + + tt_output_cpu_fp32 = tt_output_fp32.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + tt_output_cpu_bf16 = tt_output_fp16.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + + # torch matmul (float) + torch_out = torch.matmul(torch_input.float(), torch_other.float()) + + # test for equivalance + rtol = atol = 0.1 + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu_fp32, pcc=0.99, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + diff = torch.abs(torch_out - tt_output_cpu_fp32) + logger.debug(f"std={torch.std(diff)}") + logger.debug(f"mean={diff.mean()}") + logger.debug(f"topk(5) {torch.topk(diff.reshape(-1), 5)}") + + assert passing + + torch_out = torch.matmul(torch_input.bfloat16(), torch_other.bfloat16()) + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu_bf16, pcc=0.99, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + diff_fp16 = torch.abs(torch_out - tt_output_cpu_bf16) + logger.debug(f"std={torch.std(diff_fp16)}") + logger.debug(f"mean={diff_fp16.mean()}") + logger.debug(f"topk(5) {torch.topk(diff_fp16.reshape(-1), 5)}") + + assert diff.mean() < diff_fp16.mean() + + +@pytest.mark.parametrize( + "input_shape", + ( + [1, 1, 1, 10], # test not mutiple of 32 case + [1, 1, 1, 32], # test single tile + [1, 1, 1, 352], # test multiple tiles + [1, 1, 1, 323], # test multiple tiles, not a multiple of 32 + ), +) +def test_moreh_matmul_1d(input_shape, device): + torch.manual_seed(3072) + # get tensors + output_shape = [1, 1, 1, 1] + tt_input, tt_other, _, _, _, _, torch_input, torch_other, _ = get_tensors( + input_shape, input_shape, output_shape, False, False, True, device + ) + + # tt matmul + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + tt_out = ( + ttnn.operations.moreh.matmul(tt_input, tt_other).cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + ) + + # torch matmul + torch_input = torch.reshape(torch_input, (torch_input.shape[-1],)) + torch_other = torch.reshape(torch_other, (torch_other.shape[-1],)) + torch_out = torch.matmul(torch_input, torch_other) + + # test for equivalance + rtol = atol = 0.1 + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_out[0][0][0][0], pcc=0.999, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + + assert passing diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 8d01e3741fc..ef290689dde 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -390,6 +390,11 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/moreh_nll_loss_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_device_operation.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp ) # Split src and python bindings diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/moreh_matmul.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/moreh_matmul.cpp new file mode 100644 index 00000000000..11a16b9a8d7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/moreh_matmul.cpp @@ -0,0 +1,360 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// Implemented based on bmm.cpp +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" + +#include "compute_kernel_api/matmul.h" +#include "compute_kernel_api/transpose_wh.h" + +namespace NAMESPACE { + +//////////////////// +// global variables +//////////////////// +constexpr int32_t MAX_NUM_DIMENSIONS = 8; +constexpr uint32_t onetile = 1; +constexpr uint32_t num_mask_tiles = 3; +constexpr uint32_t MASK_TILE_H_IDX = 0; +constexpr uint32_t MASK_TILE_W_IDX = 1; +constexpr uint32_t MASK_TILE_HW_IDX = 2; +constexpr uint32_t cb_in0 = tt::CB::c_in0; +constexpr uint32_t cb_in1 = tt::CB::c_in1; +constexpr uint32_t cb_in2 = tt::CB::c_in2; +constexpr uint32_t cb_in3 = tt::CB::c_in3; +constexpr uint32_t bias_cb_id = tt::CB::c_in4; +constexpr uint32_t cb_out0 = tt::CB::c_out0; +constexpr uint32_t cb_intermed0 = tt::CB::c_intermed0; +constexpr uint32_t cb_intermed1 = tt::CB::c_intermed1; +constexpr uint32_t cb_intermed2 = tt::CB::c_intermed2; + +//////////////////// +// inline functions +//////////////////// +FORCE_INLINE void unravel_output_tidx(uint32_t output_tidx, uint32_t* output_idxes, uint32_t* output_stride) { + for (int32_t i = MAX_NUM_DIMENSIONS - 1; i >= 0;--i) { + uint32_t dim = output_tidx / output_stride[i]; + output_idxes[i] = dim; + output_tidx -= (output_idxes[i] * output_stride[i]); + } +} + +// TODO: move it to moreh_common.hpp if more use cases. +FORCE_INLINE void transpose_wh_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32_t idst = 0) +{ + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format_srca(icb); + #endif + transpose_wh_init_short(icb); + tile_regs_acquire(); + transpose_wh_tile(icb, itile, idst); + tile_regs_commit(); + cb_reserve_back(ocb, onetile); + tile_regs_wait(); + #if defined FP32_DEST_ACC_EN + pack_reconfig_data_format(ocb); + #endif + pack_tile(idst, ocb); + tile_regs_release(); + cb_push_back(ocb, onetile); +} + +FORCE_INLINE void transpose_tile(uint32_t &mm_src, bool transpose, bool need_mask, bool is_input) { + if (!transpose) { + return; + } + + if (need_mask) { + cb_wait_front(mm_src, onetile); + transpose_wh_tile_to_cb(mm_src, mm_src); + cb_pop_front(mm_src, onetile); + } + else { + uint32_t trans_src = (is_input) ? (cb_in0) : (cb_in1); + mm_src = (is_input) ? (cb_intermed1) : (cb_intermed2); + transpose_wh_tile_to_cb(trans_src, mm_src); + } +} + +FORCE_INLINE void pack_onetile_to_cb(uint32_t ocb = 16, uint32_t idst = 0) +{ + cb_reserve_back(ocb, onetile); + tile_regs_wait(); + #if defined FP32_DEST_ACC_EN + pack_reconfig_data_format(ocb); + #endif + pack_tile(idst, ocb); + tile_regs_release(); + cb_push_back(ocb, onetile); +} + +FORCE_INLINE void mask_tile_to_cb(uint32_t& mm_src, bool& need_mask, bool need_mask_h, bool need_mask_w, bool last_out, bool last_line, bool transpose, bool is_input) { + bool need_mask_last_line_and_out = (last_line && last_out); + bool need_mask_last_line = false; + bool need_mask_last_out = false; + + if (!(need_mask_w || need_mask_h)) { + return; + } + + if (is_input) { + need_mask_last_line = last_line && ((transpose) ? (need_mask_w) : (need_mask_h)); + need_mask_last_out = last_out && ((transpose) ? (need_mask_h) : (need_mask_w)); + } + else { + need_mask_last_line = last_line && ((transpose) ? (need_mask_h) : (need_mask_w)); + need_mask_last_out = last_out && ((transpose) ? (need_mask_w) : (need_mask_h)); + } + + if (need_mask_last_line_and_out || need_mask_last_line || need_mask_last_out) { + uint32_t cb_in = (is_input) ? (cb_in0) : (cb_in1); + uint32_t cb_mask = (is_input) ? (cb_in2) : (cb_in3); + uint32_t cb_intermed = (is_input) ? (cb_intermed1) : (cb_intermed2); + uint32_t mask_tidx = MASK_TILE_H_IDX; + if (need_mask_last_line_and_out) { + mask_tidx = MASK_TILE_HW_IDX; + } + else if (need_mask_last_line) { + if (is_input) { + mask_tidx = (transpose) ? (MASK_TILE_W_IDX) : (MASK_TILE_H_IDX); + } + else { + mask_tidx = (transpose) ? (MASK_TILE_H_IDX) : (MASK_TILE_W_IDX); + } + } + else { + if (is_input) { + mask_tidx = (transpose) ? (MASK_TILE_H_IDX) : (MASK_TILE_W_IDX); + } + else { + mask_tidx = (transpose) ? (MASK_TILE_W_IDX) : (MASK_TILE_H_IDX); + } + } + + // mul input tile with mask tile + tile_regs_acquire(); + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_in0, cb_mask); + #endif + mul_tiles_init(cb_in, cb_mask); + mul_tiles(cb_in, cb_mask, 0, mask_tidx, 0); + tile_regs_commit(); + + pack_onetile_to_cb(cb_intermed); + mm_src = cb_intermed; + need_mask = true; + } +} + +#ifdef FUSE_BIAS +FORCE_INLINE void bias_add(bool is_scalar_bias) +{ + static bool scalar_bias_loaded = false; + pack_onetile_to_cb(cb_intermed0); + cb_wait_front(cb_intermed0, onetile); + + if (is_scalar_bias && !scalar_bias_loaded) { + cb_wait_front(bias_cb_id, onetile); + scalar_bias_loaded = true; + } else { + cb_wait_front(bias_cb_id, onetile); + } + + tile_regs_acquire(); + if (is_scalar_bias) { + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_intermed0, bias_cb_id); + #endif + add_bcast_scalar_init_short(cb_intermed0, bias_cb_id); + add_tiles_bcast_scalar(cb_intermed0, bias_cb_id, 0, 0, 0); + } + else { + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_intermed0, bias_cb_id); + #endif + add_bcast_rows_init_short(cb_intermed0, bias_cb_id); + add_tiles_bcast_rows(cb_intermed0, bias_cb_id, 0, 0, 0); + } + tile_regs_commit(); + + cb_pop_front(cb_intermed0, onetile); + if (!is_scalar_bias) { + cb_pop_front(bias_cb_id, onetile); + } +} +#endif + +FORCE_INLINE void matmul_with_transpose_and_mask(uint32_t output_tidx, uint32_t num_output_tiles, uint32_t Kt, bool transpose_input, bool transpose_other, + bool need_input_mask_h, bool need_input_mask_w, uint32_t *output_stride, uint32_t Mt, uint32_t Nt, + bool need_other_mask_h, bool need_other_mask_w, bool is_scalar_bias) +{ + // TODO: checking required when the input cb format and intermediate cb format are different. + mm_init(cb_in0, cb_in1, cb_out0); + if (transpose_input || transpose_other) { + transpose_wh_init(cb_in0); + } + + if (need_input_mask_h || need_input_mask_w) { + cb_wait_front(cb_in2, num_mask_tiles); + } + + if (need_other_mask_h || need_other_mask_w) { + cb_wait_front(cb_in3, num_mask_tiles); + } + + #pragma GCC unroll 0 + for (uint32_t i = 0; i < num_output_tiles; ++i) { + bool spill = Kt > 1; + bool enable_reload = false; + + // get row and column positions of input and other based on output tile indexes. + uint32_t output_idxes[MAX_NUM_DIMENSIONS]; + unravel_output_tidx(output_tidx, output_idxes, output_stride); + bool input_last_row = (output_idxes[1] == Mt - 1) ? (true) : (false); + bool other_last_col = (output_idxes[0] == Nt - 1) ? (true) : (false); + + #pragma GCC unroll 0 + for (uint32_t kt = 0; kt < Kt; kt++) { + bool last_out = kt == (Kt - 1); + bool need_input_mask = false; + bool need_other_mask = false; + + uint32_t mm_src0 = cb_in0; + uint32_t mm_src1 = cb_in0; + + cb_wait_front(cb_in0, onetile); + cb_wait_front(cb_in1, onetile); + + mm_src0 = cb_in0; + mm_src1 = cb_in1; + + //////////////////// + // mask: the first two arguments (mm_src0, need_input_mask) are passed by reference. + // transpose: the first argument (mm_src0) is passed by reference. + //////////////////// + mask_tile_to_cb(mm_src0, need_input_mask, need_input_mask_h, need_input_mask_w, last_out, input_last_row, transpose_input, true); + transpose_tile(mm_src0, transpose_input, need_input_mask, true); + + mask_tile_to_cb(mm_src1, need_other_mask, need_other_mask_h, need_other_mask_w, last_out, other_last_col, transpose_other, false); + transpose_tile(mm_src1, transpose_other, need_other_mask, false); + + //////////////////// + // matmul + //////////////////// + tile_regs_acquire(); + if (enable_reload) { + cb_wait_front(cb_intermed0, onetile); + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format_srca(cb_intermed0); + #endif + copy_tile_to_dst_init_short(cb_intermed0); + copy_tile(cb_intermed0, 0, 0); + cb_pop_front(cb_intermed0, onetile); + } + + if (transpose_input || need_input_mask) { + cb_wait_front(mm_src0, onetile); + } + + if (transpose_other || need_other_mask) { + cb_wait_front(mm_src1, onetile); + } + + mm_init_short(mm_src0, mm_src1); + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(mm_src0, mm_src1); + #endif + matmul_tiles(mm_src0, mm_src1, 0, 0, 0, false); + tile_regs_commit(); + + cb_pop_front(cb_in0, onetile); + cb_pop_front(cb_in1, onetile); + + if (transpose_input || need_input_mask) { + cb_pop_front(mm_src0, onetile); + } + if (transpose_other || need_other_mask) { + cb_pop_front(mm_src1, onetile); + } + + if (last_out) { + //////////////////// + // bias add + //////////////////// + #ifdef FUSE_BIAS + bias_add(is_scalar_bias); + #endif + pack_onetile_to_cb(cb_out0); + } + else { + pack_onetile_to_cb(cb_intermed0); + } + + if (spill) { + enable_reload = true; + } + } + output_tidx++; + } +} + +FORCE_INLINE void matmul(uint32_t num_output_tiles, uint32_t Kt) { + mm_init(cb_in0, cb_in1, cb_out0); + for (uint32_t i = 0; i < num_output_tiles; ++i) { + tile_regs_acquire(); + for (uint32_t kt = 0; kt < Kt; kt++) { + cb_wait_front(cb_in0, onetile); + cb_wait_front(cb_in1, onetile); + matmul_tiles(cb_in0, cb_in1, 0, 0, 0, false); + cb_pop_front(cb_in0, onetile); + cb_pop_front(cb_in1, onetile); + } + tile_regs_commit(); + pack_onetile_to_cb(cb_out0); + } +} + +void MAIN { + // compile-time args + constexpr uint32_t num_output_tiles = get_compile_time_arg_val(0); + constexpr uint32_t Mt = get_compile_time_arg_val(1); + constexpr uint32_t Nt = get_compile_time_arg_val(2); + constexpr uint32_t Kt = get_compile_time_arg_val(3); + constexpr bool transpose_input = (get_compile_time_arg_val(4) == 1); + constexpr bool transpose_other = (get_compile_time_arg_val(5) == 1); + constexpr uint32_t input_mask_h = get_compile_time_arg_val(6); + constexpr uint32_t input_mask_w = get_compile_time_arg_val(7); + constexpr uint32_t other_mask_h = get_compile_time_arg_val(8); + constexpr uint32_t other_mask_w = get_compile_time_arg_val(9); + #ifdef FUSE_BIAS + constexpr bool is_scalar_bias = (get_compile_time_arg_val(10) == 1); + constexpr bool need_bias_add = true; + #else + constexpr bool is_scalar_bias = false; + constexpr bool need_bias_add = false; + #endif + constexpr bool need_input_mask_h = (input_mask_h != 32); + constexpr bool need_input_mask_w = (input_mask_w != 32); + constexpr bool need_other_mask_h = (other_mask_h != 32); + constexpr bool need_other_mask_w = (other_mask_w != 32); + constexpr bool need_mask = (need_input_mask_h || need_input_mask_w || need_other_mask_h || need_other_mask_w); + constexpr bool need_transpose = (transpose_input || transpose_other); + + // runtime args + ArgFetcher arg_fetcher; + uint32_t output_tile_start_idx = arg_fetcher.get_next_arg_val(); + uint32_t output_stride[MAX_NUM_DIMENSIONS]; + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + output_stride[i] = arg_fetcher.get_next_arg_val(); + } + + if (need_transpose || need_mask || need_bias_add) { + matmul_with_transpose_and_mask(output_tile_start_idx, num_output_tiles, Kt, transpose_input, transpose_other, + need_input_mask_h, need_input_mask_w, output_stride, Mt, Nt, need_other_mask_h, need_other_mask_w, is_scalar_bias); + } + else { + matmul(num_output_tiles, Kt); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/reader_moreh_matmul.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/reader_moreh_matmul.cpp new file mode 100644 index 00000000000..5652d64d678 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/reader_moreh_matmul.cpp @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" + +static constexpr int32_t MAX_NUM_DIMENSIONS = 8; + +inline uint32_t get_tidx(uint32_t* output_idxes, uint32_t* stride, uint32_t* not_bcast, bool transpose, bool use_h_dim) { + uint32_t tidx = 0; + // batch dim + for (int32_t i = MAX_NUM_DIMENSIONS - 1; i >= 2; --i) { + tidx += not_bcast[i] * stride[i] * output_idxes[i]; + } + + // last 2-dim + int32_t i = transpose ? (use_h_dim ? 0 : 1) : (use_h_dim ? 1 : 0); + tidx += not_bcast[i] * stride[i] * output_idxes[use_h_dim ? 1 : 0]; + return tidx; +} + +inline void unravel_output_tidx(uint32_t output_tidx, uint32_t* output_idxes, uint32_t* output_stride) { + for (int32_t i = MAX_NUM_DIMENSIONS - 1; i >= 0;--i) { + uint32_t dim = output_tidx / output_stride[i]; + output_idxes[i] = dim; + output_tidx -= (output_idxes[i] * output_stride[i]); + } +} + +void kernel_main() { + // compile-time args + constexpr bool input_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool other_is_dram = get_compile_time_arg_val(1) == 1; + constexpr uint32_t Kt = get_compile_time_arg_val(2); + bool transpose_input = (get_compile_time_arg_val(3) == 1); + bool transpose_other = (get_compile_time_arg_val(4) == 1); + uint32_t input_mask_h = get_compile_time_arg_val(5); + uint32_t input_mask_w = get_compile_time_arg_val(6); + uint32_t other_mask_h = get_compile_time_arg_val(7); + uint32_t other_mask_w = get_compile_time_arg_val(8); + #ifdef FUSE_BIAS + constexpr bool bias_is_dram = (get_compile_time_arg_val(9) == 1); + bool is_scalar_bias = (get_compile_time_arg_val(10) == 1); + bool scalar_bias_loaded = false; + #endif + + // runtime args + ArgFetcher arg_fetcher; + uint32_t input_addr = arg_fetcher.get_next_arg_val(); + uint32_t other_addr = arg_fetcher.get_next_arg_val(); + uint32_t output_tile_start_idx = arg_fetcher.get_next_arg_val(); + uint32_t num_output_tiles = arg_fetcher.get_next_arg_val(); + + uint32_t input_stride[MAX_NUM_DIMENSIONS]; + uint32_t other_stride[MAX_NUM_DIMENSIONS]; + uint32_t output_stride[MAX_NUM_DIMENSIONS]; + uint32_t input_not_bcast[MAX_NUM_DIMENSIONS]; + uint32_t other_not_bcast[MAX_NUM_DIMENSIONS]; + + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + input_stride[i] = arg_fetcher.get_next_arg_val(); + } + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + other_stride[i] = arg_fetcher.get_next_arg_val(); + } + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + output_stride[i] = arg_fetcher.get_next_arg_val(); + } + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + input_not_bcast[i] = arg_fetcher.get_next_arg_val(); + } + for (int32_t i = 0; i < MAX_NUM_DIMENSIONS;++i) { + other_not_bcast[i] = arg_fetcher.get_next_arg_val(); + } + + #ifdef FUSE_BIAS + uint32_t bias_addr = arg_fetcher.get_next_arg_val(); + #endif + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + constexpr uint32_t cb_id_in2 = 2; + constexpr uint32_t cb_id_in3 = 3; + constexpr uint32_t cb_id_in4 = 4; + constexpr uint32_t onetile = 1; + + const uint32_t in0_tile_bytes = get_tile_size(cb_id_in0); + const DataFormat in0_data_format = get_dataformat(cb_id_in0); + const uint32_t in1_tile_bytes = get_tile_size(cb_id_in1); + const DataFormat in1_data_format = get_dataformat(cb_id_in1); + + const InterleavedAddrGenFast s0 = { + .bank_base_address = input_addr, .page_size = in0_tile_bytes, .data_format = in0_data_format}; + + const InterleavedAddrGenFast s1 = { + .bank_base_address = other_addr, .page_size = in1_tile_bytes, .data_format = in1_data_format}; + + #ifdef FUSE_BIAS + const uint32_t in4_tile_bytes = get_tile_size(cb_id_in4); + const DataFormat in4_data_format = get_dataformat(cb_id_in4); + const InterleavedAddrGenFast s_bias = { + .bank_base_address = bias_addr, .page_size = in4_tile_bytes, .data_format = in4_data_format}; + #endif + + // mask + bool need_input_mask_h = (input_mask_h != 32); + bool need_input_mask_w = (input_mask_w != 32); + + if (need_input_mask_h || need_input_mask_w) { + generate_mask_tiles(cb_id_in2, input_mask_h, input_mask_w); + } + + bool need_other_mask_h = (other_mask_h != 32); + bool need_other_mask_w = (other_mask_w != 32); + if (need_other_mask_h || need_other_mask_w) { + generate_mask_tiles(cb_id_in3, other_mask_h, other_mask_w); + } + + uint32_t output_tidx = output_tile_start_idx; + uint32_t input_step_count = (transpose_input) ? (input_stride[1]) : (input_stride[0]); + uint32_t other_step_count = (transpose_other) ? (other_stride[0]) : (other_stride[1]); + + for (uint32_t n = 0; n < num_output_tiles; n++) { + uint32_t output_idxes[MAX_NUM_DIMENSIONS]; + unravel_output_tidx(output_tidx, output_idxes, output_stride); + uint32_t input_tidx = get_tidx(output_idxes, input_stride, input_not_bcast, transpose_input, true); + uint32_t other_tidx = get_tidx(output_idxes, other_stride, other_not_bcast, transpose_other, false); + + for (uint32_t kt = 0; kt < Kt; kt++) { + // read input, other tile + cb_reserve_back(cb_id_in0, onetile); + cb_reserve_back(cb_id_in1, onetile); + + uint32_t l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(input_tidx, s0, l1_write_addr_in0); + + uint32_t l1_write_addr_in1 = get_write_ptr(cb_id_in1); + noc_async_read_tile(other_tidx, s1, l1_write_addr_in1); + noc_async_read_barrier(); + + cb_push_back(cb_id_in0, onetile); + cb_push_back(cb_id_in1, onetile); + + input_tidx += input_step_count; + other_tidx += other_step_count; + } + #ifdef FUSE_BIAS + if (!is_scalar_bias) { + uint32_t bias_tidx = output_idxes[0]; + cb_reserve_back(cb_id_in4, onetile); + uint32_t l1_write_addr_in4 = get_write_ptr(cb_id_in4); + noc_async_read_tile(bias_tidx, s_bias, l1_write_addr_in4); + noc_async_read_barrier(); + cb_push_back(cb_id_in4, onetile); + } else { + if (!scalar_bias_loaded) { + cb_reserve_back(cb_id_in4, onetile); + uint32_t l1_write_addr_in4 = get_write_ptr(cb_id_in4); + noc_async_read_tile(0, s_bias, l1_write_addr_in4); + noc_async_read_barrier(); + cb_push_back(cb_id_in4, onetile); + scalar_bias_loaded = true; + } + } + #endif + + + output_tidx++; + } +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/writer_moreh_matmul.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/writer_moreh_matmul.cpp new file mode 100644 index 00000000000..ef4963f0483 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/writer_moreh_matmul.cpp @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" + +void kernel_main() { + // compile-time args + constexpr bool output_is_dram = (get_compile_time_arg_val(0) == 1); + + // runtime args + ArgFetcher arg_fetcher; + uint32_t output_addr = arg_fetcher.get_next_arg_val(); + uint32_t start_id = arg_fetcher.get_next_arg_val(); + uint32_t num_output_tiles = arg_fetcher.get_next_arg_val(); + + constexpr uint32_t onetile = 1; + constexpr uint32_t cb_id_out = 16; + const uint32_t output_tile_bytes = get_tile_size(cb_id_out); + const DataFormat output_data_format = get_dataformat(cb_id_out); + + const InterleavedAddrGenFast s = { + .bank_base_address = output_addr, .page_size = output_tile_bytes, .data_format = output_data_format }; + + uint32_t end_id = start_id + num_output_tiles; + for (uint32_t i = start_id; i < end_id; i++) { + cb_wait_front(cb_id_out, onetile); + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write_tile(i, s, l1_read_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_out, onetile); + } +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp new file mode 100644 index 00000000000..852bbe9eed5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul_device_operation.hpp" + +#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::moreh::moreh_matmul { + +void MorehMatmulOperation::validate_inputs( + const operation_attributes_t &operation_attributes, const tensor_args_t &tensor_args) { + const bool transpose_input = operation_attributes.transpose_input; + const bool transpose_other = operation_attributes.transpose_other; + + log_debug(tt::LogOp, "{}:{}", __func__, __LINE__); + + const auto &input = tensor_args.input; + const auto &other = tensor_args.other; + const auto &bias = tensor_args.bias; + const auto &output = tensor_args.output; + + // validate tensor + tt::operations::primary::check_tensor(input, "moreh_matmul", "input"); + tt::operations::primary::check_tensor(other, "moreh_matmul", "other"); + tt::operations::primary::check_tensor(output, "moreh_matmul", "output"); + tt::operations::primary::check_tensor(bias, "moreh_matmul", "bias"); + + // check matrix dims + const auto &input_shape = input.get_shape().value.without_padding(); + const auto &other_shape = other.get_shape().value.without_padding(); + const auto &input_wo_shape = input_shape.without_padding(); + const auto &other_wo_shape = other_shape.without_padding(); + uint32_t input_m = (transpose_input) ? (input_wo_shape[-1]) : (input_wo_shape[-2]); + uint32_t input_k = (transpose_input) ? (input_wo_shape[-2]) : (input_wo_shape[-1]); + uint32_t other_k = (transpose_other) ? (other_wo_shape[-1]) : (other_wo_shape[-2]); + uint32_t other_n = (transpose_other) ? (other_wo_shape[-2]) : (other_wo_shape[-1]); + + TT_FATAL(input_k == other_k, "k must be the same. input_k {}, other_k {}", input_k, other_k); + + // check batch dims + std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(input_dim, input_shape); + get_tensor_dim(other_dim, other_shape); + for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + if (input_dim[i] != other_dim[i]) { + TT_FATAL( + input_dim[i] == 1 || other_dim[i] == 1, + "one of dim must be one. {}th dim input_dim {}, other_dim {}", + i, + input_dim[i], + other_dim[i]); + } + } + + // check output dims + if (output.has_value()) { + const auto &output_shape = output.value().get_legacy_shape().without_padding(); + const auto &output_wo_shape = output_shape.without_padding(); + uint32_t output_m = output_wo_shape[-2]; + uint32_t output_n = output_wo_shape[-1]; + TT_FATAL(input_m == output_m, "m must be the same. input_m {}, output_m {}", input_m, output_m); + TT_FATAL(other_n == output_n, "n must be the same. other_n {}, output_n {}", other_n, output_n); + + std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(output_dim, output_shape); + + for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + TT_FATAL( + std::max(input_dim[i], other_dim[i]) == output_dim[i], + "{}th max(input_dim[i], other_dim[i]) {} must be the same as output_dim[i] {}", + i, + std::max(input_dim[i], other_dim[i]), + output_dim[i]); + } + } + + // check bias size + if (bias.has_value()) { + const auto &bias_wo_shape = bias.value().get_legacy_shape().without_padding(); + uint32_t bias_rank = bias_wo_shape.rank(); + uint32_t bias_w = bias_wo_shape[-1]; + TT_FATAL(bias_rank == 2, "bias rank {} must be 2 (tilized).", bias_rank); + TT_FATAL( + bias_w == 1 || bias_w == other_n, + "bias_w must be one or the same as other_n. bias_w {}, other_n {}", + bias_w, + other_n); + } +} + +void MorehMatmulOperation::validate_on_program_cache_hit( + const operation_attributes_t &operation_attributes, const tensor_args_t &tensor_args) { + validate_inputs(operation_attributes, tensor_args); +} + +void MorehMatmulOperation::validate_on_program_cache_miss( + const operation_attributes_t &operation_attributes, const tensor_args_t &tensor_args) { + validate_inputs(operation_attributes, tensor_args); +} + +MorehMatmulOperation::shape_return_value_t compute_output_shapes( + const MorehMatmulOperation::operation_attributes_t &operation_attributes, + const MorehMatmulOperation::tensor_args_t &tensor_args) { + auto input_shape = tensor_args.input.get_shape().value; + auto other_shape = tensor_args.other.get_shape().value; + + auto transpose_input = operation_attributes.transpose_input; + auto transpose_other = operation_attributes.transpose_other; + + const auto &input_shape_wo_padding = input_shape.without_padding(); + const auto &other_shape_wo_padding = other_shape.without_padding(); + + auto h = (transpose_input) ? (input_shape[-1]) : (input_shape[-2]); + auto w = (transpose_other) ? (other_shape[-2]) : (other_shape[-1]); + auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); + auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); + + std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(input_dim, input_shape); + get_tensor_dim(other_dim, other_shape); + + int32_t output_rank = std::max(input_shape.rank(), other_shape.rank()); + log_debug( + tt::LogOp, + "{}:{} input, other, output rank {}, {}, {}", + __func__, + __LINE__, + input_shape.rank(), + other_shape.rank(), + output_rank); + + std::vector output_dim(output_rank); + // batch dims + for (int i = 0; i < output_rank - 2; ++i) { + int idx = output_rank - 1 - i; + TT_ASSERT(idx >= 0); + uint32_t max_dim = std::max(input_dim[idx], other_dim[idx]); + output_dim[i] = max_dim; + } + // matrix dims + output_dim[output_rank - 2] = h; + output_dim[output_rank - 1] = w; + + tt::tt_metal::Shape output_shape{output_dim}; + auto padding = output_shape.padding(); + // padding for t logmatrix dims + padding[output_rank - 2] = Padding::PadDimension{0, h - h_wo_padding}; + padding[output_rank - 1] = Padding::PadDimension{0, w - w_wo_padding}; + return Shape({tt::tt_metal::Shape(output_shape, padding)}); +} +MorehMatmulOperation::tensor_return_value_t MorehMatmulOperation::create_output_tensors( + const MorehMatmulOperation::operation_attributes_t &operation_attributes, + const MorehMatmulOperation::tensor_args_t &tensor_args) { + if (tensor_args.output.has_value()) { + return tensor_args.output.value(); + } + + return create_device_tensor( + compute_output_shapes(operation_attributes, tensor_args), + tensor_args.input.get_dtype(), + Layout::TILE, + tensor_args.input.device(), + operation_attributes.output_memory_config); +}; + +MorehMatmulOperation::program_factory_t MorehMatmulOperation::select_program_factory( + const operation_attributes_t &, const tensor_args_t &) { + return MultiCoreProgramFactory{}; +} + +MorehMatmulOperation::shape_return_value_t MorehMatmulOperation::compute_output_shapes( + const operation_attributes_t &operation_attributes, const tensor_args_t &tensor_args) { + const auto &input_shape = tensor_args.input.get_shape().value; + const auto &other_shape = tensor_args.other.get_shape().value; + bool transpose_input = operation_attributes.transpose_input; + bool transpose_other = operation_attributes.transpose_other; + const auto &input_shape_wo_padding = input_shape.without_padding(); + const auto &other_shape_wo_padding = other_shape.without_padding(); + + auto h = (transpose_input) ? (input_shape[-1]) : (input_shape[-2]); + auto w = (transpose_other) ? (other_shape[-2]) : (other_shape[-1]); + auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); + auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); + + std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(input_dim, input_shape); + get_tensor_dim(other_dim, other_shape); + + int32_t output_rank = std::max(input_shape.rank(), other_shape.rank()); + log_debug( + tt::LogOp, + "{}:{} input, other, output rank {}, {}, {}", + __func__, + __LINE__, + input_shape.rank(), + other_shape.rank(), + output_rank); + + std::vector output_dim(output_rank); + // batch dims + for (int i = 0; i < output_rank - 2; ++i) { + int idx = output_rank - 1 - i; + TT_ASSERT(idx >= 0); + const uint32_t max_dim = std::max(input_dim[idx], other_dim[idx]); + output_dim[i] = max_dim; + } + // matrix dims + output_dim[output_rank - 2] = h; + output_dim[output_rank - 1] = w; + + tt::tt_metal::Shape output_shape{output_dim}; + auto padding = output_shape.padding(); + // padding for t logmatrix dims + padding[output_rank - 2] = Padding::PadDimension{0, h - h_wo_padding}; + padding[output_rank - 1] = Padding::PadDimension{0, w - w_wo_padding}; + return Shape({tt::tt_metal::Shape(output_shape, padding)}); +} + +std::tuple +MorehMatmulOperation::invoke( + const Tensor &input, + const Tensor &other, + bool transpose_input, + bool transpose_other, + const std::optional &output, + const std::optional &bias, + const std::optional &output_memory_config, + const std::optional &compute_kernel_config) { + return { + MorehMatmulOperation::operation_attributes_t{ + transpose_input, + transpose_other, + output_memory_config.value_or(input.memory_config()), + compute_kernel_config}, + MorehMatmulOperation::tensor_args_t{input, other, output, bias}}; +} +} // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp new file mode 100644 index 00000000000..340befe8379 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" + +namespace ttnn::operations::moreh::moreh_matmul { +struct MorehMatmulOperation { + struct operation_attributes_t { + bool transpose_input; + bool transpose_other; + + const MemoryConfig output_memory_config; + const std::optional compute_kernel_config; + }; + + struct tensor_args_t { + const Tensor& input; + const Tensor& other; + + const std::optional& output; + const std::optional& bias; + }; + + using shape_return_value_t = Shape; + using tensor_return_value_t = Tensor; + + struct MultiCoreProgramFactory { + struct shared_variable_t { + KernelHandle reader_kernel_id; + KernelHandle writer_kernel_id; + std::size_t num_cores; + std::size_t num_cores_y; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; + + static void validate_inputs(const operation_attributes_t&, const tensor_args_t&); + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + static std::tuple invoke( + const Tensor& input, + const Tensor& other, + bool transpose_input, + bool transpose_other, + const std::optional& output, + const std::optional& bias, + const std::optional& output_memory_config, + const std::optional& compute_kernel_config); +}; + +void get_tensor_dim(std::vector& dim, const tt::tt_metal::Shape& shape); +std::vector find_reduce_dim(const tt::tt_metal::Shape& a_shape, const tt::tt_metal::Shape& b_shape); +bool is_same_batch_dim(const Tensor& tensor_a, const Tensor& tensor_b); + +} // namespace ttnn::operations::moreh::moreh_matmul + +namespace ttnn::prim { +constexpr auto moreh_matmul = + ttnn::register_operation<"ttnn::prim::moreh_matmul", ttnn::operations::moreh::moreh_matmul::MorehMatmulOperation>(); +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp new file mode 100644 index 00000000000..8a17bb43b28 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp @@ -0,0 +1,523 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul_device_operation.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" + +namespace ttnn::operations::moreh::moreh_matmul { + +void get_tensor_dim(std::vector &dim, const tt::tt_metal::Shape &shape) { + const auto rank = shape.rank(); + for (auto i = 0; i < rank; ++i) { + auto idx = rank - 1 - i; + + // last 2-dim + if (idx == rank - 1 || idx == rank - 2) { + dim[i] = shape[idx] / tt::constants::TILE_HEIGHT; + } else { + dim[i] = shape[idx]; + } + } + + log_debug(tt::LogOp, "rank {}", rank); + for (auto i = 0; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + log_debug(tt::LogOp, "dim[{}] = {}", i, dim[i]); + } +} + +std::vector find_reduce_dim(const tt::tt_metal::Shape &a_shape, const tt::tt_metal::Shape &b_shape) { + std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(a_dim, a_shape); + get_tensor_dim(b_dim, b_shape); + int32_t rank = std::max(a_shape.rank(), b_shape.rank()); + log_debug(tt::LogOp, "find_reduce_dim :{} rank {} a {} b {}", __LINE__, rank, a_shape.rank(), b_shape.rank()); + std::vector dims; + // batch dims + for (int i = 0; i < rank - 2; ++i) { + int idx = rank - 1 - i; + TT_ASSERT(idx >= 0); + if (a_dim[idx] != b_dim[idx]) { + dims.push_back(i); + log_debug(tt::LogOp, "find_reduce_dim :{} push {} dim", __LINE__, i); + } + } + return dims; +} + +bool is_same_batch_dim(const Tensor &tensor_a, const Tensor &tensor_b) { + // check batch dims + const auto &a_shape = tensor_a.get_shape().value; + const auto &b_shape = tensor_b.get_shape().value; + std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(a_dim, a_shape); + get_tensor_dim(b_dim, b_shape); + for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + if (a_dim[i] != b_dim[i]) { + log_debug(tt::LogOp, "{}:{} {} a_dim {} - b_dim {}", __func__, __LINE__, i, a_dim[i], b_dim[i]); + return false; + } + } + log_debug(tt::LogOp, "{}:{} batch dims are the same.", __func__, __LINE__); + return true; +} + +void get_tensor_stride(std::vector &stride, std::vector &dim) { + stride[0] = 1; + for (auto i = 1; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + stride[i] = stride[i - 1] * dim[i - 1]; + } + + for (auto i = 0; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + log_debug(tt::LogOp, "stride[{}] = {}", i, stride[i]); + } +} + +void get_not_bcast( + std::vector &input_not_bcast, + std::vector &input_dim, + std::vector &other_not_bcast, + std::vector &other_dim) { + // first 2-dims are M,K and K,N + // TODO: refaactoring + for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + if (input_dim[i] == other_dim[i]) { + input_not_bcast[i] = 1; + other_not_bcast[i] = 1; + } else { + if (input_dim[i] == 1) { + input_not_bcast[i] = 0; + other_not_bcast[i] = 1; + } else { + input_not_bcast[i] = 1; + other_not_bcast[i] = 0; + } + } + } + + for (auto i = 0; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { + log_debug(tt::LogOp, "not bcast [{}] input {} other {}", i, input_not_bcast[i], other_not_bcast[i]); + } +} + +MorehMatmulOperation::MultiCoreProgramFactory::cached_program_t MorehMatmulOperation::MultiCoreProgramFactory::create( + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &tensor_return_value) { + const Tensor &input = tensor_args.input; + const Tensor &other = tensor_args.other; + const Tensor &output = tensor_return_value; + + const std::optional &bias = tensor_args.bias; + + bool transpose_input = operation_attributes.transpose_input; + bool transpose_other = operation_attributes.transpose_other; + + const DeviceComputeKernelConfig &compute_kernel_config = init_device_compute_kernel_config( + input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); + ; + + //////////////////////////////////////////////////////////////////////////// + // Device Setup + //////////////////////////////////////////////////////////////////////////// + tt::tt_metal::Program program{}; + tt::tt_metal::Device *device{input.device()}; + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + tt::DataFormat cb_data_format{datatype_to_dataformat_converter(output.get_dtype())}; + const auto single_tile_size{tt::tt_metal::detail::TileSize(cb_data_format)}; + const auto num_output_tiles{output.volume() / tt::constants::TILE_HW}; + + // input tensor + const auto &input_shape = input.get_shape().value; + const auto &input_shape_wo_padding = input_shape.without_padding(); + const auto input_rank = input_shape.rank(); + log_debug(tt::LogOp, "input dim"); + std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(input_dim, input_shape); + + log_debug(tt::LogOp, "input stride"); + std::vector input_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + get_tensor_stride(input_stride, input_dim); + + // other tensor + const auto &other_shape = other.get_shape().value; + const auto &other_shape_wo_padding = other_shape.without_padding(); + const auto other_rank = other_shape.rank(); + log_debug(tt::LogOp, "other dim"); + std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(other_dim, other_shape); + + log_debug(tt::LogOp, "other stride"); + std::vector other_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + get_tensor_stride(other_stride, other_dim); + + log_debug(tt::LogOp, "not bcast"); + std::vector input_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + std::vector other_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_not_bcast(input_not_bcast, input_dim, other_not_bcast, other_dim); + + // output tensor + const auto &output_shape = output.get_shape().value; + const auto &output_shape_wo_padding = output_shape.without_padding(); + const auto output_rank = output_shape.rank(); + log_debug(tt::LogOp, "output dim"); + std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + get_tensor_dim(output_dim, output_shape); + + log_debug(tt::LogOp, "output stride"); + std::vector output_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + get_tensor_stride(output_stride, output_dim); + + // matrix shape + uint32_t Kt = (transpose_input) ? (input_shape[-2] / tt::constants::TILE_HEIGHT) + : (input_shape[-1] / tt::constants::TILE_WIDTH); + uint32_t Mt = (transpose_input) ? (input_shape[-1] / tt::constants::TILE_WIDTH) + : (input_shape[-2] / tt::constants::TILE_HEIGHT); + uint32_t Nt = (transpose_other) ? (other_shape[-2] / tt::constants::TILE_HEIGHT) + : (other_shape[-1] / tt::constants::TILE_WIDTH); + log_debug(tt::LogOp, "{}:{} Mt {} Nt {} Kt {}", __func__, __LINE__, Mt, Nt, Kt); + + // bias tensor + bool is_scalar_bias = false; + if (bias.has_value()) { + const auto &bias_tensor = bias.value(); + const auto &bias_shape_wo_padding = bias_tensor.get_shape().value.without_padding(); + is_scalar_bias = (bias_shape_wo_padding[-1] == 1) ? (true) : (false); + log_debug(tt::LogOp, "{}:{} bias tensor. is_scalar_bias {}", __func__, __LINE__, is_scalar_bias); + } + + // mask + uint32_t input_mask_h = input_shape_wo_padding[-2] % tt::constants::TILE_HEIGHT; + uint32_t input_mask_w = input_shape_wo_padding[-1] % tt::constants::TILE_WIDTH; + uint32_t other_mask_h = other_shape_wo_padding[-2] % tt::constants::TILE_HEIGHT; + uint32_t other_mask_w = other_shape_wo_padding[-1] % tt::constants::TILE_WIDTH; + + bool need_input_mask_h = (input_mask_h) ? (true) : (false); + bool need_input_mask_w = (input_mask_w) ? (true) : (false); + + bool need_other_mask_h = (other_mask_h) ? (true) : (false); + bool need_other_mask_w = (other_mask_w) ? (true) : (false); + + if (input_mask_h == 0) { + input_mask_h = tt::constants::TILE_HEIGHT; + } + if (input_mask_w == 0) { + input_mask_w = tt::constants::TILE_WIDTH; + } + if (other_mask_h == 0) { + other_mask_h = tt::constants::TILE_HEIGHT; + } + if (other_mask_w == 0) { + other_mask_w = tt::constants::TILE_WIDTH; + } + + log_debug( + tt::LogOp, + "{}:{} {} {} mask_h {} mask_w {}", + __func__, + __LINE__, + need_input_mask_h, + need_input_mask_w, + input_mask_h, + input_mask_w); + log_debug( + tt::LogOp, + "{}:{} {} {} mask_h {} mask_w {}", + __func__, + __LINE__, + need_other_mask_h, + need_other_mask_w, + other_mask_h, + other_mask_w); + + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + get_compute_kernel_config_args(device->arch(), compute_kernel_config); + log_debug( + tt::LogOp, + "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", + math_fidelity, + math_approx_mode, + fp32_dest_acc_en, + packer_l1_acc); + //////////////////////////////////////////////////////////////////////////// + // Core Grid Configuration For Workload + //////////////////////////////////////////////////////////////////////////// + auto grid = device->compute_with_storage_grid_size(); + const auto num_cores_y = grid.y; + + const auto + [num_cores, + all_cores, + core_group_1, + core_group_2, + num_output_tiles_per_core_group_1, + num_output_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(grid, num_output_tiles); + + log_debug(tt::LogOp, "{}:{} num_output_tiles: {}", __func__, __LINE__, num_output_tiles); + log_debug( + tt::LogOp, + "{}:{} num_output_tiles_per_core_group1: {}, 2: {} ", + __func__, + __LINE__, + num_output_tiles_per_core_group_1, + num_output_tiles_per_core_group_2); + //////////////////////////////////////////////////////////////////////////// + // CircularBuffer Setup + //////////////////////////////////////////////////////////////////////////// + const uint32_t in0_t{2}; // input + const uint32_t in1_t{2}; // other + const uint32_t in2_t{3}; // mask for input + const uint32_t in3_t{3}; // mask for other + const uint32_t in4_t{2}; // bias + const uint32_t im0_t{1}; // temp + const uint32_t im1_t{2}; // transpose for input + const uint32_t im2_t{2}; // transpose for other + const uint32_t out0_t{2}; // output + + tt::operations::primary::CreateCircularBuffer( + program, + all_cores, + cb_data_format, + { + {tt::CB::c_in0, in0_t}, + {tt::CB::c_in1, in1_t}, + {tt::CB::c_in2, in2_t}, + {tt::CB::c_in3, in3_t}, + {tt::CB::c_in4, in4_t}, + {tt::CB::c_intermed0, im0_t, (fp32_dest_acc_en) ? tt::DataFormat::Float32 : cb_data_format}, + {tt::CB::c_intermed1, im1_t}, + {tt::CB::c_intermed2, im2_t}, + {tt::CB::c_out0, out0_t}, + }); + + //////////////////////////////////////////////////////////////////////////// + // DataMovementKernel SetUp + //////////////////////////////////////////////////////////////////////////// + std::map reader_defines; + std::vector reader_compile_time_args = { + static_cast(tt::operations::primary::is_dram(input)), + static_cast(tt::operations::primary::is_dram(other)), + Kt, + static_cast(transpose_input), + static_cast(transpose_other), + input_mask_h, + input_mask_w, + other_mask_h, + other_mask_w, + }; + + if (bias.has_value()) { + reader_defines["FUSE_BIAS"] = "1"; + reader_compile_time_args.push_back(static_cast(tt::operations::primary::is_dram(bias))); + reader_compile_time_args.push_back(static_cast(is_scalar_bias)); + log_debug( + tt::LogOp, + "{}:{} bias tensor. is bias dram {}", + __func__, + __LINE__, + tt::operations::primary::is_dram(bias)); + } + + const std::vector writer_compile_time_args = { + static_cast(tt::operations::primary::is_dram(output))}; + + const auto reader_kernel_file = + "ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/reader_moreh_matmul.cpp"; + const auto writer_kernel_file = + "ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/writer_moreh_matmul.cpp"; + + const auto reader_kernel_id = tt::operations::primary::CreateReadKernel( + program, reader_kernel_file, all_cores, reader_compile_time_args, reader_defines); + const auto writer_kernel_id = + tt::operations::primary::CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args); + log_debug( + tt::LogOp, + "{}:{} DMVK is_dram(input): {}, is_dram(other): {}, is_dram(output): {}", + __func__, + __LINE__, + tt::operations::primary::is_dram(input), + tt::operations::primary::is_dram(other), + tt::operations::primary::is_dram(output)); + + //////////////////////////////////////////////////////////////////////////// + // ComputeKernel SetUp + //////////////////////////////////////////////////////////////////////////// + std::map compute_defines; + + const auto compute_kernel_file = "ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/kernels/moreh_matmul.cpp"; + std::vector compute_args_group_1 = { + num_output_tiles_per_core_group_1, // num_output_tiles + Mt, + Nt, + Kt, + static_cast(transpose_input), + static_cast(transpose_other), + input_mask_h, + input_mask_w, + other_mask_h, + other_mask_w}; + + if (bias.has_value()) { + compute_defines["FUSE_BIAS"] = "1"; + compute_args_group_1.push_back(static_cast(is_scalar_bias)); + } + + bool preserve_fp32_precision = false; + if (fp32_dest_acc_en) { + compute_defines["FP32_DEST_ACC_EN"] = "1"; + preserve_fp32_precision = true; + } + + const auto compute_kernel_1_id = tt::operations::primary::CreateComputeKernel( + program, + compute_kernel_file, + {core_group_1, num_output_tiles_per_core_group_1, compute_args_group_1}, + compute_defines, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode, + preserve_fp32_precision); + + std::optional compute_kernel_2_id = std::nullopt; + if (!core_group_2.ranges().empty()) { + std::vector compute_args_group_2 = { + num_output_tiles_per_core_group_2, // num_output_tiles + Mt, + Nt, + Kt, + static_cast(transpose_input), + static_cast(transpose_other), + input_mask_h, + input_mask_w, + other_mask_h, + other_mask_w}; + + if (bias.has_value()) { + compute_args_group_2.push_back(static_cast(is_scalar_bias)); + } + + compute_kernel_2_id = tt::operations::primary::CreateComputeKernel( + program, + compute_kernel_file, + {core_group_2, num_output_tiles_per_core_group_2, compute_args_group_2}, + compute_defines, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode, + preserve_fp32_precision); + } + log_debug( + tt::LogOp, + "{}:{} Compute ", + __func__, + __LINE__, + static_cast(transpose_input), + static_cast(transpose_other)); + + //////////////////////////////////////////////////////////////////////////// + // RuntimeArgs SetUp + //////////////////////////////////////////////////////////////////////////// + for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t num_output_tiles_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_output_tiles_per_core = num_output_tiles_per_core_group_1; + std::vector compute_rt_args; + compute_rt_args.push_back(num_tiles_written); + compute_rt_args.insert(compute_rt_args.end(), output_stride.begin(), output_stride.end()); + tt::tt_metal::SetRuntimeArgs(program, compute_kernel_1_id, core, compute_rt_args); + } else if (core_group_2.core_coord_in_core_ranges(core)) { + TT_FATAL(compute_kernel_2_id.has_value(), "Core not in specified core ranges"); + num_output_tiles_per_core = num_output_tiles_per_core_group_2; + std::vector compute_rt_args; + compute_rt_args.push_back(num_tiles_written); + compute_rt_args.insert(compute_rt_args.end(), output_stride.begin(), output_stride.end()); + tt::tt_metal::SetRuntimeArgs(program, compute_kernel_2_id.value(), core, compute_rt_args); + } else { + TT_THROW("Core not in specified core ranges"); + } + + std::vector reader_rt_args; + reader_rt_args.push_back(input.buffer()->address()); + reader_rt_args.push_back(other.buffer()->address()); + reader_rt_args.push_back(num_tiles_written); + reader_rt_args.push_back(num_output_tiles_per_core); + + // TODO: move some to compile args + reader_rt_args.insert(reader_rt_args.end(), input_stride.begin(), input_stride.end()); + reader_rt_args.insert(reader_rt_args.end(), other_stride.begin(), other_stride.end()); + reader_rt_args.insert(reader_rt_args.end(), output_stride.begin(), output_stride.end()); + reader_rt_args.insert(reader_rt_args.end(), input_not_bcast.begin(), input_not_bcast.end()); + reader_rt_args.insert(reader_rt_args.end(), other_not_bcast.begin(), other_not_bcast.end()); + + if (bias.has_value()) { + reader_rt_args.push_back(bias.value().buffer()->address()); + } + + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_rt_args); + + tt::tt_metal::SetRuntimeArgs( + program, + writer_kernel_id, + core, + {output.buffer()->address(), num_tiles_written, num_output_tiles_per_core}); + num_tiles_written += num_output_tiles_per_core; + } + + auto override_runtime_args_callback = [reader_kernel_id, writer_kernel_id, num_cores, num_cores_y]( + const void *operation, + Program &program, + const std::vector &input_tensors, + const std::vector> &optional_input_tensors, + const std::vector &output_tensors) { + + }; + + return {std::move(program), {reader_kernel_id, writer_kernel_id, num_cores, num_cores_y}}; +} + +void MorehMatmulOperation::MultiCoreProgramFactory::override_runtime_arguments( + cached_program_t &cached_program, + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &tensor_return_value) { + auto &program = cached_program.program; + auto &reader_kernel_id = cached_program.shared_variables.reader_kernel_id; + auto &writer_kernel_id = cached_program.shared_variables.writer_kernel_id; + auto num_cores = cached_program.shared_variables.num_cores; + auto num_cores_y = cached_program.shared_variables.num_cores_y; + + auto bias = tensor_args.bias; + + log_debug(tt::LogOp, "{}:{} args_callback ", __func__, __LINE__); + const auto input_address = tensor_args.input.buffer()->address(); + const auto other_address = tensor_args.other.buffer()->address(); + const auto output_address = tensor_return_value.buffer()->address(); + + for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = input_address; + runtime_args[1] = other_address; + + if (bias.has_value()) { + const auto bias_address = bias.value().buffer()->address(); + runtime_args[runtime_args.size() - 1] = bias_address; + } + } + + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = output_address; + } + } +} +} // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp new file mode 100644 index 00000000000..3b24592ac4e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul.hpp" + +#include "tt_dnn/op_library/moreh_helper_functions.hpp" +#include "ttnn/operations/moreh/moreh_dot_op/moreh_dot.hpp" +#include "ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp" + +namespace ttnn::operations::moreh::moreh_matmul { + +inline bool is_dot_forward(const Tensor& input, const Tensor& other, bool transpose_input, bool transpose_other) { + // TODO: non-4d support for dot. + if (input.get_legacy_shape().rank() != 4 || other.get_legacy_shape().rank() != 4) { + return false; + } + + if (transpose_input || transpose_other) { + return false; + } + + return tt::operations::primary::is_1d_tensor(input) && tt::operations::primary::is_1d_tensor(other) && + tt::operations::primary::is_same_shape(input, other); +} + +Tensor MorehMatmul::invoke( + const Tensor& input, + const Tensor& other, + bool transpose_input, + bool transpose_other, + const std::optional& output, + const std::optional bias, + const std::optional& output_mem_config, + const std::optional compute_kernel_config) { + if (is_dot_forward(input, other, transpose_input, transpose_other)) { + return ttnn::moreh_dot(input, other, input.get_dtype(), output_mem_config); + } + return ttnn::prim::moreh_matmul( + input, other, transpose_input, transpose_other, output, bias, output_mem_config, compute_kernel_config); +} +} // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp new file mode 100644 index 00000000000..5d1ecb5c23b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "ttnn/decorators.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +namespace ttnn::operations::moreh::moreh_matmul { +struct MorehMatmul { + static Tensor invoke( + const Tensor &input, + const Tensor &other, + bool transpose_input, + bool transpose_other, + const std::optional &output, + const std::optional bias, + const std::optional &output_mem_config, + const std::optional compute_kernel_config); +}; +} // namespace ttnn::operations::moreh::moreh_matmul + +namespace ttnn { +constexpr auto moreh_matmul = ttnn:: + register_operation_with_auto_launch_op<"ttnn::moreh_matmul", ttnn::operations::moreh::moreh_matmul::MorehMatmul>(); +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp new file mode 100644 index 00000000000..8d77ec5096a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul_pybind.hpp" + +#include "moreh_matmul.hpp" +#include "pybind11/cast.h" +#include "pybind11/decorators.hpp" +#include "ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp" + +namespace ttnn::operations::moreh::moreh_matmul { +void bind_moreh_matmul_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::moreh_matmul, + "Moreh moreh_matmul Operation", + ttnn::pybind_arguments_t{ + py::arg("input").noconvert(), + py::arg("other").noconvert(), + py::kw_only(), + py::arg("transpose_input").noconvert() = false, + py::arg("transpose_other").noconvert() = false, + py::arg("output").noconvert() = std::nullopt, + py::arg("bias").noconvert() = std::nullopt, + py::arg("memory_config").noconvert() = std::nullopt, + py::arg("compute_kernel_config").noconvert() = std::nullopt}); +} +} // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.hpp new file mode 100644 index 00000000000..2a5a13de3fa --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.hpp @@ -0,0 +1,14 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::moreh::moreh_matmul { +void bind_moreh_matmul_operation(py::module& module); +} // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp index 01612e373f7..24265cde7a9 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp @@ -9,6 +9,7 @@ #include "ttnn/operations/moreh/moreh_dot_op/moreh_dot_pybind.hpp" #include "ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp" +#include "ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.hpp" #include "ttnn/operations/moreh/moreh_mean/moreh_mean_pybind.hpp" #include "ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_nll_loss_backward/moreh_nll_loss_backward_pybind.hpp" @@ -27,5 +28,6 @@ void bind_moreh_operations(py::module &module) { moreh_dot_backward::bind_moreh_dot_backward_operation(module); moreh_nll_loss_unreduced_backward::bind_moreh_nll_loss_unreduced_backward_operation(module); moreh_nll_loss_backward::bind_moreh_nll_loss_backward_operation(module); + moreh_matmul::bind_moreh_matmul_operation(module); } } // namespace ttnn::operations::moreh diff --git a/ttnn/ttnn/operations/moreh.py b/ttnn/ttnn/operations/moreh.py index 7c8161ab9ac..4d5790af323 100644 --- a/ttnn/ttnn/operations/moreh.py +++ b/ttnn/ttnn/operations/moreh.py @@ -10,3 +10,4 @@ sum = ttnn._ttnn.operations.moreh.moreh_sum mean = ttnn._ttnn.operations.moreh.moreh_mean mean_backward = ttnn._ttnn.operations.moreh.moreh_mean_backward +matmul = ttnn._ttnn.operations.moreh.moreh_matmul