Skip to content

Commit

Permalink
gs: Assert gateway rights on discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
adriansmares committed Dec 1, 2023
1 parent 667ee19 commit da3af9a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ For details about compatibility between different releases, see the **Commitment
- Server side events replaced with single socket connection using the native WebSocket API.
- Gateways now disconnect if the Gateway Server address has changed.
- This enables CUPS-enabled gateways to change their LNS before the periodic CUPS lookup occurs.
- The LoRa Basics Station discovery endpoint now verifies the authorization credentials of the caller.
- This enables the gateways to migrate to another instance gracefully while using CUPS.

### Deprecated

Expand Down
16 changes: 13 additions & 3 deletions pkg/gatewayserver/io/ws/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ type Formatter interface {
Endpoints() Endpoints
// HandleConnectionInfo handles connection information requests from web socket based protocols.
// This function returns a byte stream that contains connection information (ex: scheme, host, port etc) or an error if applicable.
HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, serverInfo ServerInfo, receivedAt time.Time) []byte
HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
serverInfo ServerInfo,
assertAuth func(context.Context, *ttnpb.GatewayIdentifiers) error,
) []byte
// HandleUp handles upstream messages from web socket based gateways.
// This function optionally returns a byte stream to be sent as response to the upstream message.
HandleUp(ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time) ([]byte, error)
HandleUp(
ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time,
) ([]byte, error)
// FromDownlink generates a downlink byte stream that can be sent over the WS connection.
FromDownlink(ctx context.Context, down *ttnpb.DownlinkMessage, bandID string, dlTime time.Time) ([]byte, error)
// TransferTime generates a spurious time transfer message for a particular server time.
TransferTime(ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime) ([]byte, error)
TransferTime(
ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime,
) ([]byte, error)
}
13 changes: 11 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"fmt"
"time"

"go.thethings.network/lorawan-stack/v3/pkg/errors"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io"
Expand All @@ -32,7 +31,13 @@ import (
var errEmptyGatewayEUI = errors.DefineFailedPrecondition("empty_gateway_eui", "empty gateway EUI")

// HandleConnectionInfo implements Formatter.
func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, info ws.ServerInfo, receivedAt time.Time) []byte {
func (f *lbsLNS) HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
info ws.ServerInfo,
assertAuth func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error,
) []byte {
var req DiscoverQuery

if err := json.Unmarshal(raw, &req); err != nil {
Expand All @@ -52,6 +57,10 @@ func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io
}
ctx = filledCtx

if err := assertAuth(ctx, ids); err != nil {
return logAndWrapDiscoverError(ctx, err, fmt.Sprintf("Unauthorized"))
}

euiWithPrefix := fmt.Sprintf("eui-%s", types.MustEUI64(ids.Eui).OrZero().String())
res := DiscoverResponse{
EUI: req.EUI,
Expand Down
5 changes: 3 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"testing"
"time"

"github.com/smarty/assertions"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io/ws"
Expand Down Expand Up @@ -69,9 +68,11 @@ func TestDiscover(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
msg, err := json.Marshal(tc.Query)
a.So(err, should.BeNil)
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, time.Now())
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, noopAssertRights)
expected, _ := json.Marshal(tc.ExpectedResponse)
a.So(string(resp), should.Equal, string(expected))
})
}
}

func noopAssertRights(context.Context, *ttnpb.GatewayIdentifiers) error { return nil }
72 changes: 35 additions & 37 deletions pkg/gatewayserver/io/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) {
"remote_addr", r.RemoteAddr,
))
logger := log.FromContext(ctx)

assertAuth := func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error {
ctx, hasAuth := withForwardedAuth(ctx, ids, r.Header.Get("Authorization"))
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
return errNoAuthProvided.WithAttributes("uid", unique.ID(ctx, ids))
}
return nil
}
return s.server.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK)
}

ws, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.WithError(err).Debug("Failed to upgrade request to websocket connection")
Expand Down Expand Up @@ -154,7 +166,7 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) {
Address: net.JoinHostPort(host, port),
}

resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, time.Now())
resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, assertAuth)
if err := ws.WriteMessage(websocket.TextMessage, resp); err != nil {
logger.WithError(err).Warn("Failed to write connection info response message")
return
Expand Down Expand Up @@ -202,47 +214,12 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error)
uid := unique.ID(ctx, ids)
ctx = log.NewContextWithField(ctx, "gateway_uid", uid)

var md metadata.MD

if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"id": ids.GatewayId,
"authorization": auth,
})
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"authorization": auth,
})
hasAuth = true
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

ctx, hasAuth := withForwardedAuth(ctx, ids, auth)
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
// We error here directly as there is no auth.
Expand Down Expand Up @@ -416,3 +393,24 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error)
}
}
}

func withForwardedAuth(ctx context.Context, ids *ttnpb.GatewayIdentifiers, auth string) (context.Context, bool) {
var md metadata.MD
var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
m := map[string]string{"authorization": auth}
if ids != nil {
m["id"] = ids.GatewayId
}
md = metadata.New(m)
if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
hasAuth = true
}
return ctx, hasAuth
}

0 comments on commit da3af9a

Please sign in to comment.