Skip to content

Commit

Permalink
Add support for REST based remote functions.
Browse files Browse the repository at this point in the history
Co-authored-by: Wills Feng <[email protected]>
  • Loading branch information
Joe-Abraham and wills-feng committed Nov 27, 2024
1 parent 0ee4687 commit beccbc6
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 7 deletions.
5 changes: 3 additions & 2 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,11 @@ void PrestoServer::registerRemoteFunctions() {
} else {
VELOX_FAIL(
"To register remote functions using a json file path you need to "
"specify the remote server location using '{}', '{}' or '{}'.",
"specify the remote server location using '{}', '{}' or '{}' or {}.",
SystemConfig::kRemoteFunctionServerThriftAddress,
SystemConfig::kRemoteFunctionServerThriftPort,
SystemConfig::kRemoteFunctionServerThriftUdsPath);
SystemConfig::kRemoteFunctionServerThriftUdsPath,
SystemConfig::kRemoteFunctionServerRestURL);
}
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const {
return optionalProperty(kRemoteFunctionServerSerde).value();
}

std::string SystemConfig::remoteFunctionRestUrl() const {
return optionalProperty(kRemoteFunctionServerRestURL).value();
}

int32_t SystemConfig::maxDriversPerTask() const {
return optionalProperty<int32_t>(kMaxDriversPerTask).value();
}
Expand Down
6 changes: 6 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ class SystemConfig : public ConfigBase {
static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{
"remote-function-server.thrift.uds-path"};

/// HTTP URL used by the remote function rest server.
static constexpr std::string_view kRemoteFunctionServerRestURL{
"remote-function-server.rest.url"};

/// Path where json files containing signatures for remote functions can be
/// found.
static constexpr std::string_view
Expand Down Expand Up @@ -702,6 +706,8 @@ class SystemConfig : public ConfigBase {

std::string remoteFunctionServerSerde() const;

std::string remoteFunctionRestUrl() const;

int32_t maxDriversPerTask() const;

folly::Optional<int32_t> taskWriterCount() const;
Expand Down
8 changes: 8 additions & 0 deletions presto-native-execution/presto_cpp/main/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@ add_library(
presto_types OBJECT
PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp VeloxPlanValidator.cpp
PrestoToVeloxSplit.cpp PrestoToVeloxConnector.cpp)

add_dependencies(presto_types presto_operators presto_type_converter velox_type
velox_type_fbhive velox_dwio_dwrf_proto)

target_link_libraries(presto_types presto_type_converter velox_type_fbhive
velox_hive_partition_function velox_tpch_gen)

if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
add_dependencies(presto_types velox_expression presto_server_remote_function
velox_functions_remote)
target_link_libraries(presto_types presto_server_remote_function
velox_functions_remote)
endif()

set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool)

add_library(presto_function_metadata OBJECT FunctionMetadata.cpp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@

#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
#include <boost/algorithm/string/case_conv.hpp>
#include "presto_cpp/main/common/Configs.h"
#include "presto_cpp/presto_protocol/Base64Util.h"
#include "velox/common/base/Exceptions.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/vector/ComplexVector.h"
#include "velox/vector/ConstantVector.h"
#include "velox/vector/FlatVector.h"
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
#include "presto_cpp/main/JsonSignatureParser.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/functions/remote/client/Remote.h"
#endif

using namespace facebook::velox::core;
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
using facebook::velox::functions::remote::PageFormat;
#endif
using facebook::velox::TypeKind;

namespace facebook::presto {
Expand Down Expand Up @@ -412,6 +421,18 @@ std::optional<TypedExprPtr> VeloxExprConverter::tryConvertLike(
returnType, args, getFunctionName(signature));
}

#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
PageFormat fromSerdeString(const std::string_view& serdeName) {
if (serdeName == "presto_page") {
return PageFormat::PRESTO_PAGE;
} else {
VELOX_FAIL(
"presto_page serde is expected by remote function server but got : '{}'",
serdeName);
}
}
#endif

TypedExprPtr VeloxExprConverter::toVeloxExpr(
const protocol::CallExpression& pexpr) const {
if (auto builtin = std::dynamic_pointer_cast<protocol::BuiltInFunctionHandle>(
Expand Down Expand Up @@ -458,10 +479,68 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
pexpr.functionHandle)) {
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(sqlFunctionHandle->functionId));
}
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
else if (
auto restFunctionHandle =
std::dynamic_pointer_cast<protocol::RestFunctionHandle>(
pexpr.functionHandle)) {

auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

const auto* systemConfig = SystemConfig::instance();

velox::functions::RemoteVectorFunctionMetadata metadata;
metadata.serdeFormat =
fromSerdeString(systemConfig->remoteFunctionServerSerde());
metadata.location = systemConfig->remoteFunctionRestUrl();
metadata.functionId = restFunctionHandle->functionId;
metadata.version = restFunctionHandle->version;

const auto& prestoSignature = restFunctionHandle->signature;
// parseTypeSignature
velox::exec::FunctionSignatureBuilder signatureBuilder;
// Handle type variable constraints
for (const auto& typeVar : prestoSignature.typeVariableConstraints) {
signatureBuilder.typeVariable(typeVar.name);
}

// Handle long variable constraints (for integer variables)
for (const auto& longVar : prestoSignature.longVariableConstraints) {
signatureBuilder.integerVariable(longVar.name);
}

// Handle return type
signatureBuilder.returnType(prestoSignature.returnType);

// Handle argument types
for (const auto& argType : prestoSignature.argumentTypes) {
signatureBuilder.argumentType(argType);
}

// Handle variable arity
if (prestoSignature.variableArity) {
signatureBuilder.variableArity();
}

auto signature = signatureBuilder.build();
std::vector<velox::exec::FunctionSignaturePtr> veloxSignatures = {
signature};

velox::functions::registerRemoteFunction(
getFunctionName(restFunctionHandle->functionId),
veloxSignatures,
metadata,
false);

return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(restFunctionHandle->functionId));
}
#endif
VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type);
}

