Skip to content

Commit

Permalink
#13320: Impl uniform operation in ttnn (#13735)
Browse files Browse the repository at this point in the history
* #13320: add draft code to test dropout op

#13320: rollback add_2_integers_in_compute changes

#13320: add draft code for uniform in programming_examples

#13320: check random numbers frequency using map

#13320: add draft code

#13320: try remove reader kernel and cb1

#13320: add skeleton code for uniform

#13220: support from, to param

#13320: remove reader kernel

#13320: add cb_intermed0

#13320: add unit-test

#13320: refactor writer kernel by exposing an sfpu API

#13320: remove native dir

#13320: refactor

#13320: update unit test

#13320: change compute_kernel_config param to const optional and add license

#13320: add cb id as compile args

#13320: raise exception when from = to

#13320: skip test for grayskull

#13320: revise uniform operation

#13320: update program factory

#13320: update unit-test for bfloat16 and refactor cb var name in kernels

#13320: add feature to benchmark ttnn with torch uniform

#13320: refactor

#13320: update compute_output_shapes func

* #13320: update unit-test

* #13320: update sfpu api

* #13320: remove un-used header in kernel and counter in sfpu API

* #13320: remove TTI_SFPSETSGN instr

* #13320: allow set fp32_dest_acc to false

* #13320: update callback test

* #13320: skip test for grayskull

* #13320: update sfpu to generate random float

* #13320: update sfpu to generate random float between range [from, to)

* #13320: refactor writer kernel

* #13320: update compute kernel

* #13320: update sfpu to use TTI_ only

* #13320: change cb format to float32

* #13320: update var name

* #13320: draft commit to support generate random bfloat16 in sfpu

* #13320: rollback to inital approach to handle bfloat16

* #13320: rebase with main

* #13320: rm unused header
  • Loading branch information
BuiChiTrung authored Oct 24, 2024
1 parent fe9560b commit 002ca3a
Show file tree
Hide file tree
Showing 15 changed files with 760 additions and 2 deletions.
157 changes: 157 additions & 0 deletions tests/ttnn/unit_tests/operations/test_uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import time
import pytest
import numpy as np
import ttnn
from collections import Counter
from loguru import logger
from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
get_lib_dtype,
)
from models.utility_functions import skip_for_grayskull
from enum import Enum


class TestMode(Enum):
VALIDATE = 0
BENCHMARK = 1


def check_torch_uniform_bfloat16():
input = torch.zeros(10, 10, dtype=torch.bfloat16).uniform_(2.1, 2.11)
logger.info(input)


# With small tensor ttnn might be slower than torch
def benchmark_uniform(cpu_input, npu_input, rand_from, rand_to):
iter_num = 10

cpu_total_time = 0
for i in range(iter_num + 1):
cpu_start_time = time.time_ns()
cpu_input.uniform_(rand_from, rand_to)
cpu_end_time = time.time_ns()
if i > 0:
cpu_total_time += cpu_end_time - cpu_start_time
logger.info(f"CPU avg time: {cpu_total_time / iter_num}ns")

npu_total_time = 0
for i in range(iter_num + 1):
npu_start_time = time.time_ns()
ttnn.uniform(npu_input, rand_from, rand_to)
npu_end_time = time.time_ns()
if i > 0:
npu_total_time += npu_end_time - npu_start_time
logger.info(f"NPU avg time: {npu_total_time / iter_num}ns")


def validate_uniform(npu_input, shape, rand_from, rand_to, dtype, compute_kernel_config):
ttnn.uniform(npu_input, rand_from, rand_to, compute_kernel_config=compute_kernel_config)
tt_input = ttnn.to_torch(npu_input).reshape(shape)
elem_cnt = Counter(tt_input.flatten().tolist())

expected_mean, expected_var = (rand_from + rand_to) / 2, pow(rand_to - rand_from, 2) / 12
npu_mean, npu_var = torch.mean(tt_input).item(), torch.var(tt_input).item()
min_val, max_val = torch.min(tt_input).item(), torch.max(tt_input).item()

logger.info(f"Distinct elements: {len(elem_cnt.keys())}")
if max_val == rand_to:
logger.info(f"Count max_val: {elem_cnt[max_val]}")
logger.info(f"Min val: {min_val}, Max val: {max_val}")
logger.info(f"Expected mean: {expected_mean}, NPU mean: {npu_mean}")
logger.info(f"Expected var: {expected_var}, NPU var: {npu_var}")

