Skip to content

Commit

Permalink
fix: auth mutex locking for callbacks (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawkyZ authored Nov 6, 2024
1 parent a829a9c commit dfa4b0f
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 21 deletions.
4 changes: 2 additions & 2 deletions infrastructure/authentication/auth_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func Default(c *config.Config, authenticationService AuthenticationService) Auth
credentialsUpdateCallback := func(_ string, value any) {
// an empty struct marks an empty token, so we stay with empty string if the cast fails
newToken, _ := value.(string)
go authenticationService.UpdateCredentials(newToken, true)
go authenticationService.updateCredentials(newToken, true)
}

openBrowserFunc := func(url string) {
authenticationService.Provider().SetAuthURL(url)
authenticationService.provider().setAuthUrl(url)
types.DefaultOpenBrowserFunc(url)
}

Expand Down
2 changes: 1 addition & 1 deletion infrastructure/authentication/auth_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type AuthenticationProvider interface {
// AuthURL returns the latest provided AuthenticationURL. This can be empty.
AuthURL(ctx context.Context) string
// SetAuthURL sets the latest provided Authentication URL. This is a temporary URL.
SetAuthURL(url string)
setAuthUrl(url string)

GetCheckAuthenticationFunction() AuthenticationFunction
}
Expand Down
12 changes: 11 additions & 1 deletion infrastructure/authentication/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,28 @@ type AuthenticationService interface {
// Authenticate attempts to authenticate the user, and sends a notification to the client when successful
Authenticate(ctx context.Context) (string, error)

// Provider returns current authentication provider.
Provider() AuthenticationProvider

// provider returns current authentication provider.
// doesn't have a mutex lock.
provider() AuthenticationProvider

// UpdateCredentials stores the token in the configuration, and sends a $/snyk.hasAuthenticated notification to the
// client if sendNotification is true
UpdateCredentials(newToken string, sendNotification bool)

// updateCredentials stores the token in the configuration, and sends a $/snyk.hasAuthenticated notification to the
// client if sendNotification is true
// doesn't have a mutex lock
updateCredentials(newToken string, sendNotification bool)

Logout(ctx context.Context)

// IsAuthenticated returns true if the token is verified
IsAuthenticated() bool

// AddProvider sets the authentication provider
// SetProvider sets the authentication provider
SetProvider(provider AuthenticationProvider)

// ConfigureProviders updates the providers based on the stored configuration
Expand Down
30 changes: 19 additions & 11 deletions infrastructure/authentication/auth_service_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const ExpirationMsg = "Your authentication failed due to token expiration. Pleas
const InvalidCredsMessage = "Your authentication credentials cannot be validated. Automatically clearing credentials. You need to re-authenticate to use Snyk."

type AuthenticationServiceImpl struct {
provider AuthenticationProvider
authProvider AuthenticationProvider
errorReporter error_reporting.ErrorReporter
notifier noti.Notifier
c *config.Config
Expand All @@ -58,7 +58,7 @@ type AuthenticationServiceImpl struct {
func NewAuthenticationService(c *config.Config, authProviders AuthenticationProvider, errorReporter error_reporting.ErrorReporter, notifier noti.Notifier) AuthenticationService {
cache := imcache.New[string, bool]()
return &AuthenticationServiceImpl{
provider: authProviders,
authProvider: authProviders,
errorReporter: errorReporter,
notifier: notifier,
c: c,
Expand All @@ -70,7 +70,11 @@ func (a *AuthenticationServiceImpl) Provider() AuthenticationProvider {
a.m.RLock()
defer a.m.RUnlock()

return a.provider
return a.authProvider
}

func (a *AuthenticationServiceImpl) provider() AuthenticationProvider {
return a.authProvider
}

func (a *AuthenticationServiceImpl) Authenticate(ctx context.Context) (token string, err error) {
Expand All @@ -81,10 +85,10 @@ func (a *AuthenticationServiceImpl) Authenticate(ctx context.Context) (token str
}

func (a *AuthenticationServiceImpl) authenticate(ctx context.Context) (token string, err error) {
token, err = a.provider.Authenticate(ctx)
token, err = a.authProvider.Authenticate(ctx)

if token == "" || err != nil {
a.c.Logger().Warn().Err(err).Msgf("Failed to authenticate using auth provider %v", reflect.TypeOf(a.provider))
a.c.Logger().Warn().Err(err).Msgf("Failed to authenticate using auth provider %v", reflect.TypeOf(a.authProvider))
a.sendAuthenticationAnalytics(analytics.Failure, err)
return token, err
}
Expand Down Expand Up @@ -197,7 +201,7 @@ func (a *AuthenticationServiceImpl) Logout(ctx context.Context) {
}

func (a *AuthenticationServiceImpl) logout(ctx context.Context) {
err := a.provider.ClearAuthentication(ctx)
err := a.authProvider.ClearAuthentication(ctx)
if err != nil {
a.c.Logger().Warn().Err(err).Str("method", "Logout").Msg("Failed to log out.")
a.errorReporter.CaptureError(err)
Expand Down Expand Up @@ -230,18 +234,22 @@ func (a *AuthenticationServiceImpl) isAuthenticated() bool {
return false
}

if a.provider == nil {
if a.authProvider == nil {
a.configureProviders(a.c)
}

var user string
var err error
user, err = a.provider.GetCheckAuthenticationFunction()()
user, err = a.authProvider.GetCheckAuthenticationFunction()()
if user == "" {
if a.c.Offline() || (err != nil && !shouldCauseLogout(err, a.c.Logger())) {
userMsg := fmt.Sprintf("Could not retrieve authentication status. Most likely this is a temporary error "+
"caused by connectivity problems. If this message does not go away, please log out and re-authenticate (%s)", err.Error())
userMsg := "Could not retrieve authentication status. Most likely this is a temporary error " +
"caused by connectivity problems. If this message does not go away, please log out and re-authenticate"
if err != nil {
userMsg += fmt.Sprintf(" (%s)", err.Error())
}
a.notifier.SendShowMessage(sglsp.MTError, userMsg)

logger.Info().Msg("not logging out, as we had an error, but returning not authenticated to caller")
return false
}
Expand Down Expand Up @@ -319,7 +327,7 @@ func (a *AuthenticationServiceImpl) SetProvider(provider AuthenticationProvider)
}

func (a *AuthenticationServiceImpl) setProvider(provider AuthenticationProvider) {
a.provider = provider
a.authProvider = provider
}

func (a *AuthenticationServiceImpl) ConfigureProviders(c *config.Config) {
Expand Down
2 changes: 1 addition & 1 deletion infrastructure/authentication/cli_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewCliAuthenticationProvider(c *config.Config, errorReporter error_reportin
return &CliAuthenticationProvider{"", errorReporter, c}
}

func (a *CliAuthenticationProvider) SetAuthURL(url string) {
func (a *CliAuthenticationProvider) setAuthUrl(url string) {
a.authURL = url
}

Expand Down
4 changes: 1 addition & 3 deletions infrastructure/authentication/oauth_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ func (p *OAuth2Provider) Authenticate(_ context.Context) (string, error) {
return p.config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN), err
}

func (p *OAuth2Provider) SetAuthURL(url string) {
p.m.Lock()
defer p.m.Unlock()
func (p *OAuth2Provider) setAuthUrl(url string) {
p.authURL = url
}

Expand Down
2 changes: 1 addition & 1 deletion infrastructure/authentication/oauth_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestAuthURL_ShouldReturnURL(t *testing.T) {
config := configuration.New()
authenticator := NewFakeOauthAuthenticator(time.Now().Add(10*time.Second), true, config, true).(*fakeOauthAuthenticator)
provider := newOAuthProvider(config, authenticator, config2.CurrentConfig().Logger())
provider.SetAuthURL("https://auth.fake.snyk.io")
provider.setAuthUrl("https://auth.fake.snyk.io")
url := provider.AuthURL(context.Background())

assert.NotEmpty(t, url)
Expand Down
2 changes: 1 addition & 1 deletion infrastructure/authentication/provider_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (a *FakeAuthenticationProvider) AuthURL(_ context.Context) string {
return a.ExpectedAuthURL
}

func (a *FakeAuthenticationProvider) SetAuthURL(url string) {
func (a *FakeAuthenticationProvider) setAuthUrl(url string) {
a.authURL = url
}

Expand Down

0 comments on commit dfa4b0f

Please sign in to comment.