Skip to content

Commit

Permalink
Stable diffusion pipeline updates to use int32 and float32.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681267676
  • Loading branch information
ai-edge-bot authored and copybara-github committed Oct 2, 2024
1 parent d4c8646 commit b076be8
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions ai_edge_torch/generative/examples/stable_diffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,13 @@ def run_tflite_pipeline(

# Text embedding.
cond_tokens = model.tokenizer.encode(prompt)
cond_context = model.clip(np.array(cond_tokens), signature_name='encode')
cond_context = model.clip(
np.array(cond_tokens).astype(np.int32), signature_name='encode'
)
uncond_tokens = model.tokenizer.encode(uncond_prompt)
uncond_context = model.clip(np.array(uncond_tokens), signature_name='encode')
uncond_context = model.clip(
np.array(uncond_tokens).astype(np.int32), signature_name='encode'
)
context = np.concatenate([cond_context, uncond_context], axis=0)
noise_shape = (1, 4, height // 8, width // 8)

Expand All @@ -198,7 +202,7 @@ def run_tflite_pipeline(
input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
input_image_np = util.move_channel(input_image_np, to='first')
encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
latents = model.encoder(input_image_np, encoder_noise)
latents = model.encoder(input_image_np.astype(np.float32), encoder_noise)
latents_noise = np.random.normal(size=noise_shape).astype(np.float32)
sampler.set_strength(strength=strength)
latents += latents_noise * sampler.initial_scale
Expand All @@ -214,15 +218,18 @@ def run_tflite_pipeline(
input_latents = latents * sampler.get_input_scale()
input_latents = input_latents.repeat(2, axis=0)
output = model.diffusion(
input_latents, context, time_embedding, signature_name='diffusion'
input_latents.astype(np.float32),
context.astype(np.float32),
time_embedding,
signature_name='diffusion',
)
output_cond, output_uncond = np.split(output, 2, axis=0)
output = cfg_scale * (output_cond - output_uncond) + output_uncond

latents = sampler.step(latents, output)

# Image decoding.
images = model.decoder(latents, signature_name='decode')
images = model.decoder(latents.astype(np.float32), signature_name='decode')
images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
images = util.move_channel(images, to='last')
if not os.path.exists(output_path):
Expand Down

0 comments on commit b076be8

Please sign in to comment.