Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audio transcription support #781

Merged
merged 24 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e5a87ae
Updates, docs stubs
andrewfrench May 15, 2024
b5e7997
Audio Loader dependencies
andrewfrench May 15, 2024
2c20378
Remove task, prefer ToolTask with Client?
andrewfrench May 15, 2024
5e73eee
Docs and tests
andrewfrench May 21, 2024
c055935
Fix typo
andrewfrench May 21, 2024
55cbf44
poetry lock --no-update
andrewfrench May 21, 2024
a9162cb
Tasks, tests, docs
andrewfrench May 21, 2024
af4d323
TranscriptionClient tests
andrewfrench May 22, 2024
3d306db
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench May 22, 2024
6715875
Add audio transcription task tests
andrewfrench May 22, 2024
47ebb6a
Fix docs
andrewfrench May 22, 2024
d719edb
poetry run ruff format
andrewfrench May 22, 2024
ae93421
Update changelog
andrewfrench May 23, 2024
91c97ba
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench May 23, 2024
211eeae
Evaluate callable input at runtime
andrewfrench May 24, 2024
dd00a89
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench May 24, 2024
34108eb
Fix docs example
andrewfrench May 24, 2024
2906a9b
Update changelog
andrewfrench May 24, 2024
949792c
Naming
andrewfrench May 29, 2024
0e944d0
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench Jun 3, 2024
09b9e24
poetry lock --no-update
andrewfrench Jun 3, 2024
2c0e261
attr -> attrs
andrewfrench Jun 3, 2024
40692b6
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench Jun 3, 2024
83e8408
Fix integration test
andrewfrench Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration.
- `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models.
- `AudioLoader` for loading audio content into an `AudioArtifact`.
- `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures.
- `OpenAiAudioTranscriptionDriver` for integration with OpenAI's speech-to-text models, including Whisper.
- Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run.

### Changed
Expand Down
21 changes: 21 additions & 0 deletions docs/griptape-framework/data/loaders.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,24 @@ loader.load(EmailLoader.EmailQuery(label="INBOX"))

loader.load_collection([EmailLoader.EmailQuery(label="INBOX"), EmailLoader.EmailQuery(label="SENT")])
```

## Audio Loader

!!! info
This driver requires the `loaders-audio` [extra](../index.md#extras).

The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audioartifact). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory.

The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package.

```python
from griptape.loaders import AudioLoader
from griptape.utils import load_file

# Load an image from disk
with open("tests/resources/sentences.wav", "rb") as f:
audio_artifact = AudioLoader().load(f.read())

# You can also use the load_file utility function
AudioLoader().load(load_file("tests/resources/sentences.wav"))
```
32 changes: 32 additions & 0 deletions docs/griptape-framework/drivers/audio-transcription-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
## Overview

[Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md) extract text from spoken audio.

This driver acts as a critical bridge between audio transcription Engines and the underlying models, facilitating the construction and execution of API calls that transform speech into editable and searchable text. Utilized predominantly in applications that support the input of verbal communications, the Audio Transcription Driver effectively extracts and interprets speech, rendering it into a textual format that can be easily integrated into data systems and Workflows.

This capability is essential for enhancing accessibility, improving content discoverability, and automating tasks that traditionally relied on manual transcription, thereby streamlining operations and enhancing efficiency across various industries.

### OpenAI

The [OpenAI Audio Transcription Driver](../../reference/griptape/drivers/audio_transcription/openai_audio_transcription_driver.md) utilizes OpenAI's sophisticated `whisper` model to accurately transcribe spoken audio into text. This model supports multiple languages, ensuring precise transcription across a wide range of dialects.

```python
from griptape.drivers import OpenAiAudioTranscriptionDriver
from griptape.engines import AudioTranscriptionEngine
from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient
from griptape.structures import Agent


driver = OpenAiAudioTranscriptionDriver(
model="whisper-1"
)

tool = AudioTranscriptionClient(
off_prompt=False,
engine=AudioTranscriptionEngine(
audio_transcription_driver=driver,
),
)

Agent(tools=[tool]).run("Transcribe the following audio file: tests/resources/sentences.wav")
```
23 changes: 23 additions & 0 deletions docs/griptape-framework/engines/audio-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,26 @@ engine.run(
prompts=["Hello, world!"],
)
```

### Audio Transcription Engine

The [Audio Transcription Engine](../../reference/griptape/engines/audio/audio_transcription_engine.md) facilitates transcribing speech from audio inputs.

