diff --git a/src/functor.jl b/src/functor.jl index 2c8f3360db..8215b92863 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -403,8 +403,9 @@ function _metal end """ gpu(data::DataLoader) + cpu(data::DataLoader) -Transforms a given `DataLoader` to apply `gpu` to each batch of data, +Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data, when iterated over. (If no GPU is available, this does nothing.) # Example @@ -456,6 +457,18 @@ function gpu(d::MLUtils.DataLoader) ) end +function cpu(d::MLUtils.DataLoader) + MLUtils.DataLoader(MLUtils.mapobs(cpu, d.data), + d.batchsize, + d.buffer, + d.partial, + d.shuffle, + d.parallel, + d.collate, + d.rng, + ) +end + # Defining device interfaces. """ Flux.AbstractDevice <: Function diff --git a/test/data.jl b/test/data.jl index 4e4c485064..b97c4dae80 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,3 +1,4 @@ +using Flux: DataLoader using Random @testset "DataLoader" begin @@ -14,6 +15,11 @@ using Random @test batches[2] == X[:,3:4] @test batches[3] == X[:,5:5] + d_cpu = d |> cpu # does nothing but shouldn't error + @test d_cpu isa DataLoader + @test first(d_cpu) == X[:,1:2] + @test length(d_cpu) == 3 + d = DataLoader(X, batchsize=2, partial=false) # @inferred first(d) batches = collect(d) diff --git a/test/ext_cuda/cuda.jl b/test/ext_cuda/cuda.jl index b52fa6c296..bbfd2854ba 100644 --- a/test/ext_cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -182,11 +182,14 @@ end X = randn(Float64, 3, 33) pre1 = Flux.DataLoader(X |> gpu; batchsize=13, shuffle=false) post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> gpu + rev1 = pre1 |> cpu # inverse operation for epoch in 1:2 - for (p, q) in zip(pre1, post1) + for (p, q, a) in zip(pre1, post1, rev1) @test p isa CuArray{Float32} @test q isa CuArray{Float32} @test p ≈ q + @test a isa Array{Float32} + @test a ≈ Array(p) end end