Skip to content

Commit

Permalink
Add AzureOpenAiTextToSpeechDriver (#1150)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored Sep 9, 2024
1 parent 9735d88 commit 5b56867
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ jobs:
AZURE_OPENAI_API_KEY_2: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_2 }}
AZURE_OPENAI_ENDPOINT_3: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_3 }}
AZURE_OPENAI_API_KEY_3: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_3 }}
AZURE_OPENAI_ENDPOINT_4: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_4 }}
AZURE_OPENAI_API_KEY_4: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_4 }}
AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_16K_DEPLOYMENT_ID }}
AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_DEPLOYMENT_ID }}
AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_DAVINCI_DEPLOYMENT_ID }}
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
20 changes: 20 additions & 0 deletions docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from griptape.drivers import AzureOpenAiTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.structures import Agent
from griptape.tools.text_to_speech.tool import TextToSpeechTool

driver = AzureOpenAiTextToSpeechDriver(
api_key=os.environ["AZURE_OPENAI_API_KEY_4"],
model="tts",
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_4"],
)

tool = TextToSpeechTool(
engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'")
8 changes: 8 additions & 0 deletions docs/griptape-framework/drivers/text-to-speech-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ The [OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_spee
```python
--8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py"
```

## Azure OpenAI

The [Azure OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.md) provides support for text-to-speech models hosted in your Azure OpenAI instance. This Driver supports configurations specific to OpenAI, like voice selection and output format.

```python
--8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py"
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver
from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver
from .text_to_speech.openai_text_to_speech_driver import OpenAiTextToSpeechDriver
from .text_to_speech.azure_openai_text_to_speech_driver import AzureOpenAiTextToSpeechDriver

from .structure_run.base_structure_run_driver import BaseStructureRunDriver
from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver
Expand Down Expand Up @@ -227,6 +228,7 @@
"DummyTextToSpeechDriver",
"ElevenLabsTextToSpeechDriver",
"OpenAiTextToSpeechDriver",
"AzureOpenAiTextToSpeechDriver",
"BaseStructureRunDriver",
"GriptapeCloudStructureRunDriver",
"LocalStructureRunDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Callable, Optional

import openai
from attrs import Factory, define, field

from griptape.drivers import OpenAiTextToSpeechDriver


@define
class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver):
"""Azure OpenAi Text to Speech Driver.
Attributes:
azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

model: str = field(default="tts", kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True,
default=Factory(lambda self: self.model, takes_self=True),
metadata={"serializable": True},
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True,
default=None,
metadata={"serializable": False},
)
api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True})
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import Mock

import pytest

from griptape.drivers import AzureOpenAiTextToSpeechDriver


class TestAzureOpenAiTextToSpeechDriver:
@pytest.fixture()
def mock_speech_create(self, mocker):
mock_speech_create = mocker.patch("openai.AzureOpenAI").return_value.audio.speech.create
mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id")
mock_function.name = "MockTool_test"
mock_speech_create.return_value = Mock(
content=b"speech",
)

return mock_speech_create

def test_init(self):
assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar", azure_deployment="foobar")
assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar").azure_deployment == "tts"

def test_run_text_to_audio(self, mock_speech_create):
driver = AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar")
output = driver.run_text_to_audio(["foo", "bar"])
mock_speech_create.assert_called_once_with(
input="foo. bar",
model=driver.model,
response_format=driver.format,
voice=driver.voice,
)
assert output.value == b"speech"

0 comments on commit 5b56867

Please sign in to comment.