Skip to content

Commit

Permalink
refactor(hass): 🚚 restructure hass package
Browse files Browse the repository at this point in the history
- create separate api package for sending requests to hass api
- create new request types to send registration and event/sensor requests to hass
- validation back in api package
  • Loading branch information
joshuar committed Oct 26, 2024
1 parent 2da3563 commit 124175f
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 390 deletions.
4 changes: 1 addition & 3 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,13 @@ func Run(ctx context.Context) error {
return
}

client, err := hass.NewClient(ctx)
client, err := hass.NewClient(ctx, prefs.RestAPIURL())
if err != nil {
logging.FromContext(ctx).Error("Cannot connect to Home Assistant.",
slog.Any("error", err))
return
}

client.Endpoint(prefs.RestAPIURL(), hass.DefaultTimeout)

// Initialize and gather OS sensor and event workers.
sensorWorkers, eventWorkers := setupOSWorkers(runCtx)
// Initialize and add connection latency sensor worker.
Expand Down
4 changes: 1 addition & 3 deletions internal/agent/ui/fyneUI/fyneUI.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,12 @@ func (i *FyneUI) aboutWindow(ctx context.Context) fyne.Window {
return nil
}

hassclient, err := hass.NewClient(ctx)
hassclient, err := hass.NewClient(ctx, prefs.RestAPIURL())
if err != nil {
logging.FromContext(ctx).Debug("Cannot create Home Assistant client.", slog.Any("error", err))
return nil
}

hassclient.Endpoint(prefs.RestAPIURL(), hass.DefaultTimeout)

icon := canvas.NewImageFromResource(&ui.TrayIcon{})
icon.FillMode = canvas.ImageFillOriginal

Expand Down
155 changes: 155 additions & 0 deletions internal/hass/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 Joshua Rich <[email protected]>.
// SPDX-License-Identifier: MIT

//go:generate go run github.com/matryer/moq -out api_mocks_test.go . PostRequest
package api

import (
"context"
"fmt"
"log/slog"
"net/http"
"strings"
"time"

"github.com/go-resty/resty/v2"

"github.com/joshuar/go-hass-agent/internal/logging"
)

const (
defaultTimeout = 30 * time.Second
)

var (
client *resty.Client

defaultRetryFunc = func(r *resty.Response, _ error) bool {
return r.StatusCode() == http.StatusTooManyRequests
}
)

func init() {
client = resty.New().
SetTimeout(defaultTimeout).
AddRetryCondition(defaultRetryFunc)
}

type RawRequest interface {
RequestBody() any
}

// Request is a HTTP POST request with the request body provided by Body().
type Request interface {
RequestType() string
RequestData() any
}

// Authenticated represents a request that requires passing an authentication
// header with the value returned by Auth().
type Authenticated interface {
Auth() string
}

// Encrypted represents a request that should be encrypted with the secret
// provided by Secret().
type Encrypted interface {
Secret() string
}

type Validator interface {
Validate() error
}

type requestBody struct {
Data any `json:"data"`
RequestType string `json:"type"`
}

