Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support switch-case with fall throughs #68

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ RUN(NAME expr2.c LABELS gcc c wasm llvm NOFAST)
RUN(NAME expr3.c FAIL LABELS gcc c wasm llvm NOFAST)

RUN(NAME switch_case_01.cpp LABELS gcc llvm NOFAST)
RUN(NAME switch_case_02.cpp LABELS gcc llvm NOFAST)

# arrays
RUN(NAME array_01.cpp LABELS gcc llvm NOFAST)
Expand Down
61 changes: 61 additions & 0 deletions integration_tests/switch_case_02.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <iostream>

int switch_case_with_fall_through(int x) {
int value = x;
switch(x) {
case 1: {
value = 2*value;
break;
}
case 2: {
value = 3*value;
}
case 3: {
value = 4*value;
break;
}
case 4: {
value = 5*value;
}
default: {
value = 6*value;
}
}
return value;
}


int main() {
int value;
value = switch_case_with_fall_through(1);
std::cout<<value<<std::endl;
if( value != 2 ) {
exit(2);
}

value = switch_case_with_fall_through(2);
std::cout<<value<<std::endl;
if( value != 24 ) {
exit(2);
}

value = switch_case_with_fall_through(3);
std::cout<<value<<std::endl;
if( value != 12 ) {
exit(2);
}

value = switch_case_with_fall_through(4);
std::cout<<value<<std::endl;
if( value != 120 ) {
exit(2);
}

value = switch_case_with_fall_through(7);
std::cout<<value<<std::endl;
if( value != 42 ) {
exit(2);
}

return 0;
}
17 changes: 13 additions & 4 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,16 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
Vec<ASR::case_stmt_t*>* current_switch_case;
Vec<ASR::stmt_t*>* default_stmt;
OneTimeUseBool is_break_stmt_present;
bool enable_fall_through;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Context(Context_), al{al_}, tu{tu_},
current_body{nullptr}, is_stmt_created{true},
assignment_target{nullptr}, print_args{nullptr},
is_all_called{false}, is_range_called{false},
current_switch_case{nullptr}, default_stmt{nullptr} {}
current_switch_case{nullptr}, default_stmt{nullptr},
enable_fall_through{false} {}

template <typename T>
Location Lloc(T *x) {
Expand Down Expand Up @@ -1711,6 +1713,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
Vec<ASR::stmt_t*> default_stmt_; default_stmt_.reserve(al, 1);
current_switch_case = &current_switch_case_;
default_stmt = &default_stmt_;
bool enable_fall_through_copy = enable_fall_through;
enable_fall_through = false;

clang::Expr* clang_cond = x->getCond();
TraverseStmt(clang_cond);
Expand All @@ -1720,10 +1724,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

tmp = ASR::make_Select_t(al, Lloc(x), cond,
current_switch_case->p, current_switch_case->size(),
default_stmt->p, default_stmt->size());
default_stmt->p, default_stmt->size(), enable_fall_through);
enable_fall_through = false;
current_switch_case = current_switch_case_copy;
default_stmt = default_stmt_copy;
is_stmt_created = true;
enable_fall_through = enable_fall_through_copy;
return true;
}

Expand Down Expand Up @@ -1751,12 +1757,15 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
is_break_stmt_present.set(false);
TraverseStmt(x->getSubStmt());
current_body = current_body_copy;
bool case_fall_through = true;
if( !is_break_stmt_present.get() ) {
throw std::runtime_error("Case must contain break statement. Fall through not yet supported.");
enable_fall_through = true;
} else {
case_fall_through = false;
}
ASR::case_stmt_t* case_stmt = ASR::down_cast<ASR::case_stmt_t>(
ASR::make_CaseStmt_t(al, Lloc(x), a_test.p, a_test.size(),
body.p, body.size()));
body.p, body.size(), case_fall_through));
current_switch_case->push_back(al, case_stmt);
is_stmt_created = false;
return true;
Expand Down
5 changes: 3 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ stmt
expr? carriagecontrol, expr? iolength)
| FileWrite(int label, expr? unit, expr? iomsg, expr? iostat, expr? id, expr* values, expr? separator, expr? end)
| Return()
| Select(expr test, case_stmt* body, stmt* default)
| Select(expr test, case_stmt* body, stmt* default, bool enable_fall_through)
| Stop(expr? code)
| Assert(expr test, expr? msg)
| SubroutineCall(symbol name, symbol? original_name, call_arg* args, expr? dt)
Expand Down Expand Up @@ -490,7 +490,8 @@ array_index = (expr? left, expr? right, expr? step)

do_loop_head = (expr? v, expr? start, expr? end, expr? increment)

case_stmt = CaseStmt(expr* test, stmt* body) | CaseStmt_Range(expr? start, expr? end, stmt* body)
case_stmt = CaseStmt(expr* test, stmt* body, bool fall_through) |
CaseStmt_Range(expr? start, expr? end, stmt* body)

type_stmt
= TypeStmtName(symbol sym, stmt* body)
Expand Down
20 changes: 16 additions & 4 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
current_symtab = nullptr;
}

void visit_Select(const Select_t& x) {
bool fall_through = false;
for( int i = 0; i < x.n_body; i++ ) {
ASR::CaseStmt_t* case_stmt_t = ASR::down_cast<ASR::CaseStmt_t>(x.m_body[i]);
fall_through = fall_through || case_stmt_t->m_fall_through;
}
require(fall_through == x.m_enable_fall_through,
"Select_t::m_enable_fall_through should be " +
std::to_string(x.m_enable_fall_through));
Copy link
Contributor

@certik certik Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an even stronger requirement, which is fine. It's always good to start strong, and relax later if needed, than the other way round.

BaseWalkVisitor<VerifyVisitor>::visit_Select(x);
}

