Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update/Fix Pipeline Mixins and ORT Pipelines #2021

Open
wants to merge 66 commits into
base: main
Choose a base branch
from

Conversation

IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Sep 10, 2024

What does this PR do?

This PR allows for using the same modeling in diffusers for ORT diffusion pipelines without maintaining custom mixins.
It also fixes the issues in output reproducibility and numeric consistency vs diffusers observed in #1960.
Breaking changes:

  • we export the vae encoder by outputting its latent distribution parameters instead of sampling during export, this way we can solve the above mentioned issues.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

…eproducibility and comparaison tests (7 failed, 35 passed)
optimum/exporters/onnx/model_configs.py Outdated Show resolved Hide resolved
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
"latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will result in a breaking change so would keep it if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it will be wrong because we don't sample the latent distribution anymore

)

init_latents = [
retrieve_latents(self.vae_encoder(image[i : i + 1]), generator=generator[i])
Copy link
Collaborator

@echarlaix echarlaix Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would like to keep as much numpy as possible (no need to have np -> torch -> np) sot not sure why we need retrieve_latents + DiagonalGaussianDistribution, could you explain why ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retrieve_latents will sample from the DiagonalGaussianDistribution distribution that's parametrized by the outputs of the VAE encoder using the generator. Exporting the encoder with its final sampled output results in inconsistency with the diffusers output and inability to even have deterministic outputs (with the same generator/seed and outputs) in ORT pipelines for Img2Img and Inpainting.

Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that (np -> torch -> np) is inevitable if we want drop-in replacement compatibility with diffusers, the whole logic in diffusers is dependent on torch functionalities (like distributions and generators) that can't be translated perfectly to numpy (the underlying implementations and random numbers generators are not the same).
I would argue that this doesn't add a lot of overhead in this case, especially on CPUs, e.g. torch.from_numpy doesn't create a new tensor but rather points to the same block of memory. Tensor.numpy work similarly when force is False (default), so there's almost no copying or allocation of tensors/ndarrays (if I understand correctly).
With GPUs it allows for better perf because with torch we can easily implement io binding for pipelines on CUDA EP 😁.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that there's also already .from_numpy .numpy() calls in the unet loop:

# compute the previous noisy sample x_t -> x_t-1
latents, denoised = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False
)
latents, denoised = latents.numpy(), denoised.numpy()

@IlyasMoutawwakil
Copy link
Member Author

@echarlaix, as discussed, one reason why we need to export the distribution parameters instead of performing sampling at export time, is because we can't control the randomness otherwise, which results in some issues, one of which is not being able to reproduce the same output even with the same inputs and seeds. Something as simple as the following snippet does not reproduce the same output with ort.

import random
import torch
import onnxruntime as ort
import numpy as np

def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    ort.set_seed(seed)

class RandomModule(torch.nn.Module):
    def __init__(self):
        super(RandomModule, self).__init__()

    def forward(self, x):
        return x + torch.rand(x.size())
    
model = RandomModule()
x = torch.randn(1, 3, 224, 224)

set_all_seeds(0)
torch_output_1 = model(x)
set_all_seeds(0)
torch_output_2 = model(x)
print(np.testing.assert_allclose(torch_output_1, torch_output_2))

