Skip to content

Commit

Permalink
fix attribute for arrray of bools (#279)
Browse files Browse the repository at this point in the history
* fix attribute for arrray of bools

* test

* remove uses of unsafe pointer calls
  • Loading branch information
Pangoraw authored Nov 15, 2024
1 parent f6dcae1 commit b92a7d8
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 68 deletions.
8 changes: 3 additions & 5 deletions src/mlir/IR/AffineMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Creates an affine map with results defined by the given list of affine expressio
The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results.
"""
AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) =
AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), pointer(exprs)))
AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs))

"""
ConstantAffineMap(val; context=context())
Expand Down Expand Up @@ -94,9 +94,7 @@ The affine map is owned by the context.
function PermutationAffineMap(permutation; context::Context=context())
@assert Base.isperm(permutation) "$permutation must be a valid permutation"
zero_perm = permutation .- 1
return AffineMap(
API.mlirAffineMapPermutationGet(context, length(zero_perm), pointer(zero_perm))
)
return AffineMap(API.mlirAffineMapPermutationGet(context, length(zero_perm), zero_perm))
end

"""
Expand Down Expand Up @@ -192,7 +190,7 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map)
Returns the affine map consisting of the `positions` subset.
"""
submap(map::AffineMap, pos::Vector{Int}) =
AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pointer(pos)))
AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos))

"""
majorsubmap(affineMap, nresults)
Expand Down
66 changes: 26 additions & 40 deletions src/mlir/IR/Attribute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ isarray(attr::Attribute) = API.mlirAttributeIsAArray(attr)
Creates an array element containing the given list of elements in the given context.
"""
Attribute(attrs::Vector{Attribute}; context::Context=context()) =
Attribute(API.mlirArrayAttrGet(context, length(attrs), pointer(attrs)))
Attribute(API.mlirArrayAttrGet(context, length(attrs), attrs))

"""
isdict(attr)
Expand All @@ -97,7 +97,7 @@ Creates a dictionary attribute containing the given list of elements in the prov
"""
function Attribute(attrs::Dict; context::Context=context())
attrs = map(splat(NamedAttribute), attrs)
return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), pointer(attrs)))
return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), attrs))
end

"""
Expand Down Expand Up @@ -309,9 +309,7 @@ Each of the references in the list must not be nested.
"""
SymbolRefAttribute(
symbol::String, references::Vector{Attribute}; context::Context=context()
) = Attribute(
API.mlirSymbolRefAttrGet(context, symbol, length(references), pointer(references))
)
) = Attribute(API.mlirSymbolRefAttrGet(context, symbol, length(references), references))

"""
rootref(attr)
Expand Down Expand Up @@ -429,9 +427,7 @@ Creates a dense elements attribute with the given Shaped type and elements in th
"""
function DenseElementsAttribute(shaped_type::Type, elements::AbstractArray)
@assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type"
return Attribute(
API.mlirDenseElementsAttrGet(shaped_type, length(elements), pointer(elements))
)
return Attribute(API.mlirDenseElementsAttrGet(shaped_type, length(elements), elements))
end

# TODO mlirDenseElementsAttrRawBufferGet
Expand Down Expand Up @@ -504,77 +500,67 @@ Creates a dense elements attribute with the given shaped type from elements of a
function DenseElementsAttribute(values::AbstractArray{Bool})
shaped_type = TensorType(size(values), Type(Bool))
return Attribute(
API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrBoolGet(
shaped_type, length(values), AbstractArray{Cint}(values)
),
)
end

function DenseElementsAttribute(values::AbstractArray{UInt8})
shaped_type = TensorType(size(values), Type(UInt8))
return Attribute(
API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{Int8})
shaped_type = TensorType(size(values), Type(Int8))
return Attribute(
API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{UInt16})
shaped_type = TensorType(size(values), Type(UInt16))
return Attribute(
API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), values)
)
end

function DenseElementsAttribute(values::AbstractArray{Int16})
shaped_type = TensorType(size(values), Type(Int16))
return Attribute(
API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{UInt32})
shaped_type = TensorType(size(values), Type(UInt32))
return Attribute(
API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), values)
)
end

function DenseElementsAttribute(values::AbstractArray{Int32})
shaped_type = TensorType(size(values), Type(Int32))
return Attribute(
API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{UInt64})
shaped_type = TensorType(size(values), Type(UInt64))
return Attribute(
API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), values)
)
end

