Skip to content

Commit

Permalink
MG-1887 - Add support for OAuth2.0 (absmach#2103)
Browse files Browse the repository at this point in the history
Signed-off-by: Rodney Osodo <[email protected]>
  • Loading branch information
rodneyosodo authored Mar 1, 2024
1 parent adb03b4 commit 0f05c10
Show file tree
Hide file tree
Showing 43 changed files with 1,672 additions and 386 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/check-generated-files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ jobs:
- name: Set up protoc
if: steps.changes.outputs.proto == 'true'
run: |
PROTOC_VERSION=25.1
PROTOC_GEN_VERSION=v1.31.0
PROTOC_VERSION=25.3
PROTOC_GEN_VERSION=v1.32.0
PROTOC_GRPC_VERSION=v1.3.0
# Download and install protoc
Expand Down
611 changes: 322 additions & 289 deletions auth.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions auth.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ message IssueReq {
string user_id = 1;
optional string domain_id = 2;
uint32 type = 3;
string oauth_provider = 4;
string oauth_access_token = 5;
string oauth_refresh_token = 6;
}

message RefreshReq {
Expand Down
20 changes: 18 additions & 2 deletions auth/api/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,16 @@ func (client grpcClient) Issue(ctx context.Context, req *magistrala.IssueReq, _
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()

res, err := client.issue(ctx, issueReq{userID: req.GetUserId(), domainID: req.GetDomainId(), keyType: auth.KeyType(req.Type)})
res, err := client.issue(ctx, issueReq{
userID: req.GetUserId(),
domainID: req.GetDomainId(),
keyType: auth.KeyType(req.GetType()),
oauthToken: auth.OAuthToken{
Provider: req.GetOauthProvider(),
AccessToken: req.GetOauthAccessToken(),
RefreshToken: req.GetOauthRefreshToken(),
},
})
if err != nil {
return &magistrala.Token{}, decodeError(err)
}
Expand All @@ -183,7 +192,14 @@ func (client grpcClient) Issue(ctx context.Context, req *magistrala.IssueReq, _

func encodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(issueReq)
return &magistrala.IssueReq{UserId: req.userID, DomainId: &req.domainID, Type: uint32(req.keyType)}, nil
return &magistrala.IssueReq{
UserId: req.userID,
DomainId: &req.domainID,
Type: uint32(req.keyType),
OauthProvider: req.oauthToken.Provider,
OauthAccessToken: req.oauthToken.AccessToken,
OauthRefreshToken: req.oauthToken.RefreshToken,
}, nil
}

func decodeIssueResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
Expand Down
1 change: 1 addition & 0 deletions auth/api/grpc/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
Type: req.keyType,
User: req.userID,
Domain: req.domainID,
OAuth: req.oauthToken,
}
tkn, err := svc.Issue(ctx, "", key)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions auth/api/grpc/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ func (req identityReq) validate() error {
}

type issueReq struct {
userID string
domainID string // optional
keyType auth.KeyType
userID string
domainID string // optional
keyType auth.KeyType
oauthToken auth.OAuthToken
}

func (req issueReq) validate() error {
Expand Down
11 changes: 10 additions & 1 deletion auth/api/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,16 @@ func (s *grpcServer) ListPermissions(ctx context.Context, req *magistrala.ListPe

func decodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*magistrala.IssueReq)
return issueReq{userID: req.GetUserId(), domainID: req.GetDomainId(), keyType: auth.KeyType(req.GetType())}, nil
return issueReq{
userID: req.GetUserId(),
domainID: req.GetDomainId(),
keyType: auth.KeyType(req.GetType()),
oauthToken: auth.OAuthToken{
Provider: req.GetOauthProvider(),
AccessToken: req.GetOauthAccessToken(),
RefreshToken: req.GetOauthRefreshToken(),
},
}, nil
}

func decodeRefreshRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
Expand Down
5 changes: 4 additions & 1 deletion auth/api/http/keys/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/absmach/magistrala/internal/apiutil"
mglog "github.com/absmach/magistrala/logger"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
oauth2mocks "github.com/absmach/magistrala/pkg/oauth2/mocks"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -72,7 +73,9 @@ func newService() (auth.Service, *mocks.KeyRepository) {
drepo := new(mocks.DomainsRepository)
idProvider := uuid.NewMock()

t := jwt.New([]byte(secret))
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
t := jwt.New([]byte(secret), provider)

return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), krepo
}
Expand Down
196 changes: 182 additions & 14 deletions auth/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
package jwt_test

import (
"context"
"fmt"
"strings"
"testing"
"time"

"github.com/absmach/magistrala/auth"
authjwt "github.com/absmach/magistrala/auth/jwt"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
oauth2mocks "github.com/absmach/magistrala/pkg/oauth2/mocks"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

const (
Expand All @@ -31,18 +37,6 @@ var (
reposecret = []byte("test")
)

func key() auth.Key {
exp := time.Now().UTC().Add(10 * time.Minute).Round(time.Second)
return auth.Key{
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
Type: auth.AccessKey,
Issuer: "magistrala.auth",
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: exp,
}
}

func newToken(issuerName string, key auth.Key) string {
builder := jwt.NewBuilder()
builder.
Expand All @@ -62,7 +56,9 @@ func newToken(issuerName string, key auth.Key) string {
}

func TestIssue(t *testing.T) {
tokenizer := authjwt.New([]byte(secret))
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
tokenizer := authjwt.New([]byte(secret), provider)

cases := []struct {
desc string
Expand All @@ -74,6 +70,24 @@ func TestIssue(t *testing.T) {
key: key(),
err: nil,
},
{
desc: "issue token with OAuth token",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: testsutil.GenerateUUID(t),
User: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
OAuth: auth.OAuthToken{
Provider: "test",
AccessToken: strings.Repeat("a", 1024),
RefreshToken: strings.Repeat("b", 1024),
},
},
err: nil,
},
}

for _, tc := range cases {
Expand All @@ -86,7 +100,9 @@ func TestIssue(t *testing.T) {
}

func TestParse(t *testing.T) {
tokenizer := authjwt.New([]byte(secret))
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
tokenizer := authjwt.New([]byte(secret), provider)

token, err := tokenizer.Issue(key())
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
Expand Down Expand Up @@ -156,3 +172,155 @@ func TestParse(t *testing.T) {
}
}
}

func TestParseOAuthToken(t *testing.T) {
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
tokenizer := authjwt.New([]byte(secret), provider)

validKey := oauthKey(t)
invalidKey := oauthKey(t)
invalidKey.OAuth.Provider = "invalid"

cases := []struct {
desc string
token auth.Key
issuedToken string
key auth.Key
validateErr error
refreshToken oauth2.Token
refreshErr error
err error
}{
{
desc: "parse valid key",
token: validKey,
issuedToken: "",
key: validKey,
validateErr: nil,
refreshErr: nil,
err: nil,
},
{
desc: "parse invalid key but refreshed",
token: validKey,
issuedToken: "",
key: validKey,
validateErr: svcerr.ErrAuthentication,
refreshToken: oauth2.Token{
AccessToken: strings.Repeat("a", 10),
RefreshToken: strings.Repeat("b", 10),
},
refreshErr: nil,
err: nil,
},
{
desc: "parse invalid key but not refreshed",
token: validKey,
issuedToken: "",
key: validKey,
validateErr: svcerr.ErrAuthentication,
refreshToken: oauth2.Token{},
refreshErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with different provider",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", "b"),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid access token",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", 123, "b"),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid refresh token",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", 123),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid provider",
issuedToken: invalidOauthToken(t, invalidKey, "test", "a", "b"),
err: svcerr.ErrAuthentication,
},
}

for _, tc := range cases {
tokenCall := provider.On("Name").Return("test")
tokenCall1 := provider.On("Validate", context.Background(), mock.Anything).Return(tc.validateErr)
tokenCall2 := provider.On("Refresh", context.Background(), mock.Anything).Return(tc.refreshToken, tc.refreshErr)
if tc.issuedToken == "" {
var err error
tc.issuedToken, err = tokenizer.Issue(tc.token)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
}
key, err := tokenizer.Parse(tc.issuedToken)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
}
tokenCall.Unset()
tokenCall1.Unset()
tokenCall2.Unset()
}
}