set_all_seeds(0)
torch.onnx.export(model, x, "random.onnx")
ort_session = ort.InferenceSession("random.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

set_all_seeds(0)
ort_output_1 = ort_session.run([output_name], {input_name: x.numpy()})[0]

set_all_seeds(0)
ort_output_2 = ort_session.run([output_name], {input_name: x.numpy()})[0]

print(np.testing.assert_allclose(ort_output_1, ort_output_2))

PS: I just found out abourt ort.set_seed and it doesn't seem to make any difference.

@IlyasMoutawwakil
Copy link
Member Author

the overhead from using from_numpy() and tensor.numpy is minimal because these methods only copy the pointer to the block of memory that contains the array (on cpu). An end to end benchmark of the tiny model (where if there's an overhead, it should be most apparent) demonstrates that the perf is almost the same:

  • on main:
 [MAIN-PROCESS][2024-09-13 13:55:22,816][process][INFO] -        + Received report from isolated process
[MAIN-PROCESS][2024-09-13 13:55:22,817][memory][INFO] -                 + load memory:
[MAIN-PROCESS][2024-09-13 13:55:22,818][memory][INFO] -                         - max RAM: 878.497792 (MB)
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                + load latency:
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - count: 1
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - total: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - mean: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - stdev: 0.000000 s (0.00%)
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - p50: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - p90: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,818][latency][INFO] -                        - p95: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - p99: 7.848160 s
[MAIN-PROCESS][2024-09-13 13:55:22,819][memory][INFO] -                 + call memory:
[MAIN-PROCESS][2024-09-13 13:55:22,819][memory][INFO] -                         - max RAM: 1038.311424 (MB)
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                + call latency:
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - count: 10
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - total: 39.163408 s
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - mean: 3.916341 s
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - stdev: 0.175419 s (4.48%)
[MAIN-PROCESS][2024-09-13 13:55:22,819][latency][INFO] -                        - p50: 3.862534 s
[MAIN-PROCESS][2024-09-13 13:55:22,820][latency][INFO] -                        - p90: 4.104698 s
[MAIN-PROCESS][2024-09-13 13:55:22,820][latency][INFO] -                        - p95: 4.234256 s
[MAIN-PROCESS][2024-09-13 13:55:22,820][latency][INFO] -                        - p99: 4.337903 s
[MAIN-PROCESS][2024-09-13 13:55:22,820][latency][INFO] -                + call throughput: 0.255340 images/s
  • on pr:
[MAIN-PROCESS][2024-09-13 13:57:36,918][process][INFO] -        + Received report from isolated process
[MAIN-PROCESS][2024-09-13 13:57:36,919][memory][INFO] -                 + load memory:
[MAIN-PROCESS][2024-09-13 13:57:36,919][memory][INFO] -                         - max RAM: 612.913152 (MB)
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                + load latency:
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - count: 1
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - total: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - mean: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - stdev: 0.000000 s (0.00%)
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - p50: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - p90: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - p95: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,920][latency][INFO] -                        - p99: 7.528061 s
[MAIN-PROCESS][2024-09-13 13:57:36,921][memory][INFO] -                 + call memory:
[MAIN-PROCESS][2024-09-13 13:57:36,921][memory][INFO] -                         - max RAM: 1033.846784 (MB)
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                + call latency:
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - count: 10
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - total: 38.591033 s
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - mean: 3.859103 s
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - stdev: 0.174273 s (4.52%)
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - p50: 3.814653 s
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - p90: 4.081044 s
[MAIN-PROCESS][2024-09-13 13:57:36,921][latency][INFO] -                        - p95: 4.187505 s
[MAIN-PROCESS][2024-09-13 13:57:36,922][latency][INFO] -                        - p99: 4.272674 s
[MAIN-PROCESS][2024-09-13 13:57:36,922][latency][INFO] -                + call throughput: 0.259128 images/s

On actual/bigger pipelines the overhead is beyond minimal as one inference step takes a couple of seconds, while the conversion torch-np-torch takes microseconds:

import time

import torch

with torch.no_grad():
    torch_tensor = torch.randn(1000, 1000)

    start = time.time()
    for _ in range(30):
        numpy_array = torch_tensor.numpy()
        torch_tensor = torch.from_numpy(numpy_array)
    end = time.time()

    print("Time taken for conversion: ", (end - start) / 30)
$ python bench.py 
Time taken for conversion:  7.883707682291667e-06

return ModelOutput(**model_outputs)


class ORTVaeWrapper(ORTPipelinePart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!

@@ -320,6 +318,7 @@ def test_load_stable_diffusion_model_from_hub(self):
self.assertIsInstance(model.vae_encoder, ORTModelVaeEncoder)
self.assertIsInstance(model.unet, ORTModelUnet)
self.assertIsInstance(model.config, Dict)
model(prompt="cat", num_inference_steps=2)
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this to check inferencing and not only loading, can be done better

self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline"
self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "IlyasMoutawwakil/tiny-stable-diffusion-onnx"
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because this pipeline doesn't even have configs, it used to load (and can still load) but since inferencing with it wasn't tested, it didn't raise any issues, but there is an issue, which is that configs are needed for inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants