Skip to content

Commit

Permalink
Merge pull request #2801 from tanay-man/inheritance
Browse files Browse the repository at this point in the history
Initial implementation of Inheritance and Polymorphic functions
  • Loading branch information
Thirumalai-Shaktivel authored Aug 18, 2024
2 parents a104c30 + f42a561 commit 5a66456
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 34 deletions.
3 changes: 3 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,9 @@ RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
RUN(NAME class_04 LABELS cpython llvm llvm_jit)
RUN(NAME class_05 LABELS cpython llvm llvm_jit)
RUN(NAME class_06 LABELS cpython llvm llvm_jit)


# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
37 changes: 37 additions & 0 deletions integration_tests/class_05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from lpython import i32

class Animal:
def __init__(self:"Animal"):
self.species: str = "Generic Animal"
self.age: i32 = 0
self.is_domestic: bool = True

class Dog(Animal):
def __init__(self:"Dog", name:str, age:i32):
super().__init__()
self.species: str = "Dog"
self.name: str = name
self.age: i32 = age

class Cat(Animal):
def __init__(self:"Cat", name: str, age: i32):
super().__init__()
self.species: str = "Cat"
self.name:str = name
self.age: i32 = age

def main():
dog: Dog = Dog("Buddy", 5)
cat: Cat = Cat("Whiskers", 3)
op1: str = str(dog.name+" is a "+str(dog.age)+"-year-old "+dog.species+".")
print(op1)
assert op1 == "Buddy is a 5-year-old Dog."
print(dog.is_domestic)
assert dog.is_domestic == True
op2: str = str(cat.name+ " is a "+ str(cat.age)+ "-year-old "+ cat.species+ ".")
print(op2)
assert op2 == "Whiskers is a 3-year-old Cat."
print(cat.is_domestic)
assert cat.is_domestic == True

main()
36 changes: 36 additions & 0 deletions integration_tests/class_06.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from lpython import i32

class Base():
def __init__(self:"Base"):
self.x : i32 = 10

def get_x(self:"Base")->i32:
print(self.x)
return self.x

#Testing polymorphic fn calls
def get_x_static(d: Base)->i32:
print(d.x)
return d.x

class Derived(Base):
def __init__(self: "Derived"):
super().__init__()
self.y : i32 = 20

def get_y(self:"Derived")->i32:
print(self.y)
return self.y


def main():
d : Derived = Derived()
x : i32 = get_x_static(d)
assert x == 10
# Testing parent method call using der obj
x = d.get_x()
assert x == 10
y: i32 = d.get_y()
assert y == 20

main()
2 changes: 1 addition & 1 deletion src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ ttype
| Array(ttype type, dimension* dims, array_physical_type physical_type)
| FunctionType(ttype* arg_types, ttype? return_var_type, abi abi, deftype deftype, string? bindc_name, bool elemental, bool pure, bool module, bool inline, bool static, symbol* restrictions, bool is_restriction)

cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray
cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray | DerivedToBase
storage_type = Default | Save | Parameter
access = Public | Private
intent = Local | In | Out | InOut | ReturnVar | Unspecified
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/casting_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ namespace LCompilers::CastingUtil {
{ASR::ttypeType::Complex, ASR::cast_kindType::ComplexToComplex},
{ASR::ttypeType::Real, ASR::cast_kindType::RealToReal},
{ASR::ttypeType::Integer, ASR::cast_kindType::IntegerToInteger},
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger}
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger},
{ASR::ttypeType::StructType, ASR::cast_kindType::DerivedToBase}
};

int get_type_priority(ASR::ttypeType type) {
Expand Down
5 changes: 5 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7725,6 +7725,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = LLVM::CreateLoad(*builder, list_api->get_pointer_to_list_data(tmp));
break;
}
case (ASR::cast_kindType::DerivedToBase) : {
this->visit_expr(*x.m_arg);
tmp = llvm_utils->create_gep(tmp, 0);
break;
}
default : throw CodeGenError("Cast kind not implemented");
}
}
Expand Down
151 changes: 121 additions & 30 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::call_arg_t c_arg;
c_arg.loc = args[i].loc;
c_arg.m_value = args[i].m_value;
cast_helper(m_args[i], c_arg.m_value, true);
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
l_type->m_derived_type));
ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
r_type->m_derived_type));
if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) {
cast_helper(m_args[i], c_arg.m_value, true, true);
check_type_equality = false;
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
std::string ltype = ASRUtils::type_to_str_python(left_type);
std::string rtype = ASRUtils::type_to_str_python(right_type);
Expand Down Expand Up @@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
std::string obj_name = x.m_args.m_args->m_arg;
for(size_t i = 0; i < x.n_body; i++) {
std::string var_name;
if (! AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
throw SemanticError("Only AnnAssign implemented in __init__ ",
x.m_body[i]->base.loc);
if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
continue;
}
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if(AST::is_a<AST::Attribute_t>(*ann_assign.m_target)){
Expand Down Expand Up @@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope->add_symbol(x_m_name, class_type);
}
} else {
if( x.n_bases > 0 ) {
throw SemanticError("Inheritance in classes isn't supported yet.",
ASR::symbol_t* parent = nullptr;
if( x.n_bases > 1 ) {
throw SemanticError("Multiple inheritance in classes isn't supported yet.",
x.base.base.loc);
}
else if (x.n_bases == 1) {
std::string b_name = "";
if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
b_name = AST::down_cast<AST::Name_t>(x.m_bases[0])->m_id;
} else {
throw SemanticError("Expected a Name here", x.base.base.loc);
}
parent = current_scope->resolve_symbol(b_name);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
}
SymbolTable *parent_scope = current_scope;
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
Expand All @@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
init_self_type(*f, sym, x.base.base.loc);
if ( std::string(f->m_name) == std::string("__init__") ) {
this->visit_init_body(*f);
this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc);
} else {
this->visit_stmt(*x.m_body[i]);
}
Expand Down Expand Up @@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
member_names.p, member_names.size(), member_fn_names.p,
member_fn_names.size(), class_abi, ASR::accessType::Public,
false, false, member_init.p, member_init.size(),
nullptr, nullptr));
nullptr, parent));
parent_scope->add_symbol(x.m_name, class_sym);
visit_ClassMembers(x, member_names, member_fn_names,
struct_dependencies, member_init, false, class_abi, true);
Expand Down Expand Up @@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope = parent_scope;
}

virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) = 0;

void add_name(const Location &loc) {
std::string var_name = "__name__";
Expand Down Expand Up @@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
// Implement visit_Global for Symbol Table visitor.
void visit_Global(const AST::Global_t &/*x*/) {}

void visit_init_body (const AST::FunctionDef_t &/*x*/) {
void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) {
//Implemented in BodyVisitor
}

Expand Down Expand Up @@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = asr;
}

void visit_init_body (const AST::FunctionDef_t &x) {
void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc) {
SymbolTable *old_scope = current_scope;
ASR::symbol_t *t = current_scope->get_symbol("__init__");
if ( t==nullptr ) {
Expand All @@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
throw SemanticError("__init__ is not a function", x.base.base.loc);
}
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
current_scope = f->m_symtab;
//Transform statements into correct format
Vec<AST::stmt_t*> new_body;
new_body.reserve(al, 1);
Vec<AST::stmt_t*> body;
body.reserve(al, 1);
ASR::stmt_t* super_call_stmt = nullptr;
for (size_t i=0; i<x.n_body; i++) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
new_body.push_back(al, assgn);
if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
body.push_back(al, assgn);
}
} else if (AST::is_a<AST::Expr_t>(*x.m_body[i]) &&
AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value))) {
AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value);

if ( !AST::is_a<AST::Attribute_t>(*(c->m_func))
|| !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value)) ) {
body.push_back(al, x.m_body[i]);
continue;
}
AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value);
std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func)->m_attr;
if ( AST::is_a<AST::Name_t>(*(super_call->m_func)) &&
std::string(AST::down_cast<AST::Name_t>(super_call->m_func)->m_id)=="super" &&
attr == "__init__") {
if (parent_sym == nullptr) {
throw SemanticError("The class doesn't have a base class",loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, 1);
parse_args(*super_call,args);
ASR::call_arg_t first_arg;
first_arg.loc = loc;
ASR::symbol_t* self_sym = current_scope->get_symbol("self");
first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym));
ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym));
cast_helper(target_type, first_arg.m_value, x.base.base.loc, true);
Vec<ASR::call_arg_t> args_w_first; args_w_first.reserve(al,1);
args_w_first.push_back(al, first_arg);
for( size_t i = 0; i < args.size(); i++ ) {
args_w_first.push_back(al,args[i]);
}
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc);
super_call_stmt = ASRUtils::STMT(
ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p,
args_w_first.size(), nullptr));
}
} else {
body.push_back(al, x.m_body[i]);
}
}
current_scope = f->m_symtab;
Vec<ASR::stmt_t*> body;
body.reserve(al, x.n_body);
Vec<ASR::stmt_t*> body_asr;
body_asr.reserve(al, x.n_body);
if ( super_call_stmt ) {
body_asr.push_back(al, super_call_stmt);
}
Vec<ASR::symbol_t*> rts;
rts.reserve(al, 4);
dependencies.clear(al);
transform_stmts(body, new_body.n, new_body.p);
transform_stmts(body_asr, body.n, body.p);
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
f->m_body = body.p;
f->n_body = body.size();
f->m_body = body_asr.p;
f->n_body = body_asr.size();
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
f->m_function_signature);
func_type->m_restrictions = rts.p;
Expand Down Expand Up @@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
member_found = std::string(der_type->m_members[i]) == member_name;
}
if( !member_found ) {
if( !member_found && !der_type->m_parent ) {
throw SemanticError("No member " + member_name +
" found in " + std::string(der_type->m_name),
loc);
} else if ( !member_found && der_type->m_parent ) {
ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
visit_AttributeUtil(parent_type,attr_char,t,loc);
return;
}
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
Expand Down Expand Up @@ -8064,7 +8141,8 @@ we will have to use something else.
//TODO: Correct Class and ClassType
// call to struct member function
// modifying args to pass the object as self
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::symbol_t* der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
ASR::call_arg_t self_arg;
self_arg.loc = args[0].loc;
Expand All @@ -8073,7 +8151,20 @@ we will have to use something else.
for (size_t i=0; i<args.n; i++) {
new_args.push_back(al, args[i]);
}
st = get_struct_member(der, call_name, loc);
if ( der->m_symtab->get_symbol(call_name) ) {
st = get_struct_member(der_sym, call_name, loc);
} else if ( der->m_parent ) {
ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent);
if ( !parent->m_symtab->get_symbol(call_name) ) {
throw SemanticError("Method not found in the class "+ std::string(der->m_name) +
" or it's parents",loc);
} else {
st = get_struct_member(der->m_parent, call_name, loc);
}
} else {
throw SemanticError("Method not found in the class "+std::string(der->m_name)+
" or it's parents",loc);
}
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
return;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-structs_09-f3ffe08.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_09-f3ffe08.stderr",
"stderr_hash": "f59ab2d213f6423e0a891e43d5a19e83d4405391b1c7bf481b4b939e",
"stderr_hash": "14119a0bc6420ad242b99395d457f2092014d96d2a1ac81d376c649d",
"returncode": 2
}
Loading

0 comments on commit 5a66456

Please sign in to comment.