diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 1962971e1..0038a6bbd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1722,21 +1722,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; const Expr* arg = CE->getArg(idx); - if (utils::IsReferenceOrPointerArg(arg)) { - if (argDerivative) { - if (utils::isArrayOrPointerType(argDerivative->getType()) || - isCladArrayType(argDerivative->getType()) || - !argDerivative->isLValue()) - gradArgExpr = argDerivative; - else - gradArgExpr = - BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); - } - } else { - Expr* gradVarExpr = CallArgDx[idx]; + if (utils::isArrayOrPointerType(arg->getType()) || + isCladArrayType(argDerivative->getType())) + gradArgExpr = argDerivative; + else gradArgExpr = - BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation()); - } + BuildOp(UO_AddrOf, argDerivative, m_Function->getLocation()); DerivedCallOutputArgs.push_back(gradArgExpr); idx++; } @@ -1745,13 +1736,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); - // FIXME: Remove this restriction. - if (!FD->getReturnType()->isVoidType()) { - assert((pullback && !FD->getReturnType()->isVoidType()) && - "Call to function returning non-void type with no dfdx() is not " - "supported!"); - } - if (FD->getReturnType()->isVoidType()) { assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any "