Skip to content

Commit

Permalink
fix: DDIM updates and fix argument ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 6, 2024
1 parent 3ccb1cc commit bf11707
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
4 changes: 0 additions & 4 deletions examples/DDIM/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Expand All @@ -25,10 +23,8 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AMDGPU = "0.9.6, 1"
ArgCheck = "2.3.0"
CairoMakie = "0.12"
ChainRulesCore = "1.23"
Comonicon = "1"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Expand Down
68 changes: 34 additions & 34 deletions examples/DDIM/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

# ## Package Imports

using ArgCheck, CairoMakie, ChainRulesCore, ConcreteStructs, Comonicon, DataAugmentation,
DataDeps, FileIO, ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers,
ParameterSchedulers, ProgressBars, Random, Setfield, StableRNGs, Statistics, Zygote
using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO,
ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars,
Random, Setfield, StableRNGs, Statistics, Zygote
using TensorBoardLogger: TBLogger, log_value, log_images
const CRC = ChainRulesCore

CUDA.allowscalar(false)

Expand Down Expand Up @@ -130,24 +129,22 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0,
max_signal_rate, dispatch=:DDIM) do x::AbstractArray{<:Real, 4}
images = bn(x)
rng = Lux.replicate(rng)
T = eltype(x)

noises = CRC.@ignore_derivatives randn!(rng, similar(images, T, size(images)...))
diffusion_times = CRC.@ignore_derivatives rand!(
rng, similar(images, T, 1, 1, 1, size(images, 4)))
noises = rand_like(rng, images)
diffusion_times = rand_like(rng, images, (1, 1, 1, size(images, 4)))

noise_rates, signal_rates = __diffusion_schedules(
noise_rates, signal_rates = diffusion_schedules(
diffusion_times, min_signal_rate, max_signal_rate)

noisy_images = @. signal_rates * images + noise_rates * noises

pred_noises, pred_images = __denoise(unet, noisy_images, noise_rates, signal_rates)
pred_noises, pred_images = denoise(unet, noisy_images, noise_rates, signal_rates)

@return noises, images, pred_noises, pred_images
end
end

function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T,
function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T,
max_signal_rate::T) where {T <: Real}
start_angle = acos(max_signal_rate)
end_angle = acos(min_signal_rate)
Expand All @@ -160,8 +157,7 @@ function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_
return noise_rates, signal_rates
end

function __denoise(
unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4},
function denoise(unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4},
signal_rates::AbstractArray{T, 4}) where {T <: Real}
pred_noises = unet((noisy_images, noise_rates .^ 2))
pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates
Expand All @@ -170,7 +166,7 @@ end

# ## Helper Functions for Image Generation

function __reverse_diffusion(
function reverse_diffusion(
model, initial_noise::AbstractArray{T, 4}, diffusion_steps::Int) where {T <: Real}
num_images = size(initial_noise, 4)
step_size = one(T) / diffusion_steps
Expand All @@ -188,15 +184,15 @@ function __reverse_diffusion(
# We start t = 1, and gradually decreases to t=0
diffusion_times = (ones(T, 1, 1, 1, num_images) .- step_size * step) |> dev

noise_rates, signal_rates = __diffusion_schedules(
noise_rates, signal_rates = diffusion_schedules(
diffusion_times, min_signal_rate, max_signal_rate)

pred_noises, pred_images = __denoise(
pred_noises, pred_images = denoise(
StatefulLuxLayer{true}(model.model.layers.unet, model.ps.unet, model.st.unet),
noisy_images, noise_rates, signal_rates)

next_diffusion_times = diffusion_times .- step_size
next_noisy_rates, next_signal_rates = __diffusion_schedules(
next_noisy_rates, next_signal_rates = diffusion_schedules(
next_diffusion_times, min_signal_rate, max_signal_rate)

next_noisy_images = next_signal_rates .* pred_images .+
Expand All @@ -206,14 +202,14 @@ function __reverse_diffusion(
return pred_images
end

function __denormalize(model::StatefulLuxLayer{true}, x::AbstractArray{<:Real, 4})
function denormalize(model::StatefulLuxLayer, x::AbstractArray{<:Real, 4})
mean = reshape(model.st.bn.running_mean, 1, 1, 3, 1)
var = reshape(model.st.bn.running_var, 1, 1, 3, 1)
std = sqrt.(var .+ model.model.layers.bn.epsilon)
return std .* x .+ mean
end

function __save_images(output_dir, images::AbstractArray{<:Real, 4})
function save_images(output_dir, images::AbstractArray{<:Real, 4})
imgs = Vector{Array{RGB, 2}}(undef, size(images, 4))
for i in axes(images, 4)
img = @view images[:, :, :, i]
Expand All @@ -224,7 +220,7 @@ function __save_images(output_dir, images::AbstractArray{<:Real, 4})
return imgs
end

function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}})
function generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}})
fig = Figure()
nrows, ncols = 3, 4
for r in 1:nrows, c in 1:ncols
Expand All @@ -238,11 +234,11 @@ function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray
return
end

