Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Symbolic func attribute #2367

Merged
merged 10 commits into from
Oct 17, 2023
28 changes: 22 additions & 6 deletions integration_tests/symbolics_02.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
from sympy import Symbol, pi
from sympy import Symbol, pi, Add, Mul, Pow
from lpython import S

def test_symbolic_operations():
x: S = Symbol('x')
y: S = Symbol('y')
p1: S = pi
p2: S = pi
pi1: S = pi
pi2: S = pi

# Addition
z: S = x + y
z1: bool = z.func == Add
z2: bool = z.func == Mul
assert(z == x + y)
assert(z1 == True)
assert(z2 == False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to this, let's also add tests like this:

if z.func == Add:
    assert True
else:
    assert False

as well as:

assert z.func == Add

To ensure that the z.func == Add idiom works in all the situations:

  • assigning to a bool variable (already tested)
  • in if
  • in assert

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should still be addressed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test assert z.func == Add ? I don't see it, unless I missed it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remember this and add it in a subsequent pr

print(z)

# Subtraction
w: S = x - y
w1: bool = w.func == Add
assert(w == x - y)
assert(w1 == True)
print(w)

# Multiplication
u: S = x * y
u1: bool = u.func == Mul
assert(u == x * y)
assert(u1 == True)
print(u)

# Division
v: S = x / y
v1: bool = v.func == Mul
assert(v == x / y)
assert(v1 == True)
print(v)

# Power
p: S = x ** y
p1: bool = p.func == Pow
p2: bool = p.func == Add
p3: bool = p.func == Mul
assert(p == x ** y)
assert(p1 == True)
assert(p2 == False)
assert(p3 == False)
print(p)

# Casting
Expand All @@ -40,13 +56,13 @@ def test_symbolic_operations():
print(c)

# Comparison
b1: bool = p1 == p2
b1: bool = pi1 == pi2
print(b1)
assert(b1 == True)
b2: bool = p1 != pi
b2: bool = pi1 != pi
print(b2)
assert(b2 == False)
b3: bool = p1 != x
b3: bool = pi1 != x
print(b3)
assert(b3 == True)
b4: bool = pi == Symbol("x")
Expand Down
63 changes: 63 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ enum class IntrinsicScalarFunctions : int64_t {
SymbolicExp,
SymbolicAbs,
SymbolicHasSymbolQ,
SymbolicAddQ,
SymbolicMulQ,
SymbolicPowQ,
// ...
};

Expand Down Expand Up @@ -140,6 +143,9 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicExp)
INTRINSIC_NAME_CASE(SymbolicAbs)
INTRINSIC_NAME_CASE(SymbolicHasSymbolQ)
INTRINSIC_NAME_CASE(SymbolicAddQ)
INTRINSIC_NAME_CASE(SymbolicMulQ)
INTRINSIC_NAME_CASE(SymbolicPowQ)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
}
Expand Down Expand Up @@ -3100,6 +3106,48 @@ namespace SymbolicHasSymbolQ {
}
} // namespace SymbolicHasSymbolQ

#define create_symbolic_query_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
diag::Diagnostics& diagnostics) { \
const Location& loc = x.base.base.loc; \
ASRUtils::require_impl(x.n_args == 1, \
#X " must have exactly 1 input argument", loc, diagnostics); \
\
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type), \
#X " expects an argument of type SymbolicExpression", loc, diagnostics); \
} \
\
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) { \
/*TODO*/ \
return nullptr; \
} \
\
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> err) { \
if (args.size() != 1) { \
err("Intrinsic " #X " function accepts exactly 1 argument", loc); \
} \
\
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \
if (!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
err("Argument of " #X " function must be of type SymbolicExpression", \
args[0]->base.loc); \
} \
\
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \
static_cast<int64_t>(IntrinsicScalarFunctions::X), 0, logical); \
} \
} // namespace X

create_symbolic_query_macro(SymbolicAddQ)
create_symbolic_query_macro(SymbolicMulQ)
create_symbolic_query_macro(SymbolicPowQ)


