Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Symbolic func attribute #2367

Merged
merged 10 commits into from
Oct 17, 2023
48 changes: 42 additions & 6 deletions integration_tests/symbolics_02.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,71 @@
from sympy import Symbol, pi
from sympy import Symbol, pi, Add, Mul, Pow
from lpython import S

def test_symbolic_operations():
x: S = Symbol('x')
y: S = Symbol('y')
p1: S = pi
p2: S = pi
pi1: S = pi
pi2: S = pi

# Addition
z: S = x + y
z1: bool = z.func == Add
z2: bool = z.func == Mul
assert(z == x + y)
assert(z1 == True)
assert(z2 == False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to this, let's also add tests like this:

if z.func == Add:
    assert True
else:
    assert False

as well as:

assert z.func == Add

To ensure that the z.func == Add idiom works in all the situations:

  • assigning to a bool variable (already tested)
  • in if
  • in assert

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should still be addressed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test assert z.func == Add ? I don't see it, unless I missed it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remember this and add it in a subsequent pr

if z.func == Add:
assert True
else:
assert False
print(z)

# Subtraction
w: S = x - y
w1: bool = w.func == Add
assert(w == x - y)
assert(w1 == True)
if w.func == Add:
assert True
else:
assert False
print(w)

# Multiplication
u: S = x * y
u1: bool = u.func == Mul
assert(u == x * y)
assert(u1 == True)
if u.func == Mul:
assert True
else:
assert False
print(u)

# Division
v: S = x / y
v1: bool = v.func == Mul
assert(v == x / y)
assert(v1 == True)
if v.func == Mul:
assert True
else:
assert False
print(v)

# Power
p: S = x ** y
p1: bool = p.func == Pow
p2: bool = p.func == Add
p3: bool = p.func == Mul
assert(p == x ** y)
assert(p1 == True)
assert(p2 == False)
assert(p3 == False)
if p.func == Pow:
assert True
else:
assert False
print(p)

# Casting
Expand All @@ -40,13 +76,13 @@ def test_symbolic_operations():
print(c)

# Comparison
b1: bool = p1 == p2
b1: bool = pi1 == pi2
print(b1)
assert(b1 == True)
b2: bool = p1 != pi
b2: bool = pi1 != pi
print(b2)
assert(b2 == False)
b3: bool = p1 != x
b3: bool = pi1 != x
print(b3)
assert(b3 == True)
b4: bool = pi == Symbol("x")
Expand Down
63 changes: 63 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ enum class IntrinsicScalarFunctions : int64_t {
SymbolicExp,
SymbolicAbs,
SymbolicHasSymbolQ,
SymbolicAddQ,
SymbolicMulQ,
SymbolicPowQ,
// ...
};

Expand Down Expand Up @@ -140,6 +143,9 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicExp)
INTRINSIC_NAME_CASE(SymbolicAbs)
INTRINSIC_NAME_CASE(SymbolicHasSymbolQ)
INTRINSIC_NAME_CASE(SymbolicAddQ)
INTRINSIC_NAME_CASE(SymbolicMulQ)
INTRINSIC_NAME_CASE(SymbolicPowQ)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
}
Expand Down Expand Up @@ -3100,6 +3106,48 @@ namespace SymbolicHasSymbolQ {
}
} // namespace SymbolicHasSymbolQ

#define create_symbolic_query_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
diag::Diagnostics& diagnostics) { \
const Location& loc = x.base.base.loc; \
ASRUtils::require_impl(x.n_args == 1, \
#X " must have exactly 1 input argument", loc, diagnostics); \
\
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type), \
#X " expects an argument of type SymbolicExpression", loc, diagnostics); \
} \
\
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) { \
/*TODO*/ \
return nullptr; \
} \
\
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> err) { \
if (args.size() != 1) { \
err("Intrinsic " #X " function accepts exactly 1 argument", loc); \
} \
\
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \
if (!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
err("Argument of " #X " function must be of type SymbolicExpression", \
args[0]->base.loc); \
} \
\
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \
static_cast<int64_t>(IntrinsicScalarFunctions::X), 0, logical); \
} \
} // namespace X

create_symbolic_query_macro(SymbolicAddQ)
create_symbolic_query_macro(SymbolicMulQ)
create_symbolic_query_macro(SymbolicPowQ)


#define create_symbolic_unary_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
Expand Down Expand Up @@ -3253,6 +3301,12 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &SymbolicAbs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
{nullptr, &SymbolicHasSymbolQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
{nullptr, &SymbolicAddQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
{nullptr, &SymbolicMulQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
{nullptr, &SymbolicPowQ::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -3357,6 +3411,12 @@ namespace IntrinsicScalarFunctionRegistry {
"SymbolicAbs"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
"SymbolicHasSymbolQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
"SymbolicAddQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
"SymbolicMulQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
"SymbolicPowQ"},
};


Expand Down Expand Up @@ -3412,6 +3472,9 @@ namespace IntrinsicScalarFunctionRegistry {
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
{"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}},
{"AddQ", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}},
{"MulQ", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}},
{"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down
109 changes: 108 additions & 1 deletion src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_get_type";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_eq";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
default: {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(intrinsic_id)
Expand Down Expand Up @@ -998,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

void visit_If(const ASR::If_t& x) {
ASR::If_t& xx = const_cast<ASR::If_t&>(x);
transform_stmts(xx.m_body, xx.n_body);
transform_stmts(xx.m_orelse, xx.n_orelse);
SymbolTable* module_scope = current_scope->parent;
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope);
xx.m_test = function_call;
}
}
}

void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
SymbolTable* module_scope = current_scope->parent;
Vec<ASR::call_arg_t> call_args;
Expand Down Expand Up @@ -1298,7 +1405,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
} else if(ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
SymbolTable* module_scope = current_scope->parent;
ASR::expr_t* left_tmp = nullptr;
Expand Down
Loading