Skip to content

Commit

Permalink
GPU client throw nicer error
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 15, 2024
1 parent 811eb3f commit b74a004
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ extern "C" PjRtClient* MakeCPUClient(uint8_t asynchronous, int node_id, int num_
}

// xla/python/xla.cc 390
extern "C" PjRtClient* MakeGPUClient(int node_id, int num_nodes, int* allowed_devices, int num_allowed_devices, const char* platform_name) {
extern "C" PjRtClient* MakeGPUClient(int node_id, int num_nodes, int* allowed_devices, int num_allowed_devices, const char* platform_name, const char** error) {
GpuClientOptions options;
// options.kv_store = "etcd";
// options.allocator_config =
Expand All @@ -86,8 +86,18 @@ extern "C" PjRtClient* MakeGPUClient(int node_id, int num_nodes, int* allowed_de
options.allowed_devices = allowed_devices ? std::set<int>(allowed_devices, allowed_devices + num_allowed_devices) : std::optional<std::set<int>>();
options.platform_name = platform_name ? std::string(platform_name) : std::optional<std::string>();
// options.collectives = num_nodes;
auto client = xla::ValueOrThrow(GetStreamExecutorGpuClient(options));
return client.release();
auto clientErr = GetStreamExecutorGpuClient(options);

if (!v.ok()) {
auto str = x.status();
const char* err = malloc(strlen(str)+1);
memcpy(err, str, strlen(str)+1);
*error = err;
return nullptr;
} else {
auto client = std::move(v).value();
return client.release();
}
}

extern "C" int ClientNumDevices(PjRtClient* client) {
Expand Down

0 comments on commit b74a004

Please sign in to comment.