Skip to content

Commit

Permalink
Pir batch query (primihub#775)
Browse files Browse the repository at this point in the history
*number of queries must be less than table size
  • Loading branch information
phoenix20162016 authored May 13, 2024
1 parent 36cf363 commit 3976c17
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 55 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE_CN
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ http_archive(
git_repository(
name = "mircrosoft_apsi",
#branch = "bazel_version",
commit = "44243c1a85435c04ca858279757ca5524dd3c9aa",
commit = "0b192b2df72a136900b8c06fdd0caff8b1d4ce8d",
remote = "https://gitee.com/primihub/APSI.git",
)

Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE_GITHUB
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ http_archive(
git_repository(
name = "mircrosoft_apsi",
#branch = "bazel_version",
commit = "44243c1a85435c04ca858279757ca5524dd3c9aa",
commit = "0b192b2df72a136900b8c06fdd0caff8b1d4ce8d",
remote = "https://github.com/primihub/APSI.git",
)

Expand Down
2 changes: 1 addition & 1 deletion config/pir_server_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"table_params": {
"hash_func_count": 2,
"hash_func_count": 5,
"table_size": 409,
"max_items_per_bin": 20
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,68 +31,113 @@ using PSIParams = apsi::PSIParams;

retcode KeywordPirOperatorClient::OnExecute(const PirDataType& input,
PirDataType* result) {
auto link_ctx = this->GetLinkContext();
VLOG(5) << "begin to request psi params";
auto ret = RequestPSIParams();
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
std::vector<std::string> orig_item;

std::vector<std::string> orig_item_total;
std::vector<apsi::Item> items_vec;
orig_item.reserve(input.size());
orig_item_total.reserve(input.size());
items_vec.reserve(input.size());
for (const auto& [key, val_vec] : input) {
apsi::Item item = key;
items_vec.emplace_back(std::move(item));
orig_item.emplace_back(key);
orig_item_total.emplace_back(key);
}
std::vector<HashedItem> oprf_items;
std::vector<LabelKey> label_keys;
std::vector<HashedItem> oprf_items_total;
std::vector<LabelKey> label_keys_total;
VLOG(5) << "begin to Receiver::RequestOPRF";
ret = RequestOprf(items_vec, &oprf_items, &label_keys);
ret = RequestOprf(items_vec, &oprf_items_total, &label_keys_total);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);

CHECK_TASK_STOPPED(retcode::FAIL);
VLOG(5) << "Receiver::RequestOPRF end, begin to receiver.request_query";
// request query

this->receiver_ = std::make_unique<Receiver>(*psi_params_);
std::vector<MatchRecord> query_result;
auto query = this->receiver_->create_query(oprf_items);
// chl.send(move(query.first));
auto request_query_data = std::move(query.first);
std::ostringstream string_ss;
request_query_data->save(string_ss);
std::string query_data_str = string_ss.str();
auto itt = move(query.second);
VLOG(5) << "query_data_str size: " << query_data_str.size();
auto link_ctx = this->GetLinkContext();
ret = link_ctx->Send(this->key_, PeerNode(), query_data_str);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
size_t query_size = input.size();
auto table_size = this->psi_params_->table_params().table_size;
table_size = static_cast<size_t>(table_size * PirConstant::table_size_factor);
size_t block_num = query_size / table_size;
size_t rem_size = query_size % table_size;
std::vector<int64_t> block_item_info;
for (size_t i = 0; i < block_num; i++) {
block_item_info.push_back(table_size);
}
if (rem_size != 0) {
block_item_info.push_back(rem_size);
}
if (VLOG_IS_ON(5)) {
std::string block_item_info_str;
for (const auto& item : block_item_info) {
block_item_info_str.append(std::to_string(item)).append(" ");
}
LOG(INFO) << "block_item_info: " << block_item_info_str;
}

// receive package count
uint32_t package_count = 0;
std::string pkg_count_key = this->PackageCountKey(link_ctx->request_id());
ret = link_ctx->Recv(pkg_count_key,
this->PeerNode(),
reinterpret_cast<char*>(&package_count),
sizeof(package_count));
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
int64_t start_index{0};
for (size_t i = 0; i < block_item_info.size(); i++) {
int64_t size_per_query = block_item_info[i];
if (i > 0) {
start_index += block_item_info[i-1];
}
LOG(INFO) << "start batch group: " << i << " "
<< "start index: " << start_index << " "
<< "group size: " << size_per_query;
std::vector<std::string> orig_item;
std::vector<HashedItem> oprf_items;
std::vector<LabelKey> label_keys;
orig_item.reserve(size_per_query);
oprf_items.reserve(size_per_query);
label_keys.reserve(size_per_query);
for (size_t j = 0; j < size_per_query; j++) {
size_t index = start_index + j;
orig_item.push_back(orig_item_total[index]);
oprf_items.push_back(oprf_items_total[index]);
label_keys.push_back(label_keys_total[index]);
}
// request query
std::vector<MatchRecord> query_result;
auto query = this->receiver_->create_query(oprf_items);
// chl.send(move(query.first));
auto request_query_data = std::move(query.first);
std::ostringstream string_ss;
request_query_data->save(string_ss);
std::string query_data_str = string_ss.str();
auto itt = move(query.second);
VLOG(5) << "query_data_str size: " << query_data_str.size();

VLOG(5) << "received package count: " << package_count;
std::vector<apsi::ResultPart> result_packages;
for (size_t i = 0; i < package_count; i++) {
std::string recv_data;
ret = link_ctx->Recv(this->response_key_, this->PeerNode(), &recv_data);
ret = link_ctx->Send(this->key_, PeerNode(), query_data_str);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
VLOG(5) << "client received data length: " << recv_data.size();
std::istringstream stream_in(recv_data);
auto result_part = std::make_unique<apsi::network::ResultPackage>();
auto seal_context = this->receiver_->get_seal_context();
result_part->load(stream_in, seal_context);
result_packages.push_back(std::move(result_part));

// receive package count
uint32_t package_count = 0;
std::string pkg_count_key = this->PackageCountKey(link_ctx->request_id());
ret = link_ctx->Recv(pkg_count_key,
this->PeerNode(),
reinterpret_cast<char*>(&package_count),
sizeof(package_count));
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);

VLOG(5) << "received package count: " << package_count;
std::vector<apsi::ResultPart> result_packages;
for (size_t i = 0; i < package_count; i++) {
std::string recv_data;
ret = link_ctx->Recv(this->response_key_, this->PeerNode(), &recv_data);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
VLOG(5) << "client received data length: " << recv_data.size();
std::istringstream stream_in(recv_data);
auto result_part = std::make_unique<apsi::network::ResultPackage>();
auto seal_context = this->receiver_->get_seal_context();
result_part->load(stream_in, seal_context);
result_packages.push_back(std::move(result_part));
}
query_result = this->receiver_->process_result(label_keys, itt,
result_packages);
VLOG(5) << "query_resultquery_resultquery_resultquery_result: "
<< query_result.size();
ExtractResult(orig_item, query_result, result);
}
query_result = this->receiver_->process_result(label_keys, itt,
result_packages);
VLOG(5) << "query_resultquery_resultquery_resultquery_result: "
<< query_result.size();
ExtractResult(orig_item, query_result, result);
{
std::string task_end{"SUCCESS"};
auto link_ctx = this->GetLinkContext();
Expand Down Expand Up @@ -168,7 +213,7 @@ retcode KeywordPirOperatorClient::RequestOprf(const std::vector<Item>& items,
<< PeerNode().to_string() << "] failed";
return retcode::FAIL;
}
VLOG(5) << "received oprf response length: " << oprf_response.size() << " ";
VLOG(0) << "received oprf response length: " << oprf_response.size() << " ";
oprf_receiver.process_responses(oprf_response, res_items, res_label_keys);
return retcode::SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ enum class RequestType : uint8_t {
Oprf,
Query,
};
struct PirConstant {
inline static double table_size_factor{0.9};
};
}
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_IMPL_KEYWORD_PIR_COMMON_H_
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ retcode KeywordPirOperatorServer::OnExecute(const PirDataType& input,
size_t use_core_num = cpu_core_num / 2;
LOG(INFO) << "ThreadPoolMgr thread count: " << use_core_num;
ThreadPoolMgr::SetThreadCount(use_core_num);
int64_t client_data_size{0};

auto params = SetPsiParams();
CHECK_NULLPOINTER(params, retcode::FAIL);
Expand Down Expand Up @@ -83,9 +84,22 @@ retcode KeywordPirOperatorServer::OnExecute(const PirDataType& input,
}
ret = ProcessOprf();
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
auto table_size = static_cast<size_t>(this->table_size_ * PirConstant::table_size_factor);
auto block_size = this->query_data_size_ / table_size;
auto rem_size = this->query_data_size_ % table_size;
if (rem_size != 0) {
block_size++;
}
LOG(INFO) << "size of loop: " << block_size << " "
<< "query_data_size: " << query_data_size_ << " "
<< "table size: " << table_size;
for (size_t i = 0 ; i < block_size; i++) {
LOG(INFO) << "current loop: " << i << " total: " << block_size;
ret = ProcessQuery(sender_db);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
LOG(INFO) << "end of process loop: " << i;
}

ret = ProcessQuery(sender_db);
CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL);
{
std::string task_end;
auto link_ctx = this->GetLinkContext();
Expand Down Expand Up @@ -232,10 +246,11 @@ retcode KeywordPirOperatorServer::ProcessOprf() {
// // OPRFKey key_oprf;
auto oprf_response =
OPRFSender::ProcessQueries(oprf_request_str, *(this->oprf_key_));

this->query_data_size_ = oprf_request_str.size() / apsi::oprf::oprf_query_size;
std::string oprf_response_str{
reinterpret_cast<char*>(const_cast<unsigned char*>(oprf_response.data())),
oprf_response.size()};
VLOG(5) << "qeury size from client: " << query_data_size_;
// VLOG(5) << "send data size: " << oprf_response_str.size() << " "
// << "data content: " << oprf_response_str;
return link_ctx->Send(this->response_key_, ProxyNode(), oprf_response_str);
Expand Down Expand Up @@ -418,8 +433,8 @@ retcode KeywordPirOperatorServer::ProcessQuery(
for (auto& f : futures) {
f.get();
}
link_ctx->CheckSendCompleteStatus(this->response_key_,
ProxyNode(), package_count);
// link_ctx->CheckSendCompleteStatus(this->response_key_,
// ProxyNode(), package_count);
VLOG(5) << "Finished processing query request";
return retcode::SUCCESS;
}
Expand Down Expand Up @@ -599,15 +614,16 @@ std::unique_ptr<apsi::PSIParams> KeywordPirOperatorServer::SetPsiParams() {
}

// std::unique_ptr<PSIParams> params{nullptr};
auto params = std::make_unique<PSIParams>(PSIParams::Load(params_json));
auto psi_params = std::make_unique<PSIParams>(PSIParams::Load(params_json));
SCopedTimer timer;
std::ostringstream param_ss;
size_t param_size = params->save(param_ss);
size_t param_size = psi_params->save(param_ss);
psi_params_str_ = param_ss.str();
auto time_cost = timer.timeElapse();
this->table_size_ = psi_params->table_params().table_size;
VLOG(5) << "param_size: " << param_size << " time cost(ms): " << time_cost;
VLOG(5) << "param_content: " << psi_params_str_.size();
return params;
return psi_params;
}

bool KeywordPirOperatorServer::DbCacheAvailable(const std::string& db_path) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class KeywordPirOperatorServer : public BasePirOperator {
std::shared_ptr<apsi::sender::SenderDB>;

private:
int64_t query_data_size_;
int64_t table_size_;
std::string psi_params_str_;
std::unique_ptr<apsi::oprf::OPRFKey> oprf_key_{nullptr};
std::unique_ptr<apsi::PSIParams> psi_params_{nullptr};
Expand Down

0 comments on commit 3976c17

Please sign in to comment.