Skip to content

Commit

Permalink
[ASR Pass] Replace the Class obj Assignment with SubroutineCall (#2795)
Browse files Browse the repository at this point in the history
Move the code modification to the ASR Pass
  • Loading branch information
tanay-man authored Aug 16, 2024
1 parent f3389e6 commit 61a27d5
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 42 deletions.
34 changes: 34 additions & 0 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,40 @@ namespace LCompilers {
}
}
}
ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope) {
ASR::Struct_t* struct_type = ASR::down_cast<ASR::Struct_t>(struct_type_sym);
std::string struct_var_name = struct_type->m_name;
std::string struct_member_name = call_name;
ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name);
ASR::symbol_t* struct_mem_asr_owner = ASRUtils::get_asr_owner(struct_member);
if( !struct_member || !struct_mem_asr_owner ||
!ASR::is_a<ASR::Struct_t>(*struct_mem_asr_owner) ) {
throw LCompilersException(struct_member_name + " not present in " +
struct_var_name + " dataclass");
}
std::string import_name = struct_var_name + "_" + struct_member_name;
ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name);
bool import_from_struct = true;
if( import_struct_member ) {
if( ASR::is_a<ASR::ExternalSymbol_t>(*import_struct_member) ) {
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(import_struct_member);
if( ext_sym->m_external == struct_member &&
std::string(ext_sym->m_module_name) == struct_var_name ) {
import_from_struct = false;
}
}
}
if( import_from_struct ) {
import_name = current_scope->get_unique_name(import_name, false);
import_struct_member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(al,
loc, current_scope, s2c(al, import_name),
struct_member, s2c(al, struct_var_name), nullptr, 0,
s2c(al, struct_member_name), ASR::accessType::Public));
current_scope->add_symbol(import_name, import_struct_member);
}
return import_struct_member;
}

} // namespace PassUtils

Expand Down
52 changes: 41 additions & 11 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ namespace LCompilers {
*/
};

ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope);

namespace ReplacerUtils {
template <typename T>
void replace_StructConstructor(ASR::StructConstructor_t* x,
Expand All @@ -578,6 +581,33 @@ namespace LCompilers {
bool perform_cast=false,
ASR::cast_kindType cast_kind=ASR::cast_kindType::IntegerToInteger,
ASR::ttype_t* casted_type=nullptr) {
if ( ASR::is_a<ASR::Struct_t>(*(x->m_dt_sym)) ) {
ASR::Struct_t* st = ASR::down_cast<ASR::Struct_t>(x->m_dt_sym);
if ( st->n_member_functions > 0 ) {
remove_original_statement = true;
if ( !ASR::is_a<ASR::Var_t>(*(replacer->result_var)) ) {
throw LCompilersException("Expected a var here");
}
ASR::Var_t* target = ASR::down_cast<ASR::Var_t>(replacer->result_var);
ASR::call_arg_t first_arg;
first_arg.loc = x->base.base.loc; first_arg.m_value = replacer->result_var;
Vec<ASR::call_arg_t> new_args; new_args.reserve(replacer->al,x->n_args+1);
new_args.push_back(replacer->al, first_arg);
for( size_t i = 0; i < x->n_args; i++ ) {
new_args.push_back(replacer->al, x->m_args[i]);
}
ASR::StructType_t* type = ASR::down_cast<ASR::StructType_t>(
(ASR::down_cast<ASR::Variable_t>(target->m_v))->m_type);
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(replacer->al,type->m_derived_type, call_name,
x->base.base.loc, replacer->current_scope);
result_vec->push_back(replacer->al, ASRUtils::STMT(
ASRUtils::make_SubroutineCall_t_util(replacer->al,
x->base.base.loc, call_sym, nullptr, new_args.p, new_args.size(),
nullptr, nullptr, false, false)));
return;
}
}
if( x->n_args == 0 ) {
if( !inside_symtab ) {
remove_original_statement = true;
Expand All @@ -598,22 +628,22 @@ namespace LCompilers {
}

std::deque<ASR::symbol_t*> constructor_arg_syms;
ASR::StructType_t* dt_der = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_dertype = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
while( dt_dertype ) {
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_dertype->m_derived_type));
while( dt_der ) {
for( int i = (int) dt_der->n_members - 1; i >= 0; i-- ) {
constructor_arg_syms.push_front(
dt_dertype->m_symtab->get_symbol(
dt_dertype->m_members[i]));
dt_der->m_symtab->get_symbol(
dt_der->m_members[i]));
}
if( dt_dertype->m_parent != nullptr ) {
if( dt_der->m_parent != nullptr ) {
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
dt_dertype->m_parent);
dt_der->m_parent);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*dt_der_sym));
dt_dertype = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
dt_der = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
} else {
dt_dertype = nullptr;
dt_der = nullptr;
}
}
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);
Expand Down
35 changes: 4 additions & 31 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,16 +1282,13 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
args, st, loc);
}
if ( st->n_member_functions > 0 ) {
// Empty struct constructor
// Initializers handled in init proc call
Vec<ASR::call_arg_t>empty_args;
empty_args.reserve(al, 1);
for (size_t i = 0; i < st->n_members; i++) {
empty_args.push_back(al, st->m_initializers[i]);
if ( n_kwargs>0 ) {
throw SemanticError("Keyword args are not supported", loc);
}
ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp));
return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p,
empty_args.size(), der_type, nullptr);
return ASR::make_StructConstructor_t(al, loc, stemp, args.p,
args.size(), der_type, nullptr);
}

if ( args.size() > 0 && args.size() > st->n_members ) {
Expand Down Expand Up @@ -5316,17 +5313,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, call->n_args + 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym));
args.push_back(al, self_arg);
visit_expr_list(call->m_args, call->n_args, args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>((var->m_type))->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc);
}
}
}
Expand Down Expand Up @@ -5611,23 +5597,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
overloaded));
if ( target->type == ASR::exprType::Var &&
tmp_value->type == ASR::exprType::StructConstructor ) {
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
ASR::symbol_t* st = ASR::down_cast<ASR::Var_t>(target)->m_v;
self_arg.m_value = target;
new_args.push_back(al,self_arg);
AST::Call_t* call = AST::down_cast<AST::Call_t>(x.m_value);
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
visit_expr_list(call->m_args, call->n_args, new_args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(
ASR::down_cast<ASR::Variable_t>(st)->m_type)->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp_vec.push_back(make_call_helper(al, call_sym,
current_scope, new_args, call_name, x.base.base.loc));
}
}
// to make sure that we add only those statements in tmp_vec
Expand Down

0 comments on commit 61a27d5

Please sign in to comment.