From 3976c1707595184f366352c3dffb1b43e3f2c85d Mon Sep 17 00:00:00 2001 From: phoenix20162016 Date: Mon, 13 May 2024 18:36:57 +0800 Subject: [PATCH] Pir batch query (#775) *number of queries must be less than table size --- WORKSPACE_CN | 2 +- WORKSPACE_GITHUB | 2 +- config/pir_server_config.json | 2 +- .../keyword_pir_impl/keyword_pir_client.cc | 133 ++++++++++++------ .../keyword_pir_impl/keyword_pir_common.h | 3 + .../keyword_pir_impl/keyword_pir_server.cc | 32 +++-- .../keyword_pir_impl/keyword_pir_server.h | 2 + 7 files changed, 121 insertions(+), 55 deletions(-) diff --git a/WORKSPACE_CN b/WORKSPACE_CN index 4f435bf79..e516e8abb 100644 --- a/WORKSPACE_CN +++ b/WORKSPACE_CN @@ -266,7 +266,7 @@ http_archive( git_repository( name = "mircrosoft_apsi", #branch = "bazel_version", - commit = "44243c1a85435c04ca858279757ca5524dd3c9aa", + commit = "0b192b2df72a136900b8c06fdd0caff8b1d4ce8d", remote = "https://gitee.com/primihub/APSI.git", ) diff --git a/WORKSPACE_GITHUB b/WORKSPACE_GITHUB index 8bcfab72a..90dd43f04 100644 --- a/WORKSPACE_GITHUB +++ b/WORKSPACE_GITHUB @@ -544,7 +544,7 @@ http_archive( git_repository( name = "mircrosoft_apsi", #branch = "bazel_version", - commit = "44243c1a85435c04ca858279757ca5524dd3c9aa", + commit = "0b192b2df72a136900b8c06fdd0caff8b1d4ce8d", remote = "https://github.com/primihub/APSI.git", ) diff --git a/config/pir_server_config.json b/config/pir_server_config.json index 9800ed2d6..017d321bb 100644 --- a/config/pir_server_config.json +++ b/config/pir_server_config.json @@ -1,6 +1,6 @@ { "table_params": { - "hash_func_count": 2, + "hash_func_count": 5, "table_size": 409, "max_items_per_bin": 20 }, diff --git a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_client.cc b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_client.cc index ecb613581..4a3322230 100644 --- a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_client.cc +++ b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_client.cc @@ -31,68 +31,113 @@ using PSIParams = apsi::PSIParams; retcode KeywordPirOperatorClient::OnExecute(const PirDataType& input, PirDataType* result) { + auto link_ctx = this->GetLinkContext(); VLOG(5) << "begin to request psi params"; auto ret = RequestPSIParams(); CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); - std::vector orig_item; + + std::vector orig_item_total; std::vector items_vec; - orig_item.reserve(input.size()); + orig_item_total.reserve(input.size()); items_vec.reserve(input.size()); for (const auto& [key, val_vec] : input) { apsi::Item item = key; items_vec.emplace_back(std::move(item)); - orig_item.emplace_back(key); + orig_item_total.emplace_back(key); } - std::vector oprf_items; - std::vector label_keys; + std::vector oprf_items_total; + std::vector label_keys_total; VLOG(5) << "begin to Receiver::RequestOPRF"; - ret = RequestOprf(items_vec, &oprf_items, &label_keys); + ret = RequestOprf(items_vec, &oprf_items_total, &label_keys_total); CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); CHECK_TASK_STOPPED(retcode::FAIL); VLOG(5) << "Receiver::RequestOPRF end, begin to receiver.request_query"; - // request query + this->receiver_ = std::make_unique(*psi_params_); - std::vector query_result; - auto query = this->receiver_->create_query(oprf_items); - // chl.send(move(query.first)); - auto request_query_data = std::move(query.first); - std::ostringstream string_ss; - request_query_data->save(string_ss); - std::string query_data_str = string_ss.str(); - auto itt = move(query.second); - VLOG(5) << "query_data_str size: " << query_data_str.size(); - auto link_ctx = this->GetLinkContext(); - ret = link_ctx->Send(this->key_, PeerNode(), query_data_str); - CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + size_t query_size = input.size(); + auto table_size = this->psi_params_->table_params().table_size; + table_size = static_cast(table_size * PirConstant::table_size_factor); + size_t block_num = query_size / table_size; + size_t rem_size = query_size % table_size; + std::vector block_item_info; + for (size_t i = 0; i < block_num; i++) { + block_item_info.push_back(table_size); + } + if (rem_size != 0) { + block_item_info.push_back(rem_size); + } + if (VLOG_IS_ON(5)) { + std::string block_item_info_str; + for (const auto& item : block_item_info) { + block_item_info_str.append(std::to_string(item)).append(" "); + } + LOG(INFO) << "block_item_info: " << block_item_info_str; + } - // receive package count - uint32_t package_count = 0; - std::string pkg_count_key = this->PackageCountKey(link_ctx->request_id()); - ret = link_ctx->Recv(pkg_count_key, - this->PeerNode(), - reinterpret_cast(&package_count), - sizeof(package_count)); - CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + int64_t start_index{0}; + for (size_t i = 0; i < block_item_info.size(); i++) { + int64_t size_per_query = block_item_info[i]; + if (i > 0) { + start_index += block_item_info[i-1]; + } + LOG(INFO) << "start batch group: " << i << " " + << "start index: " << start_index << " " + << "group size: " << size_per_query; + std::vector orig_item; + std::vector oprf_items; + std::vector label_keys; + orig_item.reserve(size_per_query); + oprf_items.reserve(size_per_query); + label_keys.reserve(size_per_query); + for (size_t j = 0; j < size_per_query; j++) { + size_t index = start_index + j; + orig_item.push_back(orig_item_total[index]); + oprf_items.push_back(oprf_items_total[index]); + label_keys.push_back(label_keys_total[index]); + } + // request query + std::vector query_result; + auto query = this->receiver_->create_query(oprf_items); + // chl.send(move(query.first)); + auto request_query_data = std::move(query.first); + std::ostringstream string_ss; + request_query_data->save(string_ss); + std::string query_data_str = string_ss.str(); + auto itt = move(query.second); + VLOG(5) << "query_data_str size: " << query_data_str.size(); - VLOG(5) << "received package count: " << package_count; - std::vector result_packages; - for (size_t i = 0; i < package_count; i++) { - std::string recv_data; - ret = link_ctx->Recv(this->response_key_, this->PeerNode(), &recv_data); + ret = link_ctx->Send(this->key_, PeerNode(), query_data_str); CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); - VLOG(5) << "client received data length: " << recv_data.size(); - std::istringstream stream_in(recv_data); - auto result_part = std::make_unique(); - auto seal_context = this->receiver_->get_seal_context(); - result_part->load(stream_in, seal_context); - result_packages.push_back(std::move(result_part)); + + // receive package count + uint32_t package_count = 0; + std::string pkg_count_key = this->PackageCountKey(link_ctx->request_id()); + ret = link_ctx->Recv(pkg_count_key, + this->PeerNode(), + reinterpret_cast(&package_count), + sizeof(package_count)); + CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + + VLOG(5) << "received package count: " << package_count; + std::vector result_packages; + for (size_t i = 0; i < package_count; i++) { + std::string recv_data; + ret = link_ctx->Recv(this->response_key_, this->PeerNode(), &recv_data); + CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + VLOG(5) << "client received data length: " << recv_data.size(); + std::istringstream stream_in(recv_data); + auto result_part = std::make_unique(); + auto seal_context = this->receiver_->get_seal_context(); + result_part->load(stream_in, seal_context); + result_packages.push_back(std::move(result_part)); + } + query_result = this->receiver_->process_result(label_keys, itt, + result_packages); + VLOG(5) << "query_resultquery_resultquery_resultquery_result: " + << query_result.size(); + ExtractResult(orig_item, query_result, result); } - query_result = this->receiver_->process_result(label_keys, itt, - result_packages); - VLOG(5) << "query_resultquery_resultquery_resultquery_result: " - << query_result.size(); - ExtractResult(orig_item, query_result, result); { std::string task_end{"SUCCESS"}; auto link_ctx = this->GetLinkContext(); @@ -168,7 +213,7 @@ retcode KeywordPirOperatorClient::RequestOprf(const std::vector& items, << PeerNode().to_string() << "] failed"; return retcode::FAIL; } - VLOG(5) << "received oprf response length: " << oprf_response.size() << " "; + VLOG(0) << "received oprf response length: " << oprf_response.size() << " "; oprf_receiver.process_responses(oprf_response, res_items, res_label_keys); return retcode::SUCCESS; } diff --git a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_common.h b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_common.h index 6137a621c..ef9e84adc 100644 --- a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_common.h +++ b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_common.h @@ -22,5 +22,8 @@ enum class RequestType : uint8_t { Oprf, Query, }; +struct PirConstant { + inline static double table_size_factor{0.9}; +}; } #endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_IMPL_KEYWORD_PIR_COMMON_H_ diff --git a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.cc b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.cc index 7f02aacb1..6a8a6becd 100644 --- a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.cc +++ b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.cc @@ -43,6 +43,7 @@ retcode KeywordPirOperatorServer::OnExecute(const PirDataType& input, size_t use_core_num = cpu_core_num / 2; LOG(INFO) << "ThreadPoolMgr thread count: " << use_core_num; ThreadPoolMgr::SetThreadCount(use_core_num); + int64_t client_data_size{0}; auto params = SetPsiParams(); CHECK_NULLPOINTER(params, retcode::FAIL); @@ -83,9 +84,22 @@ retcode KeywordPirOperatorServer::OnExecute(const PirDataType& input, } ret = ProcessOprf(); CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + auto table_size = static_cast(this->table_size_ * PirConstant::table_size_factor); + auto block_size = this->query_data_size_ / table_size; + auto rem_size = this->query_data_size_ % table_size; + if (rem_size != 0) { + block_size++; + } + LOG(INFO) << "size of loop: " << block_size << " " + << "query_data_size: " << query_data_size_ << " " + << "table size: " << table_size; + for (size_t i = 0 ; i < block_size; i++) { + LOG(INFO) << "current loop: " << i << " total: " << block_size; + ret = ProcessQuery(sender_db); + CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); + LOG(INFO) << "end of process loop: " << i; + } - ret = ProcessQuery(sender_db); - CHECK_RETCODE_WITH_RETVALUE(ret, retcode::FAIL); { std::string task_end; auto link_ctx = this->GetLinkContext(); @@ -232,10 +246,11 @@ retcode KeywordPirOperatorServer::ProcessOprf() { // // OPRFKey key_oprf; auto oprf_response = OPRFSender::ProcessQueries(oprf_request_str, *(this->oprf_key_)); - + this->query_data_size_ = oprf_request_str.size() / apsi::oprf::oprf_query_size; std::string oprf_response_str{ reinterpret_cast(const_cast(oprf_response.data())), oprf_response.size()}; + VLOG(5) << "qeury size from client: " << query_data_size_; // VLOG(5) << "send data size: " << oprf_response_str.size() << " " // << "data content: " << oprf_response_str; return link_ctx->Send(this->response_key_, ProxyNode(), oprf_response_str); @@ -418,8 +433,8 @@ retcode KeywordPirOperatorServer::ProcessQuery( for (auto& f : futures) { f.get(); } - link_ctx->CheckSendCompleteStatus(this->response_key_, - ProxyNode(), package_count); + // link_ctx->CheckSendCompleteStatus(this->response_key_, + // ProxyNode(), package_count); VLOG(5) << "Finished processing query request"; return retcode::SUCCESS; } @@ -599,15 +614,16 @@ std::unique_ptr KeywordPirOperatorServer::SetPsiParams() { } // std::unique_ptr params{nullptr}; - auto params = std::make_unique(PSIParams::Load(params_json)); + auto psi_params = std::make_unique(PSIParams::Load(params_json)); SCopedTimer timer; std::ostringstream param_ss; - size_t param_size = params->save(param_ss); + size_t param_size = psi_params->save(param_ss); psi_params_str_ = param_ss.str(); auto time_cost = timer.timeElapse(); + this->table_size_ = psi_params->table_params().table_size; VLOG(5) << "param_size: " << param_size << " time cost(ms): " << time_cost; VLOG(5) << "param_content: " << psi_params_str_.size(); - return params; + return psi_params; } bool KeywordPirOperatorServer::DbCacheAvailable(const std::string& db_path) { diff --git a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.h b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.h index 885d36845..f6b1232fb 100644 --- a/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.h +++ b/src/primihub/kernel/pir/operator/keyword_pir_impl/keyword_pir_server.h @@ -114,6 +114,8 @@ class KeywordPirOperatorServer : public BasePirOperator { std::shared_ptr; private: + int64_t query_data_size_; + int64_t table_size_; std::string psi_params_str_; std::unique_ptr oprf_key_{nullptr}; std::unique_ptr psi_params_{nullptr};