Skip to content

Commit

Permalink
Optimize pullback calls
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Mar 6, 2024
1 parent d7e5434 commit 3aa40e5
Show file tree
Hide file tree
Showing 31 changed files with 293 additions and 885 deletions.
6 changes: 6 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ namespace clad {

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

/// Checks if the argument arg involves type conversion
/// (e.g. float to double). This is used to find out
/// a similar cast can be performed for the derived
/// arg (for corresponding pointers).
bool isTriviallyCastableArg(const clang::Expr* arg);
} // namespace utils
} // namespace clad

Expand Down
15 changes: 7 additions & 8 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] CallArgs The orignal call arguments of the function call.
/// \param[in] ArgResultDecls The differentiated call arguments.
/// \param[in] numArgs The number of call args.
void EmitNestedFunctionParamError(
clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, size_t numArgs);
void
EmitNestedFunctionParamError(clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::Expr*>& ArgResult,
size_t numArgs);

/// Save values of registered variables so that they can be replaced
/// properly in case of re-assignments.
Expand Down Expand Up @@ -265,8 +266,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void
ActAfterCloningLHSOfAssignOp(clang::Expr*& LCloned, clang::Expr*& R,
clang::BinaryOperator::Opcode& opCode) override;
Expand All @@ -275,8 +275,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ExternalRMVSource {
virtual void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {}
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {}

/// This is called just before finalising processing of post and pre
/// increment and decrement operations.
Expand Down Expand Up @@ -155,7 +155,7 @@ class ExternalRMVSource {

virtual void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {}
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {}

virtual void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
Expand Down
6 changes: 2 additions & 4 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override;
void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) override;
Expand All @@ -59,8 +58,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
10 changes: 7 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,16 +583,20 @@ namespace clad {
///
/// \param[in] targetFuncCall The function to get the derivative for.
/// \param[in] retType The return type of the target call expression.
/// \param[in] dfdx The dfdx corresponding to this call expression.
/// \param[in] numArgs The total number of 'args'.
/// \param[in] NumericalDiffMultiArg The built statements to add to block
/// later.
/// \param[in] PreCallStmts The built statements to add to block
/// before the call to the derived function.
/// \param[in] PostCallStmts The built statements to add to block
/// after the call to the derived function.
/// \param[in] args All the arguments to the target function.
/// \param[in] outputArgs The output gradient arguments.
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
llvm::SmallVectorImpl<clang::Stmt*>& NumericalDiffMultiArg,
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
Expand Down
11 changes: 11 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,5 +680,16 @@ namespace clad {
return FD->getNameAsString() == "free";
#endif
}

bool isTriviallyCastableArg(const clang::Expr* arg) {
// `arg` is the of the type that the function expects.
const QualType expectedType = arg->getType().getCanonicalType();
// By removing all implicit casts, we can get the type of the original
// argument.
const QualType argType =
arg->IgnoreImplicit()->getType().getCanonicalType();
// Compare the unqualified types.
return expectedType.getTypePtr() == argType.getTypePtr();
}
} // namespace utils
} // namespace clad
10 changes: 5 additions & 5 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void ErrorEstimationHandler::SaveReturnExpr(Expr* retExpr) {

void ErrorEstimationHandler::EmitNestedFunctionParamError(
FunctionDecl* fnDecl, llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, size_t numArgs) {
llvm::SmallVectorImpl<Expr*>& ArgResult, size_t numArgs) {
assert(fnDecl && "Must have a value");
for (size_t i = 0; i < numArgs; i++) {
if (!fnDecl->getParamDecl(0)->getType()->isLValueReferenceType())
Expand All @@ -145,7 +145,7 @@ void ErrorEstimationHandler::EmitNestedFunctionParamError(
// if (utils::IsReferenceOrPointerType(fnDecl->getParamDecl(i)->getType()))
// continue;
Expr* errorExpr = m_EstModel->AssignError(
{derivedCallArgs[i], m_RMV->BuildDeclRef(ArgResultDecls[i])},
{derivedCallArgs[i], m_RMV->Clone(ArgResult[i])},
fnDecl->getNameInfo().getAsString() + "_param_" + std::to_string(i));
Expr* errorStmt = m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr);
m_ReverseErrorStmts.push_back(errorStmt);
Expand Down Expand Up @@ -605,7 +605,7 @@ void ErrorEstimationHandler::ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) {
void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<Expr*>& ArgResult, bool asGrad) {
if (OverloadedDerivedFn && asGrad) {
// Derivative was found.
FunctionDecl* fnDecl =
Expand All @@ -615,7 +615,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
// in the input prameters (if of reference type) to call and save to
// emit them later.

EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResultDecls,
EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResult,
CE->getNumArgs());
}
}
Expand Down Expand Up @@ -650,7 +650,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleExpr(

void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<Stmt*>& ArgDecls, bool hasAssignee) {
auto errorRef =
m_RMV->BuildVarDecl(m_RMV->m_Context.DoubleTy, "_t",
m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy));
Expand Down
8 changes: 4 additions & 4 deletions lib/Differentiator/MultiplexExternalRMVSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ void MultiplexExternalRMVSource::ActBeforeFinalisingVisitReturnStmt(
void MultiplexExternalRMVSource::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {
for (auto source : m_Sources) {
source->ActBeforeFinalizingVisitCallExpr(CE, OverloadedDerivedFn, derivedCallArgs,
ArgResultDecls, asGrad);
source->ActBeforeFinalizingVisitCallExpr(
CE, OverloadedDerivedFn, derivedCallArgs, ArgResult, asGrad);
}
}

Expand Down Expand Up @@ -198,7 +198,7 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingDifferentiateSingleExpr(

void MultiplexExternalRMVSource::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {
for (auto source : m_Sources)
source->ActBeforeDifferentiatingCallExpr(pullbackArgs, ArgDecls,
hasAssignee);
Expand Down
Loading

0 comments on commit 3aa40e5

Please sign in to comment.