#define create_symbolic_unary_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
Expand Down Expand Up @@ -3253,6 +3301,12 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &SymbolicAbs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
{nullptr, &SymbolicHasSymbolQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
{nullptr, &SymbolicAddQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
{nullptr, &SymbolicMulQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
{nullptr, &SymbolicPowQ::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -3357,6 +3411,12 @@ namespace IntrinsicScalarFunctionRegistry {
"SymbolicAbs"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
"SymbolicHasSymbolQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
"SymbolicAddQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
"SymbolicMulQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
"SymbolicPowQ"},
};


Expand Down Expand Up @@ -3412,6 +3472,9 @@ namespace IntrinsicScalarFunctionRegistry {
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
{"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}},
{"is_Add", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}},
{"is_Mul", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}},
{"is_Pow", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
anutosh491 marked this conversation as resolved.
Show resolved Hide resolved
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down
95 changes: 94 additions & 1 deletion src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_get_type";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_eq";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
// Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
default: {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(intrinsic_id)
Expand Down Expand Up @@ -1298,7 +1391,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
} else if(ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
SymbolTable* module_scope = current_scope->parent;
ASR::expr_t* left_tmp = nullptr;
Expand Down
43 changes: 41 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4605,6 +4605,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
public:
ASR::asr_t *asr;
std::vector<ASR::symbol_t*> do_loop_variables;
bool using_func_attr = false;

BodyVisitor(Allocator &al, LocationManager &lm, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
bool main_module, std::string module_name, std::map<int, ASR::symbol_t*> &ast_overload,
Expand Down Expand Up @@ -5803,6 +5804,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
} else if(ASR::is_a<ASR::Pointer_t>(*type)) {
ASR::Pointer_t* p = ASR::down_cast<ASR::Pointer_t>(type);
visit_AttributeUtil(p->m_type, attr_char, t, loc);
} else if(ASR::is_a<ASR::SymbolicExpression_t>(*type)) {
std::string attr = attr_char;
if (attr == "func") {
using_func_attr = true;
return;
}
ASR::expr_t *se = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
Vec<ASR::expr_t*> args;
args.reserve(al, 0);
handle_symbolic_attribute(se, attr, loc, args);
} else {
throw SemanticError(ASRUtils::type_to_str_python(type) + " not supported yet in Attribute.",
loc);
Expand Down Expand Up @@ -5996,8 +6007,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}

void visit_Compare(const AST::Compare_t &x) {
this->visit_expr(*x.m_left);
ASR::expr_t *left = ASRUtils::EXPR(tmp);
if (x.n_comparators > 1) {
diag.add(diag::Diagnostic(
"Only one comparison operator is supported for now",
Expand All @@ -6008,6 +6017,36 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
);
throw SemanticAbort();
}
this->visit_expr(*x.m_left);
if (using_func_attr) {
if (AST::is_a<AST::Attribute_t>(*x.m_left) && AST::is_a<AST::Name_t>(*x.m_comparators[0])) {
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_left);
AST::Name_t *type_name = AST::down_cast<AST::Name_t>(x.m_comparators[0]);
std::string symbolic_type = type_name->m_id;
if (AST::is_a<AST::Name_t>(*attr->m_value)) {
AST::Name_t *var_name = AST::down_cast<AST::Name_t>(attr->m_value);
std::string var = var_name->m_id;
ASR::symbol_t *st = current_scope->resolve_symbol(var);
ASR::expr_t *se = ASR::down_cast<ASR::expr_t>(
ASR::make_Var_t(al, x.base.base.loc, st));
Vec<ASR::expr_t*> args;
args.reserve(al, 0);
if (symbolic_type == "Add") {
handle_symbolic_attribute(se, "is_Add", x.base.base.loc, args);
return;
} else if (symbolic_type == "Mul") {
handle_symbolic_attribute(se, "is_Mul", x.base.base.loc, args);
return;
} else if (symbolic_type == "Pow") {
handle_symbolic_attribute(se, "is_Pow", x.base.base.loc, args);
return;
} else {
throw SemanticError(symbolic_type + " symbolic type not supported yet", x.base.base.loc);
}
}
}
}
ASR::expr_t *left = ASRUtils::EXPR(tmp);
this->visit_expr(*x.m_comparators[0]);
ASR::expr_t *right = ASRUtils::EXPR(tmp);

Expand Down
Loading