From 002ca3a127357c9d74176463d9207a6678a5339d Mon Sep 17 00:00:00 2001 From: Bui Chi Trung <52347285+BuiChiTrung@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:19:43 +0700 Subject: [PATCH] #13320: Impl uniform operation in ttnn (#13735) * #13320: add draft code to test dropout op #13320: rollback add_2_integers_in_compute changes #13320: add draft code for uniform in programming_examples #13320: check random numbers frequency using map #13320: add draft code #13320: try remove reader kernel and cb1 #13320: add skeleton code for uniform #13220: support from, to param #13320: remove reader kernel #13320: add cb_intermed0 #13320: add unit-test #13320: refactor writer kernel by exposing an sfpu API #13320: remove native dir #13320: refactor #13320: update unit test #13320: change compute_kernel_config param to const optional and add license #13320: add cb id as compile args #13320: raise exception when from = to #13320: skip test for grayskull #13320: revise uniform operation #13320: update program factory #13320: update unit-test for bfloat16 and refactor cb var name in kernels #13320: add feature to benchmark ttnn with torch uniform #13320: refactor #13320: update compute_output_shapes func * #13320: update unit-test * #13320: update sfpu api * #13320: remove un-used header in kernel and counter in sfpu API * #13320: remove TTI_SFPSETSGN instr * #13320: allow set fp32_dest_acc to false * #13320: update callback test * #13320: skip test for grayskull * #13320: update sfpu to generate random float * #13320: update sfpu to generate random float between range [from, to) * #13320: refactor writer kernel * #13320: update compute kernel * #13320: update sfpu to use TTI_ only * #13320: change cb format to float32 * #13320: update var name * #13320: draft commit to support generate random bfloat16 in sfpu * #13320: rollback to inital approach to handle bfloat16 * #13320: rebase with main * #13320: rm unused header --- .../unit_tests/operations/test_uniform.py | 157 ++++++++++++++++++ .../llk_api/llk_sfpu/ckernel_sfpu_rand.h | 53 ++++++ .../llk_math_eltwise_unary_sfpu_rand.h | 27 +++ .../compute_kernel_api/eltwise_unary/rand.h | 43 +++++ ttnn/CMakeLists.txt | 5 + ttnn/cpp/pybind11/operations/__init__.hpp | 8 +- .../device/kernels/compute_uniform.cpp | 43 +++++ .../uniform/device/kernels/writer_uniform.cpp | 56 +++++++ .../device/uniform_device_operation.cpp | 62 +++++++ .../device/uniform_device_operation.hpp | 68 ++++++++ .../device/uniform_program_factory.cpp | 145 ++++++++++++++++ ttnn/cpp/ttnn/operations/uniform/uniform.cpp | 19 +++ ttnn/cpp/ttnn/operations/uniform/uniform.hpp | 23 +++ .../operations/uniform/uniform_pybind.cpp | 40 +++++ .../operations/uniform/uniform_pybind.hpp | 13 ++ 15 files changed, 760 insertions(+), 2 deletions(-) create mode 100644 tests/ttnn/unit_tests/operations/test_uniform.py create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_rand.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_rand.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/rand.h create mode 100644 ttnn/cpp/ttnn/operations/uniform/device/kernels/compute_uniform.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/device/kernels/writer_uniform.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/uniform.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/uniform.hpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/uniform_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/uniform/uniform_pybind.hpp diff --git a/tests/ttnn/unit_tests/operations/test_uniform.py b/tests/ttnn/unit_tests/operations/test_uniform.py new file mode 100644 index 00000000000..0ee59766878 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_uniform.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import time +import pytest +import numpy as np +import ttnn +from collections import Counter +from loguru import logger +from tests.ttnn.unit_tests.operations.test_utils import ( + get_compute_kernel_options, + compute_kernel_options, + compute_kernel_ids, + get_lib_dtype, +) +from models.utility_functions import skip_for_grayskull +from enum import Enum + + +class TestMode(Enum): + VALIDATE = 0 + BENCHMARK = 1 + + +def check_torch_uniform_bfloat16(): + input = torch.zeros(10, 10, dtype=torch.bfloat16).uniform_(2.1, 2.11) + logger.info(input) + + +# With small tensor ttnn might be slower than torch +def benchmark_uniform(cpu_input, npu_input, rand_from, rand_to): + iter_num = 10 + + cpu_total_time = 0 + for i in range(iter_num + 1): + cpu_start_time = time.time_ns() + cpu_input.uniform_(rand_from, rand_to) + cpu_end_time = time.time_ns() + if i > 0: + cpu_total_time += cpu_end_time - cpu_start_time + logger.info(f"CPU avg time: {cpu_total_time / iter_num}ns") + + npu_total_time = 0 + for i in range(iter_num + 1): + npu_start_time = time.time_ns() + ttnn.uniform(npu_input, rand_from, rand_to) + npu_end_time = time.time_ns() + if i > 0: + npu_total_time += npu_end_time - npu_start_time + logger.info(f"NPU avg time: {npu_total_time / iter_num}ns") + + +def validate_uniform(npu_input, shape, rand_from, rand_to, dtype, compute_kernel_config): + ttnn.uniform(npu_input, rand_from, rand_to, compute_kernel_config=compute_kernel_config) + tt_input = ttnn.to_torch(npu_input).reshape(shape) + elem_cnt = Counter(tt_input.flatten().tolist()) + + expected_mean, expected_var = (rand_from + rand_to) / 2, pow(rand_to - rand_from, 2) / 12 + npu_mean, npu_var = torch.mean(tt_input).item(), torch.var(tt_input).item() + min_val, max_val = torch.min(tt_input).item(), torch.max(tt_input).item() + + logger.info(f"Distinct elements: {len(elem_cnt.keys())}") + if max_val == rand_to: + logger.info(f"Count max_val: {elem_cnt[max_val]}") + logger.info(f"Min val: {min_val}, Max val: {max_val}") + logger.info(f"Expected mean: {expected_mean}, NPU mean: {npu_mean}") + logger.info(f"Expected var: {expected_var}, NPU var: {npu_var}") + + assert torch.tensor(rand_from, dtype=get_lib_dtype(torch, dtype)) <= min_val and max_val < torch.tensor( + rand_to, dtype=get_lib_dtype(torch, dtype) + ) + assert np.allclose(npu_mean, expected_mean, rtol=0.5) + assert np.allclose(npu_var, expected_var, rtol=0.5) + + +def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, mode=TestMode.VALIDATE): + compute_kernel_config = get_compute_kernel_options(compute_kernel_options) + rand_from, rand_to = rand_range[0], rand_range[1] + cpu_input = torch.ones(shape, dtype=get_lib_dtype(torch, dtype)) + npu_input = ttnn.from_torch(cpu_input, device=device, dtype=get_lib_dtype(ttnn, dtype), layout=ttnn.TILE_LAYOUT) + + if mode == TestMode.BENCHMARK: + benchmark_uniform(cpu_input=cpu_input, npu_input=npu_input, rand_from=rand_from, rand_to=rand_to) + else: + validate_uniform( + npu_input=npu_input, + shape=shape, + rand_from=rand_from, + rand_to=rand_to, + dtype=dtype, + compute_kernel_config=compute_kernel_config, + ) + + +# fmt: off +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize("shape", + [ + [32, 32], + [64, 64], + [1, 512, 2, 256], + [512, 512], + [1024, 1024], + ], +) +@pytest.mark.parametrize("rand_range", + [ + [0, 1], + [2.1, 9], + [-5.1, 1.2] + ] +) +@pytest.mark.parametrize("dtype", + [ + "bfloat16", + "float32" + ] +) +# fmt: on +def test_uniform(shape, rand_range, dtype, device): + torch.manual_seed(0) + run_uniform(shape, rand_range, dtype, device) + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "shape", + [[2, 32, 32, 16]], +) +@pytest.mark.parametrize("rand_range", [[-3, 4]]) +@pytest.mark.parametrize("dtype", ["bfloat16", "float32"]) +def test_uniform_callback(shape, rand_range, dtype, device, use_program_cache): + torch.manual_seed(0) + num_program_cache_entries_list = [] + for i in range(2): + run_uniform(shape, rand_range, dtype, device) + # Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr + tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + num_program_cache_entries_list.append(device.num_program_cache_entries()) + logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") + assert num_program_cache_entries_list[0] > 0 + assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1] + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "shape", + [[512, 512], [5, 2, 4, 70, 40]], +) +@pytest.mark.parametrize("rand_range", [[0, 1]]) +@pytest.mark.parametrize("dtype", ["bfloat16", "float32"]) +@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) +def test_uniform_with_compute_kernel_options(shape, rand_range, dtype, device, compute_kernel_options): + torch.manual_seed(0) + run_uniform(shape, rand_range, dtype, device, compute_kernel_options) diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_rand.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_rand.h new file mode 100644 index 00000000000..fa93251e2f2 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_rand.h @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" + +using namespace sfpi; + +namespace ckernel::sfpu { + +template +inline void rand_init(uint32_t seed) { + init_prng_seed(seed); +} + +template +inline void rand(uint32_t from, uint32_t scale) { + // Load scale param to lreg1 + TT_SFPLOADI(p_sfpu::LREG1, 10, scale & 0xFFFF); + TT_SFPLOADI(p_sfpu::LREG1, 8, scale >> 16); + + // Load from param to lreg2 + TT_SFPLOADI(p_sfpu::LREG2, 10, from & 0xFFFF); + TT_SFPLOADI(p_sfpu::LREG2, 8, from >> 16); + +#pragma GCC unroll 0 + for (int d = 0; d < 8; d++) { + // Generate random float + TTI_SFPMOV(0, 9, p_sfpu::LREG0, 8); + + // Unset sign bit and Set exponent to 127 to ensure the float is within the range [1, 2). + // lreg0.sign = 0 + // lreg0 = {sign: 0, exponent: 127, mantissa: lreg0.mantissa} + TTI_SFPSETSGN(0, p_sfpu::LREG0, p_sfpu::LREG0, 1); + TTI_SFPSETEXP(127, p_sfpu::LREG0, p_sfpu::LREG0, 1); + + // -1 to ensure the float is within the range [0, 1). + // lreg0 = lreg0 - 1 + TTI_SFPADDI(0xbf80 /*-1*/, p_sfpu::LREG0, 0); + TTI_SFPNOP; + + // Scale the float from [0, 1) to [from, from + scale) + // lreg0 = lreg0 * scale + from + TTI_SFPMAD(p_sfpu::LREG0, p_sfpu::LREG1, p_sfpu::LREG2, p_sfpu::LREG0, 1); + TTI_SFPNOP; + + TTI_SFPSTORE(0, 3, 3, 0); + dst_reg++; + } +} +} // namespace ckernel::sfpu diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_rand.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_rand.h new file mode 100644 index 00000000000..31a66fc820f --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_rand.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_instr_params.h" +#include "ckernel_sfpu_rand.h" +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_params.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_rand_init(uint32_t seed = 0) { + llk_math_eltwise_unary_sfpu_init(sfpu::rand_init, seed); +} + +template +inline void llk_math_eltwise_unary_sfpu_rand(uint32_t dst_index, uint32_t from, uint32_t scale) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::rand, dst_index, VectorMode::RC, from, scale); +} + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/rand.h b/tt_metal/include/compute_kernel_api/eltwise_unary/rand.h new file mode 100644 index 00000000000..dd20c6b2c2a --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/rand.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_rand.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +/** + * Performs element-wise rand on each element of a of a tile in DST register at index tile_index. + * That is each element is overwritten with a randomly generated float. + * The DST register buffer must be in acquired state via *acquire_dst* call. + * This call is blocking and is only available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|-----------| + * | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be + * less than the size of the DST register buffer | True | | from | Random range lowerbound(inclusive) | + * uint | Any number | True | | scale | Random scale rand + * float in range [from, from + scale] | uint | Must be greater than 0 | True | + */ +ALWI void rand_tile(uint32_t idst, uint32_t from, uint32_t scale) { + MATH((llk_math_eltwise_unary_sfpu_rand(idst, from, scale))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void rand_tile_init(uint32_t seed) { MATH((llk_math_eltwise_unary_sfpu_rand_init(seed))); } + +} // namespace ckernel diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 38b2bd414b2..5f1d7d0ff1d 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -369,6 +369,11 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 361a6c71867..0083e077e7f 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -13,7 +13,6 @@ #include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" #include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp" #include "ttnn/operations/conv/conv2d/conv2d_pybind.hpp" -#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp" #include "ttnn/operations/data_movement/data_movement_pybind.hpp" #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" @@ -27,9 +26,9 @@ #include "ttnn/operations/embedding/embedding_pybind.hpp" #include "ttnn/operations/embedding_backward/embedding_backward_pybind.hpp" #include "ttnn/operations/examples/examples_pybind.hpp" -#include "ttnn/operations/full_like/full_like_pybind.hpp" #include "ttnn/operations/experimental/experimental_pybind.hpp" #include "ttnn/operations/full/full_pybind.hpp" +#include "ttnn/operations/full_like/full_like_pybind.hpp" #include "ttnn/operations/kv_cache/kv_cache_pybind.hpp" #include "ttnn/operations/loss/loss_pybind.hpp" #include "ttnn/operations/matmul/matmul_pybind.hpp" @@ -40,7 +39,9 @@ #include "ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp" #include "ttnn/operations/pool/upsample/upsample_pybind.hpp" #include "ttnn/operations/reduction/reduction_pybind.hpp" +#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp" #include "ttnn/operations/transformer/transformer_pybind.hpp" +#include "ttnn/operations/uniform/uniform_pybind.hpp" namespace py = pybind11; @@ -144,6 +145,9 @@ void py_module(py::module& module) { auto m_full_like = module.def_submodule("full_like", "full_like operation"); full_like::bind_full_like_operation(m_full_like); + + auto m_uniform = module.def_submodule("uniform", "uniform operations"); + uniform::bind_uniform_operation(m_uniform); } } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/uniform/device/kernels/compute_uniform.cpp b/ttnn/cpp/ttnn/operations/uniform/device/kernels/compute_uniform.cpp new file mode 100644 index 00000000000..fa2220e8edc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/device/kernels/compute_uniform.cpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "compute_kernel_api.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" +#include "compute_kernel_api/eltwise_unary/rand.h" + +namespace NAMESPACE { + +void MAIN { + constexpr uint32_t intermed_cb_id = get_compile_time_arg_val(0); + + const uint32_t seed = get_arg_val(0); + union { + float f; + uint32_t u; + } f2u_from, f2u_to, f2u_scale; + f2u_from.u = get_arg_val(1); + f2u_to.u = get_arg_val(2); + f2u_scale.f = f2u_to.f - f2u_from.f; + const uint32_t start_id = get_arg_val(3); + const uint32_t num_tiles = get_arg_val(4); + const uint32_t end_id = start_id + num_tiles; + + init_sfpu(intermed_cb_id); + + rand_tile_init(seed); + for (uint32_t i = start_id; i < end_id; ++i) { + cb_reserve_back(intermed_cb_id, 1); + + tile_regs_acquire(); + rand_tile(0, f2u_from.u, f2u_scale.u); + tile_regs_commit(); + + tile_regs_wait(); + pack_tile(0, intermed_cb_id, 0); + tile_regs_release(); + + cb_push_back(intermed_cb_id, 1); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/uniform/device/kernels/writer_uniform.cpp b/ttnn/cpp/ttnn/operations/uniform/device/kernels/writer_uniform.cpp new file mode 100644 index 00000000000..c13e90f96d1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/device/kernels/writer_uniform.cpp @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "common/constants.hpp" +#include "dataflow_api.h" + +using namespace tt; + +void kernel_main() { + constexpr uint32_t intermed_cb_id = get_compile_time_arg_val(0); + constexpr uint32_t dst_cb_id = get_compile_time_arg_val(1); + constexpr bool output_is_dram = get_compile_time_arg_val(2) == 1; + + uint32_t dst_addr = get_arg_val(0); + uint32_t start_id = get_arg_val(1); + uint32_t num_tiles = get_arg_val(2); + uint32_t end_id = start_id + num_tiles; + + const InterleavedAddrGenFast output_addrg = { + .bank_base_address = dst_addr, .page_size = get_tile_size(dst_cb_id), .data_format = get_dataformat(dst_cb_id)}; + + cb_reserve_back(dst_cb_id, 1); + uint32_t dst_cb_write_ptr = get_write_ptr(dst_cb_id); + + for (uint32_t i = start_id; i < end_id; ++i) { + cb_wait_front(intermed_cb_id, 1); + + uint32_t intermed_cb_read_ptr = get_read_ptr(intermed_cb_id); + auto intermed_cb_addr = reinterpret_cast(intermed_cb_read_ptr); + +#ifdef OUTPUT_DTYPE_FLOAT32 + noc_async_write_tile(i, output_addrg, intermed_cb_read_ptr); + noc_async_write_barrier(); + cb_pop_front(intermed_cb_id, 1); +#endif + +#ifdef OUTPUT_DTYPE_BFLOAT16 + auto dst_cb_addr = reinterpret_cast(dst_cb_write_ptr); + for (uint32_t k = 0; k < constants::TILE_WIDTH; k++) { + for (uint32_t j = 0; j < constants::TILE_HEIGHT; j++) { + float rand_float = *intermed_cb_addr; + + uint16_t *uint16_ptr = reinterpret_cast(&rand_float) + 1; + *(uint16_t *)dst_cb_addr = *uint16_ptr; + dst_cb_addr += 2; + intermed_cb_addr += 1; + } + } + cb_pop_front(intermed_cb_id, 1); + + noc_async_write_tile(i, output_addrg, dst_cb_write_ptr); + noc_async_write_barrier(); +#endif + } +} diff --git a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp new file mode 100644 index 00000000000..031af15dbb1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "uniform_device_operation.hpp" + +namespace ttnn::operations::uniform { + +UniformDeviceOperation::program_factory_t UniformDeviceOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return ProgramFactory{}; +} + +void UniformDeviceOperation::validate_inputs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + TT_FATAL(tensor_args.input.storage_type() == StorageType::DEVICE, "Uniform: Input tensor need to be on device"); + TT_FATAL(tensor_args.input.buffer() != nullptr, "Uniform: Input tensor need to be allocated in buffers on device"); + TT_FATAL((tensor_args.input.get_layout() == Layout::TILE), "Uniform: Input tensor must be tilized"); + TT_FATAL( + tensor_args.input.get_dtype() == DataType::BFLOAT16 || tensor_args.input.get_dtype() == DataType::FLOAT32, + "Uniform: Input tensor must be Float32 or Bfloat16"); + TT_FATAL(operation_attributes.from < operation_attributes.to, "Uniform: from param must be < to"); +} + +void UniformDeviceOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +} + +void UniformDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +} + +UniformDeviceOperation::shape_return_value_t UniformDeviceOperation::compute_output_shapes( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return tensor_args.input.get_logical_shape(); +} + +UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + // Since this is an in-place operation, return the input tensor to be updated directly + return tensor_args.input; +} + +std::tuple +UniformDeviceOperation::invoke( + const Tensor& input, + const float from, + const float to, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + return { + operation_attributes_t{ + from, + to, + memory_config.value_or(input.memory_config()), + init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)}, + tensor_args_t{input}}; +} + +} // namespace ttnn::operations::uniform diff --git a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp new file mode 100644 index 00000000000..372c9dcc6ae --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/decorators.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" + +namespace ttnn::operations::uniform { + +struct UniformDeviceOperation { + struct operation_attributes_t { + const float from; + const float to; + const MemoryConfig memory_config; + const DeviceComputeKernelConfig compute_kernel_config; + }; + + struct tensor_args_t { + const Tensor& input; + }; + + using shape_return_value_t = SimpleShape; + using tensor_return_value_t = Tensor; + + struct ProgramFactory { + struct shared_variables_t { + KernelHandle compute_kernel_id; + KernelHandle writer_kernel_id; + std::vector cores; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + }; + + using program_factory_t = std::variant; + + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + + static std::tuple invoke( + const Tensor& input, + const float from, + const float to, + const std::optional& memory_config, + const std::optional& compute_kernel_config); +}; + +} // namespace ttnn::operations::uniform + +namespace ttnn::prim { +constexpr auto uniform = + ttnn::register_operation<"ttnn::prim::uniform", ttnn::operations::uniform::UniformDeviceOperation>(); +} // namespace ttnn::prim diff --git a/ttnn/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp b/ttnn/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp new file mode 100644 index 00000000000..3afc57437e3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +#include "common/constants.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/tensor/types.hpp" +#include "uniform_device_operation.hpp" + +namespace ttnn::operations::uniform { + +using namespace tt; +using namespace tt::tt_metal; + +std::mt19937 rng(std::time(nullptr)); +std::uniform_int_distribution distribution(1, std::numeric_limits::max()); + +auto get_random_seed() -> uint32_t { return distribution(rng); } + +UniformDeviceOperation::ProgramFactory::cached_program_t UniformDeviceOperation::ProgramFactory::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + Device* device = output.device(); + auto grid = device->compute_with_storage_grid_size(); + auto core_h = grid.y; + + uint32_t units_to_divide = output.volume() / constants::TILE_HW; + auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = + split_work_to_cores(grid, units_to_divide); + + uint32_t num_cores_x = grid.x; + uint32_t num_cores_y = grid.y; + auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y); + + Program program = Program(); + + DataType output_dtype = output.dtype(); + auto out_data_format = datatype_to_dataformat_converter(output_dtype); + const uint32_t dtype_tile_size = tile_size(out_data_format); + const uint32_t intermed_tile_size = tile_size(tt::DataFormat::Float32); + + constexpr uint32_t in_out_num_tiles = 1; + constexpr uint32_t intermed_num_tiles = 2; + + constexpr uint32_t intermed_cb_id = CB::c_intermed0; + CircularBufferConfig cb_intermed_config = + CircularBufferConfig(intermed_num_tiles * intermed_tile_size, {{intermed_cb_id, tt::DataFormat::Float32}}) + .set_page_size(intermed_cb_id, intermed_tile_size); + CBHandle cb_intermed = tt_metal::CreateCircularBuffer(program, all_cores, cb_intermed_config); + + constexpr uint32_t dst_cb_id = CB::c_in0; + CircularBufferConfig cb_output_config = + CircularBufferConfig(in_out_num_tiles * dtype_tile_size, {{dst_cb_id, out_data_format}}) + .set_page_size(dst_cb_id, dtype_tile_size); + CBHandle cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + + const std::string kernels_dir_path = "ttnn/cpp/ttnn/operations/uniform/device/kernels/"; + const uint32_t output_is_dram = output.buffer()->buffer_type() == BufferType::DRAM ? 1 : 0; + const std::vector writer_compile_time_args{intermed_cb_id, dst_cb_id, output_is_dram}; + const std::string writer_file_path = kernels_dir_path + "writer_uniform.cpp"; + const std::vector compute_compile_time_args{intermed_cb_id}; + const std::string compute_file_path = kernels_dir_path + "compute_uniform.cpp"; + + std::map writer_defines; + switch (output_dtype) { + case DataType::BFLOAT16: writer_defines["OUTPUT_DTYPE_BFLOAT16"] = "1"; break; + case DataType::FLOAT32: writer_defines["OUTPUT_DTYPE_FLOAT32"] = "1"; break; + default: break; + } + + KernelHandle writer_kernel_id = tt_metal::CreateKernel( + program, writer_file_path, all_cores, WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = + get_compute_kernel_config_args(device->arch(), operation_attributes.compute_kernel_config); + KernelHandle compute_kernel_id = CreateKernel( + program, + compute_file_path, + all_cores, + ComputeConfig{ + .math_fidelity = math_fidelity, + .fp32_dest_acc_en = true, // if fp32_dest_acc_en set to false a precision error may occur which makes + // generated number out of range [from, to) + .dst_full_sync_en = dst_full_sync_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_compile_time_args, + }); + + uint32_t tile_offset = 0; + for (const auto& core : cores) { + uint32_t units_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges"); + } + + const float eps = 1e-6; + union { + float f; + uint32_t u; + } f2u_from, f2u_to; + f2u_from.f = operation_attributes.from; + f2u_to.f = operation_attributes.to - eps; // -eps make sure that generated number is < operation_attributes.to + std::vector compute_runtime_args = { + get_random_seed(), f2u_from.u, f2u_to.u, tile_offset, units_per_core}; + SetRuntimeArgs(program, compute_kernel_id, core, compute_runtime_args); + + std::vector writer_runtime_args = {output.buffer()->address(), tile_offset, units_per_core}; + SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); + + tile_offset += units_per_core; + } + + return { + std::move(program), + {.compute_kernel_id = compute_kernel_id, .writer_kernel_id = writer_kernel_id, .cores = cores}}; +} + +void UniformDeviceOperation::ProgramFactory::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto& program = cached_program.program; + auto& writer_kernel_id = cached_program.shared_variables.writer_kernel_id; + auto& compute_kernel_id = cached_program.shared_variables.compute_kernel_id; + auto& cores = cached_program.shared_variables.cores; + + const uint32_t output_addr = output.buffer()->address(); + + for (const auto& core : cores) { + { + auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + runtime_args[0] = get_random_seed(); + } + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = output_addr; + } + } +} + +} // namespace ttnn::operations::uniform diff --git a/ttnn/cpp/ttnn/operations/uniform/uniform.cpp b/ttnn/cpp/ttnn/operations/uniform/uniform.cpp new file mode 100644 index 00000000000..755cac6a74c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/uniform.cpp @@ -0,0 +1,19 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "uniform.hpp" + +#include "device/uniform_device_operation.hpp" + +namespace ttnn::operations::uniform { +Tensor Uniform::invoke( + const Tensor &input, + const float from, + const float to, + const std::optional &memory_config, + const std::optional &compute_kernel_config) { + return ttnn::prim::uniform(input, from, to, memory_config, compute_kernel_config); +} +} // namespace ttnn::operations::uniform diff --git a/ttnn/cpp/ttnn/operations/uniform/uniform.hpp b/ttnn/cpp/ttnn/operations/uniform/uniform.hpp new file mode 100644 index 00000000000..6fbbb44059f --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/uniform.hpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "ttnn/decorators.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" + +namespace ttnn::operations::uniform { +struct Uniform { + static Tensor invoke( + const Tensor& input, + const float from, + const float to, + const std::optional& memory_config, + const std::optional& compute_kernel_config); +}; +} // namespace ttnn::operations::uniform + +namespace ttnn { +constexpr auto uniform = + ttnn::register_operation_with_auto_launch_op<"ttnn::uniform", ttnn::operations::uniform::Uniform>(); +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.cpp b/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.cpp new file mode 100644 index 00000000000..dc31fff1b31 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "uniform_pybind.hpp" + +#include "pybind11/decorators.hpp" +#include "uniform.hpp" + +namespace ttnn::operations::uniform { +void bind_uniform_operation(py::module &module) { + auto doc = + R"doc(uniform(input: Tensor, from: float = 0, to: float = 1, memory_config: Optional[MemoryConfig] = None, compute_kernel_config: Optional[ComputeKernelConfig] = None) -> Tensor + Generates a tensor with values drawn from a uniform distribution [`from`, `to`). The input tensor provides the shape for the output tensor, while the data type remains unchanged. + This operation allows configuration of memory allocation using `memory_config` and computation settings via `compute_kernel_config`. + + Args: + * :attr:`input`: The tensor that provides the shape for the generated uniform tensor. + * :attr:`from`: The lower bound of the uniform distribution. Defaults to 0. + * :attr:`to`: The upper bound of the uniform distribution. Defaults to 1. + * :attr:`memory_config`: The memory configuration for the generated tensor. + * :attr:`compute_kernel_config`: Optional configuration for the compute kernel used during generation. + + Returns: + Tensor: A new tensor with the same shape as `input` and values drawn from the specified uniform distribution. + )doc"; + + bind_registered_operation( + module, + ttnn::uniform, + doc, + ttnn::pybind_arguments_t{ + py::arg("input"), + py::arg("from") = 0, + py::arg("to") = 1, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt}); +} +} // namespace ttnn::operations::uniform diff --git a/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.hpp b/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.hpp new file mode 100644 index 00000000000..cf3ebe50614 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/uniform/uniform_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::uniform { +void bind_uniform_operation(py::module &module); +} // namespace ttnn::operations::uniform