Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to differentiate StableHLO with Enzyme-JAX from C++? #195

Open
joelberkeley opened this issue Dec 16, 2024 · 29 comments
Open

How to differentiate StableHLO with Enzyme-JAX from C++? #195

joelberkeley opened this issue Dec 16, 2024 · 29 comments

Comments

@joelberkeley
Copy link

I'm trying to differentiate a StableHLO mlir::ModuleOp, and I'm lost. I'm starting by trying to interface with Enzyme-JAX, but I've noticed that the only source file exported by bazel is enzymexlamlir-opt.cpp. I did try to use registerStableHLODialectAutoDiffInterface but that's not exported. What's the recommended usage of Enzyme-JAX from C++?

@joelberkeley joelberkeley changed the title How to use Enzyme-JAX from C++? How to differentiate StableHLO with Enzyme-JAX from C++? Dec 16, 2024
@joelberkeley
Copy link
Author

closing as I think I've missed sth, may reopen

@joelberkeley
Copy link
Author

could you reopen please? All good

@wsmoses wsmoses reopened this Dec 16, 2024
@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2024

It depends on setup but I think my recommendation would be to emit an enzyme autodiff or forwardiff op, then you can run the enzyme pass which will replace the op with a call to the derivative

@joelberkeley
Copy link
Author

joelberkeley commented Dec 16, 2024

OK thanks. I'll need to do some reading to understand (I'm unfamiliar with "emit", and "pass")

Is this all in Enzyme-JAX repo or is some in Enzyme repo?

@wsmoses
Copy link
Member

wsmoses commented Dec 17, 2024

partially,

so just like stablehlo has ops for things there's also enzyme autodiff ops

In text form you can see it, for example, here:

%0 = enzyme.autodiff @"Const{typeof(simple_reduce)}(simple_reduce)_autodiff"(%arg0, %cst) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (tensor<5x3xf32>, tensor<3xf32>) -> (tensor<5x3xf32>)

You then can run an optimization pass which generates the derivative [creating the code at the bottom in the comment (which is what the test compares against)]

@joelberkeley
Copy link
Author

thanks, and are there C++ apis for this? Looking at the Julia implementation "emit" seems to defer to the MLIR C++ API.

I'm not sure what running an optimization pass looks like. Is that with PassManager? Which passes are needed for enzyme.autodiff?

btw if you don't have time for all these questions, just say

@wsmoses
Copy link
Member

wsmoses commented Dec 17, 2024

Yeah there are. For an example of running passes you can look here

