Skip to content

Commit

Permalink
Merge branch 'feffy380' into orpo
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed May 26, 2024
2 parents bf37587 + c88a04c commit 37334fb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
10 changes: 7 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,15 +2479,14 @@ def load_image(image_path):

# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize

if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = image.resize(resized_size, Image.Resampling.LANCZOS)
# image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ

image_height, image_width = image.shape[0:2]

Expand Down Expand Up @@ -3430,6 +3429,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
type=float,
help="Disable ToDo after this many steps. Value less than 1.0 is fraction of total step count",
)
parser.add_argument(
"--todo_mem_eff_attn",
action="store_true",
help="enable memory-efficient attention after disabling ToDo",
)
parser.add_argument(
"--loss_type",
type=str,
Expand Down
16 changes: 15 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,21 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
disable_step = int(disable_step) + 1
if global_step == disable_step:
token_downsampling.remove_patch(unet)
train_util.replace_unet_modules(unet, mem_eff_attn=True, xformers=False, sdpa=False)

# according to TI example in Diffusers, train is required
unet.enable_gradient_checkpointing()
unet.train()
for t_enc in text_encoders:
t_enc.gradient_checkpointing_enable()
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder or args.continue_inversion:
t_enc.text_model.embeddings.requires_grad_(True)
del t_enc
network.enable_gradient_checkpointing() # may have no effect

if args.todo_mem_eff_attn:
train_util.replace_unet_modules(unet, mem_eff_attn=True, xformers=False, sdpa=False)

if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True)
Expand Down

0 comments on commit 37334fb

Please sign in to comment.