diff --git a/src/primihub/algorithm/arithmetic.cc b/src/primihub/algorithm/arithmetic.cc index ccd2274b..a13ec1e7 100644 --- a/src/primihub/algorithm/arithmetic.cc +++ b/src/primihub/algorithm/arithmetic.cc @@ -1,15 +1,16 @@ +#include "src/primihub/algorithm/arithmetic.h" + #include #include #include #include -#include "src/primihub/algorithm/arithmetic.h" #include "src/primihub/data_store/csv/csv_driver.h" #include "src/primihub/data_store/factory.h" -#include "src/primihub/util/util.h" -#include "src/primihub/util/network/message_interface.h" #include "src/primihub/util/file_util.h" +#include "src/primihub/util/network/message_interface.h" +#include "src/primihub/util/util.h" using arrow::Array; using arrow::DoubleArray; @@ -42,7 +43,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { LOG(ERROR) << "no data set found for party name: " << this->party_name(); return -1; } - const auto& dataset = it->second.data(); + const auto &dataset = it->second.data(); auto iter = dataset.find("Data_File"); if (iter == dataset.end()) { LOG(ERROR) << "no dataset found for dataset name Data_File"; @@ -51,7 +52,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { // File path. if (it->second.dataset_detail()) { this->is_dataset_detail_ = true; - auto& param_map = task.params().param_map(); + auto ¶m_map = task.params().param_map(); auto p_it = param_map.find("Data_File"); if (p_it != param_map.end()) { this->data_file_path_ = p_it->second.value_string(); @@ -68,7 +69,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { LOG(INFO) << "Data file path is " << data_file_path_ << "."; auto param_map = task.params().param_map(); try { - const auto& task_info = task.task_info(); + const auto &task_info = task.task_info(); task_id_ = task_info.task_id(); job_id_ = task_info.job_id(); @@ -83,7 +84,8 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { uint16_t owner; auto ret = party_config_.PartyName2PartyId(party_name, &owner); if (ret != retcode::SUCCESS) { - LOG(ERROR) << "convert party name to party id failed for: " << party_name; + LOG(ERROR) << "convert party name to party id failed for: " + << party_name; return -1; } col_and_owner_.insert(make_pair(col, owner)); @@ -101,11 +103,13 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { // LOG(INFO) << col << ":" << dtype; } // LOG(INFO) << col_and_dtype; - + expr_ = param_map["Expr"].value_string(); + int comma_index = expr_.find(","); + cmp_col1 = expr_.substr(4, comma_index - 4); + cmp_col2 = expr_.substr(comma_index + 1, expr_.length() - comma_index - 2); is_cmp = false; - if (expr_.substr(0, 3) == "CMP") - is_cmp = true; + if (expr_.substr(0, 3) == "CMP") is_cmp = true; if (is_cmp) { std::string next_name; std::string prev_name; @@ -120,7 +124,8 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { prev_name = "12"; } - mpc_op_exec_ = std::make_unique(party_id_, next_name, prev_name); + mpc_op_exec_ = + std::make_unique(party_id_, next_name, prev_name); } else { mpc_exec_ = std::make_unique>(); } @@ -150,7 +155,8 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { return 0; } -template int ArithmeticExecutor::loadDataset() { +template +int ArithmeticExecutor::loadDataset() { int ret = _LoadDatasetFromCSV(this->dataset_id_); // file reading error or file empty if (ret <= 0) { @@ -158,8 +164,85 @@ template int ArithmeticExecutor::loadDataset() { return -1; } - if (is_cmp) + if (is_cmp) { + // When the data types of two columns are different, + // conversion is possible. + auto iter1 = col_and_dtype_.find(cmp_col1); + if (iter1 == col_and_dtype_.end()) { + std::stringstream ss; + ss << "Can't find dtype of column " << cmp_col1; + LOG(ERROR) << ss.str() << "."; + throw std::runtime_error(ss.str()); + } + + auto iter2 = col_and_dtype_.find(cmp_col2); + if (iter2 == col_and_dtype_.end()) { + std::stringstream ss; + ss << "Can't find dtype of column " << cmp_col1; + LOG(ERROR) << ss.str() << "."; + throw std::runtime_error(ss.str()); + } + + if (iter1->second == iter2->second) { + if (iter1->second == 1) + i64_cmp = false; + else + i64_cmp = true; + + return 0; + } + + if (iter1->second == 0) + LOG(INFO) << "Dtype of the compared column don't match, Convert dtype of " + "column " + << cmp_col1 << " to double."; + else + LOG(INFO) << "Dtype of the compared column don't match, Convert dtype of " + "column " + << cmp_col2 << " to double."; + + i64_cmp = false; + + std::string convert_col; + if (iter1->second == 0) + convert_col = cmp_col1; + else + convert_col = cmp_col2; + + auto iter3 = col_and_owner_.find(convert_col); + if (iter3 == col_and_owner_.end()) { + std::stringstream ss; + ss << "Can't find the party that column " << convert_col << " belong to."; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); + } + + if (party_id_ != iter3->second) { + LOG(INFO) << "Skip dtype convert because column " << convert_col + << " don't belong to this party."; + return 0; + } + + auto col_iter = col_and_val_int.find(convert_col); + if (col_iter == col_and_val_int.end()) { + std::stringstream ss; + ss << "Can't find column value with column name " << convert_col << "."; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); + } + + std::vector &col_src = col_iter->second; + std::vector col_dest; + + for (auto &elem : col_src) col_dest.push_back(elem); + + col_and_val_double.insert(std::make_pair(convert_col, col_dest)); + col_and_val_int.erase(col_iter); + + LOG(INFO) << "Convert value of column " << convert_col + << " to double finish."; return 0; + } mpc_exec_->initColumnConfig(party_id_); for (auto &pair : col_and_owner_) @@ -198,22 +281,55 @@ template int ArithmeticExecutor::execute() { if (is_cmp) { try { + LOG(INFO) << "Run MPC Compare between " << cmp_col1 << " and " << cmp_col2 + << "."; + sbMatrix sh_res; - f64Matrix m; - if (col_and_owner_[expr_.substr(4, 1)] == party_id_) { - m.resize(1, col_and_val_double[expr_.substr(4, 1)].size()); - for (size_t i = 0; i < col_and_val_double[expr_.substr(4, 1)].size(); - i++) - m(i) = col_and_val_double[expr_.substr(4, 1)][i]; - mpc_op_exec_->MPC_Compare(m, sh_res); - } else if (col_and_owner_[expr_.substr(6, 1)] == party_id_) { - m.resize(1, col_and_val_double[expr_.substr(6, 1)].size()); - for (size_t i = 0; i < col_and_val_double[expr_.substr(6, 1)].size(); - i++) - m(i) = col_and_val_double[expr_.substr(6, 1)][i]; - mpc_op_exec_->MPC_Compare(m, sh_res); + + if (i64_cmp) { + i64Matrix m; + if (col_and_owner_[cmp_col1] == party_id_) { + size_t count = col_and_val_int[cmp_col1].size(); + m.resize(1, count); + + std::vector &col = col_and_val_int[cmp_col1]; + for (size_t i = 0; i < count; i++) + m(i) = col[i]; + + mpc_op_exec_->MPC_Compare(m, sh_res); + } else if (col_and_owner_[cmp_col2] == party_id_) { + size_t count = col_and_val_int[cmp_col2].size(); + m.resize(1, count); + + std::vector &col = col_and_val_int[cmp_col2]; + for (size_t i = 0; i < count; i++) + m(i) = col[i]; + + mpc_op_exec_->MPC_Compare(m, sh_res); + } else { + mpc_op_exec_->MPC_Compare(sh_res); + } } else { - mpc_op_exec_->MPC_Compare(sh_res); + f64Matrix m; + if (col_and_owner_[cmp_col1] == party_id_) { + size_t count = col_and_val_double[cmp_col1].size(); + m.resize(1, count); + + std::vector &col = col_and_val_double[cmp_col1]; + for (size_t i = 0; i < count; i++) m(i) = col[i]; + + mpc_op_exec_->MPC_Compare(m, sh_res); + } else if (col_and_owner_[cmp_col2] == party_id_) { + size_t count = col_and_val_double[cmp_col2].size(); + m.resize(1, count); + + std::vector &col = col_and_val_double[cmp_col2]; + for (size_t i = 0; i < count; i++) m(i) = col[i]; + + mpc_op_exec_->MPC_Compare(m, sh_res); + } else { + mpc_op_exec_->MPC_Compare(sh_res); + } } // reveal @@ -227,16 +343,19 @@ int ArithmeticExecutor::execute() { } } } catch (std::exception &e) { - LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << "."; + std::stringstream ss; + ss << "Error occurs during MPC Compare, " << e.what(); + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); } + return 0; } try { std::stringstream ss; ss << "Reveal result to"; - for (auto &party : parties_) - ss << " " << party; + for (auto &party : parties_) ss << " " << party; ss << "."; LOG(INFO) << ss.str(); @@ -247,14 +366,17 @@ int ArithmeticExecutor::execute() { mpc_exec_->revealMPCResult(parties_, final_val_int64_); } } catch (const std::exception &e) { - std::string msg = "In party 0, "; - msg = msg + e.what(); - throw std::runtime_error(msg); + std::stringstream ss; + ss << "Error occurs during MPC run, " << e.what() << "."; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); } + return 0; } -template int ArithmeticExecutor::saveModel(void) { +template +int ArithmeticExecutor::saveModel(void) { bool is_reveal = false; for (auto party : parties_) { if (party == party_id_) { @@ -305,7 +427,7 @@ template int ArithmeticExecutor::saveModel(void) { DataDirverFactory::getDriver("CSV", dataset_service_->getNodeletAddr()); // std::shared_ptr csv_driver = // std::dynamic_pointer_cast(driver); - auto& filepath = res_name_; + auto &filepath = res_name_; auto data_cursor = driver->initCursor(filepath); auto dataset = std::make_shared(table, driver); // int ret = 0; @@ -325,8 +447,8 @@ template int ArithmeticExecutor::saveModel(void) { template int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &dataset_id) { - auto driver = this->dataset_service_->getDriver(dataset_id, - this->is_dataset_detail_); + auto driver = + this->dataset_service_->getDriver(dataset_id, this->is_dataset_detail_); if (driver == nullptr) { LOG(ERROR) << "load dataset driver failed"; return -1; @@ -341,7 +463,7 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &dataset_id) { LOG(ERROR) << "load dataset failed"; return -1; } - auto& table = std::get>(ds->data); + auto &table = std::get>(ds->data); // Label column. std::vector col_names = table->ColumnNames(); @@ -388,7 +510,7 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &dataset_id) { break; } col_and_val_int.insert(std::make_pair(col_names[i], tmp_data)); - // std::pair>(col_names[i], tmp_data)); + // std::pair>(col_names[i], tmp_data)); // for (auto itr = col_and_val_int.begin(); itr != col_and_val_int.end(); // itr++) { // LOG(INFO) << itr->first; @@ -435,7 +557,7 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &dataset_id) { } } col_and_val_double.insert(std::make_pair(col_names[i], tmp_data)); - // pair>(col_names[i], tmp_data)); + // pair>(col_names[i], tmp_data)); // for (auto itr = col_and_val_double.begin(); // itr != col_and_val_double.end(); itr++) { // LOG(INFO) << itr->first; @@ -445,8 +567,7 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &dataset_id) { // } } } - if (errors) - return -1; + if (errors) return -1; return array_len; } @@ -524,16 +645,17 @@ template class ArithmeticExecutor; // auto prev_party_info = this->party_config_.PrevPartyInfo(); // auto base_channel_prev = link_ctx->getChannel(prev_party_info); - // // The 'osuCrypto::Channel' will consider it to be a unique_ptr and will // // reset the unique_ptr, so the 'osuCrypto::Channel' will delete it. -// auto msg_interface_prev = std::make_unique( -// link_ctx->job_id(), link_ctx->task_id(), link_ctx->request_id(), this->party_name(), -// prev_party_name, link_ctx, base_channel_prev); +// auto msg_interface_prev = +// std::make_unique( +// link_ctx->job_id(), link_ctx->task_id(), link_ctx->request_id(), +// this->party_name(), prev_party_name, link_ctx, base_channel_prev); -// auto msg_interface_next = std::make_unique( -// link_ctx->job_id(), link_ctx->task_id(), link_ctx->request_id(), this->party_name(), -// next_party_name, link_ctx, base_channel_next); +// auto msg_interface_next = +// std::make_unique( +// link_ctx->job_id(), link_ctx->task_id(), link_ctx->request_id(), +// this->party_name(), next_party_name, link_ctx, base_channel_next); // osuCrypto::Channel chl_prev(ios_, msg_interface_prev.release()); // osuCrypto::Channel chl_next(ios_, msg_interface_next.release()); @@ -566,7 +688,8 @@ template class ArithmeticExecutor; // int MPCSendRecvExecutor::execute() { -// // Phase 1: simulate the communication in the creation of matrix's arithmetic +// // Phase 1: simulate the communication in the creation of matrix's +// arithmetic // // share. // LOG(INFO) << "Send and recv si64Matrix."; @@ -640,7 +763,8 @@ template class ArithmeticExecutor; // LOG(INFO) << "Finish."; -// // Phase 3: simulate the communicate in the creation of a value's arithmetic +// // Phase 3: simulate the communicate in the creation of a value's +// arithmetic // // share. // LOG(INFO) << "Send and recv si64."; // { @@ -775,4 +899,4 @@ template class ArithmeticExecutor; // } // #endif -} // namespace primihub +} // namespace primihub diff --git a/src/primihub/algorithm/arithmetic.h b/src/primihub/algorithm/arithmetic.h index 32a5c344..21927eed 100644 --- a/src/primihub/algorithm/arithmetic.h +++ b/src/primihub/algorithm/arithmetic.h @@ -48,9 +48,12 @@ class ArithmeticExecutor : public AlgorithmBase { // For MPC compare task. bool is_cmp{false}; std::vector cmp_res_; + bool i64_cmp{true}; // For MPC express task. std::string expr_; + std::string cmp_col1; + std::string cmp_col2; std::map col_and_owner_; std::map col_and_dtype_; std::vector final_val_double_;