Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Replicate demo and API #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# HART: Efficient Visual Generation with Hybrid Autoregressive Transformer

\[[Paper](https://arxiv.org/abs/2410.10812)\] \[[Demo](https://hart.mit.edu)\] \[[Project](https://hanlab.mit.edu/projects/hart)\]
\[[Paper](https://arxiv.org/abs/2410.10812)\] \[[Demo](https://hart.mit.edu)\] \[[Project](https://hanlab.mit.edu/projects/hart)\] \[[Replicate Demo & API](https://replicate.com/chenxwh/hart)\]

![teaser_Page1](assets/teaser.jpg)

## News

- \[2024/10\] 🔥 Added the Replicate Demo and API [![Replicate](https://replicate.com/chenxwh/hart/badge)](https://replicate.com/chenxwh/hart)!
- \[2024/10\] 🔥 We open source the inference code and [Gradio demo](https://hart.mit.edu) for HART!

## Abstract
Expand Down
43 changes: 43 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
# set to true if your model requires a GPU
gpu: true

# a list of ubuntu apt packages to install
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"

# python version in the form '3.11' or '3.11.4'
python_version: "3.11"

# a list of packages in the format <package-name>==<version>
python_packages:
- "torch==2.3.0"
- "torchvision==0.18.0"
- "transformers==4.42.2"
- "tokenizers>=0.15.2"
- "sentencepiece==0.2.0"
- "shortuuid"
- "accelerate==0.27.2"
- "numpy==1.26.4"
- "scikit-learn==1.2.2"
- "einops==0.6.1"
- "einops-exts==0.0.4"
- "timm==0.9.12"
- "openpyxl==3.1.2"
- "nltk==3.3"
- "opencv-python==4.8.0.74"
- "omegaconf==2.3.0"
- "diffusers==0.28.2"
- "xformers==0.0.26.post1"
- "pydantic==1.10.7"

# commands run after the environment is setup
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
170 changes: 170 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Prediction interface for Cog ⚙️
# https://cog.run/python

import os
import subprocess
import time
import sys

os.environ["PYTHONDONTWRITEBYTECODE"] = "0"

original_dir = os.getcwd()
try:
subprocess.run(["python", "setup.py", "install"], cwd="hart/kernels")
import pkg_resources

dist = pkg_resources.get_distribution("hart_backend")
egg_path = dist.location
sys.path.append(egg_path)
finally:
os.chdir(original_dir)

import numpy as np
import torch
from PIL import Image
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
)
from cog import BasePredictor, Input, Path

from hart.modules.models.transformer.hart_transformer_t2i import (
HARTForT2IConfig,
HARTForT2I,
)
from hart.utils import encode_prompts, llm_system_prompt, safety_check

AutoConfig.register("hart_transformer_t2i", HARTForT2IConfig)
AutoModel.register(HARTForT2IConfig, HARTForT2I)


# cache files from mit-han-lab/Qwen2-VL-1.5B-Instruct, mit-han-lab/hart-0.7b-1024px, and google/shieldgemma-2b
MODEL_CACHE = "model_cache"
MODEL_URL = (
f"https://weights.replicate.delivery/default/mit-han-lab/hart/{MODEL_CACHE}.tar"
)

os.environ.update(
{
"HF_DATASETS_OFFLINE": "1",
"TRANSFORMERS_OFFLINE": "1",
"HF_HOME": MODEL_CACHE,
"TORCH_HOME": MODEL_CACHE,
"HF_DATASETS_CACHE": MODEL_CACHE,
"TRANSFORMERS_CACHE": MODEL_CACHE,
"HUGGINGFACE_HUB_CACHE": MODEL_CACHE,
}
)


def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""

if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)

model_path = f"{MODEL_CACHE}/mit-han-lab/hart-0.7b-1024px/llm"
self.model = AutoModel.from_pretrained(
model_path, torch_dtype=torch.float16
).to("cuda")
self.model.eval()
# use_ema by default
self.model.load_state_dict(
torch.load(os.path.join(model_path, "ema_model.bin"))
)

text_model_path = f"{MODEL_CACHE}/mit-han-lab/Qwen2-VL-1.5B-Instruct"
self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_path)
self.text_model = AutoModel.from_pretrained(
text_model_path, torch_dtype=torch.float16
).to("cuda")
self.text_model.eval()

shield_model_path = f"{MODEL_CACHE}/google/shieldgemma-2b"
self.safety_checker_tokenizer = AutoTokenizer.from_pretrained(shield_model_path)
self.safety_checker_model = AutoModelForCausalLM.from_pretrained(
shield_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).to("cuda")
self.safety_checker_model.eval()

def predict(
self,
prompt: str = Input(
description="Input prompt",
default="An astronaut riding a horse on the moon, oil painting by Van Gogh.",
),
max_token_length: int = Input(default=300),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
),
more_smooth: bool = Input(
description="Turn on for more visually smooth samples", default=True
),
use_llm_system_prompt: bool = Input(default=True),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""

assert not safety_check.is_dangerous(
self.safety_checker_tokenizer, self.safety_checker_model, prompt
), f"The prompt id not pass the safety checker, please use a different prompt."

if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator = torch.Generator().manual_seed(seed)

with torch.inference_mode():
with torch.autocast(
"cuda", enabled=True, dtype=torch.float16, cache_enabled=True
):

(
context_tokens,
context_mask,
context_position_ids,
context_tensor,
) = encode_prompts(
[prompt],
self.text_model,
self.text_tokenizer,
max_token_length,
llm_system_prompt,
use_llm_system_prompt,
)

infer_func = self.model.autoregressive_infer_cfg

output_imgs = infer_func(
B=context_tensor.size(0),
label_B=context_tensor,
cfg=guidance_scale,
g_seed=seed,
more_smooth=more_smooth,
context_position_ids=context_position_ids,
context_mask=context_mask,
)

sample_imgs_np = output_imgs.clone().mul_(255).cpu().numpy()
cur_img = sample_imgs_np[0]
cur_img = cur_img.transpose(1, 2, 0).astype(np.uint8)
cur_img_store = Image.fromarray(cur_img)

out_path = "/tmp/out.png"
cur_img_store.save(out_path)
return Path(out_path)