Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Varied analysis to the reverse mode #1084

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
enable_va = 1 << (ORDER_BITS + 5),
disable_aa = 1 << (ORDER_BITS + 6),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
Expand Down
11 changes: 11 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include <iterator>
#include <set>
namespace clang {
class CallExpr;
class CompilerInstance;
Expand All @@ -31,6 +33,11 @@ struct DiffRequest {
bool HasAnalysisRun = false;
} m_TbrRunInfo;

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

public:
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
Expand All @@ -55,6 +62,7 @@ struct DiffRequest {
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'EnableVariedAnalysis' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

  bool EnableVariedAnalysis = false;
       ^

/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down Expand Up @@ -112,6 +120,7 @@ struct DiffRequest {
RequestedDerivativeOrder == other.RequestedDerivativeOrder &&
CallContext == other.CallContext && Args == other.Args &&
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
EnableVariedAnalysis == other.EnableVariedAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
}
Expand All @@ -129,6 +138,7 @@ struct DiffRequest {
}

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand All @@ -137,6 +147,7 @@ struct DiffRequest {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
Expand Down
179 changes: 179 additions & 0 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#include "ActivityAnalyzer.h"

using namespace clang;

namespace clad {

void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
// Build the CFG (control-flow graph) of FD.
clang::CFG::BuildOptions Options;
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
// Set current block ID to the ID of entry the block.
CFGBlock* entry = &m_CFG->getEntry();
m_CurBlockID = entry->getBlockID();
m_BlockData[m_CurBlockID] = createNewVarsData({});
for (const VarDecl* i : m_VariedDecls)
m_BlockData[m_CurBlockID]->insert(i);
// Add the entry block to the queue.
m_CFGQueue.insert(m_CurBlockID);

// Visit CFG blocks in the queue until it's empty.
while (!m_CFGQueue.empty()) {
auto IDIter = std::prev(m_CFGQueue.end());
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
AnalyzeCFGBlock(nextBlock);
}
}

void mergeVarsData(VarsData* targetData, VarsData* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
for (const clang::VarDecl* i : *targetData)
mergeData->insert(i);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this function will make targetData and mergeData identical. Do we need them both to be modified? Even if we do, we can probably do *mergedata = *targetdata instead of the last two lines.


CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {
// Visit all the statements inside the block.
for (const clang::CFGElement& Element : block) {
if (Element.getKind() == clang::CFGElement::Statement) {
const clang::Stmt* S = Element.castAs<clang::CFGStmt>().getStmt();
// The const_cast is inevitable, since there is no
// ConstRecusiveASTVisitor.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TraverseStmt(const_cast<clang::Stmt*>(S));
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
}
}

for (const clang::CFGBlock::AdjacentBlock succ : block.succs()) {
if (!succ)
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
continue;
auto& succData = m_BlockData[succ->getBlockID()];

if (!succData)
succData = createNewVarsData(*m_BlockData[block.getBlockID()]);

bool shouldPushSucc = true;
if (succ->getBlockID() > block.getBlockID()) {
if (m_LoopMem == *m_BlockData[block.getBlockID()])
shouldPushSucc = false;

for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_LoopMem.insert(i);
}

if (shouldPushSucc)
m_CFGQueue.insert(succ->getBlockID());

mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get());
}
// FIXME: Information about the varied variables is stored in the last block,
// so we should be able to get it form there
for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_VariedDecls.insert(i);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
}

bool VariedAnalyzer::isVaried(const VarDecl* VD) const {
const VarsData& curBranch = getCurBlockVarsData();
return curBranch.find(VD) != curBranch.end();
}

void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
curBranch.insert(VD);
}

bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) {
Expr* L = BinOp->getLHS();
Expr* R = BinOp->getRHS();
const auto opCode = BinOp->getOpcode();
if (BinOp->isAssignmentOp()) {
m_Varied = false;
TraverseStmt(R);
m_Marking = m_Varied;
TraverseStmt(L);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: cannot initialize object parameter of type 'clang::RecursiveASTVisitorclad::VariedAnalyzer' with an expression of type 'clad::VariedAnalyzer' [clang-diagnostic-error]

    TraverseStmt(L);
    ^

m_Marking = false;
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul ||
opCode == BO_Div) {
for (auto* subexpr : BinOp->children())
if (!isa<BinaryOperator>(subexpr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we visit binary operators recursively? Doesn't this mean we will ignore a part of x + y + z for example? Since it's treated as a nested addition operation in the AST.

Copy link
Collaborator Author

@ovdiiuv ovdiiuv Oct 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1106

TraverseStmt(subexpr);
}
return true;
}

// add branching merging
bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) {
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
TraverseStmt(CO->getCond());
TraverseStmt(CO->getTrueExpr());
ovdiiuv marked this conversation as resolved.
Show resolved Hide resolved
TraverseStmt(CO->getFalseExpr());
ovdiiuv marked this conversation as resolved.
Show resolved Hide resolved
return true;
}

bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
}
m_Varied = true;
return true;
}

bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
for (Decl* D : DS->decls()) {
if (!isa<VarDecl>(D))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test here could be class C{} c; that produces a DeclStmt whose Stmt is not a VarDecl.

continue;
if (Expr* init = cast<VarDecl>(D)->getInit()) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied)
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
}
return true;
}

bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) {
const auto opCode = UnOp->getOpcode();
Expr* E = UnOp->getSubExpr();
if (opCode == UO_AddrOf || opCode == UO_Deref) {
m_Varied = true;
m_Marking = true;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to set these modes to true here? And even if so, shouldn't we also set m_Marking to false afterward?

TraverseStmt(E);
return true;
}

bool VariedAnalyzer::VisitInitListExpr(InitListExpr* ILE) { return true; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: method 'VisitInitListExpr' can be made static [readability-convert-member-functions-to-static]

lib/Differentiator/ActivityAnalyzer.h:81:

-   bool VisitInitListExpr(clang::InitListExpr* ILE);
+   static bool VisitInitListExpr(clang::InitListExpr* ILE);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: method 'VisitInitListExpr' can be made static [readability-convert-member-functions-to-static]

lib/Differentiator/ActivityAnalyzer.h:79:

-   bool VisitInitListExpr(clang::InitListExpr* ILE);
+   static bool VisitInitListExpr(clang::InitListExpr* ILE);

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping.


bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
if (isVaried(dyn_cast<VarDecl>(DRE->getDecl())))
m_Varied = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this portion of the code repeat what comes next? Perhaps we should keep one of these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


auto* VD = dyn_cast<VarDecl>(DRE->getDecl());
if (!VD)
return true;

if (isVaried(VD))
m_Varied = true;

if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
return true;
}
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
} // namespace clad
83 changes: 83 additions & 0 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#ifndef CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: header guard does not follow preferred style [llvm-header-guard]

Suggested change
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#ifndef GITHUB_WORKSPACE_LIB_DIFFERENTIATOR_ACTIVITYANALYZER_H
#define GITHUB_WORKSPACE_LIB_DIFFERENTIATOR_ACTIVITYANALYZER_H

lib/Differentiator/ActivityAnalyzer.h:84:

- #endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
+ #endif // GITHUB_WORKSPACE_LIB_DIFFERENTIATOR_ACTIVITYANALYZER_H

vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

#include "clang/AST/RecursiveASTVisitor.h"
ovdiiuv marked this conversation as resolved.
Show resolved Hide resolved
#include "clang/Analysis/CFG.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <iterator>
#include <memory>
#include <set>
#include <utility>
/// Class that implemets Varied part of the Activity analysis.
/// By performing static data-flow analysis, so called Varied variables
/// are determined, meaning variables that depend on input parameters
/// in a differentiable way. That result enables us to remove redundant
/// statements in the reverse mode, improving generated codes efficiency.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation should go to the class definition.

namespace clad {
using VarsData = std::set<const clang::VarDecl*>;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can go in the class definition.

class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
ovdiiuv marked this conversation as resolved.
Show resolved Hide resolved

bool m_Varied = false;
bool m_Marking = false;

std::set<const clang::VarDecl*>& m_VariedDecls;
ovdiiuv marked this conversation as resolved.
Show resolved Hide resolved
/// A helper method to allocate VarsData
/// \param[in] toAssign - Parameter to initialize new VarsData with.
/// \return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
}
VarsData m_LoopMem;

clang::CFGBlock* getCFGBlockByID(unsigned ID);

clang::ASTContext& m_Context;
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
std::unique_ptr<clang::CFG> m_CFG;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;
/// Checks if a variable is on the current branch.
/// \param[in] VD - Variable declaration.
/// @return Whether a variable is on the current branch.
bool isVaried(const clang::VarDecl* VD) const;
/// Adds varied variable to current branch.
/// \param[in] VD - Variable declaration.
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
[[nodiscard]] const VarsData& getCurBlockVarsData() const {
return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData();
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
}
void AnalyzeCFGBlock(const clang::CFGBlock& block);

public:
/// Constructor
VariedAnalyzer(clang::ASTContext& Context,
std::set<const clang::VarDecl*>& Decls)
: m_VariedDecls(Decls), m_Context(Context) {}

/// Destructor
~VariedAnalyzer() = default;

/// Delete copy/move operators and constructors.
VariedAnalyzer(const VariedAnalyzer&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&) = delete;
VariedAnalyzer(const VariedAnalyzer&&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Runs Varied analysis.
/// \param[in] FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
bool VisitCallExpr(clang::CallExpr* CE);
bool VisitConditionalOperator(clang::ConditionalOperator* CO);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitDeclStmt(clang::DeclStmt* DS);
bool VisitUnaryOperator(clang::UnaryOperator* UnOp);
bool VisitInitListExpr(clang::InitListExpr* ILE);
};
} // namespace clad
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: #endif without #if [clang-diagnostic-error]

#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
 ^

1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY
# (Ab)use llvm facilities for adding libraries.
llvm_add_library(cladDifferentiator
STATIC
ActivityAnalyzer.cpp
BaseForwardModeVisitor.cpp
CladUtils.cpp
ConstantFolder.cpp
Expand Down
Loading
Loading