diff --git a/Project.toml b/Project.toml index 926981f..b196f37 100644 --- a/Project.toml +++ b/Project.toml @@ -17,10 +17,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Distributions = "0.20, 0.21, 0.22, 0.23, 0.25" FileIO = "1" Flux = "0.10, 0.11, 0.12, 0.14" -ImageCore = "0.8, 0.9" -ImageTransformations = "0.8, 0.9" +ImageCore = "0.8, 0.9, 0.10" +ImageTransformations = "0.8, 0.9, 0.10" Reexport = "0.2, 1" -StatsBase = "0.30, 0.33" +StatsBase = "0.30, 0.33, 0.34" julia = "1.6" [extras] diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..fef5f63 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,12 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" +ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" +IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" +NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" +Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" +UNet = "0d73aaa9-994a-4556-95d0-da67cb772a03" +View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/examples/deconvolve.jl b/examples/deconvolve.jl new file mode 100644 index 0000000..eea7405 --- /dev/null +++ b/examples/deconvolve.jl @@ -0,0 +1,43 @@ +# Example using U-net to deconvolve an image + +using UNet, Flux, TestImages, View5D, Noise, NDTools, FourierTools, IndexFunArrays + +img = 100f0 .* Float32.(testimage("resolution_test_512")) + +u = Unet(); + +u = gpu(u); +function loss(u, x, y) + return Flux.mse(u(x),y) +end + +opt_state = Flux.setup(Momentum(), u); + +# selects a tile at a random (default) or predifined (ctr) position returning tile and center. +function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) + return select_region(img,new_size=tile_size, center=ctr), ctr +end + +R_max = 70; +sz = size(img); psf = abs2.(ift(disc(Float32, sz, R_max))); psf ./= sum(psf); conv_img = conv_psf(img,psf); + +scale = 0.5f0/maximum(conv_img) +patch = (128, 128) +for n in 1:2000 + println("Iteration: $n") + myimg, pos = get_tile(conv_img, patch) + # image to denoise + # nimg1 = gpu(reshape(scale .* myimg,(size(myimg)...,1,1))); # gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + nimg1 = gpu(Float32.(scale.*reshape(poisson(Float64.(myimg)), (size(myimg)...,1,1)))) + # goal image (with noise) + pimg, pos = get_tile(img, patch, pos) + pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) + rep = [(nimg1, pimg)] # Iterators.repeated((nimg1, pimg), 1); + Flux.train!(loss, u, rep, opt_state) +end + +# apply the net to the whole image instead: +nimg = gpu(scale .* reshape(conv_img,(size(conv_img)...,1,1))); # gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +nimg2 = gpu(scale.*reshape(poisson(Float64.(conv_img)),(size(conv_img)...,1,1))) +# display the images using View5D +@vt img nimg u(nimg) nimg2 u(nimg2) diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl new file mode 100644 index 0000000..bba0c6c --- /dev/null +++ b/examples/noise2noise.jl @@ -0,0 +1,39 @@ +# Example using U-net for a noise2noise problem + +using UNet, Flux, TestImages, View5D, Noise, NDTools, CUDA + +img = 10f0 .* Float32.(testimage("resolution_test_512")) + +u = Unet(); + +u = gpu(u); +function loss(u, x, y) + # return mean(abs2.(u(x) .-y)) + return Flux.mse(u(x), y) +end +opt_state = Flux.setup(Momentum(), u); + +# selects a tile at a random (default) or predifined (ctr) position returning tile and center. +function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) + return select_region(img,new_size=tile_size, center=ctr), ctr +end + +sz = size(img); +scale = 0.5f0/maximum(img) +patch = (128, 128) +for n in 1:1000 + println("Iteration: $n") + myimg, pos = get_tile(img, patch) + # image to denoise + nimg1 = gpu(scale.*reshape(Float32.(poisson(Float64.(myimg))), (size(myimg)...,1,1))) + # goal image (with noise) + nimg2 = gpu(scale.*reshape(Float32.(poisson(Float64.(myimg))), (size(myimg)...,1,1))) + rep = [(nimg1, nimg2),] # Iterators.repeated((nimg1, nimg2), 1); + # Flux.train!(loss, Flux.params(u), rep, opt_state) + Flux.train!(loss, u, rep, opt_state) +end + +# apply the net to the whole image instead: +nimg = gpu(scale.*reshape(Float32.(poisson(Float64.(img))),(size(img)...,1,1))); +# display the images using View5D +@vt img nimg u(nimg) diff --git a/src/model.jl b/src/model.jl index 76c7de1..32c06ec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,106 +1,279 @@ -function BatchNormWrap(out_ch) - Chain(x->expand_dims(x,2), - BatchNorm(out_ch), - x->squeeze(x)) +using Flux +using Flux: @functor + +struct ConvBlock{T} + op::T end -UNetConvBlock(in_chs, out_chs, kernel = (3, 3)) = - Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal), - BatchNormWrap(out_chs), - x->leakyrelu.(x,0.2f0)) +@functor ConvBlock + +function ConvBlock(in_out_pair::Pair; kwargs...) + return ConvBlock([(3, 3), (3, 3)], in_out_pair; kwargs...) +end -ConvDown(in_chs,out_chs,kernel = (4,4)) = - Chain(Conv(kernel,in_chs=>out_chs,pad=(1,1),stride=(2,2);init=_random_normal), - BatchNormWrap(out_chs), - x->leakyrelu.(x,0.2f0)) +""" + ConvBlock([kernel_sizes,] in_channels, out_channels; activation = NNlib.relu, padding = "valid") -struct UNetUpBlock - upsample +creates a convolution block with the UNet. The optional kernel sizes are given as a vector of tuples, +where each tuple represents the kernel size of a convolution layer. Default is [(3, 3), (3, 3)]. +The in_channels and out_channels are the number of input and output channels of the block. +The in_channels and out_channels have to be given as a Pair of integers (in_channels => out_channels). +The activation function is applied after each convolution layer. +""" +function ConvBlock(kernel_sizes, in_out_pair::Pair; + activation = NNlib.relu, padding = "valid") + in_channels, out_channels = in_out_pair + pad_arg = padding == "same" ? SamePad() : 0 + conv_layers = Any[] + in_channels_it = in_channels + for kernel_size in kernel_sizes + push!(conv_layers, + Conv(kernel_size, in_channels_it => out_channels, activation; pad=pad_arg) + ) + in_channels_it = out_channels + end + return ConvBlock(Chain(conv_layers...)) end -@functor UNetUpBlock +function (m::ConvBlock)(x) + return m.op(x) +end + + # ConvBlock(in_chs, out_chs, kernel = (3, 3)) = + # Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal), + # BatchNormWrap(out_chs), + # x->leakyrelu.(x,0.2f0)) -UNetUpBlock(in_chs::Int, out_chs::Int; kernel = (3, 3), p = 0.5f0) = - UNetUpBlock(Chain(x->leakyrelu.(x,0.2f0), - ConvTranspose((2, 2), in_chs=>out_chs, - stride=(2, 2);init=_random_normal), - BatchNormWrap(out_chs), - Dropout(p))) +struct Downsample{T1, T2, T3} + op::T1 + factor::T2 + pooling_func::T3 +end -function (u::UNetUpBlock)(x, bridge) - x = u.upsample(x) - return cat(x, bridge, dims = 3) +function Base.show(io::IO, d::Downsample) + print(io, "Downsample($(d.factor), $(d.pooling_func))") end -""" - Unet(channels::Int = 1, labels::Int = channels) +@functor Downsample - Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images. - `labels`, equal to the number of input channels by default, specifies the number of output channels. -""" -struct Unet - conv_down_blocks - conv_blocks - up_blocks +function Downsample(downsample_factor; pooling_func = NNlib.maxpool) + downop = x -> pooling_func(x, downsample_factor, pad=0) + return Downsample(downop, downsample_factor, pooling_func) end -@functor Unet +function (m::Downsample)(x) + for (d, x_s, f_s) in zip(1: length(m.factor), size(x), m.factor) + if (mod(x_s, f_s) !=0) + throw(DimensionMismatch("Can not downsample $(size(x)) with factor $(m.factor), mismatch in spatial dimension $d")) + end + end + return m.op(x) +end -function Unet(channels::Int = 1, labels::Int = channels) - conv_down_blocks = Chain(ConvDown(64,64), - ConvDown(128,128), - ConvDown(256,256), - ConvDown(512,512)) +struct Upsample{T1, T2} + op::T1 + factor::T2 +end - conv_blocks = Chain(UNetConvBlock(channels, 3), - UNetConvBlock(3, 64), - UNetConvBlock(64, 128), - UNetConvBlock(128, 256), - UNetConvBlock(256, 512), - UNetConvBlock(512, 1024), - UNetConvBlock(1024, 1024)) +@functor Upsample - up_blocks = Chain(UNetUpBlock(1024, 512), - UNetUpBlock(1024, 256), - UNetUpBlock(512, 128), - UNetUpBlock(256, 64,p = 0.0f0), - Chain(x->leakyrelu.(x,0.2f0), - Conv((1, 1), 128=>labels;init=_random_normal))) - Unet(conv_down_blocks, conv_blocks, up_blocks) +function Upsample(scale_factor, in_out_pair::Pair) + upop = ConvTranspose(scale_factor, in_out_pair, stride=scale_factor) + return Upsample(upop, scale_factor) end -function (u::Unet)(x::AbstractArray) - op = u.conv_blocks[1:2](x) +function (m::Upsample)(x) + return m.op(x) +end - x1 = u.conv_blocks[3](u.conv_down_blocks[1](op)) - x2 = u.conv_blocks[4](u.conv_down_blocks[2](x1)) - x3 = u.conv_blocks[5](u.conv_down_blocks[3](x2)) - x4 = u.conv_blocks[6](u.conv_down_blocks[4](x3)) +function crop(x, target_size) + if (size(x) == target_size) + return x + else + offset = Tuple((a-b)÷2+1 for (a,b) in zip(size(x), target_size)) + slice = Tuple(o:o+t-1 for (o,t) in zip(offset,target_size)) + return x[slice...,:,:] + end +end - up_x4 = u.conv_blocks[7](x4) +function(m::Upsample)(x, y) + #todo: crop_to_factor + g_up = m(x) + k = size(g_up)[1:length(m.factor)] + f_cropped = crop(y, k) + new_arr = cat(f_cropped, g_up; dims=length(m.factor)+1) + return new_arr +end - up_x1 = u.up_blocks[1](up_x4, x3) - up_x2 = u.up_blocks[2](up_x1, x2) - up_x3 = u.up_blocks[3](up_x2, x1) - up_x5 = u.up_blocks[4](up_x3, op) - tanh.(u.up_blocks[end](up_x5)) +# holds the information on the unet structure +struct Unet{T1, T2, T3, T4, T5, T6} + num_levels::T1 + l_conv_chain::T2 + l_down_chain::T3 + r_up_chain::T4 + r_conv_chain::T5 + final_conv::T6 end -function Base.show(io::IO, u::Unet) - println(io, "UNet:") +@functor Unet + +""" +function Unet(; + in_out_channels_pair = (1 => 1), + num_fmaps = 64, + fmap_inc_factor = 2, + downsample_factors = [(2,2),(2,2),(2,2),(2,2)], + kernel_sizes_down = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + kernel_sizes_up = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + activation = NNlib.relu, + final_activation = NNlib.relu; + padding="same", + pooling_func = NNlib.maxpool + ) + creates a U-net model that can then be used to be trained and to perform predictions. A UNet consists of an initial layer to + create feature maps, controlled via `num_fmaps`. This is followed by downsampling and umsampling steps, + which obtain information from the downsampling side of the net via skip-connections, which are automatically inserted. + The down- and upsampling steps contain on each level a number of consequtive convolutions controlled via the arguments `kernel_sizes_down` + and `kernel_sizes_up` respectively. + +# Paramers ++ `in_out_channels_pair`: channels of the input to the U-net and channels of the output of the U-net as a Pair of integers + ++ `num_fmaps`: number of feature maps that the input gets expanded to in the first step - for l in u.conv_down_blocks - println(io, " ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") ++ `fmap_inc_factor`: the factor that the feature maps get expanded by in every level of the U-net + ++ `downsample_factors`: vector of downsampling factors of individual U-net levels + ++ `kernel_sizes_down`: vector of vectors of tuples of individual kernel_sizes used in the convolutions on the way down + e.g. 5 lists of convolutions. 4 before downsampling and one final after the downsample, each with 2 consecutive 3x3 convolutions. + ++ `kernel_sizes_up`: vector of vectors of tuples of individual kernel_sizes used in the convolutions on the way up (backwards) + similar but but after each upsampling step, starting from the top, but not initial one before upsampling. + ++ `activation`: activation function after each convolution layer + ++ `final_activation`: activation function for the final step + ++ `padding="valid"`: method of padding during convolution and upsampling + ++ `pooling_func` = NNlib.maxpool + +# Example +```jldoctest +``` +""" +function Unet(; # all arguments are named and ahve defaults + in_out_channels_pair = (1 => 1), + num_fmaps = 64, + fmap_inc_factor = 2, + downsample_factors = [(2,2),(2,2),(2,2),(2,2)], + kernel_sizes_down = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + kernel_sizes_up = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + activation = NNlib.relu, + final_activation = NNlib.relu, + padding ="same", + pooling_func = NNlib.maxpool + ) + in_channels, out_channels = in_out_channels_pair + num_levels = length(downsample_factors) + 1 + dims = length(downsample_factors[1]) + l_convs = Any[] + for level in 1:num_levels + in_ch = (level == 1) ? in_channels : num_fmaps * fmap_inc_factor ^ (level - 2) + + cb = ConvBlock( + kernel_sizes_down[level], + in_ch => + num_fmaps * fmap_inc_factor ^ (level - 1), + activation=activation, + padding=padding + ) + push!(l_convs, cb) + end + + l_downs = Any[] + for level in 1:num_levels - 1 + push!(l_downs, + Downsample( + downsample_factors[level]; + pooling_func = pooling_func + ) + ) + end + + r_ups = Any[] + for level in 1:num_levels - 1 + push!(r_ups, + Upsample( + downsample_factors[level], + num_fmaps * fmap_inc_factor ^ level => num_fmaps * fmap_inc_factor ^ level + ) + ) end - println(io, "\n") - for l in u.conv_blocks - println(io, " UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") + r_convs = Any[] + for level in 1:num_levels - 1 + push!(r_convs, + ConvBlock( + kernel_sizes_up[level], + num_fmaps * fmap_inc_factor ^ (level - 1) + + num_fmaps * fmap_inc_factor ^ level => num_fmaps * fmap_inc_factor ^ (level - 1), + activation=activation, + padding=padding + ) + ) end - println(io, "\n") - for l in u.up_blocks - l isa UNetUpBlock || continue - println(io, " UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))") + final_conv = ConvBlock( + [ntuple((i) -> 1, dims)], + num_fmaps => out_channels, + activation=final_activation, + padding=padding + ) + return Unet(num_levels, l_convs, l_downs, r_ups, r_convs, final_conv) +end + + +function (m::Unet)(x::AbstractArray; level=1) + f_left = m.l_conv_chain[level](x) + fs_out = let + if (level == m.num_levels) + f_left + else + g_in = m.l_down_chain[level](f_left) + gs_out = m(g_in; level=level+1) + fs_right = m.r_up_chain[level](gs_out, f_left) + m.r_conv_chain[level](fs_right) + end + end + + if (level == 1) + return m.final_conv(fs_out) + else + return fs_out + end +end + +function Base.show(io::IO, u::Unet) + ws = size(u.l_conv_chain[1].op[1].weight) + println(io, "UNet, Input Channels: $(ws[end-1])") + lvl = "" + for (c, d) in zip(u.l_conv_chain, u.l_down_chain) + println(io, "$(lvl)Conv: $c") + println(io, "$(lvl)| \\") + println(io, "$(lvl)| \\DownSample: $d") + println(io, "$(lvl)| \\") + lvl *= "| " + end + println(io, "$(lvl)Conv: $(u.l_conv_chain[end])") + for (c, d) in zip(u.r_conv_chain[end:-1:1], u.r_up_chain[end:-1:1]) + lvl = lvl[1:end-5] + println(io, "$(lvl)| /") + println(io, "$(lvl)| /UpSample: $d ") + println(io, "$(lvl)| /") + println(io, "$(lvl)Concat") + println(io, "$(lvl)Conv: $(c)") end + println(io, "FinalConv: $(u.final_conv)") end