diff --git a/src/catalog/abstract_catalog.cpp b/src/catalog/abstract_catalog.cpp index 98879a0d36c..5894941a37e 100644 --- a/src/catalog/abstract_catalog.cpp +++ b/src/catalog/abstract_catalog.cpp @@ -32,6 +32,8 @@ #include "executor/index_scan_executor.h" #include "executor/insert_executor.h" #include "executor/seq_scan_executor.h" +#include "executor/update_executor.h" +#include "executor/plan_executor.h" #include "storage/database.h" #include "storage/storage_manager.h" @@ -272,5 +274,69 @@ void AbstractCatalog::AddIndex(const std::vector &key_attrs, index_name.c_str(), (int)catalog_table_->GetOid()); } +/*@brief Update specific columns using index scan + * @param update_columns Columns to be updated + * @param update_values Values to be updated + * @param scan_values Value to be scaned (used in index scan) + * @param index_offset Offset of index for scan + * @return true if successfully executes + */ +bool AbstractCatalog::UpdateWithIndexScan( + std::vector update_columns, std::vector update_values, + std::vector scan_values, oid_t index_offset, + concurrency::TransactionContext *txn) { + if (txn == nullptr) throw CatalogException("Scan table requires transaction"); + + std::unique_ptr context( + new executor::ExecutorContext(txn)); + // Construct index scan executor + auto index = catalog_table_->GetIndex(index_offset); + std::vector key_column_offsets = + index->GetMetadata()->GetKeySchema()->GetIndexedColumns(); + PELOTON_ASSERT(scan_values.size() == key_column_offsets.size()); + std::vector expr_types(scan_values.size(), + ExpressionType::COMPARE_EQUAL); + std::vector runtime_keys; + + planner::IndexScanPlan::IndexScanDesc index_scan_desc( + index, key_column_offsets, expr_types, scan_values, runtime_keys); + + planner::IndexScanPlan index_scan_node(catalog_table_, nullptr, + update_columns, index_scan_desc); + + executor::IndexScanExecutor index_scan_executor(&index_scan_node, + context.get()); + // Construct update executor + TargetList target_list; + DirectMapList direct_map_list; + + size_t column_count = catalog_table_->GetSchema()->GetColumnCount(); + for (size_t col_itr = 0; col_itr < column_count; col_itr++) { + // Skip any column for update + if (std::find(std::begin(update_columns), std::end(update_columns), + col_itr) == std::end(update_columns)) { + direct_map_list.emplace_back(col_itr, std::make_pair(0, col_itr)); + } + } + + PELOTON_ASSERT(update_columns.size() == update_values.size()); + for (size_t i = 0; i < update_values.size(); i++) { + planner::DerivedAttribute update_attribute{ + new expression::ConstantValueExpression(update_values[i])}; + target_list.emplace_back(update_columns[i], update_attribute); + } + + std::unique_ptr project_info( + new planner::ProjectInfo(std::move(target_list), + std::move(direct_map_list))); + planner::UpdatePlan update_node(catalog_table_, std::move(project_info)); + + executor::UpdateExecutor update_executor(&update_node, context.get()); + update_executor.AddChild(&index_scan_executor); + // Execute + update_executor.Init(); + return update_executor.Execute(); +} + } // namespace catalog } // namespace peloton diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 3ed19e68dc1..856377b9f8d 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -25,10 +25,12 @@ #include "catalog/table_catalog.h" #include "catalog/table_metrics_catalog.h" #include "catalog/trigger_catalog.h" +#include "catalog/sequence_catalog.h" #include "concurrency/transaction_manager_factory.h" #include "function/date_functions.h" #include "function/decimal_functions.h" #include "function/old_engine_string_functions.h" +#include "function/string_functions.h" #include "function/timestamp_functions.h" #include "index/index_factory.h" #include "settings/settings_manager.h" @@ -148,12 +150,13 @@ void Catalog::Bootstrap() { DatabaseMetricsCatalog::GetInstance(txn); TableMetricsCatalog::GetInstance(txn); IndexMetricsCatalog::GetInstance(txn); - QueryMetricsCatalog::GetInstance(txn); + QueryMetricsCatalog::GetInstance(txn); SettingsCatalog::GetInstance(txn); TriggerCatalog::GetInstance(txn); LanguageCatalog::GetInstance(txn); ProcCatalog::GetInstance(txn); - + SequenceCatalog::GetInstance(txn); + if (settings::SettingsManager::GetBool(settings::SettingId::brain)) { QueryHistoryCatalog::GetInstance(txn); } @@ -1060,6 +1063,20 @@ void Catalog::InitializeFunctions() { function::BuiltInFuncType{OperatorId::Like, function::OldEngineStringFunctions::Like}, txn); + // Sequence + AddBuiltinFunction( + "nextval", {type::TypeId::VARCHAR}, type::TypeId::INTEGER, + internal_lang, "Nextval", + function::BuiltInFuncType{OperatorId::Nextval, + function::OldEngineStringFunctions::Nextval}, + txn); + AddBuiltinFunction( + "currval", {type::TypeId::VARCHAR}, type::TypeId::INTEGER, + internal_lang, "Currval", + function::BuiltInFuncType{OperatorId::Currval, + function::OldEngineStringFunctions::Currval}, + txn); + /** * decimal functions @@ -1106,28 +1123,28 @@ void Catalog::InitializeFunctions() { * integer functions */ AddBuiltinFunction( - "abs", {type::TypeId::TINYINT}, type::TypeId::TINYINT, + "abs", {type::TypeId::TINYINT}, type::TypeId::TINYINT, internal_lang, "Abs", function::BuiltInFuncType{OperatorId::Abs, function::DecimalFunctions::_Abs}, txn); AddBuiltinFunction( - "abs", {type::TypeId::SMALLINT}, type::TypeId::SMALLINT, + "abs", {type::TypeId::SMALLINT}, type::TypeId::SMALLINT, internal_lang, "Abs", function::BuiltInFuncType{OperatorId::Abs, function::DecimalFunctions::_Abs}, txn); AddBuiltinFunction( - "abs", {type::TypeId::INTEGER}, type::TypeId::INTEGER, + "abs", {type::TypeId::INTEGER}, type::TypeId::INTEGER, internal_lang, "Abs", function::BuiltInFuncType{OperatorId::Abs, function::DecimalFunctions::_Abs}, txn); AddBuiltinFunction( - "abs", {type::TypeId::BIGINT}, type::TypeId::BIGINT, + "abs", {type::TypeId::BIGINT}, type::TypeId::BIGINT, internal_lang, "Abs", function::BuiltInFuncType{OperatorId::Abs, function::DecimalFunctions::_Abs}, diff --git a/src/catalog/sequence_catalog.cpp b/src/catalog/sequence_catalog.cpp new file mode 100644 index 00000000000..43ae067c6d6 --- /dev/null +++ b/src/catalog/sequence_catalog.cpp @@ -0,0 +1,300 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// sequence_catalog.h +// +// Identification: src/catalog/sequence_catalog.cpp +// +// Copyright (c) 2015-17, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include + +#include "catalog/sequence_catalog.h" + +#include "catalog/catalog.h" +#include "catalog/database_catalog.h" +#include "catalog/table_catalog.h" +#include "common/internal_types.h" +#include "storage/data_table.h" +#include "type/value_factory.h" +#include "function/functions.h" +#include "planner/update_plan.h" +#include "executor/update_executor.h" +#include "executor/executor_context.h" +#include "optimizer/optimizer.h" +#include "parser/postgresparser.h" + +namespace peloton { +namespace catalog { + +/* @brief Get the nextval of the sequence + * @return the next value of the sequence. + * @exception throws SequenceException if the sequence exceeds the upper/lower + * limit. + */ +int64_t SequenceCatalogObject::get_next_val() { + int64_t result = seq_curr_val; + if (seq_increment > 0) { + if ((seq_max >= 0 && seq_curr_val > seq_max - seq_increment) || + (seq_max < 0 && seq_curr_val + seq_increment > seq_max)) { + if (!seq_cycle) { + throw SequenceException( + StringUtil::Format("Sequence exceeds upper limit!")); + } + seq_curr_val = seq_min; + } else + seq_curr_val += seq_increment; + } else { + if ((seq_min < 0 && seq_curr_val < seq_min - seq_increment) || + (seq_min >= 0 && seq_curr_val + seq_increment < seq_min)) { + if (!seq_cycle) { + throw SequenceException( + StringUtil::Format("Sequence exceeds lower limit!")); + } + seq_curr_val = seq_max; + } else + seq_curr_val += seq_increment; + } + + // TODO: this will become visible after Mengran's team push the + // AbstractCatalog::UpdateWithIndexScan. + // Link for the function: + // https://github.com/camellyx/peloton/blob/master/src/catalog/abstract_catalog.cpp#L305 + bool status = catalog::SequenceCatalog::GetInstance().UpdateNextVal(seq_oid, seq_curr_val, txn_); + LOG_DEBUG("status of update pg_sequence: %d", status); + + return result; +} + +SequenceCatalog &SequenceCatalog::GetInstance( + concurrency::TransactionContext *txn) { + static SequenceCatalog sequence_catalog{txn}; + return sequence_catalog; +} + +SequenceCatalog::SequenceCatalog(concurrency::TransactionContext *txn) + : AbstractCatalog("CREATE TABLE " SEQUENCE_CATALOG_NAME + " (" + "oid INT NOT NULL PRIMARY KEY, " + "sqdboid INT NOT NULL, " + "sqname VARCHAR NOT NULL, " + "sqinc BIGINT NOT NULL, " + "sqmax BIGINT NOT NULL, " + "sqmin BIGINT NOT NULL, " + "sqstart BIGINT NOT NULL, " + "sqcycle BOOLEAN NOT NULL, " + "sqval BIGINT NOT NULL);", + txn) { + Catalog::GetInstance()->CreateIndex( + CATALOG_DATABASE_NAME, SEQUENCE_CATALOG_NAME, + {ColumnId::DATABSE_OID, ColumnId::SEQUENCE_NAME}, + SEQUENCE_CATALOG_NAME "_skey0", false, IndexType::BWTREE, txn); +} + +SequenceCatalog::~SequenceCatalog() {} + +/* @brief Delete the sequence by name. + * @param database_oid the databse_oid associated with the sequence + * @param sequence_name the name of the sequence + * @param seq_increment the increment per step of the sequence + * @param seq_max the max value of the sequence + * @param seq_min the min value of the sequence + * @param seq_start the start of the sequence + * @param seq_cycle whether the sequence cycles + * @param pool an instance of abstract pool + * @param txn current transaction + * @return ResultType::SUCCESS if the sequence exists, ResultType::FAILURE + * otherwise. + * @exception throws SequenceException if the sequence already exists. + */ +bool SequenceCatalog::InsertSequence(oid_t database_oid, + std::string sequence_name, + int64_t seq_increment, int64_t seq_max, + int64_t seq_min, int64_t seq_start, + bool seq_cycle, type::AbstractPool *pool, + concurrency::TransactionContext *txn) { + LOG_DEBUG("Insert Sequence Database Oid: %u", database_oid); + LOG_DEBUG("Insert Sequence Sequence Name: %s", sequence_name.c_str()); + if (GetSequence(database_oid, sequence_name, txn) != nullptr) { + throw SequenceException( + StringUtil::Format("Insert Sequence with Duplicate Sequence Name: %s", + sequence_name.c_str())); + } + + std::unique_ptr tuple( + new storage::Tuple(catalog_table_->GetSchema(), true)); + + auto val0 = type::ValueFactory::GetIntegerValue(GetNextOid()); + auto val1 = type::ValueFactory::GetIntegerValue(database_oid); + auto val2 = type::ValueFactory::GetVarcharValue(sequence_name); + auto val3 = type::ValueFactory::GetBigIntValue(seq_increment); + auto val4 = type::ValueFactory::GetBigIntValue(seq_max); + auto val5 = type::ValueFactory::GetBigIntValue(seq_min); + auto val6 = type::ValueFactory::GetBigIntValue(seq_start); + auto val7 = type::ValueFactory::GetBooleanValue(seq_cycle); + // When insert value, seqval = seq_start + auto val8 = type::ValueFactory::GetBigIntValue(seq_start); + + tuple->SetValue(ColumnId::SEQUENCE_OID, val0, pool); + tuple->SetValue(ColumnId::DATABSE_OID, val1, pool); + tuple->SetValue(ColumnId::SEQUENCE_NAME, val2, pool); + tuple->SetValue(ColumnId::SEQUENCE_INC, val3, pool); + tuple->SetValue(ColumnId::SEQUENCE_MAX, val4, pool); + tuple->SetValue(ColumnId::SEQUENCE_MIN, val5, pool); + tuple->SetValue(ColumnId::SEQUENCE_START, val6, pool); + tuple->SetValue(ColumnId::SEQUENCE_CYCLE, val7, pool); + tuple->SetValue(ColumnId::SEQUENCE_VALUE, val8, pool); + + // Insert the tuple + return InsertTuple(std::move(tuple), txn); +} + +/* @brief Delete the sequence by name. + * @param database_oid the databse_oid associated with the sequence + * @param sequence_name the name of the sequence + * @param txn current transaction + * @return ResultType::SUCCESS if the sequence exists, ResultType::FAILURE + * otherwise. + */ +ResultType SequenceCatalog::DropSequence(const std::string &database_name, + const std::string &sequence_name, + concurrency::TransactionContext *txn) { + if (txn == nullptr) { + LOG_TRACE("Do not have transaction to drop sequence: %s", + database_name.c_str()); + return ResultType::FAILURE; + } + + auto database_object = + Catalog::GetInstance()->GetDatabaseObject(database_name, txn); + + oid_t sequence_oid = SequenceCatalog::GetInstance().GetSequenceOid( + sequence_name, database_object->GetDatabaseOid(), txn); + if (sequence_oid == INVALID_OID) { + LOG_TRACE("Cannot find sequence %s to drop!", sequence_name.c_str()); + return ResultType::FAILURE; + } + + LOG_INFO("sequence %d will be deleted!", sequence_oid); + + oid_t database_oid = database_object->GetDatabaseOid(); + DeleteSequenceByName(sequence_name, database_oid, txn); + + return ResultType::SUCCESS; +} + +/* @brief Delete the sequence by name. The sequence is guaranteed to exist. + * @param database_oid the databse_oid associated with the sequence + * @param sequence_name the name of the sequence + * @param txn current transaction + * @return The result of DeleteWithIndexScan. + */ +bool SequenceCatalog::DeleteSequenceByName( + const std::string &sequence_name, oid_t database_oid, + concurrency::TransactionContext *txn) { + oid_t index_offset = IndexId::DBOID_SEQNAME_KEY; + std::vector values; + values.push_back(type::ValueFactory::GetIntegerValue(database_oid).Copy()); + values.push_back(type::ValueFactory::GetVarcharValue(sequence_name).Copy()); + + return DeleteWithIndexScan(index_offset, values, txn); +} + +/* @brief get sequence from pg_sequence table + * @param database_oid the databse_oid associated with the sequence + * @param sequence_name the name of the sequence + * @param txn current transaction + * @return a SequenceCatalogObject if the sequence is found, nullptr otherwise + */ +std::shared_ptr SequenceCatalog::GetSequence( + oid_t database_oid, const std::string &sequence_name, + concurrency::TransactionContext *txn) { + std::vector column_ids( + {ColumnId::SEQUENCE_OID, ColumnId::SEQUENCE_NAME, + ColumnId::SEQUENCE_START, ColumnId::SEQUENCE_INC, ColumnId::SEQUENCE_MAX, + ColumnId::SEQUENCE_MIN, ColumnId::SEQUENCE_CYCLE, + ColumnId::SEQUENCE_VALUE}); + oid_t index_offset = IndexId::DBOID_SEQNAME_KEY; + std::vector values; + values.push_back(type::ValueFactory::GetIntegerValue(database_oid).Copy()); + values.push_back(type::ValueFactory::GetVarcharValue(sequence_name).Copy()); + + // the result is a vector of executor::LogicalTile + auto result_tiles = + GetResultWithIndexScan(column_ids, index_offset, values, txn); + // carefull! the result tile could be null! + if (result_tiles == nullptr || result_tiles->size() == 0) { + LOG_INFO("no sequence on database %d and %s", database_oid, + sequence_name.c_str()); + return std::shared_ptr(nullptr); + } else { + LOG_INFO("size of the result tiles = %lu", result_tiles->size()); + } + + PELOTON_ASSERT(result_tiles->size() == 1); + size_t tuple_count = (*result_tiles)[0]->GetTupleCount(); + PELOTON_ASSERT(tuple_count == 1); + auto new_sequence = std::make_shared( + (*result_tiles)[0]->GetValue(0, 0).GetAs(), + (*result_tiles)[0]->GetValue(0, 1).ToString(), + (*result_tiles)[0]->GetValue(0, 2).GetAs(), + (*result_tiles)[0]->GetValue(0, 3).GetAs(), + (*result_tiles)[0]->GetValue(0, 4).GetAs(), + (*result_tiles)[0]->GetValue(0, 5).GetAs(), + (*result_tiles)[0]->GetValue(0, 6).GetAs(), + (*result_tiles)[0]->GetValue(0, 7).GetAs(), txn); + + return new_sequence; +} + +bool SequenceCatalog::UpdateNextVal(oid_t sequence_oid, int64_t nextval, + concurrency::TransactionContext *txn){ + std::vector update_columns({SequenceCatalog::ColumnId::SEQUENCE_VALUE}); + std::vector update_values; + update_values.push_back(type::ValueFactory::GetBigIntValue(nextval).Copy()); + std::vector scan_values; + scan_values.push_back(type::ValueFactory::GetIntegerValue(sequence_oid).Copy()); + oid_t index_offset = SequenceCatalog::IndexId::PRIMARY_KEY; + + return UpdateWithIndexScan(update_columns, update_values, scan_values, index_offset, txn); +} + +/* @brief get sequence oid from pg_sequence table given sequence_name and + * database_oid + * @param database_oid the databse_oid associated with the sequence + * @param sequence_name the name of the sequence + * @param txn current transaction + * @return the oid_t of the sequence if the sequence is found, INVALID_OID + * otherwise + */ +oid_t SequenceCatalog::GetSequenceOid(std::string sequence_name, + oid_t database_oid, + concurrency::TransactionContext *txn) { + std::vector column_ids({ColumnId::SEQUENCE_OID}); + oid_t index_offset = IndexId::DBOID_SEQNAME_KEY; + std::vector values; + values.push_back(type::ValueFactory::GetIntegerValue(database_oid).Copy()); + values.push_back(type::ValueFactory::GetVarcharValue(sequence_name).Copy()); + + // the result is a vector of executor::LogicalTile + auto result_tiles = + GetResultWithIndexScan(column_ids, index_offset, values, txn); + // carefull! the result tile could be null! + if (result_tiles == nullptr || result_tiles->size() == 0) { + LOG_INFO("no sequence on database %d and %s", database_oid, + sequence_name.c_str()); + return INVALID_OID; + } + + PELOTON_ASSERT(result_tiles->size() == 1); + oid_t result; + result = (*result_tiles)[0]->GetValue(0, 0).GetAs(); + + return result; +} + +} // namespace catalog +} // namespace peloton diff --git a/src/codegen/proxy/string_functions_proxy.cpp b/src/codegen/proxy/string_functions_proxy.cpp index 46a356b61dd..d8786cc5b29 100644 --- a/src/codegen/proxy/string_functions_proxy.cpp +++ b/src/codegen/proxy/string_functions_proxy.cpp @@ -33,5 +33,9 @@ DEFINE_METHOD(peloton::function, StringFunctions, Trim); DEFINE_METHOD(peloton::function, StringFunctions, LTrim); DEFINE_METHOD(peloton::function, StringFunctions, RTrim); +// Sequence-related functions +DEFINE_METHOD(peloton::function, StringFunctions, Nextval); +DEFINE_METHOD(peloton::function, StringFunctions, Currval); + } // namespace codegen } // namespace peloton diff --git a/src/codegen/type/varchar_type.cpp b/src/codegen/type/varchar_type.cpp index 0066457e425..52716c60d66 100644 --- a/src/codegen/type/varchar_type.cpp +++ b/src/codegen/type/varchar_type.cpp @@ -187,6 +187,46 @@ struct Trim : public TypeSystem::UnaryOperatorHandleNull { } }; +// Nextval +struct Nextval : public TypeSystem::UnaryOperatorHandleNull { + bool SupportsType(const Type &type) const override { + return type.GetSqlType() == Varchar::Instance(); + } + + Type ResultType(UNUSED_ATTRIBUTE const Type &val_type) const override { + return Integer::Instance(); + } + + Value Impl(CodeGen &codegen, const Value &val, + const TypeSystem::InvocationContext &ctx) const override { + llvm::Value *executor_ctx = ctx.executor_context; + llvm::Value *raw_ret = + codegen.Call(StringFunctionsProxy::Nextval, + {executor_ctx, val.GetValue()}); + return Value{Integer::Instance(), raw_ret}; + } +}; + +// Currval +struct Currval : public TypeSystem::UnaryOperatorHandleNull { + bool SupportsType(const Type &type) const override { + return type.GetSqlType() == Varchar::Instance(); + } + + Type ResultType(UNUSED_ATTRIBUTE const Type &val_type) const override { + return Integer::Instance(); + } + + Value Impl(CodeGen &codegen, const Value &val, + const TypeSystem::InvocationContext &ctx) const override { + llvm::Value *executor_ctx = ctx.executor_context; + llvm::Value *raw_ret = + codegen.Call(StringFunctionsProxy::Currval, + {executor_ctx, val.GetValue()}); + return Value{Integer::Instance(), raw_ret}; + } +}; + //////////////////////////////////////////////////////////////////////////////// /// /// Binary operators @@ -536,10 +576,14 @@ std::vector kComparisonTable = {{kCompareVarchar}}; Ascii kAscii; Length kLength; Trim kTrim; +Nextval kNextval; +Currval kCurrval; std::vector kUnaryOperatorTable = { {OperatorId::Ascii, kAscii}, {OperatorId::Length, kLength}, - {OperatorId::Trim, kTrim}}; + {OperatorId::Trim, kTrim}, + {OperatorId::Nextval, kNextval}, + {OperatorId::Currval, kCurrval}}; // Binary operations Like kLike; diff --git a/src/common/internal_types.cpp b/src/common/internal_types.cpp index 125d719b5d7..df58fead3e7 100644 --- a/src/common/internal_types.cpp +++ b/src/common/internal_types.cpp @@ -329,6 +329,9 @@ std::string CreateTypeToString(CreateType type) { case CreateType::TRIGGER: { return "TRIGGER"; } + case CreateType::SEQUENCE: { + return "SEQUENCE"; + } default: { throw ConversionException( StringUtil::Format("No string conversion for CreateType value '%d'", @@ -680,6 +683,9 @@ QueryType StatementTypeToQueryType(StatementType stmt_type, case parser::CreateStatement::CreateType::kView: query_type = QueryType::QUERY_CREATE_VIEW; break; + case parser::CreateStatement::CreateType::kSequence: + query_type = QueryType::QUERY_CREATE_SEQUENCE; + break; } break; } diff --git a/src/executor/create_executor.cpp b/src/executor/create_executor.cpp index 22c4eb77e2c..748a27921f6 100644 --- a/src/executor/create_executor.cpp +++ b/src/executor/create_executor.cpp @@ -14,6 +14,7 @@ #include "catalog/catalog.h" #include "catalog/foreign_key.h" +#include "catalog/sequence_catalog.h" #include "catalog/trigger_catalog.h" #include "catalog/database_catalog.h" #include "catalog/table_catalog.h" @@ -71,6 +72,12 @@ bool CreateExecutor::DExecute() { break; } + // if query was for creating sequence + case CreateType::SEQUENCE: { + result = CreateSequence(node); + break; + } + default: { std::string create_type = CreateTypeToString(node.GetCreateType()); LOG_ERROR("Not supported create type %s", create_type.c_str()); @@ -87,7 +94,6 @@ bool CreateExecutor::DExecute() { } bool CreateExecutor::CreateDatabase(const planner::CreatePlan &node) { - auto txn = context_->GetTransaction(); auto database_name = node.GetDatabaseName(); ResultType result = catalog::Catalog::GetInstance()->CreateDatabase( @@ -270,6 +276,33 @@ bool CreateExecutor::CreateTrigger(const planner::CreatePlan &node) { return (true); } +bool CreateExecutor::CreateSequence(const planner::CreatePlan &node) { + auto txn = context_->GetTransaction(); + std::string database_name = node.GetDatabaseName(); + std::string table_name = node.GetTableName(); + std::string sequence_name = node.GetSequenceName(); + + auto database_object = catalog::Catalog::GetInstance()->GetDatabaseObject( + database_name, txn); + + catalog::SequenceCatalog::GetInstance().InsertSequence( + database_object->GetDatabaseOid(), sequence_name, + node.GetSequenceIncrement(), node.GetSequenceMaxValue(), + node.GetSequenceMinValue(), node.GetSequenceStart(), + node.GetSequenceCycle(), pool_.get(), txn); + + if (txn->GetResult() == ResultType::SUCCESS) { + LOG_DEBUG("Creating sequence succeeded!"); + } else if (txn->GetResult() == ResultType::FAILURE) { + LOG_DEBUG("Creating sequence failed!"); + } else { + LOG_DEBUG("Result is: %s", + ResultTypeToString(txn->GetResult()).c_str()); + } + + return (true); +} + } // namespace executor } // namespace peloton diff --git a/src/executor/executor_context.cpp b/src/executor/executor_context.cpp index ae9281c13fe..72071dccfa9 100644 --- a/src/executor/executor_context.cpp +++ b/src/executor/executor_context.cpp @@ -16,12 +16,17 @@ #include "executor/executor_context.h" #include "concurrency/transaction_context.h" + namespace peloton { namespace executor { ExecutorContext::ExecutorContext(concurrency::TransactionContext *transaction, - codegen::QueryParameters parameters) - : transaction_(transaction), parameters_(std::move(parameters)) {} + codegen::QueryParameters parameters, + const std::string default_database_name) + : transaction_(transaction), parameters_(std::move(parameters)), + default_database_name_(default_database_name) { + LOG_DEBUG("ExecutorContext default db name: %s", default_database_name.c_str()); +} concurrency::TransactionContext *ExecutorContext::GetTransaction() const { return transaction_; @@ -43,5 +48,9 @@ type::EphemeralPool *ExecutorContext::GetPool() { return pool_.get(); } +std::string ExecutorContext::GetDatabaseName() const { + return default_database_name_; +} + } // namespace executor } // namespace peloton diff --git a/src/executor/plan_executor.cpp b/src/executor/plan_executor.cpp index 104aff1351c..414aa703603 100644 --- a/src/executor/plan_executor.cpp +++ b/src/executor/plan_executor.cpp @@ -36,9 +36,10 @@ static void CompileAndExecutePlan( std::shared_ptr plan, concurrency::TransactionContext *txn, const std::vector ¶ms, - std::function &&)> - on_complete) { + std::function &&)> on_complete, + std::string default_database_name) { LOG_TRACE("Compiling and executing query ..."); + LOG_DEBUG("CompileAndExecutePlan default db name: %s", default_database_name.c_str()); // Perform binding planner::BindingContext context; @@ -51,7 +52,8 @@ static void CompileAndExecutePlan( std::unique_ptr executor_context( new executor::ExecutorContext(txn, - codegen::QueryParameters(*plan, params))); + codegen::QueryParameters(*plan, params), + default_database_name)); // Compile the query codegen::Query *query = codegen::QueryCache::Instance().Find(plan); @@ -87,12 +89,14 @@ static void InterpretPlan( const std::vector ¶ms, const std::vector &result_format, std::function &&)> - on_complete) { + on_complete, + std::string default_database_name) { executor::ExecutionResult result; std::vector values; + LOG_DEBUG("InterpretPlan default db name: %s", default_database_name.c_str()); std::unique_ptr executor_context( - new executor::ExecutorContext(txn, params)); + new executor::ExecutorContext(txn, params, default_database_name)); bool status; std::unique_ptr executor_tree( @@ -142,19 +146,20 @@ void PlanExecutor::ExecutePlan( concurrency::TransactionContext *txn, const std::vector ¶ms, const std::vector &result_format, - std::function &&)> - on_complete) { + std::function &&)> on_complete, + std::string default_database_name) { PELOTON_ASSERT(plan != nullptr && txn != nullptr); LOG_TRACE("PlanExecutor Start (Txn ID=%" PRId64 ")", txn->GetTransactionId()); + LOG_DEBUG("PlanExecutor 1 default db name: %s", default_database_name.c_str()); bool codegen_enabled = settings::SettingsManager::GetBool(settings::SettingId::codegen); try { if (codegen_enabled && codegen::QueryCompiler::IsSupported(*plan)) { - CompileAndExecutePlan(plan, txn, params, on_complete); + CompileAndExecutePlan(plan, txn, params, on_complete, default_database_name); } else { - InterpretPlan(plan, txn, params, result_format, on_complete); + InterpretPlan(plan, txn, params, result_format, on_complete, default_database_name); } } catch (Exception &e) { ExecutionResult result; @@ -181,6 +186,7 @@ int PlanExecutor::ExecutePlan( std::vector> &logical_tile_list) { PELOTON_ASSERT(plan != nullptr); LOG_TRACE("PlanExecutor Start with transaction"); + LOG_DEBUG("PlanExecutor 2"); auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); auto txn = txn_manager.BeginTransaction(); diff --git a/src/expression/function_expression.cpp b/src/expression/function_expression.cpp index 9f74f30e475..1ab6d28af31 100644 --- a/src/expression/function_expression.cpp +++ b/src/expression/function_expression.cpp @@ -42,13 +42,15 @@ expression::FunctionExpression::FunctionExpression( type::Value FunctionExpression::Evaluate( const AbstractTuple *tuple1, const AbstractTuple *tuple2, - UNUSED_ATTRIBUTE executor::ExecutorContext *context) const { + executor::ExecutorContext *context) const { std::vector child_values; PELOTON_ASSERT(func_.impl != nullptr); for (auto &child : children_) { child_values.push_back(child->Evaluate(tuple1, tuple2, context)); } + uint64_t ctx = (uint64_t)context; + child_values.push_back(type::ValueFactory::GetBigIntValue(ctx)); type::Value ret = func_.impl(child_values); diff --git a/src/function/old_engine_string_functions.cpp b/src/function/old_engine_string_functions.cpp index 8add85a1fe1..96fe5a30e35 100644 --- a/src/function/old_engine_string_functions.cpp +++ b/src/function/old_engine_string_functions.cpp @@ -235,5 +235,18 @@ type::Value OldEngineStringFunctions::Lower( throw Exception{"Lower not implemented in old engine"}; } +type::Value OldEngineStringFunctions::Nextval( + UNUSED_ATTRIBUTE const std::vector &args) { + executor::ExecutorContext* ctx=(executor::ExecutorContext*)args[1].GetAs(); + uint32_t ret = StringFunctions::Nextval(*ctx, args[0].GetAs()); + return type::ValueFactory::GetIntegerValue(ret); +} + +type::Value OldEngineStringFunctions::Currval( + UNUSED_ATTRIBUTE const std::vector &args) { + executor::ExecutorContext* ctx=(executor::ExecutorContext*)args[1].GetAs(); + uint32_t ret = StringFunctions::Currval(*ctx, args[0].GetAs()); + return type::ValueFactory::GetIntegerValue(ret); +} } // namespace function } // namespace peloton diff --git a/src/function/string_functions.cpp b/src/function/string_functions.cpp index 841a9ee6e15..5b739bcd40f 100644 --- a/src/function/string_functions.cpp +++ b/src/function/string_functions.cpp @@ -14,6 +14,9 @@ #include "common/macros.h" #include "executor/executor_context.h" +#include "catalog/catalog.h" +#include "catalog/database_catalog.h" +#include "catalog/sequence_catalog.h" namespace peloton { namespace function { @@ -220,5 +223,36 @@ uint32_t StringFunctions::Length( return length; } +uint32_t StringFunctions::Nextval(executor::ExecutorContext &ctx, const char *sequence_name) { + PELOTON_ASSERT(sequence_name != nullptr); + auto database_object = + catalog::Catalog::GetInstance() + ->GetDatabaseObject(ctx.GetDatabaseName(), ctx.GetTransaction()); + catalog::SequenceCatalogObject* sequence_object = catalog::SequenceCatalog::GetInstance(). + GetSequence(database_object->GetDatabaseOid(), sequence_name, ctx.GetTransaction()).get(); + if (sequence_object != nullptr) { + return sequence_object->GetNextVal(); + } else { + throw SequenceException( + StringUtil::Format("Sequence not exists!")); + } +} + + +uint32_t StringFunctions::Currval(executor::ExecutorContext &ctx, const char *sequence_name) { + PELOTON_ASSERT(sequence_name != nullptr); + auto database_object = + catalog::Catalog::GetInstance() + ->GetDatabaseObject(ctx.GetDatabaseName(), ctx.GetTransaction()); + catalog::SequenceCatalogObject* sequence_object = catalog::SequenceCatalog::GetInstance(). + GetSequence(database_object->GetDatabaseOid(), sequence_name, ctx.GetTransaction()).get(); + if (sequence_object != nullptr) { + return sequence_object->GetCurrVal(); + } else { + throw SequenceException( + StringUtil::Format("Sequence not exists!")); + } +} + } // namespace function } // namespace peloton diff --git a/src/include/catalog/abstract_catalog.h b/src/include/catalog/abstract_catalog.h index 9acf67773b9..bf62b62e2d4 100644 --- a/src/include/catalog/abstract_catalog.h +++ b/src/include/catalog/abstract_catalog.h @@ -72,6 +72,11 @@ class AbstractCatalog { expression::AbstractExpression *predicate, concurrency::TransactionContext *txn); + bool UpdateWithIndexScan( + std::vector update_columns, std::vector update_values, + std::vector scan_values, oid_t index_offset, + concurrency::TransactionContext *txn); + void AddIndex(const std::vector &key_attrs, oid_t index_oid, const std::string &index_name, IndexConstraintType index_constraint); diff --git a/src/include/catalog/catalog_defaults.h b/src/include/catalog/catalog_defaults.h index 69834ec769a..e5697632bb2 100644 --- a/src/include/catalog/catalog_defaults.h +++ b/src/include/catalog/catalog_defaults.h @@ -44,6 +44,7 @@ namespace catalog { #define TRIGGER_OID_MASK (static_cast(catalog::CatalogType::TRIGGER)) #define LANGUAGE_OID_MASK (static_cast(catalog::CatalogType::LANGUAGE)) #define PROC_OID_MASK (static_cast(catalog::CatalogType::PROC)) +#define SEQUENCE_OID_MASK (static_cast(catalog::CatalogType::SEQUENCE)) // Reserved pg_catalog database oid #define CATALOG_DATABASE_OID (0 | DATABASE_OID_MASK) @@ -76,6 +77,7 @@ enum class CatalogType : uint32_t { TRIGGER = 5 << CATALOG_TYPE_OFFSET, LANGUAGE = 6 << CATALOG_TYPE_OFFSET, PROC = 7 << CATALOG_TYPE_OFFSET, + SEQUENCE = 8 << CATALOG_TYPE_OFFSET, // To be added }; diff --git a/src/include/catalog/sequence_catalog.h b/src/include/catalog/sequence_catalog.h new file mode 100644 index 00000000000..9b26a647647 --- /dev/null +++ b/src/include/catalog/sequence_catalog.h @@ -0,0 +1,156 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// sequence_catalog.h +// +// Identification: src/include/catalog/sequence_catalog.h +// +// Copyright (c) 2015-17, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// pg_trigger +// +// Schema: (column offset: column_name) +// 0: oid (pkey) +// 1: sqdboid : database_oid +// 2: sqname : sequence_name +// 3: sqinc : seq_increment +// 4: sqmax : seq_max +// 5: sqmin : seq_min +// 6: sqstart : seq_start +// 7: sqcycle : seq_cycle +// 7: sqval : seq_value +// +// Indexes: (index offset: indexed columns) +// 0: oid (primary key) +// 1: (sqdboid, sqname) (secondary key 0) +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include "catalog/abstract_catalog.h" +#include "catalog/catalog_defaults.h" + +#define SEQUENCE_CATALOG_NAME "pg_sequence" + +namespace peloton { + +namespace concurrency { +class TransactionContext; +} + +namespace catalog { + +class SequenceCatalogObject { + public: + SequenceCatalogObject(oid_t seqoid, const std::string &name, + const int64_t seqstart, const int64_t seqincrement, + const int64_t seqmax, const int64_t seqmin, + const bool seqcycle, const int64_t seqval, + concurrency::TransactionContext *txn) + : seq_oid(seqoid), + seq_name(name), + seq_start(seqstart), + seq_increment(seqincrement), + seq_max(seqmax), + seq_min(seqmin), + seq_cycle(seqcycle), + txn_(txn), + seq_curr_val(seqval){}; + + oid_t seq_oid; + std::string seq_name; + int64_t seq_start; // Start value of the sequence + int64_t seq_increment; // Increment value of the sequence + int64_t seq_max; // Maximum value of the sequence + int64_t seq_min; // Minimum value of the sequence + int64_t seq_cache; // Cache size of the sequence + bool seq_cycle; // Whether the sequence cycles + concurrency::TransactionContext *txn_; + + std::mutex sequence_mutex; // mutex for all operations + int64_t GetNextVal() { + std::lock_guard lock(sequence_mutex); + return get_next_val(); + }; + + int64_t GetCurrVal() { + std::lock_guard lock(sequence_mutex); + return seq_curr_val; + }; + + void SetCurrVal(int64_t curr_val) { + seq_curr_val = curr_val; + }; // only visible for test! + void SetCycle(bool cycle) { seq_cycle = cycle; }; + + private: + int64_t seq_curr_val; + int64_t get_next_val(); +}; + +class SequenceCatalog : public AbstractCatalog { + public: + ~SequenceCatalog(); + + // Global Singleton + static SequenceCatalog &GetInstance( + concurrency::TransactionContext *txn = nullptr); + + //===--------------------------------------------------------------------===// + // write Related API + //===--------------------------------------------------------------------===// + bool InsertSequence(oid_t database_oid, std::string sequence_name, + int64_t seq_increment, int64_t seq_max, int64_t seq_min, + int64_t seq_start, bool seq_cycle, + type::AbstractPool *pool, + concurrency::TransactionContext *txn); + + ResultType DropSequence(const std::string &database_name, + const std::string &sequence_name, + concurrency::TransactionContext *txn); + + bool DeleteSequenceByName(const std::string &sequence_name, + oid_t database_oid, + concurrency::TransactionContext *txn); + + std::shared_ptr GetSequence( + oid_t database_oid, const std::string &sequence_name, + concurrency::TransactionContext *txn); + + oid_t GetSequenceOid(std::string sequence_name, oid_t database_oid, + concurrency::TransactionContext *txn); + + bool UpdateNextVal(oid_t sequence_oid, int64_t nextval, + concurrency::TransactionContext *txn); + + enum ColumnId { + SEQUENCE_OID = 0, + DATABSE_OID = 1, + SEQUENCE_NAME = 2, + SEQUENCE_INC = 3, + SEQUENCE_MAX = 4, + SEQUENCE_MIN = 5, + SEQUENCE_START = 6, + SEQUENCE_CYCLE = 7, + SEQUENCE_VALUE = 8 + }; + + enum IndexId { PRIMARY_KEY = 0, DBOID_SEQNAME_KEY = 1 }; + + private: + SequenceCatalog(concurrency::TransactionContext *txn); + + oid_t GetNextOid() { return oid_++ | SEQUENCE_OID_MASK; } +}; + +} // namespace catalog +} // namespace peloton diff --git a/src/include/codegen/proxy/string_functions_proxy.h b/src/include/codegen/proxy/string_functions_proxy.h index 1862db23db9..5868229e13b 100644 --- a/src/include/codegen/proxy/string_functions_proxy.h +++ b/src/include/codegen/proxy/string_functions_proxy.h @@ -31,6 +31,10 @@ PROXY(StringFunctions) { DECLARE_METHOD(RTrim); DECLARE_METHOD(Substr); DECLARE_METHOD(Repeat); + + // Sequence-related functions + DECLARE_METHOD(Nextval); + DECLARE_METHOD(Currval); }; PROXY(StrWithLen) { diff --git a/src/include/common/exception.h b/src/include/common/exception.h index 4c201891751..d8e82565da1 100644 --- a/src/include/common/exception.h +++ b/src/include/common/exception.h @@ -59,7 +59,8 @@ enum class ExceptionType { SETTINGS = 23, // settings related BINDER = 24, // binder related NETWORK = 25, // network related - OPTIMIZER = 26 // optimizer related + OPTIMIZER = 26, // optimizer related + SEQUENCE = 27 // sequence related }; class Exception : public std::runtime_error { @@ -76,9 +77,7 @@ class Exception : public std::runtime_error { "\nMessage :: " + message; } - std::string GetMessage() { - return exception_message_; - } + std::string GetMessage() { return exception_message_; } std::string ExceptionTypeToString(ExceptionType type) { switch (type) { @@ -132,6 +131,8 @@ class Exception : public std::runtime_error { return "Settings"; case ExceptionType::OPTIMIZER: return "Optimizer"; + case ExceptionType::SEQUENCE: + return "Sequence"; default: return "Unknown"; } @@ -467,4 +468,12 @@ class OptimizerException : public Exception { : Exception(ExceptionType::OPTIMIZER, msg) {} }; +class SequenceException : public Exception { + SequenceException() = delete; + + public: + SequenceException(std::string msg) + : Exception(ExceptionType::SEQUENCE, msg) {} +}; + } // namespace peloton diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index bf4bc20398a..4ffad744001 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -611,7 +611,8 @@ enum class CreateType { TABLE = 2, // table create type INDEX = 3, // index create type CONSTRAINT = 4, // constraint create type - TRIGGER = 5 // trigger create type + TRIGGER = 5, // trigger create type + SEQUENCE = 6 }; std::string CreateTypeToString(CreateType type); CreateType StringToCreateType(const std::string &str); @@ -707,7 +708,8 @@ enum class QueryType { QUERY_CREATE_TRIGGER = 21, QUERY_CREATE_SCHEMA = 22, QUERY_CREATE_VIEW = 23, - QUERY_EXPLAIN = 24 + QUERY_EXPLAIN = 24, + QUERY_CREATE_SEQUENCE = 25 }; std::string QueryTypeToString(QueryType query_type); QueryType StringToQueryType(std::string str); @@ -1089,6 +1091,8 @@ enum class OperatorId : uint32_t { DateTrunc, Like, Now, + Nextval, + Currval, // Add more operators here, before the last "Invalid" entry Invalid diff --git a/src/include/executor/create_executor.h b/src/include/executor/create_executor.h index 1802081ac62..fbdfd5ab7a0 100644 --- a/src/include/executor/create_executor.h +++ b/src/include/executor/create_executor.h @@ -52,6 +52,8 @@ class CreateExecutor : public AbstractExecutor { bool CreateTrigger(const planner::CreatePlan &node); + bool CreateSequence(const planner::CreatePlan &node); + private: ExecutorContext *context_; diff --git a/src/include/executor/executor_context.h b/src/include/executor/executor_context.h index 79cfe5cd19b..0e02a3e6d2a 100644 --- a/src/include/executor/executor_context.h +++ b/src/include/executor/executor_context.h @@ -30,7 +30,8 @@ namespace executor { class ExecutorContext { public: explicit ExecutorContext(concurrency::TransactionContext *transaction, - codegen::QueryParameters parameters = {}); + codegen::QueryParameters parameters = {}, + std::string default_database_name = ""); DISALLOW_COPY_AND_MOVE(ExecutorContext); @@ -44,6 +45,8 @@ class ExecutorContext { type::EphemeralPool *GetPool(); + std::string GetDatabaseName() const; + // Number of processed tuples during execution uint32_t num_processed = 0; @@ -54,6 +57,8 @@ class ExecutorContext { codegen::QueryParameters parameters_; // Temporary memory pool for allocations done during execution std::unique_ptr pool_; + // Default database name + std::string default_database_name_; }; } // namespace executor diff --git a/src/include/executor/plan_executor.h b/src/include/executor/plan_executor.h index 49d32b98b71..c02e3e32581 100644 --- a/src/include/executor/plan_executor.h +++ b/src/include/executor/plan_executor.h @@ -61,7 +61,8 @@ class PlanExecutor { const std::vector ¶ms, const std::vector &result_format, std::function &&)> on_complete); + std::vector &&)> on_complete, + std::string default_database_name); /* * @brief When a peloton node recvs a query plan, this function is invoked diff --git a/src/include/function/old_engine_string_functions.h b/src/include/function/old_engine_string_functions.h index 7603ac14fd0..f8748e9ae76 100644 --- a/src/include/function/old_engine_string_functions.h +++ b/src/include/function/old_engine_string_functions.h @@ -68,6 +68,10 @@ class OldEngineStringFunctions { // Upper, Lower static type::Value Upper(const std::vector &args); static type::Value Lower(const std::vector &args); + + // Sequence-related + static type::Value Nextval(const std::vector &args); + static type::Value Currval(const std::vector &args); }; } // namespace function diff --git a/src/include/function/string_functions.h b/src/include/function/string_functions.h index 2a209d0dee6..1d41551cef1 100644 --- a/src/include/function/string_functions.h +++ b/src/include/function/string_functions.h @@ -74,6 +74,12 @@ class StringFunctions { // Length will return the number of characters in the given string static uint32_t Length(executor::ExecutorContext &ctx, const char *str, uint32_t length); + + // Nextval will return the next value of the given sequence + static uint32_t Nextval(executor::ExecutorContext &ctx, const char *sequence_name); + + // Currval will return the current value of the given sequence + static uint32_t Currval(executor::ExecutorContext &ctx, const char *sequence_name); }; } // namespace function diff --git a/src/include/parser/create_statement.h b/src/include/parser/create_statement.h index 1abf45fa968..1b9dce7d88e 100644 --- a/src/include/parser/create_statement.h +++ b/src/include/parser/create_statement.h @@ -13,6 +13,8 @@ #pragma once #include +#include + #include "common/sql_node_visitor.h" #include "expression/abstract_expression.h" #include "parser/sql_statement.h" @@ -215,7 +217,15 @@ struct ColumnDefinition { */ class CreateStatement : public TableRefStatement { public: - enum CreateType { kTable, kDatabase, kIndex, kTrigger, kSchema, kView }; + enum CreateType { + kTable, + kDatabase, + kIndex, + kTrigger, + kSchema, + kView, + kSequence + }; CreateStatement(CreateType type) : TableRefStatement(StatementType::CREATE), @@ -254,6 +264,17 @@ class CreateStatement : public TableRefStatement { std::unique_ptr trigger_when; int16_t trigger_type; // information about row, timing, events, access by // pg_trigger + + // attributes related to sequences + std::string sequence_name; + std::unique_ptr table; // deal with RangeVar + int64_t seq_start = 1; + int64_t seq_increment = 1; + int64_t seq_max_value = LONG_MAX; + int64_t seq_min_value = 1; + int64_t seq_cache; // sequence cache size, probably won't be supported in + // this project + bool seq_cycle = false; }; } // namespace parser diff --git a/src/include/parser/parsenodes.h b/src/include/parser/parsenodes.h index bf818ff6b86..a43c937076c 100644 --- a/src/include/parser/parsenodes.h +++ b/src/include/parser/parsenodes.h @@ -730,6 +730,16 @@ typedef struct CreateSchemaStmt bool if_not_exists; /* just do nothing if schema already exists? */ } CreateSchemaStmt; +typedef struct CreateSeqStmt +{ + NodeTag type; + RangeVar *sequence; /* the sequence to create */ + List *options; + Oid ownerId; /* ID of owner, or InvalidOid for default */ + bool for_identity; + bool if_not_exists; /* just do nothing if it already exists? */ +} CreateSeqStmt; + typedef enum RoleSpecType { ROLESPEC_CSTRING, /* role name is stored as a C string */ diff --git a/src/include/parser/postgresparser.h b/src/include/parser/postgresparser.h index decd43d9ee7..3827dcf52d1 100644 --- a/src/include/parser/postgresparser.h +++ b/src/include/parser/postgresparser.h @@ -118,8 +118,8 @@ class PostgresParser { static parser::TableRef *FromTransform(SelectStmt *root); // transform helper for select targets - static std::vector> - *TargetTransform(List *root); + static std::vector> * + TargetTransform(List *root); // transform helper for all expr nodes static expression::AbstractExpression *ExprTransform(Node *root); @@ -167,7 +167,8 @@ class PostgresParser { static parser::OrderDescription *OrderByTransform(List *order); // transform helper for table column definitions - static void ColumnDefTransform(ColumnDef* root, parser::CreateStatement* stmt); + static void ColumnDefTransform(ColumnDef *root, + parser::CreateStatement *stmt); // transform helper for create statements static parser::SQLStatement *CreateTransform(CreateStmt *root); @@ -195,7 +196,8 @@ class PostgresParser { * @param Postgres CreateDatabaseStmt parsenode * @return a peloton CreateStatement node */ - static parser::SQLStatement *CreateDatabaseTransform(CreateDatabaseStmt *root); + static parser::SQLStatement *CreateDatabaseTransform( + CreateDatabaseStmt *root); // transform helper for create schema statements static parser::SQLStatement *CreateSchemaTransform(CreateSchemaStmt *root); @@ -203,13 +205,15 @@ class PostgresParser { // transform helper for create view statements static parser::SQLStatement *CreateViewTransform(ViewStmt *root); + static parser::SQLStatement *CreateSequenceTransform(CreateSeqStmt *root); + // transform helper for column name (for insert statement) static std::vector *ColumnNameTransform(List *root); // transform helper for ListsTransform (insert multiple rows) static std::vector< - std::vector>> - *ValueListsTransform(List *root); + std::vector>> * + ValueListsTransform(List *root); // transform helper for insert statements static parser::SQLStatement *InsertTransform(InsertStmt *root); @@ -233,8 +237,8 @@ class PostgresParser { static parser::UpdateStatement *UpdateTransform(UpdateStmt *update_stmt); // transform helper for update statement - static std::vector> - *UpdateTargetTransform(List *root); + static std::vector> * + UpdateTargetTransform(List *root); // transform helper for drop statement static parser::DropStatement *DropTransform(DropStmt *root); @@ -282,13 +286,20 @@ class PostgresParser { static parser::CopyStatement *CopyTransform(CopyStmt *root); // transform helper for analyze statement - static parser::AnalyzeStatement *VacuumTransform(VacuumStmt* root); + static parser::AnalyzeStatement *VacuumTransform(VacuumStmt *root); - static parser::VariableSetStatement *VariableSetTransform(VariableSetStmt* root); + static parser::VariableSetStatement *VariableSetTransform( + VariableSetStmt *root); // transform helper for subquery expressions static expression::AbstractExpression *SubqueryExprTransform(SubLink *node); + static void parse_sequence_params(List *options, + parser::CreateStatement *result); + + static int64_t get_long_in_defel(DefElem *defel) { + return (int64_t)((reinterpret_cast(defel->arg))->val.ival); + }; }; } // namespace parser diff --git a/src/include/planner/create_plan.h b/src/include/planner/create_plan.h index e0e9f84add8..20973b4a1d7 100644 --- a/src/include/planner/create_plan.h +++ b/src/include/planner/create_plan.h @@ -50,7 +50,7 @@ struct ForeignKeyInfo { class CreatePlan : public AbstractPlan { public: CreatePlan() = delete; - + // This construnctor is for Create Database Test used only explicit CreatePlan(std::string database_name, CreateType c_type); @@ -86,7 +86,9 @@ class CreatePlan : public AbstractPlan { std::vector GetIndexAttributes() const { return index_attrs; } - inline std::vector GetForeignKeys() const { return foreign_keys; } + inline std::vector GetForeignKeys() const { + return foreign_keys; + } std::vector GetKeyAttrs() const { return key_attrs; } void SetKeyAttrs(std::vector p_key_attrs) { key_attrs = p_key_attrs; } @@ -108,11 +110,19 @@ class CreatePlan : public AbstractPlan { int16_t GetTriggerType() const { return trigger_type; } -protected: - // This is a helper method for extracting foreign key information - // and storing it in an internal struct. - void ProcessForeignKeyConstraint(const std::string &table_name, - const parser::ColumnDefinition *col); + std::string GetSequenceName() const { return sequence_name; } + int64_t GetSequenceStart() const { return seq_start; }; + int64_t GetSequenceIncrement() const { return seq_increment; } + int64_t GetSequenceMaxValue() const { return seq_max_value; } + int64_t GetSequenceMinValue() const { return seq_min_value; } + int64_t GetSequenceCacheSize() const { return seq_cache; } + bool GetSequenceCycle() const { return seq_cycle; } + + protected: + // This is a helper method for extracting foreign key information + // and storing it in an internal struct. + void ProcessForeignKeyConstraint(const std::string &table_name, + const parser::ColumnDefinition *col); private: // Table Name @@ -150,6 +160,15 @@ class CreatePlan : public AbstractPlan { int16_t trigger_type; // information about row, timing, events, access by // pg_trigger + // information for sequences; + std::string sequence_name; + int64_t seq_start; + int64_t seq_increment; + int64_t seq_max_value; + int64_t seq_min_value; + int64_t seq_cache; // sequence cache size, not supported yet + bool seq_cycle; + private: DISALLOW_COPY_AND_MOVE(CreatePlan); }; diff --git a/src/parser/create_statement.cpp b/src/parser/create_statement.cpp index dd52ee70433..c150ed4d855 100644 --- a/src/parser/create_statement.cpp +++ b/src/parser/create_statement.cpp @@ -28,7 +28,8 @@ const std::string CreateStatement::GetInfo(int num_indent) const { << StringUtil::Format("IF NOT EXISTS: %s", (if_not_exists) ? "True" : "False") << std::endl; os << StringUtil::Indent(num_indent + 1) - << StringUtil::Format("Table name: %s", GetTableName().c_str());; + << StringUtil::Format("Table name: %s", GetTableName().c_str()); + ; break; } case CreateStatement::CreateType::kDatabase: { @@ -70,6 +71,12 @@ const std::string CreateStatement::GetInfo(int num_indent) const { << StringUtil::Format("View name: %s", view_name.c_str()); break; } + case CreateStatement::CreateType::kSequence: { + os << "Create type: Sequence" << std::endl; + os << StringUtil::Indent(num_indent + 1) + << StringUtil::Format("Sequence name: %s", sequence_name.c_str()); + break; + } } os << std::endl; @@ -98,7 +105,7 @@ const std::string CreateStatement::GetInfo(int num_indent) const { << col->not_null << " primary : " << col->primary << " unique " << col->unique << " varlen " << col->varlen; } - os << std::endl; + os << std::endl; } } std::string info = os.str(); diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 32e7a374e38..c407d68c3c3 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -286,12 +286,10 @@ expression::AbstractExpression *PostgresParser::ColumnRefTransform( ->val.str)); } else { result = new expression::TupleValueExpression( - std::string( - (reinterpret_cast(fields->head->next->data.ptr_value)) - ->val.str), - std::string( - (reinterpret_cast(fields->head->data.ptr_value)) - ->val.str)); + std::string((reinterpret_cast( + fields->head->next->data.ptr_value))->val.str), + std::string((reinterpret_cast( + fields->head->data.ptr_value))->val.str)); } break; } @@ -490,9 +488,8 @@ expression::AbstractExpression *PostgresParser::TypeCastTransform( } TypeName *type_name = root->typeName; - char *name = - (reinterpret_cast(type_name->names->tail->data.ptr_value) - ->val.str); + char *name = (reinterpret_cast( + type_name->names->tail->data.ptr_value)->val.str); type::VarlenType temp(StringToTypeId("INVALID")); result = new expression::ConstantValueExpression( temp.CastAs(source_value, ColumnDefinition::StrToValueType(name))); @@ -558,8 +555,8 @@ expression::AbstractExpression *PostgresParser::FuncCallTransform( // This function takes in the whereClause part of a Postgres SelectStmt // parsenode and transfers it into the select_list of a Peloton SelectStatement. // It checks the type of each target and call the corresponding helpers. -std::vector> - *PostgresParser::TargetTransform(List *root) { +std::vector> * +PostgresParser::TargetTransform(List *root) { // Statement like 'SELECT;' cannot detect by postgres parser and would lead to // null list if (root == nullptr) { @@ -865,9 +862,8 @@ expression::AbstractExpression *PostgresParser::WhenTransform(Node *root) { void PostgresParser::ColumnDefTransform(ColumnDef *root, parser::CreateStatement *stmt) { TypeName *type_name = root->typeName; - char *name = - (reinterpret_cast(type_name->names->tail->data.ptr_value) - ->val.str); + char *name = (reinterpret_cast( + type_name->names->tail->data.ptr_value)->val.str); parser::ColumnDefinition *result = nullptr; parser::ColumnDefinition::DataType data_type = @@ -1055,9 +1051,8 @@ parser::FuncParameter *PostgresParser::FunctionParameterTransform( FunctionParameter *root) { parser::FuncParameter::DataType data_type; TypeName *type_name = root->argType; - char *name = - (reinterpret_cast(type_name->names->tail->data.ptr_value) - ->val.str); + char *name = (reinterpret_cast( + type_name->names->tail->data.ptr_value)->val.str); parser::FuncParameter *result = nullptr; // Transform parameter type @@ -1346,6 +1341,85 @@ parser::SQLStatement *PostgresParser::CreateViewTransform(ViewStmt *root) { return result; } +parser::SQLStatement *PostgresParser::CreateSequenceTransform( + CreateSeqStmt *root) { + parser::CreateStatement *result = + new parser::CreateStatement(CreateStatement::kSequence); + result->sequence_name = std::string(root->sequence->relname); + result->table.reset( + RangeVarTransform(reinterpret_cast(root->sequence))); + parse_sequence_params(root->options, result); + return result; +} + +void PostgresParser::parse_sequence_params(List *options, + parser::CreateStatement *result) { + DefElem *start_value = NULL; + // DefElem *restart_value = NULL; + DefElem *increment_by = NULL; + DefElem *max_value = NULL; + DefElem *min_value = NULL; + DefElem *cache_value = NULL; + DefElem *is_cycled = NULL; + if (!options) return; + + ListCell *option; + for (option = options->head; option != NULL; option = lnext(option)) { + DefElem *defel = (DefElem *)lfirst(option); + + if (strcmp(defel->defname, "increment") == 0) { + if (increment_by) + throw ParserException( + "Redundant definition of increment in defining sequence"); + increment_by = defel; + result->seq_increment = get_long_in_defel(increment_by); + } else if (strcmp(defel->defname, "start") == 0) { + if (start_value) + throw ParserException( + "Redundant definition of start in defining sequence"); + start_value = defel; + result->seq_start = get_long_in_defel(start_value); + } else if (strcmp(defel->defname, "maxvalue") == 0) { + if (max_value) + throw ParserException( + "Redundant definition of max in defining sequence"); + max_value = defel; + result->seq_max_value = get_long_in_defel(max_value); + } else if (strcmp(defel->defname, "minvalue") == 0) { + if (min_value) + throw ParserException( + "Redundant definition of min in defining sequence"); + min_value = defel; + result->seq_min_value = get_long_in_defel(min_value); + } else if (strcmp(defel->defname, "cache") == 0) { + if (cache_value) + throw ParserException( + "Redundant definition of cache in defining sequence"); + cache_value = defel; + result->seq_cache = get_long_in_defel(cache_value); + } else if (strcmp(defel->defname, "cycle") == 0) { + if (is_cycled) + throw ParserException( + "Redundant definition of cycle in defining sequence"); + is_cycled = defel; + result->seq_cycle = (bool)get_long_in_defel(is_cycled); + } + // TODO: support owned_by + // else if (strcmp(defel->defname, "owned_by") == 0) + // { + // // if (*owned_by) + // // ereport(ERROR, + // // (errcode(ERRCODE_SYNTAX_ERROR), + // // errmsg("conflicting or redundant options"), + // // parser_errposition(pstate, defel->location))); + // *owned_by = defGetQualifiedName(defel); + // } + else + throw ParserException( + StringUtil::Format("option \"%s\" not recognized\n", defel->defname)); + } +} + parser::DropStatement *PostgresParser::DropTransform(DropStmt *root) { switch (root->removeType) { case ObjectType::OBJECT_TABLE: @@ -1515,8 +1589,8 @@ std::vector *PostgresParser::ColumnNameTransform(List *root) { // parsenode and transfers it into Peloton AbstractExpression. // This is a vector pointer of vector pointers because one InsertStmt can insert // multiple tuples. -std::vector>> - *PostgresParser::ValueListsTransform(List *root) { +std::vector>> * +PostgresParser::ValueListsTransform(List *root) { auto result = new std::vector< std::vector>>(); @@ -1627,8 +1701,8 @@ parser::SQLStatement *PostgresParser::InsertTransform(InsertStmt *root) { result = new parser::InsertStatement(InsertType::VALUES); PELOTON_ASSERT(select_stmt->valuesLists != NULL); - std::vector>> - *insert_values = nullptr; + std::vector>> * + insert_values = nullptr; try { insert_values = ValueListsTransform(select_stmt->valuesLists); } catch (Exception e) { @@ -1763,6 +1837,9 @@ parser::SQLStatement *PostgresParser::NodeTransform(Node *stmt) { result = CreateSchemaTransform(reinterpret_cast(stmt)); break; + case T_CreateSeqStmt: + result = CreateSequenceTransform(reinterpret_cast(stmt)); + break; case T_ViewStmt: result = CreateViewTransform(reinterpret_cast(stmt)); break; @@ -1843,8 +1920,8 @@ parser::SQLStatementList *PostgresParser::ListTransform(List *root) { return result; } -std::vector> - *PostgresParser::UpdateTargetTransform(List *root) { +std::vector> * +PostgresParser::UpdateTargetTransform(List *root) { auto result = new std::vector>(); for (auto cell = root->head; cell != NULL; cell = cell->next) { auto update_clause = new UpdateClause(); diff --git a/src/planner/create_plan.cpp b/src/planner/create_plan.cpp index b19d8d534d8..bab9553f68f 100644 --- a/src/planner/create_plan.cpp +++ b/src/planner/create_plan.cpp @@ -21,8 +21,7 @@ namespace peloton { namespace planner { CreatePlan::CreatePlan(std::string database_name, CreateType c_type) - : database_name(database_name), - create_type(c_type) {} + : database_name(database_name), create_type(c_type) {} CreatePlan::CreatePlan(std::string table_name, std::string database_name, std::unique_ptr schema, @@ -32,8 +31,7 @@ CreatePlan::CreatePlan(std::string table_name, std::string database_name, table_schema(schema.release()), create_type(c_type) {} -CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) -{ +CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) { switch (parse_tree->type) { case parser::CreateStatement::CreateType::kDatabase: { create_type = CreateType::DB; @@ -56,74 +54,87 @@ CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) for (auto &col : parse_tree->columns) { type::TypeId val = col->GetValueType(col->type); - - LOG_TRACE("Column name: %s.%s; Is primary key: %d", table_name.c_str(), col->name.c_str(), col->primary); - + + LOG_TRACE("Column name: %s.%s; Is primary key: %d", table_name.c_str(), + col->name.c_str(), col->primary); + // Check main constraints if (col->primary) { - catalog::Constraint constraint(ConstraintType::PRIMARY, "con_primary"); + catalog::Constraint constraint(ConstraintType::PRIMARY, + "con_primary"); column_constraints.push_back(constraint); - LOG_TRACE("Added a primary key constraint on column \"%s.%s\"", table_name.c_str(), col->name.c_str()); + LOG_TRACE("Added a primary key constraint on column \"%s.%s\"", + table_name.c_str(), col->name.c_str()); } - + if (col->not_null) { - catalog::Constraint constraint(ConstraintType::NOTNULL, "con_not_null"); + catalog::Constraint constraint(ConstraintType::NOTNULL, + "con_not_null"); column_constraints.push_back(constraint); - LOG_TRACE("Added a not-null constraint on column \"%s.%s\"", table_name.c_str(), col->name.c_str()); + LOG_TRACE("Added a not-null constraint on column \"%s.%s\"", + table_name.c_str(), col->name.c_str()); } - + if (col->unique) { catalog::Constraint constraint(ConstraintType::UNIQUE, "con_unique"); column_constraints.push_back(constraint); - LOG_TRACE("Added a unique constraint on column \"%s.%s\"", table_name.c_str(), col->name.c_str()); + LOG_TRACE("Added a unique constraint on column \"%s.%s\"", + table_name.c_str(), col->name.c_str()); } - + /* **************** */ - + // Add the default value if (col->default_value != nullptr) { // Referenced from insert_plan.cpp - if (col->default_value->GetExpressionType() != ExpressionType::VALUE_PARAMETER) { + if (col->default_value->GetExpressionType() != + ExpressionType::VALUE_PARAMETER) { expression::ConstantValueExpression *const_expr_elem = - dynamic_cast(col->default_value.get()); - - catalog::Constraint constraint(ConstraintType::DEFAULT, "con_default"); + dynamic_cast( + col->default_value.get()); + + catalog::Constraint constraint(ConstraintType::DEFAULT, + "con_default"); type::Value v = const_expr_elem->GetValue(); constraint.addDefaultValue(v); column_constraints.push_back(constraint); LOG_TRACE("Added a default constraint %s on column \"%s.%s\"", - v.ToString().c_str(), table_name.c_str(), col->name.c_str()); + v.ToString().c_str(), table_name.c_str(), + col->name.c_str()); } } - + // Check expression constraint // Currently only supports simple boolean forms like (a > 0) if (col->check_expression != nullptr) { // TODO: more expression types need to be supported if (col->check_expression->GetValueType() == type::TypeId::BOOLEAN) { catalog::Constraint constraint(ConstraintType::CHECK, "con_check"); - + const expression::ConstantValueExpression *const_expr_elem = - dynamic_cast(col->check_expression->GetChild(1)); - + dynamic_cast( + col->check_expression->GetChild(1)); + type::Value tmp_value = const_expr_elem->GetValue(); - constraint.AddCheck(std::move(col->check_expression->GetExpressionType()), std::move(tmp_value)); + constraint.AddCheck( + std::move(col->check_expression->GetExpressionType()), + std::move(tmp_value)); column_constraints.push_back(constraint); LOG_TRACE("Added a check constraint on column \"%s.%s\"", table_name.c_str(), col->name.c_str()); } } - + auto column = catalog::Column(val, type::Type::GetTypeSize(val), std::string(col->name), false); if (!column.IsInlined()) { column.SetLength(col->varlen); } - + for (auto con : column_constraints) { column.AddConstraint(con); } - + column_constraints.clear(); columns.push_back(column); } @@ -141,17 +152,17 @@ CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) // This is a fix for a bug where // The vector* items gets deleted when passed // To the Executor. - + std::vector index_attrs_holder; - + for (auto &attr : parse_tree->index_attrs) { index_attrs_holder.push_back(attr); } - + index_attrs = index_attrs_holder; - + index_type = parse_tree->index_type; - + unique = parse_tree->unique; break; } @@ -160,14 +171,14 @@ CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) trigger_name = std::string(parse_tree->trigger_name); table_name = std::string(parse_tree->GetTableName()); database_name = std::string(parse_tree->GetDatabaseName()); - + if (parse_tree->trigger_when) { trigger_when.reset(parse_tree->trigger_when->Copy()); } else { trigger_when.reset(); } trigger_type = parse_tree->trigger_type; - + for (auto &s : parse_tree->trigger_funcname) { trigger_funcname.push_back(s); } @@ -180,25 +191,38 @@ CreatePlan::CreatePlan(parser::CreateStatement *parse_tree) break; } + case parser::CreateStatement::CreateType::kSequence: { + create_type = CreateType::SEQUENCE; + database_name = std::string(parse_tree->GetDatabaseName()); + + sequence_name = parse_tree->sequence_name; + seq_start = parse_tree->seq_start; + seq_increment = parse_tree->seq_increment; + seq_max_value = parse_tree->seq_max_value; + seq_min_value = parse_tree->seq_min_value; + seq_cache = parse_tree->seq_cache; + seq_cycle = parse_tree->seq_cycle; + + break; + } default: LOG_ERROR("UNKNOWN CREATE TYPE"); - //TODO Should we handle this here? + // TODO Should we handle this here? break; } - + // TODO check type CreateType::kDatabase } -void CreatePlan::ProcessForeignKeyConstraint(const std::string &table_name, - const parser::ColumnDefinition *col) { - +void CreatePlan::ProcessForeignKeyConstraint( + const std::string &table_name, const parser::ColumnDefinition *col) { ForeignKeyInfo fkey_info; // Extract source and sink column names - for (auto& key : col->foreign_key_source) { + for (auto &key : col->foreign_key_source) { fkey_info.foreign_key_sources.push_back(key); } - for (auto& key : col->foreign_key_sink) { + for (auto &key : col->foreign_key_sink) { fkey_info.foreign_key_sinks.push_back(key); } diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index fd29c7966b2..a3cbd34b009 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -191,9 +191,10 @@ executor::ExecutionResult TrafficCop::ExecuteHelper( }; auto &pool = threadpool::MonoQueuePool::GetInstance(); - pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete] { + std::string default_database_name = default_database_name_; + pool.SubmitTask([plan, txn, ¶ms, &result_format, on_complete, default_database_name] { executor::PlanExecutor::ExecutePlan(plan, txn, params, result_format, - on_complete); + on_complete, default_database_name); }); is_queuing_ = true; diff --git a/test/catalog/sequence_catalog_test.cpp b/test/catalog/sequence_catalog_test.cpp new file mode 100644 index 00000000000..3f9dda1611d --- /dev/null +++ b/test/catalog/sequence_catalog_test.cpp @@ -0,0 +1,210 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// sequence_test.cpp +// +// Identification: test/sequence/sequence_test.cpp +// +// Copyright (c) 2015-17, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include "catalog/catalog.h" +#include "catalog/sequence_catalog.h" +#include "storage/abstract_table.h" +#include "common/harness.h" +#include "common/exception.h" +#include "executor/executors.h" +#include "parser/postgresparser.h" +#include "planner/create_plan.h" +#include "planner/insert_plan.h" +#include "concurrency/transaction_manager_factory.h" + +namespace peloton { +namespace test { + +class SequenceTests : public PelotonTest { + protected: + void CreateDatabaseHelper() { + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(); + catalog::Catalog::GetInstance()->Bootstrap(); + catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn); + txn_manager.CommitTransaction(txn); + } + + std::shared_ptr GetSequenceHelper( + std::string sequence_name, concurrency::TransactionContext *txn) { + // Check the effect of creation + oid_t database_oid = catalog::Catalog::GetInstance() + ->GetDatabaseWithName(DEFAULT_DB_NAME, txn) + ->GetOid(); + std::shared_ptr new_sequence = + catalog::SequenceCatalog::GetInstance().GetSequence(database_oid, + sequence_name, txn); + + return new_sequence; + } + + void CreateSequenceHelper(std::string query, + concurrency::TransactionContext *txn) { + auto parser = parser::PostgresParser::GetInstance(); + + std::unique_ptr stmt_list( + parser.BuildParseTree(query).release()); + EXPECT_TRUE(stmt_list->is_valid); + EXPECT_EQ(StatementType::CREATE, stmt_list->GetStatement(0)->GetType()); + auto create_sequence_stmt = + static_cast(stmt_list->GetStatement(0)); + + create_sequence_stmt->TryBindDatabaseName(DEFAULT_DB_NAME); + // Create plans + planner::CreatePlan plan(create_sequence_stmt); + + // plan type + EXPECT_EQ(CreateType::SEQUENCE, plan.GetCreateType()); + + // Execute the create sequence + std::unique_ptr context( + new executor::ExecutorContext(txn)); + executor::CreateExecutor createSequenceExecutor(&plan, context.get()); + createSequenceExecutor.Init(); + createSequenceExecutor.Execute(); + } +}; + +TEST_F(SequenceTests, BasicTest) { + CreateDatabaseHelper(); + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(); + + // Create statement + std::string query = + "CREATE SEQUENCE seq " + "INCREMENT BY 2 " + "MINVALUE 10 MAXVALUE 50 " + "START 10 CYCLE;"; + std::string name = "seq"; + + CreateSequenceHelper(query, txn); + std::shared_ptr new_sequence = + GetSequenceHelper(name, txn); + + EXPECT_EQ(name, new_sequence->seq_name); + EXPECT_EQ(2, new_sequence->seq_increment); + EXPECT_EQ(10, new_sequence->seq_min); + EXPECT_EQ(50, new_sequence->seq_max); + EXPECT_EQ(10, new_sequence->seq_start); + EXPECT_EQ(true, new_sequence->seq_cycle); + EXPECT_EQ(10, new_sequence->GetCurrVal()); + + int64_t nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(10, nextVal); + txn_manager.CommitTransaction(txn); +} + +TEST_F(SequenceTests, NoDuplicateTest) { + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(); + + // Create statement + std::string query = + "CREATE SEQUENCE seq " + "INCREMENT BY 2 " + "MINVALUE 10 MAXVALUE 50 " + "START 10 CYCLE;"; + std::string name = "seq"; + + // Expect exception + try { + CreateSequenceHelper(query, txn); + EXPECT_EQ(0, 1); + } catch (const SequenceException &expected) { + ASSERT_STREQ("Insert Sequence with Duplicate Sequence Name: seq", + expected.what()); + } + txn_manager.CommitTransaction(txn); +} + +TEST_F(SequenceTests, NextValPosIncrementFunctionalityTest) { + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(); + + std::string query = + "CREATE SEQUENCE seq1 " + "INCREMENT BY 1 " + "MINVALUE 10 MAXVALUE 50 " + "START 10 CYCLE;"; + std::string name = "seq1"; + + CreateSequenceHelper(query, txn); + std::shared_ptr new_sequence = + GetSequenceHelper(name, txn); + + int64_t nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(10, nextVal); + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(11, nextVal); + + // test cycle + new_sequence->SetCurrVal(50); + nextVal = new_sequence->GetNextVal(); + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(10, nextVal); + + // test no cycle + new_sequence->SetCycle(false); + new_sequence->SetCurrVal(50); + + // Expect exception + try { + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(0, 1); + } catch (const SequenceException &expected) { + ASSERT_STREQ("Sequence exceeds upper limit!", expected.what()); + } + txn_manager.CommitTransaction(txn); +} + +TEST_F(SequenceTests, NextValNegIncrementFunctionalityTest) { + auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); + auto txn = txn_manager.BeginTransaction(); + + std::string query = + "CREATE SEQUENCE seq2 " + "INCREMENT BY -1 " + "MINVALUE 10 MAXVALUE 50 " + "START 10 CYCLE;"; + std::string name = "seq2"; + + CreateSequenceHelper(query, txn); + std::shared_ptr new_sequence = + GetSequenceHelper(name, txn); + + // test cycle + int64_t nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(10, nextVal); + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(50, nextVal); + + new_sequence->SetCurrVal(49); + nextVal = new_sequence->GetNextVal(); + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(48, nextVal); + + // test no cycle + new_sequence->SetCycle(false); + new_sequence->SetCurrVal(10); + + // Expect exception + try { + nextVal = new_sequence->GetNextVal(); + EXPECT_EQ(0, 1); + } catch (const SequenceException &expected) { + ASSERT_STREQ("Sequence exceeds lower limit!", expected.what()); + } + txn_manager.CommitTransaction(txn); +} +} +} diff --git a/test/parser/postgresparser_test.cpp b/test/parser/postgresparser_test.cpp index 97e86adc8da..5e4740cd783 100644 --- a/test/parser/postgresparser_test.cpp +++ b/test/parser/postgresparser_test.cpp @@ -588,8 +588,7 @@ TEST_F(PostgresParserTests, InsertTest) { CmpBool res = five.CompareEquals( ((expression::ConstantValueExpression *)insert_stmt->insert_values.at(1) .at(1) - .get()) - ->GetValue()); + .get())->GetValue()); EXPECT_EQ(CmpBool::CmpTrue, res); // LOG_TRACE("%d : %s", ++ii, stmt_list->GetInfo().c_str()); @@ -1021,15 +1020,13 @@ TEST_F(PostgresParserTests, CreateTriggerTest) { EXPECT_EQ(ExpressionType::VALUE_TUPLE, left->GetExpressionType()); EXPECT_EQ("old", static_cast(left) ->GetTableName()); - EXPECT_EQ("balance", - static_cast(left) - ->GetColumnName()); + EXPECT_EQ("balance", static_cast( + left)->GetColumnName()); EXPECT_EQ(ExpressionType::VALUE_TUPLE, right->GetExpressionType()); EXPECT_EQ("new", static_cast(right) ->GetTableName()); - EXPECT_EQ("balance", - static_cast(right) - ->GetColumnName()); + EXPECT_EQ("balance", static_cast( + right)->GetColumnName()); // level // the level is for each row EXPECT_TRUE(TRIGGER_FOR_ROW(create_trigger_stmt->trigger_type)); @@ -1068,6 +1065,41 @@ TEST_F(PostgresParserTests, DropTriggerTest) { EXPECT_EQ("films", drop_trigger_stmt->GetTriggerTableName()); } +TEST_F(PostgresParserTests, CreateSequenceTest) { + auto parser = parser::PostgresParser::GetInstance(); + + // missing AS, CACHE and OWNED BY. + std::string query = + "CREATE SEQUENCE seq " + "INCREMENT BY 2 " + "MINVALUE 10 " + "MAXVALUE 50 " + "CYCLE " + "START 10;"; + std::unique_ptr stmt_list( + parser.BuildParseTree(query).release()); + EXPECT_TRUE(stmt_list->is_valid); + if (!stmt_list->is_valid) { + LOG_ERROR("Message: %s, line: %d, col: %d", stmt_list->parser_msg, + stmt_list->error_line, stmt_list->error_col); + } + EXPECT_EQ(StatementType::CREATE, stmt_list->GetStatement(0)->GetType()); + auto create_sequence_stmt = + static_cast(stmt_list->GetStatement(0)); + + // The following code checks the arguments in the create statement + // are identical to what is specified in the query. + + // create type + EXPECT_EQ(parser::CreateStatement::CreateType::kSequence, + create_sequence_stmt->type); + EXPECT_EQ(10, create_sequence_stmt->seq_start); + EXPECT_EQ(2, create_sequence_stmt->seq_increment); + EXPECT_EQ(50, create_sequence_stmt->seq_max_value); + EXPECT_EQ(10, create_sequence_stmt->seq_min_value); + EXPECT_EQ(true, create_sequence_stmt->seq_cycle); +} + TEST_F(PostgresParserTests, FuncCallTest) { std::string query = "SELECT add(1,a), chr(99) FROM TEST WHERE FUN(b) > 2";