Skip to content

Commit

Permalink
Merge pull request #27 from czgdp1807/lc_18
Browse files Browse the repository at this point in the history
Support for `xt::xtensor<..., rank>` analogous to allocatables
  • Loading branch information
czgdp1807 authored Dec 20, 2023
2 parents 9cbf89f + 12d6345 commit 9a60262
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 80 deletions.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ if (NOT (CMAKE_BUILD_TYPE STREQUAL "Debug" OR
message(FATAL_ERROR "CMAKE_BUILD_TYPE must be one of: Debug, Release (current value: '${CMAKE_BUILD_TYPE}')")
endif ()

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
# In Debug mode we enable assertions
set(WITH_LFORTRAN_ASSERT_DEFAULT yes)
else()
set(WITH_LFORTRAN_ASSERT_DEFAULT no)
endif()

# LFORTRAN_ASSERT
set(WITH_LFORTRAN_ASSERT ${WITH_LFORTRAN_ASSERT_DEFAULT}
CACHE BOOL "Enable LFORTRAN_ASSERT macro")

if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17
CACHE STRING "C++ standard" FORCE)
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ RUN(NAME array_02.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CMAKE_CURRENT_SOURCE_DIR}/../src/runtime/include)
RUN(NAME array_03.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_04.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_05.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_06.cpp LABELS gcc llvm NOFAST
Expand Down
10 changes: 7 additions & 3 deletions integration_tests/array_04.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
#include <iostream>
#include <xtensor/xfixed.hpp>
#include "xtensor/xtensor.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

int main() {

xt::xtensor<double, 2> arr1 = {
xt::xtensor<double, 2> arr1 {
{1.0, 2.0, 3.0},
{2.0, 5.0, 7.0},
{2.0, 5.0, 7.0}};

xt::xtensor<double, 1> arr2 {5.0, 6.0, 7.0};
xt::xtensor<double, 1> res;

// xt::xarray<double> res = xt::view(arr1, 1) + arr2; // TODO: Uncomment this statement
std::cout << arr2;
res = xt::empty<double>({3});
res = xt::view(arr1, 1) + arr2;
std::cout<< arr1 << arr2 <<std::endl;
std::cout << res << std::endl;

return 0;
}
1 change: 1 addition & 0 deletions src/bin/lc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ namespace LCompilers {
} \
compiler_options.po.always_run = true; \
compiler_options.po.run_fun = "f"; \
compiler_options.po.realloc_lhs = true; \
diagnostics.diagnostics.clear(); \

int emit_wat(Allocator &al, std::string &infile, LCompilers::ASR::TranslationUnit_t *asr) {
Expand Down
79 changes: 66 additions & 13 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#ifndef CLANG_AST_TO_ASR_H
#define CLANG_AST_TO_ASR_H

#define WITH_LFORTRAN_ASSERT

#include <clang/AST/ASTConsumer.h>
#include <clang/AST/RecursiveASTVisitor.h>
#include <clang/Frontend/CompilerInstance.h>
Expand Down Expand Up @@ -36,13 +34,15 @@ enum SpecialFunc {
Exit,
View,
Shape,
Empty,
};

std::map<std::string, SpecialFunc> special_function_map = {
{"printf", SpecialFunc::Printf},
{"exit", SpecialFunc::Exit},
{"view", SpecialFunc::View},
{"shape", SpecialFunc::Shape},
{"empty", SpecialFunc::Empty}
};

class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisitor> {
Expand Down Expand Up @@ -124,7 +124,6 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

std::string main_func = "main";
ASR::symbol_t *main_sym = tu->m_symtab->resolve_symbol(main_func);
LCOMPILERS_ASSERT(main_sym);
if (main_sym == nullptr) {
return;
}
Expand Down Expand Up @@ -494,11 +493,15 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
clang::Expr** args = x->getArgs();
TraverseStmt(args[0]);
ASR::expr_t* obj = ASRUtils::EXPR(tmp);
assignment_target = obj;
if( ASRUtils::is_array(ASRUtils::expr_type(obj)) ) {
TraverseStmt(args[1]);
ASR::expr_t* value = ASRUtils::EXPR(tmp);
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
is_stmt_created = true;
if( !is_stmt_created ) {
ASR::expr_t* value = ASRUtils::EXPR(tmp);
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
is_stmt_created = true;
}
assignment_target = nullptr;
} else {
throw std::runtime_error("operator= is supported only for arrays.");
}
Expand Down Expand Up @@ -543,6 +546,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
Vec<ASR::expr_t*> args;
args.reserve(al, 1);
bool skip_format_str = true;
ASR::expr_t* assignment_target_copy = assignment_target;
assignment_target = nullptr;
for (auto *p : x->arguments()) {
TraverseStmt(p);
if (sf == SpecialFunc::Printf && skip_format_str) {
Expand All @@ -551,6 +556,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
args.push_back(al, ASRUtils::EXPR(tmp));
}
assignment_target = assignment_target_copy;
if (sf == SpecialFunc::Printf) {
tmp = ASR::make_Print_t(al, Lloc(x), args.p, args.size(), nullptr, nullptr);
is_stmt_created = true;
Expand Down Expand Up @@ -614,6 +620,46 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
LCOMPILERS_ASSERT(args.size() == 1);
tmp = ASR::make_Stop_t(al, Lloc(x), args[0]);
is_stmt_created = true;
} else if (sf == SpecialFunc::Empty) {
if( args.size() != 1 ) {
throw std::runtime_error("xt::empty must be provided with shape.");
}
if( assignment_target == nullptr ) {
throw std::runtime_error("xt::empty should be used only in assignment statement.");
}
if( !ASRUtils::is_allocatable(assignment_target) ) {
throw std::runtime_error("Assignment target must be an alloctable");
}

if( ASR::is_a<ASR::ArrayConstant_t>(*args.p[0]) ) {
ASR::ArrayConstant_t* array_constant = ASR::down_cast<ASR::ArrayConstant_t>(args.p[0]);
size_t target_rank = ASRUtils::extract_n_dims_from_ttype(
ASRUtils::expr_type(assignment_target));
if( array_constant->n_args != target_rank ) {
throw std::runtime_error("Assignment target must be of same rank as the size of the shape array.");
}

Vec<ASR::alloc_arg_t> alloc_args; alloc_args.reserve(al, 1);
ASR::alloc_arg_t alloc_arg; alloc_arg.loc = Lloc(x);
alloc_arg.m_a = assignment_target;
alloc_arg.m_len_expr = nullptr; alloc_arg.m_type = nullptr;
Vec<ASR::dimension_t> alloc_dims; alloc_dims.reserve(al, target_rank);
for( size_t i = 0; i < target_rank; i++ ) {
ASR::dimension_t alloc_dim;
alloc_dim.loc = Lloc(x);
alloc_dim.m_length = array_constant->m_args[i];
alloc_dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(
al, Lloc(x), 0, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
alloc_dims.push_back(al, alloc_dim);
}
alloc_arg.m_dims = alloc_dims.p; alloc_arg.n_dims = alloc_dims.size();
alloc_args.push_back(al, alloc_arg);
tmp = ASR::make_Allocate_t(al, Lloc(x), alloc_args.p, alloc_args.size(),
nullptr, nullptr, nullptr);
is_stmt_created = true;
} else {
throw std::runtime_error("Only {...} is allowed for supplying shape to xt::empty.");
}
} else {
throw std::runtime_error("Only printf and exit special functions supported");
}
Expand Down Expand Up @@ -648,12 +694,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
tmp = ASRUtils::make_SubroutineCall_t_util(al, Lloc(x), callee_sym,
callee_sym, call_args.p, call_args.size(), nullptr,
nullptr, false);
is_stmt_created = true;
} else {
tmp = ASRUtils::make_FunctionCall_t_util(al, Lloc(x), callee_sym,
callee_sym, call_args.p, call_args.size(), return_type,
nullptr, nullptr);
}
is_stmt_created = true;
return true;
}

Expand Down Expand Up @@ -733,9 +779,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASR::dimension_t *expr_dims = nullptr, *target_expr_dims = nullptr;
size_t expr_rank = ASRUtils::extract_dimensions_from_ttype(expr_type, expr_dims);
size_t target_expr_rank = ASRUtils::extract_dimensions_from_ttype(target_expr_type, target_expr_dims);
if( expr_rank == target_expr_rank ||
ASRUtils::extract_physical_type(target_expr_type) ==
ASR::array_physical_typeType::FixedSizeArray ) {
if( expr_rank == target_expr_rank ) {
return ;
}

Expand All @@ -757,8 +801,17 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
new_shape_dims.size(), ASR::array_physical_typeType::FixedSizeArray));
ASR::expr_t* new_shape = ASRUtils::EXPR(ASR::make_ArrayConstant_t(al, loc,
new_shape_.p, new_shape_.size(), new_shape_type, ASR::arraystorageType::RowMajor));
new_shape = ASRUtils::cast_to_descriptor(al, new_shape);

ASR::ttype_t* reshaped_expr_type = target_expr_type;
if( ASRUtils::is_fixed_size_array(expr_type) ) {
reshaped_expr_type = ASRUtils::duplicate_type_with_empty_dims(al,
ASRUtils::type_get_past_allocatable(
ASRUtils::type_get_past_pointer(target_expr_type)),
ASR::array_physical_typeType::FixedSizeArray, true);
}
ASR::expr_t* reshaped_expr = ASRUtils::EXPR(ASR::make_ArrayReshape_t(al, loc, expr,
new_shape, target_expr_type, nullptr));
new_shape, reshaped_expr_type, nullptr));
expr = reshaped_expr;
}

