Skip to content

Commit

Permalink
Add the yandex art with examples (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar authored Nov 8, 2024
1 parent 48043dc commit 628e166
Show file tree
Hide file tree
Showing 19 changed files with 1,416 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
rev: v4.1.0
hooks:
- id: check-added-large-files
args: ['--maxkb=10000']
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
Expand Down Expand Up @@ -41,6 +42,7 @@ repos:
rev: v2.1.0
hooks:
- id: codespell
args: ["--skip=*.ipynb,*.json"]

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
Expand Down
Empty file.
239 changes: 239 additions & 0 deletions examples/async/image_generation/run_deferred.ipynb

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions examples/async/image_generation/run_deferred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
from __future__ import annotations

import asyncio
import pathlib

from yandex_cloud_ml_sdk import AsyncYCloudML


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')

model = sdk.models.image_generation('yandex-art')

# configuring model for all of future runs
model = model.configure(width_ratio=1, height_ratio=2, seed=50)

# simple run
operation = await model.run_deferred('a red cat')
result = await operation
print(result)

# run with a several messages and with saving image to file
path = pathlib.Path('image.jpeg')
try:
operation = await model.run_deferred(['a red cat', 'Miyazaki style'])
result = await operation
path.write_bytes(result.image_bytes)
finally:
path.unlink(missing_ok=True)

# example of several messages with a weight
operation = await model.run_deferred([{'text': 'a red cat', 'weight': 5}, 'Miyazaki style'])
result = await operation
print(result)

# example of using yandexgpt and yandex-art models together
gpt = sdk.models.completions('yandexgpt')
messages = await gpt.run([
'you need to create a prompt for a yandexart model',
'of a cat in a Miyazaki style'
])
print(messages)

operation = await model.run_deferred(messages)
result = await operation
print(result)


if __name__ == '__main__':
asyncio.run(main())
Empty file.
235 changes: 235 additions & 0 deletions examples/sync/image_generation/run_deferred.ipynb

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions examples/sync/image_generation/run_deferred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
from __future__ import annotations

import pathlib

from yandex_cloud_ml_sdk import YCloudML


def main() -> None:
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')

model = sdk.models.image_generation('yandex-art')

# configuring model for all of future runs
model = model.configure(width_ratio=1, height_ratio=2, seed=50)

# simple run
operation = model.run_deferred('a red cat')
result = operation.wait()
print(result)

# run with a several messages and with saving image to file
path = pathlib.Path('image.jpeg')
try:
operation = model.run_deferred(['a red cat', 'Miyazaki style'])
result = operation.wait()
path.write_bytes(result.image_bytes)
finally:
path.unlink(missing_ok=True)

# example of several messages with a weight
operation = model.run_deferred([{'text': 'a red cat', 'weight': 5}, 'Miyazaki style'])
result = operation.wait()
print(result)

# example of using yandexgpt and yandex-art models together
gpt = sdk.models.completions('yandexgpt')
messages = gpt.run([
'you need to create a prompt for a yandexart model',
'of a cat in a Miyazaki style'
])
print(messages)

operation = model.run_deferred(messages)
result = operation.wait()
print(result)


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from yandex_cloud_ml_sdk._types.function import BaseFunction

from .completions.function import AsyncCompletions, BaseCompletions, Completions
from .image_generation.function import AsyncImageGeneration, ImageGeneration
from .text_classifiers.function import AsyncTextClassifiers, TextClassifiers
from .text_embeddings.function import AsyncTextEmbeddings, TextEmbeddings

Expand All @@ -34,9 +35,11 @@ class AsyncModels(BaseModels):
completions: AsyncCompletions
text_embeddings: AsyncTextEmbeddings
text_classifiers: AsyncTextClassifiers
image_generation: AsyncImageGeneration


class Models(BaseModels):
completions: Completions
text_embeddings: TextEmbeddings
text_classifiers: TextClassifiers
image_generation: ImageGeneration
8 changes: 0 additions & 8 deletions src/yandex_cloud_ml_sdk/_models/completions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,6 @@ async def _run_deferred(
result_type=self._result_type,
)

@override
def attach_async(self, operation_id: str) -> OperationTypeT:
return self._operation_type(
id=operation_id,
sdk=self._sdk,
result_type=self._result_type,
)

async def _tokenize(
self,
messages: MessageInputType,
Expand Down
Empty file.
13 changes: 13 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/image_generation/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass

from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig


@dataclass(frozen=True)
class ImageGenerationModelConfig(BaseModelConfig):
seed: int | None = None
width_ratio: int | None = None
height_ratio: int | None = None
mime_type: str | None = None
35 changes: 35 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/image_generation/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from typing_extensions import override

from yandex_cloud_ml_sdk._types.function import BaseModelFunction, ModelTypeT

from .model import AsyncImageGenerationModel, ImageGenerationModel


class BaseImageGeneration(BaseModelFunction[ModelTypeT]):
@override
def __call__(
self,
model_name: str,
*,
model_version: str = 'latest',
):
if '://' in model_name:
uri = model_name
else:
folder_id = self._sdk._folder_id
uri = f'art://{folder_id}/{model_name}/{model_version}'

return self._model_type(
sdk=self._sdk,
uri=uri,
)


class ImageGeneration(BaseImageGeneration):
_model_type = ImageGenerationModel


class AsyncImageGeneration(BaseImageGeneration):
_model_type = AsyncImageGenerationModel
66 changes: 66 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/image_generation/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Protocol, TypedDict, Union, cast, runtime_checkable

from typing_extensions import NotRequired
# pylint: disable-next=no-name-in-module
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_pb2 import Message as ProtoMessage


@dataclass(frozen=True)
class ImageMessage:
text: str
weight: float | None = None


class ImageMessageDict(TypedDict):
text: str
weight: NotRequired[float]


# NB: it supports _messages.message.Message and _models.completions.message.TextMessage
@runtime_checkable
class AnyMessage(Protocol):
text: str


ImageMessageType = Union[ImageMessage, ImageMessageDict, AnyMessage, str]
ImageMessageInputType = Union[ImageMessageType, Iterable[ImageMessageType]]


def messages_to_proto(messages: ImageMessageInputType) -> list[ProtoMessage]:
msgs: Iterable[ImageMessageType]
if isinstance(messages, (dict, str, ImageMessage, AnyMessage)):
# NB: dict is also Iterable, so techically messages could be a dict[MessageType, str]
# and we are wrongly will get into this branch.
# At least mypy thinks so.
# In real life, messages could be list, tuple, iterator, generator but nothing else.

msgs = [messages] # type: ignore[list-item]
else:
msgs = messages

result: list[ProtoMessage] = []

for message in msgs:
kwargs: ImageMessageDict
if isinstance(message, str):
kwargs = {'text': message}
elif isinstance(message, AnyMessage):
kwargs = {'text': message.text}
if weight := getattr(message, 'weight', None):
kwargs['weight'] = weight
elif isinstance(message, dict) and 'text' in message:
kwargs = cast(ImageMessageDict, message)
else:
raise TypeError(
f'{message=} should be str, {{["weight": float], "text": str}} dict '
'or object with a `text: str` field'
)

result.append(
ProtoMessage(**kwargs)
)

return result
110 changes: 110 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/image_generation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# pylint: disable=arguments-renamed,no-name-in-module
from __future__ import annotations

from typing_extensions import Self, override
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_pb2 import (
AspectRatio, ImageGenerationOptions
)
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_service_pb2 import ImageGenerationRequest
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_service_pb2_grpc import (
ImageGenerationAsyncServiceStub
)
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation

from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr
from yandex_cloud_ml_sdk._types.model import ModelAsyncMixin, OperationTypeT
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
from yandex_cloud_ml_sdk._utils.sync import run_sync

from .config import ImageGenerationModelConfig
from .message import ImageMessageInputType, messages_to_proto
from .result import ImageGenerationModelResult


class BaseImageGenerationModel(
ModelAsyncMixin[ImageGenerationModelConfig, ImageGenerationModelResult, OperationTypeT],
):
_config_type = ImageGenerationModelConfig
_result_type = ImageGenerationModelResult
_operation_type: type[OperationTypeT]

# pylint: disable=useless-parent-delegation,arguments-differ
def configure( # type: ignore[override]
self,
*,
seed: UndefinedOr[int] = UNDEFINED,
width_ratio: UndefinedOr[int] = UNDEFINED,
height_ratio: UndefinedOr[int] = UNDEFINED,
mime_type: UndefinedOr[str] = UNDEFINED,
) -> Self:
return super().configure(
seed=seed,
width_ratio=width_ratio,
height_ratio=height_ratio,
mime_type=mime_type,
)

@override
# pylint: disable-next=arguments-differ
async def _run_deferred(
self,
messages: ImageMessageInputType,
*,
timeout=60
) -> OperationTypeT:
request = ImageGenerationRequest(
model_uri=self._uri,
messages=messages_to_proto(messages),
generation_options=ImageGenerationOptions(
mime_type=self.config.mime_type or '',
seed=self.config.seed or 0,
aspect_ratio=AspectRatio(
width_ratio=self.config.width_ratio or 0,
height_ratio=self.config.height_ratio or 0,
)
)
)
async with self._client.get_service_stub(ImageGenerationAsyncServiceStub, timeout=timeout) as stub:
response = await self._client.call_service(
stub.Generate,
request,
timeout=timeout,
expected_type=ProtoOperation,
)
return self._operation_type(
id=response.id,
sdk=self._sdk,
result_type=self._result_type
)


class AsyncImageGenerationModel(BaseImageGenerationModel[AsyncOperation[ImageGenerationModelResult]]):
_operation_type = AsyncOperation[ImageGenerationModelResult]

async def run_deferred(
self,
messages: ImageMessageInputType,
*,
timeout: float = 60,
) -> AsyncOperation[ImageGenerationModelResult]:
return await self._run_deferred(
messages=messages,
timeout=timeout
)


class ImageGenerationModel(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]):
_operation_type = Operation[ImageGenerationModelResult]
__run_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._run_deferred)

def run_deferred(
self,
messages: ImageMessageInputType,
*,
timeout: float = 60,
) -> Operation[ImageGenerationModelResult]:
# Mypy thinks that self.__run_deferred returns OperationTypeT instead of Operation
return self.__run_deferred( # type: ignore[return-value]
messages=messages,
timeout=timeout
)
Loading

0 comments on commit 628e166

Please sign in to comment.