Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for the remaining wrapper types #369

Merged
merged 15 commits into from
Dec 29, 2024
Merged

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 12, 2024

we need this before linear algebra stuff because most of them end up using these wrappers as return types.

fixes #345 #242

Added Wrapper Types

  • Symmetric
  • UnitLowerTriangular
  • LowerTriangular
  • UnitUpperTriangular
  • UpperTriangular
  • Tridiagonal

Updated Linear Algebra Support

  • diagm extended for key value pair version and now we use scatter

New Addition to Ops

  1. scatter_setindex
  2. gather_getindex

Not calling them scatter / gather since they are a significantly restrictive subset of the corresponding stablehlo versions but these are also a very common pattern.

Fixes / Refactor

  1. Non-contiguous indexing for both setindex and getindex

@avik-pal avik-pal force-pushed the ap/more_wrappers branch 3 times, most recently from b2c7e54 to 995d957 Compare December 12, 2024 15:46
@avik-pal avik-pal linked an issue Dec 12, 2024 that may be closed by this pull request
@mofeing
Copy link
Collaborator

mofeing commented Dec 12, 2024

what do you think about adding a section in the developer reference docs? so that wrapper array authors can add support for Reactant and we also remember how it should be done 😅

i actually want to add support for Reactant in my DeltaArrays package

@avik-pal
Copy link
Collaborator Author

what do you think about adding a section in the developer reference docs? so that wrapper array authors can add support for Reactant and we also remember how it should be done 😅

i actually want to add support for Reactant in my DeltaArrays package

That is a good idea, I will add that. 1 caveat of not having "AbstractWrappedArray" in Base Julia is that these new wrapper types won't go through our AnyTracedRArray dispatches. Technically, it is not a unique problem for us, and CUDA has been struggling with this for years now.

xref: JuliaLang/julia#51910

Comment on lines +182 to +131
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_setindex = true
break
end
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we introduce a runtime error in the generated MLIR?

Copy link
Collaborator

@mofeing mofeing Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you can't? at least not in the ML IR, but the verifier will error when verifying the ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of some sort of custom_call which jax is using here https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb

stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()

