Skip to content

Commit

Permalink
praching fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
anzr299 committed Dec 30, 2024
1 parent 2c39174 commit a898c01
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@
"slideshow": {
"slide_type": ""
},
"tags": []
"tags": [],
"test_replace": {
"height=512,": "",
"width=512": ""
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -272,7 +276,11 @@
"slide_type": ""
},
"tags": [],
"test_replace": {}
"test_replace": {
"torch.ones((1, 16, 64, 64))": "torch.ones((1, 16, 128, 128))",
"torch.ones((1, 3, 64, 64))": "torch.ones((1, 3, 128, 128))",
"torch.ones((2, 16, 64, 64))": "torch.ones((2, 16, 128, 128))"
}
},
"outputs": [],
"source": [
Expand All @@ -287,7 +295,7 @@
"vae_decoder_input = torch.ones((1, 16, 64, 64))\n",
"\n",
"unet_kwargs = {}\n",
"unet_kwargs[\"hidden_states\"] = torch.ones((2, 16, 64, 64))\n",
"unet_kwargs[\"hidden_states\"] = orch.ones((2, 16, 64, 64))\n",
"unet_kwargs[\"timestep\"] = torch.from_numpy(np.array([1, 2], dtype=np.float32))\n",
"unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n",
"unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n",
Expand Down Expand Up @@ -396,7 +404,8 @@
"tags": [],
"test_replace": {
"calibration_dataset_size = 200": "calibration_dataset_size = 1",
"init_pipeline(models_dict, configs_dict)": "init_pipeline(models_dict, configs_dict, \"katuni4ka/tiny-random-sd3\")"
"init_pipeline(models_dict, configs_dict)": "init_pipeline(models_dict, configs_dict, \"katuni4ka/tiny-random-sd3\")",
"pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512)": "pipe(prompt, num_inference_steps=num_inference_steps)"
}
},
"outputs": [],
Expand Down Expand Up @@ -650,7 +659,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"test_replace": {
"prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator, height=512, width=512": "prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator"
}
},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
Expand All @@ -674,7 +687,12 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"test_replace": {
"height=512,": "",
"width=512": ""
}
},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
Expand Down

0 comments on commit a898c01

Please sign in to comment.