function DenseElementsAttribute(values::AbstractArray{Int64})
shaped_type = TensorType(size(values), Type(Int64))
return Attribute(
API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{Float32})
shaped_type = TensorType(size(values), Type(Float32))
return Attribute(
API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), pointer(values))
)
return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), values))
end

function DenseElementsAttribute(values::AbstractArray{Float64})
shaped_type = TensorType(size(values), Type(Float64))
return Attribute(
API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), values)
)
end

Expand All @@ -583,7 +569,7 @@ end
function DenseElementsAttribute(values::AbstractArray{Float16})
shaped_type = TensorType(size(values), Type(Float16))
return Attribute(
API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), values)
)
end

Expand All @@ -606,7 +592,7 @@ function DenseElementsAttribute(values::AbstractArray{String})
# TODO may fail because `Type(String)` is not defined
shaped_type = TensorType(size(values), Type(String))
return Attribute(
API.mlirDenseElementsAttrStringGet(shaped_type, length(values), pointer(values))
API.mlirDenseElementsAttrStringGet(shaped_type, length(values), values)
)
end

Expand Down Expand Up @@ -677,25 +663,25 @@ function DenseArrayAttribute end

@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Bool}; context::Context=context()
) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int8}; context::Context=context()
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int16}; context::Context=context()
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int32}; context::Context=context()
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int64}; context::Context=context()
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Float32}; context::Context=context()
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), values))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Float64}; context::Context=context()
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), pointer(values)))
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), values))

@llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values)