Expand Down Expand Up @@ -805,7 +858,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return ;
}

LCOMPILERS_ASSERT(ASR::is_a<ASR::ArrayConstant_t>(array_constant));
LCOMPILERS_ASSERT(ASR::is_a<ASR::ArrayConstant_t>(*array_constant));
ASR::ArrayConstant_t* array_constant_t = ASR::down_cast<ASR::ArrayConstant_t>(array_constant);
Vec<ASR::expr_t*> new_elements; new_elements.reserve(al, array_constant_t->n_args);
for( size_t i = 0; i < array_constant_t->n_args; i++ ) {
Expand Down Expand Up @@ -1011,7 +1064,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
if( name == "operator<<" || name == "cout" ||
name == "endl" || name == "operator()" ||
name == "operator+" || name == "operator=" ||
name == "view" ) {
name == "view" || name == "empty" ) {
cxx_operator_name = name;
return true;
}
Expand Down
16 changes: 15 additions & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4752,7 +4752,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_value_data_only_array = (value_ptype == ASR::array_physical_typeType::PointerToDataArray);
bool is_target_fixed_sized_array = (target_ptype == ASR::array_physical_typeType::FixedSizeArray);
bool is_value_fixed_sized_array = (value_ptype == ASR::array_physical_typeType::FixedSizeArray);
// bool is_target_descriptor_based_array = (target_ptype == ASR::array_physical_typeType::DescriptorArray);
bool is_target_descriptor_based_array = (target_ptype == ASR::array_physical_typeType::DescriptorArray);
bool is_value_descriptor_based_array = (value_ptype == ASR::array_physical_typeType::DescriptorArray);
if( is_value_fixed_sized_array && is_target_fixed_sized_array ) {
value = llvm_utils->create_gep(value, 0);
Expand Down Expand Up @@ -4782,6 +4782,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_size = builder->CreateMul(llvm_size,
llvm::ConstantInt::get(context, llvm::APInt(32, data_size)));
builder->CreateMemCpy(target, llvm::MaybeAlign(), value, llvm::MaybeAlign(), llvm_size);
} else if( is_target_descriptor_based_array && is_value_fixed_sized_array ) {
if( ASRUtils::is_allocatable(target_type) ) {
target = LLVM::CreateLoad(*builder, target);
}
llvm::Value* llvm_size = arr_descr->get_array_size(target, nullptr, 4);
target = LLVM::CreateLoad(*builder, arr_descr->get_pointer_to_data(target));
value = llvm_utils->create_gep(value, 0);
llvm::Type* llvm_data_type = llvm_utils->get_type_from_ttype_t_util(ASRUtils::type_get_past_array(
ASRUtils::type_get_past_allocatable(ASRUtils::type_get_past_pointer(value_type))), module.get());
llvm::DataLayout data_layout(module.get());
uint64_t data_size = data_layout.getTypeAllocSize(llvm_data_type);
llvm_size = builder->CreateMul(llvm_size,
llvm::ConstantInt::get(context, llvm::APInt(32, data_size)));
builder->CreateMemCpy(target, llvm::MaybeAlign(), value, llvm::MaybeAlign(), llvm_size);
} else if( is_target_data_only_array || is_value_data_only_array ) {
if( is_value_fixed_sized_array ) {
value = llvm_utils->create_gep(value, 0);
Expand Down
23 changes: 15 additions & 8 deletions src/libasr/codegen/llvm_array_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,22 @@ namespace LCompilers {
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, n_dims)), get_rank(arr, true));
builder->CreateStore(dim_des_first, dim_des_val);
dim_des_val = LLVM::CreateLoad(*builder, dim_des_val);
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
for( int r = 0; r < n_dims; r++ ) {
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
llvm::Value* l_val = llvm_utils->create_gep(dim_val, 1);
llvm::Value* dim_size_ptr = llvm_utils->create_gep(dim_val, 2);
builder->CreateStore(prod, s_val);
builder->CreateStore(llvm_dims[r].first, l_val);
llvm::Value* dim_size = llvm_dims[r].second;
prod = builder->CreateMul(prod, dim_size);
builder->CreateStore(dim_size, dim_size_ptr);
}

llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
for( int r = n_dims - 1; r >= 0; r-- ) {
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
builder->CreateStore(prod, s_val);
llvm::Value* dim_size = llvm_dims[r].second;
prod = builder->CreateMul(prod, dim_size);
}
if( !reserve_data_memory ) {
return ;
}
Expand Down Expand Up @@ -340,16 +343,20 @@ namespace LCompilers {
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)),
offset_val);
llvm::Value* dim_des_val = LLVM::CreateLoad(*builder, llvm_utils->create_gep(arr, 2));
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
for( int r = 0; r < n_dims; r++ ) {
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
llvm::Value* l_val = llvm_utils->create_gep(dim_val, 1);
llvm::Value* dim_size_ptr = llvm_utils->create_gep(dim_val, 2);
builder->CreateStore(prod, s_val);
builder->CreateStore(llvm_dims[r].first, l_val);
llvm::Value* dim_size = llvm_dims[r].second;
builder->CreateStore(dim_size, dim_size_ptr);
}
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
for( int r = n_dims - 1; r >= 0; r-- ) {
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
builder->CreateStore(prod, s_val);
llvm::Value* dim_size = llvm_dims[r].second;
prod = builder->CreateMul(prod, dim_size);
}
llvm::Value* ptr2firstptr = get_pointer_to_data(arr);
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-array_01-9c6ecba.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_01-9c6ecba.stdout",
"stdout_hash": "141c00a55703c54d1d113b441b052bd1ac7dfc8cc1fb73b243532bfd",
"stdout_hash": "1e36705289359e234e0543e1437e4c9decfcd771f8d6e7e297a8948a",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading

0 comments on commit 9a60262

Please sign in to comment.