Skip to content

Commit

Permalink
fix: test updates from new changes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 23, 2024
1 parent bd7b8c0 commit e8808fb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ext/LuxDynamicExpressionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent
using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction, DynamicExpressionsLayer
using Lux: Lux, NAME_TYPE, Chain, Parallel, DynamicExpressionsLayer
using LuxDeviceUtils: LuxCPUDevice

const CRC = ChainRulesCore
Expand Down
9 changes: 3 additions & 6 deletions ext/LuxFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ function Lux.__from_flux_adaptor(l::Flux.Conv; preserve_ps_st::Bool=false, kwarg
groups = l.groups
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing : vec(copy(l.bias))
return Lux.Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation,
groups, init_weight=Returns(Lux._maybe_flip_conv_weight(l.weight)),
init_bias=Returns(_bias), use_bias=!(l.bias isa Bool))
Expand All @@ -114,8 +113,7 @@ function Lux.__from_flux_adaptor(
groups = l.groups
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing : vec(copy(l.bias))
return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride,
pad, l.dilation, groups, use_bias=!(l.bias isa Bool),
init_weight=Returns(Lux._maybe_flip_conv_weight(l.weight)),
Expand All @@ -131,8 +129,7 @@ function Lux.__from_flux_adaptor(l::Flux.CrossCor; preserve_ps_st::Bool=false, k
in_chs, out_chs = size(l.weight)[(end - 1):end]
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing : vec(copy(l.bias))
return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad,
l.dilation, init_weight=Returns(copy(l.weight)),
init_bias=Returns(_bias), use_bias=!(l.bias isa Bool))
Expand Down
4 changes: 2 additions & 2 deletions test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ end
display(layer)
ps, st = Lux.setup(rng, layer)
@test ps.weight isa aType{Float64, 4}
@test ps.bias isa aType{Float16, 4}
@test ps.bias isa aType{Float16, 1}
end

@testset "Depthwise Conv" begin
Expand Down Expand Up @@ -447,7 +447,7 @@ end
display(layer)
ps, st = Lux.setup(rng, layer)
@test ps.weight isa aType{Float64, 4}
@test ps.bias isa aType{Float16, 4}
@test ps.bias isa aType{Float16, 1}
end

@testset "CrossCor SamePad kernelsize $k" for k in (
Expand Down

0 comments on commit e8808fb

Please sign in to comment.