Skip to content

Commit

Permalink
Merge pull request #740 from LuxDL/ap/revert_brokem
Browse files Browse the repository at this point in the history
Revert bee2de7-1188db7
  • Loading branch information
avik-pal committed Jun 28, 2024
2 parents 928e3b3 + 0bfd4a9 commit dcab297
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 37 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/)

[![CI](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml)
[![CI (pre-release)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml/badge.svg?branch=main)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml)
[![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main)](https://buildkite.com/julialang/lux-dot-jl)
[![CI (pre-release)](https://img.shields.io/github/actions/workflow/status/LuxDL/Lux.jl/CIPreRelease.yml?branch=main&label=CI%20(pre-release)&logo=github)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml)
[![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite)](https://buildkite.com/julialang/lux-dot-jl)

[![codecov](https://codecov.io/gh/LuxDL/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/Lux.jl)
[![Benchmarks](https://github.com/LuxDL/Lux.jl/actions/workflows/Benchmark.yml/badge.svg?branch=main)](https://lux.csail.mit.edu/benchmarks/)

Expand Down
53 changes: 18 additions & 35 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
layer2 = ReverseSequence(2)
layer3 = ReverseSequence(1)
display(layer)
ps, st = Lux.setup(rng, layer) |> device
ps2, st2 = Lux.setup(rng, layer2) |> device
ps3, st3 = Lux.setup(rng, layer3) |> device
ps, st = Lux.setup(rng, layer) .|> device
ps2, st2 = Lux.setup(rng, layer2) .|> device
ps3, st3 = Lux.setup(rng, layer3) .|> device

x = randn(rng, 3) |> aType
xr = reverse(x)
Expand Down Expand Up @@ -440,44 +440,27 @@ end
@jet layer(x, ps, st)

x = (rand(1:vocab_size[1], 3), rand(1:vocab_size[2], 3)) .|> aType
if mode == "cuda"
@test_broken begin
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, CartesianIndex.(x...)]
end
else
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, CartesianIndex.(x...)]
end
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, CartesianIndex.(x...)]

@jet layer(x, ps, st)

if mode == "cuda"
@test_broken begin
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 3, 4)) .|> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)
end
else
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 3, 4)) .|> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)
end
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 3, 4)) .|> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@jet layer(x, ps, st)

if mode != "cuda"
x = (rand(1:vocab_size[1], 3), rand(1:vocab_size[2], 4)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)
x = (rand(1:vocab_size[1], 3), rand(1:vocab_size[2], 4)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)

x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 4, 5)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 4, 5)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)

x = ()
@test_throws ArgumentError layer(x, ps, st)
end
x = ()
@test_throws ArgumentError layer(x, ps, st)
end
end
end

0 comments on commit dcab297

Please sign in to comment.