Skip to content

Commit

Permalink
Merge branch 'master' into activity
Browse files Browse the repository at this point in the history
  • Loading branch information
ovdiiuv authored Oct 4, 2024
2 parents 431c490 + 821683f commit 94d5ccd
Show file tree
Hide file tree
Showing 13 changed files with 1,225 additions and 76 deletions.
27 changes: 27 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename T, typename U> struct ValueAndPushforward {
}
};

template <typename T, typename U>
ValueAndPushforward<T, U> make_value_and_pushforward(T value, U pushforward) {
return {value, pushforward};
}

template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
Expand Down Expand Up @@ -178,6 +183,25 @@ CUDA_HOST_DEVICE ValueAndPushforward<T, T> floor_pushforward(T x, T /*d_x*/) {
return {::std::floor(x), (T)0};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> atan2_pushforward(T y, T x, T d_y,
T d_x) {
return {::std::atan2(y, x),
-(y / ((x * x) + (y * y))) * d_x + x / ((x * x) + (y * y)) * d_y};
}

template <typename T, typename U>
CUDA_HOST_DEVICE void atan2_pullback(T y, T x, U d_z, T* d_y, T* d_x) {
*d_y += x / ((x * x) + (y * y)) * d_z;

*d_x += -(y / ((x * x) + (y * y))) * d_z;
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> acos_pushforward(T x, T d_x) {
return {::std::acos(x), ((-1) / (::std::sqrt(1 - x * x))) * d_x};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> ceil_pushforward(T x, T /*d_x*/) {
return {::std::ceil(x), (T)0};
Expand Down Expand Up @@ -316,6 +340,9 @@ inline void free_pushforward(void* ptr, void* d_ptr) {
// These are required because C variants of mathematical functions are
// defined in global namespace.
using std::abs_pushforward;
using std::acos_pushforward;
using std::atan2_pullback;
using std::atan2_pushforward;
using std::ceil_pushforward;
using std::cos_pushforward;
using std::exp_pushforward;
Expand Down
67 changes: 31 additions & 36 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,18 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args..., (void*)static_cast<Rest>(nullptr)...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest);
std::vector<void*> argPtrs;
argPtrs.reserve(totalArgs);
(argPtrs.push_back(static_cast<void*>(&args)), ...);

void* null_param = nullptr;
for (size_t i = sizeof...(args); i < totalArgs; ++i)
argPtrs[i] = &null_param;

cudaLaunchKernel((void*)f, grid, block, argPtrs.data(), shared_mem,
stream);
return return_type_t<F>();
} else {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}
Expand Down Expand Up @@ -198,23 +208,17 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code,
FunctorType* functor = nullptr,
bool CUDAkernel = false)
: m_Functor(functor), m_CUDAkernel(CUDAkernel) {
assert(f && "Must pass a non-0 argument.");
if (size_t length = GetLength(code)) {
m_Function = f;
char* temp = (char*)malloc(length + 1);
m_Code = temp;
while ((*temp++ = *code++));
} else {
// clad did not place the derivative in this object. This can happen
// upon error of if clad was disabled. Diagnose.
printf("clad failed to place the generated derivative in the object\n");
printf("Make sure calls to clad are within a #pragma clad ON region\n");

// Invalidate the placeholders.
m_Function = nullptr;
m_Code = nullptr;
}
: m_Function(f), m_Functor(functor), m_CUDAkernel(CUDAkernel) {
#ifndef __CLAD_SO_LOADED
static_assert(false, "clad doesn't appear to be loaded; make sure that "
"you pass clad.so to clang.");
#endif

size_t length = GetLength(code);
char* temp = (char*)malloc(length + 1);
m_Code = temp;
while ((*temp++ = *code++))
;
}
/// Constructor overload for initializing `m_Functor` when functor
/// is passed by reference.
Expand Down Expand Up @@ -371,9 +375,6 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
template <class ReturnType, class C, class... Args>
return_type_t<CladFunctionType> execute_helper(ReturnType C::*f,
Args&&... args) {
assert(m_Functor &&
"No default object set, explicitly pass an object to "
"CladFunction::execute");
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f, *m_Functor,
Expand Down Expand Up @@ -411,9 +412,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
}

/// Specialization for differentiating functors.
Expand Down Expand Up @@ -454,9 +454,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
}

/// Generates function which computes gradient of the given function wrt the
Expand All @@ -475,7 +474,6 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "", bool CUDAkernel = false) {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, nullptr, CUDAkernel);
}
Expand Down Expand Up @@ -512,9 +510,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}

/// Specialization for differentiating functors.
Expand Down Expand Up @@ -549,9 +546,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -576,7 +572,6 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
estimate_error(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType>(derivedFn /* will be replaced by estimation code*/,
code);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ namespace clad {
StmtDiff
VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);
StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE);
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
Expand Down
Loading

0 comments on commit 94d5ccd

Please sign in to comment.