Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change assist satellite announce method signature #126299

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .const import DOMAIN, AssistSatelliteEntityFeature
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
Expand All @@ -22,6 +23,7 @@

__all__ = [
"DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteEntity",
"AssistSatelliteConfiguration",
"AssistSatelliteEntityDescription",
Expand Down
29 changes: 26 additions & 3 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import StrEnum
import logging
import time
from typing import Any, Final, final
from typing import Any, Final, Literal, final

from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
Expand Down Expand Up @@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""


@dataclass
class AssistSatelliteAnnouncement:
"""Announcement to be made."""

message: str
"""Message to be spoken."""

media_id: str
"""Media ID to be played."""

media_id_source: Literal["url", "media_id", "tts"]


class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""

Expand Down Expand Up @@ -174,10 +187,13 @@ async def async_internal_announce(
"""
await self._cancel_running_pipeline()

media_id_source: Literal["url", "media_id", "tts"] | None = None

if message is None:
message = ""

if not media_id:
media_id_source = "tts"
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
Expand All @@ -198,13 +214,18 @@ async def async_internal_announce(
)

if media_source.is_media_source_id(media_id):
if not media_id_source:
media_id_source = "media_id"
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url

if not media_id_source:
media_id_source = "url"

# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)

Expand All @@ -216,12 +237,14 @@ async def async_internal_announce(

try:
# Block until announcement is finished
await self.async_announce(message, media_id)
await self.async_announce(
AssistSatelliteAnnouncement(message, media_id, media_id_source)
)
finally:
self._is_announcing = False
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.

Should block until the announcement is done playing.
Expand Down
10 changes: 6 additions & 4 deletions homeassistant/components/esphome/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:

self.cli.send_voice_assistant_event(event_type, data_to_send)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(
self, announcement: assist_satellite.AssistSatelliteAnnouncement
) -> None:
"""Announce media on the satellite.

Should block until the announcement is done playing.
"""
_LOGGER.debug(
"Waiting for announcement to finished (message=%s, media_id=%s)",
message,
media_id,
announcement.message,
announcement.media_id,
)
await self.cli.send_voice_assistant_announcement_await_response(
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message
announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
)

async def handle_pipeline_start(
Expand Down
5 changes: 3 additions & 2 deletions tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
Expand Down Expand Up @@ -63,9 +64,9 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on a device."""
self.announcements.append((message, media_id))
self.announcements.append(announcement)

@callback
def async_get_configuration(self) -> AssistSatelliteConfiguration:
Expand Down
23 changes: 15 additions & 8 deletions tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import SatelliteBusyError
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
SatelliteBusyError,
)
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -159,18 +162,22 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
"media_id": "media-source://bla",
},
("Hello", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
),
],
)
Expand All @@ -195,10 +202,10 @@ async def test_announce(
original_announce = entity.async_announce
announce_started = asyncio.Event()

async def async_announce(message, media_id):
async def async_announce(announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_announce(message, media_id)
await original_announce(announcement)
announce_started.set()

def tts_generate_media_source_id(
Expand Down Expand Up @@ -249,7 +256,7 @@ async def test_announce_busy(
announce_started = asyncio.Event()
got_error = asyncio.Event()

async def async_announce(message, media_id):
async def async_announce(announcement):
announce_started.set()

# Block so we can do another announcement
Expand Down