diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 1133f9c..7d022f8 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -198,3 +198,5 @@ RUN(NAME array_09.cpp LABELS gcc llvm NOFAST EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include) RUN(NAME array_10.cpp LABELS gcc llvm NOFAST EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include) +RUN(NAME array_11.cpp LABELS gcc llvm NOFAST + EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include) diff --git a/integration_tests/array_11.cpp b/integration_tests/array_11.cpp new file mode 100644 index 0000000..916446e --- /dev/null +++ b/integration_tests/array_11.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include "xtensor/xio.hpp" +#include "xtensor/xview.hpp" + +int main() { + xt::xtensor_fixed> R; + xt::xtensor_fixed> V; + xt::xtensor_fixed> U; + R.fill(23); + V.fill(9); + U.fill(1); + + xt::view(R, xt::range(1, 4, 1)) = xt::view(V, xt::range(1, 4)) * 1 * xt::view(U, xt::range(1, 4)); + + std::cout<< R << std::endl; + if (R(0) != 23) { + exit(2); + } + if (R(1) != 9) { + exit(2); + } + if (R(2) != 9) { + exit(2); + } + if (R(3) != 9) { + exit(2); + } + + return 0; +} diff --git a/src/lc/clang_ast_to_asr.h b/src/lc/clang_ast_to_asr.h index d391533..565020f 100644 --- a/src/lc/clang_ast_to_asr.h +++ b/src/lc/clang_ast_to_asr.h @@ -44,6 +44,7 @@ enum SpecialFunc { Abs, AMax, Sum, + Range, }; std::map special_function_map = { @@ -60,6 +61,7 @@ std::map special_function_map = { {"abs", SpecialFunc::Abs}, {"amax", SpecialFunc::AMax}, {"sum", SpecialFunc::Sum}, + {"range", SpecialFunc::Range}, }; class OneTimeUseString { @@ -129,14 +131,15 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor* print_args; - bool is_all_called; + bool is_all_called, is_range_called; + OneTimeUseASRNode range_start, range_end, range_step; explicit ClangASTtoASRVisitor(clang::ASTContext *Context_, Allocator& al_, ASR::asr_t*& tu_): Context(Context_), al{al_}, tu{tu_}, current_body{nullptr}, is_stmt_created{true}, assignment_target{nullptr}, print_args{nullptr}, - is_all_called{false} {} + is_all_called{false}, is_range_called{false} {} template Location Lloc(T *x) { @@ -638,49 +641,78 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor args; args.reserve(al, 1); - bool skip_format_str = true; - ASR::expr_t* assignment_target_copy = assignment_target; - assignment_target = nullptr; - for (auto *p : x->arguments()) { - TraverseStmt(p); - if (sf == SpecialFunc::Printf && skip_format_str) { - skip_format_str = false; - continue; - } - if( (tmp == nullptr && is_all_called) || - (tmp == nullptr || p->getStmtClass() == - clang::Stmt::StmtClass::CXXDefaultArgExprClass ) ) { - args.push_back(al, nullptr); - is_all_called = false; - } else { - ASR::asr_t* tmp_ = tmp.get(); - args.push_back(al, ASRUtils::EXPR(tmp_)); + if( sf != SpecialFunc::View ) { + bool skip_format_str = true; + ASR::expr_t* assignment_target_copy = assignment_target; + assignment_target = nullptr; + for (auto *p : x->arguments()) { + TraverseStmt(p); + if (sf == SpecialFunc::Printf && skip_format_str) { + skip_format_str = false; + continue; + } + if( (tmp == nullptr || p->getStmtClass() == + clang::Stmt::StmtClass::CXXDefaultArgExprClass ) ) { + args.push_back(al, nullptr); + } else { + ASR::asr_t* tmp_ = tmp.get(); + args.push_back(al, ASRUtils::EXPR(tmp_)); + } } + assignment_target = assignment_target_copy; } - assignment_target = assignment_target_copy; if (sf == SpecialFunc::Printf) { tmp = ASR::make_Print_t(al, Lloc(x), args.p, args.size(), nullptr, nullptr); is_stmt_created = true; } else if (sf == SpecialFunc::View) { - ASR::expr_t* array = args.p[0]; + clang::Expr** view_args = x->getArgs(); + size_t view_nargs = x->getNumArgs(); + ASR::expr_t* assignment_target_copy = assignment_target; + assignment_target = nullptr; + TraverseStmt(view_args[0]); + assignment_target = assignment_target_copy; + ASR::expr_t* array = ASRUtils::EXPR(tmp.get()); size_t rank = ASRUtils::extract_n_dims_from_ttype(ASRUtils::expr_type(array)); Vec array_section_indices; array_section_indices.reserve(al, rank); size_t i, j, result_dims = 0; - for( i = 0, j = 1; j < args.size(); j++, i++ ) { + for( i = 0, j = 1; j < view_nargs; j++, i++ ) { + ASR::expr_t* assignment_target_copy = assignment_target; + assignment_target = nullptr; + TraverseStmt(view_args[j]); + assignment_target = assignment_target_copy; ASR::array_index_t index; - if( args.p[j] == nullptr ) { + if( (tmp == nullptr && (is_all_called || is_range_called)) || + (tmp == nullptr || view_args[j]->getStmtClass() == + clang::Stmt::StmtClass::CXXDefaultArgExprClass ) ) { index.loc = array->base.loc; - index.m_left = ASRUtils::get_bound(array, i + 1, "lbound", al); - index.m_right = ASRUtils::get_bound(array, i + 1, "ubound", al); - index.m_step = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, index.loc, 1, - ASRUtils::TYPE(ASR::make_Integer_t(al, index.loc, 4)))); - array_section_indices.push_back(al, index); + if( is_range_called ) { + index.m_left = range_start.get(); + ASR::expr_t* range_end_ = range_end.get(); + CreateBinOp(range_end_, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, index.loc, + 1, ASRUtils::expr_type(range_end_))), + ASR::binopType::Sub, index.loc); + index.m_right = ASRUtils::EXPR(tmp.get()); + index.m_step = range_step.get(); + array_section_indices.push_back(al, index); + } else if( is_all_called ) { + index.m_left = ASRUtils::get_bound(array, i + 1, "lbound", al); + index.m_right = ASRUtils::get_bound(array, i + 1, "ubound", al); + index.m_step = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, index.loc, 1, + ASRUtils::TYPE(ASR::make_Integer_t(al, index.loc, 4)))); + array_section_indices.push_back(al, index); + } else { + throw std::runtime_error("Neither xt::range nor xt::all is being called for slicing array."); + } + is_range_called = false; + is_all_called = false; result_dims += 1; } else { - index.loc = args.p[j]->base.loc; + ASR::expr_t* arg_ = ASRUtils::EXPR(tmp.get()); + index.loc = arg_->base.loc; index.m_left = nullptr; - index.m_right = args.p[j]; + index.m_right = arg_; index.m_step = nullptr; array_section_indices.push_back(al, index); } @@ -744,6 +776,20 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor 2 ) { + range_step = args.p[2]; + } else { + range_step = ASRUtils::EXPR(ASR::make_IntegerConstant_t( + al, args.p[0]->base.loc, 1, ASRUtils::expr_type(args.p[0]))); + } + is_range_called = true; + tmp = nullptr; } else if( sf == SpecialFunc::Any ) { tmp = ASRUtils::make_IntrinsicArrayFunction_t_util(al, Lloc(x), static_cast(ASRUtils::IntrinsicArrayFunctions::Any), @@ -1259,7 +1305,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor" ) { + name == "operator-" || name == "operator/" || name == "operator>" || + name == "range" ) { if( sym != nullptr && ASR::is_a( *ASRUtils::symbol_get_past_external(sym)) ) { throw std::runtime_error("Special function " + name + " cannot be overshadowed yet.");