Skip to content

Commit

Permalink
update stablediffusion_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cmgzn committed Sep 12, 2024
1 parent a5f31e2 commit 992695d
Showing 1 changed file with 158 additions and 89 deletions.
247 changes: 158 additions & 89 deletions src/agentscope/models/stablediffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@
"""Model wrapper for stable diffusion models."""
from abc import ABC
import base64
import json
import time
from typing import Any, Optional, Union, List, Sequence

import requests
from loguru import logger

from . import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import Msg
from ..manager import FileManager
import requests
from ..utils.common import _convert_to_str


class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
"""The base class for stable-diffusion model wrappers.
To use SD API, please
To use SD-webui API, please
1. First download stable-diffusion-webui from https://github.com/AUTOMATIC1111/stable-diffusion-webui and
install it with 'webui-user.bat'
2. Move your checkpoint to 'models/Stable-diffusion' folder
Expand All @@ -23,77 +29,176 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
query the available parameters on the http://localhost:7860/docs page
"""

model_type: str
"""The type of the model wrapper, which is to identify the model wrapper
class in model configuration."""

options: dict
"""A dict contains the options for stable-diffusion option API.
Modifications made through this parameter are persistent, meaning they will
remain in effect for subsequent generation requests until explicitly changed or reset.
e.g. {"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}"""
model_type: str = "stable_diffusion"