func key() auth.Key {
exp := time.Now().UTC().Add(10 * time.Minute).Round(time.Second)
return auth.Key{
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
Type: auth.AccessKey,
Issuer: "magistrala.auth",
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: exp,
}
}

func oauthKey(t *testing.T) auth.Key {
return auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Issuer: "magistrala.auth",
Subject: testsutil.GenerateUUID(t),
User: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute).Round(time.Second),
OAuth: auth.OAuthToken{
Provider: "test",
AccessToken: strings.Repeat("a", 10),
RefreshToken: strings.Repeat("b", 10),
},
}
}

func invalidOauthToken(t *testing.T, key auth.Key, provider, accessToken, refreshToken interface{}) string {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
IssuedAt(key.IssuedAt).
Subject(key.Subject).
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(userField, key.User)
builder.Claim(domainField, key.Domain)
if provider != nil {
builder.Claim("oauth_provider", provider)
if accessToken != nil {
builder.Claim(provider.(string), map[string]interface{}{"access_token": accessToken})
}
if refreshToken != nil {
builder.Claim(provider.(string), map[string]interface{}{"refresh_token": refreshToken})
}
}
if key.ID != "" {
builder.JwtID(key.ID)
}
tkn, err := builder.Build()
require.Nil(t, err, fmt.Sprintf("building token expected to succeed: %s", err))
signedTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, reposecret))
require.Nil(t, err, fmt.Sprintf("signing token expected to succeed: %s", err))
return string(signedTkn)
}
Loading

0 comments on commit 0f05c10

Please sign in to comment.