// --------------------------------------------------------
// symbol instances:

Expand Down Expand Up @@ -889,7 +901,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
}

SymbolTable* temp_scope = current_symtab;

if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() &&
!ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) && !ASR::is_a<ASR::Variable_t>(*x.m_name)) {
if (ASR::is_a<ASR::AssociateBlock_t>(*asr_owner_sym) || ASR::is_a<ASR::Block_t>(*asr_owner_sym)) {
Expand All @@ -899,7 +911,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
}
} else {
function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name)));
}
}
}

if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) ) {
Expand Down Expand Up @@ -1040,7 +1052,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
}

SymbolTable* temp_scope = current_symtab;

if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() &&
!ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) && !ASR::is_a<ASR::Variable_t>(*x.m_name)) {
if (ASR::is_a<ASR::AssociateBlock_t>(*asr_owner_sym) || ASR::is_a<ASR::Block_t>(*asr_owner_sym)) {
Expand All @@ -1050,7 +1062,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
}
} else {
function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name)));
}
}
}

if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) ) {
Expand Down
69 changes: 68 additions & 1 deletion src/libasr/pass/select_case.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,69 @@ Vec<ASR::stmt_t*> replace_selectcase(Allocator &al, const ASR::Select_t &select_
return body;
}

void case_to_if_with_fall_through(Allocator& al, const ASR::Select_t& x,
ASR::expr_t* a_test, Vec<ASR::stmt_t*>& body, SymbolTable* scope) {
body.reserve(al, x.n_body + 1);
const Location& loc = x.base.base.loc;
ASR::symbol_t* case_found_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, scope, s2c(al, scope->get_unique_name("case_found")), nullptr, 0,
ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default,
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, ASR::abiType::Source,
ASR::accessType::Public, ASR::presenceType::Required, false));
scope->add_symbol(scope->get_unique_name("case_found"), case_found_sym);
ASR::expr_t* true_asr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true,
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4))));
ASR::expr_t* false_asr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false,
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4))));
ASR::expr_t* case_found = ASRUtils::EXPR(ASR::make_Var_t(al, loc, case_found_sym));
body.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, loc, case_found, false_asr, nullptr)));
int label_id = ASRUtils::LabelGenerator::get_instance()->get_unique_label();
for( int idx = 0; idx < x.n_body; idx++ ) {
ASR::case_stmt_t* case_body = x.m_body[idx];
switch(case_body->type) {
case ASR::case_stmtType::CaseStmt: {
ASR::CaseStmt_t* Case_Stmt = ASR::down_cast<ASR::CaseStmt_t>(case_body);
ASR::expr_t* test_expr = gen_test_expr_CaseStmt(al, loc, Case_Stmt, a_test);
test_expr = ASRUtils::EXPR(ASR::make_LogicalBinOp_t(al, loc, test_expr,
ASR::logicalbinopType::Or, case_found, ASRUtils::expr_type(case_found), nullptr));
Vec<ASR::stmt_t*> case_body; case_body.reserve(al, Case_Stmt->n_body);
case_body.from_pointer_n(Case_Stmt->m_body, Case_Stmt->n_body);
case_body.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(
al, loc, case_found, true_asr, nullptr)));
if( !Case_Stmt->m_fall_through ) {
case_body.push_back(al, ASRUtils::STMT(ASR::make_GoTo_t(al, loc,
label_id, s2c(al, scope->get_unique_name("switch_case_label")))));
}
body.push_back(al, ASRUtils::STMT(ASR::make_If_t(al, loc,
test_expr, case_body.p, case_body.size(), nullptr, 0)));
break;
}
case ASR::case_stmtType::CaseStmt_Range: {
LCOMPILERS_ASSERT(false);
break;
}
}
}
for( int id = 0; id < x.n_default; id++ ) {
body.push_back(al, x.m_default[id]);
}
SymbolTable* block_symbol_table = al.make_new<SymbolTable>(scope);
ASR::symbol_t* empty_block = ASR::down_cast<ASR::symbol_t>(ASR::make_Block_t(
al, loc, block_symbol_table, s2c(al, scope->get_unique_name("switch_case_label")),
nullptr, 0));
scope->add_symbol(scope->get_unique_name("switch_case_label"), empty_block);
body.push_back(al, ASRUtils::STMT(ASR::make_BlockCall_t(al, loc, label_id, empty_block)));
}

Vec<ASR::stmt_t*> replace_selectcase_with_fall_through(
Allocator &al, const ASR::Select_t &select_case,
SymbolTable* scope) {
ASR::expr_t *a = select_case.m_test;
Vec<ASR::stmt_t*> body;
case_to_if_with_fall_through(al, select_case, a, body, scope);
return body;
}

class SelectCaseVisitor : public PassUtils::PassVisitor<SelectCaseVisitor>
{

Expand All @@ -165,7 +228,11 @@ class SelectCaseVisitor : public PassUtils::PassVisitor<SelectCaseVisitor>
}

void visit_Select(const ASR::Select_t &x) {
pass_result = replace_selectcase(al, x);
if( x.m_enable_fall_through ) {
pass_result = replace_selectcase_with_fall_through(al, x, current_scope);
} else {
pass_result = replace_selectcase(al, x);
}
}
};

Expand Down
Loading