From bf11707168833374422847adf1ed5563493dfef2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 15:40:20 -0400 Subject: [PATCH] fix: DDIM updates and fix argument ordering --- examples/DDIM/Project.toml | 4 --- examples/DDIM/main.jl | 68 +++++++++++++++++++------------------- src/helpers/stateful.jl | 10 +++--- 3 files changed, 40 insertions(+), 42 deletions(-) diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index 42a76263b..461bf2222 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -1,8 +1,6 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" @@ -25,10 +23,8 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.9.6, 1" ArgCheck = "2.3.0" CairoMakie = "0.12" -ChainRulesCore = "1.23" Comonicon = "1" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index 6e81b88f8..1a0039541 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -6,11 +6,10 @@ # ## Package Imports -using ArgCheck, CairoMakie, ChainRulesCore, ConcreteStructs, Comonicon, DataAugmentation, - DataDeps, FileIO, ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, - ParameterSchedulers, ProgressBars, Random, Setfield, StableRNGs, Statistics, Zygote +using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO, + ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars, + Random, Setfield, StableRNGs, Statistics, Zygote using TensorBoardLogger: TBLogger, log_value, log_images -const CRC = ChainRulesCore CUDA.allowscalar(false) @@ -130,24 +129,22 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0, max_signal_rate, dispatch=:DDIM) do x::AbstractArray{<:Real, 4} images = bn(x) rng = Lux.replicate(rng) - T = eltype(x) - noises = CRC.@ignore_derivatives randn!(rng, similar(images, T, size(images)...)) - diffusion_times = CRC.@ignore_derivatives rand!( - rng, similar(images, T, 1, 1, 1, size(images, 4))) + noises = rand_like(rng, images) + diffusion_times = rand_like(rng, images, (1, 1, 1, size(images, 4))) - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) noisy_images = @. signal_rates * images + noise_rates * noises - pred_noises, pred_images = __denoise(unet, noisy_images, noise_rates, signal_rates) + pred_noises, pred_images = denoise(unet, noisy_images, noise_rates, signal_rates) @return noises, images, pred_noises, pred_images end end -function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, +function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, max_signal_rate::T) where {T <: Real} start_angle = acos(max_signal_rate) end_angle = acos(min_signal_rate) @@ -160,8 +157,7 @@ function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_ return noise_rates, signal_rates end -function __denoise( - unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, +function denoise(unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, signal_rates::AbstractArray{T, 4}) where {T <: Real} pred_noises = unet((noisy_images, noise_rates .^ 2)) pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates @@ -170,7 +166,7 @@ end # ## Helper Functions for Image Generation -function __reverse_diffusion( +function reverse_diffusion( model, initial_noise::AbstractArray{T, 4}, diffusion_steps::Int) where {T <: Real} num_images = size(initial_noise, 4) step_size = one(T) / diffusion_steps @@ -188,15 +184,15 @@ function __reverse_diffusion( # We start t = 1, and gradually decreases to t=0 diffusion_times = (ones(T, 1, 1, 1, num_images) .- step_size * step) |> dev - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) - pred_noises, pred_images = __denoise( + pred_noises, pred_images = denoise( StatefulLuxLayer{true}(model.model.layers.unet, model.ps.unet, model.st.unet), noisy_images, noise_rates, signal_rates) next_diffusion_times = diffusion_times .- step_size - next_noisy_rates, next_signal_rates = __diffusion_schedules( + next_noisy_rates, next_signal_rates = diffusion_schedules( next_diffusion_times, min_signal_rate, max_signal_rate) next_noisy_images = next_signal_rates .* pred_images .+ @@ -206,14 +202,14 @@ function __reverse_diffusion( return pred_images end -function __denormalize(model::StatefulLuxLayer{true}, x::AbstractArray{<:Real, 4}) +function denormalize(model::StatefulLuxLayer, x::AbstractArray{<:Real, 4}) mean = reshape(model.st.bn.running_mean, 1, 1, 3, 1) var = reshape(model.st.bn.running_var, 1, 1, 3, 1) std = sqrt.(var .+ model.model.layers.bn.epsilon) return std .* x .+ mean end -function __save_images(output_dir, images::AbstractArray{<:Real, 4}) +function save_images(output_dir, images::AbstractArray{<:Real, 4}) imgs = Vector{Array{RGB, 2}}(undef, size(images, 4)) for i in axes(images, 4) img = @view images[:, :, :, i] @@ -224,7 +220,7 @@ function __save_images(output_dir, images::AbstractArray{<:Real, 4}) return imgs end -function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) +function generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) fig = Figure() nrows, ncols = 3, 4 for r in 1:nrows, c in 1:ncols @@ -238,11 +234,11 @@ function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray return end -function __generate( +function generate( model::StatefulLuxLayer, rng, image_size::NTuple{4, Int}, diffusion_steps::Int, dev) initial_noise = randn(rng, Float32, image_size...) |> dev - generated_images = __reverse_diffusion(model, initial_noise, diffusion_steps) - generated_images = __denormalize(model, generated_images) + generated_images = reverse_diffusion(model, initial_noise, diffusion_steps) + generated_images = denormalize(model, generated_images) return clamp01.(generated_images) end @@ -287,21 +283,23 @@ function Base.getindex(ds::FlowersDataset, i::Int) end function preprocess_image(image::Matrix{<:RGB}, image_size::Int) - return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata + return apply( + CenterResizeCrop((image_size, image_size)), DataAugmentation.Image(image)) |> + itemdata end const maeloss = MAELoss() function loss_function(model, ps, st, data) (noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st) - noise_loss = maeloss(noises, pred_noises) - image_loss = maeloss(images, pred_images) + noise_loss = maeloss(pred_noises, noises) + image_loss = maeloss(pred_images, images) return noise_loss, st, (; image_loss, noise_loss) end # ## Entry Point for our code -@main function main(; epochs::Int=100, image_size::Int=128, +Comonicon.@main function main(; epochs::Int=100, image_size::Int=128, batchsize::Int=128, learning_rate_start::Float32=1.0f-3, learning_rate_end::Float32=1.0f-5, weight_decay::Float32=1.0f-6, checkpoint_interval::Int=25, expt_dir=tempname(@__DIR__), @@ -316,7 +314,8 @@ end @info "Experiment directory: $(expt_dir)" - rng = StableRNG(1234) + rng = Random.default_rng() + Random.seed!(rng, 1234) image_dir = joinpath(expt_dir, "images") isdir(image_dir) || mkpath(image_dir) @@ -339,19 +338,20 @@ end states = states |> gdev model = StatefulLuxLayer{true}(model, parameters, Lux.testmode(states)) - generated_images = __generate(model, StableRNG(generate_image_seed), + generated_images = generate(model, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "inference") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) - __generate_and_save_image_grid(path, imgs) + imgs = save_images(path, generated_images) + generate_and_save_image_grid(path, imgs) return end tb_dir = joinpath(expt_dir, "tb_logs") - @info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`" + @info "Tensorboard logs being saved to $(tb_dir). Run tensorboard with \ + `tensorboard --logdir $(dirname(tb_dir))`" tb_logger = TBLogger(tb_dir) tstate = Training.TrainState( @@ -393,13 +393,13 @@ end if epoch % generate_image_interval == 0 || epoch == epochs model_test = StatefulLuxLayer{true}( tstate.model, tstate.parameters, Lux.testmode(tstate.states)) - generated_images = __generate(model_test, StableRNG(generate_image_seed), + generated_images = generate(model_test, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "epoch_$(epoch)") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) + imgs = save_images(path, generated_images) log_images(tb_logger, "Generated Images", imgs; step) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 02c57eeaf..0fdf475ee 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -121,10 +121,12 @@ function (s::StatefulLuxLayer)(x, p=s.ps) return y end -function CRC.rrule( - ::Type{<:StatefulLuxLayer{FT}}, model::AbstractLuxLayer, ps, st, st_any) where {FT} - slayer = StatefulLuxLayer{FT}(model, ps, st, st_any) - ∇StatefulLuxLayer(Δ) = NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent() +function CRC.rrule(::Type{<:StatefulLuxLayer}, model::AbstractLuxLayer, + ps, st, st_any, fixed_state_type) + slayer = StatefulLuxLayer(model, ps, st, st_any, fixed_state_type) + function ∇StatefulLuxLayer(Δ) + return NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent(), NoTangent() + end return slayer, ∇StatefulLuxLayer end