def __init__(
self,
config_name: str,
options: dict = None,
host: str = "127.0.0.1:7860",
base_url: Optional[Union[str, None]] = None,
use_https: bool = False,
generate_args: dict = None,
url: Optional[Union[str, None]] = None,
headers: dict = None,
options: dict = None,
timeout: int = 30,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) -> None:
"""Initialize the model wrapper for SD-webui API.
"""
Initializes the SD-webui API client.
Args:
options (`dict`, default `None`):
config_name (`str`):
The name of the model config.
host (`str`, default `"127.0.0.1:7860"`):
The host port of the stable-diffusion webui server.
base_url (`str`, default `None`):
Base URL for the stable-diffusion webui services. If not provided, it will be generated based on `host` and `use_https`.
use_https (`bool`, default `False`):
Whether to generate the base URL with HTTPS protocol or HTTP.
generate_args (`dict`, default `None`):
The extra keyword arguments used in SD api generation,
e.g. `{"steps": 50}`.
headers (`dict`, default `None`):
HTTP request headers.
options (`dict`, default `None`):
The keyword arguments to change the webui settings
such as model or CLIP skip, this changes will persist across sessions.
e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}`.
generate_args (`dict`, default `None`):
The extra keyword arguments used in SD-webui api generation,
e.g. `steps`, `seed`.
url (`str`, default `None`):
The url of the SD-webui server.
Defaults to `None`, which is http://127.0.0.1:7860.
"""
if url is None:
url = "http://127.0.0.1:7860"
# If base_url is not provided, construct it based on whether HTTPS is used
if base_url is None:
if use_https:
base_url = f"https://{host}"
else:
base_url = f"http://{host}"

self.url = url
self.base_url = base_url
self.options_url = f"{base_url}/sdapi/v1/options"
self.generate_args = generate_args or {}

options_url = f"{self.url}/sdapi/v1/options"
# Get the current default model
default_model_name = (
requests.get(options_url)
.json()["sd_model_checkpoint"]
.split("[")[0]
.strip()
# Initialize the HTTP session and update the request headers
self.session = requests.Session()
if headers:
self.session.headers.update(headers)

# Set options if provided
if options:
self._set_options(options)

# Get the default model name from the web-options
model_name = self._get_options()["sd_model_checkpoint"].split("[")[0].strip()
# Update the model name if override_settings is provided in generate_args
if self.generate_args.get("override_settings"):
model_name = generate_args["override_settings"].get(
"sd_model_checkpoint", model_name
)

super().__init__(config_name=config_name, model_name=model_name)

self.timeout = timeout
self.max_retries = max_retries
self.retry_interval = retry_interval

@property
def url(self):
"""SD-webui API endpoint URL"""
raise NotImplementedError()

def _get_options(self) -> dict:
response = self.session.get(url=self.options_url)
if response.status_code != 200:
logger.error(f"Failed to get options with {response.json()}")
raise RuntimeError(f"Failed to get options with {response.json()}")
return response.json()

def _set_options(self, options) -> None:
response = self.session.post(url=self.options_url, json=options)
if response.status_code != 200:
logger.error(json.dumps(options, indent=4))
raise RuntimeError(f"Failed to set options with {response.json()}")
else:
logger.info("Optionsset successfully")

def _invoke_model(self, payload: dict) -> dict:
"""Invoke SD webui API and record the invocation if needed"""
# step1: prepare post requests
for i in range(1, self.max_retries + 1):
response = self.session.post(url=self.url, json=payload)

if response.status_code == requests.codes.ok:
break

if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)

# step2: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=payload,
response=response.json(),
)

if options is not None:
# Update webui options if needed
requests.post(options_url, json=options)
model_name = options.get("sd_model_checkpoint", default_model_name)
# step3: return the response json
if response.status_code == requests.codes.ok:
return response.json()
else:
model_name = default_model_name
logger.error(json.dumps({"url": self.url, "json": payload}, indent=4))
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)

super().__init__(config_name=config_name, model_name=model_name)
def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)

def __call__(self, **kwargs: Any) -> ModelResponse:
payload = {
**self.generate_args,
**kwargs,
}
response = self._invoke_model(payload)
return self._parse_response(response)

def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
f"need to format the input. Please try to use the "
f"model wrapper directly.",
)


class StableDiffusionTxt2imgWrapper(StableDiffusionWrapperBase):
"""Stable Diffusion txt2img API wrapper"""

model_type: str = "sd_txt2img"

@property
def url(self):
return f"{self.base_url}/sdapi/v1/txt2img"

def _parse_response(self, response: dict) -> ModelResponse:
session_parameters = response["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = session_parameters["batch_size"] * session_parameters["n_iter"]

self.monitor.update_image_tokens(
model_name=self.model_name,
image_count=image_count,
resolution=size,
)

# Get image base64code as a list
images = response["images"]
b64_images = [base64.b64decode(image) for image in images]

file_manager = FileManager.get_instance()
# Return local url
image_urls = [file_manager.save_image(_) for _ in b64_images]
text = "Image saved to " + "\n".join(image_urls)
return ModelResponse(text=text, image_urls=image_urls, raw=response)

def __call__(
self,
prompt: str,
Expand All @@ -109,13 +214,11 @@ def __call__(
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API#api-guide-by-kilvoctu
or http://localhost:7860/docs
for more detailed arguments.
Returns:
`ModelResponse`:
A list of image local urls in image_urls field and the
raw response in raw field.
"""

# step1: prepare keyword arguments
payload = {
"prompt": prompt,
Expand All @@ -124,49 +227,15 @@ def __call__(
}

# step2: forward to generate response
txt2img_url = f"{self.url}/sdapi/v1/txt2img"
response = requests.post(url=txt2img_url, json=payload)

if response.status_code != requests.codes.ok:
error_msg = f" Status code: {response.status_code},"
raise RuntimeError(error_msg)

# step3: record the model api invocation if needed
output = response.json()
self._save_model_invocation(
arguments={
"model": self.model_name,
**payload,
},
response=output,
)

# step4: update monitor accordingly
session_parameters = output["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = session_parameters["batch_size"] * session_parameters["n_iter"]
response = self._invoke_model(payload)

self.monitor.update_image_tokens(
model_name=self.model_name,
image_count=image_count,
resolution=size,
)

# step5: return response
# Get image base64code as a list
images = output["images"]
b64_images = [base64.b64decode(image) for image in images]

file_manager = FileManager.get_instance()
# Return local url
urls = [file_manager.save_image(_) for _ in b64_images]
text = "Image saved to " + "\n".join(urls)
return ModelResponse(text=text, image_urls=urls, raw=response)
# step3: parse the response
return self._parse_response(response)

def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str:
# This is a temporary implementation to focus on the prompt
# on single-turn image generation by preserving only the system prompt and
# the last user message. This logic might change in the future to support
# This is a temporary implementation to focus on the prompt
# on single-turn image generation by preserving only the system prompt and
# the last user message. This logic might change in the future to support
# more complex conversational scenarios
if len(args) == 0:
raise ValueError(
Expand Down Expand Up @@ -204,7 +273,7 @@ def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str:

content_components = []
# Add system prompt at the beginning if provided
if sys_prompt is not None:
if sys_prompt:
content_components.append(sys_prompt)
# Add the last user message if the user messages is not empty
if len(user_messages) > 0:
Expand Down

0 comments on commit 992695d

Please sign in to comment.