From ab2581b76213b571379c4e3cc832435ef65efc47 Mon Sep 17 00:00:00 2001 From: phoenix20162016 Date: Mon, 26 Feb 2024 18:13:16 +0800 Subject: [PATCH] mpc-lr read int32 type data as label (#759) --- src/primihub/algorithm/logistic.cc | 123 +++++++++++----------- src/primihub/algorithm/logistic.h | 24 +++++ src/primihub/data_store/csv/csv_driver.cc | 5 +- 3 files changed, 87 insertions(+), 65 deletions(-) diff --git a/src/primihub/algorithm/logistic.cc b/src/primihub/algorithm/logistic.cc index dc4ec472..c0586768 100755 --- a/src/primihub/algorithm/logistic.cc +++ b/src/primihub/algorithm/logistic.cc @@ -32,6 +32,7 @@ using arrow::Array; using arrow::DoubleArray; using arrow::Int64Array; +using arrow::Int32Array; using arrow::Table; namespace primihub { eMatrix logistic_main(sf64Matrix &train_data_0_1, @@ -189,12 +190,18 @@ int LogisticRegressionExecutor::_LoadDataset(const std::string &dataset_id) { int64_t array_len{0}; int chunk_size = table->column(num_col - 1)->num_chunks(); + std::vector train_data_each_chunk(chunk_size); auto& chunks = table->column(num_col - 1)->chunks(); - for (const auto& array : chunks) { - array_len += array->length(); + double training_set_percentage = 0.8; + for (int i = 0; i < chunk_size; i++) { + auto& array = chunks[i]; + int64_t len = array->length(); + train_data_each_chunk[i] = floor(len * training_set_percentage); + array_len += len; } - VLOG(3) << "Label column '" << col_names[num_col - 1] << "' has " << array_len - << " values."; + + VLOG(3) << "Label column '" << col_names[num_col - 1] + << "' has " << array_len << " values."; // Force the same value count in every column. for (int i = 0; i < num_col - 1; i++) { @@ -215,37 +222,13 @@ int LogisticRegressionExecutor::_LoadDataset(const std::string &dataset_id) { if (errors) return -1; - int64_t train_length = floor(array_len * 0.8); + int64_t train_length = std::accumulate(train_data_each_chunk.begin(), + train_data_each_chunk.end(), + 0); int64_t test_length = array_len - train_length; - // LOG(INFO)<<"array_len: "<schema()->GetFieldByName(col_names[i])->type()->id() == 9) { - // auto array = - // std::static_pointer_cast(table->column(i)->chunk(0)); - // for (int64_t j = 0; j < array->length(); j++) { - // if (j < train_length) - // train_input_(j, i) = array->Value(j); - // else - // test_input_(j - train_length, i) = array->Value(j); - // // m(j, i) = array->Value(j); - // } - // } else { - // auto array = - // std::static_pointer_cast(table->column(i)->chunk(0)); - // for (int64_t j = 0; j < array->length(); j++) { - // if (j < train_length) - // train_input_(j, i) = array->Value(j); - // else - // test_input_(j - train_length, i) = array->Value(j); - // // m(j, i) = array->Value(j); - // } - // } - // } + LOG(INFO) << "array_len: "<schema()->GetFieldByName(col_names[i - 1])->type(); auto data_type = type_ptr->id(); std::string type_name = type_ptr->name(); - if (data_type == arrow::Type::type::INT64) { - auto array = std::static_pointer_cast( - table->column(i - 1)->chunk(0)); - - for (int64_t j = 0; j < array->length(); j++) { - if (j < train_length) { - train_input_(j, i) = array->Value(j); - } else { - test_input_(j - train_length, i) = array->Value(j); + auto chunk_array = table->column(i - 1); + int chunk_size = chunk_array->num_chunks(); + int64_t train_data_start = 0; + int64_t test_data_start = 0; + for (int chk_i = 0; chk_i < chunk_size; chk_i++) { + auto array = chunk_array->chunk(chk_i); + auto train_length = train_data_each_chunk[chk_i]; + auto test_length = array->length() - train_length; + if (data_type == arrow::Type::type::INT64) { + auto ret = FillTrainAndTestData(array, + train_data_start, + test_data_start, + i, + train_length); + if (ret != retcode::SUCCESS) { + std::stringstream ss; + ss << "Column " << col_names[i-1] << ", " + << "Convert Data From " << type_name << " To Int64 failed"; + RaiseException(ss.str()); + } + } else if (data_type == arrow::Type::type::INT32) { + auto ret = FillTrainAndTestData(array, + train_data_start, + test_data_start, + i, + train_length); + if (ret != retcode::SUCCESS) { + std::stringstream ss; + ss << "Column " << col_names[i-1] << ", " + << "Convert Data From " << type_name << " To Int32 failed"; + RaiseException(ss.str()); + } + } else { + auto ret = FillTrainAndTestData(array, + train_data_start, + test_data_start, + i, + train_length); + if (ret != retcode::SUCCESS) { + std::stringstream ss; + ss << "Column " << col_names[i-1] << ", " + << "Convert Data From " << type_name << " To Double failed"; + RaiseException(ss.str()); } - // m(j, i) = array->Value(j); - } - } else { - auto chunk_array = table->column(i - 1)->chunk(0); - auto array = std::dynamic_pointer_cast(chunk_array); - if (array == nullptr) { - std::stringstream ss; - ss << "Column " << col_names[i-1] - << ", Convert Data From " << type_name << " To Double failed"; - RaiseException(ss.str()); - } - for (int64_t j = 0; j < array->length(); j++) { - if (j < train_length) - train_input_(j, i) = array->Value(j); - else - test_input_(j - train_length, i) = array->Value(j); - // m(j, i) = array->Value(j); } + train_data_start += train_length; + test_data_start += test_length; } } } - return array_len; } diff --git a/src/primihub/algorithm/logistic.h b/src/primihub/algorithm/logistic.h index 4abc5d51..063795bd 100755 --- a/src/primihub/algorithm/logistic.h +++ b/src/primihub/algorithm/logistic.h @@ -79,6 +79,30 @@ class LogisticRegressionExecutor : public AlgorithmBase { uint16_t NextPartyId() {return (local_id_ + 1) % 3;} uint16_t PrevPartyId() {return (local_id_ + 2) % 3;} + template + retcode FillTrainAndTestData(std::shared_ptr chunk_array, + const int64_t train_fill_start_pos, + const int64_t test_fill_start_pos, + const int col_index, + int64_t train_length) { + auto array = std::dynamic_pointer_cast(chunk_array); + if (array == nullptr) { + return retcode::FAIL; + } + int64_t train_data_pos = train_fill_start_pos; + int64_t test_data_pos = test_fill_start_pos; + for (int64_t i = 0; i < train_length; i++) { + this->train_input_(train_data_pos, col_index) = array->Value(i); + train_data_pos++; + } + for (int64_t i = train_length; i < array->length(); i++) { + this->test_input_(test_data_pos, col_index) = array->Value(i); + test_data_pos++; + } + return retcode::SUCCESS; + } + + private: std::string model_file_name_; std::string model_name_; uint16_t local_id_; diff --git a/src/primihub/data_store/csv/csv_driver.cc b/src/primihub/data_store/csv/csv_driver.cc index 3ba6b1c4..d2d765f3 100644 --- a/src/primihub/data_store/csv/csv_driver.cc +++ b/src/primihub/data_store/csv/csv_driver.cc @@ -221,11 +221,8 @@ std::shared_ptr ReadCSVFile(const std::string& file_path, << "detail: " << result_ifstream.status(); RaiseException(ss.str()); } - int64_t file_size = FileSize(file_path); - ReadOptions read_opt_ = read_opt; - read_opt_.block_size = file_size; std::shared_ptr input = result_ifstream.ValueOrDie(); - return Read(input, read_opt_, parse_opt, convert_opt); + return Read(input, read_opt, parse_opt, convert_opt); } std::string ReadRawData(const std::string& file_path, int64_t line_number) {