Expand Down
10 changes: 10 additions & 0 deletions presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ target_link_libraries(
${GFLAGS_LIBRARIES}
pthread)

if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
add_dependencies(presto_expressions_test presto_server_remote_function
velox_expression velox_functions_remote)

target_link_libraries(
presto_expressions_test GTest::gmock GTest::gmock_main
presto_server_remote_function velox_expression velox_functions_remote)

endif()

set_property(TARGET presto_expressions_test PROPERTY JOB_POOL_LINK
presto_link_job_pool)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ void to_json(json& j, const std::shared_ptr<FunctionHandle>& p) {
j = *std::static_pointer_cast<SqlFunctionHandle>(p);
return;
}
if (type == "rest") {
j = *std::static_pointer_cast<RestFunctionHandle>(p);
return;
}

throw TypeError(type + " no abstract type FunctionHandle ");
}
Expand All @@ -138,6 +142,13 @@ void from_json(const json& j, std::shared_ptr<FunctionHandle>& p) {
p = std::static_pointer_cast<FunctionHandle>(k);
return;
}
if (type == "rest") {
std::shared_ptr<RestFunctionHandle> k =
std::make_shared<RestFunctionHandle>();
j.get_to(*k);
p = std::static_pointer_cast<FunctionHandle>(k);
return;
}

throw TypeError(type + " no abstract type FunctionHandle ");
}
Expand Down Expand Up @@ -5849,6 +5860,20 @@ void to_json(json& j, const JsonBasedUdfFunctionMetadata& p) {
"JsonBasedUdfFunctionMetadata",
"AggregationFunctionMetadata",
"aggregateMetadata");
to_json_key(
j,
"functionId",
p.functionId,
"JsonBasedUdfFunctionMetadata",
"SqlFunctionId",
"functionId");
to_json_key(
j,
"version",
p.version,
"JsonBasedUdfFunctionMetadata",
"String",
"version");
}