```python
from griptape.drivers import OpenAiAudioTranscriptionDriver
from griptape.engines import AudioTranscriptionEngine
from griptape.loaders import AudioLoader
from griptape.utils import load_file


driver = OpenAiAudioTranscriptionDriver(
model="whisper-1"
)

engine = AudioTranscriptionEngine(
audio_transcription_driver=driver,
)

audio_artifact = AudioLoader().load(load_file("tests/resources/sentences.wav"))
engine.run(audio_artifact)
```
54 changes: 54 additions & 0 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,57 @@ team = Pipeline(

team.run()
```

## Text to Speech Task

This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).

```python
import os

from griptape.drivers import ElevenLabsTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.tasks import TextToSpeechTask
from griptape.structures import Pipeline


driver = ElevenLabsTextToSpeechDriver(
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
model="eleven_multilingual_v2",
voice="Matilda",
)

task = TextToSpeechTask(
text_to_speech_engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Pipeline(tasks=[task]).run("Generate audio from this text: 'Hello, world!'")
```

## Audio Transcription Task

This Task enables Structures to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md).

```python
from griptape.drivers import OpenAiAudioTranscriptionDriver
from griptape.engines import AudioTranscriptionEngine
from griptape.loaders import AudioLoader
from griptape.tasks import AudioTranscriptionTask
from griptape.structures import Pipeline
from griptape.utils import load_file

driver = OpenAiAudioTranscriptionDriver(
model="whisper-1"
)

task = AudioTranscriptionTask(
input=lambda _: AudioLoader().load(load_file("tests/resources/sentences2.wav")),
audio_transcription_engine=AudioTranscriptionEngine(
audio_transcription_driver=driver,
),
)

Pipeline(tasks=[task]).run()
```
24 changes: 24 additions & 0 deletions docs/griptape-tools/official-tools/audio-transcription-client.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# AudioTranscriptionClient

This Tool enables [Agents](../../griptape-framework/structures/agents.md) to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md).

```python
from griptape.drivers import OpenAiAudioTranscriptionDriver
from griptape.engines import AudioTranscriptionEngine
from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient
from griptape.structures import Agent


driver = OpenAiAudioTranscriptionDriver(
model="whisper-1"
)

tool = AudioTranscriptionClient(
off_prompt=False,
engine=AudioTranscriptionEngine(
audio_transcription_driver=driver,
),
)

Agent(tools=[tool]).run("Transcribe the following audio file: /Users/andrew/code/griptape/tests/resources/sentences2.wav")
```
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TextToSpeechClient

This tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).
This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).

```python
import os
Expand Down
2 changes: 2 additions & 0 deletions griptape/config/base_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BasePromptDriver,
BaseVectorStoreDriver,
BaseTextToSpeechDriver,
BaseAudioTranscriptionDriver,
)
from griptape.utils import dict_merge

Expand All @@ -29,6 +30,7 @@ class BaseStructureConfig(BaseConfig, ABC):
default=None, kw_only=True, metadata={"serializable": True}
)
text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True})
audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True})

def merge_config(self, config: dict) -> BaseStructureConfig:
base_config = self.to_dict()
Expand Down
12 changes: 12 additions & 0 deletions griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
OpenAiImageGenerationDriver,
BaseTextToSpeechDriver,
OpenAiTextToSpeechDriver,
BaseAudioTranscriptionDriver,
OpenAiAudioTranscriptionDriver,
OpenAiImageQueryDriver,
)

Expand Down Expand Up @@ -40,3 +44,11 @@ class OpenAiStructureConfig(StructureConfig):
kw_only=True,
metadata={"serializable": True},
)
text_to_speech_driver: BaseTextToSpeechDriver = field(
default=Factory(lambda: OpenAiTextToSpeechDriver(model="tts")), kw_only=True, metadata={"serializable": True}
)
audio_transcription_driver: BaseAudioTranscriptionDriver = field(
default=Factory(lambda: OpenAiAudioTranscriptionDriver(model="whisper-1")),
kw_only=True,
metadata={"serializable": True},
)
5 changes: 5 additions & 0 deletions griptape/config/structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
BaseImageQueryDriver,
BaseTextToSpeechDriver,
DummyTextToSpeechDriver,
BaseAudioTranscriptionDriver,
DummyAudioTranscriptionDriver,
)


Expand All @@ -43,3 +45,6 @@ class StructureConfig(BaseStructureConfig):
text_to_speech_driver: BaseTextToSpeechDriver = field(
default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True}
)
audio_transcription_driver: BaseAudioTranscriptionDriver = field(
default=Factory(lambda: DummyAudioTranscriptionDriver()), kw_only=True, metadata={"serializable": True}
)
7 changes: 7 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver
from .structure_run.local_structure_run_driver import LocalStructureRunDriver

