-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* #13320: add draft code to test dropout op #13320: rollback add_2_integers_in_compute changes #13320: add draft code for uniform in programming_examples #13320: check random numbers frequency using map #13320: add draft code #13320: try remove reader kernel and cb1 #13320: add skeleton code for uniform #13220: support from, to param #13320: remove reader kernel #13320: add cb_intermed0 #13320: add unit-test #13320: refactor writer kernel by exposing an sfpu API #13320: remove native dir #13320: refactor #13320: update unit test #13320: change compute_kernel_config param to const optional and add license #13320: add cb id as compile args #13320: raise exception when from = to #13320: skip test for grayskull #13320: revise uniform operation #13320: update program factory #13320: update unit-test for bfloat16 and refactor cb var name in kernels #13320: add feature to benchmark ttnn with torch uniform #13320: refactor #13320: update compute_output_shapes func * #13320: update unit-test * #13320: update sfpu api * #13320: remove un-used header in kernel and counter in sfpu API * #13320: remove TTI_SFPSETSGN instr * #13320: allow set fp32_dest_acc to false * #13320: update callback test * #13320: skip test for grayskull * #13320: update sfpu to generate random float * #13320: update sfpu to generate random float between range [from, to) * #13320: refactor writer kernel * #13320: update compute kernel * #13320: update sfpu to use TTI_ only * #13320: change cb format to float32 * #13320: update var name * #13320: draft commit to support generate random bfloat16 in sfpu * #13320: rollback to inital approach to handle bfloat16 * #13320: rebase with main * #13320: rm unused header
- Loading branch information
1 parent
fe9560b
commit 002ca3a
Showing
15 changed files
with
760 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import time | ||
import pytest | ||
import numpy as np | ||
import ttnn | ||
from collections import Counter | ||
from loguru import logger | ||
from tests.ttnn.unit_tests.operations.test_utils import ( | ||
get_compute_kernel_options, | ||
compute_kernel_options, | ||
compute_kernel_ids, | ||
get_lib_dtype, | ||
) | ||
from models.utility_functions import skip_for_grayskull | ||
from enum import Enum | ||
|
||
|
||
class TestMode(Enum): | ||
VALIDATE = 0 | ||
BENCHMARK = 1 | ||
|
||
|
||
def check_torch_uniform_bfloat16(): | ||
input = torch.zeros(10, 10, dtype=torch.bfloat16).uniform_(2.1, 2.11) | ||
logger.info(input) | ||
|
||
|
||
# With small tensor ttnn might be slower than torch | ||
def benchmark_uniform(cpu_input, npu_input, rand_from, rand_to): | ||
iter_num = 10 | ||
|
||
cpu_total_time = 0 | ||
for i in range(iter_num + 1): | ||
cpu_start_time = time.time_ns() | ||
cpu_input.uniform_(rand_from, rand_to) | ||
cpu_end_time = time.time_ns() | ||
if i > 0: | ||
cpu_total_time += cpu_end_time - cpu_start_time | ||
logger.info(f"CPU avg time: {cpu_total_time / iter_num}ns") | ||
|
||
npu_total_time = 0 | ||
for i in range(iter_num + 1): | ||
npu_start_time = time.time_ns() | ||
ttnn.uniform(npu_input, rand_from, rand_to) | ||
npu_end_time = time.time_ns() | ||
if i > 0: | ||
npu_total_time += npu_end_time - npu_start_time | ||
logger.info(f"NPU avg time: {npu_total_time / iter_num}ns") | ||
|
||
|
||
def validate_uniform(npu_input, shape, rand_from, rand_to, dtype, compute_kernel_config): | ||
ttnn.uniform(npu_input, rand_from, rand_to, compute_kernel_config=compute_kernel_config) | ||
tt_input = ttnn.to_torch(npu_input).reshape(shape) | ||
elem_cnt = Counter(tt_input.flatten().tolist()) | ||
|
||
expected_mean, expected_var = (rand_from + rand_to) / 2, pow(rand_to - rand_from, 2) / 12 | ||
npu_mean, npu_var = torch.mean(tt_input).item(), torch.var(tt_input).item() | ||
min_val, max_val = torch.min(tt_input).item(), torch.max(tt_input).item() | ||
|
||
logger.info(f"Distinct elements: {len(elem_cnt.keys())}") | ||
if max_val == rand_to: | ||
logger.info(f"Count max_val: {elem_cnt[max_val]}") | ||
logger.info(f"Min val: {min_val}, Max val: {max_val}") | ||
logger.info(f"Expected mean: {expected_mean}, NPU mean: {npu_mean}") | ||
logger.info(f"Expected var: {expected_var}, NPU var: {npu_var}") | ||
|
||
assert torch.tensor(rand_from, dtype=get_lib_dtype(torch, dtype)) <= min_val and max_val < torch.tensor( | ||
rand_to, dtype=get_lib_dtype(torch, dtype) | ||
) | ||
assert np.allclose(npu_mean, expected_mean, rtol=0.5) | ||
assert np.allclose(npu_var, expected_var, rtol=0.5) | ||
|
||
|
||
def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, mode=TestMode.VALIDATE): | ||
compute_kernel_config = get_compute_kernel_options(compute_kernel_options) | ||
rand_from, rand_to = rand_range[0], rand_range[1] | ||
cpu_input = torch.ones(shape, dtype=get_lib_dtype(torch, dtype)) | ||
npu_input = ttnn.from_torch(cpu_input, device=device, dtype=get_lib_dtype(ttnn, dtype), layout=ttnn.TILE_LAYOUT) | ||
|
||
if mode == TestMode.BENCHMARK: | ||
benchmark_uniform(cpu_input=cpu_input, npu_input=npu_input, rand_from=rand_from, rand_to=rand_to) | ||
else: | ||
validate_uniform( | ||
npu_input=npu_input, | ||
shape=shape, | ||
rand_from=rand_from, | ||
rand_to=rand_to, | ||
dtype=dtype, | ||
compute_kernel_config=compute_kernel_config, | ||
) | ||
|
||
|
||
# fmt: off | ||
@skip_for_grayskull("Requires wormhole_b0 to run") | ||
@pytest.mark.parametrize("shape", | ||
[ | ||
[32, 32], | ||
[64, 64], | ||
[1, 512, 2, 256], | ||
[512, 512], | ||
[1024, 1024], | ||
], | ||
) | ||
@pytest.mark.parametrize("rand_range", | ||
[ | ||
[0, 1], | ||
[2.1, 9], | ||
[-5.1, 1.2] | ||
] | ||
) | ||
@pytest.mark.parametrize("dtype", | ||
[ | ||
"bfloat16", | ||
"float32" | ||
] | ||
) | ||
# fmt: on | ||
def test_uniform(shape, rand_range, dtype, device): | ||
torch.manual_seed(0) | ||
run_uniform(shape, rand_range, dtype, device) | ||
|
||
|
||
@skip_for_grayskull("Requires wormhole_b0 to run") | ||
@pytest.mark.parametrize( | ||
"shape", | ||
[[2, 32, 32, 16]], | ||
) | ||
@pytest.mark.parametrize("rand_range", [[-3, 4]]) | ||
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"]) | ||
def test_uniform_callback(shape, rand_range, dtype, device, use_program_cache): | ||
torch.manual_seed(0) | ||
num_program_cache_entries_list = [] | ||
for i in range(2): | ||
run_uniform(shape, rand_range, dtype, device) | ||
# Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr | ||
tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device) | ||
num_program_cache_entries_list.append(device.num_program_cache_entries()) | ||
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") | ||
assert num_program_cache_entries_list[0] > 0 | ||
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1] | ||
|
||
|
||
@skip_for_grayskull("Requires wormhole_b0 to run") | ||
@pytest.mark.parametrize( | ||
"shape", | ||
[[512, 512], [5, 2, 4, 70, 40]], | ||
) | ||
@pytest.mark.parametrize("rand_range", [[0, 1]]) | ||
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"]) | ||
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) | ||
def test_uniform_with_compute_kernel_options(shape, rand_range, dtype, device, compute_kernel_options): | ||
torch.manual_seed(0) | ||
run_uniform(shape, rand_range, dtype, device, compute_kernel_options) |
53 changes: 53 additions & 0 deletions
53
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_rand.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#pragma once | ||
|
||
#include "ckernel.h" | ||
#include "ckernel_defs.h" | ||
|
||
using namespace sfpi; | ||
|
||
namespace ckernel::sfpu { | ||
|
||
template <bool APPROXIMATION_MODE> | ||
inline void rand_init(uint32_t seed) { | ||
init_prng_seed(seed); | ||
} | ||
|
||
template <bool APPROXIMATION_MODE> | ||
inline void rand(uint32_t from, uint32_t scale) { | ||
// Load scale param to lreg1 | ||
TT_SFPLOADI(p_sfpu::LREG1, 10, scale & 0xFFFF); | ||
TT_SFPLOADI(p_sfpu::LREG1, 8, scale >> 16); | ||
|
||
// Load from param to lreg2 | ||
TT_SFPLOADI(p_sfpu::LREG2, 10, from & 0xFFFF); | ||
TT_SFPLOADI(p_sfpu::LREG2, 8, from >> 16); | ||
|
||
#pragma GCC unroll 0 | ||
for (int d = 0; d < 8; d++) { | ||
// Generate random float | ||
TTI_SFPMOV(0, 9, p_sfpu::LREG0, 8); | ||
|
||
// Unset sign bit and Set exponent to 127 to ensure the float is within the range [1, 2). | ||
// lreg0.sign = 0 | ||
// lreg0 = {sign: 0, exponent: 127, mantissa: lreg0.mantissa} | ||
TTI_SFPSETSGN(0, p_sfpu::LREG0, p_sfpu::LREG0, 1); | ||
TTI_SFPSETEXP(127, p_sfpu::LREG0, p_sfpu::LREG0, 1); | ||
|
||
// -1 to ensure the float is within the range [0, 1). | ||
// lreg0 = lreg0 - 1 | ||
TTI_SFPADDI(0xbf80 /*-1*/, p_sfpu::LREG0, 0); | ||
TTI_SFPNOP; | ||
|
||
// Scale the float from [0, 1) to [from, from + scale) | ||
// lreg0 = lreg0 * scale + from | ||
TTI_SFPMAD(p_sfpu::LREG0, p_sfpu::LREG1, p_sfpu::LREG2, p_sfpu::LREG0, 1); | ||
TTI_SFPNOP; | ||
|
||
TTI_SFPSTORE(0, 3, 3, 0); | ||
dst_reg++; | ||
} | ||
} | ||
} // namespace ckernel::sfpu |
27 changes: 27 additions & 0 deletions
27
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_rand.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ckernel_instr_params.h" | ||
#include "ckernel_sfpu_rand.h" | ||
#include "llk_math_eltwise_unary_sfpu_init.h" | ||
#include "llk_math_eltwise_unary_sfpu_params.h" | ||
|
||
namespace ckernel { | ||
|
||
// New LLK SFPU APIs | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_unary_sfpu_rand_init(uint32_t seed = 0) { | ||
llk_math_eltwise_unary_sfpu_init<SfpuType::unused, APPROXIMATE>(sfpu::rand_init<APPROXIMATE>, seed); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_unary_sfpu_rand(uint32_t dst_index, uint32_t from, uint32_t scale) { | ||
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>( | ||
ckernel::sfpu::rand<APPROXIMATE>, dst_index, VectorMode::RC, from, scale); | ||
} | ||
|
||
} // namespace ckernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "compute_kernel_api/common_globals.h" | ||
#ifdef TRISC_MATH | ||
#include "llk_math_eltwise_unary_sfpu_rand.h" | ||
#define MAIN math_main() | ||
#define MATH(x) x | ||
#else | ||
#define MATH(x) | ||
#endif | ||
|
||
namespace ckernel { | ||
|
||
/** | ||
* Performs element-wise rand on each element of a of a tile in DST register at index tile_index. | ||
* That is each element is overwritten with a randomly generated float. | ||
* The DST register buffer must be in acquired state via *acquire_dst* call. | ||
* This call is blocking and is only available on the compute engine. | ||
* | ||
* Return value: None | ||
* | ||
* | Argument | Description | Type | Valid | ||
* Range | Required | | ||
* |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|-----------| | ||
* | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be | ||
* less than the size of the DST register buffer | True | | from | Random range lowerbound(inclusive) | | ||
* uint | Any number | True | | scale | Random scale rand | ||
* float in range [from, from + scale] | uint | Must be greater than 0 | True | | ||
*/ | ||
ALWI void rand_tile(uint32_t idst, uint32_t from, uint32_t scale) { | ||
MATH((llk_math_eltwise_unary_sfpu_rand<APPROX>(idst, from, scale))); | ||
} | ||
|
||
/** | ||
* Please refer to documentation for any_init. | ||
*/ | ||
ALWI void rand_tile_init(uint32_t seed) { MATH((llk_math_eltwise_unary_sfpu_rand_init<APPROX>(seed))); } | ||
|
||
} // namespace ckernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
ttnn/cpp/ttnn/operations/uniform/device/kernels/compute_uniform.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "compute_kernel_api.h" | ||
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" | ||
#include "compute_kernel_api/eltwise_unary/rand.h" | ||
|
||
namespace NAMESPACE { | ||
|
||
void MAIN { | ||
constexpr uint32_t intermed_cb_id = get_compile_time_arg_val(0); | ||
|
||
const uint32_t seed = get_arg_val<uint32_t>(0); | ||
union { | ||
float f; | ||
uint32_t u; | ||
} f2u_from, f2u_to, f2u_scale; | ||
f2u_from.u = get_arg_val<uint32_t>(1); | ||
f2u_to.u = get_arg_val<uint32_t>(2); | ||
f2u_scale.f = f2u_to.f - f2u_from.f; | ||
const uint32_t start_id = get_arg_val<uint32_t>(3); | ||
const uint32_t num_tiles = get_arg_val<uint32_t>(4); | ||
const uint32_t end_id = start_id + num_tiles; | ||
|
||
init_sfpu(intermed_cb_id); | ||
|
||
rand_tile_init(seed); | ||
for (uint32_t i = start_id; i < end_id; ++i) { | ||
cb_reserve_back(intermed_cb_id, 1); | ||
|
||
tile_regs_acquire(); | ||
rand_tile(0, f2u_from.u, f2u_scale.u); | ||
tile_regs_commit(); | ||
|
||
tile_regs_wait(); | ||
pack_tile(0, intermed_cb_id, 0); | ||
tile_regs_release(); | ||
|
||
cb_push_back(intermed_cb_id, 1); | ||
} | ||
} | ||
} // namespace NAMESPACE |
Oops, something went wrong.