Skip to content

Commit

Permalink
fix huber_c random timestep selection
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 9, 2024
1 parent 7274b08 commit e6857f1
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5144,30 +5144,27 @@ def generate_timestep_weights(args, num_timesteps):


def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.timestep_bias_strategy == "none":
# Sample a random timestep for each image without bias within [min_timestep, max_timestep)
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
else:
# Sample a random timestep for each image, potentially biased by the timestep weights.
weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(device)
timesteps = torch.multinomial(weights, b_size, replacement=True)

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()

if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
huber_c = torch.exp(np.log(args.huber_c) * timesteps / noise_scheduler.config.num_train_timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
alphas_cumprod = noise_scheduler.alphas_cumprod.to(device=device)[timesteps]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
huber_c = None
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()
Expand Down

0 comments on commit e6857f1

Please sign in to comment.