Skip to content

Commit

Permalink
Add Flamingo example
Browse files Browse the repository at this point in the history
  • Loading branch information
bddppq committed Sep 13, 2023
1 parent 7a5102c commit 42ce69e
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
66 changes: 66 additions & 0 deletions advanced/flamingo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Flamingo

[Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) is an effective and efficient general-purpose family of models that can be applied to image and video understanding tasks with minimal task-specific examples. In this example we are going to run Flamingo with [open-flamingo](https://github.com/mlfoundations/open_flamingo) on Lepton.

## Install Lepton sdk
```shell
pip install leptonai
```

## Launch Flamingo inference service locally

Run:
```shell
lep photon run -n flamingo -m photon.py
```
Although it's runnable on cpu, we recommend you to use a gpu to run vision model to get more satisfying performance.

## Launch Flamingo inference service in the cloud

Similar to other examples, you can run Flamingo with the following command.

```shell
lep photon create -n flamingo -m photon.py
lep photon push -n flamingo
lep photon run \
-n flamingo \
--resource-shape gpu.a10
```

And visit [dashboard.lepton.ai](https://dashboard.lepton.ai/) to try out the model.

Note: in default, the server is protected via a token, so you won't be able to access the gradio UI. This is by design to provide adequate security. If you want to make the UI public, you can either add the `--public` argument to `lep photon run`, or update the deployment with:

```shell
lep deployment update -n flamingo --public
```

### Client

Once the inference service is up (either locally or in the cloud), you can use the client to access it in a programmatical way:

```python
from leptonai.client import Client

client = Client(...)

inputs = {
"demo_images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg"
],
"demo_texts": [
"An image of two cats.",
"An image of a bathroom sink."
],
"query_image": "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
"query_text": "An image of"
}
res = client.run(**inputs)

print(inputs["query_text"] + res)
```

```
An image of a buffet table.
```
123 changes: 123 additions & 0 deletions advanced/flamingo/photon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
from io import BytesIO

from typing import List, Union

from leptonai.photon import Photon, FileParam, HTTPException


class Flamingo(Photon):
requirement_depdency = ["open-flamingo", "huggingface-hub", "Pillow", "requests"]

IMAGE_TOKEN = "<image>"
END_OF_TEXT_TOKEN = "<|endofchunk|>"

def init(self):
from open_flamingo import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch

if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"

model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1,
)

checkpoint_path = hf_hub_download(
"openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt"
)
model.load_state_dict(torch.load(checkpoint_path), strict=False)
model = model.to(self.device)

tokenizer.padding_side = "left"

self.model = model
self.image_processor = image_processor
self.tokenizer = tokenizer

def _img_param_to_img(self, param):
from PIL import Image
import requests

if isinstance(param, FileParam):
content = param.file.read()
elif isinstance(param, str):
if param.startswith("http://") or param.startswith("https://"):
content = requests.get(param).content
else:
content = base64.b64decode(param).decode("utf-8")
else:
raise TypeError(f"Invalid image type: {type(param)}")

return Image.open(BytesIO(content))

@Photon.handler(
example={
"demo_images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
],
"demo_texts": ["An image of two cats.", "An image of a bathroom sink."],
"query_image": (
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg"
),
"query_text": "An image of",
},
)
def run(
self,
demo_images: List[Union[FileParam, str]],
demo_texts: List[str],
query_image: Union[FileParam, str],
query_text: str,
max_new_tokens: int = 32,
num_beams: int = 3,
) -> str:
import torch

if len(demo_images) != len(demo_texts):
raise HTTPException(
status_code=400,
detail="The number of demo images and demo texts must be the same.",
)

demo_images = [self._img_param_to_img(img) for img in demo_images]
query_image = self._img_param_to_img(query_image)

vision_x = [
self.image_processor(img).unsqueeze(0).to(self.device)
for img in (demo_images + [query_image])
]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

lang_x_text = self.END_OF_TEXT_TOKEN.join(
f"{self.IMAGE_TOKEN}{text}" for text in (demo_texts + [query_text])
)
lang_x = self.tokenizer(
lang_x_text,
return_tensors="pt",
)

generated_text = self.model.generate(
vision_x=vision_x,
lang_x=lang_x["input_ids"].to(self.device),
attention_mask=lang_x["attention_mask"].to(self.device),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
generated_text = self.tokenizer.decode(generated_text[0])

if generated_text.startswith(lang_x_text):
generated_text = generated_text[len(lang_x_text) :]
if generated_text.endswith(self.END_OF_TEXT_TOKEN):
generated_text = generated_text[: -len(self.END_OF_TEXT_TOKEN)]

return generated_text

0 comments on commit 42ce69e

Please sign in to comment.