From c8c8fc4a58fba25e304267c40773221ea5ed2979 Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Wed, 15 Dec 2021 10:56:31 +0100 Subject: [PATCH 01/21] Added version compatibility --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c7a35d0..ac600a7 100644 --- a/Project.toml +++ b/Project.toml @@ -20,8 +20,8 @@ Flux = "0.10, 0.11, 0.12" ImageCore = "0.8, 0.9" ImageTransformations = "0.8" Reexport = "0.2, 1" -StatsBase = "0.30" -julia = "1.3" +StatsBase = "0.3, 0.24, 0.33" +julia = "1.3, 1.6" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 5b5c6e1d55868b848808ab87e3da01fe8a2f3534 Mon Sep 17 00:00:00 2001 From: Larissa Date: Thu, 16 Dec 2021 10:55:07 -0500 Subject: [PATCH 02/21] WIP configurable UNet --- src/model.jl | 270 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 190 insertions(+), 80 deletions(-) diff --git a/src/model.jl b/src/model.jl index 76c7de1..d07e79a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,106 +1,216 @@ -function BatchNormWrap(out_ch) - Chain(x->expand_dims(x,2), - BatchNorm(out_ch), - x->squeeze(x)) +using Flux +using Flux: @functor +struct ConvBlock + op end -UNetConvBlock(in_chs, out_chs, kernel = (3, 3)) = - Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal), - BatchNormWrap(out_chs), - x->leakyrelu.(x,0.2f0)) - -ConvDown(in_chs,out_chs,kernel = (4,4)) = - Chain(Conv(kernel,in_chs=>out_chs,pad=(1,1),stride=(2,2);init=_random_normal), - BatchNormWrap(out_chs), - x->leakyrelu.(x,0.2f0)) - -struct UNetUpBlock - upsample +@functor ConvBlock + +function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)]; + activation = NNlib.relu, padding = "valid") + pad_arg = padding == "same" ? SamePad() : 0 + conv_layers = Any[] + in_channels_it = in_channels + for kernel_size in kernel_sizes + push!(conv_layers, + Conv(kernel_size, in_channels_it => out_channels, activation; pad=pad_arg) + ) + in_channels_it = out_channels + end + return ConvBlock(Chain(conv_layers...)) end -@functor UNetUpBlock +function (m::ConvBlock)(x) + println(size(x)) + return m.op(x) +end + + # ConvBlock(in_chs, out_chs, kernel = (3, 3)) = + # Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal), + # BatchNormWrap(out_chs), + # x->leakyrelu.(x,0.2f0)) + +struct Downsample + op + factor +end -UNetUpBlock(in_chs::Int, out_chs::Int; kernel = (3, 3), p = 0.5f0) = - UNetUpBlock(Chain(x->leakyrelu.(x,0.2f0), - ConvTranspose((2, 2), in_chs=>out_chs, - stride=(2, 2);init=_random_normal), - BatchNormWrap(out_chs), - Dropout(p))) +@functor Downsample -function (u::UNetUpBlock)(x, bridge) - x = u.upsample(x) - return cat(x, bridge, dims = 3) +function Downsample(downsample_factor; pooling_type="max") + if (pooling_type == "max") + downop = x -> NNlib.maxpool(x, downsample_factor, pad=0) + else + downop = x -> NNlib.meanpool(x, downsample_factor, pad=0) + end + return Downsample(downop, downsample_factor) end -""" - Unet(channels::Int = 1, labels::Int = channels) +struct RuntimeError <: Exception end - Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images. - `labels`, equal to the number of input channels by default, specifies the number of output channels. -""" -struct Unet - conv_down_blocks - conv_blocks - up_blocks +function (m::Downsample)(x) + for (d, x_s, f_s) in zip(1: length(m.factor), size(x), m.factor) + if (mod(x_s, f_s) !=0) + throw(RuntimeError("Can not downsample $(size(x)) with factor $(m.factor), mismatch in spatial dimension $d")) + end + end + return m.op(x) end -@functor Unet +struct Upsample + op + factor +end + +@functor Upsample -function Unet(channels::Int = 1, labels::Int = channels) - conv_down_blocks = Chain(ConvDown(64,64), - ConvDown(128,128), - ConvDown(256,256), - ConvDown(512,512)) - - conv_blocks = Chain(UNetConvBlock(channels, 3), - UNetConvBlock(3, 64), - UNetConvBlock(64, 128), - UNetConvBlock(128, 256), - UNetConvBlock(256, 512), - UNetConvBlock(512, 1024), - UNetConvBlock(1024, 1024)) - - up_blocks = Chain(UNetUpBlock(1024, 512), - UNetUpBlock(1024, 256), - UNetUpBlock(512, 128), - UNetUpBlock(256, 64,p = 0.0f0), - Chain(x->leakyrelu.(x,0.2f0), - Conv((1, 1), 128=>labels;init=_random_normal))) - Unet(conv_down_blocks, conv_blocks, up_blocks) +function Upsample(scale_factor, in_channels, out_channels) + upop = ConvTranspose(scale_factor, in_channels=>out_channels, stride=scale_factor) + return Upsample(upop, scale_factor) end -function (u::Unet)(x::AbstractArray) - op = u.conv_blocks[1:2](x) +function (m::Upsample)(x) + return m.op(x) +end - x1 = u.conv_blocks[3](u.conv_down_blocks[1](op)) - x2 = u.conv_blocks[4](u.conv_down_blocks[2](x1)) - x3 = u.conv_blocks[5](u.conv_down_blocks[3](x2)) - x4 = u.conv_blocks[6](u.conv_down_blocks[4](x3)) +function crop(x, size) + target_size = size(x) + offset = Tuple((a-b)÷2 for (a,b) in zip(size(x), target_size)) +end - up_x4 = u.conv_blocks[7](x4) +function(m::Upsample)(x, y) + #todo: crop_to_factor + g_cropped = y + f_cropped = crop(m(x), size(g_cropped)[:length(m.factor)]) + new_arr = cat(f_cropped, g_cropped; dims=length(m.factor)+1) + println("CONCAT", size(new_arr)) + return new_arr +end - up_x1 = u.up_blocks[1](up_x4, x3) - up_x2 = u.up_blocks[2](up_x1, x2) - up_x3 = u.up_blocks[3](up_x2, x1) - up_x5 = u.up_blocks[4](up_x3, op) - tanh.(u.up_blocks[end](up_x5)) +struct Unet + num_levels + l_conv_chain + l_down_chain + r_up_chain + r_conv_chain + final_conv end -function Base.show(io::IO, u::Unet) - println(io, "UNet:") +@functor Unet - for l in u.conv_down_blocks - println(io, " ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") +function Unet( + in_channels, + out_channels, + num_fmaps, + fmap_inc_factor, + downsample_factors, + kernel_size_down, + kernel_size_up, + activation, + final_activation; + padding="valid", + pooling_type="max" + ) + num_levels = length(downsample_factors) + 1 + dims = length(downsample_factors[1]) + l_convs = Any[] + for level in 1:num_levels + in_ch = (level == 1) ? in_channels : num_fmaps * fmap_inc_factor ^ (level - 2) + push!(l_convs, + ConvBlock(in_ch, + num_fmaps * fmap_inc_factor ^ (level - 1), + kernel_size_down[level], + activation=activation, + padding=padding + ) + ) + end + + + l_downs = Any[] + for level in 1:num_levels - 1 + push!(l_downs, + Downsample( + downsample_factors[level]; + pooling_type=pooling_type + ) + ) + end + + + r_ups = Any[] + for level in 1:num_levels - 1 + push!(r_ups, + Upsample( + downsample_factors[level], + num_fmaps * fmap_inc_factor ^ level, + num_fmaps * fmap_inc_factor ^ level + ) + ) end - println(io, "\n") - for l in u.conv_blocks - println(io, " UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") + + r_convs = Any[] + for level in 1:num_levels - 1 + push!(r_convs, + ConvBlock( + num_fmaps * fmap_inc_factor ^ (level - 1) + + num_fmaps * fmap_inc_factor ^ level, + num_fmaps * fmap_inc_factor ^ (level - 1), + kernel_size_up[level], + activation=activation, + padding=padding + ) + ) end - println(io, "\n") - for l in u.up_blocks - l isa UNetUpBlock || continue - println(io, " UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))") + + final_conv = ConvBlock( + num_fmaps, + out_channels, + [Tuple(1 for i in 1:dims)], + activation=final_activation, + padding=padding + ) + return Unet(num_levels, l_convs, l_downs, r_ups, r_convs, final_conv) +end + + +function (m::Unet)(x::AbstractArray; level=1) + f_left = m.l_conv_chain[level](x) + fs_out = let + if (level == m.num_levels) + f_left + else + g_in = m.l_down_chain[level](f_left) + gs_out = m(g_in; level=level+1) + fs_right = m.r_up_chain[level](gs_out, f_left) + m.r_conv_chain[level](fs_right) + end + end + + if (level == 1) + return m.final_conv(fs_out) + else + return fs_out end end + +# function Base.show(io::IO, u::Unet) +# println(io, "UNet:") + +# for l in u.conv_down_blocks +# println(io, " ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") +# end + +# println(io, "\n") +# for l in u.conv_blocks +# println(io, " UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") +# end + +# println(io, "\n") +# for l in u.up_blocks +# l isa UNetUpBlock || continue +# println(io, " UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))") +# end +# end From ffd9a8220493b57ffa14ea82da4326da34e1cce2 Mon Sep 17 00:00:00 2001 From: Larissa Date: Thu, 16 Dec 2021 10:56:17 -0500 Subject: [PATCH 03/21] add example with configured UNet --- Project.toml | 7 ++++++ examples/flux_test.jl | 53 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 examples/flux_test.jl diff --git a/Project.toml b/Project.toml index ac600a7..79c5e7a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,11 +7,18 @@ version = "0.2.0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" +ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" +IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" +NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" +Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" +View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" [compat] Distributions = "0.20, 0.21, 0.22, 0.23, 0.25" diff --git a/examples/flux_test.jl b/examples/flux_test.jl new file mode 100644 index 0000000..02544d2 --- /dev/null +++ b/examples/flux_test.jl @@ -0,0 +1,53 @@ +# Test FLUX.jl and a U-net architechture + +using Flux +using UNet +using TestImages +using View5D +using Noise +using NDTools +using FourierTools +using IndexFunArrays + +img = 10.0 .* Float32.(testimage("resolution_test_512")) + + +# @ve nimg1 nimg2 + +u = Unet(1, 1, 64, 2, [(2,2),(2,2),(2,2),(2,2)], [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], +[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], NNlib.relu, NNlib.relu; padding="same"); +u = gpu(u); +function loss(x, y) + # op = clamp.(u(x), 0.001f0, 1.f0) + mean(abs2.(u(x) .-y)) +end +opt = Momentum() + +function get_random_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) + return select_region(img,new_size=tile_size, center=ctr), ctr +end + +sz = size(img) +psf = abs2.(ift(disc(sz, 40))); psf ./= sum(psf) +conv_img = conv_psf(img,psf) + +scale = 0.5/maximum(conv_img) +patch = (128,128) +for n in 1:100 + println("Iteration: $n") + myimg, pos = get_random_tile(conv_img,patch) + nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + # nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + pimg, pos = get_random_tile(img,patch,pos) + pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) + rep = Iterators.repeated((nimg1, pimg), 1); + Flux.train!(loss, Flux.params(u), rep, opt) +end + +# myimg, pos = get_random_tile(conv_img,patch) +# nimg3 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) +# @ve myimg nimg3 u(nimg3) + +# apply the net to the whole image instead: +nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +@ve img nimg u(nimg) From bd60834d614156373c7547c9e88d412c6ba004dc Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Thu, 16 Dec 2021 18:13:01 +0100 Subject: [PATCH 04/21] bug fixes --- examples/flux_test.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/flux_test.jl b/examples/flux_test.jl index 02544d2..c79fa95 100644 --- a/examples/flux_test.jl +++ b/examples/flux_test.jl @@ -14,8 +14,14 @@ img = 10.0 .* Float32.(testimage("resolution_test_512")) # @ve nimg1 nimg2 -u = Unet(1, 1, 64, 2, [(2,2),(2,2),(2,2),(2,2)], [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], -[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], NNlib.relu, NNlib.relu; padding="same"); +# u = Unet(); +u = Unet(1, 1, 64, 2, [(2,2),(2,2),(2,2),(2,2)], +[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], +[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], +NNlib.relu, +NNlib.relu; +padding="same"); + u = gpu(u); function loss(x, y) # op = clamp.(u(x), 0.001f0, 1.f0) @@ -50,4 +56,4 @@ end # apply the net to the whole image instead: nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) -@ve img nimg u(nimg) +@ve img nimg u(nimg) conv_img From a4cfa131fa72b6aab1cbcc3004d48b30f582c17e Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Thu, 16 Dec 2021 18:13:06 +0100 Subject: [PATCH 05/21] bug fixes --- src/model.jl | 89 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/src/model.jl b/src/model.jl index d07e79a..2082185 100644 --- a/src/model.jl +++ b/src/model.jl @@ -81,13 +81,14 @@ end function(m::Upsample)(x, y) #todo: crop_to_factor g_cropped = y - f_cropped = crop(m(x), size(g_cropped)[:length(m.factor)]) + # f_cropped = crop(m(x), size(g_cropped)[:length(m.factor)]) + f_cropped = m(x) new_arr = cat(f_cropped, g_cropped; dims=length(m.factor)+1) println("CONCAT", size(new_arr)) return new_arr end -struct Unet +struct Unet num_levels l_conv_chain l_down_chain @@ -98,34 +99,77 @@ end @functor Unet +""" function Unet( - in_channels, - out_channels, - num_fmaps, - fmap_inc_factor, - downsample_factors, - kernel_size_down, - kernel_size_up, - activation, - final_activation; - padding="valid", + in_channels = 1, + out_channels = 1, + num_fmaps = 64, + fmap_inc_factor = 2, + downsample_factors = [(2,2),(2,2),(2,2),(2,2)], + kernel_sizes_down = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + kernel_sizes_up = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + activation = NNlib.relu, + final_activation = NNlib.relu; + padding="same", pooling_type="max" ) + creates a U-net model that can then be used to be trained and to perform predictions. + +# Paramers ++ `in_channels`: channels of the input to the U-net + ++ `out_channels`: channels of the output of the U-net + ++ `num_fmaps`: number of feature maps that the input gets expanded to in the first step + ++ `fmap_inc_factor`: the factor that the feature maps get expanded by in every level of the U-net + ++ `downsample_factors`: vector of downsampling factors of individual U-net levels + ++ `kernel_sizes_down`: vector of vectors of tuples of individual kernel_sizes used in the convolutions on the way down + e.g. 5 lists of convolutions. 4 before downsampling and one final after the downsample, each with 2 consecutive 3x3 convolutions. + ++ `kernel_sizes_up`: vector of vectors of tuples of individual kernel_sizes used in the convolutions on the way up (backwards) + similar but but after each upsampling step, starting from the top, but not initial one before upsampling. + ++ `activation`: activation function after each convolution layer + ++ `final_activation`: activation function for the final step + ++ `padding="valid"`: method of padding during convolution and upsampling + ++ `pooling_type="max"`: type of pooling + +""" +function Unet( + in_channels,# = 1, + out_channels,# = 1, + num_fmaps,# = 64, + fmap_inc_factor,# = 2, + downsample_factors,# = [(2,2),(2,2),(2,2),(2,2)], + kernel_sizes_down,# = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + kernel_sizes_up,# = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + activation,# = NNlib.relu, + final_activation ; # = NNlib.relu; + padding ="same", + pooling_type ="max" + ) + @show downsample_factors num_levels = length(downsample_factors) + 1 dims = length(downsample_factors[1]) l_convs = Any[] for level in 1:num_levels + @show l_convs in_ch = (level == 1) ? in_channels : num_fmaps * fmap_inc_factor ^ (level - 2) - push!(l_convs, - ConvBlock(in_ch, - num_fmaps * fmap_inc_factor ^ (level - 1), - kernel_size_down[level], - activation=activation, - padding=padding - ) - ) + + cb = ConvBlock(in_ch, + num_fmaps * fmap_inc_factor ^ (level - 1), + kernel_sizes_down[level], + activation=activation, + padding=padding + ) + push!(l_convs, cb) end - l_downs = Any[] for level in 1:num_levels - 1 @@ -157,14 +201,13 @@ function Unet( num_fmaps * fmap_inc_factor ^ (level - 1) + num_fmaps * fmap_inc_factor ^ level, num_fmaps * fmap_inc_factor ^ (level - 1), - kernel_size_up[level], + kernel_sizes_up[level], activation=activation, padding=padding ) ) end - final_conv = ConvBlock( num_fmaps, out_channels, From 8b91545b66e44242ea2f732da295676e9f09463c Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Fri, 17 Dec 2021 09:49:10 +0100 Subject: [PATCH 06/21] added pretty print and removed debug prints. --- examples/deconvolve.jl | 39 ++++++++++++++++++++++ examples/flux_test.jl | 59 --------------------------------- examples/noise2noise.jl | 37 +++++++++++++++++++++ src/model.jl | 72 ++++++++++++++++++++++------------------- 4 files changed, 114 insertions(+), 93 deletions(-) create mode 100644 examples/deconvolve.jl delete mode 100644 examples/flux_test.jl create mode 100644 examples/noise2noise.jl diff --git a/examples/deconvolve.jl b/examples/deconvolve.jl new file mode 100644 index 0000000..1ab5ee0 --- /dev/null +++ b/examples/deconvolve.jl @@ -0,0 +1,39 @@ +# Example using U-net to deconvolve an image + +using UNet, Flux, TestImages, View5D, Noise, NDTools, FourierTools, IndexFunArrays + +img = 10.0 .* Float32.(testimage("resolution_test_512")) + +u = Unet(); + +u = gpu(u); +function loss(x, y) + return Flux.mse(u(x),y) +end +opt = Momentum() + +# selects a tile at a random (default) or predifined (ctr) position returning tile and center. +function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) + return select_region(img,new_size=tile_size, center=ctr), ctr +end + +sz = size(img); psf = abs2.(ift(disc(sz, 40))); psf ./= sum(psf); conv_img = conv_psf(img,psf); + +scale = 0.5/maximum(conv_img) +patch = (128,128) +for n in 1:1000 + println("Iteration: $n") + myimg, pos = get_tile(conv_img,patch) + # image to denoise + nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + # goal image (with noise) + nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + pimg, pos = get_tile(img,patch,pos) + pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) + rep = Iterators.repeated((nimg1, pimg), 1); + Flux.train!(loss, Flux.params(u), rep, opt) +end + +# apply the net to the whole image instead: +nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +@ve img nimg u(nimg) conv_img diff --git a/examples/flux_test.jl b/examples/flux_test.jl deleted file mode 100644 index c79fa95..0000000 --- a/examples/flux_test.jl +++ /dev/null @@ -1,59 +0,0 @@ -# Test FLUX.jl and a U-net architechture - -using Flux -using UNet -using TestImages -using View5D -using Noise -using NDTools -using FourierTools -using IndexFunArrays - -img = 10.0 .* Float32.(testimage("resolution_test_512")) - - -# @ve nimg1 nimg2 - -# u = Unet(); -u = Unet(1, 1, 64, 2, [(2,2),(2,2),(2,2),(2,2)], -[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], -[[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], -NNlib.relu, -NNlib.relu; -padding="same"); - -u = gpu(u); -function loss(x, y) - # op = clamp.(u(x), 0.001f0, 1.f0) - mean(abs2.(u(x) .-y)) -end -opt = Momentum() - -function get_random_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) - return select_region(img,new_size=tile_size, center=ctr), ctr -end - -sz = size(img) -psf = abs2.(ift(disc(sz, 40))); psf ./= sum(psf) -conv_img = conv_psf(img,psf) - -scale = 0.5/maximum(conv_img) -patch = (128,128) -for n in 1:100 - println("Iteration: $n") - myimg, pos = get_random_tile(conv_img,patch) - nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) - # nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) - pimg, pos = get_random_tile(img,patch,pos) - pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) - rep = Iterators.repeated((nimg1, pimg), 1); - Flux.train!(loss, Flux.params(u), rep, opt) -end - -# myimg, pos = get_random_tile(conv_img,patch) -# nimg3 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) -# @ve myimg nimg3 u(nimg3) - -# apply the net to the whole image instead: -nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) -@ve img nimg u(nimg) conv_img diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl new file mode 100644 index 0000000..96dda92 --- /dev/null +++ b/examples/noise2noise.jl @@ -0,0 +1,37 @@ +# Example using U-net for a noise2noise problem + +using UNet, Flux, TestImages, View5D, Noise, NDTools + +img = 10.0 .* Float32.(testimage("resolution_test_512")) + +u = Unet(); + +u = gpu(u); +function loss(x, y) + # return mean(abs2.(u(x) .-y)) + return Flux.mse(u(x),y) +end +opt = Momentum() + +# selects a tile at a random (default) or predifined (ctr) position returning tile and center. +function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) + return select_region(img,new_size=tile_size, center=ctr), ctr +end + +sz = size(img); +scale = 0.5/maximum(img) +patch = (128,128) +for n in 1:1000 + println("Iteration: $n") + myimg, pos = get_tile(img,patch) + # image to denoise + nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + # goal image (with noise) + nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + rep = Iterators.repeated((nimg1, nimg2), 1); + Flux.train!(loss, Flux.params(u), rep, opt) +end + +# apply the net to the whole image instead: +nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +@ve img nimg u(nimg) diff --git a/src/model.jl b/src/model.jl index 2082185..4ecce05 100644 --- a/src/model.jl +++ b/src/model.jl @@ -21,7 +21,7 @@ function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)]; end function (m::ConvBlock)(x) - println(size(x)) + # println(size(x)) return m.op(x) end @@ -84,10 +84,11 @@ function(m::Upsample)(x, y) # f_cropped = crop(m(x), size(g_cropped)[:length(m.factor)]) f_cropped = m(x) new_arr = cat(f_cropped, g_cropped; dims=length(m.factor)+1) - println("CONCAT", size(new_arr)) + # println("CONCAT", size(new_arr)) return new_arr end +# holds the information on the unet structure struct Unet num_levels l_conv_chain @@ -100,7 +101,7 @@ end @functor Unet """ -function Unet( +function Unet(; in_channels = 1, out_channels = 1, num_fmaps = 64, @@ -113,7 +114,11 @@ function Unet( padding="same", pooling_type="max" ) - creates a U-net model that can then be used to be trained and to perform predictions. + creates a U-net model that can then be used to be trained and to perform predictions. A UNet consists of an initial layer to + create feature maps, controlled via `num_fmaps`. This is followed by downsampling and umsampling steps, + which obtain information from the downsampling side of the net via skip-connections, which are automatically inserted. + The down- and upsampling steps contain on each level a number of consequtive convolutions controlled via the arguments `kernel_sizes_down` + and `kernel_sizes_up` respectively. # Paramers + `in_channels`: channels of the input to the U-net @@ -140,26 +145,27 @@ function Unet( + `pooling_type="max"`: type of pooling +# Example +```jldoctest +``` """ -function Unet( - in_channels,# = 1, - out_channels,# = 1, - num_fmaps,# = 64, - fmap_inc_factor,# = 2, - downsample_factors,# = [(2,2),(2,2),(2,2),(2,2)], - kernel_sizes_down,# = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], - kernel_sizes_up,# = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], - activation,# = NNlib.relu, - final_activation ; # = NNlib.relu; +function Unet(; # all arguments are named and ahve defaults + in_channels = 1, + out_channels = 1, + num_fmaps = 64, + fmap_inc_factor = 2, + downsample_factors = [(2,2),(2,2),(2,2),(2,2)], + kernel_sizes_down = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + kernel_sizes_up = [[(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)], [(3,3), (3,3)]], + activation = NNlib.relu, + final_activation = NNlib.relu, padding ="same", pooling_type ="max" ) - @show downsample_factors num_levels = length(downsample_factors) + 1 dims = length(downsample_factors[1]) l_convs = Any[] for level in 1:num_levels - @show l_convs in_ch = (level == 1) ? in_channels : num_fmaps * fmap_inc_factor ^ (level - 2) cb = ConvBlock(in_ch, @@ -181,7 +187,6 @@ function Unet( ) end - r_ups = Any[] for level in 1:num_levels - 1 push!(r_ups, @@ -193,7 +198,6 @@ function Unet( ) end - r_convs = Any[] for level in 1:num_levels - 1 push!(r_convs, @@ -239,21 +243,21 @@ function (m::Unet)(x::AbstractArray; level=1) end end -# function Base.show(io::IO, u::Unet) -# println(io, "UNet:") - -# for l in u.conv_down_blocks -# println(io, " ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") -# end +function Base.show(io::IO, u::Unet) + ws = size(u.l_conv_chain[1].op[1].weight) + println(io, "UNet, Input Channels: $(ws[end-1]):") + lvl = "" + for (c,d) in zip(u.l_conv_chain,u.l_down_chain) + println(io, "$(lvl)Conv($(c))") + println(io, "$(lvl)DownSample($(d.factor))") + lvl *= "| " + end -# println(io, "\n") -# for l in u.conv_blocks -# println(io, " UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))") -# end + for (c,d) in zip(u.l_conv_chain,u.l_down_chain) + println(io, "$(lvl)Conv($(c))") + println(io, "$(lvl)DownSample($(d.factor))") + lvl = lvl[1:end-4] + end + println(io, "FinalConv($(u.final_conv))") -# println(io, "\n") -# for l in u.up_blocks -# l isa UNetUpBlock || continue -# println(io, " UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))") -# end -# end +end From 91cfb2dbe3a53a0bf30ca1a884c76eccf41668ae Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Fri, 17 Dec 2021 10:01:14 +0100 Subject: [PATCH 07/21] bug fixes in show(). removed flux_test.jl --- examples/deconvolve.jl | 1 + examples/noise2noise.jl | 1 + src/model.jl | 12 ++++++------ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/deconvolve.jl b/examples/deconvolve.jl index 1ab5ee0..eeed910 100644 --- a/examples/deconvolve.jl +++ b/examples/deconvolve.jl @@ -36,4 +36,5 @@ end # apply the net to the whole image instead: nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +# display the images using View5D @ve img nimg u(nimg) conv_img diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index 96dda92..1602272 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -34,4 +34,5 @@ end # apply the net to the whole image instead: nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +# display the images using View5D @ve img nimg u(nimg) diff --git a/src/model.jl b/src/model.jl index 4ecce05..acf253b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -247,17 +247,17 @@ function Base.show(io::IO, u::Unet) ws = size(u.l_conv_chain[1].op[1].weight) println(io, "UNet, Input Channels: $(ws[end-1]):") lvl = "" - for (c,d) in zip(u.l_conv_chain,u.l_down_chain) + for (c,d) in zip(u.l_conv_chain, u.l_down_chain) println(io, "$(lvl)Conv($(c))") println(io, "$(lvl)DownSample($(d.factor))") lvl *= "| " end - - for (c,d) in zip(u.l_conv_chain,u.l_down_chain) + println(io, "$(lvl)Conv($(u.l_conv_chain[end]))") + lvl = lvl[1:end-5] + for (c,d) in zip(u.r_conv_chain[end:-1:1], u.r_up_chain) + println(io, "$(lvl)UpSample($(d.factor))") println(io, "$(lvl)Conv($(c))") - println(io, "$(lvl)DownSample($(d.factor))") - lvl = lvl[1:end-4] + lvl = lvl[1:end-5] end println(io, "FinalConv($(u.final_conv))") - end From 74ffb5bf699848263e8ba501cf68396756187054 Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Mon, 20 Dec 2021 22:17:07 +0100 Subject: [PATCH 08/21] bug fixes in noise2noise and deconvolve --- examples/deconvolve.jl | 20 +++++++++++--------- examples/noise2noise.jl | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/deconvolve.jl b/examples/deconvolve.jl index eeed910..7b91dec 100644 --- a/examples/deconvolve.jl +++ b/examples/deconvolve.jl @@ -2,7 +2,7 @@ using UNet, Flux, TestImages, View5D, Noise, NDTools, FourierTools, IndexFunArrays -img = 10.0 .* Float32.(testimage("resolution_test_512")) +img = 100f0 .* Float32.(testimage("resolution_test_512")) u = Unet(); @@ -17,24 +17,26 @@ function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img return select_region(img,new_size=tile_size, center=ctr), ctr end -sz = size(img); psf = abs2.(ift(disc(sz, 40))); psf ./= sum(psf); conv_img = conv_psf(img,psf); +R_max = 70; +sz = size(img); psf = abs2.(ift(disc(Float32, sz, R_max))); psf ./= sum(psf); conv_img = conv_psf(img,psf); -scale = 0.5/maximum(conv_img) +scale = 0.5f0/maximum(conv_img) patch = (128,128) -for n in 1:1000 +for n in 1:2000 println("Iteration: $n") myimg, pos = get_tile(conv_img,patch) # image to denoise - nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + # nimg1 = gpu(reshape(scale .* myimg,(size(myimg)...,1,1))); # gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + nimg1 = gpu(scale.*reshape(poisson(Float64.(myimg)),(size(myimg)...,1,1))) # goal image (with noise) - nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) pimg, pos = get_tile(img,patch,pos) pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) - rep = Iterators.repeated((nimg1, pimg), 1); + rep = Iterators.repeated((nimg1, pimg), 4); Flux.train!(loss, Flux.params(u), rep, opt) end # apply the net to the whole image instead: -nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +nimg = gpu(scale .* reshape(conv_img,(size(conv_img)...,1,1))); # gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +nimg2 = gpu(scale.*reshape(poisson(Float64.(conv_img)),(size(conv_img)...,1,1))) # display the images using View5D -@ve img nimg u(nimg) conv_img +@ve img nimg u(nimg) nimg2 u(nimg2) diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index 1602272..4bdee01 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -33,6 +33,6 @@ for n in 1:1000 end # apply the net to the whole image instead: -nimg = gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) +nimg = gpu(scale.*reshape(poisson(img),(size(img)...,1,1))); # display the images using View5D @ve img nimg u(nimg) From 82a736b873ade803e0d30edf6bc409b0b10198c6 Mon Sep 17 00:00:00 2001 From: rheintzmann Date: Wed, 22 Dec 2021 11:00:15 +0100 Subject: [PATCH 09/21] reduced iterations in noise2noise example --- examples/noise2noise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index 4bdee01..682baab 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -21,7 +21,7 @@ end sz = size(img); scale = 0.5/maximum(img) patch = (128,128) -for n in 1:1000 +for n in 1:100 println("Iteration: $n") myimg, pos = get_tile(img,patch) # image to denoise From 338797fa7459f30955780850723f7636730056c6 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Tue, 4 Jan 2022 18:55:02 -0500 Subject: [PATCH 10/21] clarify UNet pretty print --- src/model.jl | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/model.jl b/src/model.jl index acf253b..a54f3f5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -33,6 +33,11 @@ end struct Downsample op factor + pooling_type +end + +function Base.show(io::IO, d::Downsample) + print(io, "Unet.Downsample($(d.factor), $(d.pooling_type))") end @functor Downsample @@ -42,8 +47,9 @@ function Downsample(downsample_factor; pooling_type="max") downop = x -> NNlib.maxpool(x, downsample_factor, pad=0) else downop = x -> NNlib.meanpool(x, downsample_factor, pad=0) + pooling_type = "mean" end - return Downsample(downop, downsample_factor) + return Downsample(downop, downsample_factor, pooling_type) end struct RuntimeError <: Exception end @@ -245,19 +251,20 @@ end function Base.show(io::IO, u::Unet) ws = size(u.l_conv_chain[1].op[1].weight) - println(io, "UNet, Input Channels: $(ws[end-1]):") + println(io, "UNet, Input Channels: $(ws[end-1])") lvl = "" - for (c,d) in zip(u.l_conv_chain, u.l_down_chain) - println(io, "$(lvl)Conv($(c))") - println(io, "$(lvl)DownSample($(d.factor))") - lvl *= "| " + for (c, d) in zip(u.l_conv_chain, u.l_down_chain) + println(io, "$(lvl)Conv: $c") + lvl *= "| " + println(io, "$(lvl)DownSample: $d") end - println(io, "$(lvl)Conv($(u.l_conv_chain[end]))") - lvl = lvl[1:end-5] - for (c,d) in zip(u.r_conv_chain[end:-1:1], u.r_up_chain) - println(io, "$(lvl)UpSample($(d.factor))") - println(io, "$(lvl)Conv($(c))") - lvl = lvl[1:end-5] + println(io, "$(lvl)Conv: $(u.l_conv_chain[end])") + for (c, d) in zip(u.r_conv_chain[end:-1:1], u.r_up_chain[end:-1:1]) + println(io, "$(lvl)UpSample: $d ") + println(io, "$(lvl[1:end-4]) /") + lvl = lvl[1:end-5] + println(io, "$(lvl)Concat") + println(io, "$(lvl)Conv: $(c)") end - println(io, "FinalConv($(u.final_conv))") + println(io, "FinalConv: $(u.final_conv)") end From 3c65ea196ea77568bdcbb5fda575215a059f8705 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Tue, 4 Jan 2022 19:01:56 -0500 Subject: [PATCH 11/21] emphasize U-structure in pretty print --- src/model.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index a54f3f5..b143ff4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -255,14 +255,17 @@ function Base.show(io::IO, u::Unet) lvl = "" for (c, d) in zip(u.l_conv_chain, u.l_down_chain) println(io, "$(lvl)Conv: $c") + println(io, "$(lvl)| \\") + println(io, "$(lvl)| \\DownSample: $d") + println(io, "$(lvl)| \\") lvl *= "| " - println(io, "$(lvl)DownSample: $d") end println(io, "$(lvl)Conv: $(u.l_conv_chain[end])") for (c, d) in zip(u.r_conv_chain[end:-1:1], u.r_up_chain[end:-1:1]) - println(io, "$(lvl)UpSample: $d ") - println(io, "$(lvl[1:end-4]) /") lvl = lvl[1:end-5] + println(io, "$(lvl)| /") + println(io, "$(lvl)| /UpSample: $d ") + println(io, "$(lvl)| /") println(io, "$(lvl)Concat") println(io, "$(lvl)Conv: $(c)") end From f15f8c8140bad16da4b3cd771e29cba532af9510 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Tue, 15 Mar 2022 15:26:48 -0400 Subject: [PATCH 12/21] feat: add implementation for valid padding --- src/model.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/model.jl b/src/model.jl index b143ff4..166d3c6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,6 @@ using Flux using Flux: @functor + struct ConvBlock op end @@ -21,7 +22,6 @@ function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)]; end function (m::ConvBlock)(x) - # println(size(x)) return m.op(x) end @@ -52,12 +52,10 @@ function Downsample(downsample_factor; pooling_type="max") return Downsample(downop, downsample_factor, pooling_type) end -struct RuntimeError <: Exception end - function (m::Downsample)(x) for (d, x_s, f_s) in zip(1: length(m.factor), size(x), m.factor) if (mod(x_s, f_s) !=0) - throw(RuntimeError("Can not downsample $(size(x)) with factor $(m.factor), mismatch in spatial dimension $d")) + throw(DimensionMismatch("Can not downsample $(size(x)) with factor $(m.factor), mismatch in spatial dimension $d")) end end return m.op(x) @@ -79,18 +77,22 @@ function (m::Upsample)(x) return m.op(x) end -function crop(x, size) - target_size = size(x) - offset = Tuple((a-b)÷2 for (a,b) in zip(size(x), target_size)) +function crop(x, target_size) + if (size(x) == target_size) + return x + else + offset = Tuple((a-b)÷2+1 for (a,b) in zip(size(x), target_size)) + slice = Tuple(o:o+t-1 for (o,t) in zip(offset,target_size)) + return x[slice...,:,:] + end end function(m::Upsample)(x, y) #todo: crop_to_factor - g_cropped = y - # f_cropped = crop(m(x), size(g_cropped)[:length(m.factor)]) - f_cropped = m(x) - new_arr = cat(f_cropped, g_cropped; dims=length(m.factor)+1) - # println("CONCAT", size(new_arr)) + g_up = m(x) + k = size(g_up)[1:length(m.factor)] + f_cropped = crop(y, k) + new_arr = cat(f_cropped, g_up; dims=length(m.factor)+1) return new_arr end From d5d4c8ee2ebbbddc56549fdb66f6145f18ee5b15 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Wed, 13 Apr 2022 15:21:48 -0400 Subject: [PATCH 13/21] keep examples dependencies separate --- Project.toml | 8 -------- examples/Project.toml | 9 +++++++++ 2 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 examples/Project.toml diff --git a/Project.toml b/Project.toml index 5cb2b0e..a8fc41f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,19 +6,11 @@ version = "0.2.1" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" -ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" -IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" -NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" -Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" -View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" [compat] Distributions = "0.20, 0.21, 0.22, 0.23, 0.25" diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..7ff5f85 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,9 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" +ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" +IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" +Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" +NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" +View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" From 467d5bf262afbaf2f60ad827e8426f221c0d4783 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Wed, 13 Apr 2022 15:26:51 -0400 Subject: [PATCH 14/21] add Flux back --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a8fc41f..1804d27 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.1" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" From 6179fcf8ac9b1e56a1ed4f2243d4472d5e394647 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Thu, 21 Apr 2022 09:35:46 -0400 Subject: [PATCH 15/21] add parametrization --- examples/Project.toml | 2 +- src/model.jl | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 7ff5f85..669f1b5 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -3,7 +3,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" -Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" +Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" diff --git a/src/model.jl b/src/model.jl index 166d3c6..ff0bdea 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,8 +1,8 @@ using Flux using Flux: @functor -struct ConvBlock - op +struct ConvBlock{T} + op::T end @functor ConvBlock @@ -30,10 +30,10 @@ end # BatchNormWrap(out_chs), # x->leakyrelu.(x,0.2f0)) -struct Downsample - op - factor - pooling_type +struct Downsample{T1, T2, T3} + op::T1 + factor::T2 + pooling_type::T3 end function Base.show(io::IO, d::Downsample) @@ -61,9 +61,9 @@ function (m::Downsample)(x) return m.op(x) end -struct Upsample - op - factor +struct Upsample{T1, T2} + op::T1 + factor::T2 end @functor Upsample @@ -97,13 +97,13 @@ function(m::Upsample)(x, y) end # holds the information on the unet structure -struct Unet - num_levels - l_conv_chain - l_down_chain - r_up_chain - r_conv_chain - final_conv +struct Unet{T1, T2, T3} + num_levels::T1 + l_conv_chain::T2 + l_down_chain::T2 + r_up_chain::T2 + r_conv_chain::T2 + final_conv::T3 end @functor Unet From fc194a5199c67dbe3c50d5c01e9af18c7c394b2c Mon Sep 17 00:00:00 2001 From: Larissa Heinrich <7736327+neptunes5thmoon@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:40:23 -0400 Subject: [PATCH 16/21] declutter print Co-authored-by: Dhairya Gandhi --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index ff0bdea..adf456d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -37,7 +37,7 @@ struct Downsample{T1, T2, T3} end function Base.show(io::IO, d::Downsample) - print(io, "Unet.Downsample($(d.factor), $(d.pooling_type))") + print(io, "Downsample($(d.factor), $(d.pooling_type))") end @functor Downsample From 1bc3d0449f5f8cbc3296ec9fc26a1882da5ddff6 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Thu, 21 Apr 2022 16:53:16 -0400 Subject: [PATCH 17/21] only conv chains should share type --- src/model.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/model.jl b/src/model.jl index adf456d..69bd00f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -97,13 +97,13 @@ function(m::Upsample)(x, y) end # holds the information on the unet structure -struct Unet{T1, T2, T3} +struct Unet{T1, T2, T3, T4, T5} num_levels::T1 l_conv_chain::T2 - l_down_chain::T2 - r_up_chain::T2 + l_down_chain::T3 + r_up_chain::T4 r_conv_chain::T2 - final_conv::T3 + final_conv::T5 end @functor Unet From e547dfcc6d352cbbcc02d244eaaded4506d7076e Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Thu, 21 Apr 2022 17:10:13 -0400 Subject: [PATCH 18/21] separate conv chain types Even for conv chains the type can't be shared since different numbers of convolutions are permissible --- src/model.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 69bd00f..47faa6e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -97,13 +97,13 @@ function(m::Upsample)(x, y) end # holds the information on the unet structure -struct Unet{T1, T2, T3, T4, T5} +struct Unet{T1, T2, T3, T4, T5, T6} num_levels::T1 l_conv_chain::T2 l_down_chain::T3 r_up_chain::T4 - r_conv_chain::T2 - final_conv::T5 + r_conv_chain::T5 + final_conv::T6 end @functor Unet From c74ec869a71479da4fdf44d0ed45da1b05bc9741 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 13 Jun 2024 16:27:05 +0200 Subject: [PATCH 19/21] updated version numbers of packages. Adapted noise2noise exampel --- Project.toml | 6 +++--- examples/Project.toml | 3 +++ examples/noise2noise.jl | 24 +++++++++++++----------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 926981f..b196f37 100644 --- a/Project.toml +++ b/Project.toml @@ -17,10 +17,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Distributions = "0.20, 0.21, 0.22, 0.23, 0.25" FileIO = "1" Flux = "0.10, 0.11, 0.12, 0.14" -ImageCore = "0.8, 0.9" -ImageTransformations = "0.8, 0.9" +ImageCore = "0.8, 0.9, 0.10" +ImageTransformations = "0.8, 0.9, 0.10" Reexport = "0.2, 1" -StatsBase = "0.30, 0.33" +StatsBase = "0.30, 0.33, 0.34" julia = "1.6" [extras] diff --git a/examples/Project.toml b/examples/Project.toml index 669f1b5..fef5f63 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,4 +1,5 @@ [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e" ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" @@ -6,4 +7,6 @@ IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa" TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" +UNet = "0d73aaa9-994a-4556-95d0-da67cb772a03" View5D = "90d841e0-6953-4e90-9f3a-43681da8e949" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index 682baab..b0a524a 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -1,17 +1,18 @@ # Example using U-net for a noise2noise problem -using UNet, Flux, TestImages, View5D, Noise, NDTools +using UNet, Flux, TestImages, View5D, Noise, NDTools, CUDA img = 10.0 .* Float32.(testimage("resolution_test_512")) u = Unet(); u = gpu(u); -function loss(x, y) +function loss(u, x, y) # return mean(abs2.(u(x) .-y)) - return Flux.mse(u(x),y) + return Flux.mse(u(x), y) end -opt = Momentum() +# opt = Momentum() +opt_state = Flux.setup(Momentum(), u); # selects a tile at a random (default) or predifined (ctr) position returning tile and center. function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) @@ -21,18 +22,19 @@ end sz = size(img); scale = 0.5/maximum(img) patch = (128,128) -for n in 1:100 +for n in 1:1000 println("Iteration: $n") - myimg, pos = get_tile(img,patch) + myimg, pos = get_tile(img, patch) # image to denoise - nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) + nimg1 = gpu(scale.*reshape(poisson(myimg), (size(myimg)...,1,1))) # goal image (with noise) - nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) - rep = Iterators.repeated((nimg1, nimg2), 1); - Flux.train!(loss, Flux.params(u), rep, opt) + nimg2 = gpu(scale.*reshape(poisson(myimg), (size(myimg)...,1,1))) + rep = [(nimg1, nimg2),] # Iterators.repeated((nimg1, nimg2), 1); + # Flux.train!(loss, Flux.params(u), rep, opt_state) + Flux.train!(loss, u, rep, opt_state) end # apply the net to the whole image instead: nimg = gpu(scale.*reshape(poisson(img),(size(img)...,1,1))); # display the images using View5D -@ve img nimg u(nimg) +@vt img nimg u(nimg) From 4fdd9baea43b31647cd83110afa0dd83510b776e Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 13 Jun 2024 16:47:40 +0200 Subject: [PATCH 20/21] updated deconvoltion.jl --- examples/deconvolve.jl | 19 ++++++++++--------- examples/noise2noise.jl | 7 +++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/deconvolve.jl b/examples/deconvolve.jl index 7b91dec..eea7405 100644 --- a/examples/deconvolve.jl +++ b/examples/deconvolve.jl @@ -7,10 +7,11 @@ img = 100f0 .* Float32.(testimage("resolution_test_512")) u = Unet(); u = gpu(u); -function loss(x, y) +function loss(u, x, y) return Flux.mse(u(x),y) end -opt = Momentum() + +opt_state = Flux.setup(Momentum(), u); # selects a tile at a random (default) or predifined (ctr) position returning tile and center. function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) ) @@ -21,22 +22,22 @@ R_max = 70; sz = size(img); psf = abs2.(ift(disc(Float32, sz, R_max))); psf ./= sum(psf); conv_img = conv_psf(img,psf); scale = 0.5f0/maximum(conv_img) -patch = (128,128) +patch = (128, 128) for n in 1:2000 println("Iteration: $n") - myimg, pos = get_tile(conv_img,patch) + myimg, pos = get_tile(conv_img, patch) # image to denoise # nimg1 = gpu(reshape(scale .* myimg,(size(myimg)...,1,1))); # gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1))) - nimg1 = gpu(scale.*reshape(poisson(Float64.(myimg)),(size(myimg)...,1,1))) + nimg1 = gpu(Float32.(scale.*reshape(poisson(Float64.(myimg)), (size(myimg)...,1,1)))) # goal image (with noise) - pimg, pos = get_tile(img,patch,pos) + pimg, pos = get_tile(img, patch, pos) pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1))) - rep = Iterators.repeated((nimg1, pimg), 4); - Flux.train!(loss, Flux.params(u), rep, opt) + rep = [(nimg1, pimg)] # Iterators.repeated((nimg1, pimg), 1); + Flux.train!(loss, u, rep, opt_state) end # apply the net to the whole image instead: nimg = gpu(scale .* reshape(conv_img,(size(conv_img)...,1,1))); # gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1))) nimg2 = gpu(scale.*reshape(poisson(Float64.(conv_img)),(size(conv_img)...,1,1))) # display the images using View5D -@ve img nimg u(nimg) nimg2 u(nimg2) +@vt img nimg u(nimg) nimg2 u(nimg2) diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index b0a524a..a717f50 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -2,7 +2,7 @@ using UNet, Flux, TestImages, View5D, Noise, NDTools, CUDA -img = 10.0 .* Float32.(testimage("resolution_test_512")) +img = 10f0 .* Float32.(testimage("resolution_test_512")) u = Unet(); @@ -11,7 +11,6 @@ function loss(u, x, y) # return mean(abs2.(u(x) .-y)) return Flux.mse(u(x), y) end -# opt = Momentum() opt_state = Flux.setup(Momentum(), u); # selects a tile at a random (default) or predifined (ctr) position returning tile and center. @@ -20,8 +19,8 @@ function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img end sz = size(img); -scale = 0.5/maximum(img) -patch = (128,128) +scale = 0.5f0/maximum(img) +patch = (128, 128) for n in 1:1000 println("Iteration: $n") myimg, pos = get_tile(img, patch) From 597cb4fa98e4c3a80cf0305e18b421e824e4c3e5 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 13 Jun 2024 18:00:53 +0200 Subject: [PATCH 21/21] bug fixes and chages according to comments --- examples/noise2noise.jl | 6 ++-- src/model.jl | 74 ++++++++++++++++++++++------------------- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/examples/noise2noise.jl b/examples/noise2noise.jl index a717f50..bba0c6c 100644 --- a/examples/noise2noise.jl +++ b/examples/noise2noise.jl @@ -25,15 +25,15 @@ for n in 1:1000 println("Iteration: $n") myimg, pos = get_tile(img, patch) # image to denoise - nimg1 = gpu(scale.*reshape(poisson(myimg), (size(myimg)...,1,1))) + nimg1 = gpu(scale.*reshape(Float32.(poisson(Float64.(myimg))), (size(myimg)...,1,1))) # goal image (with noise) - nimg2 = gpu(scale.*reshape(poisson(myimg), (size(myimg)...,1,1))) + nimg2 = gpu(scale.*reshape(Float32.(poisson(Float64.(myimg))), (size(myimg)...,1,1))) rep = [(nimg1, nimg2),] # Iterators.repeated((nimg1, nimg2), 1); # Flux.train!(loss, Flux.params(u), rep, opt_state) Flux.train!(loss, u, rep, opt_state) end # apply the net to the whole image instead: -nimg = gpu(scale.*reshape(poisson(img),(size(img)...,1,1))); +nimg = gpu(scale.*reshape(Float32.(poisson(Float64.(img))),(size(img)...,1,1))); # display the images using View5D @vt img nimg u(nimg) diff --git a/src/model.jl b/src/model.jl index 47faa6e..32c06ec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,8 +7,22 @@ end @functor ConvBlock -function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)]; +function ConvBlock(in_out_pair::Pair; kwargs...) + return ConvBlock([(3, 3), (3, 3)], in_out_pair; kwargs...) +end + +""" + ConvBlock([kernel_sizes,] in_channels, out_channels; activation = NNlib.relu, padding = "valid") + +creates a convolution block with the UNet. The optional kernel sizes are given as a vector of tuples, +where each tuple represents the kernel size of a convolution layer. Default is [(3, 3), (3, 3)]. +The in_channels and out_channels are the number of input and output channels of the block. +The in_channels and out_channels have to be given as a Pair of integers (in_channels => out_channels). +The activation function is applied after each convolution layer. +""" +function ConvBlock(kernel_sizes, in_out_pair::Pair; activation = NNlib.relu, padding = "valid") + in_channels, out_channels = in_out_pair pad_arg = padding == "same" ? SamePad() : 0 conv_layers = Any[] in_channels_it = in_channels @@ -33,23 +47,18 @@ end struct Downsample{T1, T2, T3} op::T1 factor::T2 - pooling_type::T3 + pooling_func::T3 end function Base.show(io::IO, d::Downsample) - print(io, "Downsample($(d.factor), $(d.pooling_type))") + print(io, "Downsample($(d.factor), $(d.pooling_func))") end @functor Downsample -function Downsample(downsample_factor; pooling_type="max") - if (pooling_type == "max") - downop = x -> NNlib.maxpool(x, downsample_factor, pad=0) - else - downop = x -> NNlib.meanpool(x, downsample_factor, pad=0) - pooling_type = "mean" - end - return Downsample(downop, downsample_factor, pooling_type) +function Downsample(downsample_factor; pooling_func = NNlib.maxpool) + downop = x -> pooling_func(x, downsample_factor, pad=0) + return Downsample(downop, downsample_factor, pooling_func) end function (m::Downsample)(x) @@ -68,8 +77,8 @@ end @functor Upsample -function Upsample(scale_factor, in_channels, out_channels) - upop = ConvTranspose(scale_factor, in_channels=>out_channels, stride=scale_factor) +function Upsample(scale_factor, in_out_pair::Pair) + upop = ConvTranspose(scale_factor, in_out_pair, stride=scale_factor) return Upsample(upop, scale_factor) end @@ -110,8 +119,7 @@ end """ function Unet(; - in_channels = 1, - out_channels = 1, + in_out_channels_pair = (1 => 1), num_fmaps = 64, fmap_inc_factor = 2, downsample_factors = [(2,2),(2,2),(2,2),(2,2)], @@ -120,7 +128,7 @@ function Unet(; activation = NNlib.relu, final_activation = NNlib.relu; padding="same", - pooling_type="max" + pooling_func = NNlib.maxpool ) creates a U-net model that can then be used to be trained and to perform predictions. A UNet consists of an initial layer to create feature maps, controlled via `num_fmaps`. This is followed by downsampling and umsampling steps, @@ -129,9 +137,7 @@ function Unet(; and `kernel_sizes_up` respectively. # Paramers -+ `in_channels`: channels of the input to the U-net - -+ `out_channels`: channels of the output of the U-net ++ `in_out_channels_pair`: channels of the input to the U-net and channels of the output of the U-net as a Pair of integers + `num_fmaps`: number of feature maps that the input gets expanded to in the first step @@ -151,15 +157,14 @@ function Unet(; + `padding="valid"`: method of padding during convolution and upsampling -+ `pooling_type="max"`: type of pooling ++ `pooling_func` = NNlib.maxpool # Example ```jldoctest ``` """ function Unet(; # all arguments are named and ahve defaults - in_channels = 1, - out_channels = 1, + in_out_channels_pair = (1 => 1), num_fmaps = 64, fmap_inc_factor = 2, downsample_factors = [(2,2),(2,2),(2,2),(2,2)], @@ -168,17 +173,19 @@ function Unet(; # all arguments are named and ahve defaults activation = NNlib.relu, final_activation = NNlib.relu, padding ="same", - pooling_type ="max" + pooling_func = NNlib.maxpool ) + in_channels, out_channels = in_out_channels_pair num_levels = length(downsample_factors) + 1 dims = length(downsample_factors[1]) l_convs = Any[] for level in 1:num_levels in_ch = (level == 1) ? in_channels : num_fmaps * fmap_inc_factor ^ (level - 2) - cb = ConvBlock(in_ch, - num_fmaps * fmap_inc_factor ^ (level - 1), + cb = ConvBlock( kernel_sizes_down[level], + in_ch => + num_fmaps * fmap_inc_factor ^ (level - 1), activation=activation, padding=padding ) @@ -190,7 +197,7 @@ function Unet(; # all arguments are named and ahve defaults push!(l_downs, Downsample( downsample_factors[level]; - pooling_type=pooling_type + pooling_func = pooling_func ) ) end @@ -200,8 +207,7 @@ function Unet(; # all arguments are named and ahve defaults push!(r_ups, Upsample( downsample_factors[level], - num_fmaps * fmap_inc_factor ^ level, - num_fmaps * fmap_inc_factor ^ level + num_fmaps * fmap_inc_factor ^ level => num_fmaps * fmap_inc_factor ^ level ) ) end @@ -210,10 +216,9 @@ function Unet(; # all arguments are named and ahve defaults for level in 1:num_levels - 1 push!(r_convs, ConvBlock( - num_fmaps * fmap_inc_factor ^ (level - 1) + - num_fmaps * fmap_inc_factor ^ level, - num_fmaps * fmap_inc_factor ^ (level - 1), kernel_sizes_up[level], + num_fmaps * fmap_inc_factor ^ (level - 1) + + num_fmaps * fmap_inc_factor ^ level => num_fmaps * fmap_inc_factor ^ (level - 1), activation=activation, padding=padding ) @@ -221,9 +226,8 @@ function Unet(; # all arguments are named and ahve defaults end final_conv = ConvBlock( - num_fmaps, - out_channels, - [Tuple(1 for i in 1:dims)], + [ntuple((i) -> 1, dims)], + num_fmaps => out_channels, activation=final_activation, padding=padding ) @@ -243,7 +247,7 @@ function (m::Unet)(x::AbstractArray; level=1) m.r_conv_chain[level](fs_right) end end - + if (level == 1) return m.final_conv(fs_out) else