For tracedRarray indices we should probably always do a (dynamic_)gather (and we might be able to write an optimization later to transform that into a slice if contiguous.

@avik-pal
Copy link
Collaborator Author

Also we should revisit this at some point, but the current approach means that the generic dispatches convert a structured array into a dense array so we definitely use worse algorithms than are warranted

@avik-pal avik-pal marked this pull request as ready for review December 13, 2024 05:31
@avik-pal
Copy link
Collaborator Author

need to add tests

@mofeing
Copy link
Collaborator

mofeing commented Dec 13, 2024

Also we should revisit this at some point, but the current approach means that the generic dispatches convert a structured array into a dense array so we definitely use worse algorithms than are warranted

yeah, can you leave a note in the docs for this please? IMHO this is ok to leave as default, and we can specialize for performance if we know how to do for the combination of array type and Ops

That is a good idea, I will add that. 1 caveat of not having "AbstractWrappedArray" in Base Julia is that these new wrapper types won't go through our AnyTracedRArray dispatches. Technically, it is not a unique problem for us, and CUDA has been struggling with this for years now.

the way i've solved things like this in Tenet is using Holy traits with a default method that checks the trait (e.g. with LinearAlgebra.diag) but here is gonna be hard because we can't override the default method... maybe now that @wsmoses is readding the method overlay mechanism, we can overlay the default method and do sth like this?

iswrappedtraced(::TracedRArray) = true
iswrappedtraced(::Array) = false
iswrappedtraced(x::Diagonal) = iswrappedtraced(parent(x))
iswrappedtraced(::Type{<:TracedRArray}) = true
iswrappedtraced(::Type{<:Array}) = false
iswrappedtraced(::Type{Diagonal{T,N,A}}) where {T,N,A} = iswrappedtraced(A)

@reactant_override function LinearAlgebra.diag(x::AbstractArray{T,2}, k::Integer=0) where {T}
    if iswrappedtraced(x)
        y = materialize_traced_array(x)
        return diag(y, k)
    else
        # invoke diag(x) on NativeInterpreter
    end
end

function LinearAlgebra.diag(x::TracedRArray{T,2}, k::Integer=0) where {T}
    # here the code we currently have
end

the wrapped array lib dev who would like to add support for Reactant would then just need to correctly implement iswrappedtraced (suggestions welcome for a better name) and materialize_traced_array. furthermore, the dev could implement better methods if they know that this combination of method + array type doesn't need to be materialized (e.g. a Diagonal * Array could be implemented using dot_general with batching dims and no contracting dims).

the only problem i see is whether we can invoke methods in the NativeInterpreter from within a custom interpreter. @wsmoses @vchuravy @gbaraldi is that doable?


getting out of scope, i thing that sth Julia could have and would fix a loooot of problems like this are extendable unions. problem with abstract types is that you cannot take an external type and tell it "hey, you're now a 'son' of this abstract type because you fulfill this new interface i've defined". even with multiple inheritance there would still be problems. abstract types block composability.

an extendable union would be similar to abc.ABCMeta in Python, were you can register new subclasses dynamically https://docs.python.org/3.13/library/abc.html#abc.ABCMeta
check out this example where the user creates a new class MyABC and makes the builtin tuple be a subclass of it!

from abc import ABC

class MyABC(ABC):
    pass

MyABC.register(tuple)

assert issubclass(tuple, MyABC)
assert isinstance((), MyABC)

problem with this approach is that if not handled with care, i think it could lead to a lot of invalidation (or maybe no, i don't know much about the internals that deal with this).

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 13, 2024

We can push an is_wrapped_array to ArrayInterface (easier for other packages to use rather than depending on Reactant). Then is_wrapped_tracked_array (EDIT: is_wrapped_traced_array) is simply ancestor(x) isa TracedRArray

@mofeing
Copy link
Collaborator

mofeing commented Dec 13, 2024

We can push an is_wrapped_array to ArrayInterface (easier for other packages to use rather than depending on Reactant).

super, i like it.

Then is_wrapped_tracked_array is simply ancestor(x) isa TracedRArray

..._traced_...?

the only remaining problem is then how to call the original method in the native interpreter

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 15, 2024

the only remaining problem is then how to call the original method in the native interpreter

Looking at the current implementation, we could possibly do:

struct UseNativeInterpreter end # Open to other names as well
--- a/src/Interpreter.jl
+++ b/src/Interpreter.jl
@@ -43,7 +43,8 @@ function set_reactant_abi(
     # the original call
     if f === Reactant.call_with_reactant
         arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end])
-        return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
+        ret = abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
+        ret isa UseNativeInterpreter || return ret
     end
 
     return Base.@invoke abstract_call_known(
@reactant_override function LinearAlgebra.diag(x::AbstractArray{T,2}, k::Integer=0) where {T}
    if iswrappedtraced(x)
        y = materialize_traced_array(x)
        return diag(y, k)
    end
    return UseNativeInterpreter()
end

@mofeing
Copy link
Collaborator

mofeing commented Dec 15, 2024

uuu i like it. i'm not very familiar with set_reactant_abi but it's that what @wsmoses added for solving dynamic dispatch on abs int?

@wsmoses
Copy link
Member

wsmoses commented Dec 15, 2024

if you define the reactant override function as noinline, by default all methods within the override functions will be executed in the native interpreter

@wsmoses
Copy link
Member

wsmoses commented Dec 15, 2024

to re-enter the abstract interp override, do call_with_reactant(f, args...). This same behavior is true of anything in Ops, TracedUtils, or is otherwise in this whitelist:

function should_rewrite_ft(@nospecialize(ft))

@mofeing
Copy link
Collaborator

mofeing commented Dec 15, 2024

to re-enter the abstract interp override, do call_with_reactant(f, args...). This same behavior is true of anything in Ops, TracedUtils, or is otherwise in this whitelist:

aha, now i get it

src/Ops.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/more_wrappers branch 4 times, most recently from f39adeb to 9546d1f Compare December 24, 2024 06:53
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal is there anything more pending here or good to merge?

@avik-pal
Copy link
Collaborator Author

@avik-pal is there anything more pending here or good to merge?

I will add some tests and then we should be good to merge

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 28, 2024

Let's merge #426 first. We need that to update the dispatches correctly.

@avik-pal avik-pal force-pushed the ap/more_wrappers branch 2 times, most recently from 8d2794e to 2bdcfbd Compare December 29, 2024 03:18
@wsmoses
Copy link
Member

wsmoses commented Dec 29, 2024

@avik-pal after this [and perhaps the specialfunctions PR] lets cut a release

src/Overlay.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Collaborator Author

tests that should be passing are passing.

https://github.com/EnzymeAD/Reactant.jl/actions/runs/12537476469/job/34961608153?pr=369#step:9:823 is probably unrelated to this PR

@wsmoses
Copy link
Member

wsmoses commented Dec 29, 2024

Yeah that one feels like a jax error, let's merge

@wsmoses wsmoses merged commit 8e4c095 into main Dec 29, 2024
25 of 39 checks passed
@wsmoses wsmoses deleted the ap/more_wrappers branch December 29, 2024 18:48
jumerckx added a commit to jumerckx/Reactant.jl that referenced this pull request Jan 1, 2025
commit 6556944
Author: Avik Pal <[email protected]>
Date:   Tue Dec 31 08:53:39 2024 -0500

    feat: support Base.stack (EnzymeAD#433)

    * refactor: use scatter for generating diagm

    * refactor: directly generate the region for simple_scatter_op

    * feat: generalize diagm

    * feat: support Base.stack

    * fix: incorrect rebase

    * test: stack tests

commit 9375f57
Author: Sergio Sánchez Ramírez <[email protected]>
Date:   Mon Dec 30 22:56:14 2024 +0100

    Modularize Bazel build (EnzymeAD#421)

    * organize platforms and toolchains

    * hardcode libcxxwrap_julia path

    * format code

    * remove outdated hardcoded symbolic links

    * add third party bazel wrapper to libcxxwrap_julia

    * readd platforms

    * some small fixes

    * first step on moving externals to modular organization

    * refactor libcxxwrap_julia on top of `cc_import`

    * use modular workspaces

    * add `libcxxwrap_julia` as dependency

    * hardcode julia dep

    * export `reactant_*` functions

    * downgrade libcxxwrap_julia to v0.13.3

    * fix major version when linking to libcxxwrap_julia

    * remove legacy export

    * move `API.cpp` to new `src/` folder to start modularizing code

    * export `register_julia_module` from libcxxwrap_julia

    * fix symbol visibility

    * clean code

    remove libcxxwrap and julia deps

    * format code

    * import hedron compile commands from Enzyme-JAX

    * move deps commits to `workspace.bzl`

commit 3244204
Author: Avik Pal <[email protected]>
Date:   Mon Dec 30 16:33:52 2024 -0500

    chore: bump jll (EnzymeAD#437)

commit 25abfe4
Author: Avik Pal <[email protected]>
Date:   Mon Dec 30 12:03:20 2024 -0500

    fix: try building with cudnn 9.4 (EnzymeAD#436)

commit 241fd14
Author: Avik Pal <[email protected]>
Date:   Mon Dec 30 04:54:50 2024 -0500

    feat: indexing using traced values (EnzymeAD#434)

    * feat: indexing using traced values

    * feat: implement repeat inner

    * feat: support scalar linear indexing + tests

    * fix: regression in cartesian index support

    * Update src/TracedRArray.jl

commit 7d2b898
Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Date:   Sun Dec 29 20:28:23 2024 -0500

    Regenerate MLIR Bindings (EnzymeAD#435)

    Co-authored-by: mofeing <[email protected]>

commit 8e4c095
Author: Avik Pal <[email protected]>
Date:   Sun Dec 29 13:48:27 2024 -0500

    feat: add support for the remaining wrapper types (EnzymeAD#369)

    * feat: add materialize_traced_array for all other wrappers

    * refactor: use scatter for generating diagm

    * refactor: directly generate the region for simple_scatter_op

    * feat: generalize diagm

    * feat: efficient non-contiguous setindex

    * fix: non-contiguous indexing is now supported

    * feat: implement set_mlir_data for the remaining types

    * refactor: use `Ops.gather_getindex` to implement diag

    * fix: noinline ops

    * fix: incorrect rebase

    * fix: dispatches

    * fix: diagm for repeated indices and initial tests

    * fix: higher dimensional indexing + tests

    * fix: matrix multiplication of wrapper types

    * fix: de-specialize 3 arg mul!

commit d4e7c76
Author: William Moses <[email protected]>
Date:   Sat Dec 28 23:05:51 2024 -0500

    CUDA kernels take 3 (EnzymeAD#427)

    * CUDA take 3

    * conditional run cuda

    * Update test/integration/cuda.jl

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * bump enzymexla

    * fix

    * fix gpu reg

    * Update BUILD

    * Update BUILD

    * Update Project.toml

    * Update ReactantCUDAExt.jl

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * fix reactant method blocker

    * Update ReactantCUDAExt.jl

    * only do compile

    * use names in cache

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * cleanup further gc issues

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * fix

    ---------

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

commit b0a58bd
Author: Avik Pal <[email protected]>
Date:   Sat Dec 28 21:53:29 2024 -0500

    fix: make eltype of Traced/Concrete Arrays to be respective RNumbers (EnzymeAD#426)

    * feat: overlay eltype conversion

    * fix: overload the main methods

    * fix: make eltype of Traced/Concrete Arrays to be respective RNumbers

    * fix: handle more cases

    * fix: tracing of wrapped types

    * fix: arrayinterface overload

    * fix: python call

commit f079a9d
Author: mofeing <[email protected]>
Date:   Sun Dec 29 00:15:51 2024 +0000

    Format code

commit eeaf86c
Author: glounes <[email protected]>
Date:   Sat Dec 28 04:24:38 2024 +0100

    `stablehlo.sort` Ops (EnzymeAD#374)

    * `stablehlo.sort` Ops

    * review

    * use `return_dialect`

    * feedback

    * fix test GPU

commit 925544f
Author: William Moses <[email protected]>
Date:   Tue Dec 24 15:34:52 2024 -0500

    Cuv2 (EnzymeAD#423)

    * Kernel-supporting jll

    * fix rulescc

    * adapt to hedron dep

    * init target

    * fixup

    * additional fixups

    * fixup

    * fix

    * registry utils

    * callname

    * reg

    * fix

    * fix bld

    * cleanup

    * no pip

    * fix

    * force rules python to older version before bug

    * fixup jll

    * with proto

    * fix

    * fix

    * Update WORKSPACE

    * more deps for apple

    * bump

    * fix

    * workspace bump

    * workspace

    * Update Compiler.jl

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update Project.toml

    * Update ReactantCUDAExt.jl

    * Update Project.toml

    * Update Project.toml

    * fix

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update ReactantCUDAExt.jl

    * Update cuda.jl

    * Update cuda.jl

    * Update cuda.jl

    * Cuda kernel v2

    * Update Project.toml

    * Update API.cpp

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    ---------

    Co-authored-by: William Moses <[email protected]>
    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

commit 057e6b8
Author: Avik Pal <[email protected]>
Date:   Tue Dec 24 21:46:23 2024 +0530

    fix: handle traced array returns inside objects (EnzymeAD#417)

    * fix: handle traced array returns inside objects

    * test: add EnzymeAD#416 as a test

    * fix: propagate track_numbers correctly

    * fix: aliasing and add a test

    * test: use updated API for the tests

    * feat: cache new arrays

    * fix: traced_getfield

commit 0b6dafc
Author: Sergio Sánchez Ramírez <[email protected]>
Date:   Tue Dec 24 10:43:18 2024 +0100

    Bump Reactant_jll to v0.0.32

commit a02fd5b
Author: William Moses <[email protected]>
Date:   Mon Dec 23 21:58:40 2024 -0500

    Update WORKSPACE

commit 6e1710d
Author: William Moses <[email protected]>
Date:   Mon Dec 23 19:48:12 2024 -0500

    disable absint of absint (EnzymeAD#424)

    * disable absint of absint

    * no typeinf ext

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    ---------

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

commit 228732f
Author: William Moses <[email protected]>
Date:   Mon Dec 23 18:50:34 2024 -0500

    Fix error on global (EnzymeAD#422)

commit 695cc80
Author: William Moses <[email protected]>
Date:   Mon Dec 23 13:59:31 2024 -0500

    Update Project.toml

commit 38916f5
Author: Avik Pal <[email protected]>
Date:   Mon Dec 23 21:46:22 2024 +0530

    feat: add zero and fill! for ConcreteRArray (EnzymeAD#420)

    * feat: add zero and fill! for ConcreteRArray

    * test: add tests

commit 6571d54
Author: William S. Moses <[email protected]>
Date:   Sun Dec 22 23:49:03 2024 -0500

    bump enzymexla commit

commit 5b89b56
Author: William Moses <[email protected]>
Date:   Sun Dec 22 21:23:33 2024 -0500

    Fix ReactantPythonCallExt.jl (EnzymeAD#419)

    * Fix ReactantPythonCallExt.jl

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    ---------

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

commit 4cc000c
Author: William Moses <[email protected]>
Date:   Sun Dec 22 19:55:58 2024 -0500

    Improve reactant error messages (EnzymeAD#418)

    * Improve reactant error messages

    * More exported symbols

    * Update XLA.jl

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * Apply suggestions from code review

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    ---------

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

commit 2759c3c
Author: jumerckx <[email protected]>
Date:   Mon Dec 23 01:35:17 2024 +0100

    Inference cache (EnzymeAD#405)

    * add inference cache

    * start from `typeinf_ircode`

    * julia 1.10

    * Apply formatting suggestions

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

    * remove debug logging

    * vendor in type inference code for v1.10

    To avoid having to build a MethodInstance twice (performance hazard)

    ---------

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
    Co-authored-by: Jules Merckx <[email protected]>

commit f9c43ad
Author: William S. Moses <[email protected]>
Date:   Sun Dec 22 19:33:25 2024 -0500

    Bump enzymexla
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Handling Wrapped Arrays Correctly getindex assumes contiguous indexing
3 participants