Skip to content

Commit

Permalink
manage data driver name (primihub#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix20162016 authored Dec 27, 2023
1 parent 8f85a2e commit f8174d1
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 126 deletions.
6 changes: 6 additions & 0 deletions src/primihub/data_store/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cc_library(
"driver.cc",
],
deps = [
":driver_constant",
"//src/primihub/common:common_defination",
"//src/primihub/util:arrow_wrapper_util",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
Expand All @@ -51,3 +52,8 @@ cc_library(
"//src/primihub/common:data_type_defination",
],
)

cc_library(
name = "driver_constant",
hdrs = ["driver_constant.h"],
)
4 changes: 2 additions & 2 deletions src/primihub/data_store/csv/csv_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ std::string ReadRawData(const std::string& file_path, int64_t line_number) {
std::string CSVAccessInfo::toString() {
std::stringstream ss;
nlohmann::json js;
js["type"] = "csv";
js["type"] = kDriveType[DriverType::CSV];
js["data_path"] = this->file_path_;
js["schema"] = SchemaToJsonString();
ss << js;
Expand Down Expand Up @@ -548,7 +548,7 @@ CSVDriver::CSVDriver(const std::string &nodelet_addr,
}

void CSVDriver::setDriverType() {
driver_type = "CSV";
driver_type = kDriveType[DriverType::CSV];
}

retcode CSVDriver::GetColumnNames(const char delimiter,
Expand Down
1 change: 1 addition & 0 deletions src/primihub/data_store/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

#include "src/primihub/data_store/dataset.h"
#include "src/primihub/common/common.h"
#include "src/primihub/data_store/driver_constant.h"

namespace primihub {

Expand Down
38 changes: 38 additions & 0 deletions src/primihub/data_store/driver_constant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2023 by PrimiHub
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SRC_PRIMIHUB_DATA_STORE_DRIVER_CONSTANT_H_
#define SRC_PRIMIHUB_DATA_STORE_DRIVER_CONSTANT_H_
#include <map>
#include <string>

namespace primihub {
enum class DriverType {
CSV = 0,
SQLITE,
HDFS,
MYSQL,
IMAGE,
};

static std::map<DriverType, std::string> kDriveType = {
{DriverType::CSV, "CSV"},
{DriverType::SQLITE, "SQLITE"},
{DriverType::HDFS, "HDFS"},
{DriverType::MYSQL, "MYSQL"},
{DriverType::IMAGE, "IMAGE"},
};
} // namespace primihub
#endif // SRC_PRIMIHUB_DATA_STORE_DRIVER_CONSTANT_H_
190 changes: 89 additions & 101 deletions src/primihub/data_store/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,124 +28,112 @@
#ifdef ENABLE_MYSQL_DRIVER
#include "src/primihub/data_store/mysql/mysql_driver.h"
#endif
#define CSV_DRIVER_NAME "CSV"
#define SQLITE_DRIVER_NAME "SQLITE"
#define HDFS_DRIVER_NAME "HDFS"
#define MYSQL_DRIVER_NAME "MYSQL"
#define IMAGE_DRIVER_NAME "IMAGE"
#include "src/primihub/common/value_check_util.h"

namespace primihub {
class DataDirverFactory {
public:
using DataSetAccessInfoPtr = std::unique_ptr<DataSetAccessInfo>;
static std::shared_ptr<DataDriver>
getDriver(const std::string &dirverName,
const std::string& nodeletAddr,
DataSetAccessInfoPtr access_info = nullptr) {
if (boost::to_upper_copy(dirverName) == CSV_DRIVER_NAME) {
if (access_info == nullptr) {
return std::make_shared<CSVDriver>(nodeletAddr);
} else {
return std::make_shared<CSVDriver>(nodeletAddr, std::move(access_info));
}
} else if (dirverName == HDFS_DRIVER_NAME) {
// return new HDFSDriver(dirverName);
// TODO not implemented yet
} else if (boost::to_upper_copy(dirverName) == SQLITE_DRIVER_NAME) {
if (access_info == nullptr) {
return std::make_shared<SQLiteDriver>(nodeletAddr);
} else {
return std::make_shared<SQLiteDriver>(nodeletAddr, std::move(access_info));
}
} else if (boost::to_upper_copy(dirverName) == MYSQL_DRIVER_NAME) {
using DataSetAccessInfoPtr = std::unique_ptr<DataSetAccessInfo>;
using DataDriverPtr = std::shared_ptr<DataDriver>;
static DataDriverPtr getDriver(const std::string &dirverName,
const std::string& nodeletAddr,
DataSetAccessInfoPtr access_info = nullptr) {
DataDriverPtr driver_ptr{nullptr};
std::string driver_name = strToUpper(dirverName);
if (driver_name == kDriveType[DriverType::CSV]) {
driver_ptr = std::make_shared<CSVDriver>(nodeletAddr,
std::move(access_info));
} else if (driver_name == kDriveType[DriverType::SQLITE]) {
driver_ptr = std::make_shared<SQLiteDriver>(nodeletAddr,
std::move(access_info));
} else if (driver_name == kDriveType[DriverType::MYSQL]) {
#ifdef ENABLE_MYSQL_DRIVER
if (access_info == nullptr) {
return std::make_shared<MySQLDriver>(nodeletAddr);
} else {
return std::make_shared<MySQLDriver>(nodeletAddr, std::move(access_info));
}
driver_ptr = std::make_shared<MySQLDriver>(nodeletAddr,
std::move(access_info));
#else
std::string err_msg = "MySQL is not enabled";
LOG(ERROR) << err_msg;
throw std::invalid_argument(err_msg);
std::string err_msg = "MySQL is not enabled";
RaiseException(err_msg);
#endif
} else if (boost::to_upper_copy(dirverName) == IMAGE_DRIVER_NAME) {
if (access_info == nullptr) {
return std::make_shared<ImageDriver>(nodeletAddr);
} else {
return std::make_shared<ImageDriver>(nodeletAddr, std::move(access_info));
}
} else {
std::string err_msg = "[DataDriverFactory]Invalid driver name [" + dirverName + "]";
throw std::invalid_argument(err_msg);
}
return nullptr;
} else if (driver_name == kDriveType[DriverType::IMAGE]) {
driver_ptr = std::make_shared<ImageDriver>(nodeletAddr,
std::move(access_info));
} else {
std::string err_msg =
"[DataDriverFactory] Invalid driver name [" + dirverName + "]";
RaiseException(err_msg);
}
// internal
static DataSetAccessInfoPtr createAccessInfoInternal(const std::string& driver_type) {
std::string drive_type_ = strToUpper(driver_type);
DataSetAccessInfoPtr access_info_ptr{nullptr};
if (drive_type_ == CSV_DRIVER_NAME) {
access_info_ptr = std::make_unique<CSVAccessInfo>();
} else if (drive_type_ == SQLITE_DRIVER_NAME) {
access_info_ptr = std::make_unique<SQLiteAccessInfo>();
} else if (drive_type_ == MYSQL_DRIVER_NAME) {
return driver_ptr;
}

// internal
static DataSetAccessInfoPtr createAccessInfoInternal(
const std::string& driver_type) {
std::string drive_type_ = strToUpper(driver_type);
DataSetAccessInfoPtr access_info_ptr{nullptr};
if (drive_type_ == kDriveType[DriverType::CSV]) {
access_info_ptr = std::make_unique<CSVAccessInfo>();
} else if (drive_type_ == kDriveType[DriverType::SQLITE]) {
access_info_ptr = std::make_unique<SQLiteAccessInfo>();
} else if (drive_type_ == kDriveType[DriverType::MYSQL]) {
#ifdef ENABLE_MYSQL_DRIVER
access_info_ptr = std::make_unique<MySQLAccessInfo>();
access_info_ptr = std::make_unique<MySQLAccessInfo>();
#else
LOG(ERROR) << "MySQL is not enabled";
std::string err_msg = "MySQL is not enabled";
RaiseException(err_msg);
#endif
} else if (drive_type_ == IMAGE_DRIVER_NAME) {
access_info_ptr = std::make_unique<ImageAccessInfo>();
} else {
LOG(ERROR) << "unsupported driver type: " << drive_type_;
return access_info_ptr;
}
return access_info_ptr;
} else if (drive_type_ == kDriveType[DriverType::IMAGE]) {
access_info_ptr = std::make_unique<ImageAccessInfo>();
} else {
std::string err_msg = "unsupported driver type: " + drive_type_;
RaiseException(err_msg);
}
return access_info_ptr;
}

static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type, const std::string& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
auto ret = access_info_ptr->fromJsonString(meta_info);
if (ret == retcode::FAIL) {
LOG(ERROR) << "create dataset access info failed";
return nullptr;
}
return access_info_ptr;
static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type,
const std::string& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
auto ret = access_info_ptr->fromJsonString(meta_info);
if (ret == retcode::FAIL) {
std::string err_msg = "create dataset access info failed";
RaiseException(err_msg);
}
return access_info_ptr;
}

static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type, const YAML::Node& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
// init
auto ret = access_info_ptr->fromYamlConfig(meta_info);
if (ret == retcode::FAIL) {
LOG(ERROR) << "create dataset access info failed";
return nullptr;
}
return access_info_ptr;
static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type, const YAML::Node& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
// init
auto ret = access_info_ptr->fromYamlConfig(meta_info);
if (ret == retcode::FAIL) {
std::string err_msg = "create dataset access info failed";
RaiseException(err_msg);
}
return access_info_ptr;
}

static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type, const DatasetMetaInfo& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
// init
auto ret = access_info_ptr->FromMetaInfo(meta_info);
if (ret == retcode::FAIL) {
LOG(ERROR) << "create dataset access info failed";
return nullptr;
}
return access_info_ptr;
static DataSetAccessInfoPtr createAccessInfo(
const std::string& driver_type, const DatasetMetaInfo& meta_info) {
auto access_info_ptr = createAccessInfoInternal(driver_type);
if (access_info_ptr == nullptr) {
return nullptr;
}
// init
auto ret = access_info_ptr->FromMetaInfo(meta_info);
if (ret == retcode::FAIL) {
std::string err_msg = "create dataset access info failed";
RaiseException(err_msg);
}
return access_info_ptr;
}
};

} // namespace primihub
Expand Down
4 changes: 2 additions & 2 deletions src/primihub/data_store/image/image_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ std::string ImageAccessInfo::toString() {
nlohmann::json js;
js["image_dir"] = this->image_dir_;
js["annotations_file"] = this->annotations_file_;
js["type"] = "image";
js["type"] = kDriveType[DriverType::IMAGE];
ss << js;
return ss.str();
}
Expand Down Expand Up @@ -150,7 +150,7 @@ ImageDriver::ImageDriver(const std::string &nodelet_addr,
}

void ImageDriver::setDriverType() {
driver_type = "Image";
driver_type = kDriveType[DriverType::IMAGE];
}

std::unique_ptr<Cursor> ImageDriver::read() {
Expand Down
4 changes: 2 additions & 2 deletions src/primihub/data_store/mysql/mysql_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace primihub {
std::string MySQLAccessInfo::toString() {
std::stringstream ss;
nlohmann::json js;
js["type"] = "mysql";
js["type"] = kDriveType[DriverType::MYSQL];
js["host"] = this->ip;
js["port"] = this->port;
js["username"] = this->user_name;
Expand Down Expand Up @@ -347,7 +347,7 @@ MySQLDriver::~MySQLDriver() {
}

void MySQLDriver::setDriverType() {
driver_type = "MySQL";
driver_type = kDriveType[DriverType::MYSQL];;
}

retcode MySQLDriver::releaseMySqlLib() {
Expand Down
4 changes: 2 additions & 2 deletions src/primihub/data_store/sqlite/sqlite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace primihub {
std::string SQLiteAccessInfo::toString() {
std::stringstream ss;
nlohmann::json js;
js["type"] = "sqlite";
js["type"] = kDriveType[DriverType::SQLITE];
js["db_path"] = this->db_path_;
js["tableName"] = this->table_name_;
// ss << std::setw(4) << js;
Expand Down Expand Up @@ -461,7 +461,7 @@ SQLiteDriver::SQLiteDriver(const std::string &nodelet_addr,
}

void SQLiteDriver::setDriverType() {
driver_type = "SQLITE";
driver_type = kDriveType[DriverType::SQLITE];
}

std::unique_ptr<Cursor> SQLiteDriver::read() {
Expand Down
6 changes: 4 additions & 2 deletions src/primihub/kernel/psi/operator/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifdef SGX
#include "src/primihub/kernel/psi/operator/tee_psi.h"
#endif // SGX
#include "src/primihub/common/value_check_util.h"

namespace primihub::psi {
class Factory {
Expand Down Expand Up @@ -42,8 +43,9 @@ class Factory {
auto tee_engine = reinterpret_cast<sgx::TeeEngine*>(executor);
return std::make_unique<TeePsiOperator>(options, tee_engine);
#else
LOG(ERROR) << "sgx is not enabled";
return nullptr;
std::string err_msg{"sgx is not enabled"};
LOG(ERROR) << err_msg;
RaiseException(err_msg);
#endif
}

Expand Down
Loading

0 comments on commit f8174d1

Please sign in to comment.