Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: handle empty or invalid secrets nicely #4801

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions flyteadmin/auth/authzserver/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se
if err != nil {
return Provider{}, fmt.Errorf("failed to read secretTokenHash file. Error: %w", err)
}
if tokenHashBase64 == "" {
return Provider{}, fmt.Errorf("failed to read secretTokenHash. Error: empty value")
}

secret, err := base64.RawStdEncoding.DecodeString(tokenHashBase64)
if err != nil {
Expand All @@ -158,8 +161,14 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se
if err != nil {
return Provider{}, fmt.Errorf("failed to read token signing RSA Key. Error: %w", err)
}
if privateKeyPEM == "" {
return Provider{}, fmt.Errorf("failed to read token signing RSA Key. Error: empty value")
}

block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return Provider{}, fmt.Errorf("failed to decode token signing RSA Key. Error: no PEM data found")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return Provider{}, fmt.Errorf("failed to parse PKCS1PrivateKey. Error: %w", err)
Expand Down Expand Up @@ -197,7 +206,13 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se
// Try to load old key to validate tokens using it to support key rotation.
privateKeyPEM, err = sm.Get(ctx, cfg.OldTokenSigningRSAKeySecretName)
if err == nil {
if privateKeyPEM == "" {
return Provider{}, fmt.Errorf("failed to read PKCS1PrivateKey. Error: empty value")
}
block, _ = pem.Decode([]byte(privateKeyPEM))
if block == nil {
return Provider{}, fmt.Errorf("failed to decode PKCS1PrivateKey. Error: no PEM data found")
}
oldPrivateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return Provider{}, fmt.Errorf("failed to parse PKCS1PrivateKey. Error: %w", err)
Expand Down
125 changes: 121 additions & 4 deletions flyteadmin/auth/authzserver/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) {
var buf bytes.Buffer
assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes}))
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil)
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", fmt.Errorf("not found"))
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
assert.NoError(t, err)
Expand All @@ -45,9 +45,126 @@ func TestNewProvider(t *testing.T) {
newMockProvider(t)
}

func newInvalidMockProvider(ctx context.Context, t *testing.T, secrets auth.SecretsSet, sm *mocks.SecretManager, invalidFunc func() *mocks.SecretManager_Get, errorContains string) {

sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return(base64.RawStdEncoding.EncodeToString(secrets.TokenHashKey), nil)
sm.OnGet(ctx, config.SecretNameCookieBlockKey).Return(base64.RawStdEncoding.EncodeToString(secrets.CookieBlockKey), nil)
sm.OnGet(ctx, config.SecretNameCookieHashKey).Return(base64.RawStdEncoding.EncodeToString(secrets.CookieHashKey), nil)

privBytes := x509.MarshalPKCS1PrivateKey(secrets.TokenSigningRSAPrivateKey)
var buf bytes.Buffer
assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes}))
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil)
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

invalidFunc()
p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
assert.Error(t, err)
assert.ErrorContains(t, err, errorContains)
assert.Equal(t, Provider{}, p)
}

func TestNewInvalidProviderSecretTokenHashBad(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Unset()
return sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return("", fmt.Errorf("test error"))
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read secretTokenHash file. Error: test error")
}

func TestNewInvalidProviderSecretTokenHashEmpty(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Unset()
return sm.OnGet(ctx, config.SecretNameClaimSymmetricKey).Return("", nil)
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read secretTokenHash. Error: empty value")
}

func TestNewInvalidProviderTokenSigningRSAKeyBad(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset()
return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("", fmt.Errorf("test error"))
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read token signing RSA Key. Error: test error")
}

func TestNewInvalidProviderTokenSigningRSAKeyEmpty(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset()
return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("", nil)
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read token signing RSA Key. Error: empty value")
}

func TestNewInvalidProviderTokenSigningRSAKeyNoPEMData(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Unset()
return sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return("this is no PEM data", nil)
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to decode token signing RSA Key. Error: no PEM data found")
}

func TestNewInvalidProviderOldTokenSigningRSAKeyEmpty(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Unset()
return sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", nil)
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to read PKCS1PrivateKey. Error: empty value")
}

func TestNewInvalidProviderOldTokenSigningRSAKeyNoPEMData(t *testing.T) {
secrets, err := auth.NewSecrets()
assert.NoError(t, err)

ctx := context.Background()
sm := &mocks.SecretManager{}

invalidFunc := func() *mocks.SecretManager_Get {
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Unset()
return sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("this is no PEM data", nil)
}
newInvalidMockProvider(ctx, t, secrets, sm, invalidFunc, "failed to decode PKCS1PrivateKey. Error: no PEM data found")
}

func TestProvider_KeySet(t *testing.T) {
p, _ := newMockProvider(t)
assert.Equal(t, 1, p.KeySet().Len())
assert.Equal(t, 2, p.KeySet().Len())
}

func TestProvider_NewJWTSessionToken(t *testing.T) {
Expand All @@ -64,7 +181,7 @@ func TestProvider_NewJWTSessionToken(t *testing.T) {

func TestProvider_PublicKeys(t *testing.T) {
p, _ := newMockProvider(t)
assert.Len(t, p.PublicKeys(), 1)
assert.Len(t, p.PublicKeys(), 2)
}

type CustomClaimsExample struct {
Expand Down Expand Up @@ -175,7 +292,7 @@ func TestProvider_ValidateAccessToken(t *testing.T) {
var buf bytes.Buffer
assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes}))
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil)
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return("", fmt.Errorf("not found"))
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
assert.NoError(t, err)
Expand Down
Loading