diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index dff1f2224809a..c2bcc8ccffe9c 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -37,8 +37,6 @@ RUN mamba install -q -y \ doxygen \ libnuma \ python=${python} \ - ucx \ - ucx-proc=*=cpu \ valgrind && \ mamba clean --all diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index bc2bba915f73a..661414539358f 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -184,7 +184,6 @@ else -DARROW_WITH_OPENTELEMETRY=${ARROW_WITH_OPENTELEMETRY:-OFF} \ -DARROW_WITH_MUSL=${ARROW_WITH_MUSL:-OFF} \ -DARROW_WITH_SNAPPY=${ARROW_WITH_SNAPPY:-OFF} \ - -DARROW_WITH_UCX=${ARROW_WITH_UCX:-OFF} \ -DARROW_WITH_UTF8PROC=${ARROW_WITH_UTF8PROC:-ON} \ -DARROW_WITH_ZLIB=${ARROW_WITH_ZLIB:-OFF} \ -DARROW_WITH_ZSTD=${ARROW_WITH_ZSTD:-OFF} \ diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 41466a1c22404..16a410bed6fa5 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -542,10 +542,6 @@ takes precedence over ccache if a storage backend is configured" ON) define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF) define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF) - define_option(ARROW_WITH_UCX - "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)" - OFF) - define_option(ARROW_WITH_UTF8PROC "Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)" ON) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5b89a831ff7fe..2c5eefea03afb 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -67,7 +67,6 @@ set(ARROW_THIRDPARTY_DEPENDENCIES Snappy Substrait Thrift - ucx utf8proc xsimd ZLIB @@ -218,8 +217,6 @@ macro(build_dependency DEPENDENCY_NAME) build_substrait() elseif("${DEPENDENCY_NAME}" STREQUAL "Thrift") build_thrift() - elseif("${DEPENDENCY_NAME}" STREQUAL "ucx") - build_ucx() elseif("${DEPENDENCY_NAME}" STREQUAL "utf8proc") build_utf8proc() elseif("${DEPENDENCY_NAME}" STREQUAL "xsimd") @@ -830,13 +827,6 @@ else() "${THIRDPARTY_MIRROR_URL}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz") endif() -if(DEFINED ENV{ARROW_UCX_URL}) - set(ARROW_UCX_SOURCE_URL "$ENV{ARROW_UCX_URL}") -else() - set_urls(ARROW_UCX_SOURCE_URL - "https://github.com/openucx/ucx/archive/v${ARROW_UCX_BUILD_VERSION}.tar.gz") -endif() - if(DEFINED ENV{ARROW_UTF8PROC_URL}) set(ARROW_UTF8PROC_SOURCE_URL "$ENV{ARROW_UTF8PROC_URL}") else() @@ -5404,85 +5394,5 @@ if(ARROW_WITH_AZURE_SDK) set(AZURE_SDK_LINK_LIBRARIES Azure::azure-storage-files-datalake Azure::azure-storage-blobs Azure::azure-identity) endif() -# ---------------------------------------------------------------------- -# ucx - communication framework for modern, high-bandwidth and low-latency networks - -macro(build_ucx) - message(STATUS "Building UCX from source") - - set(UCX_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/ucx_ep-install") - - # link with static ucx libraries leads to test failures, use shared libs instead - set(UCX_SHARED_LIB_UCP "${UCX_PREFIX}/lib/libucp${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(UCX_SHARED_LIB_UCT "${UCX_PREFIX}/lib/libuct${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(UCX_SHARED_LIB_UCS "${UCX_PREFIX}/lib/libucs${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(UCX_SHARED_LIB_UCM "${UCX_PREFIX}/lib/libucm${CMAKE_SHARED_LIBRARY_SUFFIX}") - - set(UCX_CONFIGURE_COMMAND ./autogen.sh COMMAND ./configure) - list(APPEND - UCX_CONFIGURE_COMMAND - "CC=${CMAKE_C_COMPILER}" - "CXX=${CMAKE_CXX_COMPILER}" - "CFLAGS=${EP_C_FLAGS}" - "CXXFLAGS=${EP_CXX_FLAGS}" - "--prefix=${UCX_PREFIX}" - "--enable-mt" - "--enable-shared") - if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG") - list(APPEND - UCX_CONFIGURE_COMMAND - "--enable-profiling" - "--enable-frame-pointer" - "--enable-stats" - "--enable-fault-injection" - "--enable-debug-data") - else() - list(APPEND - UCX_CONFIGURE_COMMAND - "--disable-logging" - "--disable-debug" - "--disable-assertions" - "--disable-params-check") - endif() - set(UCX_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) - externalproject_add(ucx_ep - ${EP_COMMON_OPTIONS} - URL ${ARROW_UCX_SOURCE_URL} - URL_HASH "SHA256=${ARROW_UCX_BUILD_SHA256_CHECKSUM}" - CONFIGURE_COMMAND ${UCX_CONFIGURE_COMMAND} - BUILD_IN_SOURCE 1 - BUILD_COMMAND ${UCX_BUILD_COMMAND} - BUILD_BYPRODUCTS "${UCX_SHARED_LIB_UCP}" "${UCX_SHARED_LIB_UCT}" - "${UCX_SHARED_LIB_UCS}" "${UCX_SHARED_LIB_UCM}" - INSTALL_COMMAND ${MAKE} install) - - # ucx cmake module sets UCX_INCLUDE_DIRS - set(UCX_INCLUDE_DIRS "${UCX_PREFIX}/include") - file(MAKE_DIRECTORY "${UCX_INCLUDE_DIRS}") - - add_library(ucx::ucp SHARED IMPORTED) - set_target_properties(ucx::ucp PROPERTIES IMPORTED_LOCATION "${UCX_SHARED_LIB_UCP}") - add_library(ucx::uct SHARED IMPORTED) - set_target_properties(ucx::uct PROPERTIES IMPORTED_LOCATION "${UCX_SHARED_LIB_UCT}") - add_library(ucx::ucs SHARED IMPORTED) - set_target_properties(ucx::ucs PROPERTIES IMPORTED_LOCATION "${UCX_SHARED_LIB_UCS}") - - add_dependencies(ucx::ucp ucx_ep) - add_dependencies(ucx::uct ucx_ep) - add_dependencies(ucx::ucs ucx_ep) -endmacro() - -if(ARROW_WITH_UCX) - resolve_dependency(ucx - ARROW_CMAKE_PACKAGE_NAME - ArrowFlight - ARROW_PC_PACKAGE_NAME - arrow-flight - PC_PACKAGE_NAMES - ucx) - add_library(ucx::ucx INTERFACE IMPORTED) - target_include_directories(ucx::ucx INTERFACE "${UCX_INCLUDE_DIRS}") - target_link_libraries(ucx::ucx INTERFACE ucx::ucp ucx::uct ucx::ucs) -endif() message(STATUS "All bundled static libraries: ${ARROW_BUNDLED_STATIC_LIBS}") diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 43ac48b87678e..ae0fb52dac411 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -329,21 +329,8 @@ if(ARROW_BUILD_BENCHMARKS) add_dependencies(arrow_flight arrow-flight-benchmark) - if(ARROW_WITH_UCX) - if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") - target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static) - target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static) - else() - target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared) - target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared) - endif() - endif() endif(ARROW_BUILD_BENCHMARKS) -if(ARROW_WITH_UCX) - add_subdirectory(transport/ucx) -endif() - if(ARROW_FLIGHT_SQL) add_subdirectory(sql) diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 057ef15c3c7ae..bd497c2418c6e 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -43,17 +43,10 @@ #include #include "arrow/gpu/cuda_api.h" #endif -#ifdef ARROW_WITH_UCX -#include "arrow/flight/transport/ucx/ucx.h" -#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)" -#ifdef ARROW_WITH_UCX - ", \"ucx\"" -#endif // ARROW_WITH_UCX - "."); + "The network transport to use. Supported: \"grpc\" (default)."); DEFINE_string(server_host, "", "An existing performance server to benchmark against (leave blank to spawn " "one automatically)"); @@ -506,21 +499,6 @@ int main(int argc, char** argv) { options.disable_server_verification = true; } } - } else if (FLAGS_transport == "ucx") { -#ifdef ARROW_WITH_UCX - arrow::flight::transport::ucx::InitializeFlightUcx(); - if (FLAGS_test_unix || !FLAGS_server_unix.empty()) { - std::cerr << "Transport does not support domain sockets: " << FLAGS_transport - << std::endl; - return EXIT_FAILURE; - } - ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + - std::to_string(FLAGS_server_port)) - .Value(&location)); -#else - std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; - return EXIT_FAILURE; -#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 87676da11213d..697b3bcfe7b9f 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -44,17 +44,10 @@ #ifdef ARROW_CUDA #include "arrow/gpu/cuda_api.h" #endif -#ifdef ARROW_WITH_UCX -#include "arrow/flight/transport/ucx/ucx.h" -#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)" -#ifdef ARROW_WITH_UCX - ", \"ucx\"" -#endif // ARROW_WITH_UCX - "."); + "The network transport to use. Supported: \"grpc\" (default)."); DEFINE_string(server_host, "localhost", "Host where the server is running on"); DEFINE_int32(port, 31337, "Server port to listen on"); DEFINE_string(server_unix, "", "Unix socket path where the server is running on"); @@ -280,29 +273,6 @@ int main(int argc, char** argv) { ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix) .Value(&connect_location)); } - } else if (FLAGS_transport == "ucx") { -#ifdef ARROW_WITH_UCX - arrow::flight::transport::ucx::InitializeFlightUcx(); - if (FLAGS_server_unix.empty()) { - if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { - std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl; - return EXIT_FAILURE; - } - ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + - std::to_string(FLAGS_port)) - .Value(&bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + - std::to_string(FLAGS_port)) - .Value(&connect_location)); - } else { - std::cerr << "Transport does not support domain sockets: " << FLAGS_transport - << std::endl; - return EXIT_FAILURE; - } -#else - std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; - return EXIT_FAILURE; -#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt deleted file mode 100644 index 23f8850c3d460..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -add_custom_target(arrow_flight_transport_ucx) -arrow_install_all_headers("arrow/flight/transport/ucx") - -set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS - ucx_client.cc - ucx_server.cc - ucx.cc - ucx_internal.cc - util_internal.cc) - -add_arrow_lib(arrow_flight_transport_ucx - # CMAKE_PACKAGE_NAME - # ArrowFlightTransportUcx - # PKG_CONFIG_NAME - # arrow-flight-transport-ucx - SOURCES - ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS} - PRECOMPILED_HEADERS - "$<$:arrow/flight/pch.h>" - DEPENDENCIES - SHARED_LINK_FLAGS - ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt - SHARED_LINK_LIBS - arrow_flight_shared - ucx::ucx - STATIC_LINK_LIBS - arrow_flight_static - ucx::ucx) - -if(ARROW_BUILD_TESTS) - if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") - set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS - arrow_static - arrow_flight_static - arrow_flight_testing_static - arrow_flight_transport_ucx_static - ucx::ucx - ${ARROW_TEST_LINK_LIBS}) - else() - set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS - arrow_shared - arrow_flight_shared - arrow_flight_testing_shared - arrow_flight_transport_ucx_shared - ucx::ucx - ${ARROW_TEST_LINK_LIBS}) - endif() - add_arrow_test(flight_transport_ucx_test - STATIC_LINK_LIBS - ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS} - LABELS - "arrow_flight") -endif() diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc deleted file mode 100644 index c3481d834f6ea..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc +++ /dev/null @@ -1,378 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include - -#include "arrow/array/array_base.h" -#include "arrow/flight/test_definitions.h" -#include "arrow/flight/test_util.h" -#include "arrow/flight/transport/ucx/ucx.h" -#include "arrow/table.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/util/config.h" - -#ifdef UCP_API_VERSION -#error "UCX headers should not be in public API" -#endif - -#include "arrow/flight/transport/ucx/ucx_internal.h" - -#ifdef ARROW_CUDA -#include "arrow/gpu/cuda_api.h" -#endif - -namespace arrow { -namespace flight { - -class UcxEnvironment : public ::testing::Environment { - public: - void SetUp() override { transport::ucx::InitializeFlightUcx(); } -}; - -testing::Environment* const kUcxEnvironment = - testing::AddGlobalTestEnvironment(new UcxEnvironment()); - -//------------------------------------------------------------ -// Common transport tests - -class UcxConnectivityTest : public ConnectivityTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest); - -class UcxDataTest : public DataTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_DATA(UcxDataTest); - -class UcxDoPutTest : public DoPutTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest); - -class UcxAppMetadataTest : public AppMetadataTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest); - -class UcxIpcOptionsTest : public IpcOptionsTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest); - -class UcxCudaDataTest : public CudaDataTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } -}; -ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest); - -class UcxErrorHandlingTest : public ErrorHandlingTest, public ::testing::Test { - protected: - std::string transport() const override { return "ucx"; } - void SetUp() override { SetUpTest(); } - void TearDown() override { TearDownTest(); } - - void TestGetFlightInfoMetadata() { GTEST_SKIP() << "Middleware not implemented"; } -}; -ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest); - -//------------------------------------------------------------ -// UCX internals tests - -constexpr std::initializer_list kStatusCodes = { - StatusCode::OK, - StatusCode::OutOfMemory, - StatusCode::KeyError, - StatusCode::TypeError, - StatusCode::Invalid, - StatusCode::IOError, - StatusCode::CapacityError, - StatusCode::IndexError, - StatusCode::Cancelled, - StatusCode::UnknownError, - StatusCode::NotImplemented, - StatusCode::SerializationError, - StatusCode::RError, - StatusCode::CodeGenError, - StatusCode::ExpressionValidationError, - StatusCode::ExecutionError, - StatusCode::AlreadyExists, -}; - -constexpr std::initializer_list kFlightStatusCodes = { - FlightStatusCode::Internal, FlightStatusCode::TimedOut, - FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated, - FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable, - FlightStatusCode::Failed, -}; - -class TestStatusDetail : public StatusDetail { - public: - const char* type_id() const override { return "test-status-detail"; } - std::string ToString() const override { return "Custom status detail"; } -}; - -namespace transport { -namespace ucx { - -static constexpr std::initializer_list kFrameTypes = { - FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader, - FrameType::kPayloadBody, FrameType::kDisconnect, -}; - -TEST(FrameHeader, Basics) { - for (const auto frame_type : kFrameTypes) { - FrameHeader header; - ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535)); - if (frame_type == FrameType::kDisconnect) { - ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size())); - } else { - ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size())); - ASSERT_EQ(frame->type, frame_type); - ASSERT_EQ(frame->counter, 42); - ASSERT_EQ(frame->size, 65535); - } - } -} - -TEST(FrameHeader, FrameType) { - for (const auto frame_type : kFrameTypes) { - ASSERT_LE(static_cast(frame_type), static_cast(FrameType::kMaxFrameType)); - } -} - -TEST(HeadersFrame, Parse) { - const char* data = - ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar" - "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01"); - constexpr int64_t size = 34; - - { - std::unique_ptr buffer( - new Buffer(reinterpret_cast(data), size)); - ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer))); - ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo")); - ASSERT_EQ(foo, "bar"); - ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin")); - ASSERT_EQ(bin, "\x01"); - } - { - std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 3)); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("expected number of headers"), - HeadersFrame::Parse(std::move(buffer))); - } - { - std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 7)); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("expected length of key 1"), - HeadersFrame::Parse(std::move(buffer))); - } - { - std::unique_ptr buffer( - new Buffer(reinterpret_cast(data), 10)); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("expected length of value 1"), - HeadersFrame::Parse(std::move(buffer))); - } - { - std::unique_ptr buffer( - new Buffer(reinterpret_cast(data), 12)); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"), - HeadersFrame::Parse(std::move(buffer))); - } - { - std::unique_ptr buffer( - new Buffer(reinterpret_cast(data), 17)); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr( - "expected value 1 to have length 3, but only 0 bytes remain"), - HeadersFrame::Parse(std::move(buffer))); - } -} -} // namespace ucx -} // namespace transport - -//------------------------------------------------------------ -// Ad-hoc UCX-specific tests - -class SimpleTestServer : public FlightServerBase { - public: - Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* info) override { - if (request.path.size() > 0 && request.path[0] == "error") { - return status_; - } - auto examples = ExampleFlightInfo(); - info->reset(new FlightInfo(examples[0])); - return Status::OK(); - } - - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - auto batch_reader = std::make_shared(batches[0]->schema(), batches); - *data_stream = std::make_unique(batch_reader); - return Status::OK(); - } - - void set_error_status(Status st) { status_ = std::move(st); } - - private: - Status status_; -}; - -class TestUcx : public ::testing::Test { - public: - void SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "127.0.0.1", 0)); - ASSERT_OK(MakeServer( - location, &server_, &client_, - [](FlightServerOptions* options) { return Status::OK(); }, - [](FlightClientOptions* options) { return Status::OK(); })); - } - - void TearDown() { - ASSERT_OK(client_->Close()); - ASSERT_OK(server_->Shutdown()); - } - - protected: - std::unique_ptr client_; - std::unique_ptr server_; -}; - -TEST_F(TestUcx, GetFlightInfo) { - auto descriptor = FlightDescriptor::Path({"foo", "bar"}); - std::unique_ptr info; - ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); - // Test that we can reuse the connection - ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); -} - -TEST_F(TestUcx, SequentialClients) { - ASSERT_OK_AND_ASSIGN( - auto client2, - FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); - - Ticket ticket{"a"}; - - ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); - ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); - - ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); - ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); - - AssertTablesEqual(*table1, *table2); -} - -TEST_F(TestUcx, ConcurrentClients) { - ASSERT_OK_AND_ASSIGN( - auto client2, - FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); - - Ticket ticket{"a"}; - - ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); - ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); - - ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); - ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); - - AssertTablesEqual(*table1, *table2); -} - -TEST_F(TestUcx, Errors) { - auto descriptor = FlightDescriptor::Path({"error", "bar"}); - auto* server = reinterpret_cast(server_.get()); - for (const auto code : kStatusCodes) { - if (code == StatusCode::OK) continue; - - Status expected(code, "Error message"); - server->set_error_status(expected); - Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); - ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message")) - << actual.ToString(); - - // Attach a generic status detail - { - auto detail = std::make_shared(); - server->set_error_status(Status(code, "foo", detail)); - Status expected(code, "foo", - std::make_shared(FlightStatusCode::Internal, - detail->ToString())); - Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); - ASSERT_THAT(actual.message(), ::testing::HasSubstr("foo")) << actual.ToString(); - ASSERT_THAT(actual.message(), ::testing::HasSubstr("Custom status detail")) - << actual.ToString(); - } - - // Attach a Flight status detail - for (const auto flight_code : kFlightStatusCodes) { - Status expected(code, "Error message", - std::make_shared(flight_code, "extra")); - server->set_error_status(expected); - Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); - ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message")) - << actual.ToString(); - } - } -} - -TEST(TestUcxIpV6, DISABLED_IpV6Port) { - // Also, disabled in CI as machines lack an IPv6 interface - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0)); - - std::unique_ptr server(new SimpleTestServer()); - FlightServerOptions server_options(location); - ASSERT_OK(server->Init(server_options)); - - FlightClientOptions client_options = FlightClientOptions::Defaults(); - ASSERT_OK_AND_ASSIGN(auto client, - FlightClient::Connect(server->location(), client_options)); - - auto descriptor = FlightDescriptor::Path({"foo", "bar"}); - ASSERT_OK_AND_ASSIGN(auto info, client->GetFlightInfo(descriptor)); -} - -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc deleted file mode 100644 index 0e3daf6021348..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/transport/ucx/ucx.h" - -#include - -#include "arrow/flight/transport.h" -#include "arrow/flight/transport/ucx/ucx_internal.h" -#include "arrow/flight/transport_server.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace flight { -namespace transport { -namespace ucx { - -namespace { -std::once_flag kInitializeOnce; -} -void InitializeFlightUcx() { - std::call_once(kInitializeOnce, []() { - auto* registry = flight::internal::GetDefaultTransportRegistry(); - DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl)); - DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl)); - }); -} -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h deleted file mode 100644 index dda2c83035c6d..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx.h +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Experimental UCX-based transport for Flight. - -#pragma once - -#include "arrow/flight/visibility.h" - -namespace arrow { -namespace flight { -namespace transport { -namespace ucx { - -ARROW_FLIGHT_EXPORT -void InitializeFlightUcx(); - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc deleted file mode 100644 index 946ac2d176203..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ /dev/null @@ -1,769 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/// The client-side implementation of a UCX-based transport for -/// Flight. -/// -/// Each UCX driver is used to support one call at a time. This gives -/// the greatest throughput for data plane methods, but is relatively -/// expensive in terms of other resources, both for the server and the -/// client. (UCX drivers have multiple threading modes: single-thread -/// access, serialized access, and multi-thread access. Testing found -/// that multi-thread access incurred high synchronization costs.) -/// Hence, for concurrent calls in a single client, we must maintain -/// multiple drivers, and so unlike gRPC, there is no real difference -/// between using one client concurrently and using multiple -/// independent clients. - -#include "arrow/flight/transport/ucx/ucx_internal.h" - -#include -#include -#include -#include -#include - -#include -#include - -#include "arrow/buffer.h" -#include "arrow/flight/client.h" -#include "arrow/flight/transport.h" -#include "arrow/flight/transport/ucx/util_internal.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/logging.h" -#include "arrow/util/uri.h" - -namespace arrow { -namespace flight { -namespace transport { -namespace ucx { - -namespace { -class UcxClientImpl; - -Status MergeStatuses(Status server_status, Status transport_status) { - if (server_status.ok()) { - if (transport_status.ok()) return server_status; - return transport_status; - } else if (transport_status.ok()) { - return server_status; - } - return Status::FromDetailAndArgs(server_status.code(), server_status.detail(), - server_status.message(), - ". Transport context: ", transport_status.ToString()); -} - -/// \brief An individual connection to the server. -class ClientConnection { - public: - ClientConnection() = default; - ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection); - ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection); - ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; } - - Status Init(std::shared_ptr ucp_context, const arrow::util::Uri& uri) { - auto status = InitImpl(std::move(ucp_context), uri); - // Clean up after-the-fact if we fail to initialize - if (!status.ok()) { - if (driver_) { - status = MergeStatuses(std::move(status), driver_->Close()); - driver_.reset(); - remote_endpoint_ = nullptr; - } - if (ucp_worker_) ucp_worker_.reset(); - } - return status; - } - - Status InitImpl(std::shared_ptr ucp_context, const arrow::util::Uri& uri) { - { - ucs_status_t status; - ucp_worker_params_t worker_params; - std::memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_MULTI; - - ucp_worker_h ucp_worker; - status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker); - RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); - ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker)); - } - { - // Create endpoint for remote worker - struct sockaddr_storage connect_addr; - ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr)); - std::string peer; - ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer)); - ARROW_LOG(DEBUG) << "Connecting to " << peer; - - ucp_ep_params_t params; - params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME | - UCP_EP_PARAM_FIELD_SOCK_ADDR; - params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; - params.name = "UcxClientImpl"; - params.sockaddr.addr = reinterpret_cast(&connect_addr); - params.sockaddr.addrlen = addrlen; - - auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_); - RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status)); - } - - driver_ = std::make_unique(ucp_worker_, remote_endpoint_); - ARROW_LOG(DEBUG) << "Connected to " << driver_->peer(); - - { - // Set up Active Message (AM) handler - ucp_am_handler_param_t handler_params; - handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | - UCP_AM_HANDLER_PARAM_FIELD_CB | - UCP_AM_HANDLER_PARAM_FIELD_ARG; - handler_params.id = kUcpAmHandlerId; - handler_params.cb = HandleIncomingActiveMessage; - handler_params.arg = driver_.get(); - ucs_status_t status = - ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params); - RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); - } - - return Status::OK(); - } - - Status Close() { - if (!driver_) return Status::OK(); - - auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0); - const auto ucs_status = FlightUcxStatusDetail::Unwrap(status); - if (IsIgnorableDisconnectError(ucs_status)) { - status = Status::OK(); - } - status = MergeStatuses(std::move(status), driver_->Close()); - - driver_.reset(); - remote_endpoint_ = nullptr; - ucp_worker_.reset(); - return status; - } - - UcpCallDriver* driver() { - DCHECK(driver_); - return driver_.get(); - } - - private: - static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, - size_t header_length, void* data, - size_t data_length, - const ucp_am_recv_param_t* param) { - auto* driver = reinterpret_cast(self); - return driver->RecvActiveMessage(header, header_length, data, data_length, param); - } - - std::shared_ptr ucp_worker_; - ucp_ep_h remote_endpoint_; - std::unique_ptr driver_; -}; - -class UcxClientStream : public internal::ClientDataStream { - public: - UcxClientStream(UcxClientImpl* impl, ClientConnection conn) - : impl_(impl), - conn_(std::move(conn)), - driver_(conn_.driver()), - writes_done_(false), - finished_(false) { - DCHECK_NE(impl, nullptr); - DCHECK_NE(conn_.driver(), nullptr); - } - - protected: - Status DoFinish() override; - - std::mutex finish_mutex_; - UcxClientImpl* impl_; - ClientConnection conn_; - UcpCallDriver* driver_; - bool writes_done_; - bool finished_; - Status io_status_; - Status server_status_; -}; - -class GetClientStream : public UcxClientStream { - public: - GetClientStream(UcxClientImpl* impl, ClientConnection conn) - : UcxClientStream(impl, std::move(conn)) { - writes_done_ = true; - } - - bool ReadData(internal::FlightData* data) override { - if (finished_) return false; - - bool success = true; - io_status_ = ReadImpl(data).Value(&success); - - if (!io_status_.ok() || !success) { - finished_ = true; - } - return success; - } - - private: - ::arrow::Result ReadImpl(internal::FlightData* data) { - ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); - - if (frame->type == FrameType::kHeaders) { - // Trailers, stream is over - ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); - RETURN_NOT_OK(headers.GetStatus(&server_status_)); - return false; - } - - RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); - PayloadHeaderFrame payload_header(std::move(frame->buffer)); - RETURN_NOT_OK(payload_header.ToFlightData(data)); - - // DoGet does not support metadata-only messages, so we can always - // assume we have an IPC payload - ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); - - if (ipc::Message::HasBody(message->type())) { - ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); - RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); - data->body = std::move(frame->buffer); - } - return true; - } -}; - -class WriteClientStream : public UcxClientStream { - public: - WriteClientStream(UcxClientImpl* impl, ClientConnection conn) - : UcxClientStream(impl, std::move(conn)) { - std::thread t(&WriteClientStream::DriveWorker, this); - driver_thread_.swap(t); - } - arrow::Result WriteData(const FlightPayload& payload) override { - std::unique_lock guard(driver_mutex_); - if (finished_ || writes_done_) return false; - outgoing_ = driver_->SendFlightPayload(payload); - working_cv_.notify_all(); - completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); - - auto status = outgoing_.status(); - outgoing_ = Future<>(); - RETURN_NOT_OK(status); - return true; - } - Status WritesDone() override { - std::unique_lock guard(driver_mutex_); - if (!writes_done_) { - ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({})); - outgoing_ = - driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer()); - working_cv_.notify_all(); - completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); - - writes_done_ = true; - auto status = outgoing_.status(); - outgoing_ = Future<>(); - RETURN_NOT_OK(status); - } - return Status::OK(); - } - - protected: - void JoinThread() { - try { - driver_thread_.join(); - } catch (const std::system_error&) { - // Ignore - } - } - // Flight's API allows concurrent reads/writes, but the UCX driver - // here is single-threaded, so push all UCX work onto a single - // worker thread - void DriveWorker() { - while (true) { - { - std::unique_lock guard(driver_mutex_); - working_cv_.wait(guard, - [this] { return incoming_.is_valid() || outgoing_.is_valid(); }); - } - - while (true) { - std::unique_lock guard(driver_mutex_); - if (!incoming_.is_valid() && !outgoing_.is_valid()) break; - if (incoming_.is_valid() && incoming_.is_finished()) { - if (!incoming_.status().ok()) { - io_status_ = incoming_.status(); - finished_ = true; - } else { - HandleIncomingMessage(*incoming_.result()); - } - incoming_ = Future>(); - completed_cv_.notify_all(); - break; - } - if (outgoing_.is_valid() && outgoing_.is_finished()) { - completed_cv_.notify_all(); - break; - } - driver_->MakeProgress(); - } - if (finished_) return; - } - } - - virtual void HandleIncomingMessage(const std::shared_ptr& frame) {} - - std::mutex driver_mutex_; - std::thread driver_thread_; - std::condition_variable completed_cv_; - std::condition_variable working_cv_; - Future> incoming_; - Future<> outgoing_; -}; - -class PutClientStream : public WriteClientStream { - public: - using WriteClientStream::WriteClientStream; - bool ReadPutMetadata(std::shared_ptr* out) override { - std::unique_lock guard(driver_mutex_); - if (finished_) { - *out = nullptr; - guard.unlock(); - JoinThread(); - return false; - } - next_metadata_ = nullptr; - incoming_ = driver_->ReadFrameAsync(); - working_cv_.notify_all(); - completed_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; }); - - if (finished_) { - *out = nullptr; - guard.unlock(); - JoinThread(); - return false; - } - *out = std::move(next_metadata_); - return true; - } - - private: - void HandleIncomingMessage(const std::shared_ptr& frame) override { - // No lock here, since this is called from DriveWorker() which is - // holding the lock - if (frame->type == FrameType::kBuffer) { - next_metadata_ = std::move(frame->buffer); - } else if (frame->type == FrameType::kHeaders) { - // Trailers, stream is over - finished_ = true; - HeadersFrame headers; - io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); - if (!io_status_.ok()) { - finished_ = true; - return; - } - io_status_ = headers.GetStatus(&server_status_); - if (!io_status_.ok()) { - finished_ = true; - return; - } - } else { - finished_ = true; - io_status_ = - Status::IOError("Unexpected frame type ", static_cast(frame->type)); - } - } - std::shared_ptr next_metadata_; -}; - -class ExchangeClientStream : public WriteClientStream { - public: - ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn) - : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {} - - bool ReadData(internal::FlightData* data) override { - std::unique_lock guard(driver_mutex_); - if (finished_) { - guard.unlock(); - JoinThread(); - return false; - } - - // Drive the read loop here. (We can't recursively call - // ReadFrameAsync below since the internal mutex is not - // recursive.) - read_state_ = ReadState::kExpectHeader; - incoming_ = driver_->ReadFrameAsync(); - working_cv_.notify_all(); - completed_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; }); - if (read_state_ != ReadState::kFinished) { - incoming_ = driver_->ReadFrameAsync(); - working_cv_.notify_all(); - completed_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; }); - } - - if (finished_) { - guard.unlock(); - JoinThread(); - return false; - } - *data = std::move(next_data_); - return true; - } - - private: - enum class ReadState { - kFinished, - kExpectHeader, - kExpectBody, - }; - - std::string DebugExpectingString() { - switch (read_state_) { - case ReadState::kFinished: - return "(not expecting a frame)"; - case ReadState::kExpectHeader: - return "payload header frame"; - case ReadState::kExpectBody: - return "payload body frame"; - } - return "(unknown or invalid state)"; - } - - void HandleIncomingMessage(const std::shared_ptr& frame) override { - // No lock here, since this is called from MakeProgress() - // which is called under the lock already - if (frame->type == FrameType::kPayloadHeader) { - if (read_state_ != ReadState::kExpectHeader) { - finished_ = true; - io_status_ = Status::IOError("Got unexpected payload header frame, expected: ", - DebugExpectingString()); - return; - } - - PayloadHeaderFrame payload_header(std::move(frame->buffer)); - io_status_ = payload_header.ToFlightData(&next_data_); - if (!io_status_.ok()) { - finished_ = true; - return; - } - - if (next_data_.metadata) { - std::unique_ptr message; - io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message); - if (!io_status_.ok()) { - finished_ = true; - return; - } - if (ipc::Message::HasBody(message->type())) { - read_state_ = ReadState::kExpectBody; - return; - } - } - read_state_ = ReadState::kFinished; - } else if (frame->type == FrameType::kPayloadBody) { - next_data_.body = std::move(frame->buffer); - read_state_ = ReadState::kFinished; - } else if (frame->type == FrameType::kHeaders) { - // Trailers, stream is over - finished_ = true; - read_state_ = ReadState::kFinished; - HeadersFrame headers; - io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); - if (!io_status_.ok()) { - finished_ = true; - return; - } - io_status_ = headers.GetStatus(&server_status_); - if (!io_status_.ok()) { - finished_ = true; - return; - } - } else { - finished_ = true; - io_status_ = - Status::IOError("Unexpected frame type ", static_cast(frame->type)); - read_state_ = ReadState::kFinished; - } - } - - internal::FlightData next_data_; - ReadState read_state_; -}; - -class UcxClientImpl : public arrow::flight::internal::ClientTransport { - public: - UcxClientImpl() = default; - - ~UcxClientImpl() override { - if (!ucp_context_) return; - ARROW_WARN_NOT_OK(Close(), "UcxClientImpl errored in Close() in destructor"); - } - - Status Init(const FlightClientOptions& options, const Location& location, - const arrow::util::Uri& uri) override { - RETURN_NOT_OK(uri_.Parse(uri.ToString())); - { - ucp_config_t* ucp_config; - ucp_params_t ucp_params; - ucs_status_t status; - - status = ucp_config_read(nullptr, nullptr, &ucp_config); - RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); - - // If location is IPv6, must adjust UCX config - // XXX: we assume locations always resolve to IPv6 or IPv4 but - // that is not necessarily true. - { - struct sockaddr_storage connect_addr; - RETURN_NOT_OK(UriToSockaddr(uri, &connect_addr)); - if (connect_addr.ss_family == AF_INET6) { - status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); - RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); - } - } - - std::memset(&ucp_params, 0, sizeof(ucp_params)); - ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; - ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; - - ucp_context_h ucp_context; - status = ucp_init(&ucp_params, ucp_config, &ucp_context); - ucp_config_release(ucp_config); - RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); - ucp_context_.reset(new UcpContext(ucp_context)); - } - - RETURN_NOT_OK(MakeConnection()); - return Status::OK(); - } - - Status Close() override { - std::unique_lock connections_mutex_; - while (!connections_.empty()) { - ClientConnection conn = std::move(connections_.front()); - connections_.pop_front(); - RETURN_NOT_OK(conn.Close()); - } - return Status::OK(); - } - - Status GetFlightInfo(const FlightCallOptions& options, - const FlightDescriptor& descriptor, - std::unique_ptr* info) override { - ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); - UcpCallDriver* driver = connection.driver(); - - auto impl = [&]() { - RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo)); - - ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString()); - - RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, - reinterpret_cast(payload.data()), - static_cast(payload.size()))); - - ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame()); - if (incoming_message->type == FrameType::kBuffer) { - ARROW_ASSIGN_OR_RAISE( - *info, FlightInfo::Deserialize(std::string_view(*incoming_message->buffer))); - ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame()); - } - RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders)); - ARROW_ASSIGN_OR_RAISE(auto headers, - HeadersFrame::Parse(std::move(incoming_message->buffer))); - Status status; - RETURN_NOT_OK(headers.GetStatus(&status)); - return status; - }; - auto status = impl(); - return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); - } - - Status PollFlightInfo(const FlightCallOptions& options, - const FlightDescriptor& descriptor, - std::unique_ptr* info) override { - ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); - UcpCallDriver* driver = connection.driver(); - - auto impl = [&]() { - RETURN_NOT_OK(driver->StartCall(kMethodPollFlightInfo)); - - ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString()); - - RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, - reinterpret_cast(payload.data()), - static_cast(payload.size()))); - - ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame()); - if (incoming_message->type == FrameType::kBuffer) { - ARROW_ASSIGN_OR_RAISE( - *info, PollInfo::Deserialize(std::string_view(*incoming_message->buffer))); - ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame()); - } - RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders)); - ARROW_ASSIGN_OR_RAISE(auto headers, - HeadersFrame::Parse(std::move(incoming_message->buffer))); - Status status; - RETURN_NOT_OK(headers.GetStatus(&status)); - return status; - }; - auto status = impl(); - return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); - } - - Status DoExchange(const FlightCallOptions& options, - std::unique_ptr* out) override { - ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); - UcpCallDriver* driver = connection.driver(); - - auto status = driver->StartCall(kMethodDoExchange); - if (ARROW_PREDICT_TRUE(status.ok())) { - *out = std::make_unique(this, std::move(connection)); - return Status::OK(); - } - return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); - } - - Status DoGet(const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* stream) override { - ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); - UcpCallDriver* driver = connection.driver(); - - auto impl = [&]() { - RETURN_NOT_OK(driver->StartCall(kMethodDoGet)); - ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString()); - RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, - reinterpret_cast(payload.data()), - static_cast(payload.size()))); - *stream = std::make_unique(this, std::move(connection)); - return Status::OK(); - }; - - auto status = impl(); - if (ARROW_PREDICT_TRUE(status.ok())) return status; - return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); - } - - Status DoPut(const FlightCallOptions& options, - std::unique_ptr* out) override { - ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); - UcpCallDriver* driver = connection.driver(); - - auto status = driver->StartCall(kMethodDoPut); - if (ARROW_PREDICT_TRUE(status.ok())) { - *out = std::make_unique(this, std::move(connection)); - return Status::OK(); - } - return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); - } - - Status DoAction(const FlightCallOptions& options, const Action& action, - std::unique_ptr* results) override { - // XXX: fake this for now to get the perf test to work - return Status::OK(); - } - - Status MakeConnection() { - ClientConnection conn; - RETURN_NOT_OK(conn.Init(ucp_context_, uri_)); - std::unique_lock connections_mutex_; - connections_.push_back(std::move(conn)); - return Status::OK(); - } - - arrow::Result CheckoutConnection(const FlightCallOptions& options) { - std::unique_lock connections_mutex_; - if (connections_.empty()) RETURN_NOT_OK(MakeConnection()); - ClientConnection conn = std::move(connections_.front()); - connections_.pop_front(); - conn.driver()->set_memory_manager(options.memory_manager); - conn.driver()->set_read_memory_pool(options.read_options.memory_pool); - conn.driver()->set_write_memory_pool(options.write_options.memory_pool); - return conn; - } - - Status ReturnConnection(ClientConnection conn) { - std::unique_lock connections_mutex_; - // TODO(ARROW-16127): for future improvement: reclaim clients - // asynchronously in the background (try to avoid issues like - // constantly opening/closing clients because the application is - // just barely over the limit of open connections) - if (connections_.size() >= kMaxOpenConnections) { - RETURN_NOT_OK(conn.Close()); - return Status::OK(); - } - DCHECK_NE(conn.driver(), nullptr); - connections_.push_back(std::move(conn)); - return Status::OK(); - } - - private: - static constexpr size_t kMaxOpenConnections = 3; - - arrow::util::Uri uri_; - std::shared_ptr ucp_context_; - std::mutex connections_mutex_; - std::deque connections_; -}; - -Status UcxClientStream::DoFinish() { - RETURN_NOT_OK(WritesDone()); - // Both reader and writer may be used concurrently, and both may - // call Finish() - prevent concurrent state mutation - std::lock_guard guard(finish_mutex_); - if (!finished_) { - internal::FlightData message; - std::shared_ptr metadata; - while (ReadData(&message)) { - } - while (ReadPutMetadata(&metadata)) { - } - finished_ = true; - } - if (impl_) { - DCHECK_NE(conn_.driver(), nullptr); - auto status = impl_->ReturnConnection(std::move(conn_)); - impl_ = nullptr; - driver_ = nullptr; - if (!status.ok()) { - if (io_status_.ok()) { - io_status_ = std::move(status); - } else { - io_status_ = Status::FromDetailAndArgs( - io_status_.code(), io_status_.detail(), io_status_.message(), - ". Transport context: ", status.ToString()); - } - } - } - return MergeStatuses(server_status_, io_status_); -} -} // namespace - -std::unique_ptr MakeUcxClientImpl() { - return std::make_unique(); -} - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc deleted file mode 100644 index 767a877ece331..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ /dev/null @@ -1,1110 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/transport/ucx/ucx_internal.h" - -#include -#include -#include -#include -#include - -#include "arrow/buffer.h" -#include "arrow/flight/transport/ucx/util_internal.h" -#include "arrow/flight/types.h" -#include "arrow/util/base64.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/util/uri.h" - -namespace arrow { - -using internal::ToChars; - -namespace flight { -namespace transport { -namespace ucx { - -using internal::TransportStatus; - -// Defines to test different implementation strategies -// Enable the CONTIG path for CPU-only data -// #define ARROW_FLIGHT_UCX_SEND_CONTIG -// Enable ucp_mem_map in IOV path -// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP - -constexpr char kHeaderMethod[] = ":method:"; - -namespace { -Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) { - if (ARROW_PREDICT_FALSE(in < 0)) { - return Status::Invalid("Length cannot be negative"); - } else if (ARROW_PREDICT_FALSE( - in > static_cast(std::numeric_limits::max()))) { - return Status::Invalid("Length cannot exceed uint32_t"); - } - UInt32ToBytesBe(static_cast(in), out); - return Status::OK(); -} -ucs_memory_type InferMemoryType(const Buffer& buffer) { - if (!buffer.is_cpu()) { - return UCS_MEMORY_TYPE_CUDA; - } - return UCS_MEMORY_TYPE_UNKNOWN; -} -void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size, - ucs_memory_type memory_type, ucp_mem_h* memh_p) { - ucp_mem_map_params_t map_param; - std::memset(&map_param, 0, sizeof(map_param)); - map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH | - UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; - map_param.address = const_cast(buffer); - map_param.length = size; - map_param.memory_type = memory_type; - auto ucs_status = ucp_mem_map(context, &map_param, memh_p); - if (ucs_status != UCS_OK) { - *memh_p = nullptr; - ARROW_LOG(WARNING) << "Could not map memory: " - << FromUcsStatus("ucp_mem_map", ucs_status); - } -} -void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) { - TryMapBuffer(context, reinterpret_cast(buffer.address()), - static_cast(buffer.size()), InferMemoryType(buffer), memh_p); -} -void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) { - if (memh_p) { - auto ucs_status = ucp_mem_unmap(context, memh_p); - if (ucs_status != UCS_OK) { - ARROW_LOG(WARNING) << "Could not unmap memory: " - << FromUcsStatus("ucp_mem_unmap", ucs_status); - } - } -} - -/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA -/// buffer). -/// -/// Owns a reference to the associated worker to avoid undefined -/// behavior. -class UcxDataBuffer : public Buffer { - public: - explicit UcxDataBuffer(std::shared_ptr worker, void* data, size_t size) - : Buffer(reinterpret_cast(data), static_cast(size)), - worker_(std::move(worker)) {} - - ~UcxDataBuffer() { - ucp_am_data_release(worker_->get(), - const_cast(reinterpret_cast(data()))); - } - - private: - std::shared_ptr worker_; -}; -}; // namespace - -constexpr size_t FrameHeader::kFrameHeaderBytes; -constexpr uint8_t FrameHeader::kFrameVersion; - -Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) { - header[0] = kFrameVersion; - header[1] = static_cast(frame_type); - UInt32ToBytesBe(counter, header.data() + 4); - RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8)); - return Status::OK(); -} - -arrow::Result> Frame::ParseHeader(const void* header, - size_t header_length) { - if (header_length < FrameHeader::kFrameHeaderBytes) { - return Status::IOError("Header is too short, must be at least ", - FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length); - } - - const uint8_t* frame_header = reinterpret_cast(header); - if (frame_header[0] != FrameHeader::kFrameVersion) { - return Status::IOError("Expected frame version ", - static_cast(FrameHeader::kFrameVersion), " but got ", - static_cast(frame_header[0])); - } else if (frame_header[1] > static_cast(FrameType::kMaxFrameType)) { - return Status::IOError("Unknown frame type ", static_cast(frame_header[1])); - } - - const FrameType frame_type = static_cast(frame_header[1]); - const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4); - const uint32_t frame_size = BytesToUInt32Be(frame_header + 8); - - if (frame_type == FrameType::kDisconnect) { - return Status::Cancelled("Client initiated disconnect"); - } - - return std::make_shared(frame_type, frame_size, frame_counter, nullptr); -} - -arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) { - HeadersFrame result; - const uint8_t* payload = buffer->data(); - const uint8_t* end = payload + buffer->size(); - if (ARROW_PREDICT_FALSE((end - payload) < 4)) { - return Status::Invalid("Buffer underflow, expected number of headers"); - } - const uint32_t num_headers = BytesToUInt32Be(payload); - payload += 4; - for (uint32_t i = 0; i < num_headers; i++) { - if (ARROW_PREDICT_FALSE((end - payload) < 4)) { - return Status::Invalid("Buffer underflow, expected length of key ", i + 1); - } - const uint32_t key_length = BytesToUInt32Be(payload); - payload += 4; - - if (ARROW_PREDICT_FALSE((end - payload) < 4)) { - return Status::Invalid("Buffer underflow, expected length of value ", i + 1); - } - const uint32_t value_length = BytesToUInt32Be(payload); - payload += 4; - - if (ARROW_PREDICT_FALSE((end - payload) < key_length)) { - return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ", - key_length, ", but only ", (end - payload), " bytes remain"); - } - const std::string_view key(reinterpret_cast(payload), key_length); - payload += key_length; - - if (ARROW_PREDICT_FALSE((end - payload) < value_length)) { - return Status::Invalid("Buffer underflow, expected value ", i + 1, - " to have length ", value_length, ", but only ", - (end - payload), " bytes remain"); - } - const std::string_view value(reinterpret_cast(payload), value_length); - payload += value_length; - result.headers_.emplace_back(key, value); - } - - result.buffer_ = std::move(buffer); - return result; -} -arrow::Result HeadersFrame::Make( - const std::vector>& headers) { - int32_t total_length = 4 /* # of headers */; - for (const auto& header : headers) { - total_length += 4 /* key length */ + 4 /* value length */ + - header.first.size() /* key */ + header.second.size(); - } - - ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length)); - uint8_t* payload = buffer->mutable_data(); - - RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload)); - payload += 4; - for (const auto& header : headers) { - RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload)); - payload += 4; - RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload)); - payload += 4; - std::memcpy(payload, header.first.data(), header.first.size()); - payload += header.first.size(); - std::memcpy(payload, header.second.data(), header.second.size()); - payload += header.second.size(); - } - return Parse(std::move(buffer)); -} -arrow::Result HeadersFrame::Make( - const Status& status, - const std::vector>& headers) { - auto all_headers = headers; - - TransportStatus transport_status = TransportStatus::FromStatus(status); - all_headers.emplace_back(kHeaderStatus, - ToChars(static_cast(transport_status.code))); - all_headers.emplace_back(kHeaderMessage, std::move(transport_status.message)); - all_headers.emplace_back(kHeaderStatusCode, - ToChars(static_cast(status.code()))); - all_headers.emplace_back(kHeaderStatusMessage, status.message()); - if (status.detail()) { - all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString()); - auto fsd = FlightStatusDetail::UnwrapStatus(status); - if (fsd && !fsd->extra_info().empty()) { - all_headers.emplace_back(kHeaderStatusDetailBin, fsd->extra_info()); - } - } - return Make(all_headers); -} - -arrow::Result HeadersFrame::Get(const std::string& key) { - for (const auto& pair : headers_) { - if (pair.first == key) return pair.second; - } - return Status::KeyError(key); -} - -Status HeadersFrame::GetStatus(Status* out) { - static const std::string kUnknownMessage = "Server did not send status message header"; - std::string_view code_str, message_str; - auto status = Get(kHeaderStatus).Value(&code_str); - if (!status.ok()) { - return Status::KeyError("Server did not send status code header ", kHeaderStatusCode); - } - if (code_str == "0") { // == ToChars(TransportStatusCode::kOk) - *out = Status::OK(); - return Status::OK(); - } - - status = Get(kHeaderMessage).Value(&message_str); - if (!status.ok()) message_str = kUnknownMessage; - - TransportStatus transport_status = TransportStatus::FromCodeStringAndMessage( - std::string(code_str), std::string(message_str)); - if (transport_status.code == TransportStatusCode::kOk) { - *out = Status::OK(); - return Status::OK(); - } - *out = transport_status.ToStatus(); - - std::string_view detail_str, bin_str; - std::optional message, detail_message, detail_bin; - if (!Get(kHeaderStatusCode).Value(&code_str).ok()) { - // No Arrow status sent, go with the transport status - return Status::OK(); - } - if (Get(kHeaderStatusMessage).Value(&message_str).ok()) { - message = std::string(message_str); - } - if (Get(kHeaderStatusDetail).Value(&detail_str).ok()) { - detail_message = std::string(detail_str); - } - if (Get(kHeaderStatusDetailBin).Value(&bin_str).ok()) { - detail_bin = std::string(bin_str); - } - *out = internal::ReconstructStatus(std::string(code_str), *out, std::move(message), - std::move(detail_message), std::move(detail_bin), - FlightStatusDetail::UnwrapStatus(*out)); - return Status::OK(); -} - -namespace { -static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits::max(); -static constexpr uint32_t kInt32Max = - static_cast(std::numeric_limits::max()); -arrow::Result PayloadHeaderFieldSize(const std::string& field, - const std::shared_ptr& data, - uint32_t* total_size) { - if (!data) return kMissingFieldSentinel; - if (data->size() > kInt32Max) { - return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size()); - } - *total_size += static_cast(data->size()); - // Check for underflow - if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t"); - return static_cast(data->size()); -} -uint8_t* PackField(uint32_t size, const std::shared_ptr& data, uint8_t* out) { - UInt32ToBytesBe(size, out); - if (size != kMissingFieldSentinel) { - std::memcpy(out + 4, data->data(), size); - return out + 4 + size; - } else { - return out + 4; - } -} -} // namespace - -arrow::Result PayloadHeaderFrame::Make(const FlightPayload& payload, - MemoryPool* memory_pool) { - // Assemble all non-data fields here. Presumably this is much less - // than data size so we will pay the copy. - - // Structure per field: [4 byte length][data]. If a field is not - // present, UINT32_MAX is used as the sentinel (since 0-sized fields - // are acceptable) - uint32_t header_size = 12; - ARROW_ASSIGN_OR_RAISE( - const uint32_t descriptor_size, - PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size)); - ARROW_ASSIGN_OR_RAISE( - const uint32_t app_metadata_size, - PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size)); - ARROW_ASSIGN_OR_RAISE( - const uint32_t ipc_metadata_size, - PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata, - &header_size)); - - ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool)); - uint8_t* payload_header = header_buffer->mutable_data(); - - payload_header = PackField(descriptor_size, payload.descriptor, payload_header); - payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header); - payload_header = - PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header); - - return PayloadHeaderFrame(std::move(header_buffer)); -} -Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) { - std::shared_ptr buffer = std::move(buffer_); - - // Unpack the descriptor - uint32_t offset = 0; - uint32_t size = BytesToUInt32Be(buffer->data()); - offset += 4; - if (size != kMissingFieldSentinel) { - if (static_cast(offset + size) > buffer->size()) { - return Status::Invalid("Buffer is too small: expected ", offset + size, - " bytes but have ", buffer->size()); - } - std::string_view desc(reinterpret_cast(buffer->data() + offset), size); - data->descriptor.reset(new FlightDescriptor()); - ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc)); - offset += size; - } else { - data->descriptor = nullptr; - } - - // Unpack app_metadata - size = BytesToUInt32Be(buffer->data() + offset); - offset += 4; - // While we properly handle zero-size vs nullptr metadata here, gRPC - // doesn't (Protobuf doesn't differentiate between the two) - if (size != kMissingFieldSentinel) { - if (static_cast(offset + size) > buffer->size()) { - return Status::Invalid("Buffer is too small: expected ", offset + size, - " bytes but have ", buffer->size()); - } - data->app_metadata = SliceBuffer(buffer, offset, size); - offset += size; - } else { - data->app_metadata = nullptr; - } - - // Unpack the IPC header - size = BytesToUInt32Be(buffer->data() + offset); - offset += 4; - if (size != kMissingFieldSentinel) { - if (static_cast(offset + size) > buffer->size()) { - return Status::Invalid("Buffer is too small: expected ", offset + size, - " bytes but have ", buffer->size()); - } - data->metadata = SliceBuffer(std::move(buffer), offset, size); - } else { - data->metadata = nullptr; - } - data->body = nullptr; - return Status::OK(); -} - -// pImpl the driver since async methods require a stable address -class UcpCallDriver::Impl { - public: -#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG) - constexpr static bool kEnableContigSend = true; -#else - constexpr static bool kEnableContigSend = false; -#endif - - Impl(std::shared_ptr worker, ucp_ep_h endpoint) - : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}), - worker_(std::move(worker)), - endpoint_(endpoint), - read_memory_pool_(default_memory_pool()), - write_memory_pool_(default_memory_pool()), - memory_manager_(CPUDevice::Instance()->default_memory_manager()), - name_("(unknown remote)"), - counter_(0) { -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(), - UCS_MEMORY_TYPE_HOST, &padding_memh_p_); -#endif - - ucp_ep_attr_t attrs; - std::memset(&attrs, 0, sizeof(attrs)); - attrs.field_mask = - UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR; - if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) { - std::string local_addr, remote_addr; - ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr)); - ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr)); - name_ = "local:" + local_addr + ";remote:" + remote_addr; - } - } - - ~Impl() { -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - TryUnmapBuffer(worker_->context().get(), padding_memh_p_); -#endif - } - - arrow::Result> ReadNextFrame() { - auto fut = ReadFrameAsync(); - while (!fut.is_finished()) MakeProgress(); - RETURN_NOT_OK(fut.status()); - return fut.MoveResult(); - } - - Future> ReadFrameAsync() { - RETURN_NOT_OK(CheckClosed()); - - std::unique_lock guard(frame_mutex_); - if (ARROW_PREDICT_FALSE(!status_.ok())) return status_; - - // Expected value of "counter" field in the frame header - const uint32_t counter_value = next_counter_++; - auto it = frames_.find(counter_value); - if (it != frames_.end()) { - // Message already delivered, return it - Future> fut = it->second; - frames_.erase(it); - return fut; - } - // Message not yet delivered, insert a future and wait - auto pair = frames_.insert({counter_value, Future>::Make()}); - DCHECK(pair.second); - return pair.first->second; - } - - Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) { - RETURN_NOT_OK(CheckClosed()); - - void* request = nullptr; - ucp_request_param_t request_param; - request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; - request_param.flags = UCP_AM_SEND_FLAG_REPLY; - - // Send frame header - FrameHeader header; - RETURN_NOT_OK(header.Set(frame_type, counter_++, size)); - if (size == 0) { - // UCX appears to crash on zero-byte payloads - request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), - padding_bytes_.data(), - /*size=*/1, &request_param); - } else { - request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), - data, size, &request_param); - } - RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request)); - - return Status::OK(); - } - - Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer) { - RETURN_NOT_OK(CheckClosed()); - - ucp_request_param_t request_param; - std::memset(&request_param, 0, sizeof(request_param)); - request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | - UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; - request_param.cb.send = AmSendCallback; - request_param.datatype = ucp_dt_make_contig(1); - request_param.flags = UCP_AM_SEND_FLAG_REPLY; - - const int64_t size = buffer->size(); - if (size == 0) { - // UCX appears to crash on zero-byte payloads - ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_)); - } - - std::unique_ptr pending_send(new PendingContigSend()); - RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size)); - pending_send->ipc_message = std::move(buffer); - pending_send->driver = this; - pending_send->completed = Future<>::Make(); - pending_send->memh_p = nullptr; - - request_param.user_data = pending_send.release(); - { - auto* pending_send = reinterpret_cast(request_param.user_data); - - void* request = ucp_am_send_nbx( - endpoint_, kUcpAmHandlerId, pending_send->header.data(), - pending_send->header.size(), - reinterpret_cast(pending_send->ipc_message->mutable_data()), - static_cast(pending_send->ipc_message->size()), &request_param); - if (!request) { - // Request completed immediately - delete pending_send; - return Status::OK(); - } else if (UCS_PTR_IS_ERR(request)) { - delete pending_send; - return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); - } - return pending_send->completed; - } - } - - Future<> SendFlightPayload(const FlightPayload& payload) { - static const int64_t kMaxBatchSize = std::numeric_limits::max(); - RETURN_NOT_OK(CheckClosed()); - - if (payload.ipc_message.body_length > kMaxBatchSize) { - return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); - } - - { - ARROW_ASSIGN_OR_RAISE(auto frame, - PayloadHeaderFrame::Make(payload, write_memory_pool_)); - RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size())); - } - - if (!ipc::Message::HasBody(payload.ipc_message.type)) { - return Status::OK(); - } - - // While IOV (scatter-gather) might seem like it avoids a memcpy, - // profiling shows that at least for the TCP/SHM/RDMA transports, - // UCX just does a memcpy internally. Furthermore, on the receiver - // side, a sender-side IOV send prevents optimizations based on - // mapped buffers (UCX will memcpy to the destination buffer - // regardless of whether it's mapped or not). - - // If all buffers are on the CPU, concatenate them ourselves and - // do a regular send to avoid this. Else, use IOV and let UCX - // figure out what to do. - - // Weirdness: UCX prefers TCP over shared memory for CONTIG? We - // can avoid this by setting UCX_RNDV_THRESH=inf, this will make - // UCX prefer shared memory again. However, we still want to avoid - // the CONTIG path when shared memory is available, because the - // total amount of time spent in memcpy is greater than using IOV - // and letting UCX handle it. - - // Consider: if we can figure out how to make IOV always as fast - // as CONTIG, we can just send the metadata fields as part of the - // IOV payload and avoid having to send two distinct messages. - - bool all_cpu = true; - int32_t total_buffers = 0; - for (const auto& buffer : payload.ipc_message.body_buffers) { - if (!buffer || buffer->size() == 0) continue; - all_cpu = all_cpu && buffer->is_cpu(); - total_buffers++; - - // Arrow IPC requires that we align buffers to 8 byte boundary - const auto remainder = static_cast( - bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); - if (remainder) total_buffers++; - } - - ucp_request_param_t request_param; - std::memset(&request_param, 0, sizeof(request_param)); - request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | - UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; - request_param.cb.send = AmSendCallback; - request_param.flags = UCP_AM_SEND_FLAG_REPLY; - - std::unique_ptr pending_send; - void* send_data = nullptr; - size_t send_size = 0; - - if (!all_cpu) { - request_param.op_attr_mask = - request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE; - // XXX: UCX doesn't appear to autodetect this correctly if we - // use UNKNOWN - request_param.memory_type = UCS_MEMORY_TYPE_CUDA; - } - - if (kEnableContigSend && all_cpu) { - // CONTIG - concatenate buffers into one before sending - - // TODO(ARROW-16126): this needs to be pipelined since it can be expensive. - // Preliminary profiling shows ~5% overhead just from mapping the buffer - // alone (on Infiniband; it seems to be trivial for shared memory) - request_param.datatype = ucp_dt_make_contig(1); - pending_send = std::make_unique(); - auto* pending_contig = reinterpret_cast(pending_send.get()); - - const int64_t body_length = std::max(payload.ipc_message.body_length, 1); - ARROW_ASSIGN_OR_RAISE(pending_contig->ipc_message, - AllocateBuffer(body_length, write_memory_pool_)); - TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message, - &pending_contig->memh_p); - - uint8_t* ipc_message = pending_contig->ipc_message->mutable_data(); - if (payload.ipc_message.body_length == 0) { - std::memset(ipc_message, '\0', 1); - } - - for (const auto& buffer : payload.ipc_message.body_buffers) { - if (!buffer || buffer->size() == 0) continue; - - std::memcpy(ipc_message, buffer->data(), buffer->size()); - ipc_message += buffer->size(); - - const auto remainder = static_cast( - bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); - if (remainder) { - std::memset(ipc_message, 0, remainder); - ipc_message += remainder; - } - } - - send_data = reinterpret_cast(pending_contig->ipc_message->mutable_data()); - send_size = static_cast(pending_contig->ipc_message->size()); - } else { - // IOV - let UCX use scatter-gather path - request_param.datatype = UCP_DATATYPE_IOV; - pending_send = std::make_unique(); - auto* pending_iov = reinterpret_cast(pending_send.get()); - - pending_iov->payload = payload; - pending_iov->iovs.resize(total_buffers); - ucp_dt_iov_t* iov = pending_iov->iovs.data(); -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - // XXX: this seems to have no benefits in tests so far - pending_iov->memh_ps.resize(total_buffers); - ucp_mem_h* memh_p = pending_iov->memh_ps.data(); -#endif - for (const auto& buffer : payload.ipc_message.body_buffers) { - if (!buffer || buffer->size() == 0) continue; - - iov->buffer = const_cast(reinterpret_cast(buffer->address())); - iov->length = buffer->size(); - ++iov; - -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - TryMapBuffer(worker_->context().get(), *buffer, memh_p); - memh_p++; -#endif - - const auto remainder = static_cast( - bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); - if (remainder) { - iov->buffer = - const_cast(reinterpret_cast(padding_bytes_.data())); - iov->length = remainder; - ++iov; - } - } - - if (total_buffers == 0) { - // UCX cannot handle zero-byte payloads - pending_iov->iovs.resize(1); - pending_iov->iovs[0].buffer = - const_cast(reinterpret_cast(padding_bytes_.data())); - pending_iov->iovs[0].length = 1; - } - - send_data = pending_iov->iovs.data(); - send_size = pending_iov->iovs.size(); - } - - DCHECK(send_data) << "Payload cannot be nullptr"; - DCHECK_GT(send_size, 0) << "Payload cannot be empty"; - - RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++, - payload.ipc_message.body_length)); - pending_send->driver = this; - pending_send->completed = Future<>::Make(); - - request_param.user_data = pending_send.release(); - { - auto* pending_send = reinterpret_cast(request_param.user_data); - - void* request = ucp_am_send_nbx( - endpoint_, kUcpAmHandlerId, pending_send->header.data(), - pending_send->header.size(), send_data, send_size, &request_param); - if (!request) { - // Request completed immediately - delete pending_send; - return Status::OK(); - } else if (UCS_PTR_IS_ERR(request)) { - delete pending_send; - return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); - } - return pending_send->completed; - } - } - - Status Close() { - std::unique_lock guard(frame_mutex_); - if (!endpoint_) return Status::OK(); - - for (auto& item : frames_) { - item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed")); - } - frames_.clear(); - - void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH); - ucs_status_t status = UCS_OK; - std::string origin = "ucp_ep_close_nb"; - if (UCS_PTR_IS_ERR(request)) { - status = UCS_PTR_STATUS(request); - } else if (UCS_PTR_IS_PTR(request)) { - origin = "ucp_request_check_status"; - while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) { - MakeProgress(); - } - ucp_request_free(request); - } else { - DCHECK(!request); - } - - endpoint_ = nullptr; - if (status != UCS_OK && !IsIgnorableDisconnectError(status)) { - return FromUcsStatus(origin, status); - } - return Status::OK(); - } - - void MakeProgress() { ucp_worker_progress(worker_->get()); } - - void Push(std::shared_ptr frame) { - std::unique_lock guard(frame_mutex_); - if (ARROW_PREDICT_FALSE(!status_.ok())) return; - auto pair = frames_.insert({frame->counter, frame}); - if (!pair.second) { - // Not inserted, because ReadFrameAsync was called for this - // frame counter value and the client is already waiting on - // it. Complete the existing future. - pair.first->second.MarkFinished(std::move(frame)); - frames_.erase(pair.first); - } - // Otherwise, we inserted the frame, meaning the client was not - // currently waiting for that frame counter value - } - - void Push(Status status) { - std::unique_lock guard(frame_mutex_); - status_ = std::move(status); - for (auto& item : frames_) { - // Push(Frame) may push a complete frame, in which case the - // future is already complete - just skip it - if (item.second.is_finished()) continue; - item.second.MarkFinished(status_); - } - frames_.clear(); - } - - ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, - const size_t data_length, - const ucp_am_recv_param_t* param) { - auto maybe_status = - RecvActiveMessageImpl(header, header_length, data, data_length, param); - if (!maybe_status.ok()) { - Push(maybe_status.status()); - return UCS_OK; - } - return maybe_status.MoveValueUnsafe(); - } - - const std::shared_ptr& memory_manager() const { return memory_manager_; } - void set_memory_manager(std::shared_ptr memory_manager) { - if (memory_manager) { - memory_manager_ = std::move(memory_manager); - } else { - memory_manager_ = CPUDevice::Instance()->default_memory_manager(); - } - } - void set_read_memory_pool(MemoryPool* pool) { - read_memory_pool_ = pool ? pool : default_memory_pool(); - } - void set_write_memory_pool(MemoryPool* pool) { - write_memory_pool_ = pool ? pool : default_memory_pool(); - } - const std::string& peer() const { return name_; } - - private: - class PendingAmSend { - public: - virtual ~PendingAmSend() = default; - UcpCallDriver::Impl* driver; - Future<> completed; - FrameHeader header; - }; - - class PendingContigSend : public PendingAmSend { - public: - std::unique_ptr ipc_message; - ucp_mem_h memh_p; - - virtual ~PendingContigSend() { - TryUnmapBuffer(driver->worker_->context().get(), memh_p); - } - }; - - class PendingIovSend : public PendingAmSend { - public: - FlightPayload payload; - std::vector iovs; -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - std::vector memh_ps; - - virtual ~PendingIovSend() { - for (ucp_mem_h memh_p : memh_ps) { - TryUnmapBuffer(driver->worker_->context().get(), memh_p); - } - } -#endif - }; - - struct PendingAmRecv { - UcpCallDriver::Impl* driver; - std::shared_ptr frame; - ucp_mem_h memh_p; - - PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr frame_) - : driver(driver_), frame(std::move(frame_)) {} - - ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); } - }; - - static void AmSendCallback(void* request, ucs_status_t status, void* user_data) { - auto* pending_send = reinterpret_cast(user_data); - if (status == UCS_OK) { - pending_send->completed.MarkFinished(); - } else { - pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status)); - } - // TODO(ARROW-16126): delete should occur on a background thread if there's - // mapped buffers, since unmapping can be nontrivial and we don't want to block - // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.) - delete pending_send; - ucp_request_free(request); - } - - static void AmRecvCallback(void* request, ucs_status_t status, size_t length, - void* user_data) { - auto* pending_recv = reinterpret_cast(user_data); - ucp_request_free(request); - if (status != UCS_OK) { - pending_recv->driver->Push( - FromUcsStatus("ucp_am_recv_data_nbx (callback)", status)); - } else { - pending_recv->driver->Push(std::move(pending_recv->frame)); - } - delete pending_recv; - } - - arrow::Result RecvActiveMessageImpl(const void* header, - size_t header_length, void* data, - const size_t data_length, - const ucp_am_recv_param_t* param) { - DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP); - - if (data_length > static_cast(std::numeric_limits::max())) { - return Status::Invalid("Cannot allocate buffer greater than 2 GiB, requested: ", - data_length); - } - - ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length)); - if (data_length < frame->size) { - return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ", - data_length); - } - - if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) && - (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody)) { - // Zero-copy path. UCX-allocated buffer must be freed later. - - // XXX: this buffer can NOT be freed until AFTER we return from - // this handler. Otherwise, UCX won't have fully set up its - // internal data structures (allocated just before the buffer) - // and we'll crash when we free the buffer. Effectively: we can - // never use Then/AddCallback on a Future<> from ReadFrameAsync, - // because we might run the callback synchronously (which might - // free the buffer) when we call Push here. - frame->buffer = std::make_unique(worker_, data, data_length); - Push(std::move(frame)); - return UCS_INPROGRESS; - } - - if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) || - (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) { - // Rendezvous protocol (RNDV), or unpack to destination (DATA). - - // We want to map/pin/register the buffer for faster transfer - // where possible. (It gets unmapped in ~PendingAmRecv.) - // TODO(ARROW-16126): This takes non-trivial time, so return - // UCS_INPROGRESS, kick off the allocation in the background, - // and recv the data later (is it allowed to call - // ucp_am_recv_data_nbx asynchronously?). - if (frame->type == FrameType::kPayloadBody) { - ARROW_ASSIGN_OR_RAISE(frame->buffer, - memory_manager_->AllocateBuffer(data_length)); - } else { - ARROW_ASSIGN_OR_RAISE(frame->buffer, - AllocateBuffer(data_length, read_memory_pool_)); - } - - PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame)); - TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer, - &pending_recv->memh_p); - - ucp_request_param_t recv_param; - recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_MEMORY_TYPE | - UCP_OP_ATTR_FIELD_USER_DATA; - recv_param.cb.recv_am = AmRecvCallback; - recv_param.user_data = pending_recv; - recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer); - - void* dest = - reinterpret_cast(pending_recv->frame->buffer->mutable_address()); - void* request = - ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param); - if (UCS_PTR_IS_ERR(request)) { - delete pending_recv; - return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request)); - } else if (!request) { - // Request completed instantly - Push(std::move(pending_recv->frame)); - delete pending_recv; - } - return UCS_OK; - } else { - // Data will be freed after callback returns - copy to buffer - if (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody) { - ARROW_ASSIGN_OR_RAISE(frame->buffer, - AllocateBuffer(data_length, read_memory_pool_)); - std::memcpy(frame->buffer->mutable_data(), data, data_length); - } else { - ARROW_ASSIGN_OR_RAISE( - frame->buffer, - MemoryManager::CopyNonOwned(Buffer(reinterpret_cast(data), - static_cast(data_length)), - memory_manager_)); - } - Push(std::move(frame)); - return UCS_OK; - } - } - - Status CompleteRequestBlocking(const std::string& context, void* request) { - if (UCS_PTR_IS_ERR(request)) { - return FromUcsStatus(context, UCS_PTR_STATUS(request)); - } else if (UCS_PTR_IS_PTR(request)) { - while (true) { - auto status = ucp_request_check_status(request); - if (status == UCS_OK) { - break; - } else if (status != UCS_INPROGRESS) { - ucp_request_release(request); - return FromUcsStatus("ucp_request_check_status", status); - } - MakeProgress(); - } - ucp_request_free(request); - } else { - // Send was completed instantly - DCHECK(!request); - } - return Status::OK(); - } - - Status CheckClosed() { - if (!endpoint_) { - return Status::Invalid("UcpCallDriver is closed"); - } - return Status::OK(); - } - - const std::array padding_bytes_; -#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) - ucp_mem_h padding_memh_p_; -#endif - - std::shared_ptr worker_; - ucp_ep_h endpoint_; - MemoryPool* read_memory_pool_; - MemoryPool* write_memory_pool_; - std::shared_ptr memory_manager_; - - // Internal name for logging/tracing - std::string name_; - // Counter used to reorder messages - uint32_t counter_ = 0; - - std::mutex frame_mutex_; - Status status_; - std::unordered_map>> frames_; - uint32_t next_counter_ = 0; -}; - -UcpCallDriver::UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint) - : impl_(new Impl(std::move(worker), endpoint)) {} -UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default; -UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default; -UcpCallDriver::~UcpCallDriver() = default; - -arrow::Result> UcpCallDriver::ReadNextFrame() { - return impl_->ReadNextFrame(); -} - -Future> UcpCallDriver::ReadFrameAsync() { - return impl_->ReadFrameAsync(); -} - -Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) { - if (frame.type != type) { - return Status::IOError("Expected frame type ", static_cast(type), - ", but got frame type ", static_cast(frame.type)); - } - return Status::OK(); -} - -Status UcpCallDriver::StartCall(const std::string& method) { - std::vector> headers; - headers.emplace_back(kHeaderMethod, method); - ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers)); - auto buffer = std::move(frame).GetBuffer(); - RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size())); - return Status::OK(); -} - -Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) { - return impl_->SendFlightPayload(payload); -} - -Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data, - const int64_t size) { - return impl_->SendFrame(frame_type, data, size); -} - -Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type, - std::unique_ptr buffer) { - return impl_->SendFrameAsync(frame_type, std::move(buffer)); -} - -Status UcpCallDriver::Close() { return impl_->Close(); } - -void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); } - -ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length, - void* data, const size_t data_length, - const ucp_am_recv_param_t* param) { - return impl_->RecvActiveMessage(header, header_length, data, data_length, param); -} - -const std::shared_ptr& UcpCallDriver::memory_manager() const { - return impl_->memory_manager(); -} - -void UcpCallDriver::set_memory_manager(std::shared_ptr memory_manager) { - impl_->set_memory_manager(std::move(memory_manager)); -} -void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) { - impl_->set_read_memory_pool(pool); -} -void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) { - impl_->set_write_memory_pool(pool); -} -const std::string& UcpCallDriver::peer() const { return impl_->peer(); } - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h deleted file mode 100644 index c46f81eb7498b..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h +++ /dev/null @@ -1,363 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Common implementation of UCX communication primitives. - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include "arrow/buffer.h" -#include "arrow/flight/server.h" -#include "arrow/flight/transport.h" -#include "arrow/flight/transport/ucx/util_internal.h" -#include "arrow/flight/visibility.h" -#include "arrow/type_fwd.h" -#include "arrow/util/future.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" - -namespace arrow { -namespace flight { -namespace transport { -namespace ucx { - -//------------------------------------------------------------ -// Protocol Constants - -static constexpr char kMethodDoExchange[] = "DoExchange"; -static constexpr char kMethodDoGet[] = "DoGet"; -static constexpr char kMethodDoPut[] = "DoPut"; -static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo"; -static constexpr char kMethodPollFlightInfo[] = "PollFlightInfo"; - -/// The header encoding the transport status. -static constexpr char kHeaderStatus[] = "flight-status"; -/// The header encoding the transport status. -static constexpr char kHeaderMessage[] = "flight-message"; -/// The header encoding the C++ status. -static constexpr char kHeaderStatusCode[] = "flight-status-code"; -/// The header encoding the C++ status message. -static constexpr char kHeaderStatusMessage[] = "flight-status-message"; -/// The header encoding the C++ status detail message. -static constexpr char kHeaderStatusDetail[] = "flight-status-detail"; -/// The header encoding the C++ status detail binary data. -static constexpr char kHeaderStatusDetailBin[] = "flight-status-detail-bin"; - -//------------------------------------------------------------ -// UCX Helpers - -/// \brief A wrapper around a ucp_context_h. -/// -/// Used so that multiple resources can share ownership of the -/// context. UCX has zero-copy optimizations where an application can -/// directly use a UCX buffer, but the lifetime of such buffers is -/// tied to the UCX context and worker, so ownership needs to be -/// preserved. -class UcpContext final { - public: - UcpContext() : ucp_context_(nullptr) {} - explicit UcpContext(ucp_context_h context) : ucp_context_(context) {} - ~UcpContext() { - if (ucp_context_) ucp_cleanup(ucp_context_); - ucp_context_ = nullptr; - } - ucp_context_h get() const { - DCHECK(ucp_context_); - return ucp_context_; - } - - private: - ucp_context_h ucp_context_; -}; - -/// \brief A wrapper around a ucp_worker_h. -class UcpWorker final { - public: - UcpWorker() : ucp_worker_(nullptr) {} - UcpWorker(std::shared_ptr context, ucp_worker_h worker) - : ucp_context_(std::move(context)), ucp_worker_(worker) {} - ~UcpWorker() { - if (ucp_worker_) ucp_worker_destroy(ucp_worker_); - ucp_worker_ = nullptr; - } - ucp_worker_h get() const { - DCHECK(ucp_worker_); - return ucp_worker_; - } - const UcpContext& context() const { return *ucp_context_; } - - private: - std::shared_ptr ucp_context_; - ucp_worker_h ucp_worker_; -}; - -//------------------------------------------------------------ -// Message Framing - -/// \brief The message type. -enum class FrameType : uint8_t { - /// Key-value headers. Sent at the beginning (client->server) and - /// end (server->client) of a call. Also, for client-streaming calls - /// (e.g. DoPut), the client should send a headers frame to signal - /// end-of-stream. - kHeaders = 0, - /// Binary blob, does not contain Arrow data. - kBuffer, - /// Binary blob. Contains IPC metadata, app metadata. - kPayloadHeader, - /// Binary blob. Contains IPC body. Body is sent separately since it - /// may use a different memory type. - kPayloadBody, - /// Ask server to disconnect (to avoid client/server waiting on each - /// other and getting stuck). - kDisconnect, - /// Keep at end. - kMaxFrameType = kDisconnect, -}; - -/// \brief The header of a message frame. Used when sending only. -/// -/// A frame is expected to be sent over UCP Active Messages and -/// consists of a header (of kFrameHeaderBytes bytes) and a body. -/// -/// The header is as follows: -/// +-------+---------------------------------+ -/// | Bytes | Function | -/// +=======+=================================+ -/// | 0 | Version tag (see kFrameVersion) | -/// | 1 | Frame type (see FrameType) | -/// | 2-3 | Unused, reserved | -/// | 4-7 | Frame counter (big-endian) | -/// | 8-11 | Body size (big-endian) | -/// +-------+---------------------------------+ -/// -/// The frame counter lets the receiver ensure messages are processed -/// in-order. (The message receive callback may use -/// ucp_am_recv_data_nbx which is asynchronous.) -/// -/// The body size reports the expected message size (UCX chokes on -/// zero-size payloads which we occasionally want to send, so the size -/// field in the header lets us know when a payload was meant to be -/// empty). -struct FrameHeader { - /// \brief The size of a frame header. - static constexpr size_t kFrameHeaderBytes = 12; - /// \brief The expected version tag in the header. - static constexpr uint8_t kFrameVersion = 0x01; - - FrameHeader() = default; - /// \brief Initialize the frame header. - Status Set(FrameType frame_type, uint32_t counter, int64_t body_size); - void* data() const { return header.data(); } - size_t size() const { return kFrameHeaderBytes; } - - // mutable since UCX expects void* not const void* - mutable std::array header = {0}; -}; - -/// \brief A single message received via UCX. Used when receiving only. -struct Frame { - /// \brief The message type. - FrameType type; - /// \brief The message length. - uint32_t size; - /// \brief An incrementing message counter (may wrap over). - uint32_t counter; - /// \brief The message contents. - std::unique_ptr buffer; - - Frame() = default; - Frame(FrameType type_, uint32_t size_, uint32_t counter_, - std::unique_ptr buffer_) - : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {} - - std::string_view view() const { - return std::string_view(reinterpret_cast(buffer->data()), size); - } - - /// \brief Parse a UCX active message header. This will not - /// initialize the buffer field. - static arrow::Result> ParseHeader(const void* header, - size_t header_length); -}; - -/// \brief The active message handler callback ID. -static constexpr uint32_t kUcpAmHandlerId = 0x1024; - -/// \brief A collection of key-value headers. -/// -/// This should be stored in a frame of type kHeaders. -/// -/// Format: -/// +-------+----------------------------------+ -/// | Bytes | Contents | -/// +=======+==================================+ -/// | 0-4 | # of headers (big-endian) | -/// | 4-8 | Header key length (big-endian) | -/// | 2-3 | Header value length (big-endian) | -/// | (...) | Header key | -/// | (...) | Header value | -/// | (...) | (repeat from row 2) | -/// +-------+----------------------------------+ -class HeadersFrame { - public: - /// \brief Get a header value (or an error if it was not found) - arrow::Result Get(const std::string& key); - /// \brief Extract the server-sent status. - Status GetStatus(Status* out); - /// \brief Parse the headers from the buffer. - static arrow::Result Parse(std::unique_ptr buffer); - /// \brief Create a new frame with the given headers. - static arrow::Result Make( - const std::vector>& headers); - /// \brief Create a new frame with the given headers and the given status. - static arrow::Result Make( - const Status& status, - const std::vector>& headers); - - /// \brief Take ownership of the underlying buffer. - std::unique_ptr GetBuffer() && { return std::move(buffer_); } - - private: - std::unique_ptr buffer_; - std::vector> headers_; -}; - -/// \brief A representation of a kPayloadHeader frame (i.e. all of the -/// metadata in a FlightPayload/FlightData). -/// -/// Data messages are sent in two parts: one containing all metadata -/// (the Flatbuffers header, FlightDescriptor, and app_metadata -/// fields) and one containing the actual data. This was done to avoid -/// having to concatenate these fields with the data itself (in the -/// cases where we are not using IOV). -/// -/// Format: -/// +--------+----------------------------------+ -/// | Bytes | Contents | -/// +========+==================================+ -/// | 0-4 | Descriptor length (big-endian) | -/// | 4..a | Descriptor bytes | -/// | a-a+4 | app_metadata length (big-endian) | -/// | a+4..b | app_metadata bytes | -/// | b-b+4 | ipc_metadata length (big-endian) | -/// | b+4..c | ipc_metadata bytes | -/// +--------+----------------------------------+ -/// -/// If a field is not present, its length is still there, but is set -/// to UINT32_MAX. -class PayloadHeaderFrame { - public: - explicit PayloadHeaderFrame(std::unique_ptr buffer) - : buffer_(std::move(buffer)) {} - /// \brief Unpack the internal buffer into a FlightData. - Status ToFlightData(internal::FlightData* data); - /// \brief Pack a payload into the internal buffer. - static arrow::Result Make(const FlightPayload& payload, - MemoryPool* memory_pool); - const uint8_t* data() const { return buffer_->data(); } - int64_t size() const { return buffer_->size(); } - - private: - std::unique_ptr buffer_; -}; - -/// \brief Manage the state of a UCX connection. -class UcpCallDriver { - public: - UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint); - - UcpCallDriver(const UcpCallDriver&) = delete; - UcpCallDriver(UcpCallDriver&&); - void operator=(const UcpCallDriver&) = delete; - UcpCallDriver& operator=(UcpCallDriver&&); - - ~UcpCallDriver(); - - /// \brief Start a call by sending a headers frame. Client side only. - /// - /// \param[in] method The RPC method. - Status StartCall(const std::string& method); - - /// \brief Synchronously send a generic message with binary payload. - Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size); - /// \brief Asynchronously send a generic message with binary payload. - /// - /// The UCP driver must be manually polled (call MakeProgress()). - Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer); - /// \brief Asynchronously send a data message. - /// - /// The UCP driver must be manually polled (call MakeProgress()). - Future<> SendFlightPayload(const FlightPayload& payload); - - /// \brief Synchronously read the next frame. - arrow::Result> ReadNextFrame(); - /// \brief Asynchronously read the next frame. - /// - /// The UCP driver must be manually polled (call MakeProgress()). - Future> ReadFrameAsync(); - - /// \brief Validate that the frame is of the given type. - Status ExpectFrameType(const Frame& frame, FrameType type); - - /// \brief Disconnect the other side of the connection. Note, this - /// can cause deadlock. - Status Close(); - - /// \brief Synchronously make progress (to adapt async to sync APIs) - void MakeProgress(); - - /// \brief Get the associated memory manager. - const std::shared_ptr& memory_manager() const; - /// \brief Set the associated memory manager. - void set_memory_manager(std::shared_ptr memory_manager); - /// \brief Set memory pool for scratch space used during reading. - void set_read_memory_pool(MemoryPool* memory_pool); - /// \brief Set memory pool for scratch space used during writing. - void set_write_memory_pool(MemoryPool* memory_pool); - /// \brief Get a debug string naming the peer. - const std::string& peer() const; - - /// \brief Process an incoming active message. This will unblock the - /// corresponding call to ReadFrameAsync/ReadNextFrame. - ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, - const size_t data_length, - const ucp_am_recv_param_t* param); - - private: - class Impl; - std::unique_ptr impl_; -}; - -ARROW_FLIGHT_EXPORT -std::unique_ptr MakeUcxClientImpl(); - -ARROW_FLIGHT_EXPORT -std::unique_ptr MakeUcxServerImpl( - FlightServerBase* base, std::shared_ptr memory_manager); - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc deleted file mode 100644 index 55ff138348812..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ /dev/null @@ -1,653 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/transport/ucx/ucx_internal.h" - -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "arrow/buffer.h" -#include "arrow/flight/server.h" -#include "arrow/flight/transport.h" -#include "arrow/flight/transport/ucx/util_internal.h" -#include "arrow/flight/transport_server.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/io_util.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/util/thread_pool.h" -#include "arrow/util/uri.h" - -namespace arrow { - -using internal::ToChars; - -namespace flight { -namespace transport { -namespace ucx { - -// Send an error to the client and return OK. -// Statuses returned up to the main server loop trigger a kReset instead. -#define SERVER_RETURN_NOT_OK(driver, status) \ - do { \ - ::arrow::Status s = (status); \ - if (!s.ok()) { \ - ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(s, {})); \ - auto payload = std::move(headers).GetBuffer(); \ - RETURN_NOT_OK( \ - driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); \ - return ::arrow::Status::OK(); \ - } \ - } while (false) - -#define FLIGHT_LOG(LEVEL) (ARROW_LOG(LEVEL) << "[server] ") -#define FLIGHT_LOG_PEER(LEVEL, PEER) \ - (ARROW_LOG(LEVEL) << "[server]" \ - << "[peer=" << (PEER) << "] ") - -namespace { -class UcxServerCallContext : public flight::ServerCallContext { - public: - const std::string& peer_identity() const override { return peer_; } - const std::string& peer() const override { return peer_; } - // Not supported - void AddHeader(const std::string& key, const std::string& value) const override {} - void AddTrailer(const std::string& key, const std::string& value) const override {} - ServerMiddleware* GetMiddleware(const std::string& key) const override { - return nullptr; - } - bool is_cancelled() const override { return false; } - const CallHeaders& incoming_headers() const override { return incoming_headers_; } - - private: - std::string peer_; - CallHeaders incoming_headers_; -}; - -class UcxServerStream : public internal::ServerDataStream { - public: - explicit UcxServerStream(UcpCallDriver* driver) - : peer_(driver->peer()), driver_(driver), writes_done_(false) {} - - Status WritesDone() override { - writes_done_ = true; - return Status::OK(); - } - - protected: - std::string peer_; - UcpCallDriver* driver_; - bool writes_done_; -}; - -class GetServerStream : public UcxServerStream { - public: - using UcxServerStream::UcxServerStream; - - arrow::Result WriteData(const FlightPayload& payload) override { - if (writes_done_) return false; - Future<> pending_send = driver_->SendFlightPayload(payload); - while (!pending_send.is_finished()) { - driver_->MakeProgress(); - } - RETURN_NOT_OK(pending_send.status()); - return true; - } -}; - -class PutServerStream : public UcxServerStream { - public: - explicit PutServerStream(UcpCallDriver* driver) - : UcxServerStream(driver), finished_(false) {} - - bool ReadData(internal::FlightData* data) override { - if (finished_) return false; - - bool success = true; - auto status = ReadImpl(data).Value(&success); - - if (!status.ok() || !success) { - finished_ = true; - if (!status.ok()) { - FLIGHT_LOG_PEER(WARNING, peer_) << "I/O error in DoPut: " << status.ToString(); - return false; - } - } - return success; - } - - Status WritePutMetadata(const Buffer& payload) override { - if (finished_) return Status::OK(); - // Send synchronously (we don't control payload lifetime) - return driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size()); - } - - private: - ::arrow::Result ReadImpl(internal::FlightData* data) { - ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); - if (frame->type == FrameType::kHeaders) { - // Trailers, client is done writing - return false; - } - RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); - PayloadHeaderFrame payload_header(std::move(frame->buffer)); - RETURN_NOT_OK(payload_header.ToFlightData(data)); - - if (data->metadata) { - ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); - - if (ipc::Message::HasBody(message->type())) { - ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); - RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); - data->body = std::move(frame->buffer); - } - } - return true; - } - - bool finished_; -}; - -class ExchangeServerStream : public PutServerStream { - public: - using PutServerStream::PutServerStream; - - arrow::Result WriteData(const FlightPayload& payload) override { - if (writes_done_) return false; - Future<> pending_send = driver_->SendFlightPayload(payload); - while (!pending_send.is_finished()) { - driver_->MakeProgress(); - } - RETURN_NOT_OK(pending_send.status()); - return true; - } - Status WritePutMetadata(const Buffer& payload) override { - return Status::NotImplemented("Not supported on this stream"); - } -}; - -class UcxServerImpl : public arrow::flight::internal::ServerTransport { - public: - using arrow::flight::internal::ServerTransport::ServerTransport; - - virtual ~UcxServerImpl() { - if (listening_.load()) { - ARROW_WARN_NOT_OK(Shutdown(), "Server did not shut down properly"); - } - } - - Status Init(const FlightServerOptions& options, const arrow::util::Uri& uri) override { - const auto max_threads = std::max(8, std::thread::hardware_concurrency()); - ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(max_threads)); - - struct sockaddr_storage listen_addr; - ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); - - // Init UCX - { - ucp_config_t* ucp_config; - ucp_params_t ucp_params; - ucs_status_t status; - - status = ucp_config_read(nullptr, nullptr, &ucp_config); - RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); - - // If location is IPv6, must adjust UCX config - if (listen_addr.ss_family == AF_INET6) { - status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); - RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); - } - - // Allow application to override UCP config - if (options.builder_hook) options.builder_hook(ucp_config); - - std::memset(&ucp_params, 0, sizeof(ucp_params)); - ucp_params.field_mask = - UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED; - ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; - ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI; - - ucp_context_h ucp_context; - status = ucp_init(&ucp_params, ucp_config, &ucp_context); - ucp_config_release(ucp_config); - RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); - ucp_context_.reset(new UcpContext(ucp_context)); - } - - { - // Create one worker to listen for incoming connections. - ucp_worker_params_t worker_params; - ucs_status_t status; - - std::memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_MULTI; - ucp_worker_h worker; - status = ucp_worker_create(ucp_context_->get(), &worker_params, &worker); - RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); - worker_conn_.reset(new UcpWorker(ucp_context_, worker)); - } - - // Start listening for connections. - { - ucp_listener_params_t params; - ucs_status_t status; - - params.field_mask = - UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; - params.sockaddr.addr = reinterpret_cast(&listen_addr); - params.sockaddr.addrlen = addrlen; - params.conn_handler.cb = HandleIncomingConnection; - params.conn_handler.arg = this; - - status = ucp_listener_create(worker_conn_->get(), ¶ms, &listener_); - RETURN_NOT_OK(FromUcsStatus("ucp_listener_create", status)); - - // Get the real address/port - ucp_listener_attr_t attr; - attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; - status = ucp_listener_query(listener_, &attr); - RETURN_NOT_OK(FromUcsStatus("ucp_listener_query", status)); - - std::string raw_uri = "ucx://"; - if (uri.host().find(':') != std::string::npos) { - // IPv6 host - raw_uri += '['; - raw_uri += uri.host(); - raw_uri += ']'; - } else { - raw_uri += uri.host(); - } - raw_uri += ":"; - raw_uri += - ToChars(ntohs(reinterpret_cast(&attr.sockaddr)->sin_port)); - std::string listen_str; - ARROW_UNUSED(SockaddrToString(attr.sockaddr).Value(&listen_str)); - FLIGHT_LOG(DEBUG) << "Listening on " << listen_str; - ARROW_ASSIGN_OR_RAISE(location_, Location::Parse(raw_uri)); - } - - { - listening_.store(true); - std::thread listener_thread(&UcxServerImpl::DriveConnections, this); - listener_thread_.swap(listener_thread); - } - - return Status::OK(); - } - - Status Shutdown() override { - if (!listening_.load()) return Status::OK(); - Status status; - - // Wait for current RPCs to finish - listening_.store(false); - // Unstick the listener thread from ucp_worker_wait - RETURN_NOT_OK( - FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get()))); - status &= Wait(); - - { - // Reject all pending connections - std::unique_lock guard(pending_connections_mutex_); - while (!pending_connections_.empty()) { - status &= - FromUcsStatus("ucp_listener_reject", - ucp_listener_reject(listener_, pending_connections_.front())); - pending_connections_.pop(); - } - ucp_listener_destroy(listener_); - worker_conn_.reset(); - } - - status &= rpc_pool_->Shutdown(); - rpc_pool_.reset(); - - ucp_context_.reset(); - return status; - } - - Status Shutdown(const std::chrono::system_clock::time_point& deadline) override { - // TODO(ARROW-16125): implement shutdown with deadline - return Shutdown(); - } - - Status Wait() override { - std::lock_guard guard(join_mutex_); - try { - listener_thread_.join(); - } catch (const std::system_error& e) { - if (e.code() != std::errc::invalid_argument) { - return Status::UnknownError("Could not Wait(): ", e.what()); - } - // Else, server wasn't running anyways - } - return Status::OK(); - } - - Location location() const override { return location_; } - - private: - struct ClientWorker { - std::shared_ptr worker; - std::unique_ptr driver; - }; - - Status SendStatus(UcpCallDriver* driver, const Status& status) { - ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(status, {})); - auto payload = std::move(headers).GetBuffer(); - RETURN_NOT_OK( - driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); - return Status::OK(); - } - - Status HandleGetFlightInfo(UcpCallDriver* driver) { - UcxServerCallContext context; - - ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); - SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); - FlightDescriptor descriptor; - SERVER_RETURN_NOT_OK(driver, - FlightDescriptor::Deserialize(std::string_view(*frame->buffer)) - .Value(&descriptor)); - - std::unique_ptr info; - std::string response; - SERVER_RETURN_NOT_OK(driver, base_->GetFlightInfo(context, descriptor, &info)); - SERVER_RETURN_NOT_OK(driver, info->DoSerializeToString(&response)); - RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, - reinterpret_cast(response.data()), - static_cast(response.size()))); - RETURN_NOT_OK(SendStatus(driver, Status::OK())); - return Status::OK(); - } - - Status HandlePollFlightInfo(UcpCallDriver* driver) { - UcxServerCallContext context; - - ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); - SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); - FlightDescriptor descriptor; - SERVER_RETURN_NOT_OK(driver, - FlightDescriptor::Deserialize(std::string_view(*frame->buffer)) - .Value(&descriptor)); - - std::unique_ptr info; - std::string response; - SERVER_RETURN_NOT_OK(driver, base_->PollFlightInfo(context, descriptor, &info)); - SERVER_RETURN_NOT_OK(driver, info->DoSerializeToString(&response)); - RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, - reinterpret_cast(response.data()), - static_cast(response.size()))); - RETURN_NOT_OK(SendStatus(driver, Status::OK())); - return Status::OK(); - } - - Status HandleDoGet(UcpCallDriver* driver) { - UcxServerCallContext context; - - ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); - SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); - Ticket ticket; - SERVER_RETURN_NOT_OK(driver, Ticket::Deserialize(frame->view()).Value(&ticket)); - - GetServerStream stream(driver); - auto status = DoGet(context, std::move(ticket), &stream); - RETURN_NOT_OK(SendStatus(driver, status)); - return Status::OK(); - } - - Status HandleDoPut(UcpCallDriver* driver) { - UcxServerCallContext context; - - PutServerStream stream(driver); - auto status = DoPut(context, &stream); - RETURN_NOT_OK(SendStatus(driver, status)); - // Must drain any unread messages, or the next call will get confused - internal::FlightData ignored; - while (stream.ReadData(&ignored)) { - } - return Status::OK(); - } - - Status HandleDoExchange(UcpCallDriver* driver) { - UcxServerCallContext context; - - ExchangeServerStream stream(driver); - auto status = DoExchange(context, &stream); - RETURN_NOT_OK(SendStatus(driver, status)); - // Must drain any unread messages, or the next call will get confused - internal::FlightData ignored; - while (stream.ReadData(&ignored)) { - } - return Status::OK(); - } - - Status HandleOneCall(UcpCallDriver* driver, Frame* frame) { - SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kHeaders)); - ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); - ARROW_ASSIGN_OR_RAISE(auto method, headers.Get(":method:")); - if (method == kMethodGetFlightInfo) { - return HandleGetFlightInfo(driver); - } else if (method == kMethodPollFlightInfo) { - return HandlePollFlightInfo(driver); - } else if (method == kMethodDoExchange) { - return HandleDoExchange(driver); - } else if (method == kMethodDoGet) { - return HandleDoGet(driver); - } else if (method == kMethodDoPut) { - return HandleDoPut(driver); - } - RETURN_NOT_OK(SendStatus(driver, Status::NotImplemented(method))); - return Status::OK(); - } - - void WorkerLoop(ucp_conn_request_h request) { - std::string peer = "unknown:" + ToChars(counter_++); - { - ucp_conn_request_attr_t request_attr; - std::memset(&request_attr, 0, sizeof(request_attr)); - request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; - if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { - ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer)); - } - } - FLIGHT_LOG_PEER(DEBUG, peer) << "Received connection request"; - - auto maybe_worker = CreateWorker(); - if (!maybe_worker.ok()) { - FLIGHT_LOG_PEER(WARNING, peer) - << "Failed to create worker" << maybe_worker.status().ToString(); - auto status = ucp_listener_reject(listener_, request); - if (status != UCS_OK) { - FLIGHT_LOG_PEER(WARNING, peer) - << FromUcsStatus("ucp_listener_reject", status).ToString(); - } - return; - } - auto worker = maybe_worker.MoveValueUnsafe(); - - // Create an endpoint to the client, using the data worker - { - ucs_status_t status; - ucp_ep_params_t params; - std::memset(¶ms, 0, sizeof(params)); - params.field_mask = UCP_EP_PARAM_FIELD_CONN_REQUEST; - params.conn_request = request; - - ucp_ep_h client_endpoint; - - status = ucp_ep_create(worker->worker->get(), ¶ms, &client_endpoint); - if (status != UCS_OK) { - FLIGHT_LOG_PEER(WARNING, peer) - << "Failed to create endpoint: " - << FromUcsStatus("ucp_ep_create", status).ToString(); - return; - } - worker->driver.reset(new UcpCallDriver(worker->worker, client_endpoint)); - worker->driver->set_memory_manager(memory_manager_); - peer = worker->driver->peer(); - } - - while (listening_.load()) { - auto maybe_frame = worker->driver->ReadNextFrame(); - if (!maybe_frame.ok()) { - if (!maybe_frame.status().IsCancelled()) { - FLIGHT_LOG_PEER(WARNING, peer) - << "Failed to read next message: " << maybe_frame.status().ToString(); - } - break; - } - - auto status = HandleOneCall(worker->driver.get(), maybe_frame->get()); - if (!status.ok()) { - FLIGHT_LOG_PEER(WARNING, peer) << "Call failed: " << status.ToString(); - break; - } - } - - // Clean up - auto status = worker->driver->Close(); - if (!status.ok()) { - FLIGHT_LOG_PEER(WARNING, peer) << "Failed to close worker: " << status.ToString(); - } - worker->worker.reset(); - FLIGHT_LOG_PEER(DEBUG, peer) << "Disconnected"; - } - - void DriveConnections() { - while (listening_.load()) { - while (ucp_worker_progress(worker_conn_->get())) { - } - { - // Check for connect requests in queue - std::unique_lock guard(pending_connections_mutex_); - while (!pending_connections_.empty()) { - ucp_conn_request_h request = pending_connections_.front(); - pending_connections_.pop(); - - auto submitted = rpc_pool_->Submit([this, request]() { WorkerLoop(request); }); - ARROW_WARN_NOT_OK(submitted.status(), "Failed to submit task to handle client"); - } - } - - // Check listening_ in case we're shutting down. It is possible - // that Shutdown() was called while we were in - // ucp_worker_progress above, in which case if we don't check - // listening_ here, we'll enter ucp_worker_wait and get stuck. - if (!listening_.load()) break; - auto status = ucp_worker_wait(worker_conn_->get()); - if (status != UCS_OK) { - FLIGHT_LOG(WARNING) << FromUcsStatus("ucp_worker_wait", status).ToString(); - } - } - } - - void EnqueueClient(ucp_conn_request_h connection_request) { - std::unique_lock guard(pending_connections_mutex_); - pending_connections_.push(connection_request); - guard.unlock(); - } - - arrow::Result> CreateWorker() { - auto worker = std::make_shared(); - - ucp_worker_params_t worker_params; - std::memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; - - ucp_worker_h ucp_worker; - auto status = ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker); - RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); - worker->worker.reset(new UcpWorker(ucp_context_, ucp_worker)); - - // Set up Active Message (AM) handler - ucp_am_handler_param_t handler_params; - std::memset(&handler_params, 0, sizeof(handler_params)); - handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | - UCP_AM_HANDLER_PARAM_FIELD_CB | - UCP_AM_HANDLER_PARAM_FIELD_ARG; - handler_params.id = kUcpAmHandlerId; - handler_params.cb = HandleIncomingActiveMessage; - handler_params.arg = worker.get(); - - status = ucp_worker_set_am_recv_handler(worker->worker->get(), &handler_params); - RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); - return worker; - } - - // Callback handler. A new client has connected to the server. - static void HandleIncomingConnection(ucp_conn_request_h connection_request, - void* data) { - UcxServerImpl* server = reinterpret_cast(data); - // TODO(ARROW-16124): enable shedding load above some threshold - // (which is a pitfall with gRPC/Java) - server->EnqueueClient(connection_request); - } - - static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, - size_t header_length, void* data, - size_t data_length, - const ucp_am_recv_param_t* param) { - ClientWorker* worker = reinterpret_cast(self); - DCHECK(worker->driver); - return worker->driver->RecvActiveMessage(header, header_length, data, data_length, - param); - } - - std::shared_ptr ucp_context_; - // Listen for and handle incoming connections - std::shared_ptr worker_conn_; - ucp_listener_h listener_; - Location location_; - - // Counter for identifying peers when UCX doesn't give us a way - std::atomic counter_; - - std::shared_ptr rpc_pool_; - std::atomic listening_; - std::thread listener_thread_; - // std::thread::join cannot be called concurrently - std::mutex join_mutex_; - - std::mutex pending_connections_mutex_; - std::queue pending_connections_; -}; -} // namespace - -std::unique_ptr MakeUcxServerImpl( - FlightServerBase* base, std::shared_ptr memory_manager) { - return std::make_unique(base, memory_manager); -} - -#undef SERVER_RETURN_NOT_OK -#undef FLIGHT_LOG -#undef FLIGHT_LOG_PEER - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc deleted file mode 100644 index 2db7d4e2630ff..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ /dev/null @@ -1,293 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/transport/ucx/util_internal.h" - -#include -#include -#include - -#include -#include -#include -#include - -#include "arrow/buffer.h" -#include "arrow/flight/types.h" -#include "arrow/util/base64.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/io_util.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/util/uri.h" - -namespace arrow { - -using internal::ToChars; - -namespace flight { -namespace transport { -namespace ucx { - -constexpr char FlightUcxStatusDetail::kTypeId[]; -std::string FlightUcxStatusDetail::ToString() const { return ucs_status_string(status_); } -ucs_status_t FlightUcxStatusDetail::Unwrap(const Status& status) { - if (!status.detail() || status.detail()->type_id() != kTypeId) return UCS_OK; - return dynamic_cast(status.detail().get())->status_; -} - -arrow::Result UriToSockaddr(const arrow::util::Uri& uri, - struct sockaddr_storage* addr) { - std::string host = uri.host(); - if (host.empty()) { - return Status::Invalid("Must provide a host"); - } else if (uri.port() < 0) { - return Status::Invalid("Must provide a port"); - } - - std::memset(addr, 0, sizeof(*addr)); - - struct addrinfo* info = nullptr; - int err = getaddrinfo(host.c_str(), /*service=*/nullptr, /*hints=*/nullptr, &info); - if (err != 0) { - if (err == EAI_SYSTEM) { - return arrow::internal::IOErrorFromErrno(errno, "[getaddrinfo] Failure resolving ", - host); - } else { - return Status::IOError("[getaddrinfo] Failure resolving ", host, ": ", - gai_strerror(err)); - } - } - - struct addrinfo* cur_info = info; - while (cur_info) { - if (cur_info->ai_family != AF_INET && cur_info->ai_family != AF_INET6) { - cur_info = cur_info->ai_next; - continue; - } - - std::memcpy(addr, cur_info->ai_addr, cur_info->ai_addrlen); - if (cur_info->ai_family == AF_INET) { - reinterpret_cast(addr)->sin_port = htons(uri.port()); - } else if (cur_info->ai_family == AF_INET6) { - reinterpret_cast(addr)->sin6_port = htons(uri.port()); - } - size_t addrlen = cur_info->ai_addrlen; - freeaddrinfo(info); - return addrlen; - } - - if (info) freeaddrinfo(info); - return Status::IOError("[getaddrinfo] Failure resolving ", host, - ": no results of a supported family returned"); -} - -arrow::Result SockaddrToString(const struct sockaddr_storage& address) { - std::string result = ""; - if (address.ss_family != AF_INET && address.ss_family != AF_INET6) { - return Status::NotImplemented("Unknown address family"); - } - - uint16_t port = 0; - if (address.ss_family == AF_INET) { - result.resize(INET_ADDRSTRLEN + 1); - const auto* in_addr = reinterpret_cast(&address); - if (!inet_ntop(address.ss_family, &in_addr->sin_addr, &result[0], INET_ADDRSTRLEN)) { - return arrow::internal::IOErrorFromErrno(errno, - "Could not convert address to string"); - } - port = ntohs(in_addr->sin_port); - } else { - result.resize(INET6_ADDRSTRLEN + 1); - const auto* in6_addr = reinterpret_cast(&address); - if (!inet_ntop(address.ss_family, &in6_addr->sin6_addr, &result[0], - INET6_ADDRSTRLEN)) { - return arrow::internal::IOErrorFromErrno(errno, - "Could not convert address to string"); - } - port = ntohs(in6_addr->sin6_port); - } - - const size_t pos = result.find('\0'); - DCHECK_NE(pos, std::string::npos); - result[pos] = ':'; - result.resize(pos + 1); - result += ToChars(port); - return result; -} - -Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) { - switch (ucs_status) { - case UCS_OK: - return Status::OK(); - case UCS_INPROGRESS: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_MESSAGE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_RESOURCE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_IO_ERROR: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_MEMORY: - return Status::OutOfMemory(context, ": UCX error ", - static_cast(ucs_status), ": ", - "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_INVALID_PARAM: - return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_INVALID_PARAM ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_UNREACHABLE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_INVALID_ADDR: - return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_INVALID_ADDR ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NOT_IMPLEMENTED: - return Status::NotImplemented( - context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_MESSAGE_TRUNCATED: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_MESSAGE_TRUNCATED ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_PROGRESS: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_BUFFER_TOO_SMALL: - return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_BUFFER_TOO_SMALL ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_ELEM: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_SOME_CONNECTS_FAILED: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_SOME_CONNECTS_FAILED ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NO_DEVICE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_BUSY: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_CANCELED: - return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_SHMEM_SEGMENT: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_SHMEM_SEGMENT ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_ALREADY_EXISTS: - return Status::AlreadyExists( - context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_OUT_OF_RANGE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_OUT_OF_RANGE ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_TIMED_OUT: - return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_EXCEEDS_LIMIT: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_EXCEEDS_LIMIT ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_UNSUPPORTED: - return Status::NotImplemented(context, ": UCX error ", - static_cast(ucs_status), ": ", - "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_REJECTED: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_NOT_CONNECTED: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NOT_CONNECTED ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_CONNECTION_RESET: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_CONNECTION_RESET ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_FIRST_LINK_FAILURE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_FIRST_LINK_FAILURE ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_LAST_LINK_FAILURE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_LAST_LINK_FAILURE ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_FIRST_ENDPOINT_FAILURE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_FIRST_ENDPOINT_FAILURE ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_LAST_ENDPOINT_FAILURE: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_LAST_ENDPOINT_FAILURE ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_ENDPOINT_TIMEOUT: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_ENDPOINT_TIMEOUT ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - case UCS_ERR_LAST: - return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - default: - return Status::UnknownError( - context, ": Unknown UCX error: ", static_cast(ucs_status), " ", - ucs_status_string(ucs_status)) - .WithDetail(std::make_shared(ucs_status)); - } -} - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h deleted file mode 100644 index 958868d59d4f5..0000000000000 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.h +++ /dev/null @@ -1,83 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include - -#include "arrow/flight/visibility.h" -#include "arrow/status.h" -#include "arrow/util/endian.h" -#include "arrow/util/ubsan.h" -#include "arrow/util/uri.h" - -namespace arrow { -namespace flight { -namespace transport { -namespace ucx { - -static inline void UInt32ToBytesBe(const uint32_t in, uint8_t* out) { - util::SafeStore(out, bit_util::ToBigEndian(in)); -} - -static inline uint32_t BytesToUInt32Be(const uint8_t* in) { - return bit_util::FromBigEndian(util::SafeLoadAs(in)); -} - -class ARROW_FLIGHT_EXPORT FlightUcxStatusDetail : public StatusDetail { - public: - explicit FlightUcxStatusDetail(ucs_status_t status) : status_(status) {} - static constexpr char const kTypeId[] = "flight::transport::ucx::FlightUcxStatusDetail"; - - const char* type_id() const override { return kTypeId; } - std::string ToString() const override; - static ucs_status_t Unwrap(const Status& status); - - private: - ucs_status_t status_; -}; - -/// \brief Convert a UCS status to an Arrow Status. -ARROW_FLIGHT_EXPORT -Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status); - -/// \brief Check if a UCS error code can be ignored in the context of -/// a disconnect. -static inline bool IsIgnorableDisconnectError(ucs_status_t ucs_status) { - // Not connected, connection reset: we're already disconnected - // Timeout: most likely disconnected, but we can't tell from our end - return ucs_status == UCS_OK || ucs_status == UCS_ERR_ENDPOINT_TIMEOUT || - ucs_status == UCS_ERR_NOT_CONNECTED || ucs_status == UCS_ERR_CONNECTION_RESET; -} - -/// \brief Helper to convert a Uri to a struct sockaddr (used in -/// ucp_listener_params_t) -/// -/// \return The length of the sockaddr -ARROW_FLIGHT_EXPORT -arrow::Result UriToSockaddr(const arrow::util::Uri& uri, - struct sockaddr_storage* addr); - -ARROW_FLIGHT_EXPORT -arrow::Result SockaddrToString(const struct sockaddr_storage& address); - -} // namespace ucx -} // namespace transport -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index 08c2ae173601b..ddff1379b1d3b 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -63,7 +63,6 @@ #cmakedefine ARROW_WITH_OPENTELEMETRY #cmakedefine ARROW_WITH_RE2 #cmakedefine ARROW_WITH_SNAPPY -#cmakedefine ARROW_WITH_UCX #cmakedefine ARROW_WITH_UTF8PROC #cmakedefine ARROW_WITH_ZLIB #cmakedefine ARROW_WITH_ZSTD diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index ab988badec145..42a6f8788c2fb 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -109,8 +109,6 @@ ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 ARROW_THRIFT_BUILD_SHA256_CHECKSUM=f460b5c1ca30d8918ff95ea3eb6291b3951cf518553566088f3f2be8981f6209 -ARROW_UCX_BUILD_VERSION=1.12.1 -ARROW_UCX_BUILD_SHA256_CHECKSUM=9bef31aed0e28bf1973d28d74d9ac4f8926c43ca3b7010bd22a084e164e31b71 ARROW_UTF8PROC_BUILD_VERSION=v2.7.0 ARROW_UTF8PROC_BUILD_SHA256_CHECKSUM=4bb121e297293c0fd55f08f83afab6d35d48f0af4ecc07523ad8ec99aa2b12a1 ARROW_XSIMD_BUILD_VERSION=13.0.0 @@ -165,7 +163,6 @@ DEPENDENCIES=( "ARROW_S2N_TLS_URL s2n-${ARROW_S2N_TLS_BUILD_VERSION}.tar.gz https://github.com/aws/s2n-tls/archive/${ARROW_S2N_TLS_BUILD_VERSION}.tar.gz" "ARROW_SNAPPY_URL snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz https://github.com/google/snappy/archive/${ARROW_SNAPPY_BUILD_VERSION}.tar.gz" "ARROW_THRIFT_URL thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz https://archive.apache.org/dist/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz" - "ARROW_UCX_URL ucx-${ARROW_UCX_BUILD_VERSION}.tar.gz https://github.com/openucx/ucx/archive/v${ARROW_UCX_BUILD_VERSION}.tar.gz" "ARROW_UTF8PROC_URL utf8proc-${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz https://github.com/JuliaStrings/utf8proc/archive/${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz" "ARROW_XSIMD_URL xsimd-${ARROW_XSIMD_BUILD_VERSION}.tar.gz https://github.com/xtensor-stack/xsimd/archive/${ARROW_XSIMD_BUILD_VERSION}.tar.gz" "ARROW_ZLIB_URL zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz https://zlib.net/fossils/zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz" diff --git a/docs/source/cpp/flight.rst b/docs/source/cpp/flight.rst index a1e9420bfd34e..d5322fdead8b9 100644 --- a/docs/source/cpp/flight.rst +++ b/docs/source/cpp/flight.rst @@ -363,37 +363,4 @@ Closing unresponsive connections .. _ARROW-6062: https://issues.apache.org/jira/browse/ARROW-6062 -Alternative Transports -====================== - -The standard transport for Arrow Flight is gRPC_. The C++ -implementation also experimentally supports a transport based on -UCX_. To use it, use the protocol scheme ``ucx:`` when starting a -server or creating a client. - -UCX Transport -------------- - -Not all features of the gRPC transport are supported. See -:ref:`status-flight-rpc` for details. Also note these specific -caveats: - -- The server creates an independent UCP worker for each client. This - consumes more resources but provides better throughput. -- The client creates an independent UCP worker for each RPC - call. Again, this trades off resource consumption for - performance. This also means that unlike with gRPC, it is - essentially equivalent to make all calls with a single client or - with multiple clients. -- The UCX transport attempts to avoid copies where possible. In some - cases, it can directly reuse UCX-allocated buffers to back - :class:`arrow::Buffer` objects, however, this will also extend the - lifetime of associated UCX resources beyond the lifetime of the - Flight client or server object. -- Depending on the transport that UCX itself selects, you may find - that increasing ``UCX_MM_SEG_SIZE`` from the default (around 8KB) to - around 60KB improves performance (UCX will copy more data in a - single call). - .. _gRPC: https://grpc.io/ -.. _UCX: https://openucx.org/ diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst index 2c5487d857ea4..aac979cf7590b 100644 --- a/docs/source/format/Flight.rst +++ b/docs/source/format/Flight.rst @@ -333,10 +333,6 @@ schemes for the given transports: +----------------------------+--------------------------------+ | (reuse connection) | arrow-flight-reuse-connection: | +----------------------------+--------------------------------+ -| UCX_ (plaintext) | ucx: | -+----------------------------+--------------------------------+ - -.. _UCX: https://openucx.org/ Connection Reuse ---------------- diff --git a/docs/source/status.rst b/docs/source/status.rst index c232aa280befb..ec88306ade635 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -190,8 +190,6 @@ Flight RPC +--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ | gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | | +--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| UCX_ transport (ucx:) | ✓ | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ Supported features in the gRPC transport: @@ -213,42 +211,16 @@ Supported features in the gRPC transport: | RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | | | +--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -Supported features in the UCX transport: - -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Flight RPC Feature | C++ | Java | Go | JS | C# | Rust | Julia | Swift | -+============================================+=======+=======+=======+====+=======+=======+=======+=======+ -| All RPC methods | ✓ (4) | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Authentication handlers | | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Call timeouts | | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Call cancellation | | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Concurrent client calls | ✓ (5) | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| Custom middleware | | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ -| RPC error codes | ✓ | | | | | | | | -+--------------------------------------------+-------+-------+-------+----+-------+-------+-------+-------+ - Notes: * \(1) No support for Handshake or DoExchange. * \(2) Support using AspNetCore authentication handlers. * \(3) Whether a single client can support multiple concurrent calls. -* \(4) Only support for DoExchange, DoGet, DoPut, and GetFlightInfo. -* \(5) Each concurrent call is a separate connection to the server - (unlike gRPC where concurrent calls are multiplexed over a single - connection). This will generally provide better throughput but - consumes more resources both on the server and the client. .. seealso:: The :ref:`flight-rpc` specification. .. _gRPC: https://grpc.io/ -.. _UCX: https://openucx.org/ Flight SQL ==========