Skip to content

Commit

Permalink
customize to build component
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix20162016 committed Jun 19, 2024
1 parent 3d8c4e4 commit 3df7c93
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 44 deletions.
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ ifneq ($(disable_py_task), y)
BUILD_FLAG += --define enable_py_task=true
endif

ifneq ($(disable_mpc_task), y)
BUILD_FLAG += --define enable_mpc_task=true
endif

ifneq ($(disable_pir_task), y)
BUILD_FLAG += --define enable_pir_task=true
endif

ifneq ($(disable_psi_task), y)
BUILD_FLAG += --define enable_psi_task=true
endif

ifeq ($(mysql), y)
BUILD_FLAG += --define enable_mysql_driver=true
endif
Expand Down
10 changes: 9 additions & 1 deletion build_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ build_opt="mysql=y"
if [[ "$PRIMIHUB_MODE" == "FULL" ]]; then
build_opt="${build_opt}"
else
build_opt="${build_opt} disable_py_task=y"
build_opt="${build_opt} \
disable_py_task=y \
"
# build_opt="${build_opt} \
# disable_py_task=y \
# disable_mpc_task=y \
# disable_pir_task=y \
# disable_psi_task=y \
# "
fi

make $build_opt
Expand Down
7 changes: 5 additions & 2 deletions src/primihub/node/worker/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ retcode Worker::ExecuteTaskByThread(const PushTaskRequest* task_request) {
auto& executor = nodelet->GetTeeExecutor();
auto ra_service_ptr = reinterpret_cast<void*>(ra_service.get());
auto tee_executor_ptr = reinterpret_cast<void*>(executor.get());
task_ptr = TaskFactory::Create(this->node_id,
auto result = TaskFactory::Create(this->node_id,
*task_request, dataset_service, ra_service_ptr, tee_executor_ptr);
#else
task_ptr = TaskFactory::Create(this->node_id, *task_request,
auto result = TaskFactory::Create(this->node_id, *task_request,
dataset_service);

#endif
task_ptr = std::move(result.first);
std::string info = std::move(result.second);
if (task_ptr == nullptr) {
LOG(ERROR) << TASK_INFO_STR << "Woker create task failed.";
task_ready_promise_.set_value(false);
Expand Down
56 changes: 44 additions & 12 deletions src/primihub/task/semantic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,55 @@ config_setting(
name = "py_task_enabled",
values = {"define" : "enable_py_task=true"},
)
config_setting(
name = "mpc_task_enabled",
values = {"define" : "enable_mpc_task=true"},
)
config_setting(
name = "pir_task_enabled",
values = {"define" : "enable_pir_task=true"},
)
config_setting(
name = "psi_task_enabled",
values = {"define" : "enable_psi_task=true"},
)

TASK_DEFINES_OPTS = select({
"py_task_enabled": ["PY_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"mpc_task_enabled": ["MPC_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"pir_task_enabled": ["PIR_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"psi_task_enabled": ["PSI_TASK_ENABLED"],
"//conditions:default": []
})

TASK_DEPS_OPTS = select({
"py_task_enabled": [":fl_task"],
"//conditions:default": []
}) + select({
"mpc_task_enabled": [":mpc_task"],
"//conditions:default": []
}) + select({
"pir_task_enabled": [":pir_task"],
"//conditions:default": []
}) + select({
"psi_task_enabled": [":psi_task"],
"//conditions:default": []
})

#factory for create task
cc_library(
name = "task_factory",
hdrs = ["factory.h"],
defines = select({
"py_task_enabled": ["PY_TASK_ENABLED"],
"//conditions:default": []
}),
deps = [
":mpc_task",
":pir_task",
":psi_task",
] + select({
"py_task_enabled": [":fl_task"],
"//conditions:default": []
}),
defines = TASK_DEFINES_OPTS,
deps = TASK_DEPS_OPTS + [
":task_interface",
],
)

cc_library(
Expand Down
55 changes: 44 additions & 11 deletions src/primihub/task/semantic/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@
#include <memory>

#include "src/primihub/task/semantic/task.h"
#ifdef MPC_TASK_ENABLED
#include "src/primihub/task/semantic/mpc_task.h"
#endif // MPC_TASK_ENABLED

#ifdef PY_TASK_ENABLED
#include "src/primihub/task/semantic/fl_task.h"
#endif // PY_TASK_ENABLED

#ifdef PSI_TASK_ENABLED
#include "src/primihub/task/semantic/psi_task.h"
#endif // PSI_TASK_ENABLED

#ifdef PIR_TASK_ENABLED
#include "src/primihub/task/semantic/pir_task.h"
#endif // PIR_TASK_ENABLED
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/util/log.h"
#include "src/primihub/util/proto_log_helper.h"
Expand All @@ -40,7 +49,8 @@ namespace primihub::task {

class TaskFactory {
public:
static std::shared_ptr<TaskBase> Create(const std::string& node_id,
using RetType = std::pair<std::shared_ptr<TaskBase>, std::string>;
static RetType Create(const std::string& node_id,
const PushTaskRequest& request,
std::shared_ptr<DatasetService> dataset_service,
void* ra_service = nullptr,
Expand All @@ -49,38 +59,54 @@ class TaskFactory {
const auto& task_info = request.task().task_info();
std::string task_inof_str = pb_util::TaskInfoToString(task_info);
auto task_type = request.task().type();
std::shared_ptr<TaskBase> task_ptr{nullptr};
RetType result_;
switch (task_language) {
#ifdef PY_TASK_ENABLED
case Language::PYTHON:
task_ptr = TaskFactory::CreateFLTask(node_id, request, dataset_service);
result_ = std::make_pair(
TaskFactory::CreateFLTask(node_id, request, dataset_service), "SUCCESS");
break;
#endif
case Language::PROTO: {
switch (task_type) {
#ifdef MPC_TASK_ENABLED
case rpc::TaskType::ACTOR_TASK:
task_ptr = TaskFactory::CreateMPCTask(node_id, request, dataset_service);
result_ = std::make_pair(
TaskFactory::CreateMPCTask(node_id, request, dataset_service), "SUCCESS");
break;
#endif // MPC_TASK_ENABLED
#ifdef PSI_TASK_ENABLED
case rpc::TaskType::PSI_TASK:
task_ptr =
result_ = std::make_pair(
TaskFactory::CreatePSITask(node_id, request, dataset_service,
ra_service, executor);
ra_service, executor), "SUCCESS");
break;
#endif // PSI_TASK_ENABLED
#ifdef PIR_TASK_ENABLED
case rpc::TaskType::PIR_TASK:
task_ptr = TaskFactory::CreatePIRTask(node_id, request, dataset_service);
result_ = std::make_pair(
TaskFactory::CreatePIRTask(node_id, request, dataset_service), "SUCCESS");;
break;
default:
LOG(ERROR) << task_inof_str << "unsupported task type: " << task_type;
#endif // PIR_TASK_ENABLED
default: {
std::string err_msg = "Unsupported task type: " + rpc::TaskType_Name(task_type);
LOG(ERROR) << task_inof_str << err_msg;
result_ = std::make_pair(nullptr, err_msg);
break;
}
}
break;
}
default:
default: {
std::string err_msg = "Unsupported task type: " + rpc::Language_Name(task_language);
LOG(ERROR) << task_inof_str << "unsupported language: " << task_language;
result_ = std::make_pair(nullptr, err_msg);
break;
}
return task_ptr;
}
return result_;
}

#ifdef PY_TASK_ENABLED
static std::shared_ptr<TaskBase> CreateFLTask(const std::string& node_id,
const PushTaskRequest& request,
Expand All @@ -90,6 +116,8 @@ class TaskFactory {
request, dataset_service);
}
#endif // PY_TASK_ENABLED

#ifdef MPC_TASK_ENABLED
static std::shared_ptr<TaskBase> CreateMPCTask(const std::string& node_id,
const PushTaskRequest& request,
std::shared_ptr<DatasetService> dataset_service) {
Expand All @@ -98,7 +126,9 @@ class TaskFactory {
return std::make_shared<MPCTask>(node_id, _function_name,
&task_param, dataset_service);
}
#endif // MPC_TASK_ENABLED

#ifdef PSI_TASK_ENABLED
static std::shared_ptr<TaskBase> CreatePSITask(const std::string& node_id,
const PushTaskRequest& request,
std::shared_ptr<DatasetService> dataset_service,
Expand All @@ -110,13 +140,16 @@ class TaskFactory {
ra_server, tee_engine);
return task_ptr;
}
#endif // PSI_TASK_ENABLED

#ifdef PIR_TASK_ENABLED
static std::shared_ptr<TaskBase> CreatePIRTask(const std::string& node_id,
const PushTaskRequest& request,
std::shared_ptr<DatasetService> dataset_service) {
const auto& task_config = request.task();
return std::make_shared<PirTask>(&task_config, dataset_service);
}
#endif // PIR_TASK_ENABLED

};
} // namespace primihub::task
Expand Down
60 changes: 52 additions & 8 deletions src/primihub/task/semantic/scheduler/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,59 @@
package(default_visibility = ["//visibility:public"])
config_setting(
name = "py_task_enabled",
values = {"define" : "enable_py_task=true"},
)
config_setting(
name = "mpc_task_enabled",
values = {"define" : "enable_mpc_task=true"},
)
config_setting(
name = "pir_task_enabled",
values = {"define" : "enable_pir_task=true"},
)
config_setting(
name = "psi_task_enabled",
values = {"define" : "enable_psi_task=true"},
)

SCHEDULER_DEFINES_OPTS = [
] + select({
"py_task_enabled": ["PY_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"mpc_task_enabled": ["MPC_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"pir_task_enabled": ["PIR_TASK_ENABLED"],
"//conditions:default": []
}) + select({
"psi_task_enabled": ["PSI_TASK_ENABLED"],
"//conditions:default": []
})

SCHEDULER_DEPS_OPTS = [
":tee_scheduler",
] + select({
"py_task_enabled": [":fl_scheduler"],
"//conditions:default": []
}) + select({
"mpc_task_enabled": [
":mpc_scheduler",
":aby3_scheduler",
],
"//conditions:default": []
}) + select({
"pir_task_enabled": [":pir_scheduler"],
"//conditions:default": []
}) + select({
"psi_task_enabled": [":scheduler_interface"],
"//conditions:default": []
})

cc_library(
name = "scheduler_factory",
srcs = ["factory.h"],
defines = SCHEDULER_DEFINES_OPTS,
deps = [
":scheduler_lib",
],
Expand All @@ -28,14 +79,7 @@ cc_library(

cc_library(
name = "scheduler_lib",
deps = [
":scheduler_interface",
":mpc_scheduler",
":aby3_scheduler",
":fl_scheduler",
":tee_scheduler",
":pir_scheduler",
],
deps = SCHEDULER_DEPS_OPTS,
)

SCHEDULER_DEPS = [
Expand Down
Loading

0 comments on commit 3df7c93

Please sign in to comment.