Skip to content

Commit

Permalink
feat: add authentication error messages (oauth expiry, invalid creds)…
Browse files Browse the repository at this point in the history
… [IDE-459] (#607)
  • Loading branch information
bastiandoetsch authored Jul 29, 2024
1 parent 6d70ea0 commit ce0f92f
Show file tree
Hide file tree
Showing 32 changed files with 320 additions and 202 deletions.
5 changes: 3 additions & 2 deletions application/di/test_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
package di

import (
"github.com/snyk/snyk-ls/domain/snyk/persistence"
"testing"

"github.com/snyk/snyk-ls/domain/snyk/persistence"

"github.com/golang/mock/gomock"

"github.com/snyk/snyk-ls/application/codeaction"
Expand Down Expand Up @@ -60,7 +61,7 @@ func TestInit(t *testing.T) {
installer = install.NewFakeInstaller()
authProvider := authentication.NewFakeCliAuthenticationProvider(c)
snykApiClient = &snyk_api.FakeApiClient{CodeEnabled: true}
authenticationService = authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{authProvider}, errorReporter, notifier)
authenticationService = authentication.NewAuthenticationService(c, authProvider, errorReporter, notifier)
snykCli := cli.NewExecutor(c, errorReporter, notifier)
cliInitializer = cli.NewInitializer(errorReporter, installer, notifier, snykCli)
authInitializer := authentication.NewInitializer(c, authenticationService, errorReporter, notifier)
Expand Down
105 changes: 105 additions & 0 deletions application/server/authentication_smoke_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* © 2024 Snyk Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package server

import (
"context"
"encoding/json"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/snyk/snyk-ls/application/di"
"github.com/snyk/snyk-ls/infrastructure/authentication"
"github.com/snyk/snyk-ls/internal/testutil"
"github.com/snyk/snyk-ls/internal/types"
"github.com/snyk/snyk-ls/internal/uri"
)

func Test_InvalidExpiredCredentialsSendMessageRequest(t *testing.T) {
// how to process the expected callback
token := getDummyOAuth2Token(time.Now().Add(-time.Hour))
tokenBytes, marshallingErr := json.Marshal(token)
require.NoError(t, marshallingErr)

checkInvalidCredentialsMessageRequest(t, authentication.ExpirationMsg, string(tokenBytes))
}

func Test_InvalidCredentialsNotExpiredSendMessageRequest(t *testing.T) {
token := getDummyOAuth2Token(time.Now().Add(+time.Hour))
tokenBytes, marshallingErr := json.Marshal(token)
require.NoError(t, marshallingErr)

checkInvalidCredentialsMessageRequest(t, authentication.InvalidCredsMessage, string(tokenBytes))
}

func getDummyOAuth2Token(expiry time.Time) oauth2.Token {
token := oauth2.Token{
AccessToken: "a",
TokenType: "bearer",
RefreshToken: "c",
Expiry: expiry,
}
return token
}

func checkInvalidCredentialsMessageRequest(t *testing.T, expected string, tokenString string) {
t.Helper()
srv, jsonRpcRecorder := setupServer(t)

c := testutil.SmokeTest(t, false)
c.SetSnykIacEnabled(false)
c.SetSnykOssEnabled(true)
// we have to reset the token, as smoketest automatically grab it from env
c.SetToken("")
di.Init()

clientParams := types.InitializeParams{
WorkspaceFolders: []types.WorkspaceFolder{{Uri: uri.PathToUri(t.TempDir()), Name: t.Name()}},
InitializationOptions: types.Settings{
Token: tokenString,
EnableTrustedFoldersFeature: "false",
FilterSeverity: types.DefaultSeverityFilter(),
AuthenticationMethod: types.OAuthAuthentication,
AutomaticAuthentication: "false",
},
}

lspClient := srv.Client
jsonRpcRecorder.ClearCallbacks()

_, err := lspClient.Call(context.Background(), "initialize", clientParams)
require.NoError(t, err)
_, err = lspClient.Call(context.Background(), "initialized", nil)
require.NoError(t, err)

assert.Eventuallyf(t, func() bool {
callbacks := jsonRpcRecorder.FindCallbacksByMethod("window/showMessageRequest")
for _, callback := range callbacks {
if strings.Contains(callback.ParamString(), expected) {
return true
} else {
t.Error("wrong callback received", callback.ParamString())
}
}
return false
}, time.Second*5, time.Millisecond, "callback not received")
}
8 changes: 4 additions & 4 deletions application/server/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,12 @@ func Test_WorkspaceDidChangeConfiguration_PullNoCapability(t *testing.T) {
}

func Test_UpdateSettings(t *testing.T) {
di.TestInit(t)

orgUuid, _ := uuid.NewRandom()
expectedOrgId := orgUuid.String()

t.Run("All settings are updated", func(t *testing.T) {
c := testutil.UnitTest(t)
di.TestInit(t)

tempDir1 := filepath.Join(t.TempDir(), "tempDir1")
tempDir2 := filepath.Join(t.TempDir(), "tempDir2")
Expand All @@ -196,7 +195,7 @@ func Test_UpdateSettings(t *testing.T) {
RuntimeName: "java",
RuntimeVersion: "1.8.0_275",
ScanningMode: "manual",
AuthenticationMethod: types.OAuthAuthentication,
AuthenticationMethod: types.FakeAuthentication,
SnykCodeApi: sampleSettings.SnykCodeApi,
EnableSnykOpenBrowserActions: "true",
FolderConfigs: []types.FolderConfig{
Expand Down Expand Up @@ -226,7 +225,6 @@ func Test_UpdateSettings(t *testing.T) {
assert.Equal(t, expectedOrgId, c.Organization())
assert.False(t, c.ManageBinariesAutomatically())
assert.Equal(t, "C:\\Users\\CliPath\\snyk-ls.exe", c.CliSettings().Path())
assert.Equal(t, "a fancy token", c.Token())
assert.Equal(t, types.DefaultSeverityFilter(), c.FilterSeverity())
assert.Subset(t, []string{"trustedPath1", "trustedPath2"}, c.TrustedFolders())
assert.Equal(t, settings.OsPlatform, c.OsPlatform())
Expand All @@ -248,6 +246,8 @@ func Test_UpdateSettings(t *testing.T) {
folderConfig2, err := gitconfig.GetOrCreateFolderConfig(tempDir2)
assert.NoError(t, err)
assert.NotEmpty(t, folderConfig2.BaseBranch)

assert.Eventually(t, func() bool { return "a fancy token" == c.Token() }, time.Second*5, time.Millisecond)
})

t.Run("empty snyk code api is ignored and default is used", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion application/server/execute_command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func Test_loginCommand_StartsAuthentication(t *testing.T) {
if err != nil {
t.Fatal(err)
}
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = false
params := lsp.ExecuteCommandParams{Command: types.LoginCommand}

Expand Down
3 changes: 2 additions & 1 deletion application/server/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package server

import (
"context"
"reflect"

"github.com/rs/zerolog"
sglsp "github.com/sourcegraph/go-lsp"
Expand All @@ -30,7 +31,7 @@ import (
)

func notifier(c *config.Config, srv types.Server, method string, params any) {
c.Logger().Debug().Str("method", "notifier").Msgf("Notifying")
c.Logger().Debug().Str("method", "notifier").Str("type", reflect.TypeOf(params).String()).Msgf("Notifying")
err := srv.Notify(context.Background(), method, params)
logError(c.Logger(), err, "notifier")
}
Expand Down
2 changes: 1 addition & 1 deletion application/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func initializedHandler(srv *jrpc2.Server) handler.Func {
err := di.Scanner().Init()
if err != nil {
logger.Error().Err(err).Msg("Scan initialization error, canceling scan")
return nil, err
return nil, nil
}

autoScanEnabled := c.IsAutoScanEnabled()
Expand Down
14 changes: 7 additions & 7 deletions application/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func Test_TextDocumentCodeLenses_shouldReturnCodeLenses(t *testing.T) {
testutil.IntegTest(t) // this needs an authenticated user
loc, _ := setupServer(t)
didOpenParams, dir := didOpenTextParams(t)
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

clientParams := types.InitializeParams{
Expand Down Expand Up @@ -397,7 +397,7 @@ func Test_TextDocumentCodeLenses_dirtyFileShouldFilterCodeLenses(t *testing.T) {
testutil.IntegTest(t) // this needs an authenticated user
loc, _ := setupServer(t)
didOpenParams, dir := didOpenTextParams(t)
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

clientParams := types.InitializeParams{
Expand Down Expand Up @@ -662,7 +662,7 @@ func Test_initialize_handlesUntrustedFoldersWhenAuthenticated(t *testing.T) {
Token: "token",
}

fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

params := types.InitializeParams{
Expand Down Expand Up @@ -706,7 +706,7 @@ func Test_initialize_doesnotHandleUntrustedFolders(t *testing.T) {
func Test_textDocumentDidSaveHandler_shouldAcceptDocumentItemAndPublishDiagnostics(t *testing.T) {
loc, jsonRPCRecorder := setupServer(t)
config.CurrentConfig().SetSnykCodeEnabled(true)
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

_, err := loc.Client.Call(ctx, "initialize", nil)
Expand Down Expand Up @@ -758,7 +758,7 @@ func Test_textDocumentDidSaveHandler_shouldTriggerScanForDotSnykFile(t *testing.
c.SetAuthenticationMethod(types.FakeAuthentication)
di.AuthenticationService().ConfigureProviders(c)

fakeAuthenticationProvider := di.AuthenticationService().Providers()[0]
fakeAuthenticationProvider := di.AuthenticationService().Provider()
fakeAuthenticationProvider.(*authentication.FakeAuthenticationProvider).IsAuthenticated = true

_, err := loc.Client.Call(ctx, "initialize", nil)
Expand Down Expand Up @@ -810,7 +810,7 @@ func Test_textDocumentDidOpenHandler_shouldNotPublishIfNotCached(t *testing.T) {
func Test_textDocumentDidOpenHandler_shouldPublishIfCached(t *testing.T) {
loc, jsonRPCRecorder := setupServer(t)
config.CurrentConfig().SetSnykCodeEnabled(true)
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true
_, err := loc.Client.Call(ctx, "initialize", nil)
if err != nil {
Expand Down Expand Up @@ -1008,7 +1008,7 @@ func Test_IntegrationHoverResults(t *testing.T) {
loc, _ := setupServer(t)
c := testutil.IntegTest(t)

fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

var cloneTargetDir, err = testutil.SetupCustomTestRepo(t, t.TempDir(), nodejsGoof, "0336589", c.Logger())
Expand Down
4 changes: 2 additions & 2 deletions application/server/trust_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func Test_handleUntrustedFolders_shouldTriggerTrustRequestAndNotScanAfterNegativ
func Test_initializeHandler_shouldCallHandleUntrustedFolders(t *testing.T) {
loc, jsonRPCRecorder := setupServer(t)
config.CurrentConfig().SetTrustedFolderFeatureEnabled(true)
fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

_, err := loc.Client.Call(context.Background(), "initialize", types.InitializeParams{
Expand Down Expand Up @@ -150,7 +150,7 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) {
c.SetTrustedFolderFeatureEnabled(true)
c.SetTrustedFolderFeatureEnabled(true)

fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

rootDir := t.TempDir()
Expand Down
2 changes: 1 addition & 1 deletion domain/ide/codelens/codelens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func Test_GetCodeLensForPath(t *testing.T) {
// this is using the real progress channel, so we need to listen to it
dummyProgressListeners(t)

fakeAuthenticationProvider := di.AuthenticationService().Providers()[0].(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true

filePath, dir := code.TempWorkdirWithIssues(t)
Expand Down
2 changes: 1 addition & 1 deletion domain/ide/command/command_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Test_ExecuteCommand(t *testing.T) {
authProvider := &authentication.FakeAuthenticationProvider{
ExpectedAuthURL: "https://auth.url",
}
authenticationService := authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{authProvider}, nil, nil)
authenticationService := authentication.NewAuthenticationService(c, authProvider, nil, nil)
service := NewService(authenticationService, nil, nil, nil, nil, nil)
cmd := types.CommandData{
CommandId: types.CopyAuthLinkCommand,
Expand Down
9 changes: 2 additions & 7 deletions domain/ide/command/copy_auth_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ func (cmd *copyAuthLinkCommand) Command() types.CommandData {
}

func (cmd *copyAuthLinkCommand) Execute(ctx context.Context) (any, error) {
var url string
for _, provider := range cmd.authService.Providers() {
url = provider.AuthURL(ctx)
if url != "" {
break
}
}
url := cmd.authService.Provider().AuthURL(ctx)

cmd.logger.Debug().Str("method", "copyAuthLinkCommand.Execute").
Str("url", url).
Msgf("copying auth link to clipboard")
Expand Down
5 changes: 1 addition & 4 deletions domain/ide/command/get_active_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ func (cmd *getActiveUser) Command() types.CommandData {

func (cmd *getActiveUser) Execute(_ context.Context) (any, error) {
logger := config.CurrentConfig().Logger().With().Str("method", "getActiveUser.Execute").Logger()
isAuthenticated, err := cmd.authenticationService.IsAuthenticated()
if err != nil {
logger.Warn().Err(err).Msg("error checking auth status")
}
isAuthenticated := cmd.authenticationService.IsAuthenticated()

if !isAuthenticated {
logger.Info().Msg("not authenticated, skipping user retrieval")
Expand Down
2 changes: 1 addition & 1 deletion domain/ide/command/get_active_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func setupCommandWithAuthService(t *testing.T, c *config.Config) *getActiveUser
},
authenticationService: authentication.NewAuthenticationService(
c,
[]authentication.AuthenticationProvider{provider},
provider,
error_reporting.NewTestErrorReporter(),
notification.NewNotifier(),
),
Expand Down
5 changes: 1 addition & 4 deletions domain/ide/command/get_feature_flag_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ func (cmd *featureFlagStatus) Command() types.CommandData {

func (cmd *featureFlagStatus) Execute(_ context.Context) (any, error) {
logger := config.CurrentConfig().Logger().With().Str("method", "featureFlagStatus.Execute").Logger()
isAuthenticated, err := cmd.authenticationService.IsAuthenticated()
if err != nil {
logger.Warn().Err(err).Msg("error checking auth status")
}
isAuthenticated := cmd.authenticationService.IsAuthenticated()

if !isAuthenticated {
message := "not authenticated, cannot retrieve feature flags"
Expand Down
2 changes: 1 addition & 1 deletion domain/ide/command/get_feature_flag_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func setupFeatureFlagCommand(t *testing.T, c *config.Config, fakeApiClient *snyk
command: types.CommandData{Arguments: []interface{}{"snykCodeConsistentIgnores"}},
authenticationService: authentication.NewAuthenticationService(
c,
[]authentication.AuthenticationProvider{provider},
provider,
error_reporting.NewTestErrorReporter(),
notification.NewNotifier(),
),
Expand Down
7 changes: 4 additions & 3 deletions domain/ide/command/logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package command

import (
"context"
"github.com/snyk/snyk-ls/domain/snyk/persistence"
"path/filepath"
"testing"

"github.com/snyk/snyk-ls/domain/snyk/persistence"

"github.com/stretchr/testify/assert"

"github.com/snyk/snyk-ls/domain/ide/hover"
Expand All @@ -43,7 +44,7 @@ func TestLogoutCommand_Execute_ClearsIssues(t *testing.T) {
provider.IsAuthenticated = true
scanNotifier := snyk.NewMockScanNotifier()
scanPersister := persistence.NewNopScanPersister()
authenticationService := authentication.NewAuthenticationService(c, []authentication.AuthenticationProvider{provider}, error_reporting.NewTestErrorReporter(), notifier)
authenticationService := authentication.NewAuthenticationService(c, provider, error_reporting.NewTestErrorReporter(), notifier)
cmd := logoutCommand{
command: types.CommandData{CommandId: types.LogoutCommand},
authService: authenticationService,
Expand Down Expand Up @@ -75,7 +76,7 @@ func TestLogoutCommand_Execute_ClearsIssues(t *testing.T) {
_, err := cmd.Execute(ctx)

assert.NoError(t, err)
authenticated, err := authenticationService.IsAuthenticated()
authenticated := authenticationService.IsAuthenticated()
assert.NoError(t, err)
assert.False(t, authenticated)
assert.Empty(t, folder.IssuesForFile(t.TempDir()))
Expand Down
Loading

0 comments on commit ce0f92f

Please sign in to comment.