-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the yandex art with examples (#24)
- Loading branch information
Showing
19 changed files
with
1,416 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python3 | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import pathlib | ||
|
||
from yandex_cloud_ml_sdk import AsyncYCloudML | ||
|
||
|
||
async def main() -> None: | ||
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64') | ||
|
||
model = sdk.models.image_generation('yandex-art') | ||
|
||
# configuring model for all of future runs | ||
model = model.configure(width_ratio=1, height_ratio=2, seed=50) | ||
|
||
# simple run | ||
operation = await model.run_deferred('a red cat') | ||
result = await operation | ||
print(result) | ||
|
||
# run with a several messages and with saving image to file | ||
path = pathlib.Path('image.jpeg') | ||
try: | ||
operation = await model.run_deferred(['a red cat', 'Miyazaki style']) | ||
result = await operation | ||
path.write_bytes(result.image_bytes) | ||
finally: | ||
path.unlink(missing_ok=True) | ||
|
||
# example of several messages with a weight | ||
operation = await model.run_deferred([{'text': 'a red cat', 'weight': 5}, 'Miyazaki style']) | ||
result = await operation | ||
print(result) | ||
|
||
# example of using yandexgpt and yandex-art models together | ||
gpt = sdk.models.completions('yandexgpt') | ||
messages = await gpt.run([ | ||
'you need to create a prompt for a yandexart model', | ||
'of a cat in a Miyazaki style' | ||
]) | ||
print(messages) | ||
|
||
operation = await model.run_deferred(messages) | ||
result = await operation | ||
print(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
asyncio.run(main()) |
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/usr/bin/env python3 | ||
from __future__ import annotations | ||
|
||
import pathlib | ||
|
||
from yandex_cloud_ml_sdk import YCloudML | ||
|
||
|
||
def main() -> None: | ||
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64') | ||
|
||
model = sdk.models.image_generation('yandex-art') | ||
|
||
# configuring model for all of future runs | ||
model = model.configure(width_ratio=1, height_ratio=2, seed=50) | ||
|
||
# simple run | ||
operation = model.run_deferred('a red cat') | ||
result = operation.wait() | ||
print(result) | ||
|
||
# run with a several messages and with saving image to file | ||
path = pathlib.Path('image.jpeg') | ||
try: | ||
operation = model.run_deferred(['a red cat', 'Miyazaki style']) | ||
result = operation.wait() | ||
path.write_bytes(result.image_bytes) | ||
finally: | ||
path.unlink(missing_ok=True) | ||
|
||
# example of several messages with a weight | ||
operation = model.run_deferred([{'text': 'a red cat', 'weight': 5}, 'Miyazaki style']) | ||
result = operation.wait() | ||
print(result) | ||
|
||
# example of using yandexgpt and yandex-art models together | ||
gpt = sdk.models.completions('yandexgpt') | ||
messages = gpt.run([ | ||
'you need to create a prompt for a yandexart model', | ||
'of a cat in a Miyazaki style' | ||
]) | ||
print(messages) | ||
|
||
operation = model.run_deferred(messages) | ||
result = operation.wait() | ||
print(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
13 changes: 13 additions & 0 deletions
13
src/yandex_cloud_ml_sdk/_models/image_generation/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ImageGenerationModelConfig(BaseModelConfig): | ||
seed: int | None = None | ||
width_ratio: int | None = None | ||
height_ratio: int | None = None | ||
mime_type: str | None = None |
35 changes: 35 additions & 0 deletions
35
src/yandex_cloud_ml_sdk/_models/image_generation/function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import annotations | ||
|
||
from typing_extensions import override | ||
|
||
from yandex_cloud_ml_sdk._types.function import BaseModelFunction, ModelTypeT | ||
|
||
from .model import AsyncImageGenerationModel, ImageGenerationModel | ||
|
||
|
||
class BaseImageGeneration(BaseModelFunction[ModelTypeT]): | ||
@override | ||
def __call__( | ||
self, | ||
model_name: str, | ||
*, | ||
model_version: str = 'latest', | ||
): | ||
if '://' in model_name: | ||
uri = model_name | ||
else: | ||
folder_id = self._sdk._folder_id | ||
uri = f'art://{folder_id}/{model_name}/{model_version}' | ||
|
||
return self._model_type( | ||
sdk=self._sdk, | ||
uri=uri, | ||
) | ||
|
||
|
||
class ImageGeneration(BaseImageGeneration): | ||
_model_type = ImageGenerationModel | ||
|
||
|
||
class AsyncImageGeneration(BaseImageGeneration): | ||
_model_type = AsyncImageGenerationModel |
66 changes: 66 additions & 0 deletions
66
src/yandex_cloud_ml_sdk/_models/image_generation/message.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Iterable, Protocol, TypedDict, Union, cast, runtime_checkable | ||
|
||
from typing_extensions import NotRequired | ||
# pylint: disable-next=no-name-in-module | ||
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_pb2 import Message as ProtoMessage | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ImageMessage: | ||
text: str | ||
weight: float | None = None | ||
|
||
|
||
class ImageMessageDict(TypedDict): | ||
text: str | ||
weight: NotRequired[float] | ||
|
||
|
||
# NB: it supports _messages.message.Message and _models.completions.message.TextMessage | ||
@runtime_checkable | ||
class AnyMessage(Protocol): | ||
text: str | ||
|
||
|
||
ImageMessageType = Union[ImageMessage, ImageMessageDict, AnyMessage, str] | ||
ImageMessageInputType = Union[ImageMessageType, Iterable[ImageMessageType]] | ||
|
||
|
||
def messages_to_proto(messages: ImageMessageInputType) -> list[ProtoMessage]: | ||
msgs: Iterable[ImageMessageType] | ||
if isinstance(messages, (dict, str, ImageMessage, AnyMessage)): | ||
# NB: dict is also Iterable, so techically messages could be a dict[MessageType, str] | ||
# and we are wrongly will get into this branch. | ||
# At least mypy thinks so. | ||
# In real life, messages could be list, tuple, iterator, generator but nothing else. | ||
|
||
msgs = [messages] # type: ignore[list-item] | ||
else: | ||
msgs = messages | ||
|
||
result: list[ProtoMessage] = [] | ||
|
||
for message in msgs: | ||
kwargs: ImageMessageDict | ||
if isinstance(message, str): | ||
kwargs = {'text': message} | ||
elif isinstance(message, AnyMessage): | ||
kwargs = {'text': message.text} | ||
if weight := getattr(message, 'weight', None): | ||
kwargs['weight'] = weight | ||
elif isinstance(message, dict) and 'text' in message: | ||
kwargs = cast(ImageMessageDict, message) | ||
else: | ||
raise TypeError( | ||
f'{message=} should be str, {{["weight": float], "text": str}} dict ' | ||
'or object with a `text: str` field' | ||
) | ||
|
||
result.append( | ||
ProtoMessage(**kwargs) | ||
) | ||
|
||
return result |
110 changes: 110 additions & 0 deletions
110
src/yandex_cloud_ml_sdk/_models/image_generation/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# pylint: disable=arguments-renamed,no-name-in-module | ||
from __future__ import annotations | ||
|
||
from typing_extensions import Self, override | ||
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_pb2 import ( | ||
AspectRatio, ImageGenerationOptions | ||
) | ||
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_service_pb2 import ImageGenerationRequest | ||
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_service_pb2_grpc import ( | ||
ImageGenerationAsyncServiceStub | ||
) | ||
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation | ||
|
||
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr | ||
from yandex_cloud_ml_sdk._types.model import ModelAsyncMixin, OperationTypeT | ||
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation | ||
from yandex_cloud_ml_sdk._utils.sync import run_sync | ||
|
||
from .config import ImageGenerationModelConfig | ||
from .message import ImageMessageInputType, messages_to_proto | ||
from .result import ImageGenerationModelResult | ||
|
||
|
||
class BaseImageGenerationModel( | ||
ModelAsyncMixin[ImageGenerationModelConfig, ImageGenerationModelResult, OperationTypeT], | ||
): | ||
_config_type = ImageGenerationModelConfig | ||
_result_type = ImageGenerationModelResult | ||
_operation_type: type[OperationTypeT] | ||
|
||
# pylint: disable=useless-parent-delegation,arguments-differ | ||
def configure( # type: ignore[override] | ||
self, | ||
*, | ||
seed: UndefinedOr[int] = UNDEFINED, | ||
width_ratio: UndefinedOr[int] = UNDEFINED, | ||
height_ratio: UndefinedOr[int] = UNDEFINED, | ||
mime_type: UndefinedOr[str] = UNDEFINED, | ||
) -> Self: | ||
return super().configure( | ||
seed=seed, | ||
width_ratio=width_ratio, | ||
height_ratio=height_ratio, | ||
mime_type=mime_type, | ||
) | ||
|
||
@override | ||
# pylint: disable-next=arguments-differ | ||
async def _run_deferred( | ||
self, | ||
messages: ImageMessageInputType, | ||
*, | ||
timeout=60 | ||
) -> OperationTypeT: | ||
request = ImageGenerationRequest( | ||
model_uri=self._uri, | ||
messages=messages_to_proto(messages), | ||
generation_options=ImageGenerationOptions( | ||
mime_type=self.config.mime_type or '', | ||
seed=self.config.seed or 0, | ||
aspect_ratio=AspectRatio( | ||
width_ratio=self.config.width_ratio or 0, | ||
height_ratio=self.config.height_ratio or 0, | ||
) | ||
) | ||
) | ||
async with self._client.get_service_stub(ImageGenerationAsyncServiceStub, timeout=timeout) as stub: | ||
response = await self._client.call_service( | ||
stub.Generate, | ||
request, | ||
timeout=timeout, | ||
expected_type=ProtoOperation, | ||
) | ||
return self._operation_type( | ||
id=response.id, | ||
sdk=self._sdk, | ||
result_type=self._result_type | ||
) | ||
|
||
|
||
class AsyncImageGenerationModel(BaseImageGenerationModel[AsyncOperation[ImageGenerationModelResult]]): | ||
_operation_type = AsyncOperation[ImageGenerationModelResult] | ||
|
||
async def run_deferred( | ||
self, | ||
messages: ImageMessageInputType, | ||
*, | ||
timeout: float = 60, | ||
) -> AsyncOperation[ImageGenerationModelResult]: | ||
return await self._run_deferred( | ||
messages=messages, | ||
timeout=timeout | ||
) | ||
|
||
|
||
class ImageGenerationModel(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]): | ||
_operation_type = Operation[ImageGenerationModelResult] | ||
__run_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._run_deferred) | ||
|
||
def run_deferred( | ||
self, | ||
messages: ImageMessageInputType, | ||
*, | ||
timeout: float = 60, | ||
) -> Operation[ImageGenerationModelResult]: | ||
# Mypy thinks that self.__run_deferred returns OperationTypeT instead of Operation | ||
return self.__run_deferred( # type: ignore[return-value] | ||
messages=messages, | ||
timeout=timeout | ||
) |
Oops, something went wrong.