void from_json(const json& j, JsonBasedUdfFunctionMetadata& p) {
Expand Down Expand Up @@ -5901,6 +5926,20 @@ void from_json(const json& j, JsonBasedUdfFunctionMetadata& p) {
"JsonBasedUdfFunctionMetadata",
"AggregationFunctionMetadata",
"aggregateMetadata");
from_json_key(
j,
"functionId",
p.functionId,
"JsonBasedUdfFunctionMetadata",
"SqlFunctionId",
"functionId");
from_json_key(
j,
"version",
p.version,
"JsonBasedUdfFunctionMetadata",
"String",
"version");
}
} // namespace facebook::presto::protocol
/*
Expand Down Expand Up @@ -8156,6 +8195,52 @@ void from_json(const json& j, RemoteTransactionHandle& p) {
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
RestFunctionHandle::RestFunctionHandle() noexcept {
_type = "rest";
}

void to_json(json& j, const RestFunctionHandle& p) {
j = json::object();
j["@type"] = "rest";
to_json_key(
j,
"functionId",
p.functionId,
"RestFunctionHandle",
"SqlFunctionId",
"functionId");
to_json_key(
j, "version", p.version, "RestFunctionHandle", "String", "version");
to_json_key(
j,
"signature",
p.signature,
"RestFunctionHandle",
"Signature",
"signature");
}

void from_json(const json& j, RestFunctionHandle& p) {
p._type = j["@type"];
from_json_key(
j,
"functionId",
p.functionId,
"RestFunctionHandle",
"SqlFunctionId",
"functionId");
from_json_key(
j, "version", p.version, "RestFunctionHandle", "String", "version");
from_json_key(
j,
"signature",
p.signature,
"RestFunctionHandle",
"Signature",
"signature");
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
RowNumberNode::RowNumberNode() noexcept {
_type = "com.facebook.presto.sql.planner.plan.RowNumberNode";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM;
class Exception : public std::runtime_error {
public:
explicit Exception(const std::string& message)
: std::runtime_error(message){};
: std::runtime_error(message) {};
};

class TypeError : public Exception {
public:
explicit TypeError(const std::string& message) : Exception(message){};
explicit TypeError(const std::string& message) : Exception(message) {};
};

class OutOfRange : public Exception {
public:
explicit OutOfRange(const std::string& message) : Exception(message){};
explicit OutOfRange(const std::string& message) : Exception(message) {};
};
class ParseError : public Exception {
public:
explicit ParseError(const std::string& message) : Exception(message){};
explicit ParseError(const std::string& message) : Exception(message) {};
};

using String = std::string;
Expand Down Expand Up @@ -1508,6 +1508,8 @@ struct JsonBasedUdfFunctionMetadata {
String schema = {};
RoutineCharacteristics routineCharacteristics = {};
std::shared_ptr<AggregationFunctionMetadata> aggregateMetadata = {};
std::shared_ptr<SqlFunctionId> functionId = {};
std::shared_ptr<String> version = {};
};
void to_json(json& j, const JsonBasedUdfFunctionMetadata& p);
void from_json(const json& j, JsonBasedUdfFunctionMetadata& p);
Expand Down Expand Up @@ -1922,6 +1924,17 @@ void to_json(json& j, const RemoteTransactionHandle& p);
void from_json(const json& j, RemoteTransactionHandle& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct RestFunctionHandle : public FunctionHandle {
SqlFunctionId functionId = {};
String version = {};
Signature signature = {};

RestFunctionHandle() noexcept;
};
void to_json(json& j, const RestFunctionHandle& p);
void from_json(const json& j, RestFunctionHandle& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct RowNumberNode : public PlanNode {
std::shared_ptr<PlanNode> source = {};
List<VariableReferenceExpression> partitionBy = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ AbstractClasses:
subclasses:
- { name: BuiltInFunctionHandle, key: $static }
- { name: SqlFunctionHandle, key: json_file }

- { name: RestFunctionHandle, key: rest }

JavaClasses:
- presto-spi/src/main/java/com/facebook/presto/spi/ErrorCause.java
Expand All @@ -191,6 +191,7 @@ JavaClasses:
- presto-main/src/main/java/com/facebook/presto/execution/buffer/BufferState.java
- presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java
- presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionHandle.java
- presto-spi/src/main/java/com/facebook/presto/spi/function/RestFunctionHandle.java
- presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java
- presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java
- presto-spi/src/main/java/com/facebook/presto/spi/relation/CallExpression.java
Expand Down
Loading

0 comments on commit beccbc6

Please sign in to comment.