void run_pass_pipeline(mlir::Operation *mod, const std::string &pass_pipeline) {
which takes a module op and a string containing which passes to run and constructs a pass pipeline and runs it. You can also just add the passes using the C++ API.

Here you should just need to run the Enzyme pass, and the Enzyme remove unnecessary ops pass

@joelberkeley
Copy link
Author

joelberkeley commented Dec 17, 2024

thanks. I did look at that function, but it's not available in the public API of Enzyme-JAX. I might just copy its intention.

I more meant APIs for building enzyme.autodiff

This is very helpful. I can get a lot further with this

@joelberkeley
Copy link
Author

joelberkeley commented Dec 23, 2024

Where can I find the relevant dialects and passes? I tried

registry.insert<mlir::enzyme::EnzymeDialect>()
mlir::stablehlo::registerAllDialects(registry)

and

mlir::enzyme::createDifferentiatePass();

but that's not working. I don't know if that's the problem.

Much of the relevant machinery seems to be private, and there are many many different functions for dialects and passes so I'm lost on what's needed.

@joelberkeley
Copy link
Author

joelberkeley commented Dec 23, 2024

BTW once I work out how to do this, I might make a small C++ library for differentiating StableHLO, one that's not specific to any particular frontend

@mofeing
Copy link
Collaborator

mofeing commented Dec 23, 2024

BTW once I work out how to do this, I might make a small C++ library for differentiating StableHLO, one that's not specific to any particular frontend

isn't that just the enzymexla-interpreter?

@joelberkeley
Copy link
Author

isn't that just the enzymexla-interpreter?

what's that?

@wsmoses
Copy link
Member

wsmoses commented Dec 23, 2024

There’s a binary enzymexlamlir-opt that takes stablehlo (and other general MLIR files) as inputs, a list of optimizations (including differentiation) as args and prints out the transformed code to stdout or a file of choice.

Or @mofeing did you mean the interpreter (which works similarly but assumes the code is compromised of constants and essentially does the transformation and generates the final constant result)

@mofeing
Copy link
Collaborator

mofeing commented Dec 23, 2024

i mean that the interpreter does what @joelberkeley wants but in a binary manner (take MLIR files and return result of transformations or perform actual interpretation and return values).

like what he wants is sth similar to the interpreter but "librarizing" it

@joelberkeley
Copy link
Author

joelberkeley commented Dec 23, 2024

I'm differentiating a C++ mlir::ModuleOp. The StableHLO is generated at runtime

@wsmoses
Copy link
Member

wsmoses commented Dec 23, 2024

In any case can you post your full code and error message and we can try to help see if there’s something missing (likely registering one of the MLIR interfaces).

@joelberkeley
Copy link
Author

joelberkeley commented Dec 23, 2024

I can paste it, though it's not even first draft yet, and is a combination of two different languages, so I'm not sure it will help. I'm just trying to differentiate a tensor<f64> -> tensor<f64> for now

    computation <- compile xlaBuilder f
    stablehlo <- hloModuleProtoToStableHLO !(proto computation)
    reg <- mkDialectRegistry
    insertEnzymeDialect reg
    StableHLO.Dialect.Register.registerAllDialects reg
    ctx <- getContext stablehlo
    appendDialectRegistry ctx reg
    mgr <- mkPassManager ctx
    addPass mgr !createDifferentiatePass
    enzymeOp <- emitEnzymeADOp stablehlo reg
    _ <- run mgr enzymeOp
    hloProto <- convertStablehloToHlo stablehlo
    computation <- mkXlaComputation hloProto

and

    mlir::ModuleOp* emitEnzymeADOp(mlir::ModuleOp& module_op, mlir::DialectRegistry& registry) {
        mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);

        auto ctx = module_op.getContext();
        auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff");

        auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx));
        state.addTypes({scalarf64});

        auto operands = module_op.getOperation()->getOperands();  // complete guess
        state.addOperands(mlir::ValueRange(operands));

        auto operation = module_op.getOperation();  // complete guess
        state.addAttribute("fn", operation->getAttr("sym_name"));
        auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active);
        state.addAttribute("activity", {activity});
        auto ret_activity = mlir::enzyme::ActivityAttr::get(
            ctx, mlir::enzyme::Activity::enzyme_activenoneed
        );
        state.addAttribute("ret_activity", {ret_activity});

        auto res = mlir::Operation::create(state);

        return new mlir::ModuleOp(res);
    }

Error

LLVM ERROR: can't create Attribute 'mlir::enzyme::ActivityAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.

@wsmoses
Copy link
Member

wsmoses commented Dec 24, 2024

hm yeah I think you're not initializing the Enzyme dialect from the looks of the error message.

This function here is extremely overkill but it should definitely add it:

void prepareRegistry(mlir::DialectRegistry &registry) {
.

Also considering you're explicitly adding the op itself (and it's not being generated by a different pass or parser) you may need to explicitly load the dialect in the context

@joelberkeley
Copy link
Author

yeah I saw that function, but it's not public. I did try to reproduce it but its contents aren't public either

@wsmoses
Copy link
Member

wsmoses commented Dec 24, 2024

Feel free to make a PR to make anything public that you need

@joelberkeley
Copy link
Author

ok thanks. I'm going to head off now (it's late here). I'll come back to this in a few days. Seasons greetings!

@joelberkeley
Copy link
Author

joelberkeley commented Dec 25, 2024

I did quite a bit more digging today, but didn't progress. Here are some notes. I added prepareRegistry. I also added calls to registerenzymePasses and regsiterenzymeXLAPasses. Same error persists.

I noticed the call to addAttributes is commented out for EnzymeXLADialect

void EnzymeXLADialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc"
      >();
  //  addAttributes<
  // #define GET_ATTRDEF_LIST
  // #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc"
  //      >();
  //  addTypes<
  // #define GET_TYPEDEF_LIST
  // #include "src/enzyme_ad/jax/Dialect/EnzymeXLAOpsTypes.cpp.inc"
  //      >();
}

Might that be the cause? I noticed mlir::enzyme::ActivityAttr is in the Enzyme repo not Enzyme-JAX. I uncommented those lines but they refer to files that no longer exist.

I'm still keen to make a StableHLO autodiff library that's decoupled from JAX, but I might not have time.

@wsmoses
Copy link
Member

wsmoses commented Dec 25, 2024

No that’s unrelated (it’s that we recently added an xla dialect for optimizing kernel calls but haven’t yet added custom attributes, if we do well I comment that).

The code here for autodiff is independent of jax (but is in this repo for ease for building the jax plugin). Perhaps it should be renamed enzymexla and also there’s an ongoing discussion on moving it into stablehlo proper.

This is all setup for MLIR (which is unfortunately not the clearest).

My recommendation: take the enzymexlmlir-opt.cpp binary, and copy it to be a library file and call the pass manager with your op instead of parsing a new one in from a file. That way you’ll have something with all the setup properly done (and it’s easier to remove excess registration once it’s running imo).

@wsmoses
Copy link
Member

wsmoses commented Dec 25, 2024

To make things more concrete: take this

return mlir::asMainReturnCode(mlir::MlirOptMain(
line out and run the pass on your favorite module , first importing the dialect registry into the module’s context

Alternatively if you can share a repo with your whole setup we can try to take a look and fiddle with MLIR’s setup to make sure things are registered

@joelberkeley
Copy link
Author

joelberkeley commented Dec 25, 2024

ok, I've copy-pasted that function, and put its contents before my code. I still get the same error. I will compose all my stuff into a function that's as self-contained as possible and paste it here

@joelberkeley
Copy link
Author

joelberkeley commented Dec 27, 2024

I've edited this code to be more self-contained, it now produces an executable binary instead of a library. I might get round to making it into a git repo

#include "stablehlo/dialect/Register.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/hlo/builder/lib/math.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"

#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Passes/Passes.h"

#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/TransformOps/TransformOps.h"
#include "src/enzyme_ad/jax/RegistryUtils.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"

#include "llvm/Support/TargetSelect.h"

#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/tests/CheckOps.h"

class MemRefInsider
    : public mlir::MemRefElementTypeInterface::FallbackModel<MemRefInsider> {};

template <typename T>
struct PtrElementModel
    : public mlir::LLVM::PointerElementTypeInterface::ExternalModel<
          PtrElementModel<T>, T> {};

int main() {
    // create the stablehlo computation
    xla::XlaBuilder builder("root");
    auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12);
    auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg");
    auto proto = builder.Build(xla::Square(arg))->proto();

    mlir::MLIRContext ctx;
    mlir::DialectRegistry registry_;
    ctx.appendDialectRegistry(registry_);
    mlir::mhlo::registerAllMhloDialects(registry_);
    mlir::stablehlo::registerAllDialects(registry_);

    auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release();

    // stuff copied from enzyme mlir main function
    llvm::InitializeNativeTarget();
    llvm::InitializeNativeTargetAsmPrinter();

    registry_.insert<mlir::stablehlo::check::CheckDialect>();
    prepareRegistry(registry_);

    mlir::registerenzymePasses();
    regsiterenzymeXLAPasses();

    mlir::registerCSEPass();
    mlir::registerConvertAffineToStandardPass();
    mlir::registerSCCPPass();
    mlir::registerInlinerPass();
    mlir::registerCanonicalizerPass();
    mlir::registerSymbolDCEPass();
    mlir::registerLoopInvariantCodeMotionPass();
    mlir::registerConvertSCFToOpenMPPass();
    mlir::affine::registerAffinePasses();
    mlir::registerReconcileUnrealizedCasts();

    registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) {
        mlir::LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
        mlir::LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
        mlir::LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
        mlir::LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
        mlir::MemRefType::attachInterface<PtrElementModel<mlir::MemRefType>>(*ctx);
        mlir::LLVM::LLVMStructType::attachInterface<
            PtrElementModel<mlir::LLVM::LLVMStructType>>(*ctx);
        mlir::LLVM::LLVMPointerType::attachInterface<
            PtrElementModel<mlir::LLVM::LLVMPointerType>>(*ctx);
        mlir::LLVM::LLVMArrayType::attachInterface<PtrElementModel<mlir::LLVM::LLVMArrayType>>(*ctx);
    });

    mlir::transform::registerInterpreterPass();
    mlir::enzyme::registerGenerateApplyPatternsPass();
    mlir::enzyme::registerRemoveTransformPass();

    // attempt to create an `enzyme.autodiff` op
    auto state = mlir::OperationState(mlir::UnknownLoc::get(&ctx), "enzyme.autodiff");

    auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(&ctx));
    state.addTypes({scalarf64});

    auto operands = module_op_.getOperation()->getOperands();  // complete guess
    state.addOperands(mlir::ValueRange(operands));

    auto operation = module_op_.getOperation();  // complete guess
    state.addAttribute("fn", operation->getAttr("sym_name"));
    auto activity = mlir::enzyme::ActivityAttr::get(&ctx, mlir::enzyme::Activity::enzyme_active);
    state.addAttribute("activity", {activity});
    auto ret_activity = mlir::enzyme::ActivityAttr::get(
        &ctx, mlir::enzyme::Activity::enzyme_activenoneed
    );
    state.addAttribute("ret_activity", {ret_activity});

    auto res = mlir::Operation::create(state);

    return 0;
}

