Skip to content

Commit

Permalink
Merge branch 'main' into config
Browse files Browse the repository at this point in the history
  • Loading branch information
rsdmike committed Oct 30, 2024
2 parents e08af2a + 2a95938 commit ad82171
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 12 deletions.
44 changes: 43 additions & 1 deletion internal/controller/http/v1/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package v1

import (
"net/http"
"time"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"

"github.com/open-amt-cloud-toolkit/console/config"
"github.com/open-amt-cloud-toolkit/console/internal/entity/dto/v1"
"github.com/open-amt-cloud-toolkit/console/internal/usecase/devices"
"github.com/open-amt-cloud-toolkit/console/pkg/consoleerrors"
Expand All @@ -21,6 +24,8 @@ var ErrValidationDevices = dto.NotValidError{Console: consoleerrors.CreateConsol
func NewDeviceRoutes(handler *gin.RouterGroup, t devices.Feature, l logger.Interface) {
r := &deviceRoutes{t, l}

handler.GET("authorize/redirection/:id", r.LoginRedirection)

h := handler.Group("/devices")
{
h.GET("", r.get)
Expand Down Expand Up @@ -72,6 +77,43 @@ func (dr *deviceRoutes) getStats(c *gin.Context) {
c.JSON(http.StatusOK, countResponse)
}

// @Summary route for redirection auth
// @Description gets token for use with redirection
// @ID loginRedirection
// @Tags devices
// @Accept json
// @Produce json
// @Success 200 {object} DeviceCountResponse
// @Failure 500 {object} response
// @Router /api/v1/authorize/redirection [get]
func (dr *deviceRoutes) LoginRedirection(c *gin.Context) {
deviceID := c.Param("id")

_, err := dr.t.GetByID(c.Request.Context(), deviceID, "")
if err != nil {
dr.l.Error(err, "http - devices - v1 - LoginRedirection")
ErrorResponse(c, err)

return
}
// Create JWT token
expirationTime := time.Now().Add(config.ConsoleConfig.JWTExpiration)
claims := jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

tokenString, err := token.SignedString([]byte(config.ConsoleConfig.App.JWTKey))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not create token"})

return
}

c.JSON(http.StatusOK, gin.H{"token": tokenString})
}

// @Summary Show Devices
// @Description Show all devices
// @ID getDevices
Expand All @@ -80,7 +122,7 @@ func (dr *deviceRoutes) getStats(c *gin.Context) {
// @Produce json
// @Success 200 {object} DeviceCountResponse
// @Failure 500 {object} response
// @Router /api/v1/devices [get]
// @Router /api/v1/devices/:id [get]
func (dr *deviceRoutes) get(c *gin.Context) {
var odata OData
if err := c.ShouldBindQuery(&odata); err != nil {
Expand Down
35 changes: 33 additions & 2 deletions internal/controller/ws/v1/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import (
"net/http"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"

"github.com/open-amt-cloud-toolkit/console/config"
"github.com/open-amt-cloud-toolkit/console/internal/usecase/devices"
"github.com/open-amt-cloud-toolkit/console/pkg/logger"
)
Expand All @@ -25,6 +28,36 @@ func RegisterRoutes(r *gin.Engine, l logger.Interface, t devices.Feature, u Upgr
}

func (r *RedirectRoutes) websocketHandler(c *gin.Context) {
tokenString := c.GetHeader("Sec-Websocket-Protocol")

// validate jwt token in the Sec-Websocket-protocol header
if !config.ConsoleConfig.AuthDisabled {
if tokenString == "" {
http.Error(c.Writer, "request does not contain an access token", http.StatusUnauthorized)

return
}

claims := &jwt.MapClaims{}

token, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.ConsoleConfig.App.JWTKey), nil
})

if err != nil || !token.Valid {
http.Error(c.Writer, "invalid access token", http.StatusUnauthorized)

return
}
}

upgrader, ok := r.u.(*websocket.Upgrader)
if !ok {
r.l.Debug("failed to cast Upgrader to *websocket.Upgrader")
} else {
upgrader.Subprotocols = []string{tokenString}
}

conn, err := r.u.Upgrade(c.Writer, c.Request, nil)
if err != nil {
http.Error(c.Writer, "Could not open websocket connection", http.StatusInternalServerError)
Expand All @@ -39,6 +72,4 @@ func (r *RedirectRoutes) websocketHandler(c *gin.Context) {
r.l.Error(err, "http - devices - v1 - redirect")
errorResponse(c, http.StatusInternalServerError, "redirect failed")
}

c.Status(http.StatusSwitchingProtocols)
}
8 changes: 7 additions & 1 deletion internal/controller/ws/v1/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/open-amt-cloud-toolkit/console/config"
"github.com/open-amt-cloud-toolkit/console/internal/mocks"
)

Expand All @@ -23,6 +24,9 @@ func TestWebSocketHandler(t *testing.T) { //nolint:paralleltest // logging libra
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)

_, _ = config.NewConfig()

config.ConsoleConfig.AuthDisabled = true
mockFeature := mocks.NewMockFeature(ctrl)
mockUpgrader := mocks.NewMockUpgrader(ctrl)
mockLogger := mocks.NewMockLogger(ctrl)
Expand All @@ -37,7 +41,7 @@ func TestWebSocketHandler(t *testing.T) { //nolint:paralleltest // logging libra
name: "Success case",
upgraderError: nil,
redirectError: nil,
expectedStatus: http.StatusSwitchingProtocols,
expectedStatus: http.StatusOK,
},
{
name: "Upgrade error",
Expand All @@ -60,11 +64,13 @@ func TestWebSocketHandler(t *testing.T) { //nolint:paralleltest // logging libra
mockUpgrader.EXPECT().
Upgrade(gomock.Any(), gomock.Any(), nil).
Return(nil, tc.upgraderError)
mockLogger.EXPECT().Debug("failed to cast Upgrader to *websocket.Upgrader")
} else {
mockUpgrader.EXPECT().
Upgrade(gomock.Any(), gomock.Any(), nil).
Return(&websocket.Conn{}, nil)

mockLogger.EXPECT().Debug("failed to cast Upgrader to *websocket.Upgrader")
mockLogger.EXPECT().Info("Websocket connection opened")

if tc.redirectError != nil {
Expand Down
43 changes: 39 additions & 4 deletions internal/usecase/profiles/usecase.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package profiles
import (
"context"
"errors"
"strconv"
"strings"

"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/config"
Expand Down Expand Up @@ -53,6 +52,43 @@ func New(r Repository, wifiConfig wificonfigs.Repository, w profilewificonfigs.F
}
}

type (
AuthMethod int
EncryptMethod int
)

const (
WPAPSK AuthMethod = 4
WPAIEEE8021x AuthMethod = 5
WPA2PSK AuthMethod = 6
WPA2IEEE8021x AuthMethod = 7
)

const (
TKIP EncryptMethod = 3
CCMP EncryptMethod = 4
)

var authenticationMethod = map[AuthMethod]string{
WPAPSK: "WPAPSK",
WPAIEEE8021x: "WPAIEEE8021x",
WPA2PSK: "WPA2PSK",
WPA2IEEE8021x: "WPA2IEEE8021x",
}

var encryptionMethod = map[EncryptMethod]string{
TKIP: "TKIP",
CCMP: "CCMP",
}

func (uc *UseCase) getAuthMethodName(method AuthMethod) string {
return authenticationMethod[method]
}

func (uc *UseCase) getEncryptMethodName(method EncryptMethod) string {
return encryptionMethod[method]
}

// History - getting translate history from store.
func (uc *UseCase) GetCount(ctx context.Context, tenantID string) (int, error) {
count, err := uc.repo.GetCount(ctx, tenantID)
Expand Down Expand Up @@ -121,7 +157,6 @@ func (uc *UseCase) HandleIEEE8021xSettings(ctx context.Context, data *entity.Pro
AuthenticationProtocol: ieee8021xconfig.AuthenticationProtocol,
PXETimeout: *ieee8021xconfig.PXETimeout,
}

}

return nil
Expand Down Expand Up @@ -200,8 +235,8 @@ func (uc *UseCase) BuildWirelessProfiles(ctx context.Context, wifiConfigs []dto.
SSID: wifi.SSID,
Priority: wifiConfig.Priority,
Password: wifi.PSKPassphrase,
AuthenticationMethod: strconv.Itoa(wifi.AuthenticationMethod),
EncryptionMethod: strconv.Itoa(wifi.EncryptionMethod),
AuthenticationMethod: uc.getAuthMethodName(AuthMethod(wifi.AuthenticationMethod)),
EncryptionMethod: uc.getEncryptMethodName(EncryptMethod(wifi.EncryptionMethod)),
}

if wifi.IEEE8021xProfileName != nil {
Expand Down
14 changes: 10 additions & 4 deletions internal/usecase/profiles/usecase_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,8 @@ func TestBuildWirelessProfiles(t *testing.T) {
Return(&entity.WirelessConfig{
ProfileName: "wifi-profile-1",
SSID: "wifi-ssid",
AuthenticationMethod: 1,
EncryptionMethod: 2,
AuthenticationMethod: 4,
EncryptionMethod: 4,
PSKPassphrase: "encryptedPassphrase",
}, nil)
},
Expand All @@ -784,8 +784,8 @@ func TestBuildWirelessProfiles(t *testing.T) {
SSID: "wifi-ssid",
Priority: 1,
Password: "decrypted",
AuthenticationMethod: "1",
EncryptionMethod: "2",
AuthenticationMethod: "WPAPSK",
EncryptionMethod: "CCMP",
},
},
err: nil,
Expand Down Expand Up @@ -822,6 +822,7 @@ func TestBuildConfigurationObject(t *testing.T) {
t.Parallel()

originalConfig := local.ConsoleConfig

t.Cleanup(func() {
local.ConsoleConfig = originalConfig
})
Expand Down Expand Up @@ -902,6 +903,11 @@ func TestBuildConfigurationObject(t *testing.T) {
Enabled: true,
AllowNonTLS: true,
},
EnterpriseAssistant: config.EnterpriseAssistant{
URL: "http://test.com:8080",
Username: "username",
Password: "password",
},
AMTSpecific: config.AMTSpecific{
ControlMode: "acmactivate",
AdminPassword: "testAMTPassword",
Expand Down

0 comments on commit ad82171

Please sign in to comment.