from .audio_transcription.base_audio_transcription_driver import BaseAudioTranscriptionDriver
from .audio_transcription.dummy_audio_transcription_driver import DummyAudioTranscriptionDriver
from .audio_transcription.openai_audio_transcription_driver import OpenAiAudioTranscriptionDriver

__all__ = [
"BasePromptDriver",
"OpenAiChatPromptDriver",
Expand Down Expand Up @@ -199,4 +203,7 @@
"BaseStructureRunDriver",
"GriptapeCloudStructureRunDriver",
"LocalStructureRunDriver",
"BaseAudioTranscriptionDriver",
"DummyAudioTranscriptionDriver",
"OpenAiAudioTranscriptionDriver",
]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts import TextArtifact, AudioArtifact
from griptape.events import StartAudioTranscriptionEvent, FinishAudioTranscriptionEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.structures import Structure

Check warning on line 13 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L13

Added line #L13 was not covered by tests


@define
class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name these Drivers BaseSpeechToTextDriver for consistency with the inverse Drivers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've steered a bit too far in the direction of naming drivers based on their artifact interfaces. I think the specificity of this name is helpful, what do you think about adding a similarly specific name to the BaseTextToSpeechDriver? BaseSpeechGenerationDriver?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thinking about it more I think this current convention is the more "correct" one.

Down to do a rename of BaseTextToSpeechDriver (though maybe in a separate PR), what do you think about BaseAudioGenerationDriver? This may extend beyond speech in the future?

model: str = field(kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

def before_run(self) -> None:
if self.structure:
self.structure.publish_event(StartAudioTranscriptionEvent())

Check warning on line 23 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L23

Added line #L23 was not covered by tests

def after_run(self) -> None:
if self.structure:
self.structure.publish_event(FinishAudioTranscriptionEvent())

Check warning on line 27 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L27

Added line #L27 was not covered by tests

def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
for attempt in self.retrying():
with attempt:
self.before_run()
result = self.try_run(audio, prompts)
self.after_run()

Check warning on line 34 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L32-L34

Added lines #L32 - L34 were not covered by tests

return result

Check warning on line 36 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L36

Added line #L36 was not covered by tests

else:
raise Exception("Failed to run audio transcription")

Check warning on line 39 in griptape/drivers/audio_transcription/base_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/base_audio_transcription_driver.py#L39

Added line #L39 was not covered by tests

@abstractmethod
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional

from attrs import define, field
from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.exceptions import DummyException


@define
class DummyAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
model: str = field(init=False)

def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact:
raise DummyException(__class__.__name__, "try_transcription")

Check warning on line 14 in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py#L14

Added line #L14 was not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import io
from typing import Optional

import openai
from attrs import field, Factory, define

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver


@define
class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
)
)

def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
additional_params = {}

if prompts is not None:
additional_params["prompt"] = ", ".join(prompts)

Check warning on line 31 in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/openai_audio_transcription_driver.py#L31

Added line #L31 was not covered by tests

transcription = self.client.audio.transcriptions.create(
# Even though we're not actually providing a file to the client, the API still requires that we send a file
# name. We set the file name to use the same format as the audio file so that the API can reject
# it if the format is unsupported.
model=self.model,
file=(f"a.{audio.format}", io.BytesIO(audio.value)),
response_format="json",
**additional_params,
)

return TextArtifact(value=transcription.text)
2 changes: 2 additions & 0 deletions griptape/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine
from .image_query.image_query_engine import ImageQueryEngine
from .audio.text_to_speech_engine import TextToSpeechEngine
from .audio.audio_transcription_engine import AudioTranscriptionEngine

__all__ = [
"BaseQueryEngine",
Expand All @@ -28,4 +29,5 @@
"OutpaintingImageGenerationEngine",
"ImageQueryEngine",
"TextToSpeechEngine",
"AudioTranscriptionEngine",
]
12 changes: 12 additions & 0 deletions griptape/engines/audio/audio_transcription_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from attrs import define, field

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver


@define
class AudioTranscriptionEngine:
audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True)

def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact:
return self.audio_transcription_driver.try_run(audio)

Check warning on line 12 in griptape/engines/audio/audio_transcription_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/audio/audio_transcription_engine.py#L12

Added line #L12 was not covered by tests
Loading
Loading