Skip to content

Commit

Permalink
PSI Python wrapper (primihub#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix20162016 authored Dec 7, 2023
1 parent 6f782f5 commit 06b253e
Show file tree
Hide file tree
Showing 17 changed files with 373 additions and 28 deletions.
3 changes: 2 additions & 1 deletion Dockerfile.local
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ WORKDIR /app

ADD bazel-bin.tar.gz ./
COPY src/primihub/protos/ src/primihub/protos/

RUN ln -s -f bazel-bin/cli primihub-cli
RUN ln -s -f bazel-bin/node primihub-node
RUN mkdir log \
&& cd python \
&& python3 setup.py develop
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile.release
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ COPY src/primihub/protos/ src/primihub/protos/

RUN tar zxf /opt/primihub-linux-$(dpkg --print-architecture).tar.gz \
&& mkdir log

RUN ln -s -f bazel-bin/cli primihub-cli
RUN ln -s -f bazel-bin/node primihub-node
WORKDIR /app/python

RUN python3 -m pip install --upgrade pip \
Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE_CN
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ http_archive(

git_repository(
name = "com_github_bazel_rules_3rdparty",
commit = "537bac41153d21467d31f0a22eb67f02d5104e23",
commit = "67ac7942969ee2224732a9b50b17180390c4fc97",
remote = "https://gitee.com/primihub/bazel-rules-thirdparty.git",
)

Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE_GITHUB
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_r

git_repository(
name = "com_github_bazel_rules_3rdparty",
commit = "537bac41153d21467d31f0a22eb67f02d5104e23",
commit = "67ac7942969ee2224732a9b50b17180390c4fc97",
remote = "https://github.com/primihub/bazel-rules-thirdparty.git",
)

Expand Down
5 changes: 2 additions & 3 deletions build_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ fi
git rev-parse --abbrev-ref HEAD >> commit.txt
git rev-parse HEAD >> commit.txt

tar zcf bazel-bin.tar.gz bazel-bin/cli \
tar zcfh bazel-bin.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/_solib* \
bazel-bin/task_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
Expand Down
32 changes: 32 additions & 0 deletions example/code/psi_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from primihub.utils.logger_util import logger
from primihub.MPC.psi import TwoPartyPsi, PsiType, DataType
import random
import time

class Example:
def run(self):
psi_executor = TwoPartyPsi()
data_len = 10
#input_data = [str(i) for i in range(data_len)]
input_data = [i for i in range(data_len)]
# input_data = [f"中国{i}" for i in range(data_len)]
input_data_str = [str(i) for i in range(data_len)]
for i in range(1):
result = psi_executor.run(input_data,
["Alice", "Bob"], "Alice", True,
PsiType.ECDH, data_type = DataType.Interger)
logger.info(f"xxxxxxxxxxx loop: {i} psi int result: {result}")
result = psi_executor.run(input_data,
["Alice", "Bob"], "Alice", True,
PsiType.KKRT, data_type = DataType.Interger)
logger.info(f"xxxxxxxxxxx loop: {i} psi int result: {result}")
result = psi_executor.run(input_data_str,
["Alice", "Bob"], "Alice", True, PsiType.ECDH)
logger.info(f"xxxxxxxxxxx loop: {i} psi string result: {result}")
result = psi_executor.run(input_data_str,
["Alice", "Bob"], "Alice", True, PsiType.KKRT)
logger.info(f"xxxxxxxxxxx loop: {i} psi string result: {result}")


example = Example()
example.run()
33 changes: 33 additions & 0 deletions example/python_code_psi.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"task_name": "psi_task",
"task_lang": "python",
"task_code": {
"code_file_path": "example/code/psi_example.py",
"code": ""
},
"params": {
"auxiliary_server": {
"type": "OBJECT",
"value": {
"is_dataset": true,
"dataset_id": "FAKE_DATA_PARTY_2"
}
}
},
"component_params": {
"roles": {
"executor": ["Alice", "Bob"]
},
"common_params": {
"model": "PythonEngine"
},
"role_params": {
"Alice": {
"data_set": "FAKE_DATA_PARTY_0"
},
"Bob": {
"data_set": "FAKE_DATA_PARTY_1"
}
}
}
}
50 changes: 50 additions & 0 deletions python/primihub/MPC/psi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
* Copyright (c) 2023 by PrimiHub
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
"""
import ph_secure_lib as ph_slib
from primihub.context import Context
from enum import Enum
import numpy as np
from primihub.utils.logger_util import logger

class PsiType(Enum):
KKRT = "KKRT"
ECDH = "ECDH"

class DataType(Enum):
Interger = 0
String = 1

class TwoPartyPsi:
def __init__(self):
self.psi_executor = ph_slib.PSIExecutor(Context.message)

def run(self,
input: np.ndarray,
parties: list,
receiver: str,
broadcast: bool,
protocol: PsiType = PsiType.KKRT,
data_type: DataType = DataType.String):
if len(input) == 0:
return list()
if data_type == DataType.Interger:
logger.info(f"dtype is int")
return self.psi_executor.run_as_integer(
input, parties, receiver, broadcast, protocol.value)
elif data_type == DataType.String:
logger.info(f"dtype is str")
return self.psi_executor.run_as_string(
input, parties, receiver, broadcast, protocol.value)
11 changes: 11 additions & 0 deletions src/primihub/task/pybind_wrapper/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pybind_extension(
],
deps = [
":mpc_task_wrapper",
":psi_task_wrapper",
"//:python3_lib",
"//src/primihub/util:util_lib",
],
Expand All @@ -29,6 +30,16 @@ cc_library(
],
)

cc_library(
name = "psi_task_wrapper",
srcs = ["psi_wrapper.cc"],
hdrs = ["psi_wrapper.h"],
deps = [
"//src/primihub/task/semantic:psi_task",
"@com_github_stduuid//:stduuid_lib",
],
)

cc_library(
name = "util_lib",
hdrs = ["util.h"],
Expand Down
103 changes: 94 additions & 9 deletions src/primihub/task/pybind_wrapper/psi_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,113 @@
#include <glog/logging.h>

#include "src/primihub/common/common.h"
#include "src/primihub/common/value_check_util.h"
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/kernel/psi/util.h"

namespace primihub::task {
PsiExecutor::PsiExecutor(const std::string& task_req_str,
const std::string& protocol) {
//
task_req_ptr_ = std::make_unique<primihub::rpc::PushTaskRequest>();
namespace rpc = primihub::rpc;
PsiExecutor::PsiExecutor(const std::string& task_req_str) {
task_req_ptr_ = std::make_unique<rpc::PushTaskRequest>();
bool succ_flag = task_req_ptr_->ParseFromString(task_req_str);
if (!succ_flag) {
throw std::runtime_error("parser task request encountes error");
RaiseException("parser task request encountes error");
}
auto& task_config = task_req_ptr_->task();

task_ptr_ = std::make_unique<PsiTask>(&task_config);
}

auto PsiExecutor::RunPsi(const std::vector<std::string>& input) ->
auto PsiExecutor::RunPsi(const std::vector<std::string>& input,
const std::vector<std::string>& parties,
const std::string& receiver,
bool broadcast_result,
const std::string& protocol) ->
std::vector<std::string> {
if (parties.size() != 2) {
std::stringstream ss;
ss << "Need 2 parties, but get " << parties.size();
RaiseException(ss.str());
}
std::vector<std::string> result;
auto ret = task_ptr_->ExecuteTask(input, &result);
auto req_ptr = BuildPsiTaskRequest(parties, receiver,
broadcast_result, protocol, *task_req_ptr_);
auto task_config_ptr = req_ptr->mutable_task();
auto task_ptr = std::make_unique<PsiTask>(task_config_ptr);
auto ret = task_ptr->ExecuteTask(input, &result);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "run psi task failed";
throw std::runtime_error("run psi task failed");
RaiseException("run psi task failed");
}
LOG(INFO) << "PsiExecutor::RunPsi: " << result.size();
return result;
}
// to support multithread mode in future, first Negotiate a unique key
retcode PsiExecutor::NegotiateSubTaskId(std::string* sub_task_id, Role role) {
return retcode::SUCCESS;
}
auto PsiExecutor::BuildPsiTaskRequest(const std::vector<std::string>& parties,
const std::string& receiver,
bool broadcast_result,
const std::string& protocol,
const primihub::rpc::PushTaskRequest& request) ->
std::unique_ptr<primihub::rpc::PushTaskRequest> {
auto req_ptr = std::make_unique<primihub::rpc::PushTaskRequest>(request);
auto task_config_ptr = req_ptr->mutable_task();
// build access info
const auto& party_access_info = request.task().party_access_info();
auto it = party_access_info.find(receiver);
if (it == party_access_info.end()) {
std::stringstream ss;
ss << "No access info found for result receiver: " << receiver;
RaiseException(ss.str());
}
task_config_ptr->clear_party_access_info();
auto ptr = task_config_ptr->mutable_party_access_info();
(*ptr)[PARTY_CLIENT].CopyFrom(it->second);
for (const auto& party_name : parties) {
if (party_name == receiver) {
continue;
}
auto it = party_access_info.find(party_name);
if (it == party_access_info.end()) {
std::stringstream ss;
ss << "No access info found for party: " << party_name;
RaiseException(ss.str());
}
(*ptr)[PARTY_SERVER].CopyFrom(it->second);
}
// set party name
auto& party_name = task_config_ptr->party_name();
if (party_name == receiver) {
task_config_ptr->set_party_name(PARTY_CLIENT);
} else {
task_config_ptr->set_party_name(PARTY_SERVER);
}
// broadcast result
auto param_map = task_config_ptr->mutable_params()->mutable_param_map();
rpc::ParamValue pv;
pv.set_is_array(false);
pv.set_var_type(rpc::INT32);
if (broadcast_result) {
pv.set_value_int32(1);
} else {
pv.set_value_int32(0);
}
(*param_map)["sync_result_to_server"] = std::move(pv);
// configure psi protocol
rpc::ParamValue psi_type;
psi_type.set_is_array(false);
psi_type.set_var_type(rpc::INT32);
if (protocol == std::string("KKRT")) {
psi_type.set_value_int32(static_cast<int>(rpc::KKRT));
} else if (protocol == std::string("ECDH")) {
psi_type.set_value_int32(static_cast<int>(rpc::ECDH));
} else {
std::stringstream ss;
ss << "Unknown PSI protocol: " << protocol;
RaiseException(ss.str());
}
(*param_map)["psiTag"] = std::move(psi_type);
return req_ptr;
}
} // namespace primihub::task
21 changes: 18 additions & 3 deletions src/primihub/task/pybind_wrapper/psi_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,24 @@
namespace primihub::task {
class PsiExecutor {
public:
explicit PsiExecutor(const std::string& task_req, const std::string& protocol);
std::vector<std::string> RunPsi(const std::vector<std::string>& input);

explicit PsiExecutor(const std::string& task_req);
std::vector<std::string> RunPsi(const std::vector<std::string>& input,
const std::vector<std::string>& parties,
const std::string& receiver,
bool broadcast_result,
const std::string& protocol);
protected:
enum class Role : uint8_t {
kClient,
kServer,
};
retcode NegotiateSubTaskId(std::string* sub_task_id, Role role);
auto BuildPsiTaskRequest(const std::vector<std::string>& parties,
const std::string& receiver,
bool broadcast_result,
const std::string& protocol,
const primihub::rpc::PushTaskRequest& request) ->
std::unique_ptr<primihub::rpc::PushTaskRequest>;
private:
std::unique_ptr<PsiTask> task_ptr_{nullptr};
std::unique_ptr<primihub::rpc::PushTaskRequest> task_req_ptr_{nullptr};
Expand Down
Loading

0 comments on commit 06b253e

Please sign in to comment.