From ce0f92f686a5b0dfc2bab083ce4d570faa20275c Mon Sep 17 00:00:00 2001 From: Bastian Doetsch Date: Mon, 29 Jul 2024 15:50:37 +0200 Subject: [PATCH] feat: add authentication error messages (oauth expiry, invalid creds) [IDE-459] (#607) --- application/di/test_init.go | 5 +- .../server/authentication_smoke_test.go | 105 ++++++++++++ application/server/configuration_test.go | 8 +- application/server/execute_command_test.go | 2 +- application/server/notification.go | 3 +- application/server/server.go | 2 +- application/server/server_test.go | 14 +- application/server/trust_test.go | 4 +- domain/ide/codelens/codelens_test.go | 2 +- domain/ide/command/command_service_test.go | 2 +- domain/ide/command/copy_auth_link.go | 9 +- domain/ide/command/get_active_user.go | 5 +- domain/ide/command/get_active_user_test.go | 2 +- domain/ide/command/get_feature_flag_status.go | 5 +- .../command/get_feature_flag_status_test.go | 2 +- domain/ide/command/logout_test.go | 7 +- domain/ide/command/report_analytics.go | 5 +- domain/ide/command/report_analytics_test.go | 2 +- domain/ide/command/sast_enabled.go | 5 +- domain/ide/command/sast_enabled_test.go | 2 +- domain/snyk/scanner.go | 5 +- domain/snyk/scanner_test.go | 2 +- go.mod | 4 +- go.sum | 8 +- .../authentication/auth_configuration.go | 50 +++--- infrastructure/authentication/auth_service.go | 6 +- .../authentication/auth_service_impl.go | 159 +++++++++++------- .../authentication/auth_service_impl_test.go | 17 +- infrastructure/authentication/initializer.go | 17 +- .../authentication/initializer_test.go | 22 +-- infrastructure/cli/environment.go | 35 ++-- infrastructure/cli/environment_test.go | 6 + 32 files changed, 320 insertions(+), 202 deletions(-) create mode 100644 application/server/authentication_smoke_test.go diff --git a/application/di/test_init.go b/application/di/test_init.go index c68eaa651..a334dbac0 100644 --- a/application/di/test_init.go +++ b/application/di/test_init.go @@ -17,9 +17,10 @@ package di import ( - "github.com/snyk/snyk-ls/domain/snyk/persistence" "testing" + "github.com/snyk/snyk-ls/domain/snyk/persistence" + "github.com/golang/mock/gomock" "github.com/snyk/snyk-ls/application/codeaction" @@ -60,7 +61,7 @@ func TestInit(t *testing.T) { installer = install.NewFakeInstaller() authProvider := authentication.NewFakeCliAuthenticationProvider(c) snykApiClient = &snyk_api.FakeApiClient{CodeEnabled: true} - authenticationService = authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{authProvider}, errorReporter, notifier) + authenticationService = authentication.NewAuthenticationService(c, authProvider, errorReporter, notifier) snykCli := cli.NewExecutor(c, errorReporter, notifier) cliInitializer = cli.NewInitializer(errorReporter, installer, notifier, snykCli) authInitializer := authentication.NewInitializer(c, authenticationService, errorReporter, notifier) diff --git a/application/server/authentication_smoke_test.go b/application/server/authentication_smoke_test.go new file mode 100644 index 000000000..0ffc7de7a --- /dev/null +++ b/application/server/authentication_smoke_test.go @@ -0,0 +1,105 @@ +/* + * © 2024 Snyk Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/snyk/snyk-ls/application/di" + "github.com/snyk/snyk-ls/infrastructure/authentication" + "github.com/snyk/snyk-ls/internal/testutil" + "github.com/snyk/snyk-ls/internal/types" + "github.com/snyk/snyk-ls/internal/uri" +) + +func Test_InvalidExpiredCredentialsSendMessageRequest(t *testing.T) { + // how to process the expected callback + token := getDummyOAuth2Token(time.Now().Add(-time.Hour)) + tokenBytes, marshallingErr := json.Marshal(token) + require.NoError(t, marshallingErr) + + checkInvalidCredentialsMessageRequest(t, authentication.ExpirationMsg, string(tokenBytes)) +} + +func Test_InvalidCredentialsNotExpiredSendMessageRequest(t *testing.T) { + token := getDummyOAuth2Token(time.Now().Add(+time.Hour)) + tokenBytes, marshallingErr := json.Marshal(token) + require.NoError(t, marshallingErr) + + checkInvalidCredentialsMessageRequest(t, authentication.InvalidCredsMessage, string(tokenBytes)) +} + +func getDummyOAuth2Token(expiry time.Time) oauth2.Token { + token := oauth2.Token{ + AccessToken: "a", + TokenType: "bearer", + RefreshToken: "c", + Expiry: expiry, + } + return token +} + +func checkInvalidCredentialsMessageRequest(t *testing.T, expected string, tokenString string) { + t.Helper() + srv, jsonRpcRecorder := setupServer(t) + + c := testutil.SmokeTest(t, false) + c.SetSnykIacEnabled(false) + c.SetSnykOssEnabled(true) + // we have to reset the token, as smoketest automatically grab it from env + c.SetToken("") + di.Init() + + clientParams := types.InitializeParams{ + WorkspaceFolders: []types.WorkspaceFolder{{Uri: uri.PathToUri(t.TempDir()), Name: t.Name()}}, + InitializationOptions: types.Settings{ + Token: tokenString, + EnableTrustedFoldersFeature: "false", + FilterSeverity: types.DefaultSeverityFilter(), + AuthenticationMethod: types.OAuthAuthentication, + AutomaticAuthentication: "false", + }, + } + + lspClient := srv.Client + jsonRpcRecorder.ClearCallbacks() + + _, err := lspClient.Call(context.Background(), "initialize", clientParams) + require.NoError(t, err) + _, err = lspClient.Call(context.Background(), "initialized", nil) + require.NoError(t, err) + + assert.Eventuallyf(t, func() bool { + callbacks := jsonRpcRecorder.FindCallbacksByMethod("window/showMessageRequest") + for _, callback := range callbacks { + if strings.Contains(callback.ParamString(), expected) { + return true + } else { + t.Error("wrong callback received", callback.ParamString()) + } + } + return false + }, time.Second*5, time.Millisecond, "callback not received") +} diff --git a/application/server/configuration_test.go b/application/server/configuration_test.go index 47e1ef4e7..12bbe4e36 100644 --- a/application/server/configuration_test.go +++ b/application/server/configuration_test.go @@ -165,13 +165,12 @@ func Test_WorkspaceDidChangeConfiguration_PullNoCapability(t *testing.T) { } func Test_UpdateSettings(t *testing.T) { - di.TestInit(t) - orgUuid, _ := uuid.NewRandom() expectedOrgId := orgUuid.String() t.Run("All settings are updated", func(t *testing.T) { c := testutil.UnitTest(t) + di.TestInit(t) tempDir1 := filepath.Join(t.TempDir(), "tempDir1") tempDir2 := filepath.Join(t.TempDir(), "tempDir2") @@ -196,7 +195,7 @@ func Test_UpdateSettings(t *testing.T) { RuntimeName: "java", RuntimeVersion: "1.8.0_275", ScanningMode: "manual", - AuthenticationMethod: types.OAuthAuthentication, + AuthenticationMethod: types.FakeAuthentication, SnykCodeApi: sampleSettings.SnykCodeApi, EnableSnykOpenBrowserActions: "true", FolderConfigs: []types.FolderConfig{ @@ -226,7 +225,6 @@ func Test_UpdateSettings(t *testing.T) { assert.Equal(t, expectedOrgId, c.Organization()) assert.False(t, c.ManageBinariesAutomatically()) assert.Equal(t, "C:\\Users\\CliPath\\snyk-ls.exe", c.CliSettings().Path()) - assert.Equal(t, "a fancy token", c.Token()) assert.Equal(t, types.DefaultSeverityFilter(), c.FilterSeverity()) assert.Subset(t, []string{"trustedPath1", "trustedPath2"}, c.TrustedFolders()) assert.Equal(t, settings.OsPlatform, c.OsPlatform()) @@ -248,6 +246,8 @@ func Test_UpdateSettings(t *testing.T) { folderConfig2, err := gitconfig.GetOrCreateFolderConfig(tempDir2) assert.NoError(t, err) assert.NotEmpty(t, folderConfig2.BaseBranch) + + assert.Eventually(t, func() bool { return "a fancy token" == c.Token() }, time.Second*5, time.Millisecond) }) t.Run("empty snyk code api is ignored and default is used", func(t *testing.T) { diff --git a/application/server/execute_command_test.go b/application/server/execute_command_test.go index c902a5bbb..0ebd84f79 100644 --- a/application/server/execute_command_test.go +++ b/application/server/execute_command_test.go @@ -138,7 +138,7 @@ func Test_loginCommand_StartsAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = false params := lsp.ExecuteCommandParams{Command: types.LoginCommand} diff --git a/application/server/notification.go b/application/server/notification.go index fa26dc3c0..d24d7783b 100644 --- a/application/server/notification.go +++ b/application/server/notification.go @@ -18,6 +18,7 @@ package server import ( "context" + "reflect" "github.com/rs/zerolog" sglsp "github.com/sourcegraph/go-lsp" @@ -30,7 +31,7 @@ import ( ) func notifier(c *config.Config, srv types.Server, method string, params any) { - c.Logger().Debug().Str("method", "notifier").Msgf("Notifying") + c.Logger().Debug().Str("method", "notifier").Str("type", reflect.TypeOf(params).String()).Msgf("Notifying") err := srv.Notify(context.Background(), method, params) logError(c.Logger(), err, "notifier") } diff --git a/application/server/server.go b/application/server/server.go index 7a225ae69..9da86a169 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -366,7 +366,7 @@ func initializedHandler(srv *jrpc2.Server) handler.Func { err := di.Scanner().Init() if err != nil { logger.Error().Err(err).Msg("Scan initialization error, canceling scan") - return nil, err + return nil, nil } autoScanEnabled := c.IsAutoScanEnabled() diff --git a/application/server/server_test.go b/application/server/server_test.go index f1badd994..4c0f2d5a3 100644 --- a/application/server/server_test.go +++ b/application/server/server_test.go @@ -340,7 +340,7 @@ func Test_TextDocumentCodeLenses_shouldReturnCodeLenses(t *testing.T) { testutil.IntegTest(t) // this needs an authenticated user loc, _ := setupServer(t) didOpenParams, dir := didOpenTextParams(t) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true clientParams := types.InitializeParams{ @@ -397,7 +397,7 @@ func Test_TextDocumentCodeLenses_dirtyFileShouldFilterCodeLenses(t *testing.T) { testutil.IntegTest(t) // this needs an authenticated user loc, _ := setupServer(t) didOpenParams, dir := didOpenTextParams(t) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true clientParams := types.InitializeParams{ @@ -662,7 +662,7 @@ func Test_initialize_handlesUntrustedFoldersWhenAuthenticated(t *testing.T) { Token: "token", } - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true params := types.InitializeParams{ @@ -706,7 +706,7 @@ func Test_initialize_doesnotHandleUntrustedFolders(t *testing.T) { func Test_textDocumentDidSaveHandler_shouldAcceptDocumentItemAndPublishDiagnostics(t *testing.T) { loc, jsonRPCRecorder := setupServer(t) config.CurrentConfig().SetSnykCodeEnabled(true) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true _, err := loc.Client.Call(ctx, "initialize", nil) @@ -758,7 +758,7 @@ func Test_textDocumentDidSaveHandler_shouldTriggerScanForDotSnykFile(t *testing. c.SetAuthenticationMethod(types.FakeAuthentication) di.AuthenticationService().ConfigureProviders(c) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0] + fakeAuthenticationProvider := di.AuthenticationService().Provider() fakeAuthenticationProvider.(*authentication.FakeAuthenticationProvider).IsAuthenticated = true _, err := loc.Client.Call(ctx, "initialize", nil) @@ -810,7 +810,7 @@ func Test_textDocumentDidOpenHandler_shouldNotPublishIfNotCached(t *testing.T) { func Test_textDocumentDidOpenHandler_shouldPublishIfCached(t *testing.T) { loc, jsonRPCRecorder := setupServer(t) config.CurrentConfig().SetSnykCodeEnabled(true) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true _, err := loc.Client.Call(ctx, "initialize", nil) if err != nil { @@ -1008,7 +1008,7 @@ func Test_IntegrationHoverResults(t *testing.T) { loc, _ := setupServer(t) c := testutil.IntegTest(t) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true var cloneTargetDir, err = testutil.SetupCustomTestRepo(t, t.TempDir(), nodejsGoof, "0336589", c.Logger()) diff --git a/application/server/trust_test.go b/application/server/trust_test.go index e41c15816..72b7f8e8a 100644 --- a/application/server/trust_test.go +++ b/application/server/trust_test.go @@ -107,7 +107,7 @@ func Test_handleUntrustedFolders_shouldTriggerTrustRequestAndNotScanAfterNegativ func Test_initializeHandler_shouldCallHandleUntrustedFolders(t *testing.T) { loc, jsonRPCRecorder := setupServer(t) config.CurrentConfig().SetTrustedFolderFeatureEnabled(true) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true _, err := loc.Client.Call(context.Background(), "initialize", types.InitializeParams{ @@ -150,7 +150,7 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) { c.SetTrustedFolderFeatureEnabled(true) c.SetTrustedFolderFeatureEnabled(true) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true rootDir := t.TempDir() diff --git a/domain/ide/codelens/codelens_test.go b/domain/ide/codelens/codelens_test.go index 04a9b54d7..a3a3c85e2 100644 --- a/domain/ide/codelens/codelens_test.go +++ b/domain/ide/codelens/codelens_test.go @@ -50,7 +50,7 @@ func Test_GetCodeLensForPath(t *testing.T) { // this is using the real progress channel, so we need to listen to it dummyProgressListeners(t) - fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider) + fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true filePath, dir := code.TempWorkdirWithIssues(t) diff --git a/domain/ide/command/command_service_test.go b/domain/ide/command/command_service_test.go index f319ece23..478735b23 100644 --- a/domain/ide/command/command_service_test.go +++ b/domain/ide/command/command_service_test.go @@ -32,7 +32,7 @@ func Test_ExecuteCommand(t *testing.T) { authProvider := &authentication.FakeAuthenticationProvider{ ExpectedAuthURL: "https://auth.url", } - authenticationService := authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{authProvider}, nil, nil) + authenticationService := authentication.NewAuthenticationService(c, authProvider, nil, nil) service := NewService(authenticationService, nil, nil, nil, nil, nil) cmd := types.CommandData{ CommandId: types.CopyAuthLinkCommand, diff --git a/domain/ide/command/copy_auth_link.go b/domain/ide/command/copy_auth_link.go index b3c92dab0..98483d941 100644 --- a/domain/ide/command/copy_auth_link.go +++ b/domain/ide/command/copy_auth_link.go @@ -39,13 +39,8 @@ func (cmd *copyAuthLinkCommand) Command() types.CommandData { } func (cmd *copyAuthLinkCommand) Execute(ctx context.Context) (any, error) { - var url string - for _, provider := range cmd.authService.Providers() { - url = provider.AuthURL(ctx) - if url != "" { - break - } - } + url := cmd.authService.Provider().AuthURL(ctx) + cmd.logger.Debug().Str("method", "copyAuthLinkCommand.Execute"). Str("url", url). Msgf("copying auth link to clipboard") diff --git a/domain/ide/command/get_active_user.go b/domain/ide/command/get_active_user.go index f6ba82ca5..0cc31d461 100644 --- a/domain/ide/command/get_active_user.go +++ b/domain/ide/command/get_active_user.go @@ -40,10 +40,7 @@ func (cmd *getActiveUser) Command() types.CommandData { func (cmd *getActiveUser) Execute(_ context.Context) (any, error) { logger := config.CurrentConfig().Logger().With().Str("method", "getActiveUser.Execute").Logger() - isAuthenticated, err := cmd.authenticationService.IsAuthenticated() - if err != nil { - logger.Warn().Err(err).Msg("error checking auth status") - } + isAuthenticated := cmd.authenticationService.IsAuthenticated() if !isAuthenticated { logger.Info().Msg("not authenticated, skipping user retrieval") diff --git a/domain/ide/command/get_active_user_test.go b/domain/ide/command/get_active_user_test.go index c61259e9c..8f5c9b3ef 100644 --- a/domain/ide/command/get_active_user_test.go +++ b/domain/ide/command/get_active_user_test.go @@ -63,7 +63,7 @@ func setupCommandWithAuthService(t *testing.T, c *config.Config) *getActiveUser }, authenticationService: authentication.NewAuthenticationService( c, - []authentication.AuthenticationProvider{provider}, + provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier(), ), diff --git a/domain/ide/command/get_feature_flag_status.go b/domain/ide/command/get_feature_flag_status.go index 7719fd147..65e1f5898 100644 --- a/domain/ide/command/get_feature_flag_status.go +++ b/domain/ide/command/get_feature_flag_status.go @@ -39,10 +39,7 @@ func (cmd *featureFlagStatus) Command() types.CommandData { func (cmd *featureFlagStatus) Execute(_ context.Context) (any, error) { logger := config.CurrentConfig().Logger().With().Str("method", "featureFlagStatus.Execute").Logger() - isAuthenticated, err := cmd.authenticationService.IsAuthenticated() - if err != nil { - logger.Warn().Err(err).Msg("error checking auth status") - } + isAuthenticated := cmd.authenticationService.IsAuthenticated() if !isAuthenticated { message := "not authenticated, cannot retrieve feature flags" diff --git a/domain/ide/command/get_feature_flag_status_test.go b/domain/ide/command/get_feature_flag_status_test.go index 79c32ffe0..7829c3379 100644 --- a/domain/ide/command/get_feature_flag_status_test.go +++ b/domain/ide/command/get_feature_flag_status_test.go @@ -63,7 +63,7 @@ func setupFeatureFlagCommand(t *testing.T, c *config.Config, fakeApiClient *snyk command: types.CommandData{Arguments: []interface{}{"snykCodeConsistentIgnores"}}, authenticationService: authentication.NewAuthenticationService( c, - []authentication.AuthenticationProvider{provider}, + provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier(), ), diff --git a/domain/ide/command/logout_test.go b/domain/ide/command/logout_test.go index 9863ae633..2404ca336 100644 --- a/domain/ide/command/logout_test.go +++ b/domain/ide/command/logout_test.go @@ -18,10 +18,11 @@ package command import ( "context" - "github.com/snyk/snyk-ls/domain/snyk/persistence" "path/filepath" "testing" + "github.com/snyk/snyk-ls/domain/snyk/persistence" + "github.com/stretchr/testify/assert" "github.com/snyk/snyk-ls/domain/ide/hover" @@ -43,7 +44,7 @@ func TestLogoutCommand_Execute_ClearsIssues(t *testing.T) { provider.IsAuthenticated = true scanNotifier := snyk.NewMockScanNotifier() scanPersister := persistence.NewNopScanPersister() - authenticationService := authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{provider}, error_reporting.NewTestErrorReporter(), notifier) + authenticationService := authentication.NewAuthenticationService(c, provider, error_reporting.NewTestErrorReporter(), notifier) cmd := logoutCommand{ command: types.CommandData{CommandId: types.LogoutCommand}, authService: authenticationService, @@ -75,7 +76,7 @@ func TestLogoutCommand_Execute_ClearsIssues(t *testing.T) { _, err := cmd.Execute(ctx) assert.NoError(t, err) - authenticated, err := authenticationService.IsAuthenticated() + authenticated := authenticationService.IsAuthenticated() assert.NoError(t, err) assert.False(t, authenticated) assert.Empty(t, folder.IssuesForFile(t.TempDir())) diff --git a/domain/ide/command/report_analytics.go b/domain/ide/command/report_analytics.go index a661ab902..8c758370e 100644 --- a/domain/ide/command/report_analytics.go +++ b/domain/ide/command/report_analytics.go @@ -39,10 +39,7 @@ func (cmd *reportAnalyticsCommand) Execute(_ context.Context) (any, error) { c := config.CurrentConfig() logger := c.Logger().With().Str("method", "reportAnalyticsCommand.Execute").Logger() - isAuthenticated, err := cmd.authenticationService.IsAuthenticated() - if err != nil { - logger.Warn().Err(err).Msg("error checking auth status") - } + isAuthenticated := cmd.authenticationService.IsAuthenticated() if !isAuthenticated { logger.Info().Msg("not authenticated, skipping analytics reporting") diff --git a/domain/ide/command/report_analytics_test.go b/domain/ide/command/report_analytics_test.go index d0201f402..de3eed342 100644 --- a/domain/ide/command/report_analytics_test.go +++ b/domain/ide/command/report_analytics_test.go @@ -62,7 +62,7 @@ func setupReportAnalyticsCommand(t *testing.T, c *config.Config, testInput strin }, authenticationService: authentication.NewAuthenticationService( c, - []authentication.AuthenticationProvider{provider}, + provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier(), )} diff --git a/domain/ide/command/sast_enabled.go b/domain/ide/command/sast_enabled.go index 59ad81673..c792439e0 100644 --- a/domain/ide/command/sast_enabled.go +++ b/domain/ide/command/sast_enabled.go @@ -38,10 +38,7 @@ func (cmd *sastEnabled) Command() types.CommandData { } func (cmd *sastEnabled) Execute(_ context.Context) (any, error) { - isAuthenticated, err := cmd.authenticationService.IsAuthenticated() - if err != nil { - cmd.logger.Warn().Err(err).Str("method", "sastEnabled.Execute").Msg("error checking auth status") - } + isAuthenticated := cmd.authenticationService.IsAuthenticated() if !isAuthenticated { cmd.logger.Info().Str("method", "sastEnabled.Execute").Msg("not authenticated, skipping sast check") diff --git a/domain/ide/command/sast_enabled_test.go b/domain/ide/command/sast_enabled_test.go index a15892a90..835987f73 100644 --- a/domain/ide/command/sast_enabled_test.go +++ b/domain/ide/command/sast_enabled_test.go @@ -55,7 +55,7 @@ func setupSastEnabledCommand(t *testing.T, c *config.Config, fakeApiClient *snyk apiClient: fakeApiClient, authenticationService: authentication.NewAuthenticationService( c, - []authentication.AuthenticationProvider{provider}, + provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier(), ), diff --git a/domain/snyk/scanner.go b/domain/snyk/scanner.go index bbe86aa79..77b428b8c 100644 --- a/domain/snyk/scanner.go +++ b/domain/snyk/scanner.go @@ -217,10 +217,7 @@ func (sc *DelegatingConcurrentScanner) Scan( c := config.CurrentConfig() logger := c.Logger().With().Str("method", method).Logger() - authenticated, err := sc.authService.IsAuthenticated() - if err != nil { - logger.Err(err).Msg("Error checking authentication status") - } + authenticated := sc.authService.IsAuthenticated() if !authenticated { logger.Info().Msgf("Not authenticated, not scanning.") diff --git a/domain/snyk/scanner_test.go b/domain/snyk/scanner_test.go index 14d30110c..24ef793c6 100644 --- a/domain/snyk/scanner_test.go +++ b/domain/snyk/scanner_test.go @@ -64,7 +64,7 @@ func setupScanner(testProductScanners ...ProductScanner) ( er := error_reporting.NewTestErrorReporter() authenticationProvider := authentication.NewFakeCliAuthenticationProvider(c) authenticationProvider.IsAuthenticated = true - authenticationService := authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{authenticationProvider}, er, notifier) + authenticationService := authentication.NewAuthenticationService(c, authenticationProvider, er, notifier) scanner = NewDelegatingScanner(c, initialize.NewDelegatingInitializer(), performance.NewInstrumentor(), scanNotifier, apiClient, authenticationService, notifier, testProductScanners...) return scanner, scanNotifier } diff --git a/go.mod b/go.mod index 83ea08b8e..8933b2ced 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/shirou/gopsutil v3.21.11+incompatible github.com/snyk/code-client-go v1.8.0 - github.com/snyk/go-application-framework v0.0.0-20240627194757-cc0fb551c613 + github.com/snyk/go-application-framework v0.0.0-20240726091718-6ffbf7a2bcd3 github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 @@ -97,7 +97,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect - github.com/snyk/error-catalog-golang-public v0.0.0-20240527112826-2b77438d25f1 // indirect + github.com/snyk/error-catalog-golang-public v0.0.0-20240724122202-c7d3fb545c88 // indirect github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.5.0 // indirect diff --git a/go.sum b/go.sum index 0bcfba469..e921fab95 100644 --- a/go.sum +++ b/go.sum @@ -283,10 +283,10 @@ github.com/skeema/knownhosts v1.2.2 h1:Iug2P4fLmDw9f41PB6thxUkNUkJzB5i+1/exaj40L github.com/skeema/knownhosts v1.2.2/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= github.com/snyk/code-client-go v1.8.0 h1:6H883KAn7ybiSIxhvL2QR9yEyHgAwA2+9WVHMDNEKa8= github.com/snyk/code-client-go v1.8.0/go.mod h1:orU911flV1kJQOlxxx0InUQkAfpBrcERsb2olfnlI8s= -github.com/snyk/error-catalog-golang-public v0.0.0-20240527112826-2b77438d25f1 h1:49X/bTeiWdi+DrkTbTSw5BePpQ6LiucIt++/Z+MB95U= -github.com/snyk/error-catalog-golang-public v0.0.0-20240527112826-2b77438d25f1/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= -github.com/snyk/go-application-framework v0.0.0-20240627194757-cc0fb551c613 h1:igHAJ85dfn9cR1onRbpe4a9Mex1/Oo4PUxJfNPaWle0= -github.com/snyk/go-application-framework v0.0.0-20240627194757-cc0fb551c613/go.mod h1:gz3PN/OfEBbtB4VxbnV33XipM8MjBcVszPJeOhCu2DU= +github.com/snyk/error-catalog-golang-public v0.0.0-20240724122202-c7d3fb545c88 h1:ZiFV5IDPI2p1wx1D3B2iSC/8nxTGvlKuyUekZlm1ptU= +github.com/snyk/error-catalog-golang-public v0.0.0-20240724122202-c7d3fb545c88/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= +github.com/snyk/go-application-framework v0.0.0-20240726091718-6ffbf7a2bcd3 h1:gisnOktdqZYJLPIpKzNhm9UQIsuSQ0Ojded7XC5zI1E= +github.com/snyk/go-application-framework v0.0.0-20240726091718-6ffbf7a2bcd3/go.mod h1:3DhXDHbBbGWRRZESbYVZyunyIDaet9SOtuOZCK7AC3g= github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530 h1:s9PHNkL6ueYRiAKNfd8OVxlUOqU3qY0VDbgCD1f6WQY= github.com/snyk/go-httpauth v0.0.0-20231117135515-eb445fea7530/go.mod h1:88KbbvGYlmLgee4OcQ19yr0bNpXpOr2kciOthaSzCAg= github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd h1:Dq5WSzWsP1TbVi10zPWBI5LKEBDg4Y1OhWEph1wr5WQ= diff --git a/infrastructure/authentication/auth_configuration.go b/infrastructure/authentication/auth_configuration.go index cf7c62fc2..0d36981f8 100644 --- a/infrastructure/authentication/auth_configuration.go +++ b/infrastructure/authentication/auth_configuration.go @@ -32,17 +32,20 @@ import ( ) // Token authentication configures token only authentication -func Token(c *config.Config, errorReporter error_reporting.ErrorReporter) []AuthenticationProvider { - c.Engine().GetConfiguration().Set(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, false) - return []AuthenticationProvider{NewCliAuthenticationProvider(c, errorReporter)} +func Token(c *config.Config, errorReporter error_reporting.ErrorReporter) AuthenticationProvider { + conf := c.Engine().GetConfiguration() + conf.Set(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, false) + conf.Unset(configuration.AUTHENTICATION_BEARER_TOKEN) + conf.Unset(auth.CONFIG_KEY_OAUTH_TOKEN) + return NewCliAuthenticationProvider(c, errorReporter) } -// Default authentication configures two authenticators, the first OAuth2, -// the second, as fallback, CLI Token auth +// Default authentication configures an OAuth2 authenticator, // the auth service parameter is needed, as the oauth2 provider needs a callback function -func Default(c *config.Config, errorReporter error_reporting.ErrorReporter, authenticationService AuthenticationService) []AuthenticationProvider { - authProviders := []AuthenticationProvider{} - +func Default(c *config.Config, errorReporter error_reporting.ErrorReporter, authenticationService AuthenticationService) AuthenticationProvider { + conf := c.Engine().GetConfiguration() + conf.Set(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, true) + conf.Unset(configuration.AUTHENTICATION_TOKEN) credentialsUpdateCallback := func(_ string, value any) { newToken, ok := value.(string) if !ok { @@ -54,22 +57,29 @@ func Default(c *config.Config, errorReporter error_reporting.ErrorReporter, auth } openBrowserFunc := func(url string) { - for _, provider := range authenticationService.Providers() { - provider.SetAuthURL(url) - } + authenticationService.Provider().SetAuthURL(url) types.DefaultOpenBrowserFunc(url) } - // add both OAuth2 and CLI, with preference to OAuth2 - authProviders = append(authProviders, - NewOAuthProvider( - c, - auth.RefreshToken, - credentialsUpdateCallback, - openBrowserFunc, - ), + refresherFunc := func(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error) { + logger := c.Logger().With().Str("method", "oauth.refresherFunc").Logger() + logger.Info().Msg("refreshing oauth2 token") + refreshToken, err := auth.RefreshToken(ctx, oauthConfig, token) + if err != nil { + logger.Err(err).Msg("failed to refresh oauth2 token") + // call authservice to handle notifications and such + // we don't need the returned values, as we know it will either return false, nil or false, err + _ = authenticationService.IsAuthenticated() + } + return refreshToken, err + } + authProvider := NewOAuthProvider( + c, + refresherFunc, + credentialsUpdateCallback, + openBrowserFunc, ) - return authProviders + return authProvider } func NewOAuthProvider( diff --git a/infrastructure/authentication/auth_service.go b/infrastructure/authentication/auth_service.go index 1d25bbd2b..27e4ac982 100644 --- a/infrastructure/authentication/auth_service.go +++ b/infrastructure/authentication/auth_service.go @@ -26,7 +26,7 @@ type AuthenticationService interface { // Authenticate attempts to authenticate the user, and sends a notification to the client when successful Authenticate(ctx context.Context) (string, error) - Providers() []AuthenticationProvider + Provider() AuthenticationProvider // UpdateCredentials stores the token in the configuration, and sends a $/snyk.hasAuthenticated notification to the // client if sendNotification is true @@ -35,10 +35,10 @@ type AuthenticationService interface { Logout(ctx context.Context) // IsAuthenticated returns true if the token is verified - IsAuthenticated() (bool, error) + IsAuthenticated() bool // AddProvider sets the authentication provider - AddProvider(provider AuthenticationProvider) + SetProvider(provider AuthenticationProvider) // ConfigureProviders updates the providers based on the stored configuration ConfigureProviders(c *config.Config) diff --git a/infrastructure/authentication/auth_service_impl.go b/infrastructure/authentication/auth_service_impl.go index 1b3d6a351..534eebca3 100644 --- a/infrastructure/authentication/auth_service_impl.go +++ b/infrastructure/authentication/auth_service_impl.go @@ -31,8 +31,11 @@ import ( "github.com/snyk/snyk-ls/internal/types" ) +const ExpirationMsg = "Your authentication failed due to token expiration. Please re-authenticate to continue using Snyk." +const InvalidCredsMessage = "Your authentication credentials cannot be validated. Automatically clearing credentials. You need to re-authenticate to use Snyk." + type AuthenticationServiceImpl struct { - providers []AuthenticationProvider + provider AuthenticationProvider errorReporter error_reporting.ErrorReporter notifier noti.Notifier c *config.Config @@ -41,10 +44,10 @@ type AuthenticationServiceImpl struct { m sync.Mutex } -func NewAuthenticationService(c *config.Config, authProviders []AuthenticationProvider, errorReporter error_reporting.ErrorReporter, notifier noti.Notifier) AuthenticationService { +func NewAuthenticationService(c *config.Config, authProviders AuthenticationProvider, errorReporter error_reporting.ErrorReporter, notifier noti.Notifier) AuthenticationService { cache := imcache.New[string, bool]() return &AuthenticationServiceImpl{ - providers: authProviders, + provider: authProviders, errorReporter: errorReporter, notifier: notifier, c: c, @@ -52,20 +55,17 @@ func NewAuthenticationService(c *config.Config, authProviders []AuthenticationPr } } -func (a *AuthenticationServiceImpl) Providers() []AuthenticationProvider { - return a.providers +func (a *AuthenticationServiceImpl) Provider() AuthenticationProvider { + return a.provider } func (a *AuthenticationServiceImpl) Authenticate(ctx context.Context) (token string, err error) { - for _, provider := range a.providers { - token, err = provider.Authenticate(ctx) - if token == "" || err != nil { - a.c.Logger().Warn().Err(err).Msgf("Failed to authenticate using auth provider %v", reflect.TypeOf(provider)) - continue - } - a.UpdateCredentials(token, true) + token, err = a.provider.Authenticate(ctx) + if token == "" || err != nil { + a.c.Logger().Warn().Err(err).Msgf("Failed to authenticate using auth provider %v", reflect.TypeOf(a.provider)) return token, err } + a.UpdateCredentials(token, true) return token, err } @@ -76,12 +76,15 @@ func (a *AuthenticationServiceImpl) UpdateCredentials(newToken string, sendNotif return } + // unlock when leaving if we locked ourselves + if a.m.TryLock() { + defer a.m.Unlock() + } + // remove old token from cache, but don't add new token, as we want the entry only when // checks are performed - e.g. in IsAuthenticated or Authenticate which call the API to check for real - a.m.Lock() a.authCache.Remove(oldToken) c.SetToken(newToken) - a.m.Unlock() if sendNotification { a.notifier.Send(types.AuthenticationParams{Token: newToken}) @@ -89,98 +92,130 @@ func (a *AuthenticationServiceImpl) UpdateCredentials(newToken string, sendNotif } func (a *AuthenticationServiceImpl) Logout(ctx context.Context) { - for _, provider := range a.providers { - err := provider.ClearAuthentication(ctx) - if err != nil { - a.c.Logger().Warn().Err(err).Str("method", "Logout").Msg("Failed to log out.") - a.errorReporter.CaptureError(err) - } + if a.m.TryLock() { + defer a.m.Unlock() + } + err := a.provider.ClearAuthentication(ctx) + if err != nil { + a.c.Logger().Warn().Err(err).Str("method", "Logout").Msg("Failed to log out.") + a.errorReporter.CaptureError(err) } a.UpdateCredentials("", true) } // IsAuthenticated returns true if the token is verified -// If the token is set, but not valid IsAuthenticated returns false and the reported error -func (a *AuthenticationServiceImpl) IsAuthenticated() (bool, error) { +// If the token is set, but not valid IsAuthenticated returns false +func (a *AuthenticationServiceImpl) IsAuthenticated() bool { logger := a.c.Logger().With().Str("method", "AuthenticationService.IsAuthenticated").Logger() - a.m.Lock() + if a.m.TryLock() { + defer a.m.Unlock() + } _, found := a.authCache.Get(a.c.Token()) if found { a.c.Logger().Debug().Msg("IsAuthenticated (found in cache)") - a.m.Unlock() - return true, nil + return true } noToken := !a.c.NonEmptyToken() if noToken { logger.Info().Str("method", "IsAuthenticated").Msg("no credentials found") - a.m.Unlock() - return false, nil + return false } var user string var err error - for _, provider := range a.providers { - providerType := reflect.TypeOf(provider).String() - - user, err = provider.GetCheckAuthenticationFunction()() - if user == "" || err != nil { - a.c.Logger(). - Err(err). - Str("method", "AuthenticationService.IsAuthenticated"). - Str("authProvider", providerType). - Msg("Failed to get active user") - } else { - break - } - } - if user == "" { - a.m.Unlock() + user, err = a.provider.GetCheckAuthenticationFunction()() + if user == "" || err != nil { + a.c.Logger(). + Err(err). + Str("method", "AuthenticationService.IsAuthenticated"). + Msg("Failed to get active user") + + invalidToken, isLegacyTokenErr := a.c.TokenAsOAuthToken() + + // we always log out logger.Debug().Msg("logging out") a.Logout(context.Background()) - a.HandleInvalidCredentials() - return false, err + + // determine the right error message + if isLegacyTokenErr == nil { + // it is an oauth token + if invalidToken.Expiry.Before(time.Now()) { + a.handleFailedRefresh() + } else { + // access token not expired, but creds still not work + a.HandleInvalidCredentials() + } + } else { + // legacy token does not work + a.HandleInvalidCredentials() + } + return false } // we cache the API auth ok for up to 1 minutes after last access. Afterwards, a new check is performed. a.authCache.Set(a.c.Token(), true, imcache.WithSlidingExpiration(time.Minute)) a.c.Logger().Debug().Msg("IsAuthenticated: " + user + ", adding to cache.") - a.m.Unlock() - return true, nil + return true } -func (a *AuthenticationServiceImpl) AddProvider(provider AuthenticationProvider) { - a.providers = append(a.providers, provider) +func (a *AuthenticationServiceImpl) handleFailedRefresh() { + // access token expired and refresh failed + a.sendAuthenticationRequest(ExpirationMsg, "Re-authenticate") } -func (a *AuthenticationServiceImpl) setProviders(providers []AuthenticationProvider) { - a.providers = providers +func (a *AuthenticationServiceImpl) SetProvider(provider AuthenticationProvider) { + a.provider = provider } func (a *AuthenticationServiceImpl) ConfigureProviders(c *config.Config) { - var as []AuthenticationProvider + if a.m.TryLock() { + defer a.m.Unlock() + } + authProviderChange := false + var p AuthenticationProvider switch c.AuthenticationMethod() { - case types.FakeAuthentication: - a.setProviders([]AuthenticationProvider{NewFakeCliAuthenticationProvider(c)}) + default: + // if err != nil, previous token was legacy. So we had a provider change + _, err := c.TokenAsOAuthToken() + if err != nil && c.NonEmptyToken() { + authProviderChange = true + } + + p = Default(c, a.errorReporter, a) + a.SetProvider(p) case types.TokenAuthentication: - as = Token(c, a.errorReporter) - a.setProviders(as) + // if err == nil, previous token was oauth2. So we had a provider change + _, err := c.TokenAsOAuthToken() + if err == nil && c.NonEmptyToken() { + authProviderChange = true + } + + p = Token(c, a.errorReporter) + a.SetProvider(p) + case types.FakeAuthentication: + a.SetProvider(NewFakeCliAuthenticationProvider(c)) case "": // don't do anything - default: - as = Default(c, a.errorReporter, a) - a.setProviders(as) + } + + if authProviderChange { + a.Logout(context.Background()) + a.sendAuthenticationRequest("Your authentication method has changed. Please re-authenticate to continue using Snyk.", "Re-authenticate") } } func (a *AuthenticationServiceImpl) HandleInvalidCredentials() { - msg := "Your authentication credentials cannot be validated. Automatically clearing credentials. You need to re-authenticate to use Snyk." + msg := InvalidCredsMessage + a.sendAuthenticationRequest(msg, "Authenticate") +} +func (a *AuthenticationServiceImpl) sendAuthenticationRequest(msg string, actionName string) { actions := data_structure.OrderedMap[types.MessageAction, types.CommandData]{} - actions.Add("Authenticate", types.CommandData{ - Title: "Authenticate", + actions.Add(types.MessageAction(actionName), types.CommandData{ + Title: actionName, CommandId: types.LoginCommand, }) actions.Add("Cancel", types.CommandData{}) diff --git a/infrastructure/authentication/auth_service_impl_test.go b/infrastructure/authentication/auth_service_impl_test.go index 87c16c1fe..920acaa1f 100644 --- a/infrastructure/authentication/auth_service_impl_test.go +++ b/infrastructure/authentication/auth_service_impl_test.go @@ -67,25 +67,21 @@ func Test_IsAuthenticated(t *testing.T) { c := testutil.UnitTest(t) provider := FakeAuthenticationProvider{IsAuthenticated: true, C: c} - providers := []AuthenticationProvider{&provider} - service := NewAuthenticationService(c, providers, error_reporting.NewTestErrorReporter(), notification.NewNotifier()) + service := NewAuthenticationService(c, &provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier()) - isAuthenticated, err := service.IsAuthenticated() + isAuthenticated := service.IsAuthenticated() assert.True(t, isAuthenticated) - assert.NoError(t, err) }) t.Run("User is not authenticated", func(t *testing.T) { c := testutil.UnitTest(t) provider := FakeAuthenticationProvider{IsAuthenticated: false, C: c} - providers := []AuthenticationProvider{&provider} - service := NewAuthenticationService(c, providers, error_reporting.NewTestErrorReporter(), notification.NewNotifier()) + service := NewAuthenticationService(c, &provider, error_reporting.NewTestErrorReporter(), notification.NewNotifier()) - isAuthenticated, err := service.IsAuthenticated() + isAuthenticated := service.IsAuthenticated() assert.False(t, isAuthenticated) - assert.Equal(t, err.Error(), "Authentication failed. Please update your token.") }) } @@ -93,7 +89,7 @@ func Test_Logout(t *testing.T) { c := testutil.IntegTest(t) provider := FakeAuthenticationProvider{IsAuthenticated: true} notifier := notification.NewNotifier() - service := NewAuthenticationService(c, []AuthenticationProvider{&provider}, error_reporting.NewTestErrorReporter(), notifier) + service := NewAuthenticationService(c, &provider, error_reporting.NewTestErrorReporter(), notifier) // act service.Logout(context.Background()) @@ -121,9 +117,8 @@ func TestHandleInvalidCredentials(t *testing.T) { notifier := notification.NewNotifier() provider := NewFakeCliAuthenticationProvider(c) provider.IsAuthenticated = false - providers := []AuthenticationProvider{provider} c.SetToken("invalidCreds") - cut := NewAuthenticationService(c, providers, errorReporter, notifier).(*AuthenticationServiceImpl) + cut := NewAuthenticationService(c, provider, errorReporter, notifier).(*AuthenticationServiceImpl) messageRequestReceived := false callback := func(params any) { switch p := params.(type) { diff --git a/infrastructure/authentication/initializer.go b/infrastructure/authentication/initializer.go index d443ec9af..b88f536fc 100644 --- a/infrastructure/authentication/initializer.go +++ b/infrastructure/authentication/initializer.go @@ -51,20 +51,14 @@ func (i *Initializer) Init() error { const errorMessage = "Auth Initializer failed to authenticate." c := config.CurrentConfig() if c.NonEmptyToken() { - authenticated, err := i.authenticationService.IsAuthenticated() + authenticated := i.authenticationService.IsAuthenticated() if authenticated { c.Logger().Info().Str("method", "auth.initializer.init").Msg("Skipping authentication - user is already authenticated") return nil } - return err } - // token is empty from here on if !c.AutomaticAuthentication() { - err := i.handleNotAuthenticatedAndManualAuthActive() - if err != nil { - return err - } return nil } @@ -95,12 +89,3 @@ func (i *Initializer) authenticate(authenticationService AuthenticationService, } return nil } - -func (i *Initializer) handleNotAuthenticatedAndManualAuthActive() error { - msg := "Skipping scan - user is not authenticated and automatic authentication is disabled" - i.c.Logger().Info().Msg(msg) - - // If the user is not authenticated and auto-authentication is disabled, return an error to indicate the user - // could not be authenticated and the scan cannot start - return errors.New(msg) -} diff --git a/infrastructure/authentication/initializer_test.go b/infrastructure/authentication/initializer_test.go index 8f3e50174..6c3d13d72 100644 --- a/infrastructure/authentication/initializer_test.go +++ b/infrastructure/authentication/initializer_test.go @@ -19,7 +19,7 @@ package authentication import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/snyk/snyk-ls/application/config" "github.com/snyk/snyk-ls/internal/notification" @@ -27,33 +27,27 @@ import ( ) func Test_autoAuthenticationDisabled_doesNotAuthenticate(t *testing.T) { - t.Run("Does not authenticate when auto-auth is disabled", getAutoAuthenticationTest(false, true)) - t.Run("Authenticates when auto-auth is enabled", getAutoAuthenticationTest(true, false)) + t.Run("Does not authenticate when auto-auth is disabled", getAutoAuthenticationTest(false)) + t.Run("Authenticates when auto-auth is enabled", getAutoAuthenticationTest(true)) } -func getAutoAuthenticationTest(autoAuthentication bool, expectError bool) func(t *testing.T) { +func getAutoAuthenticationTest(autoAuthentication bool) func(t *testing.T) { return func(t *testing.T) { // Arrange + t.Helper() c := config.CurrentConfig() c.SetToken("") c.SetAutomaticAuthentication(autoAuthentication) provider := NewFakeCliAuthenticationProvider(c) - providers := []AuthenticationProvider{provider} notifier := notification.NewNotifier() - authenticator := NewAuthenticationService(c, providers, errorreporting.NewTestErrorReporter(), notifier) + authenticator := NewAuthenticationService(c, provider, errorreporting.NewTestErrorReporter(), notifier) initializer := NewInitializer(c, authenticator, errorreporting.NewTestErrorReporter(), notifier) // Act err := initializer.Init() + require.NoError(t, err) - // Assert - //assert.Equal(t, expectError, err != nil) - if expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - assert.Equal(t, autoAuthentication, provider.IsAuthenticated) + require.True(t, provider.IsAuthenticated == autoAuthentication) } } diff --git a/infrastructure/cli/environment.go b/infrastructure/cli/environment.go index 7510976ca..400fc3803 100644 --- a/infrastructure/cli/environment.go +++ b/infrastructure/cli/environment.go @@ -23,19 +23,21 @@ import ( "github.com/snyk/go-application-framework/pkg/auth" "github.com/snyk/go-application-framework/pkg/configuration" + "github.com/snyk/snyk-ls/internal/types" "github.com/snyk/snyk-ls/application/config" ) -const ( - ApiEnvVar = "SNYK_API" - TokenEnvVar = "SNYK_TOKEN" - DisableAnalyticsEnvVar = "SNYK_CFG_DISABLE_ANALYTICS" +var ( + ApiEnvVar = strings.ToUpper(configuration.API_URL) + TokenEnvVar = strings.ToUpper(configuration.AUTHENTICATION_TOKEN) + DisableAnalyticsEnvVar = strings.ToUpper(configuration.ANALYTICS_DISABLED) + SnykOauthTokenEnvVar = strings.ToUpper(configuration.AUTHENTICATION_BEARER_TOKEN) + OAuthEnabledEnvVar = strings.ToUpper(configuration.FF_OAUTH_AUTH_FLOW_ENABLED) IntegrationNameEnvVarKey = "SNYK_INTEGRATION_NAME" IntegrationVersionEnvVarKey = "SNYK_INTEGRATION_VERSION" IntegrationEnvironmentEnvVarKey = "SNYK_INTEGRATION_ENVIRONMENT" IntegrationEnvironmentVersionEnvVar = "SNYK_INTEGRATION_ENVIRONMENT_VERSION" - SnykOauthTokenEnvVar = "SNYK_OAUTH_TOKEN" ) // AppendCliEnvironmentVariables Returns the input array with additional variables used in the CLI run in the form of "key=value". @@ -49,12 +51,12 @@ func AppendCliEnvironmentVariables(currentEnv []string, appendToken bool) []stri // remove any existing env vars that we are going to set valuesToRemove := map[string]bool{ - ApiEnvVar: true, - TokenEnvVar: true, - SnykOauthTokenEnvVar: true, - DisableAnalyticsEnvVar: true, - auth.CONFIG_KEY_OAUTH_TOKEN: true, - configuration.FF_OAUTH_AUTH_FLOW_ENABLED: true, + ApiEnvVar: true, + TokenEnvVar: true, + SnykOauthTokenEnvVar: true, + DisableAnalyticsEnvVar: true, + auth.CONFIG_KEY_OAUTH_TOKEN: true, + OAuthEnabledEnvVar: true, } for _, s := range currentEnv { @@ -66,15 +68,18 @@ func AppendCliEnvironmentVariables(currentEnv []string, appendToken bool) []stri } if appendToken && currentConfig.NonEmptyToken() { - // default to authentication, if not there, try to set the api key - oAuthToken, err := currentConfig.TokenAsOAuthToken() - if err == nil && len(oAuthToken.AccessToken) > 0 { + if currentConfig.AuthenticationMethod() == types.OAuthAuthentication { logger.Debug().Msg("using oauth2 authentication") + oAuthToken, err := currentConfig.TokenAsOAuthToken() + if err != nil { + logger.Err(err).Msg("trying to add OAuth2 creds to CLI call and the token cannot be unmarshalled. This should never happen.") + } updatedEnv = append(updatedEnv, SnykOauthTokenEnvVar+"="+oAuthToken.AccessToken) + updatedEnv = append(updatedEnv, OAuthEnabledEnvVar+"=1") } else { - // fallback to token if existent logger.Debug().Msg("falling back to API key authentication") updatedEnv = append(updatedEnv, TokenEnvVar+"="+currentConfig.Token()) + updatedEnv = append(updatedEnv, OAuthEnabledEnvVar+"=0") } } diff --git a/infrastructure/cli/environment_test.go b/infrastructure/cli/environment_test.go index c6e89cbaa..a028148ac 100644 --- a/infrastructure/cli/environment_test.go +++ b/infrastructure/cli/environment_test.go @@ -50,6 +50,7 @@ func TestAddConfigValuesToEnv(t *testing.T) { testutil.UnitTest(t) c := config.CurrentConfig() + c.SetAuthenticationMethod(types.OAuthAuthentication) c.SetOrganization("testOrg") c.UpdateApiEndpoints("https://api.eu.snyk.io") c.SetIntegrationName(expectedIntegrationName) @@ -82,6 +83,7 @@ func TestAddConfigValuesToEnv(t *testing.T) { testutil.UnitTest(t) c := config.CurrentConfig() c.SetToken("{\"access_token\": \"testToken\"}") + c.SetAuthenticationMethod(types.OAuthAuthentication) tokenVar := TokenEnvVar + "={asdf}" inputEnv := []string{tokenVar} @@ -108,15 +110,18 @@ func TestAddConfigValuesToEnv(t *testing.T) { testutil.UnitTest(t) c := config.CurrentConfig() c.SetToken("testToken") + c.SetAuthenticationMethod(types.TokenAuthentication) updatedEnv := AppendCliEnvironmentVariables([]string{}, true) assert.Contains(t, updatedEnv, "SNYK_TOKEN="+c.Token()) + assert.Contains(t, updatedEnv, OAuthEnabledEnvVar+"=0") }) t.Run("Adds OAuth Token to env", func(t *testing.T) { testutil.UnitTest(t) c := config.CurrentConfig() + c.SetAuthenticationMethod(types.OAuthAuthentication) c.SetToken("{\"access_token\": \"testToken\"}") updatedEnv := AppendCliEnvironmentVariables([]string{}, true) @@ -124,5 +129,6 @@ func TestAddConfigValuesToEnv(t *testing.T) { token, err := c.TokenAsOAuthToken() assert.NoError(t, err) assert.Contains(t, updatedEnv, SnykOauthTokenEnvVar+"="+token.AccessToken) + assert.Contains(t, updatedEnv, OAuthEnabledEnvVar+"=1") }) }