assert torch.tensor(rand_from, dtype=get_lib_dtype(torch, dtype)) <= min_val and max_val < torch.tensor(
rand_to, dtype=get_lib_dtype(torch, dtype)
)
assert np.allclose(npu_mean, expected_mean, rtol=0.5)
assert np.allclose(npu_var, expected_var, rtol=0.5)


def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, mode=TestMode.VALIDATE):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
rand_from, rand_to = rand_range[0], rand_range[1]
cpu_input = torch.ones(shape, dtype=get_lib_dtype(torch, dtype))
npu_input = ttnn.from_torch(cpu_input, device=device, dtype=get_lib_dtype(ttnn, dtype), layout=ttnn.TILE_LAYOUT)

if mode == TestMode.BENCHMARK:
benchmark_uniform(cpu_input=cpu_input, npu_input=npu_input, rand_from=rand_from, rand_to=rand_to)
else:
validate_uniform(
npu_input=npu_input,
shape=shape,
rand_from=rand_from,
rand_to=rand_to,
dtype=dtype,
compute_kernel_config=compute_kernel_config,
)


# fmt: off
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize("shape",
[
[32, 32],
[64, 64],
[1, 512, 2, 256],
[512, 512],
[1024, 1024],
],
)
@pytest.mark.parametrize("rand_range",
[
[0, 1],
[2.1, 9],
[-5.1, 1.2]
]
)
@pytest.mark.parametrize("dtype",
[
"bfloat16",
"float32"
]
)
# fmt: on
def test_uniform(shape, rand_range, dtype, device):
torch.manual_seed(0)
run_uniform(shape, rand_range, dtype, device)


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"shape",
[[2, 32, 32, 16]],
)
@pytest.mark.parametrize("rand_range", [[-3, 4]])
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
def test_uniform_callback(shape, rand_range, dtype, device, use_program_cache):
torch.manual_seed(0)
num_program_cache_entries_list = []
for i in range(2):
run_uniform(shape, rand_range, dtype, device)
# Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr
tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"shape",
[[512, 512], [5, 2, 4, 70, 40]],
)
@pytest.mark.parametrize("rand_range", [[0, 1]])
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_uniform_with_compute_kernel_options(shape, rand_range, dtype, device, compute_kernel_options):
torch.manual_seed(0)
run_uniform(shape, rand_range, dtype, device, compute_kernel_options)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"

using namespace sfpi;

namespace ckernel::sfpu {

template <bool APPROXIMATION_MODE>
inline void rand_init(uint32_t seed) {
init_prng_seed(seed);
}

template <bool APPROXIMATION_MODE>
inline void rand(uint32_t from, uint32_t scale) {
// Load scale param to lreg1
TT_SFPLOADI(p_sfpu::LREG1, 10, scale & 0xFFFF);
TT_SFPLOADI(p_sfpu::LREG1, 8, scale >> 16);

// Load from param to lreg2
TT_SFPLOADI(p_sfpu::LREG2, 10, from & 0xFFFF);
TT_SFPLOADI(p_sfpu::LREG2, 8, from >> 16);

#pragma GCC unroll 0
for (int d = 0; d < 8; d++) {
// Generate random float
TTI_SFPMOV(0, 9, p_sfpu::LREG0, 8);

// Unset sign bit and Set exponent to 127 to ensure the float is within the range [1, 2).
// lreg0.sign = 0
// lreg0 = {sign: 0, exponent: 127, mantissa: lreg0.mantissa}
TTI_SFPSETSGN(0, p_sfpu::LREG0, p_sfpu::LREG0, 1);
TTI_SFPSETEXP(127, p_sfpu::LREG0, p_sfpu::LREG0, 1);

// -1 to ensure the float is within the range [0, 1).
// lreg0 = lreg0 - 1
TTI_SFPADDI(0xbf80 /*-1*/, p_sfpu::LREG0, 0);
TTI_SFPNOP;

// Scale the float from [0, 1) to [from, from + scale)
// lreg0 = lreg0 * scale + from
TTI_SFPMAD(p_sfpu::LREG0, p_sfpu::LREG1, p_sfpu::LREG2, p_sfpu::LREG0, 1);
TTI_SFPNOP;

TTI_SFPSTORE(0, 3, 3, 0);
dst_reg++;
}
}
} // namespace ckernel::sfpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_instr_params.h"
#include "ckernel_sfpu_rand.h"
#include "llk_math_eltwise_unary_sfpu_init.h"
#include "llk_math_eltwise_unary_sfpu_params.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_rand_init(uint32_t seed = 0) {
llk_math_eltwise_unary_sfpu_init<SfpuType::unused, APPROXIMATE>(sfpu::rand_init<APPROXIMATE>, seed);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_rand(uint32_t dst_index, uint32_t from, uint32_t scale) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::rand<APPROXIMATE>, dst_index, VectorMode::RC, from, scale);
}

} // namespace ckernel
43 changes: 43 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/rand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_rand.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif

