Skip to content

Commit

Permalink
Makes ACR cloudauth messages more informative (#62)
Browse files Browse the repository at this point in the history
* makes acr cloudauth messages more informative

- passes ctx when registering providers
- checks for the presence of envvars to decide if ACR should be enabled;
  errors out if credentials are incomplete or incorrect

* fixes linting false-positive
  • Loading branch information
sonnysideup authored Sep 16, 2022
1 parent c9132bb commit aea3822
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 48 deletions.
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ issues:
text: "waitForIstioSidecar"
linters:
- contextcheck
- path: pkg/controller/support/credentials/cloudauth/ecr/ecr.go
text: "NewFromConfig"
linters:
- contextcheck
2 changes: 1 addition & 1 deletion pkg/controller/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func Start(cfg config.Controller) error {
}

log.Info("Registering cloud auth providers")
if err = credentials.LoadCloudProviders(log); err != nil {
if err = credentials.LoadCloudProviders(ctx, log); err != nil {
return err
}

Expand Down
68 changes: 34 additions & 34 deletions pkg/controller/support/credentials/cloudauth/acr/acr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package acr

import (
"context"
"errors"
"fmt"
"os"
"regexp"
"time"

Expand All @@ -15,93 +15,93 @@ import (
"github.com/docker/docker/api/types"
"github.com/go-logr/logr"

cloudauth2 "github.com/dominodatalab/hephaestus/pkg/controller/support/credentials/cloudauth"
"github.com/dominodatalab/hephaestus/pkg/controller/support/credentials/cloudauth"
)

// https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md

const acrUserForRefreshToken = "00000000-0000-0000-0000-000000000000"

var (
acrRegex = regexp.MustCompile(`.*\.azurecr\.io|.*\.azurecr\.cn|.*\.azurecr\.de|.*\.azurecr\.us`)

ErrNoCredentials = errors.New("no Azure Credentials")
)
var acrRegex = regexp.MustCompile(`.*\.azurecr\.io|.*\.azurecr\.cn|.*\.azurecr\.de|.*\.azurecr\.us`)

type acrProvider struct {
logger logr.Logger
tenantID string
servicePrincipalToken *adal.ServicePrincipalToken
}

func Register(logger logr.Logger, registry *cloudauth2.Registry) error {
provider, err := newProvider(logger)
// Register will instantiate a new authentication provider whenever the AZURE_TENANT_ID or AZURE_CLIENT_ID envvars are
// present, otherwise it will result in a no-op. An error will be returned whenever the envvar settings are invalid.
func Register(ctx context.Context, logger logr.Logger, registry *cloudauth.Registry) error {
_, tenantIDDefined := os.LookupEnv(auth.TenantID)
_, clientIDDefined := os.LookupEnv(auth.ClientID)
if !(tenantIDDefined && clientIDDefined) {
logger.Info(fmt.Sprintf(
"ACR authentication provider not registered, %s or %s is absent", auth.TenantID, auth.ClientID,
))

return nil
}

provider, err := newProvider(ctx, logger)
if err != nil {
logger.Info("ACR not registered", "error", err)
if err == ErrNoCredentials {
return nil
}
return err
return fmt.Errorf("failed to create authentication provider: %w", err)
}

registry.Register(acrRegex, provider.authenticate)
logger.Info("ACR registered")
logger.Info("ACR authentication provider registered")

return nil
}

func newProvider(logger logr.Logger) (*acrProvider, error) {
func newProvider(ctx context.Context, logger logr.Logger) (*acrProvider, error) {
settings, err := auth.GetSettingsFromEnvironment()
if err != nil {
return nil, err
return nil, fmt.Errorf("cannot get settings from env: %w", err)
}

// the minimum set of required values
if settings.Values[auth.TenantID] == "" || settings.Values[auth.ClientID] == "" {
return nil, ErrNoCredentials
}
var token *adal.ServicePrincipalToken

var spt *adal.ServicePrincipalToken
if cc, err := settings.GetClientCredentials(); err == nil {
if spt, err = cc.ServicePrincipalToken(); err != nil {
return nil, err
if token, err = cc.ServicePrincipalToken(); err != nil {
return nil, fmt.Errorf("retrieving service principal token failed: %w", err)
}
} else {
ctx := context.Background()
err = retry(ctx, logger, 3, func() error {
spt, err = settings.GetMSI().ServicePrincipalToken()
token, err = settings.GetMSI().ServicePrincipalToken()
return err
})

if err != nil {
// IMDS can take some time to setup, restart the process
return nil, fmt.Errorf("retreiving Service Principal Token from MSI failed: %w", err)
// IMDS can take some time to set up, restart the process
return nil, fmt.Errorf("retreiving service principal token from MSI failed: %w", err)
}
}

return &acrProvider{
logger: logger.WithName("acrProvider"),
logger: logger.WithName("acr-auth-provider"),
tenantID: settings.Values[auth.TenantID],
servicePrincipalToken: spt,
servicePrincipalToken: token,
}, nil
}

func (a *acrProvider) authenticate(ctx context.Context, server string) (*types.AuthConfig, error) {
match := acrRegex.FindAllString(server, -1)
if len(match) != 1 {
return nil, fmt.Errorf("invalid acr url: %q should match %v", server, acrRegex)
return nil, fmt.Errorf("invalid ACR url: %q should match %v", server, acrRegex)
}

loginServer := match[0]
err := retry(ctx, a.logger, 3, func() error {
return a.servicePrincipalToken.EnsureFreshWithContext(ctx)
})
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to refresh AAD token: %w", err)
}

armAccessToken := a.servicePrincipalToken.OAuthToken()
loginServerURL := "https://" + loginServer
directive, err := cloudauth2.ChallengeLoginServer(ctx, loginServerURL)
directive, err := cloudauth.ChallengeLoginServer(ctx, loginServerURL)
if err != nil {
return nil, err
}
Expand All @@ -116,7 +116,7 @@ func (a *acrProvider) authenticate(ctx context.Context, server string) (*types.A
armAccessToken,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate ACR refresh token: %w", err)
}

return &types.AuthConfig{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/dominodatalab/hephaestus/pkg/controller/support/credentials/cloudauth"
)

func TestRegister(t *testing.T) {
func TestRegisterIntegration(t *testing.T) {
if os.Getenv(auth.ClientSecret) == "" {
t.Skip("Skipping, azure not setup")
}
Expand All @@ -30,7 +30,7 @@ func TestRegister(t *testing.T) {
}
}

func TestRegisterNoSecret(t *testing.T) {
func TestRegisterNoSecretIntegration(t *testing.T) {
secret := os.Getenv(auth.ClientSecret)
os.Unsetenv(auth.ClientSecret)
t.Cleanup(func() {
Expand All @@ -46,7 +46,7 @@ func TestRegisterNoSecret(t *testing.T) {
}
}

func TestAuthenticate(t *testing.T) {
func TestAuthenticateIntegration(t *testing.T) {
if os.Getenv(auth.ClientSecret) == "" {
t.Skip("Skipping, azure not setup")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/support/credentials/cloudauth/ecr/ecr.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ var (
)
)

func Register(logger logr.Logger, registry *cloudauth.Registry) error {
config, err := config.LoadDefaultConfig(context.Background(), config.WithEC2IMDSRegion())
func Register(ctx context.Context, logger logr.Logger, registry *cloudauth.Registry) error {
config, err := config.LoadDefaultConfig(ctx, config.WithEC2IMDSRegion())
if err != nil {
logger.Info("ECR not registered", "error", err)
return nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/support/credentials/cloudauth/gcr/gcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ type gcrProvider struct {
tokenSource oauth2.TokenSource
}

func Register(logger logr.Logger, registry *cloudauth2.Registry) error {
provider, err := newProvider(context.Background(), logger)
func Register(ctx context.Context, logger logr.Logger, registry *cloudauth2.Registry) error {
provider, err := newProvider(ctx, logger)
if err != nil {
logger.Info("GCR not registered", "error", err)
if strings.Contains(err.Error(), "could not find default credentials") {
Expand All @@ -66,7 +66,7 @@ func newProvider(ctx context.Context, logger logr.Logger) (*gcrProvider, error)
return nil, err
}

return &gcrProvider{logger: logger.WithName("gcrProvider"), tokenSource: creds.TokenSource}, nil
return &gcrProvider{logger: logger.WithName("gcr-auth-provider"), tokenSource: creds.TokenSource}, nil
}

func (g *gcrProvider) authenticate(ctx context.Context, server string) (*types.AuthConfig, error) {
Expand Down
10 changes: 5 additions & 5 deletions pkg/controller/support/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func Persist(ctx context.Context, cfg *rest.Config, credentials []hephv1.Registr
case pointer.BoolDeref(cred.CloudProvided, false):
pac, err := CloudAuthRegistry.RetrieveAuthorization(ctx, cred.Server)
if err != nil {
return "", err
return "", fmt.Errorf("cloud registry authorization failed: %w", err)
}

ac = *pac
Expand Down Expand Up @@ -142,14 +142,14 @@ func Verify(ctx context.Context, configDir string, insecureRegistries []string)
}

// LoadCloudProviders adds all cloud authentication providers to the CloudAuthRegistry.
func LoadCloudProviders(log logr.Logger) error {
if err := acr.Register(log, CloudAuthRegistry); err != nil {
func LoadCloudProviders(ctx context.Context, log logr.Logger) error {
if err := acr.Register(ctx, log, CloudAuthRegistry); err != nil {
return fmt.Errorf("ACR registration failed: %w", err)
}
if err := ecr.Register(log, CloudAuthRegistry); err != nil {
if err := ecr.Register(ctx, log, CloudAuthRegistry); err != nil {
return fmt.Errorf("ECR registration failed: %w", err)
}
if err := gcr.Register(log, CloudAuthRegistry); err != nil {
if err := gcr.Register(ctx, log, CloudAuthRegistry); err != nil {
return fmt.Errorf("GCR registration failed: %w", err)
}

Expand Down

0 comments on commit aea3822

Please sign in to comment.