Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bitmasked-option for tbr analysis #808

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,28 @@ enum order {
third = 3,
}; // enum order

enum opts {
enum opts : unsigned {
use_enzyme = 1 << ORDER_BITS,
vector_mode = 1 << (ORDER_BITS + 1),

// Storing two bits for tbr analysis.
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
}; // enum opts

constexpr unsigned GetDerivativeOrder(unsigned const bitmasked_opts) {
constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
return bitmasked_opts & ORDER_MASK;
}

constexpr bool HasOption(unsigned const bitmasked_opts, unsigned const option) {
constexpr bool HasOption(const unsigned bitmasked_opts, const unsigned option) {
return (bitmasked_opts & option) == option;
}

constexpr unsigned GetBitmaskedOpts() { return 0; }
constexpr unsigned GetBitmaskedOpts(unsigned const first) { return first; }
constexpr unsigned GetBitmaskedOpts(const unsigned first) { return first; }
template <typename... Opts>
constexpr unsigned GetBitmaskedOpts(unsigned const first, Opts... opts) {
constexpr unsigned GetBitmaskedOpts(const unsigned first, Opts... opts) {
return first | GetBitmaskedOpts(opts...);
}

Expand Down
10 changes: 9 additions & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ namespace clad {
using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;
using DiffInterval = std::vector<clang::SourceRange>;

struct RequestOptions {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
/// The source interval where clad was activated.
///
Expand All @@ -101,9 +107,11 @@ namespace clad {
const clang::FunctionDecl* m_TopMostFD = nullptr;
clang::Sema& m_Sema;

RequestOptions& m_Options;

public:
DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S);
DiffSchedule& plans, clang::Sema& S, RequestOptions& opts);
bool VisitCallExpr(clang::CallExpr* E);

private:
Expand Down
54 changes: 22 additions & 32 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -376,9 +375,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <unsigned... BitMaskedOpts /*To check for enzyme*/,
typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
Expand All @@ -397,38 +395,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = HessianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by hessian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code, f);
}

/// Generates function which computes jacobian matrix of the given function
Expand All @@ -438,38 +432,34 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// \param[in] args independent parameters information
/// \returns `CladFunction` object to access the corresponding derived
/// function.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
/// The specialization is needed because objects have to be passed
/// by reference whereas functions have to be passed by value.
template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
template <unsigned... BitMaskedOpts, typename ArgSpec = const char*,
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<
DerivedFnType,
ExtractFunctorTraits_t<F>>(derivedFn /* will be replaced by Jacobian*/,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code, f);
}

template <typename ArgSpec = const char*, typename F,
Expand Down
56 changes: 36 additions & 20 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ namespace clad {
}

DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S)
DiffSchedule& plans, clang::Sema& S,
RequestOptions& opts)
: m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr),
m_Sema(S) {
m_Sema(S), m_Options(opts) {

if (Interval.empty())
return;
Expand Down Expand Up @@ -556,27 +557,53 @@ namespace clad {
return true;
DiffRequest request{};

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

// Set option for TBR analysis.
enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_tbr_in_req || disable_tbr_in_req) {
// override the default value of TBR analysis.
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
}

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
derivative_order = 1; // default to first order derivative.
}
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;
Expand All @@ -601,17 +628,6 @@ namespace clad {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
request.Mode = DiffMode::reverse;

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
// reverse vector mode is not yet supported.
Expand Down
4 changes: 2 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang -mllvm -debug-only=clad-tbr -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s
// REQUIRES: asserts
//CHECK-NOT: {{.*error|warning|note:.*}}

Expand All @@ -13,7 +13,7 @@ double f1(double x) {

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
auto F##grad = clad::gradient<clad::opts::enable_tbr>(F);\
F##grad.execute(x, result);\
printf("{%.2f}\n", result[0]); \
}
Expand Down
3 changes: 2 additions & 1 deletion test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ int main () {
clad::differentiate(test_6, "x");
clad::differentiate(test_7, "i");
clad::differentiate(test_8, "x");

clad::differentiate<clad::opts::enable_tbr>(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}}
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
return 0;
}
2 changes: 1 addition & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ int main() {
d_structPointer.execute(5, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 1.00

auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x");
auto d_cStyleMemoryAlloc = clad::gradient<clad::opts::disable_tbr>(cStyleMemoryAlloc, "x");
d_x = 0;
d_cStyleMemoryAlloc.execute(5, 7, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 4.00
Expand Down
7 changes: 7 additions & 0 deletions test/Misc/Args.C
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
// CHECK_HELP-NEXT: -fdump-derived-fn
// CHECK_HELP-NEXT: -fdump-derived-fn-ast
// CHECK_HELP-NEXT: -fgenerate-source-file
// CHECK_HELP-NEXT: -fno-validate-clang-version
// CHECK_HELP-NEXT: -enable-tbr
// CHECK_HELP-NEXT: -disable-tbr
// CHECK_HELP-NEXT: -fcustom-estimation-model
// CHECK_HELP-NEXT: -fprint-num-diff-errors
// CHECK_HELP-NEXT: -help
Expand All @@ -23,3 +26,7 @@
// RUN: -Xclang %t.so %S/../../demos/ErrorEstimation/CustomModel/test.cpp \
// RUN: -I%S/../../include 2>&1 | FileCheck --check-prefix=CHECK_SO_INVALID %s
// CHECK_SO_INVALID: Failed to load '{{.*.so}}', {{.*}}. Aborting.

// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \
// RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s
// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together
24 changes: 17 additions & 7 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,15 @@ namespace clad {
if (m_HandleTopLevelDeclInternal)
return true;

RequestOptions opts{};
SetRequestOptions(opts);
DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema(),
opts);

if (requests.empty())
return true;

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
if (m_DO.EnableTBRAnalysis)
for (DiffRequest& request : requests)
request.EnableTBRAnalysis = true;

// FIXME: Remove the PerformPendingInstantiations altogether. We should
// somehow make the relevant functions referenced.
// Instantiate all pending for instantiations templates, because we will
Expand Down Expand Up @@ -318,6 +315,19 @@ namespace clad {
m_HasRuntime = !R.empty();
return m_HasRuntime;
}

void CladPlugin::SetRequestOptions(RequestOptions& opts) const {
SetTBRAnalysisOptions(m_DO, opts);
}

void CladPlugin::SetTBRAnalysisOptions(const DifferentiationOptions& DO,
RequestOptions& opts) {
// If user has explicitly specified the mode for TBR analysis, use it.
if (DO.EnableTBRAnalysis || DO.DisableTBRAnalysis)
opts.EnableTBRAnalysis = DO.EnableTBRAnalysis && !DO.DisableTBRAnalysis;
else
opts.EnableTBRAnalysis = false; // Default mode.
}
} // end namespace plugin

clad::CladTimerGroup::CladTimerGroup()
Expand Down
Loading
Loading