Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added static methods for classes #2721

Closed
wants to merge 14 commits into from
Closed
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ RUN(NAME callback_03 LABELS cpython llvm llvm_jit c)
RUN(NAME lambda_01 LABELS cpython llvm llvm_jit)

RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
RUN(NAME class_01 LABELS cpython llvm llvm_jit)

# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
28 changes: 28 additions & 0 deletions integration_tests/class_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from lpython import i32,i64

def fn_2():
print("Inside fn_2")
return

class Test:
mem : i64 = i64(5)
s : str = "abc"
def fn_1():
Copy link
Collaborator

@Thirumalai-Shaktivel Thirumalai-Shaktivel Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the values or texts are not being test, so I suggest to pass some arguments to the functions and check the computed return values.

print("Inside fn_1")
print("fn_2 called")
fn_2()
print("fn_3 called")
Test.fn_3()
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return
return

No extra spaces as well

def fn_3():
print("Inside fn_3")

def main():
t: Test = Test()
print(t.mem)
assert t.mem == i64(5)
print(t.s)
assert t.s == "abc"
Test.fn_1()

main()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: make sure to add a newline at the end of the file.

4 changes: 3 additions & 1 deletion src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,13 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
ASR::is_a<ASR::StructType_t>(*a.second) ||
ASR::is_a<ASR::UnionType_t>(*a.second) ||
ASR::is_a<ASR::ExternalSymbol_t>(*a.second) ||
ASR::is_a<ASR::CustomOperator_t>(*a.second) ) {
ASR::is_a<ASR::CustomOperator_t>(*a.second) ||
ASR::is_a<ASR::Function_t>(*a.second)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue here? Why is Function now skipped?

continue ;
}
// TODO: Uncomment the following line
// ASR::ttype_t* var_type = ASRUtils::extract_type(ASRUtils::symbol_type(a.second));

ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(a.second));
char* aggregate_type_name = nullptr;
ASR::symbol_t* sym = nullptr;
Expand Down
32 changes: 32 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void visit_StructType(const ASR::StructType_t &x){
SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
for (auto &item : x.m_symtab->get_scope()){
if (is_a<ASR::Function_t>(*item.second)){
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
instantiate_function(*v);
}
}
current_scope = current_scope_copy;
}

void make_struct_f_def(const ASR::StructType_t &x){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have time, try simplifying this function: visit_StructType and make_struct_f_def.
They seem duplicated

SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
for (auto &item : x.m_symtab->get_scope()){
if (is_a<ASR::Function_t>(*item.second)){
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
visit_Function(*v);
}
}
current_scope = current_scope_copy;
}

void start_module_init_function_prototype(const ASR::Module_t &x) {
uint32_t h = get_hash((ASR::asr_t*)&x);
llvm::FunctionType *function_type = llvm::FunctionType::get(
Expand Down Expand Up @@ -3055,6 +3079,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else if (is_a<ASR::EnumType_t>(*item.second)) {
ASR::EnumType_t *et = down_cast<ASR::EnumType_t>(item.second);
visit_EnumType(*et);
} else if (is_a<ASR::StructType_t>(*item.second)) {
mangle_prefix = "";
ASR::StructType_t *st = down_cast<ASR::StructType_t>(item.second);
visit_StructType(*st);
mangle_prefix = "__module_" + std::string(x.m_name) + "_";
}
}
finish_module_init_function_prototype(x);
Expand Down Expand Up @@ -4048,6 +4077,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
visit_Function(*s);
}else if(is_a<ASR::StructType_t>(*item.second)) {
ASR::StructType_t *st = down_cast<ASR::StructType_t>(item.second);
make_struct_f_def(*st);
}
}
}
Expand Down
188 changes: 108 additions & 80 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2493,13 +2493,11 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
return false;
}