namespace ckernel {

/**
* Performs element-wise rand on each element of a of a tile in DST register at index tile_index.
* That is each element is overwritten with a randomly generated float.
* The DST register buffer must be in acquired state via *acquire_dst* call.
* This call is blocking and is only available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid
* Range | Required |
* |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|-----------|
* | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be
* less than the size of the DST register buffer | True | | from | Random range lowerbound(inclusive) |
* uint | Any number | True | | scale | Random scale rand
* float in range [from, from + scale] | uint | Must be greater than 0 | True |
*/
ALWI void rand_tile(uint32_t idst, uint32_t from, uint32_t scale) {
MATH((llk_math_eltwise_unary_sfpu_rand<APPROX>(idst, from, scale)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void rand_tile_init(uint32_t seed) { MATH((llk_math_eltwise_unary_sfpu_rand_init<APPROX>(seed))); }

} // namespace ckernel
5 changes: 5 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp
Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp"
#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_pybind.hpp"
#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp"
#include "ttnn/operations/data_movement/data_movement_pybind.hpp"
#include "ttnn/operations/eltwise/binary/binary_pybind.hpp"
#include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp"
Expand All @@ -27,9 +26,9 @@
#include "ttnn/operations/embedding/embedding_pybind.hpp"
#include "ttnn/operations/embedding_backward/embedding_backward_pybind.hpp"
#include "ttnn/operations/examples/examples_pybind.hpp"
#include "ttnn/operations/full_like/full_like_pybind.hpp"
#include "ttnn/operations/experimental/experimental_pybind.hpp"
#include "ttnn/operations/full/full_pybind.hpp"
#include "ttnn/operations/full_like/full_like_pybind.hpp"
#include "ttnn/operations/kv_cache/kv_cache_pybind.hpp"
#include "ttnn/operations/loss/loss_pybind.hpp"
#include "ttnn/operations/matmul/matmul_pybind.hpp"
Expand All @@ -40,7 +39,9 @@
#include "ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp"
#include "ttnn/operations/pool/upsample/upsample_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"
#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp"
#include "ttnn/operations/transformer/transformer_pybind.hpp"
#include "ttnn/operations/uniform/uniform_pybind.hpp"

namespace py = pybind11;

Expand Down Expand Up @@ -144,6 +145,9 @@ void py_module(py::module& module) {

auto m_full_like = module.def_submodule("full_like", "full_like operation");
full_like::bind_full_like_operation(m_full_like);

auto m_uniform = module.def_submodule("uniform", "uniform operations");
uniform::bind_uniform_operation(m_uniform);
}
} // namespace operations

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "compute_kernel_api.h"
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/eltwise_unary/rand.h"

namespace NAMESPACE {

void MAIN {
constexpr uint32_t intermed_cb_id = get_compile_time_arg_val(0);

const uint32_t seed = get_arg_val<uint32_t>(0);
union {
float f;
uint32_t u;
} f2u_from, f2u_to, f2u_scale;
f2u_from.u = get_arg_val<uint32_t>(1);
f2u_to.u = get_arg_val<uint32_t>(2);
f2u_scale.f = f2u_to.f - f2u_from.f;
const uint32_t start_id = get_arg_val<uint32_t>(3);
const uint32_t num_tiles = get_arg_val<uint32_t>(4);
const uint32_t end_id = start_id + num_tiles;

init_sfpu(intermed_cb_id);

rand_tile_init(seed);
for (uint32_t i = start_id; i < end_id; ++i) {
cb_reserve_back(intermed_cb_id, 1);

tile_regs_acquire();
rand_tile(0, f2u_from.u, f2u_scale.u);
tile_regs_commit();

tile_regs_wait();
pack_tile(0, intermed_cb_id, 0);
tile_regs_release();

cb_push_back(intermed_cb_id, 1);
}
}
} // namespace NAMESPACE
Loading

0 comments on commit 002ca3a

Please sign in to comment.