Skip to content

Commit

Permalink
mpc-lr read int32 type data as label (primihub#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix20162016 authored Feb 26, 2024
1 parent e28446b commit ab2581b
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 65 deletions.
123 changes: 62 additions & 61 deletions src/primihub/algorithm/logistic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
using arrow::Array;
using arrow::DoubleArray;
using arrow::Int64Array;
using arrow::Int32Array;
using arrow::Table;
namespace primihub {
eMatrix<double> logistic_main(sf64Matrix<D> &train_data_0_1,
Expand Down Expand Up @@ -189,12 +190,18 @@ int LogisticRegressionExecutor::_LoadDataset(const std::string &dataset_id) {
int64_t array_len{0};
int chunk_size = table->column(num_col - 1)->num_chunks();

std::vector<int64_t> train_data_each_chunk(chunk_size);
auto& chunks = table->column(num_col - 1)->chunks();
for (const auto& array : chunks) {
array_len += array->length();
double training_set_percentage = 0.8;
for (int i = 0; i < chunk_size; i++) {
auto& array = chunks[i];
int64_t len = array->length();
train_data_each_chunk[i] = floor(len * training_set_percentage);
array_len += len;
}
VLOG(3) << "Label column '" << col_names[num_col - 1] << "' has " << array_len
<< " values.";

VLOG(3) << "Label column '" << col_names[num_col - 1]
<< "' has " << array_len << " values.";

// Force the same value count in every column.
for (int i = 0; i < num_col - 1; i++) {
Expand All @@ -215,37 +222,13 @@ int LogisticRegressionExecutor::_LoadDataset(const std::string &dataset_id) {

if (errors)
return -1;
int64_t train_length = floor(array_len * 0.8);
int64_t train_length = std::accumulate(train_data_each_chunk.begin(),
train_data_each_chunk.end(),
0);
int64_t test_length = array_len - train_length;
// LOG(INFO)<<"array_len: "<<array_len;
// LOG(INFO)<<"train_length: "<<train_length;
// LOG(INFO)<<"test_length: "<<test_length;
// train_input_.resize(train_length, num_col);
// test_input_.resize(test_length, num_col);
// // m.resize(array_len, num_col);
// for (int i = 0; i < num_col; i++) {
// if (table->schema()->GetFieldByName(col_names[i])->type()->id() == 9) {
// auto array =
// std::static_pointer_cast<Int64Array>(table->column(i)->chunk(0));
// for (int64_t j = 0; j < array->length(); j++) {
// if (j < train_length)
// train_input_(j, i) = array->Value(j);
// else
// test_input_(j - train_length, i) = array->Value(j);
// // m(j, i) = array->Value(j);
// }
// } else {
// auto array =
// std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(0));
// for (int64_t j = 0; j < array->length(); j++) {
// if (j < train_length)
// train_input_(j, i) = array->Value(j);
// else
// test_input_(j - train_length, i) = array->Value(j);
// // m(j, i) = array->Value(j);
// }
// }
// }
LOG(INFO) << "array_len: "<<array_len;
LOG(INFO) << "train_length: "<<train_length;
LOG(INFO) << "test_length: "<<test_length;
train_input_.resize(train_length, num_col + 1);
test_input_.resize(test_length, num_col + 1);
// m.resize(array_len, num_col);
Expand All @@ -263,38 +246,56 @@ int LogisticRegressionExecutor::_LoadDataset(const std::string &dataset_id) {
auto type_ptr = table->schema()->GetFieldByName(col_names[i - 1])->type();
auto data_type = type_ptr->id();
std::string type_name = type_ptr->name();
if (data_type == arrow::Type::type::INT64) {
auto array = std::static_pointer_cast<Int64Array>(
table->column(i - 1)->chunk(0));

for (int64_t j = 0; j < array->length(); j++) {
if (j < train_length) {
train_input_(j, i) = array->Value(j);
} else {
test_input_(j - train_length, i) = array->Value(j);
auto chunk_array = table->column(i - 1);
int chunk_size = chunk_array->num_chunks();
int64_t train_data_start = 0;
int64_t test_data_start = 0;
for (int chk_i = 0; chk_i < chunk_size; chk_i++) {
auto array = chunk_array->chunk(chk_i);
auto train_length = train_data_each_chunk[chk_i];
auto test_length = array->length() - train_length;
if (data_type == arrow::Type::type::INT64) {
auto ret = FillTrainAndTestData<Int64Array>(array,
train_data_start,
test_data_start,
i,
train_length);
if (ret != retcode::SUCCESS) {
std::stringstream ss;
ss << "Column " << col_names[i-1] << ", "
<< "Convert Data From " << type_name << " To Int64 failed";
RaiseException(ss.str());
}
} else if (data_type == arrow::Type::type::INT32) {
auto ret = FillTrainAndTestData<Int32Array>(array,
train_data_start,
test_data_start,
i,
train_length);
if (ret != retcode::SUCCESS) {
std::stringstream ss;
ss << "Column " << col_names[i-1] << ", "
<< "Convert Data From " << type_name << " To Int32 failed";
RaiseException(ss.str());
}
} else {
auto ret = FillTrainAndTestData<DoubleArray>(array,
train_data_start,
test_data_start,
i,
train_length);
if (ret != retcode::SUCCESS) {
std::stringstream ss;
ss << "Column " << col_names[i-1] << ", "
<< "Convert Data From " << type_name << " To Double failed";
RaiseException(ss.str());
}
// m(j, i) = array->Value(j);
}
} else {
auto chunk_array = table->column(i - 1)->chunk(0);
auto array = std::dynamic_pointer_cast<DoubleArray>(chunk_array);
if (array == nullptr) {
std::stringstream ss;
ss << "Column " << col_names[i-1]
<< ", Convert Data From " << type_name << " To Double failed";
RaiseException(ss.str());
}
for (int64_t j = 0; j < array->length(); j++) {
if (j < train_length)
train_input_(j, i) = array->Value(j);
else
test_input_(j - train_length, i) = array->Value(j);
// m(j, i) = array->Value(j);
}
train_data_start += train_length;
test_data_start += test_length;
}
}
}

return array_len;
}

Expand Down
24 changes: 24 additions & 0 deletions src/primihub/algorithm/logistic.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ class LogisticRegressionExecutor : public AlgorithmBase {
uint16_t NextPartyId() {return (local_id_ + 1) % 3;}
uint16_t PrevPartyId() {return (local_id_ + 2) % 3;}

template<typename T>
retcode FillTrainAndTestData(std::shared_ptr<arrow::Array> chunk_array,
const int64_t train_fill_start_pos,
const int64_t test_fill_start_pos,
const int col_index,
int64_t train_length) {
auto array = std::dynamic_pointer_cast<T>(chunk_array);
if (array == nullptr) {
return retcode::FAIL;
}
int64_t train_data_pos = train_fill_start_pos;
int64_t test_data_pos = test_fill_start_pos;
for (int64_t i = 0; i < train_length; i++) {
this->train_input_(train_data_pos, col_index) = array->Value(i);
train_data_pos++;
}
for (int64_t i = train_length; i < array->length(); i++) {
this->test_input_(test_data_pos, col_index) = array->Value(i);
test_data_pos++;
}
return retcode::SUCCESS;
}

private:
std::string model_file_name_;
std::string model_name_;
uint16_t local_id_;
Expand Down
5 changes: 1 addition & 4 deletions src/primihub/data_store/csv/csv_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,8 @@ std::shared_ptr<arrow::Table> ReadCSVFile(const std::string& file_path,
<< "detail: " << result_ifstream.status();
RaiseException(ss.str());
}
int64_t file_size = FileSize(file_path);
ReadOptions read_opt_ = read_opt;
read_opt_.block_size = file_size;
std::shared_ptr<arrow::io::InputStream> input = result_ifstream.ValueOrDie();
return Read(input, read_opt_, parse_opt, convert_opt);
return Read(input, read_opt, parse_opt, convert_opt);
}

std::string ReadRawData(const std::string& file_path, int64_t line_number) {
Expand Down

0 comments on commit ab2581b

Please sign in to comment.