diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 6bf76117543..865b29daefc 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -8,7 +8,10 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_to_and_from_json.cpp ) -set(TTNN_CCL_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp) +set(TTNN_CCL_UNIT_TESTS_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +) set(TTNN_TENSOR_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensor/common_tensor_test_utils.cpp diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp index 41d453e2793..66662d02630 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp @@ -38,7 +38,6 @@ void kernel_main() { } noc_async_read_barrier(); cb_push_back(cb_id_in0, pages_to_read); - // DPRINT << "SR " << num_pages_read << "\n"; } DPRINT << "SR DONE\n"; diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp new file mode 100644 index 00000000000..3437c819346 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +void kernel_main() { + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t num_pages_to_read_total = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = 1; + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + const uint32_t src_addr = get_arg_val(0); + + const InterleavedAddrGen source_address_generator = { + .bank_base_address = src_addr, .page_size = page_size}; + + DPRINT << "swr: args " << + "\n\tsrc_addr="<(pages_per_edm_buffer, num_pages_to_read_total - num_pages_read); + cb_reserve_back(cb_id_in0, pages_to_read); + uint32_t local_l1_read_addr = get_write_ptr(cb_id_in0); + local_l1_read_addr += sizeof(tt::fabric::PacketHeader); + + for (uint32_t p = 0; p < pages_to_read; ++p) { + uint64_t src_noc_addr = get_noc_addr(num_pages_read + p, source_address_generator); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); + local_l1_read_addr += page_size; + } + noc_async_read_barrier(); + cb_push_back(cb_id_in0, pages_to_read); + } + +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp new file mode 100644 index 00000000000..e0cb2f50a17 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp @@ -0,0 +1,196 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + + +struct unicast_mode { + uint8_t distance; +}; +struct mcast_mode { + uint8_t distance; + uint8_t range; +}; + +union transmit_config { + unicast_mode unicast; + mcast_mode mcast; +}; + +// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). +// -> takes input from local cb and pushes to erisc L1 +void kernel_main() { + + // Test doesn't support multiple pages per send yet since we are writing + // to interleaved which will never have subsequent pages on the same core + // (and hence, able to share a packet header) + constexpr uint32_t num_pages_per_send = 1;//get_compile_time_arg_val(0); + constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3); + constexpr bool dest_is_dram = get_compile_time_arg_val(4) != 0; + constexpr bool mcast_mode = get_compile_time_arg_val(5) == 1; + + size_t arg_idx = 0; + const uint32_t eth_l1_base_addr = get_arg_val(arg_idx++); + // erisc l1 semaphore address + const uint32_t eth_sender_l1_sem_addr = get_arg_val(arg_idx++); + volatile uint32_t* const writer_send_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + const uint32_t eth_sender_noc_x = get_arg_val(arg_idx++); + const uint32_t eth_sender_noc_y = get_arg_val(arg_idx++); + const uint32_t num_buffers_per_edm_channel = get_arg_val(arg_idx++); + + size_t edm_connection_handshake_addr = get_semaphore(get_arg_val(arg_idx++)); + size_t edm_worker_location_info_addr = get_arg_val(arg_idx++); + size_t edm_buffer_size_bytes = get_arg_val(arg_idx++); + size_t dest_addr = get_arg_val(arg_idx++); + volatile uint32_t* const last_message_semaphore_address = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + *last_message_semaphore_address = 0; + auto worker_buffer_index_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(writer_send_sem_addr)); + ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(last_message_semaphore_address)); + + transmit_config config; + if (mcast_mode) { + config.mcast.distance = static_cast(get_arg_val(arg_idx++)); + config.mcast.range = static_cast(get_arg_val(arg_idx++)); + } else { + config.unicast.distance = static_cast(get_arg_val(arg_idx++)); + } + + const InterleavedAddrGen dest_addr_gen = { + .bank_base_address = dest_addr, .page_size = page_size}; + + + ASSERT(num_buffers_per_channel > 0); + auto sender = tt::fabric::WorkerToFabricEdmSender( + eth_sender_noc_x, + eth_sender_noc_y, + eth_l1_base_addr, + num_buffers_per_channel, + eth_sender_l1_sem_addr, + + edm_connection_handshake_addr, + edm_worker_location_info_addr, + edm_buffer_size_bytes, + writer_send_sem_addr, + worker_buffer_index_semaphore_addr + ); + + sender.open(); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + // We need to normalize all noc addresses to be for a consistent noc ID + // so the remote sender core can correctly send the packet. In the future + // we can decide if it's better for the noc index to be embedded in the packet + // header (for now we don't do that) + constexpr size_t NORMALIZED_NOC_INDEX = 0; + + uint32_t buffer_index = 0; + cb_wait_front(cb_id_in0, 1); + auto a_packet_header_addr = get_read_ptr(cb_id_in0); + for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - p); + sender.wait_for_empty_write_slot(); + cb_wait_front(cb_id_in0, pages_to_send); + + // bit of a hack to extract X/Y + const auto dest_noc_address = get_noc_addr(p, dest_addr_gen, 0, NORMALIZED_NOC_INDEX); + const size_t dest_addr = dest_noc_address & 0xFFFFFFFF; + const size_t dest_noc_x = (dest_noc_address >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t dest_noc_y = (dest_noc_address >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t packet_size = page_size + sizeof(tt::fabric::PacketHeader); + + auto packet_addr = get_read_ptr(cb_id_in0); + auto &packet_header = *reinterpret_cast(packet_addr); + if constexpr (mcast_mode) { + packet_header.to_write() + .to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y) + }); + packet_header.reserved2 = 0x1111; // debug only + } else { + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{config.unicast.distance}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y) + }); + packet_header.reserved2 = 0x1111; // debug only + } + + uint64_t buffer_address = sender.edm_buffer_addr + (*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t))); + sender.send_payload_blocking_from_address(packet_addr, packet_size); + noc_async_writes_flushed(); + cb_pop_front(cb_id_in0, pages_to_send); + } + + if constexpr (!mcast_mode) { + sender.wait_for_empty_write_slot(); + + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + ASSERT(*last_message_semaphore_address == 0); + packet_header.reserved = 0xE; + packet_header.reserved2 = 0xFFFF; + packet_header.to_atomic_inc(); + packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{1}); + packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader( + reinterpret_cast(last_message_semaphore_address), + 1, + 32, + my_x[0], + my_y[0] + )); + + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + + noc_semaphore_wait(last_message_semaphore_address, 1); + } + + bool closed = false; + size_t num_endpoints_to_terminate = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_endpoints_to_terminate; i++) { + size_t edm_noc_x = get_arg_val(arg_idx++); + size_t edm_noc_y = get_arg_val(arg_idx++); + size_t distance = get_arg_val(arg_idx++); + size_t termination_addr = get_arg_val(arg_idx++); + + if (!closed && distance == 0) { + closed = true; + sender.close(); + } + if (distance == 0) { + noc_inline_dw_write(get_noc_addr(edm_noc_x, edm_noc_y, termination_addr), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + } else { + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + reinterpret_cast(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] = tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE; + sender.wait_for_empty_write_slot(); + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast(distance - 1)}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + termination_addr, + sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t), + static_cast(edm_noc_x), + static_cast(edm_noc_y) + }); + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + noc_async_writes_flushed(); + } + } + if (!closed) { + sender.close(); + } + +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp new file mode 100644 index 00000000000..1cb446d470e --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -0,0 +1,949 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "device/tt_arch_types.h" +#include "gtest/gtest.h" +// #include "tt_backend_api_types.hpp" +#include "tt_metal/common/core_coord.hpp" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +#include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/test_utils/stimulus.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +using namespace tt; +using namespace tt::test_utils; +using namespace tt::test_utils::df; + +class T3000TestDevice { + public: + T3000TestDevice() : device_open(false) { + arch_ = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + + num_devices_ = tt::tt_metal::GetNumAvailableDevices(); + if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() >= 4 and + tt::tt_metal::GetNumPCIeDevices() >= 1) { + std::vector ids(num_devices_, 0); + std::iota(ids.begin(), ids.end(), 0); + devices_ = tt::tt_metal::detail::CreateDevices(ids); + + } else { + TT_THROW("This suite can only be run on T3000 Wormhole devices"); + } + device_open = true; + } + ~T3000TestDevice() { + if (device_open) { + TearDown(); + } + } + + void TearDown() { + device_open = false; + for (auto [device_id, device_ptr] : devices_) { + tt::tt_metal::CloseDevice(device_ptr); + } + } + + std::map devices_; + tt::ARCH arch_; + size_t num_devices_; + + private: + bool device_open; +}; + +struct BankedConfig { + size_t num_pages; + size_t size_bytes; + size_t page_size_bytes; + BufferType input_buffer_type; // = BufferType::L1; + BufferType output_buffer_type; // = BufferType::L1; + tt::DataFormat l1_data_format; // = tt::DataFormat::Float16_b; +}; + +struct KernelXY { + uint16_t x; + uint16_t y; + + uint32_t to_uint32() const { return y << 16 | x; } +}; + +struct edm_termination_info_t { + uint32_t distance; + uint32_t edm_noc_x; + uint32_t edm_noc_y; + uint32_t termination_addr; +}; + +enum Correctness { Correct, Incorrect }; + +struct EthLinkBuilder { + ttnn::ccl::FabricEriscDatamoverBuilder sender_edm_builder; // chip_0_edm_builder, + ttnn::ccl::FabricEriscDatamoverBuilder receiver_edm_builder; // chip_0_edm_builder, + tt_xy_pair sender_core; + tt_xy_pair receiver_core; + size_t downstream_edm_buffer_index_semaphore_id; +}; + +Correctness run_output_check( + std::vector const& all_zeros, + std::vector const& inputs, + std::shared_ptr output_buffer) { + constexpr bool debug_mode = true; + std::vector readback_data_vec; // init to 0 data for easier debug + readback_data_vec.reserve(all_zeros.size()); + std::fill(readback_data_vec.begin(), readback_data_vec.end(), 0); + + tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); + log_info(tt::LogTest, "Checking outputs"); + if (readback_data_vec.size() != inputs.size()) { + log_error(tt::LogTest, "Output size mismatch: expected {} got {}", inputs.size(), readback_data_vec.size()); + return Correctness::Incorrect; + } + bool pass = (readback_data_vec == inputs); + TT_ASSERT( + std::any_of(inputs.begin(), inputs.end(), [](uint32_t x) { return x != 0; }), + "Input buffer expected to not be all 0"); + if (not pass) { + log_error("Output mismatch"); + if (debug_mode) { + std::size_t num_printed_mismatches = 0; + for (size_t i = 0; i < readback_data_vec.size() && num_printed_mismatches < 64; i++) { + if (readback_data_vec[i] != inputs[i]) { + log_error("[{}]: expected {} got {}", i, inputs[i], readback_data_vec[i]); + num_printed_mismatches++; + } + } + log_error("... (remaining mismatches omitted)"); + } + } + return Correctness::Correct; +}; + +void run_programs(std::vector& programs, std::vector const& devices) { + EXPECT_EQ(programs.size(), devices.size()); + const size_t num_programs = programs.size(); + try { + for (size_t i = 0; i < num_programs; i++) { + tt::tt_metal::detail::CompileProgram(devices.at(i), programs.at(i)); + } + } catch (std::exception& e) { + log_error("Failed compile: {}", e.what()); + throw e; + } + + log_info(tt::LogTest, "Running..."); + + std::vector threads; + threads.reserve(num_programs); + if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE")) { + for (size_t i = 0; i < num_programs; i++) { + threads.emplace_back(std::thread([&] { tt_metal::detail::LaunchProgram(devices.at(i), programs.at(i)); })); + } + + std::ranges::for_each(threads, [](std::thread& t) { t.join(); }); + } else { + for (size_t i = 0; i < num_programs; i++) { + tt_metal::EnqueueProgram(devices.at(i)->command_queue(), programs.at(i), false); + } + + log_debug(tt::LogTest, "Calling Finish"); + for (size_t i = 0; i < num_programs; i++) { + tt_metal::Finish(devices.at(i)->command_queue()); + } + } +} + +std::tuple, std::vector> build_input_buffer( + Device* first_device, size_t tensor_size_bytes, BankedConfig const& test_config) { + auto inputs = std::vector(tensor_size_bytes / sizeof(uint32_t), 0); + std::iota(inputs.begin(), inputs.end(), 0); + + // Input buffer + auto local_input_buffer = CreateBuffer(InterleavedBufferConfig{ + first_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + tt_metal::detail::WriteToBuffer(local_input_buffer, inputs); + return {local_input_buffer, inputs}; +} + +struct EthLinkHop { + CoreCoord hop_src; + CoreCoord hop_dest; +}; + +struct ChipConnection { + std::vector links; +}; + +struct unicast_send { + size_t distance; +}; +struct mcast_send { + size_t distance; + size_t range; +}; + + +using mode_variant_t = std::variant; + +static constexpr size_t PACKET_HEADER_SIZE_BYTES = sizeof(tt::fabric::PacketHeader); +void generate_sender_worker_kernels( + Program& program, + Device* device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::SenderWorkerAdapterSpec const& worker_fabric_connection, + mode_variant_t const& mode, + std::size_t edm_buffer_size, + uint32_t page_plus_header_size, + uint32_t num_pages_total, + uint32_t num_pages_per_edm_buffer, + uint32_t local_worker_fabric_semaphore_id, + uint32_t local_worker_last_message_semaphore_id, + uint32_t dram_input_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool src_is_dram, + uint32_t dram_output_buffer_base_addr, + bool dest_is_dram, + uint32_t worker_buffer_index_semaphore_id, + // farthest to closest + std::vector const& edm_termination_infos) { + std::vector sender_worker_reader_compile_args{ + src_is_dram, // + num_pages_total, // + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + num_pages_per_edm_buffer}; + std::vector sender_worker_reader_runtime_args{dram_input_buffer_base_addr}; + + log_info(tt::LogTest, "\tSenderReader CT Args"); + for (auto const& arg : sender_worker_reader_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderReader RT Args"); + for (auto const& arg : sender_worker_reader_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + std::vector sender_worker_writer_compile_args{ + num_pages_per_edm_buffer, + num_pages_total, + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + worker_fabric_connection.num_buffers_per_channel, + dest_is_dram, + std::holds_alternative(mode) ? 1 : 0}; + log_info(tt::LogTest, "worker_fabric_connection.edm_l1_sem_addr: {}", worker_fabric_connection.edm_l1_sem_addr); + log_info(tt::LogTest, "worker_buffer_index_semaphore_id: {}", worker_buffer_index_semaphore_id); + log_info(tt::LogTest, "last_message_semaphore_address: {}", local_worker_last_message_semaphore_id); + log_info( + tt::LogTest, + "Sender communicating with EDM: x={}, y={}", + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y); + std::vector sender_worker_writer_runtime_args{ + worker_fabric_connection.edm_buffer_base_addr, + worker_fabric_connection.edm_l1_sem_addr, + local_worker_fabric_semaphore_id, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + worker_fabric_connection.num_buffers_per_channel, + + worker_fabric_connection.edm_connection_handshake_addr, + worker_fabric_connection.edm_worker_location_info_addr, + edm_buffer_size, + dram_output_buffer_base_addr, + local_worker_last_message_semaphore_id, + worker_buffer_index_semaphore_id}; + + if (std::holds_alternative(mode)) { + sender_worker_writer_runtime_args.push_back(std::get(mode).distance); + sender_worker_writer_runtime_args.push_back(std::get(mode).range); + } else { + sender_worker_writer_runtime_args.push_back(std::get(mode).distance); + } + + sender_worker_writer_runtime_args.push_back(edm_termination_infos.size()); + for (auto const& info : edm_termination_infos) { + sender_worker_writer_runtime_args.push_back(info.edm_noc_x); + sender_worker_writer_runtime_args.push_back(info.edm_noc_y); + sender_worker_writer_runtime_args.push_back(info.distance); + sender_worker_writer_runtime_args.push_back(info.termination_addr); + log_info( + tt::LogTest, + "EDM termination info: x={}, y={}, distance={}, termination_addr={}", + info.edm_noc_x, + info.edm_noc_y, + info.distance, + info.termination_addr); + } + + uint32_t src0_cb_index = CB::c_in0; + log_info(tt::LogTest, "\tSenderWriter CT Args"); + for (auto const& arg : sender_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderWriter RT Args"); + for (auto const& arg : sender_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + // Just want a dummy DF + tt::DataFormat df = (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 1024 ? tt::DataFormat::Bfp8 + : (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 2048 ? tt::DataFormat::Float16 + : tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_plus_header_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_plus_header_size); + CBHandle sender_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + auto sender_worker_reader_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = sender_worker_reader_compile_args}); + auto sender_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = sender_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs(program, sender_worker_reader_kernel, worker_core, sender_worker_reader_runtime_args); + tt_metal::SetRuntimeArgs(program, sender_worker_writer_kernel, worker_core, sender_worker_writer_runtime_args); +} + +bool RunLoopbackTest( + tt_metal::Device* sender_device, + tt_metal::Device* receiver_device, + + const CoreCoord& eth_sender_core, + const CoreCoord& eth_receiver_core, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram) { + std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); + std::size_t tensor_size_bytes = num_pages_total * page_size; + + std::vector programs(2); + auto& sender_program = programs.at(0); + auto& receiver_program = programs.at(1); + + std::vector worker_cores = {CoreCoord(0, 0)}; + + auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + + std::optional chip0_receiver_channel_downstream_flow_control_semaphore_id = std::nullopt; + auto chip0_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + + std::optional chip1_receiver_channel_downstream_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_downstream_edm_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + auto [local_input_buffer, inputs] = build_input_buffer(sender_device, tensor_size_bytes, test_config); + + std::vector all_zeros(inputs.size(), 0); + auto local_output_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + + tt_metal::detail::WriteToBuffer(local_output_buffer, all_zeros); + + auto local_input_buffer_address = local_input_buffer->address(); + auto local_output_buffer_address = local_output_buffer->address(); + + //////////////////////////////////////////////////////////////////////////// + // EDM Builder Setup + //////////////////////////////////////////////////////////////////////////// + + static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; + const size_t local_chip_id = 0; + const size_t remote_chip_id = 1; + auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + sender_device->ethernet_core_from_logical_core(eth_sender_core).x, + sender_device->ethernet_core_from_logical_core(eth_sender_core).y, + local_chip_id, + remote_chip_id, + + chip0_receiver_channel_downstream_flow_control_semaphore_id, + chip0_sender_channel_0_flow_control_semaphore_id, + chip0_sender_channel_1_flow_control_semaphore_id, + chip0_sender_channel_0_connection_semaphore_id, + chip0_sender_channel_1_connection_semaphore_id, + + edm_config); + auto chip0_worker_fabric_connection = chip_0_edm_builder.build_connection_to_worker_channel(); + auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + receiver_device->ethernet_core_from_logical_core(eth_receiver_core).x, + receiver_device->ethernet_core_from_logical_core(eth_receiver_core).y, + remote_chip_id, + local_chip_id, + + chip1_receiver_channel_downstream_flow_control_semaphore_id, // this is the receiver channel's local sem for + // flow controlling with downstream fabric sender + chip1_sender_channel_0_flow_control_semaphore_id, + chip1_sender_channel_1_flow_control_semaphore_id, + chip1_sender_channel_0_connection_semaphore_id, + chip1_sender_channel_1_connection_semaphore_id, + + edm_config); + // Create the loopback connection on the second device + chip_1_edm_builder.connect_to_downstream_edm(chip_1_edm_builder, chip1_downstream_edm_buffer_index_semaphore_id); + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + const std::size_t pages_per_send = + (chip0_worker_fabric_connection.buffer_size_bytes - PACKET_HEADER_SIZE_BYTES) / page_size; + auto const& worker_core = worker_cores.at(0); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); + + std::vector const& edm_termination_infos = { + {1, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).x, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}, + {0, + sender_device->ethernet_core_from_logical_core(eth_sender_core).x, + sender_device->ethernet_core_from_logical_core(eth_sender_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}}; + + generate_sender_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + chip0_worker_fabric_connection, + unicast_send{1}, + edm_buffer_size, + page_plus_header_size, + num_pages_total, + pages_per_send, + local_worker_fabric_semaphore_id, + local_worker_last_message_semaphore_id, + local_input_buffer_address, + src_is_dram, + local_output_buffer_address, + dest_is_dram, + worker_buffer_index_semaphore_id, + edm_termination_infos); + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + auto local_edm_kernel = + ttnn::ccl::generate_edm_kernel(sender_program, sender_device, chip_0_edm_builder, eth_sender_core, NOC::NOC_0); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + receiver_program, receiver_device, chip_1_edm_builder, eth_receiver_core, NOC::NOC_0); + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + run_programs(programs, {sender_device, receiver_device}); + log_info(tt::LogTest, "Reading back outputs"); + + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + pass &= run_output_check(all_zeros, inputs, local_output_buffer) == Correctness::Correct; + } + return pass; +} + +bool RunLineFabricTest( + std::vector devices, + std::vector const& hops, + + const size_t mcast_first_chip, + const size_t mcast_last_chip, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram) { + std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); + std::size_t tensor_size_bytes = num_pages_total * page_size; + + static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; + const size_t local_chip_id = 0; + const size_t remote_chip_id = 1; + const size_t num_hops = hops.size(); + auto programs = std::vector(devices.size()); + + std::vector worker_cores = {CoreCoord(0, 0)}; + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + // Input buffer + auto [local_input_buffer, inputs] = build_input_buffer(devices[0], tensor_size_bytes, test_config); + auto local_input_buffer_address = local_input_buffer->address(); + + std::vector all_zeros(inputs.size(), 0); + // output buffers + TT_ASSERT(mcast_first_chip <= mcast_last_chip, "mcast_first_chip must be less than or equal to mcast_last_chip"); + TT_ASSERT(mcast_last_chip < devices.size(), "mcast_last_chip must be less than the number of devices"); + std::vector> output_buffers; + output_buffers.reserve(mcast_last_chip - mcast_first_chip + 1); + for (size_t i = mcast_first_chip; i <= mcast_last_chip; i++) { + output_buffers.push_back(CreateBuffer(InterleavedBufferConfig{ + devices.at(i), test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type})); + tt_metal::detail::WriteToBuffer(output_buffers.back(), all_zeros); + } + auto local_output_buffer_address = output_buffers[0]->address(); + bool all_same_addr = std::ranges::all_of(output_buffers, [local_output_buffer_address](auto const& buffer) { + return buffer->address() == local_output_buffer_address; + }); + TT_ASSERT(all_same_addr, "All output buffers must have the same address"); + + //////////////////////////////////////////////////////////////////////////// + // Setup Semaphores and Builders + //////////////////////////////////////////////////////////////////////////// + std::vector edm_hop_builders; + edm_hop_builders.reserve(num_hops); + + auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + + for (size_t i = 0; i < num_hops; i++) { + const auto sender_device = devices.at(i); + const auto receiver_device = devices.at(i + 1); + const auto edm_sender_core = hops.at(i).hop_src; + const auto edm_receiver_core = hops.at(i).hop_dest; + + const std::optional chip0_receiver_channel_downstream_flow_control_semaphore_id = std::nullopt; + const auto chip0_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + + std::optional chip1_receiver_channel_downstream_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_downstream_edm_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + + auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + sender_device->ethernet_core_from_logical_core(edm_sender_core).x, + sender_device->ethernet_core_from_logical_core(edm_sender_core).y, + local_chip_id, + remote_chip_id, + + chip0_receiver_channel_downstream_flow_control_semaphore_id, + chip0_sender_channel_0_flow_control_semaphore_id, + chip0_sender_channel_1_flow_control_semaphore_id, + chip0_sender_channel_0_connection_semaphore_id, + chip0_sender_channel_1_connection_semaphore_id, + + edm_config); + auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + receiver_device->ethernet_core_from_logical_core(edm_receiver_core).x, + receiver_device->ethernet_core_from_logical_core(edm_receiver_core).y, + remote_chip_id, + local_chip_id, + + chip1_receiver_channel_downstream_flow_control_semaphore_id, // this is the receiver channel's local sem + // for flow controlling with downstream fabric + // sender + chip1_sender_channel_0_flow_control_semaphore_id, + chip1_sender_channel_1_flow_control_semaphore_id, + chip1_sender_channel_0_connection_semaphore_id, + chip1_sender_channel_1_connection_semaphore_id, + + edm_config); + + edm_hop_builders.push_back(EthLinkBuilder{ + .sender_edm_builder = std::move(chip_0_edm_builder), + .receiver_edm_builder = std::move(chip_1_edm_builder), + .sender_core = edm_sender_core, + .receiver_core = edm_receiver_core, + .downstream_edm_buffer_index_semaphore_id = chip1_downstream_edm_buffer_index_semaphore_id}); + } + + for (size_t i = 0; i < num_hops - 1; i++) { + edm_hop_builders.at(i).receiver_edm_builder.connect_to_downstream_edm( + edm_hop_builders.at(i + 1).sender_edm_builder, + edm_hop_builders.at(i).downstream_edm_buffer_index_semaphore_id); + } + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + auto const& worker_core = worker_cores.at(0); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); + + std::vector edm_termination_infos; + edm_termination_infos.reserve(num_hops * 2); + for (int i = num_hops - 1; i >= 0; i--) { + const std::size_t distance_receiver = i + 1; + const auto& receiver_core = hops.at(i).hop_dest; + auto receiver_device = devices.at(i + 1); + edm_termination_infos.push_back( + {distance_receiver, + receiver_device->ethernet_core_from_logical_core(receiver_core).x, + receiver_device->ethernet_core_from_logical_core(receiver_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + const std::size_t distance_sender = i; + const auto& sender_core = hops.at(i).hop_src; + auto sender_device = devices.at(i); + edm_termination_infos.push_back( + {distance_sender, + sender_device->ethernet_core_from_logical_core(sender_core).x, + sender_device->ethernet_core_from_logical_core(sender_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + }; + + auto chip0_worker_fabric_connection = edm_hop_builders[0].sender_edm_builder.build_connection_to_worker_channel(); + const std::size_t pages_per_send = + (chip0_worker_fabric_connection.buffer_size_bytes - PACKET_HEADER_SIZE_BYTES) / page_size; + generate_sender_worker_kernels( + programs[0], + devices[0], + worker_core, + hops[0].hop_src, + chip0_worker_fabric_connection, + mcast_send{mcast_first_chip - 1, mcast_last_chip - mcast_first_chip}, + edm_buffer_size, + page_plus_header_size, + num_pages_total, + pages_per_send, + local_worker_fabric_semaphore_id, + local_worker_last_message_semaphore_id, + local_input_buffer_address, + src_is_dram, + local_output_buffer_address, + dest_is_dram, + worker_buffer_index_semaphore_id, + edm_termination_infos); + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + for (std::size_t i = 0; i < num_hops; i++) { + auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( + programs.at(i), // sender_program, + devices.at(i), // sender_device, + edm_hop_builders.at(i).sender_edm_builder, // chip_0_edm_builder, + edm_hop_builders.at(i).sender_core, // eth_sender_core, + NOC::NOC_0); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + programs.at(i + 1), + devices.at(i + 1), + edm_hop_builders.at(i).receiver_edm_builder, + edm_hop_builders.at(i).receiver_core, + NOC::NOC_0); + } + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + + run_programs(programs, devices); + log_info(tt::LogTest, "Reading back outputs"); + + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + + for (size_t i = mcast_first_chip; i <= mcast_last_chip; i++) { + bool compare_with_input = (mcast_first_chip <= i && i <= mcast_last_chip); + auto &golden_tensor = compare_with_input ? inputs : all_zeros; + pass &= run_output_check(all_zeros, golden_tensor, output_buffers.at(i)) == Correctness::Correct; + } + } + + return pass; +} + +// RESUME HERE AND IMPLEMENT MCAST TEST +int TestLineFabricEntrypoint( + const size_t mcast_first_chip, + const size_t mcast_last_chip, + const uint32_t page_size, + const uint32_t num_pages_total, + const bool src_is_dram, + const bool dest_is_dram) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on N300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + T3000TestDevice test_fixture; + + // build a line of devices + static constexpr size_t fabric_line_length = 4; + std::vector devices = { + test_fixture.devices_.at(0), + test_fixture.devices_.at(1), + test_fixture.devices_.at(2), + test_fixture.devices_.at(3)}; + std::vector fabric_hops; + fabric_hops.reserve(fabric_line_length); + + for (size_t hop = 0; hop < fabric_line_length - 1; hop++) { + auto src_device = devices[hop]; + auto dest_device = devices[hop + 1]; + auto target_dest_device_id = devices[hop + 1]->id(); + log_info(tt::LogTest, "Finding links between device {} and {}", src_device->id(), dest_device->id()); + auto const& active_eth_cores = src_device->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + + chip_id_t dest_device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end, "Error"); + std::tie(dest_device_id, eth_receiver_core) = + src_device->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (dest_device_id != target_dest_device_id); + TT_ASSERT(dest_device_id == target_dest_device_id); + + fabric_hops.push_back({eth_sender_core, eth_receiver_core}); + } + + bool success = false; + try { + success = RunLineFabricTest( + devices, + fabric_hops, + + mcast_first_chip, + mcast_last_chip, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram); + + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +int TestLoopbackEntrypoint( + const uint32_t page_size, const uint32_t num_pages_total, const bool src_is_dram, const bool dest_is_dram) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on N300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + T3000TestDevice test_fixture; + + const auto& device_0 = test_fixture.devices_.at(0); + + auto const& active_eth_cores = device_0->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + chip_id_t device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end, "Error"); + std::tie(device_id, eth_receiver_core) = device_0->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (device_id != 1); + TT_ASSERT(device_id == 1); + const auto& device_1 = test_fixture.devices_.at(device_id); + + bool success = false; + try { + success = RunLoopbackTest( + device_0, + device_1, + + eth_sender_core, + eth_receiver_core, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram); + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +//////////////////////////////////////////////////////////////////// +/// MESSAGE COUNT TERMINATION MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 2; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 10; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender and receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 20; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Currently disabled until mcast properly tested/broughtup +TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const size_t mcast_first_chip = 1; + const size_t mcast_last_chip = 3; + + auto result = TestLineFabricEntrypoint( + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram); + + ASSERT_EQ(result, 0); +} + +// EnablePersistentKernelCache diff --git a/tt_metal/hw/inc/ethernet/dataflow_api.h b/tt_metal/hw/inc/ethernet/dataflow_api.h index 8901021fac5..5b0ddafb995 100644 --- a/tt_metal/hw/inc/ethernet/dataflow_api.h +++ b/tt_metal/hw/inc/ethernet/dataflow_api.h @@ -203,6 +203,23 @@ void eth_send_bytes_over_channel_payload_only( } } +// Calls the unsafe variant of eth_send_packet under the hood which is guaranteed not to context switch +// We want this for code size reasons +FORCE_INLINE +void eth_send_bytes_over_channel_payload_only_unsafe( + uint32_t src_addr, + uint32_t dst_addr, + uint32_t num_bytes, + uint32_t num_bytes_per_send = 16, + uint32_t num_bytes_per_send_word_size = 1) { + uint32_t num_bytes_sent = 0; + while (num_bytes_sent < num_bytes) { + internal_::eth_send_packet_unsafe( + 0, ((num_bytes_sent + src_addr) >> 4), ((num_bytes_sent + dst_addr) >> 4), num_bytes_per_send_word_size); + num_bytes_sent += num_bytes_per_send; + } +} + /* * Sends the write completion signal to the receiver ethernet core, for transfers where the payload was already sent. * The second half of a full ethernet send. diff --git a/tt_metal/hw/inc/ethernet/tunneling.h b/tt_metal/hw/inc/ethernet/tunneling.h index b6e4cdd0bd5..043a133eeb0 100644 --- a/tt_metal/hw/inc/ethernet/tunneling.h +++ b/tt_metal/hw/inc/ethernet/tunneling.h @@ -26,7 +26,11 @@ struct eth_channel_sync_t { // First level ack that signals to sender that the payload was received by receiver, // indicating that sender can reuse the sender side buffer safely. volatile uint32_t receiver_ack; - uint32_t reserved_1; + + // Logical channel ID tagged by the sender. Not required when channels + // are connected 1:1 (single producer - single consumer) + volatile uint32_t src_id; + uint32_t reserved_2; }; @@ -66,6 +70,15 @@ void eth_send_packet(uint32_t q_num, uint32_t src_word_addr, uint32_t dest_word_ eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); } +FORCE_INLINE +void eth_send_packet_unsafe(uint32_t q_num, uint32_t src_word_addr, uint32_t dest_word_addr, uint32_t num_words) { + ASSERT(eth_txq_reg_read(q_num, ETH_TXQ_CMD) == 0); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_START_ADDR, src_word_addr << 4); + eth_txq_reg_write(q_num, ETH_TXQ_DEST_ADDR, dest_word_addr << 4); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_SIZE_BYTES, num_words << 4); + eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); +} + FORCE_INLINE void eth_write_remote_reg(uint32_t q_num, uint32_t reg_addr, uint32_t val) { while (eth_txq_reg_read(q_num, ETH_TXQ_CMD) != 0) { diff --git a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h index 48b6411911d..37c19fae91e 100644 --- a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h +++ b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h @@ -292,6 +292,7 @@ inline __attribute__((always_inline)) void noc_fast_write_dw_inline(uint32_t noc uint32_t be32 = be; uint32_t be_shift = (dest_addr & (NOC_WORD_BYTES-1)); + // If we're given a misaligned address, don't write to the bytes in the word below the address be32 = (be32 << be_shift); while (!noc_cmd_buf_ready(noc, cmd_buf)); diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index bc2b1773cc2..16aa324c520 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -10,6 +10,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 92e8b46e805..865f1a7e0bd 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -198,16 +198,17 @@ void generate_edm_kernels_for_ring_or_linear_topology( } } - -KernelHandle generate_edm_kernel( - tt::tt_metal::Program& program, +template +KernelHandle generate_edm_kernel_impl( + tt::tt_metal::Program& program, Device const* device, - ccl::EriscDatamoverBuilder const& edm_builder, + EDMBuilder const& edm_builder, + std::string const& kernel_path, CoreCoord const& eth_core, NOC noc_id) { edm_builder.dump_to_log(); - std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); + std::vector const& edm_kernel_rt_args = edm_builder.emit_runtime_args(); // Ethernet Kernels std::vector eth_sender_ct_args = edm_builder.emit_compile_time_args(); log_trace(tt::LogOp, "EDM core (x={},y={}):", eth_core.x, eth_core.y); @@ -216,17 +217,17 @@ KernelHandle generate_edm_kernel( log_trace(tt::LogOp, "\t{}", s); } - auto eth_sender_kernel =tt::tt_metal::CreateKernel( + auto eth_sender_kernel = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp", + kernel_path, eth_core, - tt::tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); + tt::tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); - tt::tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_clockwise_kernel_rt_args); + tt::tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_kernel_rt_args); std::stringstream ss; ss << "EDM ARGS:\n"; - for (auto const& s : edm_clockwise_kernel_rt_args) { + for (auto const& s : edm_kernel_rt_args) { ss << "\t" << s << "\n"; } log_trace(tt::LogOp, "{}", ss.str()); @@ -234,6 +235,31 @@ KernelHandle generate_edm_kernel( return eth_sender_kernel; } +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::FabricEriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id) { + return generate_edm_kernel_impl( + program, + device, + edm_builder, + "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp", + eth_core, + noc_id); +} + +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::EriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id) { + return generate_edm_kernel_impl( + program, device, edm_builder, "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp", eth_core, noc_id); +} + ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 51228970005..0ad4d35b3f1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -11,6 +11,7 @@ #include "ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/operations/ccl/common/types/ccl_types.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/program/program.hpp" #include "ttnn/tensor/types.hpp" @@ -467,6 +468,13 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { }; +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::FabricEriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id); + KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp new file mode 100644 index 00000000000..2b84cf6bac4 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" + +#include "common/math.hpp" +#include "eth_l1_address_map.h" +#include "tt_metal/common/assert.hpp" +#include "ttnn/operations/math.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +namespace ttnn::ccl { + + +// The channel structure is as follows: +// &header-> |----------------| channel_base_address +// | header | +// &payload-> |----------------| +// | | +// | payload | +// | | +// &channel_sync-> |----------------| +// | channel_sync | +// ------------------ +// + +FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( + std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size) { + TT_ASSERT(channel_buffer_size_bytes > sizeof(tt::fabric::PacketHeader) + 2 * FabricEriscDatamoverConfig::eth_channel_sync_size); + const std::size_t channel_buffer_size_with_channel_sync = + channel_buffer_size_bytes + sizeof(tt::fabric::PacketHeader); // + 16 // sizeof(tt::fabric::PacketHeader); + + this->channel_buffer_size_bytes = channel_buffer_size_bytes; + this->channel_buffer_size_bytes_with_channel_sync = channel_buffer_size_with_channel_sync; + const std::size_t total_ratio_count = 2 * sender_ratio_size + receiver_ratio_size; + this->sender_0_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, + channel_buffer_size_with_channel_sync); + this->sender_0_num_buffers = this->sender_0_channel_size_bytes / channel_buffer_size_with_channel_sync; + this->sender_1_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, + channel_buffer_size_with_channel_sync); + this->sender_1_num_buffers = this->sender_1_channel_size_bytes / channel_buffer_size_with_channel_sync; + this->receiver_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * receiver_ratio_size, + channel_buffer_size_with_channel_sync); + this->receiver_num_buffers = this->receiver_channel_size_bytes / channel_buffer_size_with_channel_sync; + + this->sender_0_channel_base_address = buffer_region_start; + this->sender_1_channel_base_address = this->sender_0_channel_base_address + this->sender_0_channel_size_bytes; + this->receiver_channel_base_address = this->sender_1_channel_base_address + this->sender_1_channel_size_bytes; + + log_trace(tt::LogOp, "Sender 0 channel_start: {}", this->sender_0_channel_base_address); + log_trace(tt::LogOp, "Sender 1 channel_start: {}", this->sender_1_channel_base_address); + log_trace(tt::LogOp, "Receiver channel_start: {}", this->receiver_channel_base_address); + + TT_ASSERT( + this->sender_0_channel_size_bytes + this->sender_1_channel_size_bytes + this->receiver_channel_size_bytes <= + this->available_channel_buffering_space); + TT_ASSERT( + this->receiver_channel_base_address + this->receiver_channel_size_bytes < + eth_l1_mem::address_map::MAX_L1_LOADING_SIZE); +} + +FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( + size_t my_noc_x, + size_t my_noc_y, + size_t my_chip_id, + size_t peer_chip_id, + + std::optional receiver_channel_downstream_flow_control_semaphore_id, + size_t sender_channel_0_flow_control_semaphore_id, + size_t sender_channel_1_flow_control_semaphore_id, + size_t sender_channel_0_connection_semaphore_id, + size_t sender_channel_1_connection_semaphore_id, + + FabricEriscDatamoverConfig const& config) : + my_noc_x(my_noc_x), + my_noc_y(my_noc_y), + config(config), + my_chip_id(my_chip_id), + peer_chip_id(peer_chip_id), + handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, FabricEriscDatamoverConfig::eth_channel_sync_size)), + channel_buffer_size(config.channel_buffer_size_bytes), + sender_0_num_buffers(config.sender_0_num_buffers), + sender_1_num_buffers(config.sender_1_num_buffers), + receiver_num_buffers(config.receiver_num_buffers), + + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + receiver_channel_downstream_flow_control_semaphore_id(receiver_channel_downstream_flow_control_semaphore_id), + sender_channel_0_flow_control_semaphore_id(sender_channel_0_flow_control_semaphore_id), + sender_channel_1_flow_control_semaphore_id(sender_channel_1_flow_control_semaphore_id), + sender_channel_0_connection_semaphore_id(sender_channel_0_connection_semaphore_id), + sender_channel_1_connection_semaphore_id(sender_channel_1_connection_semaphore_id), + + local_sender_channel_0_buffer_address(config.sender_0_channel_base_address), + local_sender_channel_0_connection_info_addr( + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address), + local_sender_channel_1_buffer_address(config.sender_1_channel_base_address), + local_sender_channel_1_connection_info_addr( + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address), + local_receiver_channel_buffer_address(config.receiver_channel_base_address), + + termination_signal_ptr(FabricEriscDatamoverConfig::termination_signal_address) {} + +std::vector FabricEriscDatamoverBuilder::emit_compile_time_args() const { + const bool is_handshake_master = this->my_chip_id < this->peer_chip_id; + TT_ASSERT(this->my_chip_id != this->peer_chip_id); + TT_ASSERT( + this->sender_0_num_buffers == this->sender_1_num_buffers); //, "Implementation expects sender_0_num_buffers and + // sender_1_num_buffers to be the same for now"); + return std::vector{ + is_handshake_master, + this->handshake_address, + this->channel_buffer_size, + + this->sender_0_num_buffers, + this->receiver_num_buffers, + + config.sender_0_channel_base_address, + FabricEriscDatamoverConfig::sender_channel_0_buffer_index_address, + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + config.sender_1_channel_base_address, + FabricEriscDatamoverConfig::sender_channel_1_buffer_index_address, + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + config.receiver_channel_base_address, + config.receiver_channel_base_address, + + config.sender_0_channel_base_address, + config.sender_1_channel_base_address, + + this->termination_signal_ptr}; +} + +std::vector FabricEriscDatamoverBuilder::emit_runtime_args() const { + return std::vector{ + this->sender_channel_0_connection_semaphore_id, + this->sender_channel_1_connection_semaphore_id, + this->downstream_edm_buffer_base_address != std::nullopt, + this->downstream_edm_buffer_base_address.value_or(0), + this->downstream_edm_noc_x.value_or(0), + this->downstream_edm_noc_y.value_or(0), + this->downstream_edm_semaphore_address.value_or(0), + this->downstream_edm_worker_registration_address.value_or(0), + this->downstream_edm_worker_location_info_address.value_or(0), + this->downstream_noc_interface_buffer_index_addr.value_or(0), + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + this->receiver_channel_downstream_flow_control_semaphore_id.value_or(0), + this->sender_channel_0_flow_control_semaphore_id, + this->sender_channel_1_flow_control_semaphore_id, + + }; +} + + +SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_channel() const { + return SenderWorkerAdapterSpec { + this->my_noc_x, + this->my_noc_y, + this->local_sender_channel_0_buffer_address, + this->sender_0_num_buffers, + this->sender_channel_0_flow_control_semaphore_id, + this->sender_channel_0_connection_semaphore_id, + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + this->config.channel_buffer_size_bytes + }; +} + + +SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_fabric_channel() const { + return SenderWorkerAdapterSpec { + this->my_noc_x, + this->my_noc_y, + this->local_sender_channel_1_buffer_address, + this->sender_1_num_buffers, + this->sender_channel_1_flow_control_semaphore_id, + this->sender_channel_1_connection_semaphore_id, + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + this->config.channel_buffer_size_bytes + }; +} + +void FabricEriscDatamoverBuilder::connect_to_downstream_edm(FabricEriscDatamoverBuilder const& downstream_edm, uint32_t downstream_edm_buffer_index_semaphore_id) { + auto const& adapter_spec = downstream_edm.build_connection_to_fabric_channel(); + + log_trace(tt::LogTest, "Connecting to downstream EDM at x={}, y={}", adapter_spec.edm_worker_x, adapter_spec.edm_worker_y); + + this->downstream_edm_noc_x = adapter_spec.edm_worker_x; + this->downstream_edm_noc_y = adapter_spec.edm_worker_y; + this->downstream_edm_buffer_base_address = adapter_spec.edm_buffer_base_addr; + this->downstream_edm_semaphore_address = adapter_spec.edm_l1_sem_addr; + this->downstream_edm_worker_registration_address = adapter_spec.edm_connection_handshake_addr; + this->downstream_edm_worker_location_info_address = adapter_spec.edm_worker_location_info_addr; + this->downstream_noc_interface_buffer_index_addr = downstream_edm_buffer_index_semaphore_id; +} + +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp new file mode 100644 index 00000000000..889c42405a1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "eth_l1_address_map.h" +#include "tt_metal/third_party/umd/device/tt_cluster_descriptor_types.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +namespace ttnn { +namespace ccl { + +struct FabricEriscDatamoverConfig { + static constexpr std::size_t field_size = 16; + static constexpr std::size_t buffer_alignment = 32; + static_assert(((buffer_alignment - 1) & buffer_alignment) == 0); + + // Global + static constexpr std::size_t eth_channel_sync_size = 16; + static constexpr std::size_t handshake_addr = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + static constexpr std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; + static constexpr std::size_t termination_signal_address = + edm_channel_ack_addr + (2 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work + + // Sender Channel 0 + static constexpr std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; + static constexpr std::size_t sender_channel_0_worker_connection_info_address = + sender_channel_0_buffer_index_address + field_size; + static_assert(field_size >= sizeof(tt::fabric::EDMChannelWorkerLocationInfo)); + + // Sender Channel 1 + static constexpr std::size_t sender_channel_1_buffer_index_address = + sender_channel_0_worker_connection_info_address + field_size; + static constexpr std::size_t sender_channel_1_worker_connection_info_address = + sender_channel_1_buffer_index_address + field_size; + + // Channel Allocations + static constexpr std::size_t buffer_region_start = + (sender_channel_1_worker_connection_info_address + field_size + buffer_alignment) & ~(buffer_alignment - 1); // Align + static constexpr std::size_t available_channel_buffering_space = + eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - buffer_region_start; + + FabricEriscDatamoverConfig( + std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size); + + std::size_t channel_buffer_size_bytes; + std::size_t channel_buffer_size_bytes_with_channel_sync; + std::size_t sender_0_channel_size_bytes; + std::size_t sender_0_num_buffers; + std::size_t sender_1_channel_size_bytes; + std::size_t sender_1_num_buffers; + std::size_t receiver_channel_size_bytes; + std::size_t receiver_num_buffers; + + std::size_t sender_0_channel_base_address; + std::size_t sender_1_channel_base_address; + std::size_t receiver_channel_base_address; +}; + +struct SenderWorkerAdapterSpec { + size_t edm_worker_x; + size_t edm_worker_y; + size_t edm_buffer_base_addr; + size_t num_buffers_per_channel; + size_t edm_l1_sem_addr; + size_t edm_connection_handshake_addr; + size_t edm_worker_location_info_addr; // The EDM's location for `EDMChannelWorkerLocationInfo` + size_t buffer_size_bytes; +}; +class FabricEriscDatamoverBuilder { + public: + FabricEriscDatamoverBuilder( + size_t my_noc_x, + size_t my_noc_y, + size_t my_chip_id, + size_t peer_chip_id, + + std::optional receiver_channel_downstream_flow_control_semaphore_id, + size_t sender_channel_0_flow_control_semaphore_id, + size_t sender_channel_1_flow_control_semaphore_id, + size_t sender_channel_0_connection_semaphore_id, + size_t sender_channel_1_connection_semaphore_id, + + FabricEriscDatamoverConfig const& config); + + [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_worker_channel() const; + [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_fabric_channel() const; + + [[nodiscard]] std::vector emit_compile_time_args() const; + + [[nodiscard]] std::vector emit_runtime_args() const; + + void connect_to_downstream_edm( + FabricEriscDatamoverBuilder const& downstream_edm, uint32_t downstream_edm_semaphore_id); + + void dump_to_log() const { + // TODO + } + + private: + size_t my_noc_x; + size_t my_noc_y; + FabricEriscDatamoverConfig config; + + size_t my_chip_id; + size_t peer_chip_id; + size_t handshake_address; + size_t channel_buffer_size; + + size_t sender_0_num_buffers; + size_t sender_1_num_buffers; + size_t receiver_num_buffers; + + size_t local_sender_channel_0_buffer_address; + size_t local_sender_channel_0_connection_info_addr; + size_t local_sender_channel_1_buffer_address; + size_t local_sender_channel_1_connection_info_addr; + size_t local_receiver_channel_buffer_address; + + size_t termination_signal_ptr; + + // Semaphore IDs + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + std::optional receiver_channel_downstream_flow_control_semaphore_id; + size_t sender_channel_0_flow_control_semaphore_id; + size_t sender_channel_1_flow_control_semaphore_id; + size_t sender_channel_0_connection_semaphore_id; + size_t sender_channel_1_connection_semaphore_id; + + std::optional downstream_edm_noc_x; + std::optional downstream_edm_noc_y; + std::optional downstream_edm_buffer_base_address; + std::optional downstream_edm_semaphore_address; + std::optional downstream_edm_worker_registration_address; + std::optional downstream_edm_worker_location_info_address; + std::optional downstream_noc_interface_buffer_index_addr; +}; + +}; // namespace ccl +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp new file mode 100644 index 00000000000..720af76ed71 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -0,0 +1,195 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "dataflow_api.h" + +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" + +#include "debug/assert.h" + +#include + + +namespace tt::fabric { + +struct WorkerToFabricEdmSender{ + WorkerToFabricEdmSender () : worker_sem_addr(nullptr) {} + + WorkerToFabricEdmSender ( + size_t edm_worker_x, + size_t edm_worker_y, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_id, + std::size_t edm_connection_handshake_addr, + std::size_t edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr, + uint32_t buffer_index_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_x, edm_worker_y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_x, edm_worker_y, get_semaphore(edm_l1_sem_id))), + edm_connection_handshake_addr(edm_connection_handshake_addr), + edm_worker_location_info_addr(edm_worker_location_info_addr), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(get_semaphore(edm_l1_sem_id)), + buffer_size_bytes(buffer_size_bytes), + buffer_index_ptr(reinterpret_cast(buffer_index_addr)) + { + ASSERT(buffer_size_bytes > 0); + } + + [[nodiscard]] FORCE_INLINE bool consumer_has_space() const { + return *this->worker_sem_addr == 1; + } + FORCE_INLINE void clear_flow_control_semaphore() const { + noc_semaphore_set(this->worker_sem_addr, 0); + } + FORCE_INLINE void wait_for_empty_write_slot() const { + noc_semaphore_wait(this->worker_sem_addr, 1); + } + + FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + send_payload_impl(cb_id, num_pages, page_size); + } + + // Does not wait for CB. Assumes caller handles CB data availability + FORCE_INLINE void send_payload_non_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + send_payload_impl(cb_id, num_pages, page_size); + } + + /* + * No CB + */ + FORCE_INLINE void send_payload_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } + + /* + * No CB + */ + // Does not wait for CB. Assumes caller handles CB data availability + FORCE_INLINE void send_payload_non_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } + + // Layout + // |-----------------------| + // | EDM Handshake | 16B + // |-----------------------| + // | EDM Ack Channel Sync | 16B + // |-----------------------| - + // | Connection Semaphore | 16B | + // |-----------------------| | + // | Buffer Index | 16B >- Per Sender Channel (On EDM) + // |-----------------------| | + // | Worker Connection Info| 16B |worker + // |-----------------------| -/ + // |-----------------------| + // + static constexpr size_t edm_sender_channel_field_stride_bytes = 16; + + FORCE_INLINE void open() { + auto dest_addr = this->edm_semaphore_addr; + static constexpr uint32_t open_connection_value = 1; + // May need to force buffer index to be a semaphore address + // remove the address portion to replace with the connection terminate address + dest_addr &= ~0x0000000FFFFFFFFFl; + uint64_t remote_buffer_index_addr = dest_addr | (edm_connection_handshake_addr + edm_sender_channel_field_stride_bytes); + ASSERT(remote_buffer_index_addr > 0); + noc_async_read(remote_buffer_index_addr, reinterpret_cast(this->buffer_index_ptr), sizeof(uint32_t)); + + ASSERT(edm_worker_location_info_addr == edm_connection_handshake_addr + 2 * edm_sender_channel_field_stride_bytes); + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_worker_location_info_addr; + // TODO: Need to change byte enable to be word enable + noc_inline_dw_write(dest_addr, reinterpret_cast(worker_sem_addr)); + noc_inline_dw_write(dest_addr + sizeof(uint32_t), ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); + + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr; + noc_inline_dw_write(dest_addr, open_connection_value); + noc_async_read_barrier(); + } + + FORCE_INLINE void close() { + auto dest_addr = this->edm_semaphore_addr; + static constexpr uint32_t terminate_connection_value = 0; + // remove the address portion to replace with the connection terminate address + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr; + noc_inline_dw_write(dest_addr, terminate_connection_value); + + // buffer index stored at location after handshake addr + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr + edm_sender_channel_field_stride_bytes; + noc_inline_dw_write(dest_addr, *this->buffer_index_ptr); + noc_async_write_barrier(); + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + size_t edm_connection_handshake_addr; + size_t edm_worker_location_info_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t *buffer_index_ptr; + + private: + template + FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, size_t size_bytes) { + this->clear_flow_control_semaphore(); + uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + + ASSERT(size_bytes <= this->buffer_size_bytes); + ASSERT(static_cast(buffer_address & 0x0FFFFFFF) <= 270000); + + /*{ // For debug purposes only. Useful to permanently backup the packet somewhere we can inspect with ttx-status + uint32_t dram_noc_x = my_y[0] == 1 ? 0 : 0; + uint32_t dram_noc_y = my_y[0] == 1 ? 0 : 5; + // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), 0x0F); + // noc_async_writes_flushed(); + // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset + 4), 0); + // auto pkthdr_size_words = sizeof(tt::fabric::PacketHeader) >> 2; + // for (size_t i = 0; i < pkthdr_size_words; i++) { + // reinterpret_cast(source_address)[pkthdr_size_words - i] = + // reinterpret_cast(source_address)[pkthdr_size_words - 1 - i]; + // } + // reinterpret_cast(source_address)[0] = 0xc0ffee; + // DPRINT << "NEXT STORAGE OFF: " << (uint32_t)storage_offset << "\n"; + noc_async_write(source_address, get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), size_bytes); + storage_offset += size_bytes; + storage_offset += 64; + storage_offset = storage_offset & (~0x1F); + }*/ + + send_chunk_from_address(source_address, 1, size_bytes, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + + *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + } + + template + FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + this->clear_flow_control_semaphore(); + uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + ASSERT(num_pages * page_size <= this->buffer_size_bytes); + send_chunk(cb_id, num_pages, page_size, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + } +}; + + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp new file mode 100644 index 00000000000..37210c2d012 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace tt::fabric { + +enum TerminationSignal : uint32_t { + KEEP_RUNNING = 0, + + // Wait for messages to drain + GRACEFULLY_TERMINATE = 1, + + // Immediately terminate - don't wait for any outstanding messages to arrive or drain out + IMMEDIATELY_TERMINATE = 2 +}; + +// 2 bits +enum CommandType : uint8_t { + WRITE = 0, + ATOMIC_INC = 1 +}; + +// How to send the payload across the cluster +// 1 bit +enum ChipSendType : uint8_t { + CHIP_UNICAST = 0, + CHIP_MULTICAST = 1 +}; +enum NocSendType : uint8_t { + NOC_UNICAST = 0, + NOC_MULTICAST = 1 +}; + + +struct UnicastRoutingCommandHeader { + uint8_t distance_in_hops; +}; +static_assert(sizeof(UnicastRoutingCommandHeader) == 1, "UnicastRoutingCommandHeader size is not 1 byte"); +struct MulticastRoutingCommandHeader { + uint8_t start_distance_in_hops: 4; + uint8_t range_hops: 4; // 0 implies unicast +}; +static_assert(sizeof(MulticastRoutingCommandHeader) == 1, "MulticastRoutingCommandHeader size is not 1 byte"); +union RoutingFields { + UnicastRoutingCommandHeader chip_unicast; + MulticastRoutingCommandHeader chip_mcast; +}; +static_assert(sizeof(RoutingFields) == sizeof(UnicastRoutingCommandHeader), "RoutingFields size is not 1 bytes"); + +struct NocUnicastCommandHeader { + uint32_t address; + uint32_t size; + uint8_t noc_x; + uint8_t noc_y; + uint16_t reserved; + // ignores header size + inline uint32_t get_payload_only_size() const { + return size; + } +}; +struct NocUnicastAtomicIncCommandHeader { + NocUnicastAtomicIncCommandHeader(uint32_t address, uint16_t val, uint16_t wrap, uint8_t noc_x, uint8_t noc_y) + : address(address), val(val), wrap(wrap), noc_x(noc_x), noc_y(noc_y) {} + + uint32_t address; + uint16_t val; + uint16_t wrap; + uint8_t noc_x; + uint8_t noc_y; + +}; +struct NocMulticastCommandHeader { + uint32_t address; + uint32_t size; + uint8_t noc_x_start; + uint8_t noc_y_start; + uint8_t mcast_rect_size_x; + uint8_t mcast_rect_size_y; + + // ignores header size + inline uint32_t get_payload_only_size() const { + return size; + } +}; +struct NocMulticastAtomicIncCommandHeader { + uint32_t address; + uint16_t val; + uint16_t wrap; + uint8_t noc_x_start; + uint8_t noc_y_start; + uint8_t size_x; + uint8_t size_y; +}; +static_assert(sizeof(NocUnicastCommandHeader) == 12, "NocUnicastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocMulticastCommandHeader) == 12, "NocMulticastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocUnicastAtomicIncCommandHeader) == 12, "NocUnicastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocMulticastAtomicIncCommandHeader) == 12, "NocAtomicIncCommandHeader size is not 1 byte"); +union CommandFields{ + NocUnicastCommandHeader unicast_write; + NocMulticastCommandHeader mcast_write; + NocUnicastAtomicIncCommandHeader unicast_seminc; + NocMulticastAtomicIncCommandHeader mcast_seminc; +} ; +static_assert(sizeof(CommandFields) <= 15, "CommandFields size is not 15 bytes"); + +// TODO: wrap this in a debug version that holds type info so we can assert for field/command/ +struct PacketHeader { + // TODO: trim this down noc_send_type 2 bits (4 values): + // -> unicast_write, mcast_write, unicast_seminc, mcast_seminc + // For now, kept it separate so I could do reads which would be handled differently + // but for our purposes we shouldn't need read so we should be able to omit the support + CommandType command_type : 2; + ChipSendType chip_send_type : 1; + NocSendType noc_send_type : 1; + uint8_t reserved : 4; + + RoutingFields routing_fields; + uint16_t reserved2; + CommandFields command_fields; + + // Sort of hack to work-around DRAM read alignment issues that must be 32B aligned + // To simplify worker kernel code, we for now decide to pad up the packet header + // to 32B so the user can simplify shift into their CB chunk by sizeof(tt::fabric::PacketHeader) + // and automatically work around the DRAM read alignment bug. + // + // Future changes will remove this padding and require the worker kernel to be aware of this bug + // and pad their own CBs conditionally when reading from DRAM. It'll be up to the users to + // manage this complexity. + uint32_t padding0; + uint32_t padding1; + uint32_t padding2; + uint32_t padding3; + + inline void set_command_type(CommandType &type) { this->command_type = type; } + inline void set_chip_send_type(ChipSendType &type) { this->chip_send_type = type; } + inline void set_noc_send_type(NocSendType &type) { this->noc_send_type = type; } + inline void set_routing_fields(RoutingFields &fields) { this->routing_fields = fields; } + inline void set_command_fields(CommandFields &fields) { this->command_fields = fields; } + + size_t get_payload_size_excluding_header() volatile const { + switch(this->command_type) { + case WRITE: { + switch(this->noc_send_type) { + case NOC_UNICAST: { + return this->command_fields.unicast_write.size - sizeof(PacketHeader); + } break; + case NOC_MULTICAST: { + return this->command_fields.mcast_write.size - sizeof(PacketHeader); + } break; + default: + return 0; + } + } break; + case ATOMIC_INC: { + return 0; + } break; + default: + return 0; + } + } + inline size_t get_payload_size_including_header() volatile const { + return get_payload_size_excluding_header() + sizeof(PacketHeader); + } + + inline PacketHeader& to_write() { this->command_type = WRITE; return *this; } + inline PacketHeader& to_atomic_inc() { this->command_type = ATOMIC_INC; return *this; } + + inline PacketHeader &to_chip_unicast(UnicastRoutingCommandHeader const &chip_unicast_command_header) { + this->chip_send_type = CHIP_UNICAST; + this->routing_fields.chip_unicast = chip_unicast_command_header; + return *this; + } + inline PacketHeader &to_chip_multicast(MulticastRoutingCommandHeader const &chip_multicast_command_header) { + this->chip_send_type = CHIP_MULTICAST; + this->routing_fields.chip_mcast = chip_multicast_command_header; + return *this; + } + inline PacketHeader &to_noc_unicast(NocUnicastCommandHeader const &noc_unicast_command_header) { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_write = noc_unicast_command_header; + return *this; + } + inline PacketHeader &to_noc_multicast(NocMulticastCommandHeader const &noc_multicast_command_header) { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_write = noc_multicast_command_header; + return *this; + } + inline PacketHeader &to_noc_unicast_atomic_inc( + NocUnicastAtomicIncCommandHeader const &noc_unicast_atomic_inc_command_header) { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_seminc = noc_unicast_atomic_inc_command_header; + return *this; + } + inline PacketHeader &to_noc_multicast_atomic_inc( + NocMulticastAtomicIncCommandHeader const &noc_multicast_atomic_inc_command_header) { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_seminc = noc_multicast_atomic_inc_command_header; + return *this; + } +}; + + +// TODO: When we remove the 32B padding requirement, reduce to 16B size check +static_assert(sizeof(PacketHeader) == 32, "sizeof(PacketHeader) is not equal to 32B"); + +static constexpr size_t header_size_bytes = sizeof(PacketHeader); + + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp new file mode 100644 index 00000000000..22267eb2bdb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "debug/assert.h" + +namespace tt::fabric { + +FORCE_INLINE void validate(PacketHeader const& packet_header) { + ASSERT(packet_header.command_type < 2); + ASSERT(packet_header.chip_send_type < 2); + ASSERT(packet_header.noc_send_type < 2); +} + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp new file mode 100644 index 00000000000..9e6ba23c4b1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -0,0 +1,228 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/hw/inc/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include + +void write_unicast_blocking(uint32_t local_address, uint64_t dest_address, uint32_t size_bytes) { + noc_async_write(local_address, dest_address, size_bytes); + noc_async_writes_flushed(); +} + +void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { + switch (packet_start->chip_send_type) { + case tt::fabric::CHIP_UNICAST: { + DPRINT << "C_UNI: dist:" << (uint32_t) packet_start->routing_fields.chip_unicast.distance_in_hops << "\n"; + break; + } + case tt::fabric::CHIP_MULTICAST: { + DPRINT << "C_MCST: dist:" << (uint32_t) packet_start->routing_fields.chip_mcast.start_distance_in_hops << + ", rng:" << (uint32_t) packet_start->routing_fields.chip_mcast.range_hops << "\n"; + break; + } + }; +} + +void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet_start) { + switch (packet_start->noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + switch (packet_start->command_type) { + case tt::fabric::CommandType::WRITE: { + DPRINT << "N_WR addr:"<<(uint32_t)packet_start->command_fields.unicast_write.address << + ", size:" << (uint32_t) packet_start->command_fields.unicast_write.size << + ", x:" << (uint32_t) packet_start->command_fields.unicast_write.noc_x << + ", y:" << (uint32_t) packet_start->command_fields.unicast_write.noc_y << "\n"; + } break; + case tt::fabric::CommandType::ATOMIC_INC: { + DPRINT << "N_WR addr:"<<(uint32_t)packet_start->command_fields.unicast_seminc.address << + ", val:" << (uint32_t) packet_start->command_fields.unicast_seminc.val << + ", x:" << (uint32_t) packet_start->command_fields.unicast_seminc.noc_x << + ", y:" << (uint32_t) packet_start->command_fields.unicast_seminc.noc_y << "\n"; + + } break; + } + break; + } + case tt::fabric::NocSendType::NOC_MULTICAST: { + break; + } + } +} + +void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { + auto const& header = *packet_start; + DPRINT << "PKT: cmd_t:" << (uint32_t) packet_start->command_type << + ", csnd_t:" << (uint32_t) packet_start->chip_send_type << + ", nsnd_t:" << (uint32_t) packet_start->noc_send_type << "\n"; + print_pkt_hdr_routing_fields(packet_start); + print_pkt_header_noc_fields(packet_start); +} + + +// Since we unicast to local, we must omit the packet header +void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { + auto const& header = *packet_start; + uint32_t payload_start_address = reinterpret_cast(packet_start) + sizeof(tt::fabric::PacketHeader); + + tt::fabric::CommandType command_type = packet_start->command_type; + tt::fabric::NocSendType noc_send_type = packet_start->noc_send_type; + switch (command_type) { + case tt::fabric::CommandType::WRITE: { + switch (noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + auto const dest_address = get_noc_addr( + header.command_fields.unicast_write.noc_x, + header.command_fields.unicast_write.noc_y, + header.command_fields.unicast_write.address); + auto const size = header.command_fields.unicast_write.size - sizeof(tt::fabric::PacketHeader); + write_unicast_blocking(payload_start_address, dest_address, size); + + }break; + case tt::fabric::NocSendType::NOC_MULTICAST: { + // TODO: confirm if we need to adjust dest core count if we span eth or dram cores + auto const mcast_dest_address = get_noc_multicast_addr( + header.command_fields.mcast_write.noc_x_start, + header.command_fields.mcast_write.noc_y_start, + header.command_fields.mcast_write.noc_x_start + header.command_fields.mcast_write.mcast_rect_size_x, + header.command_fields.mcast_write.noc_y_start + header.command_fields.mcast_write.mcast_rect_size_y, + header.command_fields.mcast_write.address); + auto const num_dests = header.command_fields.mcast_write.mcast_rect_size_x * header.command_fields.mcast_write.mcast_rect_size_y; + auto const size = header.command_fields.mcast_write.size - sizeof(tt::fabric::PacketHeader); + noc_async_write_multicast_one_packet(payload_start_address, mcast_dest_address, size, num_dests); + noc_async_writes_flushed(); + + }break; + default: { + ASSERT(false); + } + } + break; + } + case tt::fabric::CommandType::ATOMIC_INC: { + switch (noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + auto const dest_address = get_noc_addr( + header.command_fields.unicast_seminc.noc_x, + header.command_fields.unicast_seminc.noc_y, + header.command_fields.unicast_seminc.address); + auto const increment = header.command_fields.unicast_seminc.val; + noc_semaphore_inc(dest_address, increment); + + }break; + case tt::fabric::NocSendType::NOC_MULTICAST: { + ASSERT(false); + // noc_async_write(payload_start_address, header.dest_address, header.size_bytes); + + }break; + default: { + ASSERT(false); + } + } + break; + + }; + + default: { + ASSERT(false); + } + }; +} + + + +void update_packet_header_for_next_hop(volatile tt::fabric::PacketHeader * packet_header) { + switch (packet_header->chip_send_type) { + case tt::fabric::CHIP_UNICAST: { + packet_header->routing_fields.chip_unicast.distance_in_hops--; + } break; + case tt::fabric::CHIP_MULTICAST: { + if (packet_header->routing_fields.chip_mcast.start_distance_in_hops == 0) { + packet_header->routing_fields.chip_mcast.range_hops--; + } else { + packet_header->routing_fields.chip_mcast.start_distance_in_hops--; + } + } break; + } +} + +// This function forwards a packet to the downstream EDM channel for eventual sending +// to the next chip in the line/ring +// +// Modifies the packet header (decrements hop counts) so ... +// +// !!!WARNING!!! +// !!!WARNING!!! do NOT call before determining if the packet should be consumed locally or forwarded +// !!!WARNING!!! +tt::fabric::SendStatus forward_payload_to_downstream_edm( + volatile tt::fabric::PacketHeader *packet_header, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface + ) { + // SHOULD BE ABLE TO ASSERT ON THIS SINCE WE CHECK FOR THIS IN THE CALLER + // TODO: PERF + bool safe_to_send = downstream_edm_interface.consumer_has_space(); + if (!safe_to_send) { + return tt::fabric::SendStatus::NOT_SENT; + } + + // print_pkt_header(packet_header); + ASSERT(const_cast(packet_header)->get_payload_size_including_header() < 100000) + update_packet_header_for_next_hop(packet_header); + + downstream_edm_interface.send_payload_blocking_from_address( + reinterpret_cast(packet_header), + const_cast(packet_header)->get_payload_size_including_header()); + + return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; +} + +void execute_chip_multicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { + ASSERT(false); +} + +bool packet_must_be_consumed_locally(tt::fabric::PacketHeader const& packet_header) { + switch (packet_header.chip_send_type) { + case tt::fabric::ChipSendType::CHIP_UNICAST: { + // TODO: does it make more sense to have 0 as the terminating distance or 1? + // depends where we want to do the decrement and what the starting value + // is expected to be for worker + // Maybe at API level we just always decrement by 1 under the hood + // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) + return packet_header.routing_fields.chip_unicast.distance_in_hops == 0; + } + case tt::fabric::ChipSendType::CHIP_MULTICAST: { + return packet_header.routing_fields.chip_mcast.start_distance_in_hops == 0; + } + default: { + ASSERT(false); + return false; + } + } +} + + +bool packet_must_be_forwarded_to_next_chip(tt::fabric::PacketHeader const& packet_header) { + switch (packet_header.chip_send_type) { + case tt::fabric::ChipSendType::CHIP_UNICAST: { + // TODO: does it make more sense to have 0 as the terminating distance or 1? + // depends where we want to do the decrement and what the starting value + // is expected to be for worker + // Maybe at API level we just always decrement by 1 under the hood + // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) + return packet_header.routing_fields.chip_unicast.distance_in_hops != 0; + } + case tt::fabric::ChipSendType::CHIP_MULTICAST: { + return packet_header.routing_fields.chip_mcast.range_hops != 0; + } + default: { + ASSERT(false); + return false; + } + } +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp new file mode 100644 index 00000000000..2366c8758de --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include + +namespace tt::fabric { +enum BlockingMode: uint8_t { + // + BUSY_WAIT_BLOCKING, + + // will wait and allow context switching + CTX_SWITCH_BLOCKING, + + // function will early exist if not able to send + NON_BLOCKING +}; + +enum SendStatus : uint8_t { + // Indicates that the sender was able to send the payload + // but was not able to send the channel_sync_t at the end of the + // buffer + // + // This enum should only ever be returned if we are sending less than + // a full packet/buffer of data AND when we are trying to send the + // channel_sync_t at the end of the buffer (which must be as a separate + // command) but the eth_tx_cmd_q is busy for that second message + // + // Receiving this value indicates we + // MUST: + // - Eventually send the channel_sync_t before advancing to the next buffer + // MUST NOT: + // - Advance to the next buffer index + // - Forward the other sender channel's data (if it has any) + SENT_PAYLOAD_ONLY, + + // Indicates both the payload and the channel sync were sent successfully + SENT_PAYLOAD_AND_SYNC, + + // Indicates no data was sent because the eth_tx_cmd_q was busy + NOT_SENT, + + ERROR, +}; + +struct EDMChannelWorkerLocationInfo { + uint32_t worker_semaphore_address; + ttnn::ccl::WorkerXY worker_xy; +}; + +static_assert(sizeof(EDMChannelWorkerLocationInfo) <= 16); + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp new file mode 100644 index 00000000000..d105e2bf6d0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -0,0 +1,881 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "dataflow_api.h" +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +using ttnn::ccl::WorkerXY; + +/* + +The fabric Erisc Data Mover (EDM) is a component that can be used to build *very* simple linear topology fabrics. +One of these EDMs can be instantiated on each ethernet link. It is built from 3 "channels" (though the definition +of channel here is a little loose since two of the 3 will merge traffic, so this setup could be interpreted as a +two channel setup.). This EDM implements packet based packets only - concepts like sockets are not supported. + +## EDM Structure + +There are two sender channels and one receiver channel. "Sender" and "receiver" are relative to the Ethernet link, +not the chip. Sender sends over the link and receiver receives from the link. + +Each sender channel serves a different purpose: +- Sender channel 0 : Accepts packets from a workers on the local chip +- Sender channel 1: accepts packets from an upstream EDM (i.e. an upstream + EDM receiver channel on the same chip but different core) + +The receiver channel accepts packets from the Ethernet link and can do one (or both) of: +- Write the packet to local chhip if it is the intended destination (unicast or mcast) +- Forward the packet to the next chip in the line if: + - Unicast and not the target chip + - Multicast and this chip is in the multicast target range + +Sender channels will merge traffic into the remote EDM's receiver channel. + +Below is a diagram that shows how EDMs can be connected over an ethernet link. In this case, the two +EDM kernels are run on separate, but connected ethernet link cores. + + ┌───────────────────────┐ ┌───────────────────────┐ + │ Sender Channel 0 │ │ Receiver Channel │ + │ ┌────────────────┐ │ │ ┌────────────────┐ │ + │ │ ┼──┼───┬───────┼───► │ │ + │ │ │ │ │ │ │ │ │ + │ └────────────────┘ │ │ │ └────────────────┘ │ + │ Sender Channel 1 │ │ │ Sender Channel 1 │ + │ ┌────────────────┐ │ │ │ ┌────────────────┐ │ + │ │ ┼──┼───┘ │ │ │ │ + │ │ │ │ ┌─┼───┼ │ │ + │ └────────────────┘ │ │ │ └────────────────┘ │ + │ Receiver Channel │ │ │ Sender Channel 0 │ + │ ┌────────────────┐ │ │ │ ┌────────────────┐ │ + │ │ │ │ │ │ │ │ │ + │ │ ◄──┼─────────┴─┼───┼ │ │ + │ └────────────────┘ │ │ └────────────────┘ │ + │ │ │ │ + │ │ │ │ + └───────────────────────┘ └───────────────────────┘ + + +## Building a "Fabric" + +At present, only linear topologies are supported, and one per ethernet link along that given line. +Below shows the intended connectivity of EDMs across chips in a hypothetical 3-chip fabric. For longer +lines, the pattern would be extended. + + CHIP 0 CHIP 1 CHIP 2 + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ +┌────┴─────┐ ▲ ┌─────┴────┐ ┌────┴─────┐ ▲ ┌─────┴────┐ ┌────┴─────┐ ▲ ┌─────┴────┐ +│ EDM │ │ │ EDM │ │ EDM │ │ │ EDM │ │ EDM │ │ │ EDM │ +│ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ +│ │ Rx ┼─┼─┴───┼─► S1 ┼─┼─┬────┼─► Rx ┼─┼─┴───┼─► S1 ┼─┼┬─────┼─► Rx ┼─┼─┘ | | S1 │ │ +│ └──────┘ │ │ └──────┘ │ │ │ └──────┘ │ │ └──────┘ ││ │ └──────┘ │ │ └──────┘ │ +│ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ ││ │ ┌──────┐ │ │ ┌──────┐ │ +│ │ S0 ◄─┼──┬──┼─► S0 ┼─┼─┘ ┌┼─┼ S0 ◄─┼──┬──┼─► S0 ┼─┼┘ ┌┼─┼ S0 ◄─┼──┬──┼─► S0 │ │ +│ └──────┘ │ │ │ └──────┘ │ ││ └──────┘ │ │ │ └──────┘ │ ││ └──────┘ │ │ │ └──────┘ │ +│ ┌──────┐ │ │ │ ┌──────┐ │ ││ ┌──────┐ │ │ │ ┌──────┐ │ ││ ┌──────┐ │ │ │ ┌──────┐ │ +│ │ S1 | | │ ┌┼─┼ Rx ◄─┼─────┴┼─┼ S1 ◄─┼─┐│ ┌┼─┼ Rx ◄─┼─────┴┼─┼ S1 ◄─┼─┐│ ┌┼─┼ Rx │ │ +│ └──────┘ │ | |│ └──────┘ │ │ └──────┘ │ └┼─┤│ └──────┘ │ │ └──────┘ │ └┼─┤│ └──────┘ │ +└────┬─────┘ │ │└─────┬────┘ └────┬─────┘ │ │└─────┬────┘ └────┬─────┘ │ │└─────┬────┘ + │ ▼ │ │ ▼ │ │ ▼ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + + +## Connecting Workers to Channels + +As mentioned, only one worker can push to a given EDM sender channel at a time. In order to send to an EDM +sender channel, the worker must establish a connection. The connection protocol is as follows and is started +by the worker (the EDM is a slave in this protocol). + +*NOTE*: If multiple workers try to connect to the same EDM sender channel at the same time, the behavior is undefined. +*NOTE*: Additionally, if a worker pushes packets to a channel it isn't connected to, behaviour is undefined. +*NOTE*: Undefined == likely hang + +The `WorkerToFabricEdmSender` from `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp` +provides an implementation of the connection protocol. `WorkerToFabricEdmSender` also acts as a wrapper around that +protocol so workers can simply call `open()` to execute the connection protocol without having to manually reimplement +for each kernel. + +### Protocol +Worker: +- Read from EDM sender channel buffer_index address + - Required so that the worker knows where to write its first packet (since the channel may already contain packets from + a previous connection) +- Write worker core X/Y (NOC 0 based) +- Write worker flow control semaphore L1 address + +EDM Sender Channel: +- Check local connection valid semaphore for new established connection + - When the connection semaphore indicates an active connection, the channel assumes all other relevant fields were + correctly populated by the worker: + - Worker core_x (on NOC 0) + - Worker core_y (on NOC 0) + - Worker flow control semaphore L1 address + + +## Tearing Down Connections + +Every worker is required to explicitly teardown its connection with the EDM before terminating. To do this, the worker +must simply write a `0` to the EDM sender channel's connection semaphore address. As long as the worker has sent all +of its packets to the EDM before this, then the EDM will guarantee to forward the messages correctly. + +At this point, it is safe for another kernel to establish a connection. + +## Packet Structure + +Workers are responsible for populating packet headers before sending to the EDM. The packet header structure is defined +in `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp`. + +## Channel structure + +Each EDM channel is built from one or more buffers. Each buffer is the same size and can hold atmost one packet. +Neighbouring packets occupy nehighouring buffers - with the exception of the last buffer index. The next packet after a write +into the last buffer index will wrap around to the first buffer index. Even if packets do not occupy the full buffer, subsequent +packets will always be written into the next logical buffer. A gap will exist in memory but the EDM will not send that padded data +(unless it is more performant - which is possible in some special cases) + + Example channel with 8 buffers +┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ +│ │ │ │ │ │ │ │ │ +│ │ │ │ │ │ │ │ │ +└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ + buf 0 buf 1 buf 2 buf 3 buf 4 buf 5 buf 6 buf 7 + + +Here we have an example of a channel with 4 buffers, filled with some number of packets. Each packet is a different size. +Packets 0, 2, and 3 are smaller than the full buffer size, while packet 1 is the full buffer size. + +┌───────────────┬───────────────┬───────────────┬───────────────┐ +│H|Payload| / / │H|Payload │H|Pyld| / / / /│H|Payload |/ /│ +│ | |/ / /│ | │ | |/ / / / │ | | / │ +└───────────────┴───────────────┴───────────────┴───────────────┘ + buf 0 buf 1 buf 2 buf 3 + + +A detail of the channel structure is omitted from the above diagram, namely the EDM <-> EDM flow control region for each buffer. +Each buffer really looks something like this: + + + &header-> |----------------| channel_base_address + | header | + &payload-> |----------------| + | | + | payload | + | | + &channel_sync-> |----------------| + | channel_sync | // This is new + ------------------ + +The "channel_sync" is an `eth_channel_sync_t` and is internal to the EDM implementation and is used to indicate packet +transmission state between sender and receiver EDMs. + +The protocol for its use is: +1) Sender updates the field indicating new data: + - set `bytes_sent` to a non-zero value indicating new data + - clear `receiver_ack` to 0 + - set `src_id` to the sender channel id so the receiver knows who the sender was (and where the ack should go) +2) Sender sends this channel sync to the corresponding location in the receiver channel (either in the same transmission + as the packet or separately) +3) Receiver sees that `bytes_sent` is non-zero, indicating a new packet. It sends back an acknowledgement (first level): + - set `receiver_ack` to non-zero + *NOTE* IMPORTANT: To avoid a race, the receiver must be sure to send its channel_sync_t from a different address it uses + as for the second level acknowledgement + 3b) When sender receives an ack, it understands it can overwrite its local copy of the packet with new data +4) After receiver properly writes out its packet, it sends a second level acknowledgement, indicating it can receive new + data into this specific buffer index: + - clear the bytes_sent and receiver_ack fields and send back the `channel_sync` to the sender + + + +## Sending Packets +Sending a packet is done as follows: + +1) Worker waits for flow control semaphore increment from EDM sender channel + - Indicates there is space at the next buffer index for a packet +2) Worker performs a noc write of its packet to the EDM sender channel at the buffer index + +*NOTE*: !!!ALL PACKETS MUST CONTAIN DESTINATION NOC X/Y AS NOC 0 COORDINATES, REGARDLESS OF THE `noc_index` OF THE SENDER!!! + +*/ + +//////////////////////////////////////////////// +// Data structures, types, enums, and constants +//////////////////////////////////////////////// + +enum SenderState : uint8_t { + SENDER_DONE = 0, + + // we are ready to tell the worker(s) that the buffer is available for writing into + SENDER_SIGNALING_WORKER, + + // we are waiting for the payload to arrive in L1; we are checking local semaphore for worker + // completion + SENDER_WAITING_FOR_WORKER, + + // this state is enterred if the sender was able to send the payload but not the channel sync + SENDER_SEND_CHANNEL_SYNC, + + // Sender channel is not connected to a worker and is waiting for a new connection + SENDER_WAIT_WORKER_HANDSHAKE, + + // means we are waiting for ack from receiver that payload was received + SENDER_WAITING_FOR_ETH, + +}; + +enum ReceiverState : uint8_t { + RECEIVER_DONE = 0, + + // Receiver is processing the packet, either writing it locally or forwarding to the next EDM + // (toward next chip), or both + RECEIVER_SENDING_PAYLOAD, + + // Enter this state after performing writes of the current packet as a sort of soft barrier + // (for this channel only) so we can make progress on other channels while waiting for the + // writes to flush + RECEIVER_WAITING_FOR_WRITE_FLUSH, + + // means we are waitinf for a payload from sender + RECEIVER_WAITING_FOR_ETH, +}; + + +enum PacketLocalForwardType : uint8_t { + PACKET_FORWARD_INVALID = 0x0, + PACKET_FORWARD_LOCAL_ONLY = 0x1, + PACKET_FORWARD_REMOTE_ONLY = 0x2, + PACKET_FORWARD_LOCAL_AND_REMOTE = 0x3 +}; + +static constexpr uint32_t SWITCH_INTERVAL = 4000000; +static constexpr size_t ETH_BYTES_TO_WORDS_SHIFT = 4; +static constexpr size_t NUM_SENDER_CHANNELS = 2; +static constexpr size_t num_workers_ctor = 1; +static constexpr size_t num_messages_to_move_ctor_value = 1; +// Doesn't REALLY matter but for consistency I picked the next available ID +static constexpr size_t receiver_channel_id = NUM_SENDER_CHANNELS; +static constexpr size_t worker_info_offset_past_connection_semaphore = 32; + +///////////////////////////////////////////// +// SENDER SIDE HELPERS +///////////////////////////////////////////// + +FORCE_INLINE void sender_notify_workers_if_buffer_available_sequence( + tt::fabric::EdmChannelWorkerInterface &local_sender_worker_interface) { + local_sender_worker_interface.clear_local_semaphore(); + local_sender_worker_interface.increment_worker_semaphore(); +} + +template +void send_channel_sync( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + + eth_send_bytes_over_channel_payload_only_unsafe( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()), + reinterpret_cast(receiver_buffer_channel.get_current_bytes_sent_address()), + sizeof(eth_channel_sync_t), + sizeof(eth_channel_sync_t), + sizeof(eth_channel_sync_t) >> ETH_BYTES_TO_WORDS_SHIFT); +} + +template +tt::fabric::SendStatus send_next_data( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + + auto status = tt::fabric::SendStatus::NOT_SENT; + + ASSERT(!eth_txq_is_busy()); + + status = tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; + ASSERT( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == + (reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + + reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - + (uint32_t)sizeof(eth_channel_sync_t))); + *sender_buffer_channel.get_current_bytes_sent_address() = sender_buffer_channel.get_current_max_eth_payload_size(); + *sender_buffer_channel.get_current_bytes_acked_address() = 0; + *sender_buffer_channel.get_current_src_id_address() = sender_buffer_channel.get_id(); + ASSERT(*sender_buffer_channel.get_current_src_id_address() < 2); + + // TODO: TUNING - experiment with only conditionally breaking the transfer up into multiple packets if we are + // a certain threshold less than full packet + // we can precompute this value even on host and pass it in so we can get away with a single integer + // compare + // NOTE: if we always send full packet, then we don't need the second branch below dedicated for + // channel sync + tt::fabric::validate(*const_cast( + reinterpret_cast(receiver_buffer_channel.get_current_buffer_address()))); + const size_t payload_size = sender_buffer_channel.get_current_payload_plus_channel_sync_size(); + eth_send_bytes_over_channel_payload_only_unsafe( + sender_buffer_channel.get_current_buffer_address(), + receiver_buffer_channel.get_current_buffer_address(), // get_remote_eth_buffer_address(), + payload_size, + payload_size, + payload_size >> ETH_BYTES_TO_WORDS_SHIFT); + + bool sent_payload_and_channel_sync_in_one_shot = + payload_size == sender_buffer_channel.get_channel_buffer_max_size_in_bytes(); + if (!sent_payload_and_channel_sync_in_one_shot) { + // We weren't able to send the channel_sync_t in one shot with the payload so we need to send a second + // packet + // TODO: TUNING - consider busy waiting for a maximum amount of time + if (!eth_txq_is_busy()) { + send_channel_sync(sender_buffer_channel, receiver_buffer_channel); + } else { + status = tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; + } + } + + // Note: We can only advance to the next buffer index if we have fully completed the send (both the payload and sync + // messages) + if (status == tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC) { + sender_buffer_channel.advance_buffer_index(); + receiver_buffer_channel.advance_buffer_index(); + } + + return status; +} + +template +FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + return sender_buffer_channel.is_local_semaphore_full(); +} + +template +FORCE_INLINE void sender_eth_check_receiver_ack_sequence( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EdmChannelWorkerInterface &sender_worker_interface) { + sender_buffer_channel.eth_clear_sender_channel_ack(); + + sender_notify_workers_if_buffer_available_sequence(sender_worker_interface); +} + +///////////////////////////////////////////// +// RECEIVER SIDE HELPERS +///////////////////////////////////////////// + +template +FORCE_INLINE bool new_unacknowledged_packet_avilable_on_reciever_channel( + tt::fabric::EthChannelBuffer &local_receiver_channel) { + return local_receiver_channel.eth_bytes_are_available_on_channel(); +} + +/* + * Acting the receiver, we are looking at our receiver channel and acking the sender who sent us the latest packet. + * Doesn't check to see if indeed a new message is available. It's assumed the caller has handled that separately. + */ +// MUST CHECK !is_eth_txq_busy() before calling +template +void receiver_send_received_ack( + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { + // Set the acknowledgement bits. We have a different location than the + + const auto src_id = *local_receiver_buffer_channel.get_current_src_id_address(); + ASSERT(src_id < NUM_SENDER_CHANNELS); + auto &sender_buffer_channel = remote_sender_channels[src_id]; + ASSERT( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == + reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + + reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - + sizeof(eth_channel_sync_t)); + + const size_t local_ack_channel_sync_src_addr = + local_receiver_buffer_channel.get_eth_transaction_ack_word_addr() + (src_id * sizeof(eth_channel_sync_t)); + reinterpret_cast(local_ack_channel_sync_src_addr)->bytes_sent = + *local_receiver_buffer_channel.get_current_bytes_sent_address(); + reinterpret_cast(local_ack_channel_sync_src_addr)->receiver_ack = 1; + reinterpret_cast(local_ack_channel_sync_src_addr)->src_id = + *local_receiver_buffer_channel.get_current_src_id_address(); + + // Make sure we don't alias the erisc_info eth_channel_sync_t + ASSERT( + reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + ->bytes_sent != 0); + ASSERT( + reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + ->receiver_ack == 0); + + ASSERT(!eth_txq_is_busy()); + internal_::eth_send_packet_unsafe( + 0, + local_ack_channel_sync_src_addr >> 4, + ((uint32_t)(sender_buffer_channel.get_current_bytes_sent_address())) >> 4, + 1); +} + +// MUST CHECK !is_eth_txq_busy() before calling +template +FORCE_INLINE void receiver_send_completion_ack( + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { + volatile auto local_bytes_sent_addr = local_receiver_buffer_channel.get_current_bytes_sent_address(); + volatile auto local_src_id_ptr = local_receiver_buffer_channel.get_current_src_id_address(); + + auto src_sender_channel = *local_src_id_ptr; + *(local_bytes_sent_addr) = 0; + *(local_receiver_buffer_channel.get_current_bytes_acked_address()) = 0; + ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); + + ASSERT(!eth_txq_is_busy()); + internal_::eth_send_packet_unsafe( + 0, + (uint32_t)(local_bytes_sent_addr) >> 4, + (uint32_t)(remote_sender_channels[src_sender_channel].get_current_bytes_sent_address()) >> 4, + 1); + + local_receiver_buffer_channel.advance_buffer_index(); + remote_sender_channels[src_sender_channel].advance_buffer_index(); +} + + +PacketLocalForwardType get_packet_local_forward_type(const tt::fabric::PacketHeader &packet_header) { + const bool local_chip_is_packet_destination = packet_must_be_consumed_locally(packet_header); + const bool packet_needs_forwarding = packet_must_be_forwarded_to_next_chip(packet_header); + PacketLocalForwardType forward_type = + static_cast(packet_needs_forwarding << 1 | local_chip_is_packet_destination); + return forward_type; +} + +FORCE_INLINE bool can_forward_packet_completely( + const tt::fabric::PacketHeader &packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + auto forward_status = get_packet_local_forward_type(packet_header); + bool can_send = true; + switch (forward_status) { + case PACKET_FORWARD_INVALID: return false; + case PACKET_FORWARD_LOCAL_ONLY: return true; + + case PACKET_FORWARD_REMOTE_ONLY: + case PACKET_FORWARD_LOCAL_AND_REMOTE: return downstream_edm_interface.consumer_has_space(); + default: ASSERT(false); return false; + }; +} + +// template +tt::fabric::SendStatus receiver_forward_packet( + volatile tt::fabric::PacketHeader *packet_start, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + // Just cache the packet_header - we don't really expect (or care) if contents change during this function. + tt::fabric::PacketHeader const &packet_header = *const_cast(packet_start); + tt::fabric::validate(packet_header); + auto forward_status = get_packet_local_forward_type(packet_header); + + switch (forward_status) { + case PACKET_FORWARD_LOCAL_ONLY: { + execute_chip_unicast_to_local_chip(packet_start); + return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; + } break; + + case PACKET_FORWARD_REMOTE_ONLY: { + return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + } break; + + case PACKET_FORWARD_LOCAL_AND_REMOTE: { + ASSERT(packet_header.chip_send_type == tt::fabric::ChipSendType::CHIP_MULTICAST); + // TODO: make local chip write non-blocking + execute_chip_unicast_to_local_chip(packet_start); + return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + } break; + + case PACKET_FORWARD_INVALID: + default: ASSERT(false); return tt::fabric::SendStatus::ERROR; + }; +} + +//////////////////////////////////// +//////////////////////////////////// +// Main Control Loop +//////////////////////////////////// +//////////////////////////////////// +template +bool run_sender_channel_state_machine_step( + tt::fabric::EthChannelBuffer &local_sender_channel, + tt::fabric::EdmChannelWorkerInterface &local_sender_channel_worker_interface, + tt::fabric::EthChannelBuffer &remote_receiver_channel, + SenderState *const sender_state_out) { + bool incr_sender_channel_index = true; + switch (*sender_state_out) { + case SenderState::SENDER_WAITING_FOR_WORKER: { + bool able_to_send = local_sender_channel_worker_interface.has_payload() && !eth_txq_is_busy() && + local_sender_channel.eth_is_receiver_channel_send_done(); + if (able_to_send) { + auto send_status = send_next_data(local_sender_channel, remote_receiver_channel); + // TODO: align the enums and state values so I can just do + // sender_states[sender_channel_index] += send_status :) + ASSERT(send_status != tt::fabric::SendStatus::ERROR); + *sender_state_out = + send_status == tt::fabric::SendStatus::NOT_SENT ? SenderState::SENDER_WAITING_FOR_WORKER + : send_status == tt::fabric::SendStatus::SENT_PAYLOAD_ONLY ? SenderState::SENDER_SEND_CHANNEL_SYNC + : SenderState::SENDER_WAITING_FOR_ETH; + // Avoid any sort of starvation/bubbles so we only advance if we've sent the packet and channel sync + // otherwise what can happen is we could start sending another large payload from the other channel + // and not be able to send the channel sync for the packet we just sent, which overall negatively + // impact latency + incr_sender_channel_index = send_status != tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; + } else { + if (local_sender_channel_worker_interface.has_worker_teardown_request()) { + local_sender_channel_worker_interface.teardown_connection(); + *sender_state_out = SenderState::SENDER_WAIT_WORKER_HANDSHAKE; + } + } + } break; + + case SenderState::SENDER_WAIT_WORKER_HANDSHAKE: + if (local_sender_channel_worker_interface.connection_is_live()) { + bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || + local_sender_channel.eth_is_receiver_channel_send_done(); + if (is_safe_to_receive_next_message) { + sender_notify_workers_if_buffer_available_sequence(local_sender_channel_worker_interface); + *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; + } else { + *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + } + } + break; + + case SenderState::SENDER_SEND_CHANNEL_SYNC: { + bool can_send_channel_sync_without_blocking = !eth_txq_is_busy(); + if (can_send_channel_sync_without_blocking) { + send_channel_sync(local_sender_channel, remote_receiver_channel); + local_sender_channel.advance_buffer_index(); + remote_receiver_channel.advance_buffer_index(); + *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + } + } break; + + case SenderState::SENDER_WAITING_FOR_ETH: { + bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || + local_sender_channel.eth_is_receiver_channel_send_done(); + if (is_safe_to_receive_next_message) { + // This also notifies workers in the same call + sender_eth_check_receiver_ack_sequence(local_sender_channel, local_sender_channel_worker_interface); + *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; + } + } break; + + default: break; + }; + + return incr_sender_channel_index; +}; + +template +void run_receiver_channel_state_machine_step( + tt::fabric::EthChannelBuffer &local_receiver_channel, + std::array, NUM_SENDER_CHANNELS> &remote_sender_channnels, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + ReceiverState *const receiver_state_out) { + switch (*receiver_state_out) { + case ReceiverState::RECEIVER_WAITING_FOR_ETH: { + bool got_payload = local_receiver_channel.eth_bytes_are_available_on_channel(); + if (got_payload) { + bool can_ack = !eth_txq_is_busy(); + if (can_ack) { + tt::fabric::validate( + *const_cast(local_receiver_channel.get_current_packet_header())); + ASSERT( + local_receiver_channel.get_current_packet_header()->command_fields.unicast_write.size < 100000); + receiver_send_received_ack(remote_sender_channnels, local_receiver_channel); + // TODO: PERF Need to add feature to let use perform local noc write and defer the forward to EDM + // if we are mcasting to the local chip and neighbours, but the downstream EDM isn't currently able + // to accept the packet + // ... + // but as a starting point we can do the dumb thing and just wait for space downstream + // before we do either. + *receiver_state_out = ReceiverState::RECEIVER_SENDING_PAYLOAD; + // TODO: PERF - SHORT CIRCUIT IF WE CAN TO NESXT STATE TO MINIMIZE LATENCY BUT CURRENTLY + // A LITTLE CODE SIZE BOUND + } + } + } break; + + case ReceiverState::RECEIVER_SENDING_PAYLOAD: { + auto packet_header = + *const_cast(local_receiver_channel.get_current_packet_header()); + bool can_send_to_all_local_chip_receivers = + can_forward_packet_completely(packet_header, downstream_edm_interface); + if (can_send_to_all_local_chip_receivers) { + receiver_forward_packet(local_receiver_channel.get_current_packet_header(), downstream_edm_interface); + *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH; + } + } break; + + case ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH: { + bool writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); + if (writes_flushed) { + bool can_send_ack_without_blocking = !eth_txq_is_busy(); + if (can_send_ack_without_blocking) { + receiver_send_completion_ack(remote_sender_channnels, local_receiver_channel); + *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_ETH; + } + } + } break; + + default: break; + }; +}; + + +/* Termination signal handling*/ +FORCE_INLINE bool got_immediate_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return *termination_signal_ptr == tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE; +} +FORCE_INLINE bool got_graceful_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return *termination_signal_ptr == tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE; +} +FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return got_immediate_termination_signal(termination_signal_ptr) || + got_graceful_termination_signal(termination_signal_ptr); +} + +/* + * Main control loop for fabric EDM. Run indefinitely until a termination signal is received + * + * Every loop iteration visit a sender channel and the receiver channel. Switch between sender + * channels every iteration unless it is unsafe/undesirable to do so (e.g. for performance reasons). + */ +template +void run_fabric_edm_main_loop( + tt::fabric::EthChannelBuffer &local_receiver_channel, + std::array, NUM_SENDER_CHANNELS> &local_sender_channels, + std::array &local_sender_channel_worker_interfaces, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_noc_interface, + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &remote_receiver_channel, + volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + + std::array sender_states = { + SenderState::SENDER_WAIT_WORKER_HANDSHAKE, SenderState::SENDER_WAIT_WORKER_HANDSHAKE}; + ReceiverState receiver_state = ReceiverState::RECEIVER_WAITING_FOR_ETH; + size_t sender_channel_index = 0; + size_t did_nothing_count = 0; + *termination_signal_ptr = tt::fabric::TerminationSignal::KEEP_RUNNING; + + while (!got_termination_signal(termination_signal_ptr)) { + auto &local_sender_channel = local_sender_channels[sender_channel_index]; + auto &local_sender_channel_worker_interface = local_sender_channel_worker_interfaces[sender_channel_index]; + // There are some cases, mainly for performance, where we don't want to switch between sender channels + // so we interoduce this to provide finer grain control over when we disable the automatic switching + bool incr_sender_channel_index = run_sender_channel_state_machine_step( + local_sender_channel, + local_sender_channel_worker_interface, + remote_receiver_channel, + &(sender_states[sender_channel_index])); + if (incr_sender_channel_index) { + // TODO: this can probably be optimized + sender_channel_index = 1 - sender_channel_index; + } + + run_receiver_channel_state_machine_step( + local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, &receiver_state); + + if (did_nothing_count++ > SWITCH_INTERVAL) { + did_nothing_count = 0; + run_routing(); + } + } +} + +void kernel_main() { + // + // COMMON CT ARGS (not specific to sender or receiver) + // + static constexpr bool is_handshake_sender = get_compile_time_arg_val(0) != 0; + static constexpr size_t handshake_addr = get_compile_time_arg_val(1); + auto eth_transaction_ack_word_addr = handshake_addr + sizeof(eth_channel_sync_t); + + if constexpr (is_handshake_sender) { + erisc::datamover::handshake::sender_side_start(handshake_addr); + } else { + erisc::datamover::handshake::receiver_side_start(handshake_addr); + } + + // the size of one of the buffers within a sender channel + // For example if `channel_buffer_size` = 4k, with `SENDER_NUM_BUFFERS` = 2 + // then the total amount of buffering for that + static constexpr size_t channel_buffer_size = get_compile_time_arg_val(2); + + static constexpr size_t SENDER_NUM_BUFFERS = get_compile_time_arg_val(3); + static constexpr size_t RECEIVER_NUM_BUFFERS = get_compile_time_arg_val(4); + static constexpr size_t local_sender_0_channel_address = get_compile_time_arg_val(5); + static constexpr size_t local_sender_channel_0_connection_buffer_index_addr = get_compile_time_arg_val(6); + static constexpr size_t local_sender_channel_0_connection_info_addr = get_compile_time_arg_val(7); + static constexpr size_t local_sender_1_channel_address = get_compile_time_arg_val(8); + static constexpr size_t local_sender_channel_1_connection_buffer_index_addr = get_compile_time_arg_val(9); + static constexpr size_t local_sender_channel_1_connection_info_addr = get_compile_time_arg_val(10); + static constexpr size_t local_receiver_channel_buffer_address = get_compile_time_arg_val(11); + static constexpr size_t remote_receiver_channel_buffer_address = get_compile_time_arg_val(12); + static constexpr size_t remote_sender_0_channel_address = get_compile_time_arg_val(13); + static constexpr size_t remote_sender_1_channel_address = get_compile_time_arg_val(14); + + // TODO: CONVERT TO SEMAPHORE + volatile auto termination_signal_ptr = + reinterpret_cast(get_compile_time_arg_val(15)); + + static_assert(SENDER_NUM_BUFFERS > 0, "compile time argument [1]: SENDER_NUM_BUFFERS must be > 0"); + static_assert(RECEIVER_NUM_BUFFERS > 0, "compile time argument [2]: RECEIVER_NUM_BUFFERS must be > 0"); + + *reinterpret_cast(local_sender_channel_0_connection_buffer_index_addr) = 0; + *reinterpret_cast(local_sender_channel_1_connection_buffer_index_addr) = 0; + + size_t arg_idx = 0; + /////////////////////// + // Common runtime args: + /////////////////////// + + const size_t local_sender_channel_0_connection_semaphore_addr = + get_semaphore(get_arg_val(arg_idx++)); + const size_t local_sender_channel_1_connection_semaphore_addr = + get_semaphore(get_arg_val(arg_idx++)); + // downstream EDM semaphore location + const bool has_downstream_edm_buffer_connection = get_arg_val(arg_idx++) != 0; + const auto downstream_edm_buffer_base_address = get_arg_val(arg_idx++); + const auto downstream_edm_noc_x = get_arg_val(arg_idx++); + const auto downstream_edm_noc_y = get_arg_val(arg_idx++); + + // remote address for flow control + const auto downstream_edm_semaphore_id = get_arg_val(arg_idx++); // TODO: Convert to semaphore ID + const auto downstream_edm_worker_registration_address = + get_semaphore(get_arg_val(arg_idx++)); + const auto downstream_edm_worker_location_info_address = get_arg_val(arg_idx++); + const auto downstream_noc_interface_buffer_index_addr = + get_semaphore(get_arg_val(arg_idx++)); + + // Receiver channels local semaphore for managing flow control with the downstream EDM. + // The downstream EDM should be sending semaphore updates to this address any time it can + // accept a new message + const auto edm_forwarding_semaphore_address = + get_semaphore(get_arg_val(arg_idx++)); + + //////////////////////// + // Sender runtime args + //////////////////////// + auto sender0_worker_semaphore_ptr = reinterpret_cast( + get_semaphore(get_arg_val(arg_idx++))); + auto sender1_worker_semaphore_ptr = reinterpret_cast( + get_semaphore(get_arg_val(arg_idx++))); + *sender0_worker_semaphore_ptr = 0; + *sender1_worker_semaphore_ptr = 0; + + ////////////////////////////// + ////////////////////////////// + // Object Setup + ////////////////////////////// + ////////////////////////////// + + auto const &local_sender_buffer_addresses = + std::array{local_sender_0_channel_address, local_sender_1_channel_address}; + auto const &remote_sender_buffer_addresses = + std::array{remote_sender_0_channel_address, remote_sender_1_channel_address}; + std::array, NUM_SENDER_CHANNELS> remote_sender_channels; + std::array, NUM_SENDER_CHANNELS> local_sender_channels; + std::array local_sender_channel_worker_interfaces; + std::array local_sender_flow_control_semaphores = { + reinterpret_cast(sender0_worker_semaphore_ptr), reinterpret_cast(sender1_worker_semaphore_ptr)}; + std::array local_sender_connection_live_semaphore_addresses = { + local_sender_channel_0_connection_semaphore_addr, local_sender_channel_1_connection_semaphore_addr}; + std::array local_sender_connection_info_addresses = { + local_sender_channel_0_connection_info_addr, local_sender_channel_1_connection_info_addr}; + auto downstream_edm_noc_interface = + has_downstream_edm_buffer_connection + ? tt::fabric::WorkerToFabricEdmSender( + downstream_edm_noc_x, + downstream_edm_noc_y, + downstream_edm_buffer_base_address, + SENDER_NUM_BUFFERS, + downstream_edm_semaphore_id, + downstream_edm_worker_registration_address, // edm_connection_handshake_addr, + downstream_edm_worker_location_info_address, + channel_buffer_size, + reinterpret_cast(edm_forwarding_semaphore_address), + downstream_noc_interface_buffer_index_addr) + : tt::fabric::WorkerToFabricEdmSender(); + + auto local_receiver_channel = tt::fabric::EthChannelBuffer( + local_receiver_channel_buffer_address, + channel_buffer_size, + tt::fabric::header_size_bytes, + eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory that + // can fit 2 eth_channel_syncs cfor ack + receiver_channel_id); + auto remote_receiver_channel = tt::fabric::EthChannelBuffer( + remote_receiver_channel_buffer_address, + channel_buffer_size, + tt::fabric::header_size_bytes, + eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory that + // can fit 2 eth_channel_syncs cfor ack + receiver_channel_id); + + uint32_t args_offset = 0; + + for (uint8_t i = 0; i < NUM_SENDER_CHANNELS; i++) { + new (&local_sender_channels[i]) tt::fabric::EthChannelBuffer( + local_sender_buffer_addresses[i], + channel_buffer_size, + tt::fabric::header_size_bytes, + 0, // For sender channels there is no eth_transaction_ack_word_addr because they don't send acks + i); + new (&remote_sender_channels[i]) tt::fabric::EthChannelBuffer( + remote_sender_buffer_addresses[i], + channel_buffer_size, + tt::fabric::header_size_bytes, + 0, // For sender channels there is no eth_transaction_ack_word_addr because they don't send acks + i); + + auto connection_live_semaphore_ptr = + reinterpret_cast(local_sender_connection_live_semaphore_addresses[i]); + auto connection_worker_info_ptr = reinterpret_cast( + local_sender_connection_info_addresses[i]); + new (&local_sender_channel_worker_interfaces[i]) tt::fabric::EdmChannelWorkerInterface( + connection_worker_info_ptr, // worker_location_info_ptr, + reinterpret_cast( + local_sender_flow_control_semaphores[i]), // local_semaphore_address, + reinterpret_cast(connection_live_semaphore_ptr)); + } + + if (has_downstream_edm_buffer_connection) { + downstream_edm_noc_interface.open(); + } + + if constexpr (is_handshake_sender) { + erisc::datamover::handshake::sender_side_finish(handshake_addr); + } else { + erisc::datamover::handshake::receiver_side_finish(handshake_addr); + } + + ////////////////////////////// + ////////////////////////////// + // MAIN LOOP + ////////////////////////////// + ////////////////////////////// + run_fabric_edm_main_loop( + local_receiver_channel, + local_sender_channels, + local_sender_channel_worker_interfaces, + downstream_edm_noc_interface, + remote_sender_channels, + remote_receiver_channel, + termination_signal_ptr); + + if (got_graceful_termination_signal(termination_signal_ptr)) { + ASSERT(false); + } else { + // So long suckers! + } + + WAYPOINT("DONE"); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp new file mode 100644 index 00000000000..90f2c692aa2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -0,0 +1,225 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "debug/dprint.h" +#include "tt_metal/hw/inc/dataflow_api.h" +#include "tt_metal/hw/inc/ethernet/tunneling.h" +#include "tt_metal/hw/inc/risc_attribs.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +namespace tt::fabric { +// Increments val and wraps to 0 if it reaches limit +template +auto wrap_increment(T val) -> T { + static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); + if constexpr (LIMIT == 1) { + return val; + } else if constexpr (LIMIT == 2) { + return 1 - val; + } else if constexpr ((LIMIT > 0) && (LIMIT & (LIMIT - 1)) == 0) { + return (val + 1) & (LIMIT - 1); + } else { + return (val == LIMIT - 1) ? 0 : val + 1; + } +} + +template +FORCE_INLINE auto wrap_increment(T val, size_t max) { + return (val == max - 1) ? 0 : val + 1; +} + +template +class EthChannelBuffer final { + public: + // The channel structure is as follows: + // &header-> |----------------| channel_base_address + // | header | + // &payload-> |----------------| + // | | + // | payload | + // | | + // &channel_sync-> |----------------| + // | channel_sync | + // ------------------ + EthChannelBuffer() : buffer_size_in_bytes(0), eth_transaction_ack_word_addr(0), max_eth_payload_size_in_bytes(0) {} + + /* + * Expected that *buffer_index_ptr is initialized outside of this object + */ + EthChannelBuffer( + size_t channel_base_address, + size_t buffer_size_bytes, + size_t header_size_bytes, + size_t eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory + // that can fit 2 eth_channel_syncs cfor ack + uint8_t channel_id) : + buffer_size_in_bytes(buffer_size_bytes), + eth_transaction_ack_word_addr(eth_transaction_ack_word_addr), + max_eth_payload_size_in_bytes(buffer_size_in_bytes + sizeof(eth_channel_sync_t)), + buff_idx(0), + channel_id(channel_id) { + for (uint8_t i = 0; i < NUM_BUFFERS; i++) { + this->buffer_addresses[i] = + channel_base_address + i * this->max_eth_payload_size_in_bytes; //(this->buffer_size_in_bytes); + + uint32_t channel_sync_addr = this->buffer_addresses[i] + buffer_size_in_bytes; + auto channel_sync_ptr = reinterpret_cast(channel_sync_addr); + + channel_bytes_sent_addresses[i] = + reinterpret_cast(&(channel_sync_ptr->bytes_sent)); + channel_bytes_acked_addresses[i] = + reinterpret_cast(&(channel_sync_ptr->receiver_ack)); + channel_src_id_addresses[i] = reinterpret_cast(&(channel_sync_ptr->src_id)); + + ASSERT((uint32_t)channel_bytes_acked_addresses[i] != (uint32_t)(channel_bytes_sent_addresses[i])); + *(channel_bytes_sent_addresses[i]) = 0; + *(channel_bytes_acked_addresses[i]) = 0; + // Note we don't need to overwrite the `channel_src_id_addresses` except for perhapse + // debug purposes where we may wish to tag this with a special value + } + } + + [[nodiscard]] FORCE_INLINE size_t get_current_buffer_address() const { + return this->buffer_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile PacketHeader *get_current_packet_header() const { + return reinterpret_cast(this->buffer_addresses[this->buffer_index()]); + } + + [[nodiscard]] FORCE_INLINE size_t get_current_payload_size() const { + return get_current_packet_header()->get_payload_size_including_header(); + } + [[nodiscard]] FORCE_INLINE size_t get_current_payload_plus_channel_sync_size() const { + return get_current_packet_header()->get_payload_size_including_header() + sizeof(eth_channel_sync_t); + } + + // TODO: Split off into two separate functions: + // volatile tt_l1_ptr size_t *get_current_bytes_sent_ptr() const + // size_t get_current_bytes_sent_address() const + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_sent_address() const { + return this->channel_bytes_sent_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_acked_address() const { + return this->channel_bytes_acked_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_src_id_address() const { + return this->channel_src_id_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE size_t get_channel_buffer_max_size_in_bytes() const { + return this->buffer_size_in_bytes; + } + + // Doesn't return the message size, only the maximum eth payload size + [[nodiscard]] FORCE_INLINE size_t get_current_max_eth_payload_size() const { + return this->max_eth_payload_size_in_bytes; + } + + [[nodiscard]] FORCE_INLINE size_t get_id() const { return this->channel_id; } + + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_done() const { + return *(this->get_current_bytes_sent_address()) == 0; + } + [[nodiscard]] FORCE_INLINE bool eth_bytes_are_available_on_channel() const { + return *(this->get_current_bytes_sent_address()) != 0; + } + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { + return *(this->get_current_bytes_acked_address()) != 0; + } + FORCE_INLINE void eth_clear_sender_channel_ack() const { + *(this->channel_bytes_acked_addresses[this->buffer_index()]) = 0; + } + + [[nodiscard]] FORCE_INLINE size_t get_eth_transaction_ack_word_addr() const { + return this->eth_transaction_ack_word_addr; + } + + FORCE_INLINE void advance_buffer_index() { + this->buff_idx = wrap_incrementbuff_idx), NUM_BUFFERS>(this->buff_idx); + } + + private: + FORCE_INLINE auto buffer_index() const { + ASSERT(this->buff_idx < NUM_BUFFERS); + return buff_idx; + } + + std::array buffer_addresses; + std::array channel_bytes_sent_addresses; + std::array channel_bytes_acked_addresses; + std::array channel_src_id_addresses; + + // header + payload regions only + const std::size_t buffer_size_in_bytes; + // Includes header + payload + channel_sync + const std::size_t eth_transaction_ack_word_addr; + const std::size_t max_eth_payload_size_in_bytes; + uint8_t buff_idx; + uint8_t channel_id; +}; + +struct EdmChannelWorkerInterface { + EdmChannelWorkerInterface() : + worker_location_info_ptr(nullptr), local_semaphore_address(nullptr), connection_live_semaphore(nullptr) {} + EdmChannelWorkerInterface( + // TODO: PERF: See if we can make this non-volatile and then only + // mark it volatile when we know we need to reload it (i.e. after we receive a + // "done" message from sender) + // Have a volatile update function that only triggers after reading the volatile + // completion field so that way we don't have to do a volatile read for every + // packet... Then we'll also be able to cache the uint64_t addr of the worker + // semaphore directly (saving on regenerating it each time) + volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr, + volatile tt_l1_ptr uint32_t *const local_semaphore_address, + volatile tt_l1_ptr uint32_t *const connection_live_semaphore) : + worker_location_info_ptr(worker_location_info_ptr), + local_semaphore_address(local_semaphore_address), + connection_live_semaphore(connection_live_semaphore) {} + + // Flow control methods + // + [[nodiscard]] FORCE_INLINE auto local_semaphore_value() const { return *local_semaphore_address; } + + [[nodiscard]] FORCE_INLINE bool has_payload() { return *local_semaphore_address != 0; } + + FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + + [[nodiscard]] FORCE_INLINE uint32_t get_worker_semaphore_address() const { + return worker_location_info_ptr->worker_semaphore_address; + } + + void increment_worker_semaphore() const { + auto const &worker_info = *worker_location_info_ptr; + uint64_t worker_semaphore_address = get_noc_addr( + (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_semaphore_address); + + DPRINT << "EDMS notif @ " << (uint64_t)worker_semaphore_address << "\n"; + noc_semaphore_inc(worker_semaphore_address, 1); + } + + // Connection management methods + // + FORCE_INLINE void teardown_connection() const { increment_worker_semaphore(); } + + [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == 0; } + + [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == 1; } + + volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr; + volatile tt_l1_ptr uint32_t *const local_semaphore_address; + volatile tt_l1_ptr uint32_t *const connection_live_semaphore; +}; + +} // namespace tt::fabric