Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Unet configurable #26

Open
wants to merge 23 commits into
base: master
Choose a base branch
from

Conversation

neptunes5thmoon
Copy link

Hi @DhairyaLGandhi

I often need flexibility with the hyperparameters of the UNet, i.e. different numbers of downsamplings and convolutions, kernel sizes etc. So we started implementing this more configurable version.

This currently allows for a UNet with configurable

  • number of input and output channels
  • initial number of feature maps
  • factor by which number of feature maps should be multiplied with each downsampling
  • downsampling method, factors and number of downsampling steps
  • kernel sizes and number of convolutions
  • activation functions
  • "same" or "valid" padding

Let me know if this is something you would be interested in including here.

@DhairyaLGandhi
Copy link
Owner

This sounds like. a good idea to support! I will take a look at the changes in the PR as well!

src/model.jl Outdated
Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal),
BatchNormWrap(out_chs),
x->leakyrelu.(x,0.2f0))
struct ConvBlock
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be parameterised

Suggested change
struct ConvBlock
struct ConvBlock{T}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

src/model.jl Outdated

struct UNetUpBlock
upsample
function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)];
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can mirror the Flux Conv API here - this would mean we switch out the positions of the kernels, and make the channels as a Pair. This does lose out on having default kernel pairs though. Are we expecting to not need to define the kernels much? I suppose we wouldn't need to, but maintaining that consistency in the API would be a plus. What do you think?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea. There is a way to allow optional first arguments, since the number of arguments changes. Therefore we made the first argument now optional, followed by the mentioned Pair.

src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated
stride=(2, 2);init=_random_normal),
BatchNormWrap(out_chs),
Dropout(p)))
struct Downsample
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about paramterization as earlier

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

src/model.jl Outdated
conv_down_blocks
conv_blocks
up_blocks
function Downsample(downsample_factor; pooling_type="max")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we took the function maxpool or meanpool directly, we can get rid of the conditional.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Changed accordingly.

```jldoctest
```
"""
function Unet(; # all arguments are named and ahve defaults
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about retaining some of the positional argument versions as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean as above to mirror the flux API? Or are there specific arguments you think should be positional?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the channels can be made a Pair in => out

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

neptunes5thmoon and others added 4 commits April 21, 2022 09:35
Co-authored-by: Dhairya Gandhi <[email protected]>
Even for conv chains the type can't be shared since different numbers of convolutions are permissible
@mkitti
Copy link

mkitti commented Feb 7, 2023

Following. Where did we end up here?

@neptunes5thmoon
Copy link
Author

Hi @mkitti
I was waiting for some feedback to finish my updates, then got sidetracked by other projects. Feel free to add any edits/ideas/suggestions you might have!

Copy link
Owner

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm, got side tracked 😄

I've added a couple comments, but it looks mostly good! I would love to add this in!

```jldoctest
```
"""
function Unet(; # all arguments are named and ahve defaults
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the channels can be made a Pair in => out

src/model.jl Outdated
activation = NNlib.relu,
final_activation = NNlib.relu,
padding ="same",
pooling_type ="max"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to take in the function directly if possible - that way we simply call the input function and users can specify their own polling if preferable

@neptunes5thmoon
Copy link
Author

I think we addressed all your suggestions. Could you have another look @DhairyaLGandhi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants