Skip to content

Commit

Permalink
Working gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 1, 2024
1 parent 38f1c9c commit 81a889d
Show file tree
Hide file tree
Showing 11 changed files with 531 additions and 119 deletions.
22 changes: 18 additions & 4 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

#include "absl/log/initialize.h"

#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir/utils/type_util.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
Expand All @@ -53,6 +55,16 @@ using namespace xla;
// int google::protobuf::io::CodedInputStream::default_recursion_limit_ = 100;
// int xla::_LayoutProto_default_instance_;

extern "C" void InitializeLogs() {
absl::InitializeLog();
}

extern "C"
MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(
MlirContext ctx, int32_t val) {
return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx), (mlir::enzyme::Activity)val));
}

extern "C" PjRtClient* MakeCPUClient(uint8_t asynchronous, int node_id, int num_nodes) {
CpuClientOptions options;
// options.kv_store = "etcd";
Expand Down Expand Up @@ -245,13 +257,16 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu
if (!is_arg_donatable[i])
options.non_donatable_input_indices.insert((int)i);
}
options.untuple_result = true;
std::optional<std::vector<FutureType>> returned_futures;
auto results = xla::ValueOrThrow(exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer*>>>(argument_handles), options, returned_futures));

if (results.size() != num_results) {
assert(results.size() == 1);

if (results[0].size() != num_results) {
llvm::errs() <<" results.size()=" << results.size() << " num_results=" << num_results << "\n";
}
assert(results.size() == num_results);
assert(results[0].size() == num_results);
if (returned_futures) {
*futures = true;
assert(returned_futures->size() == num_results);
Expand All @@ -263,8 +278,7 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu
}

for (size_t i=0; i<num_results; i++) {
assert(results[i].size() == 1);
op_results[i] = results[i][0].release();
op_results[i] = results[0][i].release();
}
}

Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ cc_library(
"@stablehlo//:stablehlo_capi_objects",
"@stablehlo//:chlo_capi_objects",
"@com_google_absl//absl/hash:hash",
"@com_google_absl//absl/log:initialize",
"@llvm-project//mlir:CAPIIRObjects",
],
)
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "23e85c3143a0b1f25a68b63e043bd1b5ae966061"
ENZYMEXLA_COMMIT = "cf0461a3bc430779721a2709a639024e33df7637"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
Loading

0 comments on commit 81a889d

Please sign in to comment.