bool is_dataclass(AST::expr_t** decorators, size_t n,
void get_alignment(AST::expr_t** decorators, size_t n,
ASR::expr_t*& aligned_expr, bool& is_packed) {
bool is_dataclass_ = false;
for( size_t i = 0; i < n; i++ ) {
if( AST::is_a<AST::Name_t>(*decorators[i]) ) {
AST::Name_t* dc_name = AST::down_cast<AST::Name_t>(decorators[i]);
is_dataclass_ = std::string(dc_name->m_id) == "dataclass";
is_packed = is_packed || std::string(dc_name->m_id) == "packed";
} else if( AST::is_a<AST::Call_t>(*decorators[i]) ) {
AST::Call_t* dc_call = AST::down_cast<AST::Call_t>(decorators[i]);
Expand Down Expand Up @@ -2532,7 +2530,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
}

return is_dataclass_;
return;
}

bool is_enum(AST::expr_t** bases, size_t n) {
Expand Down Expand Up @@ -2870,9 +2868,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
if (is_const) {
storage_type = ASR::storage_typeType::Parameter;
}

create_add_variable_to_scope(var_name, type,
x.base.base.loc, abi, storage_type);
if( !(inside_struct && current_scope->resolve_symbol(var_name)) ){
create_add_variable_to_scope(var_name, type,
x.base.base.loc, abi, storage_type);
}

ASR::expr_t* assign_asr_target_copy = assign_asr_target;
this->visit_expr(*x.m_target);
Expand Down Expand Up @@ -2926,7 +2925,29 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
void visit_ClassMembers(const AST::ClassDef_t& x,
Vec<char*>& member_names, SetChar& struct_dependencies,
Vec<ASR::call_arg_t> &member_init,
bool is_enum_scope=false, ASR::abiType abi=ASR::abiType::Source) {
bool is_enum_scope=false, ASR::abiType abi=ASR::abiType::Source,
bool is_generating_body = false) {
if(is_generating_body && !is_enum_scope){
for( size_t i = 0; i < x.n_body; i++ ){
if ( AST::is_a<AST::FunctionDef_t>(*x.m_body[i]) )
//generating function body
this->visit_stmt(*x.m_body[i]);
if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
//Add initializers to the AnnAssign
AST::AnnAssign_t* ann_assign = AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
AST::Name_t *n = AST::down_cast<AST::Name_t>(ann_assign->m_target);
std::string var_name = n->m_id;
ASR::expr_t* init_expr = nullptr;
visit_AnnAssignUtil(*ann_assign, var_name, init_expr, false, abi, true);
ASR::symbol_t* var_sym = current_scope->resolve_symbol(var_name);
ASR::call_arg_t c_arg;
c_arg.loc = var_sym->base.loc;
c_arg.m_value = init_expr;
member_init.push_back(al, c_arg);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems repeated as well.
I will try to look into them

}
return;
}
int64_t prev_value = 1;
for( size_t i = 0; i < x.n_body; i++ ) {
if (AST::is_a<AST::Expr_t>(*x.m_body[i])) {
Expand All @@ -2939,77 +2960,80 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
throw SemanticError("Only doc strings and const ellipsis allowed as expressions inside class", expr->base.base.loc);
} else if( AST::is_a<AST::ClassDef_t>(*x.m_body[i]) ) {
visit_ClassDef(*AST::down_cast<AST::ClassDef_t>(x.m_body[i]));
visit_ClassDef(*AST::down_cast<AST::ClassDef_t>(x.m_body[i]));
continue;
} else if ( AST::is_a<AST::FunctionDef_t>(*x.m_body[i]) ) {
throw SemanticError("Struct member functions are not supported", x.m_body[i]->base.loc);
this->visit_stmt(*x.m_body[i]);
continue;
} else if (AST::is_a<AST::Pass_t>(*x.m_body[i])) {
continue;
} else if (!AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this changes, it seems very confusing, we will it in another PR

throw SemanticError("AnnAssign expected inside struct", x.m_body[i]->base.loc);
}
AST::AnnAssign_t* ann_assign = AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if (!AST::is_a<AST::Name_t>(*ann_assign->m_target)) {
throw SemanticError("Only Name supported as target in AnnAssign inside struct", x.m_body[i]->base.loc);
}
AST::Name_t *n = AST::down_cast<AST::Name_t>(ann_assign->m_target);
std::string var_name = n->m_id;
ASR::expr_t* init_expr = nullptr;
if( is_enum_scope ) {
ASR::ttype_t* i64_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
init_expr = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, -1, i64_type));
}
visit_AnnAssignUtil(*ann_assign, var_name, init_expr, false, abi, true);
ASR::symbol_t* var_sym = current_scope->resolve_symbol(var_name);
ASR::call_arg_t c_arg;
c_arg.loc = var_sym->base.loc;
c_arg.m_value = init_expr;
member_init.push_back(al, c_arg);
if( is_enum_scope ) {
if( AST::is_a<AST::Call_t>(*ann_assign->m_value) ) {
AST::Call_t* auto_call_cand = AST::down_cast<AST::Call_t>(ann_assign->m_value);
if( AST::is_a<AST::Name_t>(*auto_call_cand->m_func) ) {
AST::Name_t* func = AST::down_cast<AST::Name_t>(auto_call_cand->m_func);
std::string func_name = func->m_id;
if( func_name == "auto" ) {
ASR::ttype_t* int_type = ASRUtils::symbol_type(var_sym);
init_expr = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al,
auto_call_cand->base.base.loc, prev_value, int_type));
prev_value += 1;
} else if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
AST::AnnAssign_t* ann_assign = AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if (!AST::is_a<AST::Name_t>(*ann_assign->m_target)) {
throw SemanticError("Only Name supported as target in AnnAssign inside struct", x.m_body[i]->base.loc);
}
AST::Name_t *n = AST::down_cast<AST::Name_t>(ann_assign->m_target);
std::string var_name = n->m_id;
ASR::expr_t* init_expr = nullptr;
if( is_enum_scope ) {
ASR::ttype_t* i64_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
init_expr = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, -1, i64_type));
}
visit_AnnAssignUtil(*ann_assign, var_name, init_expr, false, abi, true);
ASR::symbol_t* var_sym = current_scope->resolve_symbol(var_name);
ASR::call_arg_t c_arg;
c_arg.loc = var_sym->base.loc;
c_arg.m_value = init_expr;
member_init.push_back(al, c_arg);
if( is_enum_scope ) {
if( AST::is_a<AST::Call_t>(*ann_assign->m_value) ) {
AST::Call_t* auto_call_cand = AST::down_cast<AST::Call_t>(ann_assign->m_value);
if( AST::is_a<AST::Name_t>(*auto_call_cand->m_func) ) {
AST::Name_t* func = AST::down_cast<AST::Name_t>(auto_call_cand->m_func);
std::string func_name = func->m_id;
if( func_name == "auto" ) {
ASR::ttype_t* int_type = ASRUtils::symbol_type(var_sym);
init_expr = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al,
auto_call_cand->base.base.loc, prev_value, int_type));
prev_value += 1;
}
}
} else {
this->visit_expr(*ann_assign->m_value);
ASR::expr_t* enum_value = ASRUtils::expr_value(ASRUtils::EXPR(tmp));
LCOMPILERS_ASSERT(ASRUtils::is_value_constant(enum_value));
ASRUtils::extract_value(enum_value, prev_value);
prev_value += 1;
init_expr = enum_value;
}
} else {
this->visit_expr(*ann_assign->m_value);
ASR::expr_t* enum_value = ASRUtils::expr_value(ASRUtils::EXPR(tmp));
LCOMPILERS_ASSERT(ASRUtils::is_value_constant(enum_value));
ASRUtils::extract_value(enum_value, prev_value);
prev_value += 1;
init_expr = enum_value;
}
} else {
init_expr = nullptr;
}
if( ASR::is_a<ASR::Variable_t>(*var_sym) ) {
ASR::Variable_t* variable = ASR::down_cast<ASR::Variable_t>(var_sym);
variable->m_symbolic_value = init_expr;
}
ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(var_sym));
char* aggregate_type_name = nullptr;
if( ASR::is_a<ASR::Struct_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Struct_t>(var_type)->m_derived_type);
} else if( ASR::is_a<ASR::Enum_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Enum_t>(var_type)->m_enum_type);
} else if( ASR::is_a<ASR::Union_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Union_t>(var_type)->m_union_type);
init_expr = nullptr;
}
if( ASR::is_a<ASR::Variable_t>(*var_sym) ) {
ASR::Variable_t* variable = ASR::down_cast<ASR::Variable_t>(var_sym);
variable->m_symbolic_value = init_expr;
}
ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(var_sym));
char* aggregate_type_name = nullptr;
if( ASR::is_a<ASR::Struct_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Struct_t>(var_type)->m_derived_type);
} else if( ASR::is_a<ASR::Enum_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Enum_t>(var_type)->m_enum_type);
} else if( ASR::is_a<ASR::Union_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Union_t>(var_type)->m_union_type);
}
if( aggregate_type_name &&
!current_scope->get_symbol(std::string(aggregate_type_name)) ) {
struct_dependencies.push_back(al, aggregate_type_name);
}
member_names.push_back(al, n->m_id);
}
if( aggregate_type_name &&
!current_scope->get_symbol(std::string(aggregate_type_name)) ) {
struct_dependencies.push_back(al, aggregate_type_name);
else {
throw SemanticError("AnnAssign or Function def expected inside struct", x.m_body[i]->base.loc);
}
member_names.push_back(al, n->m_id);
}
}

