From 42ce69ee851a8956afc6c1e12631db7e1cf3764b Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Wed, 13 Sep 2023 21:52:12 +0000 Subject: [PATCH] Add Flamingo example --- advanced/flamingo/README.md | 66 +++++++++++++++++++ advanced/flamingo/photon.py | 123 ++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 advanced/flamingo/README.md create mode 100644 advanced/flamingo/photon.py diff --git a/advanced/flamingo/README.md b/advanced/flamingo/README.md new file mode 100644 index 0000000..d4bbe5b --- /dev/null +++ b/advanced/flamingo/README.md @@ -0,0 +1,66 @@ +# Flamingo + +[Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) is an effective and efficient general-purpose family of models that can be applied to image and video understanding tasks with minimal task-specific examples. In this example we are going to run Flamingo with [open-flamingo](https://github.com/mlfoundations/open_flamingo) on Lepton. + +## Install Lepton sdk +```shell +pip install leptonai +``` + +## Launch Flamingo inference service locally + +Run: +```shell +lep photon run -n flamingo -m photon.py +``` +Although it's runnable on cpu, we recommend you to use a gpu to run vision model to get more satisfying performance. + +## Launch Flamingo inference service in the cloud + +Similar to other examples, you can run Flamingo with the following command. + +```shell +lep photon create -n flamingo -m photon.py +lep photon push -n flamingo +lep photon run \ + -n flamingo \ + --resource-shape gpu.a10 +``` + +And visit [dashboard.lepton.ai](https://dashboard.lepton.ai/) to try out the model. + +Note: in default, the server is protected via a token, so you won't be able to access the gradio UI. This is by design to provide adequate security. If you want to make the UI public, you can either add the `--public` argument to `lep photon run`, or update the deployment with: + +```shell +lep deployment update -n flamingo --public +``` + +### Client + +Once the inference service is up (either locally or in the cloud), you can use the client to access it in a programmatical way: + +```python +from leptonai.client import Client + +client = Client(...) + +inputs = { + "demo_images": [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/test-stuff2017/000000028137.jpg" + ], + "demo_texts": [ + "An image of two cats.", + "An image of a bathroom sink." + ], + "query_image": "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", + "query_text": "An image of" +} +res = client.run(**inputs) + +print(inputs["query_text"] + res) +``` + +``` +An image of a buffet table. +``` diff --git a/advanced/flamingo/photon.py b/advanced/flamingo/photon.py new file mode 100644 index 0000000..86dc30a --- /dev/null +++ b/advanced/flamingo/photon.py @@ -0,0 +1,123 @@ +import base64 +from io import BytesIO + +from typing import List, Union + +from leptonai.photon import Photon, FileParam, HTTPException + + +class Flamingo(Photon): + requirement_depdency = ["open-flamingo", "huggingface-hub", "Pillow", "requests"] + + IMAGE_TOKEN = "" + END_OF_TEXT_TOKEN = "<|endofchunk|>" + + def init(self): + from open_flamingo import create_model_and_transforms + from huggingface_hub import hf_hub_download + import torch + + if torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + + model, image_processor, tokenizer = create_model_and_transforms( + clip_vision_encoder_path="ViT-L-14", + clip_vision_encoder_pretrained="openai", + lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", + tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", + cross_attn_every_n_layers=1, + ) + + checkpoint_path = hf_hub_download( + "openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt" + ) + model.load_state_dict(torch.load(checkpoint_path), strict=False) + model = model.to(self.device) + + tokenizer.padding_side = "left" + + self.model = model + self.image_processor = image_processor + self.tokenizer = tokenizer + + def _img_param_to_img(self, param): + from PIL import Image + import requests + + if isinstance(param, FileParam): + content = param.file.read() + elif isinstance(param, str): + if param.startswith("http://") or param.startswith("https://"): + content = requests.get(param).content + else: + content = base64.b64decode(param).decode("utf-8") + else: + raise TypeError(f"Invalid image type: {type(param)}") + + return Image.open(BytesIO(content)) + + @Photon.handler( + example={ + "demo_images": [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", + ], + "demo_texts": ["An image of two cats.", "An image of a bathroom sink."], + "query_image": ( + "http://images.cocodataset.org/test-stuff2017/000000028352.jpg" + ), + "query_text": "An image of", + }, + ) + def run( + self, + demo_images: List[Union[FileParam, str]], + demo_texts: List[str], + query_image: Union[FileParam, str], + query_text: str, + max_new_tokens: int = 32, + num_beams: int = 3, + ) -> str: + import torch + + if len(demo_images) != len(demo_texts): + raise HTTPException( + status_code=400, + detail="The number of demo images and demo texts must be the same.", + ) + + demo_images = [self._img_param_to_img(img) for img in demo_images] + query_image = self._img_param_to_img(query_image) + + vision_x = [ + self.image_processor(img).unsqueeze(0).to(self.device) + for img in (demo_images + [query_image]) + ] + vision_x = torch.cat(vision_x, dim=0) + vision_x = vision_x.unsqueeze(1).unsqueeze(0) + + lang_x_text = self.END_OF_TEXT_TOKEN.join( + f"{self.IMAGE_TOKEN}{text}" for text in (demo_texts + [query_text]) + ) + lang_x = self.tokenizer( + lang_x_text, + return_tensors="pt", + ) + + generated_text = self.model.generate( + vision_x=vision_x, + lang_x=lang_x["input_ids"].to(self.device), + attention_mask=lang_x["attention_mask"].to(self.device), + max_new_tokens=max_new_tokens, + num_beams=num_beams, + ) + generated_text = self.tokenizer.decode(generated_text[0]) + + if generated_text.startswith(lang_x_text): + generated_text = generated_text[len(lang_x_text) :] + if generated_text.endswith(self.END_OF_TEXT_TOKEN): + generated_text = generated_text[: -len(self.END_OF_TEXT_TOKEN)] + + return generated_text