Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flamingo example #32

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions advanced/flamingo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Flamingo

[Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) is an effective and efficient general-purpose family of models that can be applied to image and video understanding tasks with minimal task-specific examples. In this example we are going to run Flamingo with [open-flamingo](https://github.com/mlfoundations/open_flamingo) on Lepton.

## Install Lepton sdk
```shell
pip install leptonai
```

## Launch Flamingo inference service locally

Run:
```shell
lep photon run -n flamingo -m photon.py
```
Although it's runnable on cpu, we recommend you to use a gpu to run vision model to get more satisfying performance.

## Launch Flamingo inference service in the cloud

Similar to other examples, you can run Flamingo with the following command.

```shell
lep photon create -n flamingo -m photon.py
lep photon push -n flamingo
lep photon run \
-n flamingo \
--resource-shape gpu.a10
```

And visit [dashboard.lepton.ai](https://dashboard.lepton.ai/) to try out the model.

Note: in default, the server is protected via a token, so you won't be able to access the gradio UI. This is by design to provide adequate security. If you want to make the UI public, you can either add the `--public` argument to `lep photon run`, or update the deployment with:

```shell
lep deployment update -n flamingo --public
```

### Client

Once the inference service is up (either locally or in the cloud), you can use the client to access it in a programmatical way:

```python
from leptonai.client import Client

client = Client(...)

inputs = {
"demo_images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg"
],
"demo_texts": [
"An image of two cats.",
"An image of a bathroom sink."
],
"query_image": "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
"query_text": "An image of"
}
res = client.run(**inputs)

print(inputs["query_text"] + res)
```

```
An image of a buffet table.
```
123 changes: 123 additions & 0 deletions advanced/flamingo/photon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
from io import BytesIO

from typing import List, Union

from leptonai.photon import Photon, FileParam, HTTPException


class Flamingo(Photon):
requirement_depdency = ["open-flamingo", "huggingface-hub", "Pillow", "requests"]

IMAGE_TOKEN = "<image>"
END_OF_TEXT_TOKEN = "<|endofchunk|>"

def init(self):
from open_flamingo import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch

if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"

model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1,
)

checkpoint_path = hf_hub_download(
"openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt"
)
model.load_state_dict(torch.load(checkpoint_path), strict=False)
model = model.to(self.device)

tokenizer.padding_side = "left"

self.model = model
self.image_processor = image_processor
self.tokenizer = tokenizer

def _img_param_to_img(self, param):
from PIL import Image
import requests

if isinstance(param, FileParam):
content = param.file.read()
elif isinstance(param, str):
if param.startswith("http://") or param.startswith("https://"):
content = requests.get(param).content
else:
content = base64.b64decode(param).decode("utf-8")
else:
raise TypeError(f"Invalid image type: {type(param)}")

return Image.open(BytesIO(content))

@Photon.handler(
example={
"demo_images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
],
"demo_texts": ["An image of two cats.", "An image of a bathroom sink."],
"query_image": (
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg"
),
"query_text": "An image of",
},
)
def run(
self,
demo_images: List[Union[FileParam, str]],
demo_texts: List[str],
query_image: Union[FileParam, str],
query_text: str,
max_new_tokens: int = 32,
num_beams: int = 3,
) -> str:
import torch

if len(demo_images) != len(demo_texts):
raise HTTPException(
status_code=400,
detail="The number of demo images and demo texts must be the same.",
)

demo_images = [self._img_param_to_img(img) for img in demo_images]
query_image = self._img_param_to_img(query_image)

vision_x = [
self.image_processor(img).unsqueeze(0).to(self.device)
for img in (demo_images + [query_image])
]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

lang_x_text = self.END_OF_TEXT_TOKEN.join(
f"{self.IMAGE_TOKEN}{text}" for text in (demo_texts + [query_text])
)
lang_x = self.tokenizer(
lang_x_text,
return_tensors="pt",
)

generated_text = self.model.generate(
vision_x=vision_x,
lang_x=lang_x["input_ids"].to(self.device),
attention_mask=lang_x["attention_mask"].to(self.device),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
generated_text = self.tokenizer.decode(generated_text[0])

if generated_text.startswith(lang_x_text):
generated_text = generated_text[len(lang_x_text) :]
if generated_text.endswith(self.END_OF_TEXT_TOKEN):
generated_text = generated_text[: -len(self.END_OF_TEXT_TOKEN)]

return generated_text