Skip to content

Commit

Permalink
Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth…
Browse files Browse the repository at this point in the history
…_sd3 (#9800)

* updated encode prompt and clip encod prompt


---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
SahilCarterr and sayakpaul authored Nov 6, 2024
1 parent e2b3c24 commit 76b7d86
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)

text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")

text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

pooled_prompt_embeds = prompt_embeds[0]
Expand All @@ -937,6 +943,7 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt

Expand All @@ -945,13 +952,14 @@ def encode_prompt(

clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
Expand Down

0 comments on commit 76b7d86

Please sign in to comment.