Skip to content

Commit

Permalink
refactor: remove more packages
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott committed Dec 20, 2023
1 parent c22e343 commit 23cc8f6
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 29 deletions.
19 changes: 10 additions & 9 deletions equalKeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package oauth2_test
import (
"testing"

"github.com/oleiade/reflections"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"authelia.com/provider/oauth2/internal/reflection"
)

func TestAssertObjectsAreEqualByKeys(t *testing.T) {
Expand All @@ -28,9 +29,9 @@ func TestAssertObjectsAreEqualByKeys(t *testing.T) {
func AssertObjectKeysEqual(t *testing.T, a, b any, keys ...string) {
assert.True(t, len(keys) > 0, "No key provided.")
for _, k := range keys {
c, err := reflections.GetField(a, k)
c, err := reflection.GetField(a, k)
assert.NoError(t, err)
d, err := reflections.GetField(b, k)
d, err := reflection.GetField(b, k)
assert.NoError(t, err)
assert.Equal(t, c, d, "field: %s", k)
}
Expand All @@ -39,9 +40,9 @@ func AssertObjectKeysEqual(t *testing.T, a, b any, keys ...string) {
func AssertObjectKeysNotEqual(t *testing.T, a, b any, keys ...string) {
assert.True(t, len(keys) > 0, "No key provided.")
for _, k := range keys {
c, err := reflections.GetField(a, k)
c, err := reflection.GetField(a, k)
assert.NoError(t, err)
d, err := reflections.GetField(b, k)
d, err := reflection.GetField(b, k)
assert.NoError(t, err)
assert.NotEqual(t, c, d, "%s", k)
}
Expand All @@ -50,9 +51,9 @@ func AssertObjectKeysNotEqual(t *testing.T, a, b any, keys ...string) {
func RequireObjectKeysEqual(t *testing.T, a, b any, keys ...string) {
assert.True(t, len(keys) > 0, "No key provided.")
for _, k := range keys {
c, err := reflections.GetField(a, k)
c, err := reflection.GetField(a, k)
assert.NoError(t, err)
d, err := reflections.GetField(b, k)
d, err := reflection.GetField(b, k)
assert.NoError(t, err)
require.Equal(t, c, d, "%s", k)
}
Expand All @@ -61,9 +62,9 @@ func RequireObjectKeysEqual(t *testing.T, a, b any, keys ...string) {
func RequireObjectKeysNotEqual(t *testing.T, a, b any, keys ...string) {
assert.True(t, len(keys) > 0, "No key provided.")
for _, k := range keys {
c, err := reflections.GetField(a, k)
c, err := reflection.GetField(a, k)
assert.NoError(t, err)
d, err := reflections.GetField(b, k)
d, err := reflection.GetField(b, k)
assert.NoError(t, err)
require.NotEqual(t, c, d, "%s", k)
}
Expand Down
7 changes: 3 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ module authelia.com/provider/oauth2
go 1.21

require (
github.com/cristalhq/jwt/v4 v4.0.2
github.com/dgraph-io/ristretto v0.1.1
github.com/go-jose/go-jose/v3 v3.0.1
github.com/google/uuid v1.4.0
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.5.0
github.com/gorilla/mux v1.8.1
github.com/hashicorp/go-retryablehttp v0.7.5
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/oleiade/reflections v1.0.1
github.com/parnurzeal/gorequest v0.2.16
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.0
go.uber.org/mock v0.3.0
golang.org/x/crypto v0.16.0
golang.org/x/crypto v0.17.0
golang.org/x/net v0.19.0
golang.org/x/oauth2 v0.15.0
golang.org/x/text v0.14.0
Expand Down
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cristalhq/jwt/v4 v4.0.2 h1:g/AD3h0VicDamtlM70GWGElp8kssQEv+5wYd7L9WOhU=
github.com/cristalhq/jwt/v4 v4.0.2/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqesmzs+miQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand All @@ -22,6 +20,8 @@ github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.1.1 h1:jxpi2eWoU84wbX9iIEyAeeoac3FLuifZpY9tcNUD9kw=
github.com/golang/glog v1.1.1/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ=
Expand All @@ -33,8 +33,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
Expand Down Expand Up @@ -100,8 +100,8 @@ go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
Expand Down
13 changes: 8 additions & 5 deletions handler/openid/flow_hybrid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ package openid

import (
"context"
"encoding/json"
"fmt"
"net/url"
"testing"
"time"

cristaljwt "github.com/cristalhq/jwt/v4"
xjwt "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -309,10 +308,14 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) {
idToken := aresp.GetParameters().Get("id_token")
assert.NotEmpty(t, idToken)
assert.True(t, areq.GetSession().GetExpiresAt(oauth2.IDToken).IsZero())
jwt, err := cristaljwt.ParseNoVerify([]byte(idToken))

parser := xjwt.NewParser()

claims := &xjwt.RegisteredClaims{}

_, _, err := parser.ParseUnverified(idToken, claims)

require.NoError(t, err)
claims := &cristaljwt.RegisteredClaims{}
require.NoError(t, json.Unmarshal(jwt.Claims(), claims))
internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.ExpiresAt.Time, time.Minute)

assert.NotEmpty(t, aresp.GetParameters().Get("access_token"))
Expand Down
43 changes: 43 additions & 0 deletions internal/reflection/field.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package reflection

import (
"errors"
"fmt"
"reflect"
)

func GetField(obj interface{}, name string) (interface{}, error) {
if !hasValidType(obj, []reflect.Kind{reflect.Struct, reflect.Ptr}) {
return nil, errors.New("Cannot use GetField on a non-struct interface")
}

objValue := reflectValue(obj)
field := objValue.FieldByName(name)
if !field.IsValid() {
return nil, fmt.Errorf("No such field: %s in obj", name)
}

return field.Interface(), nil
}

func hasValidType(obj interface{}, types []reflect.Kind) bool {
for _, t := range types {
if reflect.TypeOf(obj).Kind() == t {
return true
}
}

return false
}

func reflectValue(obj interface{}) reflect.Value {
var val reflect.Value

if reflect.TypeOf(obj).Kind() == reflect.Ptr {
val = reflect.ValueOf(obj).Elem()
} else {
val = reflect.ValueOf(obj)
}

return val
}
13 changes: 8 additions & 5 deletions internal/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package internal

import (
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -13,7 +12,7 @@ import (
"testing"
"time"

cristaljwt "github.com/cristalhq/jwt/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"golang.org/x/net/html"
xoauth2 "golang.org/x/oauth2"
Expand Down Expand Up @@ -61,13 +60,17 @@ func RequireEqualTime(t *testing.T, expected time.Time, actual time.Time, precis
}

func ExtractJwtExpClaim(t *testing.T, token string) *time.Time {
jwt, err := cristaljwt.ParseNoVerify([]byte(token))
parser := jwt.NewParser(jwt.WithoutClaimsValidation())

claims := &jwt.RegisteredClaims{}

_, _, err := parser.ParseUnverified(token, claims)
require.NoError(t, err)
claims := &cristaljwt.RegisteredClaims{}
require.NoError(t, json.Unmarshal(jwt.Claims(), claims))

if claims.ExpiresAt == nil {
return nil
}

return &claims.ExpiresAt.Time
}

Expand Down

0 comments on commit 23cc8f6

Please sign in to comment.