diff --git a/internal/creds.go b/internal/creds.go index 4ebeb61c1a..4792b00c8b 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -15,6 +15,7 @@ import ( "os" "time" + "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/oauth2adapt" "golang.org/x/oauth2" @@ -30,7 +31,7 @@ const quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" // it returns default credential information. func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { if ds.IsNewAuthLibraryEnabled() { - return credsNewAuth(ctx, ds) + return credsNewAuth(ds) } creds, err := baseCreds(ctx, ds) if err != nil { @@ -42,6 +43,30 @@ func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { return creds, nil } +// AuthCreds returns [cloud.google.com/go/auth.Credentials] based on credentials +// options provided via [option.ClientOption], including legacy oauth2/google +// options. If there are no applicable options, then it returns the result of +// [cloud.google.com/go/auth/credentials.DetectDefault]. +func AuthCreds(ctx context.Context, settings *DialSettings) (*auth.Credentials, error) { + if settings.AuthCredentials != nil { + return settings.AuthCredentials, nil + } + // Support oauth2/google options + var oauth2Creds *google.Credentials + if settings.InternalCredentials != nil { + oauth2Creds = settings.InternalCredentials + } else if settings.Credentials != nil { + oauth2Creds = settings.Credentials + } else if settings.TokenSource != nil { + oauth2Creds = &google.Credentials{TokenSource: settings.TokenSource} + } + if oauth2Creds != nil { + return oauth2adapt.AuthCredentialsFromOauth2Credentials(oauth2Creds), nil + } + + return detectDefaultFromDialSettings(settings) +} + // GetOAuth2Configuration determines configurations for the OAuth2 transport, which is separate from the API transport. // The OAuth2 transport and endpoint will be configured for mTLS if applicable. func GetOAuth2Configuration(ctx context.Context, settings *DialSettings) (string, *http.Client, error) { @@ -62,7 +87,7 @@ func GetOAuth2Configuration(ctx context.Context, settings *DialSettings) (string return tokenURL, oauth2Client, nil } -func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credentials, error) { +func credsNewAuth(settings *DialSettings) (*google.Credentials, error) { // Preserve old options behavior if settings.InternalCredentials != nil { return settings.InternalCredentials, nil @@ -76,6 +101,14 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti return oauth2adapt.Oauth2CredentialsFromAuthCredentials(settings.AuthCredentials), nil } + creds, err := detectDefaultFromDialSettings(settings) + if err != nil { + return nil, err + } + return oauth2adapt.Oauth2CredentialsFromAuthCredentials(creds), nil +} + +func detectDefaultFromDialSettings(settings *DialSettings) (*auth.Credentials, error) { var useSelfSignedJWT bool var aud string var scopes []string @@ -100,18 +133,13 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti aud = settings.DefaultAudience } - creds, err := credentials.DetectDefault(&credentials.DetectOptions{ + return credentials.DetectDefault(&credentials.DetectOptions{ Scopes: scopes, Audience: aud, CredentialsFile: settings.CredentialsFile, CredentialsJSON: settings.CredentialsJSON, UseSelfSignedJWT: useSelfSignedJWT, }) - if err != nil { - return nil, err - } - - return oauth2adapt.Oauth2CredentialsFromAuthCredentials(creds), nil } func baseCreds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { diff --git a/internal/creds_test.go b/internal/creds_test.go index d57ccca68e..69f0a1f505 100644 --- a/internal/creds_test.go +++ b/internal/creds_test.go @@ -9,11 +9,12 @@ import ( "os" "testing" + "cloud.google.com/go/auth" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) -func TestDefaultServiceAccount(t *testing.T) { +func TestCreds_DefaultServiceAccount(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -37,7 +38,31 @@ func TestDefaultServiceAccount(t *testing.T) { } } -func TestJWTWithAudience(t *testing.T) { +func TestAuthCreds_DefaultServiceAccount(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{ + CredentialsFile: "testdata/service-account.json", + DefaultScopes: []string{"foo"}, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{ + CredentialsJSON: []byte(validServiceAccountJSON), + DefaultScopes: []string{"foo"}, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_JWTWithAudience(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -55,7 +80,25 @@ func TestJWTWithAudience(t *testing.T) { } } -func TestJWTWithScope(t *testing.T) { +func TestAuthCreds_JWTWithAudience(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{CredentialsFile: "testdata/service-account.json", Audiences: []string{"foo"}} + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{CredentialsJSON: []byte(validServiceAccountJSON), Audiences: []string{"foo"}} + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_JWTWithScope(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -81,7 +124,33 @@ func TestJWTWithScope(t *testing.T) { } } -func TestJWTWithScopeAndUniverseDomain(t *testing.T) { +func TestAuthCreds_JWTWithScope(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{ + CredentialsFile: "testdata/service-account.json", + Scopes: []string{"foo"}, + EnableJwtWithScope: true, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{ + CredentialsJSON: []byte(validServiceAccountJSON), + Scopes: []string{"foo"}, + EnableJwtWithScope: true, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_JWTWithScopeAndUniverseDomain(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -109,7 +178,35 @@ func TestJWTWithScopeAndUniverseDomain(t *testing.T) { } } -func TestJWTWithDefaultScopes(t *testing.T) { +func TestAuthCreds_JWTWithScopeAndUniverseDomain(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{ + CredentialsFile: "testdata/service-account.json", + Scopes: []string{"foo"}, + EnableJwtWithScope: true, + UniverseDomain: "example.com", + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{ + CredentialsJSON: []byte(validServiceAccountJSON), + Scopes: []string{"foo"}, + EnableJwtWithScope: true, + UniverseDomain: "example.com", + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_JWTWithDefaultScopes(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -135,7 +232,33 @@ func TestJWTWithDefaultScopes(t *testing.T) { } } -func TestJWTWithDefaultAudience(t *testing.T) { +func TestAuthCreds_JWTWithDefaultScopes(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{ + CredentialsFile: "testdata/service-account.json", + DefaultScopes: []string{"foo"}, + EnableJwtWithScope: true, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{ + CredentialsJSON: []byte(validServiceAccountJSON), + DefaultScopes: []string{"foo"}, + EnableJwtWithScope: true, + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_JWTWithDefaultAudience(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -159,7 +282,31 @@ func TestJWTWithDefaultAudience(t *testing.T) { } } -func TestOAuth(t *testing.T) { +func TestAuthCreds_JWTWithDefaultAudience(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{ + CredentialsFile: "testdata/service-account.json", + DefaultAudience: "foo", + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{ + CredentialsJSON: []byte(validServiceAccountJSON), + DefaultAudience: "foo", + } + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + +func TestCreds_CredentialsFile_CredentialsJSON(t *testing.T) { ctx := context.Background() // Load a valid JSON file. No way to really test the contents; we just @@ -177,6 +324,24 @@ func TestOAuth(t *testing.T) { } } +func TestAuthCreds_CredentialsFile_CredentialsJSON(t *testing.T) { + ctx := context.Background() + + // Load a valid JSON file. No way to really test the contents; we just + // verify that there is no error. + ds := &DialSettings{CredentialsFile: "testdata/service-account.json", Scopes: []string{"foo"}} + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } + + // Load valid JSON. No way to really test the contents; we just + // verify that there is no error. + ds = &DialSettings{CredentialsJSON: []byte(validServiceAccountJSON), Scopes: []string{"foo"}} + if _, err := AuthCreds(ctx, ds); err != nil { + t.Errorf("got %v, wanted no error", err) + } +} + const validServiceAccountJSON = `{ "type": "service_account", "project_id": "dumba-504", @@ -268,7 +433,7 @@ func TestGetQuotaProject(t *testing.T) { } } -func TestCredsWithCredentials(t *testing.T) { +func TestCreds(t *testing.T) { tests := []struct { name string ds *DialSettings @@ -327,6 +492,98 @@ func TestCredsWithCredentials(t *testing.T) { } } +type staticTokenProvider string + +func (s staticTokenProvider) Token(context.Context) (*auth.Token, error) { + return &auth.Token{Value: string(s)}, nil +} + +func TestAuthCreds(t *testing.T) { + tests := []struct { + name string + ds *DialSettings + want string + }{ + { + name: "only token source opt", + ds: &DialSettings{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "token", + }), + }, + want: "token", + }, + { + name: "credentials and token source creds opt", + ds: &DialSettings{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "token", + }), + Credentials: &google.Credentials{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "credentials", + }), + }, + }, + want: "credentials", + }, + { + name: "internal, credentials and token source creds opt", + ds: &DialSettings{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "token", + }), + Credentials: &google.Credentials{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "credentials", + }), + }, + InternalCredentials: &google.Credentials{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "internal", + }), + }, + }, + want: "internal", + }, + { + name: "auth credentials, internal, credentials, token source creds opt", + ds: &DialSettings{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "token", + }), + Credentials: &google.Credentials{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "credentials", + }), + }, + InternalCredentials: &google.Credentials{ + TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "internal", + }), + }, + AuthCredentials: &auth.Credentials{ + TokenProvider: staticTokenProvider("auth credentials"), + }, + }, + want: "auth credentials", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + creds, err := AuthCreds(ctx, tc.ds) + if err != nil { + t.Fatalf("got %v, want nil error", err) + } + if tok, _ := creds.TokenProvider.Token(ctx); tok.Value != tc.want { + t.Fatalf("tok.AccessToken = %q, want %q", tok.Value, tc.want) + } + }) + } +} + func TestIsSelfSignedJWTFlow(t *testing.T) { tests := []struct { name string diff --git a/transport/dial.go b/transport/dial.go index 652b8eba51..885ebcf997 100644 --- a/transport/dial.go +++ b/transport/dial.go @@ -8,6 +8,7 @@ import ( "context" "net/http" + "cloud.google.com/go/auth" "golang.org/x/oauth2/google" "google.golang.org/grpc" @@ -46,3 +47,30 @@ func Creds(ctx context.Context, opts ...option.ClientOption) (*google.Credential } return internal.Creds(ctx, &ds) } + +// AuthCreds returns [cloud.google.com/go/auth.Credentials] using the following +// options provided via [option.ClientOption], including legacy oauth2/google +// options, in this order: +// +// * [option.WithAuthCredentials] +// * [option/internaloption.WithCredentials] (internal use only) +// * [option.WithCredentials] +// * [option.WithTokenSource] +// +// If there are no applicable credentials options, then it passes the +// following options to [cloud.google.com/go/auth/credentials.DetectDefault] and +// returns the result: +// +// * [option.WithAudiences] +// * [option.WithCredentialsFile] +// * [option.WithCredentialsJSON] +// * [option.WithScopes] +// * [option/internaloption.WithDefaultScopes] (internal use only) +// * [option/internaloption.EnableJwtWithScope] (internal use only) +func AuthCreds(ctx context.Context, opts ...option.ClientOption) (*auth.Credentials, error) { + var ds internal.DialSettings + for _, opt := range opts { + opt.Apply(&ds) + } + return internal.AuthCreds(ctx, &ds) +}