Here is the bazel target for the above

cc_binary(
    name = "example",
    linkstatic = True,
    srcs = glob(["*.cpp"]),
    deps = [
        "@xla//xla/hlo/builder:xla_builder",
        "@xla//xla/hlo/translate:stablehlo",
        "@xla//xla/hlo/builder/lib:math",
        "@xla//xla/mlir_hlo:hlo_dialect_registration",
        "@enzyme-jax//:everything",
    ],
    visibility = ["//visibility:public"],
)

where I've added this target to Enzyme-JAX

cc_library(
    name = "everything",
    srcs = [
        "//src/enzyme_ad/jax:TransformOps",
        "//src/enzyme_ad/jax:XLADerivatives",
        "//src/enzyme_ad/jax:RegistryUtils.cpp",
    ],
    hdrs = [
        "//src/enzyme_ad/jax:TransformOps",
        "//src/enzyme_ad/jax:XLADerivatives",
        "//src/enzyme_ad/jax:RegistryUtils.h",
    ],
    visibility = ["//visibility:public"],
    deps = [
        "@enzyme//:EnzymeMLIR",
        "@llvm-project//mlir:AffineDialect",
        "@llvm-project//mlir:AllPassesAndDialects",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:AsyncDialect",
        "@llvm-project//mlir:ComplexDialect",
        "@llvm-project//mlir:ControlFlowDialect",
        "@llvm-project//mlir:ConversionPasses",
        "@llvm-project//mlir:DLTIDialect",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:GPUDialect",
        "@llvm-project//mlir:LinalgDialect",
        "@llvm-project//mlir:LLVMDialect",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:MemRefDialect",
        "@llvm-project//mlir:MlirOptLib",
        "@llvm-project//mlir:NVVMDialect",
        "@llvm-project//mlir:NVGPUDialect",
        "@llvm-project//mlir:OpenMPDialect",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:TransformDialect",
        "@llvm-project//mlir:Transforms",
        "//src/enzyme_ad/jax:TransformOps",
        "//src/enzyme_ad/jax:XLADerivatives",
        "@stablehlo//:chlo_ops",
        "@stablehlo//stablehlo/tests:check_ops",
        "@llvm-project//mlir:ArithToLLVM",
        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
        "@llvm-project//mlir:ComplexToLLVM",
        "@llvm-project//mlir:ControlFlowToLLVM",
        "@llvm-project//mlir:GPUToLLVMIRTranslation",
        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
        "@llvm-project//mlir:NVVMToLLVMIRTranslation",

        "@llvm-project//llvm:X86AsmParser",
        "@llvm-project//llvm:X86CodeGen",
    ],
)

@joelberkeley
Copy link
Author

joelberkeley commented Dec 27, 2024

I'm going to count my chickens and say you can test this using this change, which is basically the above but inserted into this repo

https://github.com/EnzymeAD/Enzyme-JAX/compare/main...joelberkeley:Enzyme-JAX:example?expand=1

with

bazel build //:example
./bazel-bin/example

@joelberkeley
Copy link
Author

joelberkeley commented Dec 27, 2024

ok that has linker errors. I may come back to it. The code itself "works", in that it fails at runtime as described.

@joelberkeley
Copy link
Author

I've made progress: I added ctx->loadDialect<mlir::enzyme::EnzymeDialect>(), which curiously I can't find any mention of in either the Enzyme or Enzyme-JAX repos. I got it from the MLIR tutorials

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants