From f80f7c9477fbfc675478076e34922ecaabe8ec22 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 11 Nov 2024 16:31:43 -0800 Subject: [PATCH] Create subdirectory of XLA docs for PJRT, add C++ API overview docs PiperOrigin-RevId: 695521567 --- docs/_toc.yaml | 10 +- docs/pjrt/_toc.yaml | 12 + docs/pjrt/cpp_api_overview.md | 332 ++++++++++++++++++++++ docs/pjrt/examples.md | 38 +++ docs/pjrt/images/pjrt_client.svg | 1 + docs/pjrt/index.md | 30 ++ docs/{ => pjrt}/pjrt_integration.md | 17 +- xla/pjrt/c/README.md | 5 +- xla/pjrt/c/docs/pjrt_integration_guide.md | 2 +- 9 files changed, 429 insertions(+), 18 deletions(-) create mode 100644 docs/pjrt/_toc.yaml create mode 100644 docs/pjrt/cpp_api_overview.md create mode 100644 docs/pjrt/examples.md create mode 100644 docs/pjrt/images/pjrt_client.svg create mode 100644 docs/pjrt/index.md rename docs/{ => pjrt}/pjrt_integration.md (88%) diff --git a/docs/_toc.yaml b/docs/_toc.yaml index 4e340eafc2dbc..43727270dd515 100644 --- a/docs/_toc.yaml +++ b/docs/_toc.yaml @@ -51,10 +51,16 @@ toc: path: /xla/build_from_source - title: Develop a new backend for XLA path: /xla/developing_new_backend - - title: Develop a new PJRT plugin - path: /xla/pjrt_integration - title: Developer guide path: /xla/developer_guide +- title: PJRT Plugins + section: + - title: PJRT Overview + path: /xla/pjrt/overview + - title: Develop a new PJRT plugin + path: /xla/pjrt/pjrt_integration + - title: PJRT Examples + path: /xla/pjrt/examples - title: Using XLA in TensorFlow # These should be in alphabetical order unless otherwise noted. section: diff --git a/docs/pjrt/_toc.yaml b/docs/pjrt/_toc.yaml new file mode 100644 index 0000000000000..7133f9191f281 --- /dev/null +++ b/docs/pjrt/_toc.yaml @@ -0,0 +1,12 @@ +toc: +- heading: PJRT +- title: Getting started + section: + - title: Introduction + path: /xla/pjrt + - title: PJRT C++ API Overview + path: /xla/pjrt/cpp_api_overview + - title: Develop a new PJRT plugin + path: /xla/pjrt/pjrt_integration + - title: PJRT Examples + path: /xla/pjrt/examples diff --git a/docs/pjrt/cpp_api_overview.md b/docs/pjrt/cpp_api_overview.md new file mode 100644 index 0000000000000..0d0c0b7362dea --- /dev/null +++ b/docs/pjrt/cpp_api_overview.md @@ -0,0 +1,332 @@ +# PJRT C++ Device API Overview + +## Background + +[PJRT](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api.h) +is the uniform Device API that we want to add to the ML ecosystem. The long term +vision is that: + +1. Frameworks (JAX, TF, etc.) will call PJRT, which has device-specific + implementations that are opaque to the frameworks; +2. Each device focuses on implementing PJRT APIs, and can be opaque to the + frameworks. + +PJRT offers both a C API and C++ API. Plugging in at either layer is OK, the C++ +API uses classes to abstract away some concepts, but also has stronger ties to +XLA datatypes. This page docuses on the C++ API. + +## PJRT Components + +![PJRT Components](images/pjrt_client.svg) + +Note: Most items in this diagram also have backpointers, memory spaces know +thier device(s) and client, devices know their client, buffers know their memory +space. + +### PjRtClient + +_Full reference at [`pjrt_client.h > PjRtClient`](https://github.com/openxla/xla/blob/924b74d84de3760cc589fd1525c7346691d51df5/xla/pjrt/pjrt_client.h#L486)._ + +Clients manage all communication between the device and framework, and +encapsulate all state used in the communication. They have a generic set of APIs +for interacting with a PJRT plugin, and they own the devices and memory spaces +for a given plugin. + +### PjRtDevice + +_Full references at [`pjrt_client.h > PjRtDevice`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_client.h#L102), +and [`pjrt_device_description.h`](https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_device_description.h)_ + +A device class is used to describe a single device. A device has a device +description to help identify its kind (unique hash to identify GPU/CPU/xPU), and +location within a grid of devices both locally and globally. + +Devices also know thier associated memory spaces and the client it is owned by. + +A device does *not* necessarily know the buffers of actual data associated with +it, but it can figure that out by looking through its associated memory spaces. + +### PjRtMemorySpace + +_Full reference at [`pjrt_client.h > PjRtMemorySpace`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_client.h#L72)._ + +Memory spaces can be used to describe a location of memory. These can either be +unpinned, and are free to live anywhere but be accessible from a device, or they +can be pinned and must live on a specific device. + +Memory spaces know their associated buffers of data, and the devices (plural) +that a memory space is associated with, as well as the client it is a part of. + +### PjRtBuffer + +_Full reference at [`pjrt_client.h > PjRtBuffer`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_client.h#L1111)._ + +A buffer holds data on a device in some format that will be easy to work with +inside the plugin, such as an MLIR elements attr or a proprietary tensor format. +A framework may try to send data to a device in the form of an `xla::Literal`, +i.e. for an input argument to the module, which must be cloned (or borrowed), to +the devices memory. Once a buffer is no longer needed the `Delete` method is +invoked by the framework to clean up. + +A buffer knows the memory space it is a part of, and transitively can figure out +which devices are able to access it, but buffers to not necessarily know thier +devices. + +For communicating with frameworks, buffers know how to convert to and from an +`xla::Literal` type: + +```cpp +// Literal to Buffer +absl::StatusOr> BufferFromHostBuffer(...) {...} + +// Buffer to Literal +xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...} +``` + +APIs for creating a buffer have [Buffer Semantics](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_client.h#L858) +which help dictate if literal data from the host buffer can be shared or copied +or mutated. + +Lastly, a buffer may need last longer than the scope of its execution, if it is +assigned to a variable in the framework layer `x = jit(foo)(10)`, in these cases +buffers allow building external references which provide a temporarily owned +pointer to the data held by the buffer, along with metadata (dtype / dim sizes) +for interpreting the underlying data. + +### PjRtCompiler + +_Full reference at [`pjrt_compiler.h > PjRtCompiler`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_compiler.h#L157)._ + +The `PjRtCompiler` class provides useful implementation details for XLA +backends, but is not necessary for a plugin to implement. In theory, the +responsibility of a `PjRtCompiler`, or the `PjRtClient::Compile` method, is to +take an input module and return a `PjRtLoadedExecutable`. + +### PjRtExecutable / PjRtLoadedExecutable + +_Full reference at [`pjrt_executable.h > PjRtExecutable`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_executable.h#L306), +and [`pjrt_client.h > PjRtLoadedExecutable`](https://github.com/openxla/xla/blob/3e448cf9e86775a37ec5f7d3c69dfb20e0c760df/xla/pjrt/pjrt_client.h#L1506)._ + +A `PjRtExecutable` knows how to take a compiled artifact and execution options +and serialize/deserialize them so an executable can be stored and loaded as +needed. + +The `PjRtLoadedExecutable` is the in-memory compiled executable which is ready +for input arguments to execute, it is a subclass of `PjRtExecutable`. + +Executables are interfaced with via one of the client's `Execute` methods: + +```cpp +// Execute on addressable devices +absl::StatusOr>>> +Execute(absl::Span> argument_handles, ...) {...} + +// Execute assigned replica/partition on the specified device +absl::StatusOr>> +ExecuteSharded(absl::Span argument_handles, + PjRtDevice* device, ...) {...} + +// Execute on specified device, single replica / partition +absl::StatusOr>> +ExecutePortable(absl::Span argument_handles, + PjRtDevice* device, ...) {...} +``` + +Before calling `Execute` the framework will transfer all required data to +`PjRtBuffers` owned by the executing client, but returned for the framework to +reference. These buffers are then provided as arguments to the `Execute` method. + +## PJRT Concepts + +### PjRtFutures & Async Computations + +If any part of a plugin is implemented asynchronously, it _must_ properly +implement futures. + +Consider the following program: + +```py +@jax.jit +def foo(x): return x + 1 + +x = foo(1) +# [...] other logic not using `x` +print(x + 1) +``` + +An async plugin would be able to enqueue the computation `x`, and immediately +return a buffer which isn't ready to be read yet, but execution will populate +it. Execution can continue to enqueue necessary computations after `x`, that +don't require `x`, including execution on other PJRT devices. Once the value of +`x` is needed, execution will block until the buffer declares itself ready via +the future returned by `GetReadyFuture`. + +Futures can be useful to determine when an object becomes available, including +devices and buffers. + +### Advanced concepts + +Extending beyond implementing the base APIs will expand the features of JAX that +can be used by a plugin. These are all opt-in features in the sense that at +typical JIT and execute workflow will work without them, but for a production +quality pipeline some thought should likely be put into the degree of support +for any of these features supported by PJRT APIs: + +- Memory spaces +- Custom layouts +- Communication ops like send/recv +- Host offloading +- Sharding + +## Typical PJRT framework-device communication + +### Example Log + +The following is a log of the methods called to load the PJRT plugin and +execute `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`. In this case +we log JAX interacting with the StableHLO Reference PJRT plugin. + +
+Example log +
+
+```
+//////////////////////////////////
+// Load the plugin
+//////////////////////////////////
+
+I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
+I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
+I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
+I device.cc:123] client(0x23bac4e0)
+I device.cc:123] client(0x23bac4e0)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
+I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
+I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
+I client_cpp_pjrt.cc:86] devices(0x23bac400)
+I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
+I device.cc:168] description(0x23bac4e0)
+I device.cc:168] description(0x23bac4e0)
+I device.cc:86] Attributes(0x23bac4f8)
+I device.cc:128] IsAddressable(0x23bac4e0)
+I device.cc:168] description(0x23bac4e0)
+I device.cc:61] process_index(0x23bac4f8)
+I device.cc:123] client(0x23bac4e0)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
+I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
+I device.cc:128] IsAddressable(0x23bac4e0)
+I device.cc:168] description(0x23bac4e0)
+I device.cc:61] process_index(0x23bac4f8)
+I device.cc:123] client(0x23bac4e0)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+I device.cc:148] memory_spaces(0x23bac4e0)
+Creating PJRT Client from client
+I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
+I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
+I device.cc:57] id(0x23bac4f8)
+I device.cc:70] device_kind(0x23bac4f8)
+I device.cc:70] device_kind(0x23bac4f8)
+I device.cc:80] ToString(0x23bac4f8)
+I device.cc:80] ToString(0x23bac4f8)
+I device.cc:75] DebugString(0x23bac4f8)
+I device.cc:75] DebugString(0x23bac4f8)
+I device.cc:61] process_index(0x23bac4f8)
+I device.cc:128] IsAddressable(0x23bac4e0)
+I device.cc:168] description(0x23bac4e0)
+I device.cc:61] process_index(0x23bac4f8)
+I device.cc:123] client(0x23bac4e0)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+I device.cc:153] default_memory_space(0x23bac4e0)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+
+//////////////////////////////////
+// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
+//////////////////////////////////
+
+I executable.cc:309] num_partitions(0x240bab70)
+I executable.cc:305] num_replicas(0x240bab70)
+I executable.cc:309] num_partitions(0x240bab70)
+I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
+I buffer.cc:285] CreateMlirBufferFromLiteral
+I buffer.cc:98] CreateFromLiteral
+I buffer.cc:99] CreateFromLiteral: s32[] 2
+I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
+I buffer.cc:102] CreateFromLiteral -> 0x240bb050
+I buffer.cc:158] device(0x240bb050)
+I buffer.cc:154] memory_space(0x240bb050)
+I buffer.cc:154] memory_space(0x240bb050)
+I executable.cc:328] GetHloModules(0x240bab70)
+I executable.cc:240] Execute(0x240bab70)
+I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
+I buffer.cc:303] GetAttributeFromBuffer
+I buffer.cc:229] IsDeleted(0x240bb050)
+I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor
+I executable.cc:205] EvalModule:
+module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+  func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) {
+    // ...
+    return %3 : tensor
+  }
+}
+I executable.cc:206] Inputs: [dense<2> : tensor]
+I executable.cc:213] Results: [dense<2> : tensor]
+I device.cc:153] default_memory_space(0x23bac4e0)
+I buffer.cc:291] CreateMlirBufferFromAttribute
+I buffer.cc:116] CreateFromAttribute
+I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
+I buffer.cc:122] CreateFromAttribute(dense<2> : tensor) -> 0x22cea630
+
+//////////////////////////////////
+// RUN: `print(y)`
+//////////////////////////////////
+
+I buffer.cc:263] GetReadyFuture(0x22cea630)
+I buffer.cc:264] GetReadyFuture(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I buffer.cc:158] device(0x22cea630)
+I buffer.cc:158] device(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I buffer.cc:229] IsDeleted(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:158] device(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I buffer.cc:154] memory_space(0x22cea630)
+I client_cpp_pjrt.cc:71] process_index(0x23bac400)
+I buffer.cc:229] IsDeleted(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:129] on_device_shape(0x22cea630)
+I buffer.cc:168] AcquireExternalReference(0x22cea630)
+I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
+I buffer.cc:303] GetAttributeFromBuffer
+I buffer.cc:229] IsDeleted(0x22cea630)
+I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor
+I buffer.cc:291] CreateMlirBufferFromAttribute
+I buffer.cc:116] CreateFromAttribute
+I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
+I buffer.cc:122] CreateFromAttribute(dense<2> : tensor) -> 0x240bb050
+I buffer.cc:168] AcquireExternalReference(0x22cea630)
+I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
+I buffer.cc:303] GetAttributeFromBuffer
+I buffer.cc:229] IsDeleted(0x22cea630)
+I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor
+I buffer.cc:291] CreateMlirBufferFromAttribute
+I buffer.cc:116] CreateFromAttribute
+I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
+I buffer.cc:122] CreateFromAttribute(dense<2> : tensor) -> 0x23b2db60
+I buffer.cc:263] GetReadyFuture(0x22cea630)
+I buffer.cc:264] GetReadyFuture(0x22cea630)
+```
+
+
diff --git a/docs/pjrt/examples.md b/docs/pjrt/examples.md new file mode 100644 index 0000000000000..a9ae11f9a4c0f --- /dev/null +++ b/docs/pjrt/examples.md @@ -0,0 +1,38 @@ +# PJRT Examples + +## Example: JAX CUDA plugin + +1. PJRT C API implementation through wrapper ([pjrt\_c\_api\_gpu.h](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api_gpu.h)). +1. Set up the entry point for the package ([setup.py](https://github.com/google/jax/blob/main/jax_plugins/cuda/setup.py)). +1. Implement an initialize() method ([\_\_init\_\_.py](https://github.com/google/jax/blob/a10854786b6d1bc92a65dd314916b151640789af/plugins/cuda/__init__.py#L31-L51)). +1. Can be tested with any jax tests for CUDA. + + +## Frameworks Implementations + +Some references for using PJRT on the framework side, to interface with PJRT +devices: + +- JAX + + [jax-ml/jax](https://github.com/jax-ml/jax/blob/main/jax/_src/compiler.py#L248) + interacts with PJRT APIs via the `xla_client` APIs +- GoMLX + + [gomlx/gopjrt](https://github.com/gomlx/gopjrt) + + [gomlx/gomlx > backends/xla](https://github.com/gomlx/gomlx/tree/main/backends/xla/xla.go) +- ZML + + PJRT API wrapper [pjrt.zig](https://github.com/zml/zml/blob/master/pjrt/pjrt.zig) + + Load PJRT Plugin [context.zig](https://github.com/zml/zml/blob/master/zml/context.zig#L30-L34) + + Interacting with PJRT Buffers [buffer.zig](https://github.com/zml/zml/blob/master/zml/buffer.zig#L36) + + Execute a module via PJRT [module.zig](https://github.com/zml/zml/blob/master/zml/module.zig#L863-L886) + +## Hardware Implementations + +- Full integration plugins (PJRT+MLIR+XLA): + + [XLA CPU Plugin](https://github.com/openxla/xla/tree/main/xla/pjrt/cpu/cpu_client.cc) + + [XLA GPU Plugin](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu/se_gpu_pjrt_client.cc) + + [Intel XLA Plugin](https://github.com/intel/intel-extension-for-openxla) +- Light integration plugins (PJRT+MLIR): + + StableHLO Reference Interpreter plugin + (MLIR-based, C++ plugin, to be linked after devlabs) + + [Tenstorrent-XLA plugin](https://github.com/tenstorrent/tt-xla/blob/main/src/common/api_impl.cc) + (MLIR-based, C plugin) diff --git a/docs/pjrt/images/pjrt_client.svg b/docs/pjrt/images/pjrt_client.svg new file mode 100644 index 0000000000000..35e95bcd755c5 --- /dev/null +++ b/docs/pjrt/images/pjrt_client.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/pjrt/index.md b/docs/pjrt/index.md new file mode 100644 index 0000000000000..92c3bcb455afa --- /dev/null +++ b/docs/pjrt/index.md @@ -0,0 +1,30 @@ +Project: /_project.yaml +Book: /_book.yaml + + + +# PJRT - Uniform Device API + +PJRT C API is the uniform Device API that we want to add to the ML ecosystem. +The long term vision is that: (1) frameworks (TF, JAX, etc.) will call PJRT, +which has device-specific implementations that are opaque to the frameworks; (2) +each device focuses on implementing PJRT APIs as PJRT plugins, which can be +opaque to the frameworks. + +## Communication channels + +* Issues and feature requests can be filed in the [OpenXLA/xla repo](https://github.com/openxla/xla). +* Questions regarding PJRT can be asked on the [OpenXLA Discord][discord]. + +## Resources + +* [PJRT C API header](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h) +* [PJRT C API changelog](https://github.com/openxla/xla/blob/main/xla/pjrt/c/CHANGELOG.md) +* [PJRT integration guide](https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md) +* [PJRT design docs](https://drive.google.com/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN) +* [PJRT API ABI versioning and compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit) +* [PJRT Plugin Mechanism design doc](https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit) +* [OpenXLA/IREE PJRT plugin implementation](https://github.com/openxla/openxla-pjrt-plugin) + + +[discord]: https://discord.gg/ZKXq7b3V8A "Join on Discord" \ No newline at end of file diff --git a/docs/pjrt_integration.md b/docs/pjrt/pjrt_integration.md similarity index 88% rename from docs/pjrt_integration.md rename to docs/pjrt/pjrt_integration.md index c8141b158b91a..6c23d7ff17335 100644 --- a/docs/pjrt_integration.md +++ b/docs/pjrt/pjrt_integration.md @@ -1,10 +1,7 @@ # PJRT plugin integration -## Background - -[PJRT](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api.h) is the uniform Device API that we want to add to the ML ecosystem. The long term vision is that: (1) frameworks (JAX, TF, etc.) will call PJRT, which has device-specific implementations that are opaque to the frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque to the frameworks. - -This doc focuses on the recommendations about how to integrate with PJRT, and how to test PJRT integration with JAX. +This doc focuses on the recommendations about how to integrate with PJRT, and +how to test PJRT integration with JAX. ## How to integrate with PJRT @@ -140,14 +137,8 @@ jax.lax.psum(x, 'i'), axis_name='i')(arr)) # single device: [0] # 4 devices: [6 7 8 9] - ``` -(We'll add instructions for running the jax unit tests against your plugin soon!) -## Example: JAX CUDA plugin +(We'll add instructions for running the jax unit tests against your plugin soon!) -1. PJRT C API implementation through wrapper ([pjrt\_c\_api\_gpu.h](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api_gpu.h)). -1. Set up the entry point for the package ([setup.py](https://github.com/google/jax/blob/main/jax_plugins/cuda/setup.py)). -1. Implement an initialize() method ([\_\_init\_\_.py](https://github.com/google/jax/blob/a10854786b6d1bc92a65dd314916b151640789af/plugins/cuda/__init__.py#L31-L51)). -1. Can be tested with any jax tests for CUDA. -``` +For more examples of PJRT plugins see [PJRT Examples](examples.md). \ No newline at end of file diff --git a/xla/pjrt/c/README.md b/xla/pjrt/c/README.md index c567bbc9a0eb6..acc5923a2735b 100644 --- a/xla/pjrt/c/README.md +++ b/xla/pjrt/c/README.md @@ -15,8 +15,9 @@ opaque to the frameworks. * [PJRT C API header](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h) * [PJRT C API changelog](https://github.com/openxla/xla/blob/main/xla/pjrt/c/CHANGELOG.md) -* [PJRT integration guide](https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md) -* [PJRT design docs](https://drive.google.com/corp/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN) +* [PJRT C++ API Overview](https://github.com/openxla/xla/blob/main/xla/docs/pjrt/cpp_api_overview.md) +* [PJRT integration guide](https://github.com/openxla/xla/blob/main/xla/docs/pjrt/pjrt_integration.md) +* [PJRT design docs](https://drive.google.com/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN) * [PJRT API ABI versioning and compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit) * [PJRT Plugin Mechanism design doc](https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit) * [OpenXLA/IREE PJRT plugin implementation](https://github.com/openxla/openxla-pjrt-plugin) diff --git a/xla/pjrt/c/docs/pjrt_integration_guide.md b/xla/pjrt/c/docs/pjrt_integration_guide.md index 2dcb5c4f4d807..e822bb5cf05dc 100644 --- a/xla/pjrt/c/docs/pjrt_integration_guide.md +++ b/xla/pjrt/c/docs/pjrt_integration_guide.md @@ -1,4 +1,4 @@ # PJRT integration guide This file has been moved into the root -[XLA documentation directory](https://github.com/openxla/xla/blob/main/docs/pjrt_integration.md). +[XLA documentation directory](https://github.com/openxla/xla/blob/main/docs/pjrt/pjrt_integration.md).