Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection test feature to assist_satellite #126256

Merged
merged 14 commits into from
Sep 22, 2024
Merged
10 changes: 9 additions & 1 deletion homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType

from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
from .connection_test import ConnectionTestView
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
Expand Down Expand Up @@ -57,7 +63,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())

return True

Expand Down
Binary file not shown.
43 changes: 43 additions & 0 deletions homeassistant/components/assist_satellite/connection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Assist satellite connection test."""

import logging
from pathlib import Path

from aiohttp import web

from homeassistant.components.http import KEY_HASS, HomeAssistantView

from .const import CONNECTION_TEST_DATA

_LOGGER = logging.getLogger(__name__)

CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg"
CONNECTION_TEST_FILENAME = "connection_test.mp3"
CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test"


class ConnectionTestView(HomeAssistantView):
"""View to serve an audio sample for connection test."""

requires_auth = False
url = f"{CONNECTION_TEST_URL_BASE}/{{connection_id}}"
name = "api:assist_satellite_connection_test"

async def get(self, request: web.Request, connection_id: str) -> web.Response:
"""Start a get request."""
_LOGGER.debug("Request for connection test with id %s", connection_id)

hass = request.app[KEY_HASS]
connection_test_data = hass.data[CONNECTION_TEST_DATA]

connection_test_event = connection_test_data.pop(connection_id, None)

if connection_test_event is None:
return web.Response(status=404)

Check warning on line 36 in homeassistant/components/assist_satellite/connection_test.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/assist_satellite/connection_test.py#L36

Added line #L36 was not covered by tests

connection_test_event.set()

audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME
audio_data = await hass.async_add_executor_job(audio_path.read_bytes)

return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE)
4 changes: 4 additions & 0 deletions homeassistant/components/assist_satellite/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
from enum import IntFlag
from typing import TYPE_CHECKING

Expand All @@ -15,6 +16,9 @@
DOMAIN = "assist_satellite"

DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN)
CONNECTION_TEST_DATA: HassKey[dict[str, asyncio.Event]] = HassKey(
f"{DOMAIN}_connection_tests"
)


class AssistSatelliteEntityFeature(IntFlag):
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/assist_satellite/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"
Expand Down
69 changes: 68 additions & 1 deletion homeassistant/components/assist_satellite/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Assist satellite Websocket API."""

import asyncio
from dataclasses import asdict, replace
from typing import Any

Expand All @@ -9,8 +10,19 @@
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util

from .connection_test import CONNECTION_TEST_URL_BASE
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import AssistSatelliteEntity

from .const import DOMAIN, DOMAIN_DATA
CONNECTION_TEST_TIMEOUT = 30


@callback
Expand All @@ -19,6 +31,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)


@callback
Expand Down Expand Up @@ -138,3 +151,57 @@ async def websocket_set_wake_words(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])


@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Test the connection between the device and Home Assistant.

Send an announcement to the device with a special media id.
"""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return

# Announce and wait for event
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()
connection_test_data[connection_id] = connection_test_event

hass.async_create_background_task(
satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
),
f"assist_satellite_connection_test_{msg['entity_id']}",
)

try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "success"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "timeout"})
finally:
connection_test_data.pop(connection_id, None)
2 changes: 1 addition & 1 deletion tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []
self.announcements: list[AssistSatelliteAnnouncement] = []
self.config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord(
Expand Down
133 changes: 132 additions & 1 deletion tests/components/assist_satellite/test_websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
"""Test WebSocket API."""

import asyncio
from http import HTTPStatus
from unittest.mock import patch

from freezegun.api import FrozenDateTimeFactory
import pytest

from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.components.assist_satellite.websocket_api import (
CONNECTION_TEST_TIMEOUT,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant

from . import ENTITY_ID
from .conftest import MockAssistSatellite

from tests.common import MockUser
from tests.typing import WebSocketGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator


async def test_intercept_wake_word(
Expand Down Expand Up @@ -385,3 +390,129 @@ async def test_set_wake_words_bad_id(
"code": "not_supported",
"message": "Wake word id is not supported: abcd",
}


async def test_connection_test(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
) -> None:
"""Test connection test."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)

for _ in range(3):
await asyncio.sleep(0)

assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)

# Fake satellite fetches the URL
client = await hass_client()
resp = await client.get(announcement_media_id[len(hass_url) :])
assert resp.status == HTTPStatus.OK

response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "success"}


async def test_connection_test_timeout(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test connection test timeout."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)

for _ in range(3):
await asyncio.sleep(0)

assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)

freezer.tick(CONNECTION_TEST_TIMEOUT + 1)

# Timeout
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "timeout"}


async def test_connection_test_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test with unknown entity id."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}


async def test_connection_test_timeout_announcement_unsupported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test entity which does not support announce."""
ws_client = await hass_ws_client(hass)

# Disable announce support
entity.supported_features = 0

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
"code": "not_supported",
"message": "Entity does not support announce",
}