Skip to content

Commit

Permalink
StableCascade image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksandr-mokrov committed Apr 29, 2024
1 parent 98ba408 commit 8ec87ca
Showing 1 changed file with 10 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
" converted_model = nncf.compress_weights(converted_model)\n",
" ov.save_model(converted_model, xml_path)\n",
" del converted_model\n",
" \n",
"\n",
" # cleanup memory\n",
" torch._C._jit_clear_class_registry()\n",
" torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()\n",
Expand Down Expand Up @@ -271,7 +271,7 @@
"\n",
" def forward(self, input_ids, attention_mask):\n",
" outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
" return outputs['text_embeds'], outputs[\"last_hidden_state\"], outputs['hidden_states']\n",
" return outputs[\"text_embeds\"], outputs[\"last_hidden_state\"], outputs[\"hidden_states\"]\n",
"\n",
"\n",
"convert(\n",
Expand Down Expand Up @@ -362,7 +362,7 @@
" \"sample\": torch.zeros(1, 4, 256, 256),\n",
" \"timestep_ratio\": torch.ones(1),\n",
" \"clip_text_pooled\": torch.zeros(1, 1, 1280),\n",
" \"effnet\": torch.zeros(1, 16, 24, 24)\n",
" \"effnet\": torch.zeros(1, 16, 24, 24),\n",
" },\n",
" input_shape=[((1, 4, 256, 256),), ((1),), ((1, 1, 1280),), ((1, 16, 24, 24),)],\n",
")\n",
Expand All @@ -388,6 +388,7 @@
" def forward(self, h):\n",
" return self.vqgan.decode(h)\n",
"\n",
"\n",
"convert(\n",
" VqganDecoderWrapper(decoder.vqgan),\n",
" VQGAN_PATH,\n",
Expand Down Expand Up @@ -517,17 +518,12 @@
"source": [
"class DecoderWrapper:\n",
" dtype = torch.float32 # accessed in the original workflow\n",
" \n",
"\n",
" def __init__(self, decoder_path):\n",
" self.decoder = core.compile_model(decoder_path, DEVICE.value)\n",
"\n",
" def __call__(self, sample, timestep_ratio, clip_text_pooled, effnet, **kwargs):\n",
" inputs = {\n",
" \"sample\": sample,\n",
" \"timestep_ratio\": timestep_ratio,\n",
" \"clip_text_pooled\": clip_text_pooled,\n",
" \"effnet\": effnet\n",
" }\n",
" inputs = {\"sample\": sample, \"timestep_ratio\": timestep_ratio, \"clip_text_pooled\": clip_text_pooled, \"effnet\": effnet}\n",
" output = self.decoder(inputs)\n",
" return [torch.from_numpy(output[0])]"
]
Expand All @@ -541,6 +537,7 @@
"source": [
"VqganOutput = namedtuple(\"VqganOutput\", \"sample\")\n",
"\n",
"\n",
"class VqganWrapper:\n",
" config = namedtuple(\"VqganWrapperConfig\", \"scale_factor\")(0.3764) # accessed in the original workflow\n",
"\n",
Expand Down Expand Up @@ -680,7 +677,7 @@
" num_inference_steps=20,\n",
" generator=generator,\n",
" )\n",
" \n",
"\n",
" decoder_output = decoder(\n",
" image_embeddings=prior_output.image_embeddings,\n",
" prompt=caption,\n",
Expand All @@ -690,7 +687,7 @@
" num_inference_steps=10,\n",
" generator=generator,\n",
" ).images[0]\n",
" \n",
"\n",
" return decoder_output"
]
},
Expand All @@ -714,10 +711,7 @@
" gr.Slider(0, np.iinfo(np.int32).max, label=\"Seed\", step=1),\n",
" ],\n",
" \"image\",\n",
" examples=[\n",
" [\"An image of a shiba inu, donning a spacesuit and helmet\", \"\", 4, 0],\n",
" [\"An armchair in the shape of an avocado\", \"\", 4, 0]\n",
" ],\n",
" examples=[[\"An image of a shiba inu, donning a spacesuit and helmet\", \"\", 4, 0], [\"An armchair in the shape of an avocado\", \"\", 4, 0]],\n",
" allow_flagging=\"never\",\n",
")\n",
"try:\n",
Expand Down

0 comments on commit 8ec87ca

Please sign in to comment.