diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 2ccd85dcdb2..183c73e817d 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -12,6 +12,8 @@ #include "binder/bind_node_visitor.h" #include "catalog/catalog.h" +#include "catalog/table_catalog.h" +#include "catalog/column_catalog.h" #include "expression/expression_util.h" #include "expression/star_expression.h" #include "type/type_id.h" @@ -250,6 +252,21 @@ void BindNodeVisitor::Visit(expression::TupleValueExpression *expr) { expr->SetColName(col_name); expr->SetValueType(value_type); expr->SetBoundOid(col_pos_tuple); + + // TODO(esargent): Uncommenting the following code makes AddressSanitizer get mad at me with a + // heap buffer overflow whenever I try a query that references the same non-null attribute multiple + // times (e.g. 'SELECT id FROM t WHERE id < 3 AND id > 1'). Leaving it commented out prevents the + // memory error, but then this prevents the is_not_null flag of a tuple expression from being + // populated in some cases (specifically, when the expression's table name is initially empty). + + //if (table_obj == nullptr) { + // LOG_DEBUG("Extracting regular table object"); + // BinderContext::GetRegularTableObj(context_, table_name, table_obj, depth); + //} + + if (table_obj != nullptr) { + expr->SetIsNotNull(table_obj->GetColumnCatalogEntry(std::get<2>(col_pos_tuple), false)->IsNotNull()); + } } } diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 39c9647b2ef..e81ec101b02 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1393,6 +1393,14 @@ enum class RuleType : uint32_t { TV_EQUALITY_WITH_TWO_CV, // (A.B = x) AND (A.B = y) where x/y are constant TRANSITIVE_CLOSURE_CONSTANT, // (A.B = x) AND (A.B = C.D) + // Boolean short-circuit rules + AND_SHORT_CIRCUIT, // (FALSE AND B) + OR_SHORT_CIRCUIT, // (TRUE OR B) + + // Catalog-based NULL/NON-NULL rules + NULL_LOOKUP_ON_NOT_NULL_COLUMN, + NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN, + // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/expression/tuple_value_expression.h b/src/include/expression/tuple_value_expression.h index dab5d1e4ddd..16f37ae8645 100644 --- a/src/include/expression/tuple_value_expression.h +++ b/src/include/expression/tuple_value_expression.h @@ -79,6 +79,10 @@ class TupleValueExpression : public AbstractExpression { tuple_idx_ = tuple_idx; } + inline void SetIsNotNull(bool is_not_null) { + is_not_null_ = is_not_null; + } + /** * @brief Attribute binding * @param binding_contexts @@ -116,6 +120,8 @@ class TupleValueExpression : public AbstractExpression { if ((table_name_.empty() xor other.table_name_.empty()) || col_name_.empty() xor other.col_name_.empty()) return false; + if (GetIsNotNull() != other.GetIsNotNull()) + return false; bool res = bound_obj_id_ == other.bound_obj_id_; if (!table_name_.empty() && !other.table_name_.empty()) res = table_name_ == other.table_name_ && res; @@ -151,6 +157,8 @@ class TupleValueExpression : public AbstractExpression { bool GetIsBound() const { return is_bound_; } + bool GetIsNotNull() const { return is_not_null_; } + const std::tuple &GetBoundOid() const { return bound_obj_id_; } @@ -185,7 +193,8 @@ class TupleValueExpression : public AbstractExpression { value_idx_(other.value_idx_), tuple_idx_(other.tuple_idx_), table_name_(other.table_name_), - col_name_(other.col_name_) {} + col_name_(other.col_name_), + is_not_null_(other.is_not_null_) {} // Bound flag bool is_bound_ = false; @@ -196,6 +205,7 @@ class TupleValueExpression : public AbstractExpression { int tuple_idx_; std::string table_name_; std::string col_name_; + bool is_not_null_ = false; const planner::AttributeInfo *ai_; }; diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index 8df83556626..ab739aa177d 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -20,6 +20,9 @@ namespace peloton { namespace optimizer { +using GroupExprTemplate = GroupExpression; +using OptimizeContext = OptimizeContext; + /* Rules are applied from high to low priority */ enum class RulePriority : int { HIGH = 3, @@ -71,5 +74,49 @@ class TransitiveClosureConstantTransform: public Rule { OptimizeContext *context) const override; }; +class AndShortCircuit: public Rule { + public: + AndShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class OrShortCircuit: public Rule { + public: + OrShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class NullLookupOnNotNullColumn: public Rule { + public: + NullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class NotNullLookupOnNotNullColumn: public Rule { + public: + NotNullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 807c4c42f94..986710ab0ab 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -12,11 +12,15 @@ #include "optimizer/binding.h" +#include + #include "common/logger.h" #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" #include "optimizer/absexpr_expression.h" #include "expression/group_marker_expression.h" +#include "expression/abstract_expression.h" +#include "expression/tuple_value_expression.h" namespace peloton { namespace optimizer { diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 3baec8dba86..fc6e814e58c 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -76,6 +76,15 @@ RuleSet::RuleSet() { AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new AndShortCircuit()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new OrShortCircuit()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NullLookupOnNotNullColumn()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NotNullLookupOnNotNullColumn()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + // Define transformation/implementation rules AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index b804c08e488..7bcb3aaadf1 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -399,5 +399,228 @@ void TransitiveClosureConstantTransform::Transform(std::shared_ptr) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int AndShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} + +bool AndShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void AndShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (FALSE AND ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + LOG_DEBUG("fjdsklafjksdjflkadsjf"); + + // Only transform the expression if we're ANDing a FALSE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsFalse()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + + +OrShortCircuit::OrShortCircuit() { + type_ = RuleType::OR_SHORT_CIRCUIT; + + // (FALSE AND ) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_OR); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int OrShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} + +bool OrShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void OrShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_OR); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + // Only transform the expression if we're ANDing a TRUE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsTrue()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + + +NullLookupOnNotNullColumn::NullLookupOnNotNullColumn() { + type_ = RuleType::NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool NullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + PELOTON_ASSERT(child_c != nullptr); + + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [FALSE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + +NotNullLookupOnNotNullColumn::NotNullLookupOnNotNullColumn() { + type_ = RuleType::NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NOT NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NOT_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NotNullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool NotNullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NotNullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NOT_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [TRUE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + } // namespace optimizer } // namespace peloton diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp index 4ce84267d61..c9faacca9c0 100644 --- a/test/optimizer/rewriter_test.cpp +++ b/test/optimizer/rewriter_test.cpp @@ -18,6 +18,7 @@ #include "expression/constant_value_expression.h" #include "expression/comparison_expression.h" #include "expression/tuple_value_expression.h" +#include "expression/operator_expression.h" #include "type/value_factory.h" #include "type/value_peeker.h" #include "optimizer/rule_rewrite.h" @@ -142,5 +143,283 @@ TEST_F(RewriterTests, ComparativeOperatorTest) { delete rewrote; } +TEST_F(RewriterTests, BasicAndShortCircuitTest) { + + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [AND] + // [FALSE] [=] + // [X] [3] + // + // Intended output: [FALSE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_false); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [AND] + // [TRUE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_true); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_AND); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, BasicOrShortCircuitTest) { + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [OR] + // [TRUE] [=] + // [X] [3] + // + // Intended output: [TRUE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_true); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [OR] + // [FALSE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_false); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_OR); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, AndShortCircuitComparatorEliminationMixTest) { + // [AND] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: FALSE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + + +TEST_F(RewriterTests, OrShortCircuitComparatorEliminationMixTest) { + // [OR] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: TRUE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + + +TEST_F(RewriterTests, NotNullColumnsTest) { + + // First, build rewriter to be used in all test cases + Rewriter *rewriter = new Rewriter(); + + // [T.X IS NULL], where X is a non-NULL column in table T + // Intended output: FALSE + + auto child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + auto root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + auto rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), false); + + delete root; + delete rewrote; + + // [T.X IS NOT NULL], where X is a non-NULL column in table T + // Intended output: TRUE + + child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), true); + + delete root; + delete rewrote; + + // [T.Y IS NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NULL); + + delete root; + delete rewrote; + + // [T.Y IS NOT NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NOT_NULL); + + delete root; + delete rewrote; + + delete rewriter; +} + + } // namespace test } // namespace peloton