Skip to content

Commit

Permalink
Merge pull request #91 from czgdp1807/nbody_1
Browse files Browse the repository at this point in the history
Add support for compile time evaluation of ``StructInstanceMember``
  • Loading branch information
czgdp1807 authored Feb 14, 2024
2 parents f73cf85 + f6af250 commit 6def44b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 13 deletions.
22 changes: 11 additions & 11 deletions integration_tests/nbody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Source: https://benchmarksgame-team.pages.debian.net/benchmarksgame/program/nbod
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

const int nb = 5;
const double PI = 3.141592653589793;
const double SOLAR_MASS = 4 * PI * PI;
const int N = (nb - 1) * nb/2;
constexpr int nb = 5;
constexpr double PI = 3.141592653589793;
constexpr double SOLAR_MASS = 4 * PI * PI;
constexpr int N = (nb - 1) * nb/2;

void offset_momentum(const int k, xt::xtensor_fixed<double, xt::xshape<3, nb>>& v,
const xt::xtensor_fixed<double, xt::xshape<nb>>& mass) {
Expand Down Expand Up @@ -101,7 +101,7 @@ double energy(const xt::xtensor_fixed<double, xt::xshape<3, nb>>& x,
struct body {
double x, y, z, u, vx, vy, vz, vu, mass;

body(double x_, double y_, double z_, double u_,
constexpr body(double x_, double y_, double z_, double u_,
double vx_, double vy_, double vz_, double vu_,
double mass_) : x{x_}, y{y_}, z{z_}, u{u_}, vx{vx_},
vy{vy_}, vz{vz_}, vu{vu_}, mass{mass_} {
Expand All @@ -112,40 +112,40 @@ struct body {
int main() {

const double tstep = 0.01;
const double DAYS_PER_YEAR = 365.24;
constexpr double DAYS_PER_YEAR = 365.24;

const struct body jupiter = body(
constexpr struct body jupiter = body(
4.84143144246472090, -1.16032004402742839,
-1.03622044471123109e-01, 0.0, 1.66007664274403694e-03 * DAYS_PER_YEAR,
7.69901118419740425e-03 * DAYS_PER_YEAR,
-6.90460016972063023e-05 * DAYS_PER_YEAR, 0.0,
9.54791938424326609e-04 * SOLAR_MASS);

const struct body saturn = body(
constexpr struct body saturn = body(
8.34336671824457987, 4.12479856412430479,
-4.03523417114321381e-01, 0.0,
-2.76742510726862411e-03 * DAYS_PER_YEAR,
4.99852801234917238e-03 * DAYS_PER_YEAR,
2.30417297573763929e-05 * DAYS_PER_YEAR, 0.0,
2.85885980666130812e-04 * SOLAR_MASS);

const struct body uranus = body(
constexpr struct body uranus = body(
1.28943695621391310e+01, -1.51111514016986312e+01,
-2.23307578892655734e-01, 0.0,
2.96460137564761618e-03 * DAYS_PER_YEAR,
2.37847173959480950e-03 * DAYS_PER_YEAR,
-2.96589568540237556e-05 * DAYS_PER_YEAR, 0.0,
4.36624404335156298e-05 * SOLAR_MASS);

const struct body neptune = body(
constexpr struct body neptune = body(
1.53796971148509165e+01, -2.59193146099879641e+01,
1.79258772950371181e-01, 0.0,
2.68067772490389322e-03 * DAYS_PER_YEAR,
1.62824170038242295e-03 * DAYS_PER_YEAR,
-9.51592254519715870e-05 * DAYS_PER_YEAR, 0.0,
5.15138902046611451e-05 * SOLAR_MASS);

const struct body sun = body(0.0, 0.0, 0.0, 0.0, 0.0,
constexpr struct body sun = body(0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, SOLAR_MASS);

xt::xtensor_fixed<double, xt::xshape<nb>> mass = {
Expand Down
98 changes: 96 additions & 2 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
Vec<ASR::stmt_t*>* default_stmt;
OneTimeUseBool is_break_stmt_present;
bool enable_fall_through;
std::map<ASR::symbol_t*, std::map<std::string, ASR::expr_t*>> struct2member_inits;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Expand Down Expand Up @@ -481,6 +482,26 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXRecordDecl(clang::CXXRecordDecl* x) {
for( auto constructors = x->ctor_begin(); constructors != x->ctor_end(); constructors++ ) {
clang::CXXConstructorDecl* constructor = *constructors;
if( constructor->isTrivial() || constructor->isImplicit() ) {
continue ;
}
for( auto ctor = constructor->init_begin(); ctor != constructor->init_end(); ctor++ ) {
clang::CXXCtorInitializer* ctor_init = *ctor;
clang::Expr* init_expr = ctor_init->getInit();
if( init_expr->getStmtClass() == clang::Stmt::StmtClass::InitListExprClass ) {
init_expr = static_cast<clang::InitListExpr*>(init_expr)->getInit(0);
}
if( init_expr->getStmtClass() != clang::Stmt::StmtClass::ImplicitCastExprClass ||
static_cast<clang::ImplicitCastExpr*>(init_expr)->getSubExpr()->getStmtClass() !=
clang::Stmt::StmtClass::DeclRefExprClass ) {
throw std::runtime_error("Initialisation expression in constructor should "
"only be the argument itself.");
}
}
}

SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
std::string struct_name = x->getNameAsString();
Expand Down Expand Up @@ -576,6 +597,20 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return true;
}

ASR::expr_t* evaluate_compile_time_value_for_StructInstanceMember(
ASR::expr_t* base, const std::string& member_name) {
if( ASR::is_a<ASR::Var_t>(*base) ) {
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(base);
ASR::symbol_t* v = ASRUtils::symbol_get_past_external(var_t->m_v);
if( struct2member_inits.find(v) == struct2member_inits.end() ) {
return nullptr;
}
return struct2member_inits[v][member_name];
}

return nullptr;
}

bool TraverseMemberExpr(clang::MemberExpr* x) {
TraverseStmt(x->getBase());
ASR::expr_t* base = ASRUtils::EXPR(tmp.get());
Expand All @@ -596,8 +631,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
struct_type_t->m_name, nullptr, 0, s2c(al, member_name),
ASR::accessType::Public));
current_scope->add_symbol(mangled_name, member);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member,
ASRUtils::symbol_type(member), nullptr);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member, ASRUtils::symbol_type(member),
evaluate_compile_time_value_for_StructInstanceMember(base, member_name));
} else if( special_function_map.find(member_name) != special_function_map.end() ) {
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
Expand Down Expand Up @@ -1308,6 +1343,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXTemporaryObjectExpr(clang::CXXTemporaryObjectExpr *x) {
if( !x->getConstructor()->isConstexpr() ) {
throw std::runtime_error("Constructors for user-define types "
"must be defined with constexpr.");
}
if( static_cast<clang::CompoundStmt*>(x->getConstructor()->getBody())->size() > 0 ) {
throw std::runtime_error("Constructor for user-defined must have empty body.");
}
std::string type_name = x->getConstructor()->getNameAsString();
ASR::symbol_t* s = current_scope->resolve_symbol(type_name);
if( s == nullptr ) {
Expand All @@ -1322,6 +1364,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASR::call_arg_t call_arg;
call_arg.loc = Lloc(x);
call_arg.m_value = ASRUtils::EXPR(tmp.get());
if( !ASRUtils::is_value_constant(ASRUtils::expr_value(call_arg.m_value)) ) {
throw std::runtime_error("Constructor for user-defined types "
"must be initialised with constant values, " + std::to_string(i) +
"-th argument is not a constant.");
}
ASR::ttype_t* orig_type = ASRUtils::symbol_type(
struct_type_t->m_symtab->resolve_symbol(struct_type_t->m_members[i]));
ASR::ttype_t* arg_type = ASRUtils::expr_type(call_arg.m_value);
Expand All @@ -1338,6 +1385,48 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return true;
}

void TraverseAPValue(clang::APValue& field) {
Location loc;
loc.first = 1; loc.last = 1;
switch( field.getKind() ) {
case clang::APValue::Int: {
tmp = ASR::make_IntegerConstant_t(al, loc, field.getInt().getLimitedValue(),
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, field.getInt().getBitWidth()/8)));
break;
}
case clang::APValue::Float: {
tmp = ASR::make_RealConstant_t(al, loc, field.getFloat().convertToDouble(),
ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8)));
break;
}
default: {
throw std::runtime_error("APValue not supported for clang::APValue::" +
std::to_string(field.getKind()));
}
}
}

void evaluate_compile_time_value_for_Var(clang::APValue* ap_value, ASR::symbol_t* v) {
switch( ap_value->getKind() ) {
case clang::APValue::Struct: {
ASR::ttype_t* v_type = ASRUtils::type_get_past_const(ASRUtils::symbol_type(v));
if( !ASR::is_a<ASR::Struct_t>(*v_type) ) {
throw std::runtime_error("Expected ASR::Struct_t type found, " +
ASRUtils::type_to_str(v_type));
}
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(v_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
for( size_t i = 0; i < ap_value->getStructNumFields(); i++ ) {
clang::APValue& field = ap_value->getStructField(i);
TraverseAPValue(field);
struct2member_inits[v][struct_type_t->m_members[i]] = ASRUtils::EXPR(tmp.get());
}
break;
}
}
}

bool TraverseVarDecl(clang::VarDecl *x) {
std::string name = x->getName().str();
if( scopes.size() > 0 ) {
Expand Down Expand Up @@ -1385,6 +1474,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
is_stmt_created = true;
}
}

if( x->getEvaluatedValue() ) {
clang::APValue* ap_value = x->getEvaluatedValue();
evaluate_compile_time_value_for_Var(ap_value, v);
}
}
return true;
}
Expand Down

0 comments on commit 6def44b

Please sign in to comment.