-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to differentiate StableHLO with Enzyme-JAX from C++? #195
Comments
closing as I think I've missed sth, may reopen |
could you reopen please? All good |
It depends on setup but I think my recommendation would be to emit an enzyme autodiff or forwardiff op, then you can run the enzyme pass which will replace the op with a call to the derivative |
OK thanks. I'll need to do some reading to understand (I'm unfamiliar with "emit", and "pass") Is this all in Enzyme-JAX repo or is some in Enzyme repo? |
partially, so just like stablehlo has ops for things there's also enzyme autodiff ops In text form you can see it, for example, here: Enzyme-JAX/test/lit_tests/grad_sum1d.mlir Line 11 in fb483c0
You then can run an optimization pass which generates the derivative [creating the code at the bottom in the comment (which is what the test compares against)] |
thanks, and are there C++ apis for this? Looking at the Julia implementation "emit" seems to defer to the MLIR C++ API. I'm not sure what running an optimization pass looks like. Is that with btw if you don't have time for all these questions, just say |
Yeah there are. For an example of running passes you can look here
Here you should just need to run the Enzyme pass, and the Enzyme remove unnecessary ops pass |
thanks. I did look at that function, but it's not available in the public API of Enzyme-JAX. I might just copy its intention. I more meant APIs for building This is very helpful. I can get a lot further with this |
Where can I find the relevant dialects and passes? I tried
and
but that's not working. I don't know if that's the problem. Much of the relevant machinery seems to be private, and there are many many different functions for dialects and passes so I'm lost on what's needed. |
BTW once I work out how to do this, I might make a small C++ library for differentiating StableHLO, one that's not specific to any particular frontend |
isn't that just the |
what's that? |
There’s a binary enzymexlamlir-opt that takes stablehlo (and other general MLIR files) as inputs, a list of optimizations (including differentiation) as args and prints out the transformed code to stdout or a file of choice. Or @mofeing did you mean the interpreter (which works similarly but assumes the code is compromised of constants and essentially does the transformation and generates the final constant result) |
i mean that the interpreter does what @joelberkeley wants but in a binary manner (take MLIR files and return result of transformations or perform actual interpretation and return values). like what he wants is sth similar to the interpreter but "librarizing" it |
I'm differentiating a C++ |
In any case can you post your full code and error message and we can try to help see if there’s something missing (likely registering one of the MLIR interfaces). |
I can paste it, though it's not even first draft yet, and is a combination of two different languages, so I'm not sure it will help. I'm just trying to differentiate a computation <- compile xlaBuilder f
stablehlo <- hloModuleProtoToStableHLO !(proto computation)
reg <- mkDialectRegistry
insertEnzymeDialect reg
StableHLO.Dialect.Register.registerAllDialects reg
ctx <- getContext stablehlo
appendDialectRegistry ctx reg
mgr <- mkPassManager ctx
addPass mgr !createDifferentiatePass
enzymeOp <- emitEnzymeADOp stablehlo reg
_ <- run mgr enzymeOp
hloProto <- convertStablehloToHlo stablehlo
computation <- mkXlaComputation hloProto and mlir::ModuleOp* emitEnzymeADOp(mlir::ModuleOp& module_op, mlir::DialectRegistry& registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
auto ctx = module_op.getContext();
auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff");
auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx));
state.addTypes({scalarf64});
auto operands = module_op.getOperation()->getOperands(); // complete guess
state.addOperands(mlir::ValueRange(operands));
auto operation = module_op.getOperation(); // complete guess
state.addAttribute("fn", operation->getAttr("sym_name"));
auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active);
state.addAttribute("activity", {activity});
auto ret_activity = mlir::enzyme::ActivityAttr::get(
ctx, mlir::enzyme::Activity::enzyme_activenoneed
);
state.addAttribute("ret_activity", {ret_activity});
auto res = mlir::Operation::create(state);
return new mlir::ModuleOp(res);
} Error
|
hm yeah I think you're not initializing the Enzyme dialect from the looks of the error message. This function here is extremely overkill but it should definitely add it:
Also considering you're explicitly adding the op itself (and it's not being generated by a different pass or parser) you may need to explicitly load the dialect in the context |
yeah I saw that function, but it's not public. I did try to reproduce it but its contents aren't public either |
Feel free to make a PR to make anything public that you need |
ok thanks. I'm going to head off now (it's late here). I'll come back to this in a few days. Seasons greetings! |
I did quite a bit more digging today, but didn't progress. Here are some notes. I added I noticed the call to
Might that be the cause? I noticed I'm still keen to make a StableHLO autodiff library that's decoupled from JAX, but I might not have time. |
No that’s unrelated (it’s that we recently added an xla dialect for optimizing kernel calls but haven’t yet added custom attributes, if we do well I comment that). The code here for autodiff is independent of jax (but is in this repo for ease for building the jax plugin). Perhaps it should be renamed enzymexla and also there’s an ongoing discussion on moving it into stablehlo proper. This is all setup for MLIR (which is unfortunately not the clearest). My recommendation: take the enzymexlmlir-opt.cpp binary, and copy it to be a library file and call the pass manager with your op instead of parsing a new one in from a file. That way you’ll have something with all the setup properly done (and it’s easier to remove excess registration once it’s running imo). |
To make things more concrete: take this
Alternatively if you can share a repo with your whole setup we can try to take a look and fiddle with MLIR’s setup to make sure things are registered |
ok, I've copy-pasted that function, and put its contents before my code. I still get the same error. I will compose all my stuff into a function that's as self-contained as possible and paste it here |
I've edited this code to be more self-contained, it now produces an executable binary instead of a library. I might get round to making it into a git repo #include "stablehlo/dialect/Register.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/hlo/builder/lib/math.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Passes/Passes.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/TransformOps/TransformOps.h"
#include "src/enzyme_ad/jax/RegistryUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "llvm/Support/TargetSelect.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/tests/CheckOps.h"
class MemRefInsider
: public mlir::MemRefElementTypeInterface::FallbackModel<MemRefInsider> {};
template <typename T>
struct PtrElementModel
: public mlir::LLVM::PointerElementTypeInterface::ExternalModel<
PtrElementModel<T>, T> {};
int main() {
// create the stablehlo computation
xla::XlaBuilder builder("root");
auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12);
auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg");
auto proto = builder.Build(xla::Square(arg))->proto();
mlir::MLIRContext ctx;
mlir::DialectRegistry registry_;
ctx.appendDialectRegistry(registry_);
mlir::mhlo::registerAllMhloDialects(registry_);
mlir::stablehlo::registerAllDialects(registry_);
auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release();
// stuff copied from enzyme mlir main function
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
registry_.insert<mlir::stablehlo::check::CheckDialect>();
prepareRegistry(registry_);
mlir::registerenzymePasses();
regsiterenzymeXLAPasses();
mlir::registerCSEPass();
mlir::registerConvertAffineToStandardPass();
mlir::registerSCCPPass();
mlir::registerInlinerPass();
mlir::registerCanonicalizerPass();
mlir::registerSymbolDCEPass();
mlir::registerLoopInvariantCodeMotionPass();
mlir::registerConvertSCFToOpenMPPass();
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();
registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) {
mlir::LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
mlir::MemRefType::attachInterface<PtrElementModel<mlir::MemRefType>>(*ctx);
mlir::LLVM::LLVMStructType::attachInterface<
PtrElementModel<mlir::LLVM::LLVMStructType>>(*ctx);
mlir::LLVM::LLVMPointerType::attachInterface<
PtrElementModel<mlir::LLVM::LLVMPointerType>>(*ctx);
mlir::LLVM::LLVMArrayType::attachInterface<PtrElementModel<mlir::LLVM::LLVMArrayType>>(*ctx);
});
mlir::transform::registerInterpreterPass();
mlir::enzyme::registerGenerateApplyPatternsPass();
mlir::enzyme::registerRemoveTransformPass();
// attempt to create an `enzyme.autodiff` op
auto state = mlir::OperationState(mlir::UnknownLoc::get(&ctx), "enzyme.autodiff");
auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(&ctx));
state.addTypes({scalarf64});
auto operands = module_op_.getOperation()->getOperands(); // complete guess
state.addOperands(mlir::ValueRange(operands));
auto operation = module_op_.getOperation(); // complete guess
state.addAttribute("fn", operation->getAttr("sym_name"));
auto activity = mlir::enzyme::ActivityAttr::get(&ctx, mlir::enzyme::Activity::enzyme_active);
state.addAttribute("activity", {activity});
auto ret_activity = mlir::enzyme::ActivityAttr::get(
&ctx, mlir::enzyme::Activity::enzyme_activenoneed
);
state.addAttribute("ret_activity", {ret_activity});
auto res = mlir::Operation::create(state);
return 0;
} Here is the bazel target for the above
where I've added this target to Enzyme-JAX
|
I'm going to count my chickens and say you can test this using this change, which is basically the above but inserted into this repo https://github.com/EnzymeAD/Enzyme-JAX/compare/main...joelberkeley:Enzyme-JAX:example?expand=1 with
|
ok that has linker errors. I may come back to it. The code itself "works", in that it fails at runtime as described. |
I've made progress: I added |
I'm trying to differentiate a StableHLO
mlir::ModuleOp
, and I'm lost. I'm starting by trying to interface with Enzyme-JAX, but I've noticed that the only source file exported by bazel isenzymexlamlir-opt.cpp
. I did try to useregisterStableHLODialectAutoDiffInterface
but that's not exported. What's the recommended usage of Enzyme-JAX from C++?The text was updated successfully, but these errors were encountered: