Skip to content

Commit

Permalink
Fix pagerduty AMR test to prevent flakiness (#46390)
Browse files Browse the repository at this point in the history
* Fix pagerduty AMR test to prevent flakiness

* Update integrations/access/pagerduty/testlib/suite.go

Co-authored-by: Zac Bergquist <[email protected]>

* Swap pagerduty test to use EventuallyWith

* Update pagerduty tests to not create several access requests

* Make more information available to AMR cache update callback

* Update integrations/access/pagerduty/testlib/suite.go

Co-authored-by: Tiago Silva <[email protected]>

* Update integrations/access/pagerduty/testlib/suite.go

Co-authored-by: Tiago Silva <[email protected]>

* Fix formatting

* Revert rename of pluginData in pagerduty tests

* Remove duplicated ruleHandler init

---------

Co-authored-by: Zac Bergquist <[email protected]>
Co-authored-by: Tiago Silva <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 82e120d commit 3672bc6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
10 changes: 10 additions & 0 deletions integrations/access/accessmonitoring/access_monitoring_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type RuleHandler struct {
pluginName string

fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
onCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// RuleMap is a concurrent map for access monitoring rules.
Expand All @@ -65,6 +66,8 @@ type RuleHandlerConfig struct {

// FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients.
FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
// OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated.
OnCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// NewRuleHandler returns a new RuleHandler.
Expand All @@ -77,6 +80,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler {
pluginType: conf.PluginType,
pluginName: conf.PluginName,
fetchRecipientCallback: conf.FetchRecipientCallback,
onCacheUpdateCallback: conf.OnCacheUpdateCallback,
}
}

Expand All @@ -93,6 +97,9 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err
continue
}
amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, amr.GetMetadata().Name, amr)
}
}
return nil
}
Expand Down Expand Up @@ -123,6 +130,9 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t
return nil
}
amrh.accessMonitoringRules.rules[req.Metadata.Name] = req
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, req.GetMetadata().Name, req)
}
return nil
case types.OpDelete:
delete(amrh.accessMonitoringRules.rules, event.Resource.GetName())
Expand Down
9 changes: 6 additions & 3 deletions integrations/access/pagerduty/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ func NewApp(conf Config) (*App, error) {
teleport: conf.Client,
statusSink: conf.StatusSink,
}

app.mainJob = lib.NewServiceJob(app.run)

return app, nil
Expand Down Expand Up @@ -173,7 +172,7 @@ func (a *App) init(ctx context.Context) error {
}
}

a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
amrhConf := accessmonitoring.RuleHandlerConfig{
Client: a.teleport,
PluginType: types.PluginTypePagerDuty,
PluginName: pluginName,
Expand All @@ -184,7 +183,11 @@ func (a *App) init(ctx context.Context) error {
Kind: common.RecipientKindSchedule,
}, nil
},
})
}
if a.conf.OnAccessMonitoringRuleCacheUpdateCallback != nil {
amrhConf.OnCacheUpdateCallback = a.conf.OnAccessMonitoringRuleCacheUpdateCallback
}
a.accessMonitoringRules = accessmonitoring.NewRuleHandler(amrhConf)

if pong, err = a.checkTeleportVersion(ctx); err != nil {
return trace.Wrap(err)
Expand Down
6 changes: 6 additions & 0 deletions integrations/access/pagerduty/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/gravitational/trace"
"github.com/pelletier/go-toml"

accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/common/teleport"
"github.com/gravitational/teleport/integrations/lib"
Expand All @@ -47,6 +49,10 @@ type Config struct {
// TeleportUser is the name of the Teleport user that will act
// as the access request approver
TeleportUser string

// OnAccessMonitoringRuleCacheUpdateCallback is used for checking when
// the Rule cache is updated in tests
OnAccessMonitoringRuleCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

type PagerdutyConfig struct {
Expand Down
27 changes: 22 additions & 5 deletions integrations/access/pagerduty/testlib/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

const ruleName = "test-pagerduty-amr"
var collectedNames []string
var mu sync.Mutex
s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error {
mu.Lock()
collectedNames = append(collectedNames, name)
mu.Unlock()
return nil
}
s.startApp()

_, err := s.ClientByName(integration.RulerUserName).
Expand All @@ -438,7 +447,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &v1.Metadata{
Name: "test-pagerduty-amr",
Name: ruleName,
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Expand All @@ -453,6 +462,14 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
})
assert.NoError(t, err)

// Incident creation may happen before plugins Access Monitoring Rule cache
// has been updated with new rule. Retry until the new cache picks up the rule.
require.EventuallyWithT(t, func(t *assert.CollectT) {
mu.Lock()
require.Contains(t, collectedNames, ruleName)
mu.Unlock()
}, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying")

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)

Expand All @@ -463,16 +480,16 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
})

incident, err := s.fakePagerduty.CheckNewIncident(ctx)
require.NoError(t, err, "no new incidents stored")

assert.NoError(t, err, "no new incidents stored")
assert.Equal(t, incident.ID, pluginData.IncidentID)
assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID)

assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey)
assert.Equal(t, "triggered", incident.Status)

assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID)

assert.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr"))
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, ruleName))
}

func (s *PagerdutyBaseSuite) assertNewEvent(ctx context.Context, watcher types.Watcher, opType types.OpType, resourceKind, resourceName string) types.Event {
Expand Down

0 comments on commit 3672bc6

Please sign in to comment.