Expand All @@ -3030,6 +3054,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {

void visit_ClassDef(const AST::ClassDef_t& x) {
std::string x_m_name = x.m_name;
bool is_generating_body = false;
if( is_enum(x.m_bases, x.n_bases) ) {
if( current_scope->resolve_symbol(x_m_name) ) {
return ;
Expand Down Expand Up @@ -3150,19 +3175,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
ASR::expr_t* algined_expr = nullptr;
bool is_packed = false;
if( !is_dataclass(x.m_decorator_list, x.n_decorator_list,
algined_expr, is_packed) ) {
throw SemanticError("Only dataclass-decorated classes and Enum subclasses are supported.",
x.base.base.loc);
}

get_alignment(x.m_decorator_list, x.n_decorator_list, algined_expr, is_packed);
if( x.n_bases > 0 ) {
throw SemanticError("Inheritance in classes isn't supported yet.",
x.base.base.loc);
}

SymbolTable *parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
ASR::symbol_t* clss_sym = current_scope->get_symbol(x_m_name);
ASR::StructType_t* clss = nullptr;
if( clss_sym != nullptr ){
clss = ASR::down_cast<ASR::StructType_t>(clss_sym);
current_scope = clss->m_symtab;
is_generating_body = true;
}else{
current_scope = al.make_new<SymbolTable>(parent_scope);
}
Vec<char*> member_names;
Vec<ASR::call_arg_t> member_init;
member_names.reserve(al, x.n_body);
Expand All @@ -3173,8 +3200,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
if( is_bindc_class(x.m_decorator_list, x.n_decorator_list) ) {
class_abi = ASR::abiType::BindC;
}
visit_ClassMembers(x, member_names, struct_dependencies, member_init, false, class_abi);
LCOMPILERS_ASSERT(member_init.size() == member_names.size());
visit_ClassMembers(x, member_names, struct_dependencies, member_init, false, class_abi,is_generating_body);
ASR::symbol_t* class_type = ASR::down_cast<ASR::symbol_t>(ASR::make_StructType_t(al,
x.base.base.loc, current_scope, x.m_name,
struct_dependencies.p, struct_dependencies.size(),
Expand All @@ -3188,6 +3214,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::StructType_t *st = ASR::down_cast<ASR::StructType_t>(sym);
st->m_initializers = member_init.p;
st->n_initializers = member_init.size();
clss->m_symtab->asr_owner = &clss_sym->base;
return;
} else {
current_scope->add_symbol(x_m_name, class_type);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-intent_01-66824bc.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-intent_01-66824bc.stdout",
"stdout_hash": "415fb57ee7c986fc49e7c0801edae4e37d6ea06143d27a998c50ab5c",
"stdout_hash": "b4d6158df60118b1a4c88f675262373b18c3ebe073ac2d0761ece543",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-intent_01-66824bc.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
main_program:
(Program
(SymbolTable
6
5
{

})
Expand Down
Loading
Loading