Expand Down
7 changes: 1 addition & 6 deletions src/mlir/IR/IntegerSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ Both `constraints` and `eqflags` need to be arrays of the same length.
"""
IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet(
API.mlirIntegerSetGet(
context,
ndims,
nsymbols,
length(constraints),
pointer(constraints),
pointer(eqflags),
context, ndims, nsymbols, length(constraints), constraints, eqflags
),
)

Expand Down
2 changes: 1 addition & 1 deletion src/mlir/IR/Location.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
# TODO rename to merge?
function fuse(locations::Vector{Location}, metadata; context::Context=context())
return Location(
API.mlirLocationFusedGet(context, length(locations), pointer(locations), metadata)
API.mlirLocationFusedGet(context, length(locations), locations, metadata)
)
end

Expand Down
22 changes: 6 additions & 16 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,15 +449,11 @@ function MemRefType(
if check
Type(
API.mlirMemRefTypeGetChecked(
location, elem_type, length(shape), pointer(shape), layout, memspace
location, elem_type, length(shape), shape, layout, memspace
),
)
else
Type(
API.mlirMemRefTypeGet(
elem_type, length(shape), pointer(shape), layout, memspace
),
)
Type(API.mlirMemRefTypeGet(elem_type, length(shape), shape, layout, memspace))
end
end

Expand All @@ -474,15 +470,11 @@ function MemRefType(
if check
Type(
API.mlirMemRefTypeContiguousGetChecked(
location, elem_type, length(shape), pointer(shape), memspace
location, elem_type, length(shape), shape, memspace
),
)
else
Type(
API.mlirMemRefTypeContiguousGet(
elem_type, length(shape), pointer(shape), memspace
),
)
Type(API.mlirMemRefTypeContiguousGet(elem_type, length(shape), shape, memspace))
end
end

Expand Down Expand Up @@ -560,7 +552,7 @@ end
Creates a tuple type that consists of the given list of elemental types. The type is owned by the context.
"""
Type(elements::Vector{Type}; context::Context=context()) =
Type(API.mlirTupleTypeGet(context, length(elements), pointer(elements)))
Type(API.mlirTupleTypeGet(context, length(elements), elements))
function Type(@nospecialize(elements::NTuple{N,Type}); context::Context=context()) where {N}
return Type(collect(elements); context)
end
Expand Down Expand Up @@ -590,9 +582,7 @@ Creates a function type, mapping a list of input types to result types.
"""
function FunctionType(inputs, results; context::Context=context())
return Type(
API.mlirFunctionTypeGet(
context, length(inputs), pointer(inputs), length(results), pointer(results)
),
API.mlirFunctionTypeGet(context, length(inputs), inputs, length(results), results)
)
end

Expand Down
7 changes: 7 additions & 0 deletions test/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
# end
# end
end

@testset "Bool attributes" begin
x_ra = Reactant.to_rarray(false; track_numbers=(Number,))
@test @jit(iszero(x_ra)) == true
x_ra = Reactant.to_rarray(true; track_numbers=(Number,))
@test @jit(iszero(x_ra)) == false
end

1 comment on commit b92a7d8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: b92a7d8 Previous: f2a91bf Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1350587665 ns 1356899596 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1383284659 ns 1387243563 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1318521515 ns 1307710781 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3203498216 ns 3391199395 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 211895034 ns 288422658 ns 0.73
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5192693081 ns 7038978698 ns 0.74
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5105568219 ns 4922184280 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 6018217981 ns 5148694966 ns 1.17
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 11767333133 ns 7443923750 ns 1.58
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 37329667055 ns 35583150632 ns 1.05
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1352978597 ns 1326518186 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1301145716 ns 1335465488 ns 0.97
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1251932856 ns 1335531663 ns 0.94
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3185200110 ns 3467076041 ns 0.92
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8590755 ns 9067176.5 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1620067599 ns 1555452102 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1562833541 ns 1525737118 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1659962028 ns 1533932694 ns 1.08
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3277801510 ns 3300553638 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 6065735254.5 ns 2534998626 ns 2.39
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1274332633 ns 1328057662 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1312980865 ns 1259249953 ns 1.04
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1297238205 ns 1529248493 ns 0.85
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3128524863 ns 3214165683 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 21248949 ns 25864507 ns 0.82
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2146461465 ns 2138887589 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2127880295 ns 2148375491 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2168496609 ns 2178317010 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3916001186 ns 3916690304 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6604223402 ns 5911275014.5 ns 1.12
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1287472160 ns 1298681699 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1443583640.5 ns 1457354852 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1313342058.5 ns 1321007313 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3220364383 ns 3192117311 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7502452 ns 7662836.5 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1416410227 ns 1457203535 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1419844842 ns 1435056828 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1409644754 ns 1438646883 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3159215564 ns 3204423699 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 3073518446 ns 1204850367 ns 2.55
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1415141711 ns 1609570351 ns 0.88
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1303150301 ns 1288271363 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1319157274 ns 1305692496 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3404838803 ns 3210305896 ns 1.06
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11353488 ns 15484395 ns 0.73
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1723706202 ns 1721890751 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1709602541 ns 1722159183 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1695411960 ns 1701220224 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3432537936 ns 3501882082 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 5387012592 ns 2892153610 ns 1.86
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1329618389 ns 1304479984 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1343532174 ns 1315566933 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1352869420 ns 1537931496 ns 0.88
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3185476720 ns 3352632273 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 25673267 ns 25923540 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2152691101 ns 2184873795 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2143865762 ns 2206178636 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2141145130 ns 2173015030 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4249708829 ns 3992993680 ns 1.06
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 16244609813 ns 5901204125.5 ns 2.75
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1310104193 ns 1337918756 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1506787737 ns 1292347450 ns 1.17
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1269748486 ns 1325281512 ns 0.96
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3262767860 ns 3307275999 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 50691269 ns 56296298.5 ns 0.90
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6933624742 ns 2942888604 ns 2.36
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 4060945281 ns 2995273898 ns 1.36
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2967657517 ns 3018336488 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4892229925 ns 4949580089 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 11162799884 ns 14511536613 ns 0.77
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1351893597 ns 1309758052 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1298214491 ns 1309349784 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1321267516 ns 1528578322 ns 0.86
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3172547882 ns 3242193874 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 68134904 ns 72533617.5 ns 0.94
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3112388445 ns 3137681593 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3311374032 ns 3252812510 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3408895903 ns 3113127307 ns 1.10
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5053122323 ns 5045606499 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 26079756832 ns 15819654237 ns 1.65
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1300428116 ns 1316381206 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1329236745 ns 1321440422 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1335241615 ns 1321148926 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3163422535 ns 3248024917 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 19502662 ns 20424385 ns 0.95
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1818841619 ns 1868979954 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1833108489 ns 1974840414 ns 0.93
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1850734843 ns 1995499997 ns 0.93
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3648898677 ns 3753168538 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 8063384938 ns 3215953877.5 ns 2.51

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.