Skip to content

Commit

Permalink
fix stablediffusion_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cmgzn committed Sep 19, 2024
1 parent b40f41f commit 30fe6f5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 42 deletions.
6 changes: 3 additions & 3 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from .yi_model import (
YiChatWrapper,
)
from .stablediffusion_model import(
StableDiffusionTxt2imgWrapper
from .stablediffusion_model import (
StableDiffusionImageSynthesisWrapper,
)

__all__ = [
Expand All @@ -67,7 +67,7 @@
"ZhipuAIEmbeddingWrapper",
"LiteLLMChatWrapper",
"YiChatWrapper",
"StableDiffusionTxt2imgWrapper",
"StableDiffusionImageSynthesisWrapper",
]


Expand Down
77 changes: 38 additions & 39 deletions src/agentscope/models/stablediffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
"""The base class for stable-diffusion model wrappers.
To use SD-webui API, please
1. First download stable-diffusion-webui from https://github.com/AUTOMATIC1111/stable-diffusion-webui and
1. First download stable-diffusion-webui from
https://github.com/AUTOMATIC1111/stable-diffusion-webui and
install it with 'webui-user.bat'
2. Move your checkpoint to 'models/Stable-diffusion' folder
3. Start launch.py with the '--api' parameter to start the server
After that, you can use the SD-webui API and
query the available parameters on the http://localhost:7860/docs page
query the available parameters on the http://localhost:7862/docs page
"""

model_type: str = "stable_diffusion"

def __init__(
self,
config_name: str,
host: str = "127.0.0.1:7860",
host: str = "127.0.0.1:7862",
base_url: Optional[Union[str, None]] = None,
use_https: bool = False,
generate_args: dict = None,
Expand All @@ -51,23 +52,24 @@ def __init__(
Args:
config_name (`str`):
The name of the model config.
host (`str`, default `"127.0.0.1:7860"`):
host (`str`, default `"127.0.0.1:7862"`):
The host port of the stable-diffusion webui server.
base_url (`str`, default `None`):
Base URL for the stable-diffusion webui services. If not provided, it will be generated based on `host` and `use_https`.
Base URL for the stable-diffusion webui services.
Generated from host and use_https if not provided.
use_https (`bool`, default `False`):
Whether to generate the base URL with HTTPS protocol or HTTP.
generate_args (`dict`, default `None`):
generate_args (`dict`, default `None`):
The extra keyword arguments used in SD api generation,
e.g. `{"steps": 50}`.
headers (`dict`, default `None`):
HTTP request headers.
options (`dict`, default `None`):
headers (`dict`, default `None`):
HTTP request headers.
options (`dict`, default `None`):
The keyword arguments to change the webui settings
such as model or CLIP skip, this changes will persist across sessions.
e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned", "CLIP_stop_at_last_layers": 2}`.
such as model or CLIP skip, this changes will persist.
e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned"}`.
"""
# If base_url is not provided, construct it based on whether HTTPS is used
# Construct base_url based on HTTPS usage if not provided
if base_url is None:
if use_https:
base_url = f"https://{host}"
Expand All @@ -88,21 +90,24 @@ def __init__(
self._set_options(options)

# Get the default model name from the web-options
model_name = self._get_options()["sd_model_checkpoint"].split("[")[0].strip()
# Update the model name if override_settings is provided in generate_args
model_name = (
self._get_options()["sd_model_checkpoint"].split("[")[0].strip()
)
# Update the model name
if self.generate_args.get("override_settings"):
model_name = generate_args["override_settings"].get(
"sd_model_checkpoint", model_name
"sd_model_checkpoint",
model_name,
)

super().__init__(config_name=config_name, model_name=model_name)

self.timeout = timeout
self.max_retries = max_retries
self.retry_interval = retry_interval

@property
def url(self):
def url(self) -> str:
"""SD-webui API endpoint URL"""
raise NotImplementedError()

Expand All @@ -113,13 +118,12 @@ def _get_options(self) -> dict:
raise RuntimeError(f"Failed to get options with {response.json()}")
return response.json()

def _set_options(self, options) -> None:
def _set_options(self, options: dict) -> None:
response = self.session.post(url=self.options_url, json=options)
if response.status_code != 200:
logger.error(json.dumps(options, indent=4))
raise RuntimeError(f"Failed to set options with {response.json()}")
else:
logger.info("Optionsset successfully")
logger.info("Optionsset successfully")

def _invoke_model(self, payload: dict) -> dict:
"""Invoke SD webui API and record the invocation if needed"""
Expand Down Expand Up @@ -150,7 +154,9 @@ def _invoke_model(self, payload: dict) -> dict:
if response.status_code == requests.codes.ok:
return response.json()
else:
logger.error(json.dumps({"url": self.url, "json": payload}, indent=4))
logger.error(
json.dumps({"url": self.url, "json": payload}, indent=4),
)
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)
Expand All @@ -159,29 +165,22 @@ def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)

def __call__(self, **kwargs: Any) -> ModelResponse:
payload = {
**self.generate_args,
**kwargs,
}
response = self._invoke_model(payload)
return self._parse_response(response)



class StableDiffusionTxt2imgWrapper(StableDiffusionWrapperBase):
"""Stable Diffusion txt2img API wrapper"""
class StableDiffusionImageSynthesisWrapper(StableDiffusionWrapperBase):
"""Stable Diffusion Text-to-Image (txt2img) API Wrapper"""

model_type: str = "sd_txt2img"

@property
def url(self):
def url(self) -> str:
return f"{self.base_url}/sdapi/v1/txt2img"

def _parse_response(self, response: dict) -> ModelResponse:
session_parameters = response["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = session_parameters["batch_size"] * session_parameters["n_iter"]
image_count = (
session_parameters["batch_size"] * session_parameters["n_iter"]
)

self.monitor.update_image_tokens(
model_name=self.model_name,
Expand Down Expand Up @@ -211,7 +210,7 @@ def __call__(
**kwargs (`Any`):
The keyword arguments to SD-webui txt2img API, e.g.
`n_iter`, `steps`, `seed`, `width`, etc. Please refer to
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API#api-guide-by-kilvoctu
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API
or http://localhost:7860/docs
for more detailed arguments.
Returns:
Expand All @@ -234,9 +233,9 @@ def __call__(

def format(self, *args: Msg | Sequence[Msg]) -> List[dict] | str:
# This is a temporary implementation to focus on the prompt
# on single-turn image generation by preserving only the system prompt and
# the last user message. This logic might change in the future to support
# more complex conversational scenarios
# on single-turn image generation by preserving only the system prompt
# and the last user message. This logic might change in the future
# to support more complex conversational scenarios
if len(args) == 0:
raise ValueError(
"At least one message should be provided. An empty message "
Expand Down

0 comments on commit 30fe6f5

Please sign in to comment.