Skip to content

Commit

Permalink
Added support for Out[S] (#2419)
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 authored Nov 13, 2023
1 parent b34bcc1 commit b19dc21
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST)
RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_13 LABELS cpython_sym c_sym NOFAST)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
12 changes: 12 additions & 0 deletions integration_tests/symbolics_13.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from lpython import S
from sympy import pi, Symbol

def func() -> S:
return pi

def test_func():
z: S = func()
print(z)
assert z == pi

test_func()
9 changes: 5 additions & 4 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ namespace LCompilers {

static inline bool is_aggregate_or_array_type(ASR::expr_t* var) {
return (ASR::is_a<ASR::Struct_t>(*ASRUtils::expr_type(var)) ||
ASRUtils::is_array(ASRUtils::expr_type(var)));
ASRUtils::is_array(ASRUtils::expr_type(var)) ||
ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(var)));
}

template <class Struct>
Expand Down Expand Up @@ -775,7 +776,7 @@ namespace LCompilers {
}

static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x,
bool (*is_array_or_struct)(ASR::expr_t*)) {
bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) {
if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) {
return;
}
Expand All @@ -787,7 +788,7 @@ namespace LCompilers {
* in avoiding deep copies and the destination memory directly gets
* filled inside the function.
*/
if( is_array_or_struct(x->m_return_var)) {
if( is_array_or_struct_or_symbolic(x->m_return_var)) {
for( auto& s_item: x->m_symtab->get_scope() ) {
ASR::symbol_t* curr_sym = s_item.second;
if( curr_sym->type == ASR::symbolType::Variable ) {
Expand Down Expand Up @@ -834,7 +835,7 @@ namespace LCompilers {
for (auto &item : x->m_symtab->get_scope()) {
if (ASR::is_a<ASR::Function_t>(*item.second)) {
handle_fn_return_var(al, ASR::down_cast<ASR::Function_t>(
item.second), is_array_or_struct);
item.second), is_array_or_struct_or_symbolic);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
xx.m_type = type1;
if (var_name != "_lpython_return_variable") {
if (var_name != "_lpython_return_variable" && xx.m_intent != ASR::intentType::Out) {
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
}
if(xx.m_intent == ASR::intentType::In){
Expand Down

0 comments on commit b19dc21

Please sign in to comment.