type ResponseError struct {
Code any `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}

func (e *ResponseError) Error() string {
var msg []string
if e.Code != nil {
msg = append(msg, fmt.Sprintf("code %v", e.Code))
}

if e.Message != "" {
msg = append(msg, e.Message)
}

if len(msg) == 0 {
msg = append(msg, "unknown error")
}

return strings.Join(msg, ": ")
}

func Send[T any](ctx context.Context, url string, details any) (T, error) {
var (
response T
responseErr ResponseError
responseObj *resty.Response
)

requestClient := client.R().SetContext(ctx)
requestClient = requestClient.SetError(&responseErr)
requestClient = requestClient.SetResult(&response)

// If the request is authenticated, set the auth header with the token.
if a, ok := details.(Authenticated); ok {
requestClient = requestClient.SetAuthToken(a.Auth())
}

// If the request can be validated, validate it.
if v, ok := details.(Validator); ok {
if err := v.Validate(); err != nil {
return response, fmt.Errorf("invalid request: %w", err)
}
}

switch request := details.(type) {
case Request:
body := &requestBody{
RequestType: request.RequestType(),
Data: request.RequestData(),
}
logging.FromContext(ctx).
LogAttrs(ctx, logging.LevelTrace,
"Sending request.",
slog.String("method", "POST"),
slog.String("url", url),
slog.Any("body", body),
slog.Time("sent_at", time.Now()))

responseObj, _ = requestClient.SetBody(body).Post(url) //nolint:errcheck // error is checked with responseObj.IsError()
case RawRequest:
logging.FromContext(ctx).
LogAttrs(ctx, logging.LevelTrace,
"Sending request.",
slog.String("method", "POST"),
slog.String("url", url),
slog.Any("body", request),
slog.Time("sent_at", time.Now()))

responseObj, _ = requestClient.SetBody(request).Post(url) //nolint:errcheck // error is checked with responseObj.IsError()
}

logging.FromContext(ctx).
LogAttrs(ctx, logging.LevelTrace,
"Received response.",
slog.Int("statuscode", responseObj.StatusCode()),
slog.String("status", responseObj.Status()),
slog.String("protocol", responseObj.Proto()),
slog.Duration("time", responseObj.Time()),
slog.String("body", string(responseObj.Body())))

if responseObj.IsError() {
return response, &ResponseError{Code: responseObj.StatusCode(), Message: responseObj.Status()}
}

return response, nil
}
12 changes: 5 additions & 7 deletions internal/hass/response_test.go → internal/hass/api/api_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
// Copyright (c) 2024 Joshua Rich <[email protected]>
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
// Copyright 2024 Joshua Rich <[email protected]>.
// SPDX-License-Identifier: MIT

package hass
package api

import "testing"

func Test_apiError_Error(t *testing.T) {
func Test_ResponseError_Error(t *testing.T) {
type fields struct {
Code any
Message string
Expand Down Expand Up @@ -39,7 +37,7 @@ func Test_apiError_Error(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &apiError{
e := &ResponseError{
Code: tt.fields.Code,
Message: tt.fields.Message,
}
Expand Down
97 changes: 32 additions & 65 deletions internal/hass/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// Copyright (c) 2024 Joshua Rich <[email protected]>
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
// Copyright 2024 Joshua Rich <[email protected]>.
// SPDX-License-Identifier: MIT

package hass

Expand All @@ -10,11 +8,9 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"time"

"github.com/go-resty/resty/v2"

"github.com/joshuar/go-hass-agent/internal/hass/api"
"github.com/joshuar/go-hass-agent/internal/hass/event"
"github.com/joshuar/go-hass-agent/internal/hass/sensor"
"github.com/joshuar/go-hass-agent/internal/hass/sensor/registry"
Expand Down Expand Up @@ -48,10 +44,6 @@ var (
ErrUnknown = errors.New("unknown error occurred")

ErrInvalidSensor = errors.New("invalid sensor")

defaultRetry = func(r *resty.Response, _ error) bool {
return r.StatusCode() == http.StatusTooManyRequests
}
)

type Registry interface {
Expand All @@ -62,10 +54,10 @@ type Registry interface {
}

type Client struct {
endpoint *resty.Client
url string
}

func NewClient(ctx context.Context) (*Client, error) {
func NewClient(ctx context.Context, url string) (*Client, error) {
var err error

sensorTracker = sensor.NewTracker()
Expand All @@ -75,22 +67,11 @@ func NewClient(ctx context.Context) (*Client, error) {
return nil, fmt.Errorf("could not start registry: %w", err)
}

return &Client{}, nil
}

func (c *Client) Endpoint(url string, timeout time.Duration) {
if timeout == 0 {
timeout = DefaultTimeout
}

c.endpoint = resty.New().
SetTimeout(timeout).
AddRetryCondition(defaultRetry).
SetBaseURL(url)
return &Client{url: url}, nil
}

func (c *Client) HassVersion(ctx context.Context) string {
config, err := send[Config](ctx, c, &configRequest{})
config, err := api.Send[Config](ctx, c.url, &configRequest{})
if err != nil {
logging.FromContext(ctx).
Debug("Could not fetch Home Assistant config.",
Expand All @@ -103,13 +84,7 @@ func (c *Client) HassVersion(ctx context.Context) string {
}

func (c *Client) ProcessEvent(ctx context.Context, details event.Event) error {
req := &request{Data: details, RequestType: requestTypeEvent}

if err := req.Validate(); err != nil {
return fmt.Errorf("invalid event request: %w", err)
}

resp, err := send[response](ctx, c, req)
resp, err := api.Send[response](ctx, c.url, details)
if err != nil {
return fmt.Errorf("failed to send event request: %w", err)
}
Expand All @@ -122,12 +97,21 @@ func (c *Client) ProcessEvent(ctx context.Context, details event.Event) error {
}

func (c *Client) ProcessSensor(ctx context.Context, details sensor.Entity) error {
req := &request{}
// Location request.
if req, ok := details.Value.(*sensor.Location); ok {
resp, err := api.Send[response](ctx, c.url, req)
if err != nil {
return fmt.Errorf("failed to send location update: %w", err)
}

if _, ok := details.Value.(*LocationRequest); ok {
req = &request{Data: details.Value, RequestType: requestTypeLocation}
if _, err := resp.Status(); err != nil {
return err
}

return nil
}

// Sensor update.
if sensorRegistry.IsRegistered(details.ID) {
// For sensor updates, if the sensor is disabled, don't continue.
if c.isDisabled(ctx, details) {
Expand All @@ -138,41 +122,24 @@ func (c *Client) ProcessSensor(ctx context.Context, details sensor.Entity) error
return nil
}

req = &request{Data: details.State, RequestType: requestTypeUpdate}
} else {
req = &request{Data: details, RequestType: requestTypeRegister}
}

if err := req.Validate(); err != nil {
return fmt.Errorf("invalid sensor request: %w", err)
}

switch req.RequestType {
case requestTypeLocation:
resp, err := send[response](ctx, c, req)
if err != nil {
return fmt.Errorf("failed to send location update: %w", err)
}

if _, err := resp.Status(); err != nil {
return err
}
case requestTypeUpdate:
resp, err := send[bulkSensorUpdateResponse](ctx, c, req)
resp, err := api.Send[bulkSensorUpdateResponse](ctx, c.url, details.State)
if err != nil {
return fmt.Errorf("failed to send location update: %w", err)
return fmt.Errorf("failed to send sensor update: %w", err)
}

go resp.Process(ctx, details)
case requestTypeRegister:
resp, err := send[registrationResponse](ctx, c, req)
if err != nil {
return fmt.Errorf("failed to send location update: %w", err)
}

go resp.Process(ctx, details)
return nil
}

// Sensor registration.
resp, err := api.Send[registrationResponse](ctx, c.url, details)
if err != nil {
return fmt.Errorf("failed to send sensor registration: %w", err)
}

go resp.Process(ctx, details)

return nil
}

Expand Down Expand Up @@ -219,7 +186,7 @@ func (c *Client) isDisabledInReg(id string) bool {

// isDisabledInHA returns the disabled state of the sensor from Home Assistant.
func (c *Client) isDisabledInHA(ctx context.Context, details sensor.Entity) bool {
config, err := send[Config](ctx, c, &configRequest{})
config, err := api.Send[Config](ctx, c.url, &configRequest{})
if err != nil {
logging.FromContext(ctx).
Debug("Could not fetch Home Assistant config. Assuming sensor is still disabled.",
Expand Down
Loading

0 comments on commit 124175f

Please sign in to comment.