diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 72eb921f3f4e8..7fb2c045d57ea 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -48,6 +48,7 @@ type RuleHandler struct { pluginName string fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) + onCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } // RuleMap is a concurrent map for access monitoring rules. @@ -65,6 +66,8 @@ type RuleHandlerConfig struct { // FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients. FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) + // OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated. + OnCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } // NewRuleHandler returns a new RuleHandler. @@ -77,6 +80,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler { pluginType: conf.PluginType, pluginName: conf.PluginName, fetchRecipientCallback: conf.FetchRecipientCallback, + onCacheUpdateCallback: conf.OnCacheUpdateCallback, } } @@ -93,6 +97,9 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err continue } amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr + if amrh.onCacheUpdateCallback != nil { + amrh.onCacheUpdateCallback(types.OpPut, amr.GetMetadata().Name, amr) + } } return nil } @@ -123,6 +130,9 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t return nil } amrh.accessMonitoringRules.rules[req.Metadata.Name] = req + if amrh.onCacheUpdateCallback != nil { + amrh.onCacheUpdateCallback(types.OpPut, req.GetMetadata().Name, req) + } return nil case types.OpDelete: delete(amrh.accessMonitoringRules.rules, event.Resource.GetName()) diff --git a/integrations/access/pagerduty/app.go b/integrations/access/pagerduty/app.go index 972ed1bffce11..5eadcc5147cd0 100644 --- a/integrations/access/pagerduty/app.go +++ b/integrations/access/pagerduty/app.go @@ -78,7 +78,6 @@ func NewApp(conf Config) (*App, error) { teleport: conf.Client, statusSink: conf.StatusSink, } - app.mainJob = lib.NewServiceJob(app.run) return app, nil @@ -173,7 +172,7 @@ func (a *App) init(ctx context.Context) error { } } - a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{ + amrhConf := accessmonitoring.RuleHandlerConfig{ Client: a.teleport, PluginType: types.PluginTypePagerDuty, PluginName: pluginName, @@ -184,7 +183,11 @@ func (a *App) init(ctx context.Context) error { Kind: common.RecipientKindSchedule, }, nil }, - }) + } + if a.conf.OnAccessMonitoringRuleCacheUpdateCallback != nil { + amrhConf.OnCacheUpdateCallback = a.conf.OnAccessMonitoringRuleCacheUpdateCallback + } + a.accessMonitoringRules = accessmonitoring.NewRuleHandler(amrhConf) if pong, err = a.checkTeleportVersion(ctx); err != nil { return trace.Wrap(err) diff --git a/integrations/access/pagerduty/config.go b/integrations/access/pagerduty/config.go index f76e9d2f955f2..8bf7060652b01 100644 --- a/integrations/access/pagerduty/config.go +++ b/integrations/access/pagerduty/config.go @@ -24,6 +24,8 @@ import ( "github.com/gravitational/trace" "github.com/pelletier/go-toml" + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" @@ -47,6 +49,10 @@ type Config struct { // TeleportUser is the name of the Teleport user that will act // as the access request approver TeleportUser string + + // OnAccessMonitoringRuleCacheUpdateCallback is used for checking when + // the Rule cache is updated in tests + OnAccessMonitoringRuleCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } type PagerdutyConfig struct { diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index c379c85219a5c..b68c4d7e8d706 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -430,6 +430,15 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) t.Cleanup(cancel) + const ruleName = "test-pagerduty-amr" + var collectedNames []string + var mu sync.Mutex + s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error { + mu.Lock() + collectedNames = append(collectedNames, name) + mu.Unlock() + return nil + } s.startApp() _, err := s.ClientByName(integration.RulerUserName). @@ -438,7 +447,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { Kind: types.KindAccessMonitoringRule, Version: types.V1, Metadata: &v1.Metadata{ - Name: "test-pagerduty-amr", + Name: ruleName, }, Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ Subjects: []string{types.KindAccessRequest}, @@ -453,6 +462,14 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { }) assert.NoError(t, err) + // Incident creation may happen before plugins Access Monitoring Rule cache + // has been updated with new rule. Retry until the new cache picks up the rule. + require.EventuallyWithT(t, func(t *assert.CollectT) { + mu.Lock() + require.Contains(t, collectedNames, ruleName) + mu.Unlock() + }, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying") + // Test execution: create an access request req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) @@ -463,16 +480,16 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { }) incident, err := s.fakePagerduty.CheckNewIncident(ctx) - require.NoError(t, err, "no new incidents stored") - + assert.NoError(t, err, "no new incidents stored") assert.Equal(t, incident.ID, pluginData.IncidentID) - assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID) assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) assert.Equal(t, "triggered", incident.Status) + assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID) + assert.NoError(t, s.ClientByName(integration.RulerUserName). - AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr")) + AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, ruleName)) } func (s *PagerdutyBaseSuite) assertNewEvent(ctx context.Context, watcher types.Watcher, opType types.OpType, resourceKind, resourceName string) types.Event {