Skip to content

Commit

Permalink
Merge pull request #2118 from czgdp1807/array
Browse files Browse the repository at this point in the history
Accept ``dtype`` argument in ``numpy.array``
  • Loading branch information
czgdp1807 authored Aug 2, 2023
2 parents c3314f7 + 4c80fc6 commit 9aefff4
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ RUN(NAME variable_decl_02 LABELS cpython llvm c)
RUN(NAME variable_decl_03 LABELS cpython llvm c)
RUN(NAME array_expr_01 LABELS cpython llvm c)
RUN(NAME array_expr_02 LABELS cpython llvm c NOFAST)
RUN(NAME array_expr_03 LABELS cpython llvm c)
RUN(NAME array_size_01 LABELS cpython llvm c)
RUN(NAME array_size_02 LABELS cpython llvm c)
RUN(NAME array_01 LABELS cpython llvm wasm c)
Expand Down
24 changes: 24 additions & 0 deletions integration_tests/array_expr_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from lpython import i8, i32, dataclass
from numpy import empty, int8, array


@dataclass
class LPBHV_small:
dim: i32 = 4
a: i8[4] = empty(4, dtype=int8)


def g():
l2: LPBHV_small = LPBHV_small(4, array([127, -127, 3, 111], dtype=int8))

print(l2.dim)
assert l2.dim == 4

print(l2.a[0], l2.a[1], l2.a[2], l2.a[3])
assert l2.a[0] == i8(127)
assert l2.a[1] == i8(-127)
assert l2.a[2] == i8(3)
assert l2.a[3] == i8(111)


g()
46 changes: 43 additions & 3 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ namespace CastingUtil {
}
cast_kind = type_rules.at(cast_key);
}
if( ASRUtils::check_equal_type(src, dest, true) ) {
return expr;
}
// TODO: Fix loc
return ASRUtils::EXPR(ASRUtils::make_Cast_t_value(al, loc, expr,
cast_kind, dest));
Expand Down Expand Up @@ -505,6 +508,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
// Stores the name of imported functions and the modules they are imported from
std::map<std::string, std::string> imported_functions;

std::map<std::string, std::string> numpy2lpythontypes = {
{"int8", "i8"},
};

CommonVisitor(Allocator &al, LocationManager &lm, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module, std::string module_name,
std::map<int, ASR::symbol_t*> &ast_overload, std::string parent_dir,
Expand Down Expand Up @@ -7520,16 +7527,45 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value);
return;
} else if( call_name == "array" ) {
parse_args(x, args);
ASR::ttype_t* type = nullptr;
if( x.n_keywords == 0 ) {
parse_args(x, args);
} else {
args.reserve(al, 1);
visit_expr_list(x.m_args, x.n_args, args);
if( x.n_keywords > 1 ) {
throw SemanticError("More than one keyword "
"arguments aren't recognised by array",
x.base.base.loc);
}
if( std::string(x.m_keywords[0].m_arg) != "dtype" ) {
throw SemanticError("Unrecognised keyword argument, " +
std::string(x.m_keywords[0].m_arg), x.base.base.loc);
}
std::string dtype_np = "";
if( AST::is_a<AST::Name_t>(*x.m_keywords[0].m_value) ) {
AST::Name_t* name_t = AST::down_cast<AST::Name_t>(x.m_keywords[0].m_value);
dtype_np = name_t->m_id;
} else {
LCOMPILERS_ASSERT(false);
}
LCOMPILERS_ASSERT(numpy2lpythontypes.find(dtype_np) != numpy2lpythontypes.end());
Vec<ASR::dimension_t> dims;
dims.n = 0;
type = get_type_from_var_annotation(
numpy2lpythontypes[dtype_np], x.base.base.loc, dims);
}
if( args.size() != 1 ) {
throw SemanticError("array accepts only 1 argument for now, got " +
std::to_string(args.size()) + " arguments instead.",
x.base.base.loc);
}
ASR::expr_t *arg = args[0].m_value;
ASR::ttype_t *type = ASRUtils::expr_type(arg);
if( type == nullptr ) {
type = ASRUtils::expr_type(arg);
}
if(ASR::is_a<ASR::ListConstant_t>(*arg)) {
type = ASR::down_cast<ASR::List_t>(type)->m_type;
type = ASRUtils::get_contained_type(type);
ASR::ListConstant_t* list = ASR::down_cast<ASR::ListConstant_t>(arg);
ASR::expr_t **m_args = list->m_args;
size_t n_args = list->n_args;
Expand All @@ -7544,6 +7580,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
dims.push_back(al, dim);
type = ASRUtils::make_Array_t_util(al, x.base.base.loc, type, dims.p, dims.size(),
ASR::abiType::Source, false, ASR::array_physical_typeType::PointerToDataArray, true);
for( size_t i = 0; i < n_args; i++ ) {
m_args[i] = CastingUtil::perform_casting(m_args[i], ASRUtils::expr_type(m_args[i]),
ASRUtils::type_get_past_array(type), al, x.base.base.loc);
}
tmp = ASR::make_ArrayConstant_t(al, x.base.base.loc, m_args, n_args, type, ASR::arraystorageType::RowMajor);
} else {
throw SemanticError("array accepts only list for now, got " +
Expand Down

0 comments on commit 9aefff4

Please sign in to comment.