diff --git a/internal/entities/handlers/message/message.go b/internal/entities/handlers/message/message.go index c7e2743b7..a9636b9ee 100644 --- a/internal/entities/handlers/message/message.go +++ b/internal/entities/handlers/message/message.go @@ -36,6 +36,7 @@ type TypedProps struct { // EntityHint is a hint that is used to help the entity handler find the entity. type EntityHint struct { ProviderImplementsHint string `json:"provider_implements_hint"` + ProviderClassHint string `json:"provider_class_hint"` } // HandleEntityAndDoMessage is a message that is sent to the entity handler to refresh an entity and perform an action. @@ -69,11 +70,19 @@ func (e *HandleEntityAndDoMessage) WithOwner(ownerType v1.Entity, ownerProps *pr } // WithProviderImplementsHint sets the provider hint for the entity that will be used when looking up the entity. +// to the provider implements hint func (e *HandleEntityAndDoMessage) WithProviderImplementsHint(providerHint string) *HandleEntityAndDoMessage { e.Hint.ProviderImplementsHint = providerHint return e } +// WithProviderClassHint sets the provider hint for the entity that will be used when looking up the entity. +// to the provider class +func (e *HandleEntityAndDoMessage) WithProviderClassHint(providerClassHint string) *HandleEntityAndDoMessage { + e.Hint.ProviderClassHint = providerClassHint + return e +} + // ToEntityRefreshAndDo converts a Watermill message to a HandleEntityAndDoMessage struct. func ToEntityRefreshAndDo(msg *message.Message) (*HandleEntityAndDoMessage, error) { entMsg := &HandleEntityAndDoMessage{} diff --git a/internal/entities/handlers/message/message_test.go b/internal/entities/handlers/message/message_test.go index 76cf2a23c..4b851f3c1 100644 --- a/internal/entities/handlers/message/message_test.go +++ b/internal/entities/handlers/message/message_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/entities/properties" v1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) @@ -31,12 +32,13 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { t.Parallel() scenarios := []struct { - name string - props map[string]any - entType v1.Entity - ownerProps map[string]any - ownerType v1.Entity - providerHint string + name string + props map[string]any + entType v1.Entity + ownerProps map[string]any + ownerType v1.Entity + providerHint string + providerClass string }{ { name: "Valid repository entity", @@ -44,8 +46,9 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { "id": "123", "name": "test-repo", }, - entType: v1.Entity_ENTITY_REPOSITORIES, - providerHint: "github", + entType: v1.Entity_ENTITY_REPOSITORIES, + providerHint: "github", + providerClass: string(db.ProviderClassGithub), }, { name: "Valid artifact entity", @@ -57,8 +60,9 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { ownerProps: map[string]any{ "id": "123", }, - ownerType: v1.Entity_ENTITY_REPOSITORIES, - providerHint: "docker", + ownerType: v1.Entity_ENTITY_REPOSITORIES, + providerHint: "docker", + providerClass: string(db.ProviderClassDockerhub), }, { name: "Valid pull request entity", @@ -69,8 +73,9 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { ownerProps: map[string]any{ "id": "123", }, - ownerType: v1.Entity_ENTITY_REPOSITORIES, - providerHint: "github", + ownerType: v1.Entity_ENTITY_REPOSITORIES, + providerHint: "github", + providerClass: string(db.ProviderClassGithub), }, } @@ -83,7 +88,8 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { original := NewEntityRefreshAndDoMessage(). WithEntity(sc.entType, props). - WithProviderImplementsHint(sc.providerHint) + WithProviderImplementsHint(sc.providerHint). + WithProviderClassHint(sc.providerClass) if sc.ownerProps != nil { ownerProps, err := properties.NewProperties(sc.ownerProps) @@ -100,6 +106,7 @@ func TestEntityRefreshAndDoMessageRoundTrip(t *testing.T) { assert.Equal(t, original.Entity.GetByProps, roundTrip.Entity.GetByProps) assert.Equal(t, original.Entity.Type, roundTrip.Entity.Type) assert.Equal(t, original.Hint.ProviderImplementsHint, roundTrip.Hint.ProviderImplementsHint) + assert.Equal(t, original.Hint.ProviderClassHint, roundTrip.Hint.ProviderClassHint) if original.Originator.Type != v1.Entity_ENTITY_UNSPECIFIED { assert.Equal(t, original.Originator.GetByProps, roundTrip.Originator.GetByProps) assert.Equal(t, original.Originator.Type, roundTrip.Originator.Type) diff --git a/internal/entities/handlers/strategies/entity/common.go b/internal/entities/handlers/strategies/entity/common.go index b9999acd6..8de8da08d 100644 --- a/internal/entities/handlers/strategies/entity/common.go +++ b/internal/entities/handlers/strategies/entity/common.go @@ -41,6 +41,12 @@ func getEntityInner( return nil, fmt.Errorf("error scanning provider type: %w", err) } } + if hint.ProviderClassHint != "" { + svcHint.ProviderClass.Valid = true + if err := svcHint.ProviderClass.ProviderClass.Scan(hint.ProviderClassHint); err != nil { + return nil, fmt.Errorf("error scanning provider class: %w", err) + } + } lookupProperties, err := properties.NewProperties(entPropMap) if err != nil {