function __generate(
function generate(
model::StatefulLuxLayer, rng, image_size::NTuple{4, Int}, diffusion_steps::Int, dev)
initial_noise = randn(rng, Float32, image_size...) |> dev
generated_images = __reverse_diffusion(model, initial_noise, diffusion_steps)
generated_images = __denormalize(model, generated_images)
generated_images = reverse_diffusion(model, initial_noise, diffusion_steps)
generated_images = denormalize(model, generated_images)
return clamp01.(generated_images)
end

Expand Down Expand Up @@ -287,21 +283,23 @@ function Base.getindex(ds::FlowersDataset, i::Int)
end

function preprocess_image(image::Matrix{<:RGB}, image_size::Int)
return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata
return apply(
CenterResizeCrop((image_size, image_size)), DataAugmentation.Image(image)) |>
itemdata
end

const maeloss = MAELoss()

function loss_function(model, ps, st, data)
(noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st)
noise_loss = maeloss(noises, pred_noises)
image_loss = maeloss(images, pred_images)
noise_loss = maeloss(pred_noises, noises)
image_loss = maeloss(pred_images, images)
return noise_loss, st, (; image_loss, noise_loss)
end

# ## Entry Point for our code

@main function main(; epochs::Int=100, image_size::Int=128,
Comonicon.@main function main(; epochs::Int=100, image_size::Int=128,
batchsize::Int=128, learning_rate_start::Float32=1.0f-3,
learning_rate_end::Float32=1.0f-5, weight_decay::Float32=1.0f-6,
checkpoint_interval::Int=25, expt_dir=tempname(@__DIR__),
Expand All @@ -316,7 +314,8 @@ end

@info "Experiment directory: $(expt_dir)"

rng = StableRNG(1234)
rng = Random.default_rng()
Random.seed!(rng, 1234)

image_dir = joinpath(expt_dir, "images")
isdir(image_dir) || mkpath(image_dir)
Expand All @@ -339,19 +338,20 @@ end
states = states |> gdev
model = StatefulLuxLayer{true}(model, parameters, Lux.testmode(states))

generated_images = __generate(model, StableRNG(generate_image_seed),
generated_images = generate(model, StableRNG(generate_image_seed),
(image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |>
cpu_device()

path = joinpath(image_dir, "inference")
@info "Saving generated images to $(path)"
imgs = __save_images(path, generated_images)
__generate_and_save_image_grid(path, imgs)
imgs = save_images(path, generated_images)
generate_and_save_image_grid(path, imgs)
return
end

tb_dir = joinpath(expt_dir, "tb_logs")
@info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`"
@info "Tensorboard logs being saved to $(tb_dir). Run tensorboard with \
`tensorboard --logdir $(dirname(tb_dir))`"
tb_logger = TBLogger(tb_dir)

tstate = Training.TrainState(
Expand Down Expand Up @@ -393,13 +393,13 @@ end
if epoch % generate_image_interval == 0 || epoch == epochs
model_test = StatefulLuxLayer{true}(
tstate.model, tstate.parameters, Lux.testmode(tstate.states))
generated_images = __generate(model_test, StableRNG(generate_image_seed),
generated_images = generate(model_test, StableRNG(generate_image_seed),
(image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |>
cpu_device()

path = joinpath(image_dir, "epoch_$(epoch)")
@info "Saving generated images to $(path)"
imgs = __save_images(path, generated_images)
imgs = save_images(path, generated_images)
log_images(tb_logger, "Generated Images", imgs; step)
end

Expand Down
10 changes: 6 additions & 4 deletions src/helpers/stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ function (s::StatefulLuxLayer)(x, p=s.ps)
return y
end

function CRC.rrule(
::Type{<:StatefulLuxLayer{FT}}, model::AbstractLuxLayer, ps, st, st_any) where {FT}
slayer = StatefulLuxLayer{FT}(model, ps, st, st_any)
∇StatefulLuxLayer(Δ) = NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent()
function CRC.rrule(::Type{<:StatefulLuxLayer}, model::AbstractLuxLayer,
ps, st, st_any, fixed_state_type)
slayer = StatefulLuxLayer(model, ps, st, st_any, fixed_state_type)
function ∇StatefulLuxLayer(Δ)
return NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent(), NoTangent()
end
return slayer, ∇StatefulLuxLayer
end

Expand Down

0 comments on commit bf11707

Please sign in to comment.