diff --git a/go.mod b/go.mod index f0de263..26df635 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 toolchain go1.21.3 require ( - github.com/Axway/agent-sdk v1.1.73 + github.com/Axway/agent-sdk v1.1.75-0.20240222230101-1dd0ec205b21 github.com/elastic/beats/v7 v7.17.17 github.com/google/uuid v1.3.1 github.com/kong/go-kong v0.47.0 @@ -118,6 +118,7 @@ require ( github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/shirou/gopsutil v3.20.12+incompatible // indirect + github.com/shopspring/decimal v1.3.1 // indirect github.com/snowzach/rotatefilehook v0.0.0-20220211133110-53752135082d // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.5.0 // indirect diff --git a/go.sum b/go.sum index 158b202..eee72b5 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/Axway/agent-sdk v1.1.73 h1:aLtVRDeWNz/bvlZrhDyYL7lWelWV0TsMqzHmsG2BUAY= -github.com/Axway/agent-sdk v1.1.73/go.mod h1:CMuNRWCU4UswYhaR65bzLKEDT0xcViQwHJUGFCnH2yk= +github.com/Axway/agent-sdk v1.1.75-0.20240222230101-1dd0ec205b21 h1:EtVW8nyRsM25Uyu9XCsmI7/RZCI3cw0tnMmbyFNDjgA= +github.com/Axway/agent-sdk v1.1.75-0.20240222230101-1dd0ec205b21/go.mod h1:QG5rkOTPFZmbcigGsQ8TgVNyncz7QB+xJMFTJtsFWQ8= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7OZ575w+acHgRric5iCyQh+xv+KJ4HB8= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= @@ -460,6 +460,8 @@ github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/shirou/gopsutil v3.20.12+incompatible h1:6VEGkOXP/eP4o2Ilk8cSsX0PhOEfX6leqAnD+urrp9M= github.com/shirou/gopsutil v3.20.12+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/pkg/discovery/gateway/client.go b/pkg/discovery/agent/agent.go similarity index 54% rename from pkg/discovery/gateway/client.go rename to pkg/discovery/agent/agent.go index aeed476..ab0cff2 100644 --- a/pkg/discovery/gateway/client.go +++ b/pkg/discovery/agent/agent.go @@ -1,10 +1,10 @@ -package gateway +package agent import ( "context" "crypto/sha256" "fmt" - "net/http" + "net/url" "sync" klib "github.com/kong/go-kong/kong" @@ -13,63 +13,106 @@ import ( "github.com/Axway/agent-sdk/pkg/apic" "github.com/Axway/agent-sdk/pkg/apic/provisioning" "github.com/Axway/agent-sdk/pkg/cache" + corecfg "github.com/Axway/agent-sdk/pkg/config" "github.com/Axway/agent-sdk/pkg/filter" "github.com/Axway/agent-sdk/pkg/util" "github.com/Axway/agent-sdk/pkg/util/log" "github.com/Axway/agents-kong/pkg/common" "github.com/Axway/agents-kong/pkg/discovery/config" - kutil "github.com/Axway/agents-kong/pkg/discovery/kong" + "github.com/Axway/agents-kong/pkg/discovery/kong" "github.com/Axway/agents-kong/pkg/discovery/subscription" ) var kongToCRDMapper = map[string]string{ - "basic-auth": provisioning.BasicAuthCRD, - "key-auth": provisioning.APIKeyCRD, - "oauth2": provisioning.OAuthSecretCRD, + kong.BasicAuthPlugin: provisioning.BasicAuthCRD, + kong.KeyAuthPlugin: provisioning.APIKeyCRD, + kong.OAuthPlugin: provisioning.OAuthSecretCRD, } -func NewClient(agentConfig config.AgentConfig) (*Client, error) { - kongGatewayConfig := agentConfig.KongGatewayCfg - clientBase := &http.Client{} - kongClient, err := kutil.NewKongClient(clientBase, kongGatewayConfig) +type kongClient interface { + // Provisioning + CreateConsumer(ctx context.Context, id, name string) (*klib.Consumer, error) + AddConsumerACL(ctx context.Context, id string) error + DeleteConsumer(ctx context.Context, id string) error + // Credential + DeleteOauth2(ctx context.Context, consumerID, clientID string) error + DeleteHttpBasic(ctx context.Context, consumerID, username string) error + DeleteAuthKey(ctx context.Context, consumerID, authKey string) error + CreateHttpBasic(ctx context.Context, consumerID string, basicAuth *klib.BasicAuth) (*klib.BasicAuth, error) + CreateOauth2(ctx context.Context, consumerID string, oauth2 *klib.Oauth2Credential) (*klib.Oauth2Credential, error) + CreateAuthKey(ctx context.Context, consumerID string, keyAuth *klib.KeyAuth) (*klib.KeyAuth, error) + // Access Request + AddRouteACL(ctx context.Context, routeID, allowedID string) error + RemoveRouteACL(ctx context.Context, routeID, revokedID string) error + AddQuota(ctx context.Context, routeID, allowedID, quotaInterval string, quotaLimit int) error + // Discovery + ListServices(ctx context.Context) ([]*klib.Service, error) + ListRoutesForService(ctx context.Context, serviceId string) ([]*klib.Route, error) + GetSpecForService(ctx context.Context, service *klib.Service) ([]byte, error) + GetKongPlugins() *kong.Plugins +} + +type Agent struct { + logger log.FieldLogger + centralCfg corecfg.CentralConfig + kongGatewayCfg *config.KongGatewayConfig + kongClient kongClient + plugins kong.Plugins + cache cache.Cache + filter filter.Filter +} + +func NewAgent(agentConfig config.AgentConfig, agentOpts ...func(a *Agent)) (*Agent, error) { + ka := &Agent{ + logger: log.NewFieldLogger().WithComponent("agent").WithPackage("kongAgent"), + centralCfg: agentConfig.CentralCfg, + kongGatewayCfg: agentConfig.KongGatewayCfg, + cache: cache.New(), + } + for _, o := range agentOpts { + o(ka) + } + + var err error + if ka.kongClient == nil { + ka.kongClient, err = kong.NewKongClient(ka.kongGatewayCfg) + } if err != nil { return nil, err } - daCache := cache.New() - logger := log.NewFieldLogger().WithField("component", "agent") - plugins, err := kongClient.Plugins.ListAll(context.Background()) + pluginLister := ka.kongClient.GetKongPlugins() + if pluginLister == nil { + return nil, fmt.Errorf("could not get kong plugin lister") + } + plugins, err := ka.kongClient.GetKongPlugins().ListAll(context.Background()) if err != nil { return nil, err } - discoveryFilter, err := filter.NewFilter(agentConfig.KongGatewayCfg.Spec.Filter) + ka.filter, err = filter.NewFilter(agentConfig.KongGatewayCfg.Spec.Filter) if err != nil { return nil, err } - if err = hasGlobalACLEnabledInPlugins(logger, plugins, agentConfig.KongGatewayCfg.ACL.Disable); err != nil { - logger.WithError(err).Error("ACL Plugin configured as required, but none found in Kong plugins.") + if err = hasGlobalACLEnabledInPlugins(ka.logger, plugins, agentConfig.KongGatewayCfg.ACL.Disable); err != nil { + ka.logger.WithError(err).Error("ACL Plugin configured as required, but none found in Kong plugins.") return nil, err } - provisionLogger := log.NewFieldLogger().WithComponent("provision").WithPackage("kong") opts := []subscription.ProvisionerOption{} if agentConfig.KongGatewayCfg.ACL.Disable { opts = append(opts, subscription.WithACLDisable()) } - subscription.NewProvisioner(kongClient, provisionLogger, opts...) + subscription.NewProvisioner(ka.kongClient, opts...) + return ka, nil +} - return &Client{ - logger: logger, - centralCfg: agentConfig.CentralCfg, - kongGatewayCfg: kongGatewayConfig, - kongClient: kongClient, - cache: daCache, - mode: common.Marketplace, - filter: discoveryFilter, - }, nil +func withKongClient(kongClient kongClient) func(a *Agent) { + return func(a *Agent) { + a.kongClient = kongClient + } } func pluginIsGlobal(p *klib.Plugin) bool { @@ -90,17 +133,16 @@ func hasGlobalACLEnabledInPlugins(logger log.FieldLogger, plugins []*klib.Plugin return nil } } - return fmt.Errorf("failed to find acl plugin is enabled and installed on the Kong Gateway. " + - "Enable in on the Gateway or change the config to disable this check.") + return fmt.Errorf("acl plugin is not enabled/installed, install and enable or change the config to disable this check") } -func (gc *Client) DiscoverAPIs() error { +func (gc *Agent) DiscoverAPIs() error { gc.logger.Info("execute discovery process") ctx := context.Background() var err error - plugins := kutil.Plugins{PluginLister: gc.kongClient.GetKongPlugins()} + plugins := kong.Plugins{PluginLister: gc.kongClient.GetKongPlugins()} gc.plugins = plugins services, err := gc.kongClient.ListServices(ctx) @@ -113,7 +155,7 @@ func (gc *Client) DiscoverAPIs() error { return nil } -func (gc *Client) processKongServicesList(ctx context.Context, services []*klib.Service) { +func (gc *Agent) processKongServicesList(ctx context.Context, services []*klib.Service) { wg := new(sync.WaitGroup) for _, service := range services { if !gc.filter.Evaluate(toTagsMap(service)) { @@ -141,7 +183,7 @@ func toTagsMap(service *klib.Service) map[string]string { return filters } -func (gc *Client) processSingleKongService(ctx context.Context, service *klib.Service) error { +func (gc *Agent) processSingleKongService(ctx context.Context, service *klib.Service) error { log := gc.logger.WithField(common.AttrServiceName, *service.Name) log.Info("processing service") @@ -166,13 +208,20 @@ func (gc *Client) processSingleKongService(ctx context.Context, service *klib.Se spec := apic.NewSpecResourceParser(kongServiceSpec, "") spec.Parse() - for _, route := range routes { - gc.specPreparation(ctx, route, service, spec.GetSpecProcessor()) + wg := sync.WaitGroup{} + wg.Add(len(routes)) + for _, r := range routes { + func(route *klib.Route) { + defer wg.Done() + gc.specPreparation(ctx, route, service, spec.GetSpecProcessor()) + }(r) } + wg.Wait() + return nil } -func (gc *Client) specPreparation(ctx context.Context, route *klib.Route, service *klib.Service, spec apic.SpecProcessor) { +func (gc *Agent) specPreparation(ctx context.Context, route *klib.Route, service *klib.Service, spec apic.SpecProcessor) { log := gc.logger.WithField(common.AttrRouteID, *route.ID). WithField(common.AttrServiceID, *service.ID) @@ -206,7 +255,7 @@ func (gc *Client) specPreparation(ctx context.Context, route *klib.Route, servic log.Info("Successfully published to central") } -func (gc *Client) processKongRoute(route *klib.Route) []apic.EndpointDefinition { +func (gc *Agent) processKongRoute(route *klib.Route) []apic.EndpointDefinition { if route == nil { return []apic.EndpointDefinition{} } @@ -222,7 +271,7 @@ func (gc *Client) processKongRoute(route *klib.Route) []apic.EndpointDefinition return kRoute.GetEndpoints() } -func (gc *Client) processKongAPI( +func (gc *Agent) processKongAPI( ctx context.Context, route *klib.Route, service *klib.Service, @@ -230,7 +279,7 @@ func (gc *Client) processKongAPI( endpoints []apic.EndpointDefinition, apiPlugins map[string]*klib.Plugin, ) (*apic.ServiceBody, error) { - kongAPI := newKongAPI(route, service, spec, endpoints) + kongAPI := newKongAPI(route, service, spec, endpoints, apiPlugins) isAlreadyPublished, checksum := isPublished(&kongAPI, gc.cache) // If true, then the api is published and there were no changes detected if isAlreadyPublished { @@ -242,14 +291,6 @@ func (gc *Client) processKongAPI( gc.logger.WithError(err).Error("failed to save api to cache") } - kongAPI.ard = provisioning.APIKeyARD - kongAPI.crds = []string{} - for k := range apiPlugins { - if crd, ok := kongToCRDMapper[k]; ok { - kongAPI.crds = append(kongAPI.crds, crd) - } - } - agentDetails := map[string]string{ common.AttrServiceID: *service.ID, common.AttrRouteID: *route.ID, @@ -269,14 +310,10 @@ func newKongAPI( service *klib.Service, spec apic.SpecProcessor, endpoints []apic.EndpointDefinition, + apiPlugins map[string]*klib.Plugin, ) KongAPI { - // strip any security from spec if it is an oas spec resType := spec.GetResourceType() - if resType == apic.Oas2 || resType == apic.Oas3 { - spec.(apic.OasSpecProcessor).StripSpecAuth() - } - - return KongAPI{ + ka := &KongAPI{ id: *service.ID, name: *service.Name, description: spec.GetDescription(), @@ -289,6 +326,99 @@ func newKongAPI( stageName: *route.Name, stage: *route.ID, } + ka.processSpecSecurity(spec, apiPlugins) + return *ka +} + +func (ka *KongAPI) processSpecSecurity(spec apic.SpecProcessor, apiPlugins map[string]*klib.Plugin) { + // strip any security from spec if it is an oas spec + resType := spec.GetResourceType() + if resType != apic.Oas2 && resType != apic.Oas3 { + return + } + oasSpec := spec.(apic.OasSpecProcessor) + oasSpec.StripSpecAuth() + + ka.ard = provisioning.APIKeyARD + ka.crds = []string{} + for k, plugin := range apiPlugins { + if crd, ok := kongToCRDMapper[k]; ok { + ka.crds = append(ka.crds, crd) + } + switch k { + case kong.BasicAuthPlugin: + oasSpec.AddSecuritySchemes(oasSpec.GetSecurityBuilder().HTTPBasic().Build()) + case kong.KeyAuthPlugin: + ka.apiKeySecurity(oasSpec, plugin.Config) + case kong.OAuthPlugin: + ka.oAuthSecurity(oasSpec, plugin.Config) + } + } + + ka.spec = oasSpec.(apic.SpecProcessor).GetSpecBytes() +} + +func (ka *KongAPI) apiKeySecurity(spec apic.OasSpecProcessor, config map[string]interface{}) { + keyAuth, err := kong.NewKeyAuthPluginConfigFromMap(config) + if err != nil { + return + } + + for _, key := range keyAuth.KeyNames { + if keyAuth.KeyInQuery { + spec.AddSecuritySchemes(spec.GetSecurityBuilder().APIKey().SetArgumentName(key).InQueryParam().Build()) + } else { + // forcing header if not in query + spec.AddSecuritySchemes(spec.GetSecurityBuilder().APIKey().SetArgumentName(key).InHeader().Build()) + } + } +} + +func (ka *KongAPI) oAuthSecurity(spec apic.OasSpecProcessor, config map[string]interface{}) { + oAuth, err := kong.NewOAuthPluginConfigFromMap(config) + if err != nil { + return + } + + builder := spec.GetSecurityBuilder().OAuth() + + s := url.URL{} + for _, e := range ka.endpoints { + if e.Protocol == httpsScheme { + s = url.URL{ + Scheme: httpsScheme, + Host: fmt.Sprintf("%v:%v", e.Host, e.Port), + Path: e.BasePath, + } + break + } + } + if s.Scheme == "" { + return + } + tokenURL := fmt.Sprintf("%v/oauth2/token", s.String()) + authURL := fmt.Sprintf("%v/oauth2/authorize", s.String()) + scopes := map[string]string{} + for _, n := range oAuth.Scopes { + scopes[n] = n + } + + if oAuth.EnableImplicitGrant { + builder = builder.AddFlow(apic.NewOAuthFlowBuilder().SetScopes(scopes).SetAuthorizationURL(authURL).Implicit()) + } + + if oAuth.EnableAuthorizationCode { + builder = builder.AddFlow(apic.NewOAuthFlowBuilder().SetScopes(scopes).SetAuthorizationURL(authURL).SetTokenURL(tokenURL).AuthorizationCode()) + } + + if oAuth.EnableClientCredentials { + builder = builder.AddFlow(apic.NewOAuthFlowBuilder().SetScopes(scopes).SetTokenURL(tokenURL).ClientCredentials()) + } + + if oAuth.EnablePasswordGrant { + builder = builder.AddFlow(apic.NewOAuthFlowBuilder().SetScopes(scopes).SetTokenURL(tokenURL).Password()) + } + spec.AddSecuritySchemes(builder.Build()) } func (ka *KongAPI) buildServiceBody() (apic.ServiceBody, error) { diff --git a/pkg/discovery/agent/agent_test.go b/pkg/discovery/agent/agent_test.go new file mode 100644 index 0000000..935659d --- /dev/null +++ b/pkg/discovery/agent/agent_test.go @@ -0,0 +1,142 @@ +package agent + +import ( + "context" + "testing" + + "github.com/Axway/agent-sdk/pkg/agent" + "github.com/Axway/agent-sdk/pkg/apic/mock" + "github.com/Axway/agent-sdk/pkg/cache" + corecfg "github.com/Axway/agent-sdk/pkg/config" + "github.com/Axway/agent-sdk/pkg/filter" + "github.com/Axway/agent-sdk/pkg/util/log" + config "github.com/Axway/agents-kong/pkg/discovery/config" + "github.com/Axway/agents-kong/pkg/discovery/kong" + klib "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func stringPtr(s string) *string { + return &s +} + +func boolPtr(b bool) *bool { + return &b +} + +func intPtr(i int) *int { + return &i +} + +func TestNewAgent(t *testing.T) { + testCases := map[string]struct { + gatewayConfig *config.KongGatewayConfig + client *mockKongClient + expectErr bool + }{ + "error when plugin lister is not created": { + gatewayConfig: &config.KongGatewayConfig{}, + client: &mockKongClient{}, + expectErr: true, + }, + "error getting kong plugins using lister": { + gatewayConfig: &config.KongGatewayConfig{}, + client: &mockKongClient{ + GetKongPluginsMock: func() *kong.Plugins { + return &kong.Plugins{PluginLister: &mockPluginLister{}} + }, + }, + expectErr: true, + }, + "error hit because ACL was not installed": { + gatewayConfig: &config.KongGatewayConfig{}, + client: &mockKongClient{ + GetKongPluginsMock: func() *kong.Plugins { + return &kong.Plugins{PluginLister: &mockPluginLister{plugins: []*klib.Plugin{}}} + }, + }, + expectErr: true, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + cfg := config.AgentConfig{ + CentralCfg: corecfg.NewCentralConfig(corecfg.DiscoveryAgent), + KongGatewayCfg: tc.gatewayConfig, + } + agent.InitializeForTest(&mock.Client{}, agent.TestWithMarketplace()) + + a, err := NewAgent(cfg, withKongClient(tc.client)) + if tc.expectErr { + assert.Nil(t, a) + assert.NotNil(t, err) + } + }) + } +} + +func TestDiscovery(t *testing.T) { + testCases := map[string]struct { + client *mockKongClient + expectErr bool + }{ + "expect error when services call fails": { + client: &mockKongClient{ + GetKongPluginsMock: func() *kong.Plugins { + return &kong.Plugins{PluginLister: &mockPluginLister{plugins: []*klib.Plugin{}}} + }, + }, + expectErr: true, + }, + "success when no services returned": { + client: &mockKongClient{ + GetKongPluginsMock: func() *kong.Plugins { + return &kong.Plugins{PluginLister: &mockPluginLister{plugins: []*klib.Plugin{}}} + }, + ListServicesMock: func(context.Context) ([]*klib.Service, error) { + return []*klib.Service{}, nil + }, + }, + }, + "success when services returned but no routes": { + client: &mockKongClient{ + GetKongPluginsMock: func() *kong.Plugins { + return &kong.Plugins{PluginLister: &mockPluginLister{plugins: []*klib.Plugin{}}} + }, + ListServicesMock: func(context.Context) ([]*klib.Service, error) { + return []*klib.Service{ + { + Enabled: boolPtr(true), + Host: stringPtr("petstore.com"), + ID: stringPtr("petstore-id"), + Name: stringPtr("PetStore"), + Tags: []*string{}, + }, + }, nil + }, + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + f, _ := filter.NewFilter("") + ka := &Agent{ + logger: log.NewFieldLogger().WithComponent("agent").WithPackage("kongAgent"), + centralCfg: corecfg.NewCentralConfig(corecfg.DiscoveryAgent), + kongGatewayCfg: &config.KongGatewayConfig{}, + cache: cache.New(), + kongClient: tc.client, + filter: f, + } + + // agent.InitializeForTest() + + err := ka.DiscoverAPIs() + if tc.expectErr { + assert.NotNil(t, err) + return + } + assert.Nil(t, err) + }) + } +} diff --git a/pkg/discovery/agent/definitions.go b/pkg/discovery/agent/definitions.go new file mode 100644 index 0000000..be5fb6d --- /dev/null +++ b/pkg/discovery/agent/definitions.go @@ -0,0 +1,26 @@ +package agent + +import ( + "github.com/Axway/agent-sdk/pkg/apic" +) + +type KongAPI struct { + spec []byte + id string + name string + description string + version string + url string + documentation []byte + resourceType string + endpoints []apic.EndpointDefinition + image string + imageContentType string + crds []string + apiUpdateSeverity string + agentDetails map[string]string + tags []string + stage string + stageName string + ard string +} diff --git a/pkg/discovery/agent/mockkongclient_test.go b/pkg/discovery/agent/mockkongclient_test.go new file mode 100644 index 0000000..c64d18f --- /dev/null +++ b/pkg/discovery/agent/mockkongclient_test.go @@ -0,0 +1,155 @@ +package agent + +import ( + "context" + "fmt" + + "github.com/Axway/agents-kong/pkg/discovery/kong" + klib "github.com/kong/go-kong/kong" +) + +type mockKongClient struct { + // Provisioning + CreateConsumerMock func(context.Context, string, string) (*klib.Consumer, error) + AddConsumerACLMock func(context.Context, string) error + DeleteConsumerMock func(context.Context, string) error + // Credential + DeleteOauth2Mock func(context.Context, string, string) error + DeleteHttpBasicMock func(context.Context, string, string) error + DeleteAuthKeyMock func(context.Context, string, string) error + CreateHttpBasicMock func(context.Context, string, *klib.BasicAuth) (*klib.BasicAuth, error) + CreateOauth2Mock func(context.Context, string, *klib.Oauth2Credential) (*klib.Oauth2Credential, error) + CreateAuthKeyMock func(context.Context, string, *klib.KeyAuth) (*klib.KeyAuth, error) + // Access Request + AddRouteACLMock func(context.Context, string, string) error + RemoveRouteACLMock func(context.Context, string, string) error + AddQuotaMock func(context.Context, string, string, string, int) error + // Discovery + ListServicesMock func(context.Context) ([]*klib.Service, error) + ListRoutesForServiceMock func(context.Context, string) ([]*klib.Route, error) + GetSpecForServiceMock func(context.Context, *klib.Service) ([]byte, error) + GetKongPluginsMock func() *kong.Plugins +} + +func (m *mockKongClient) CreateConsumer(ctx context.Context, id, name string) (*klib.Consumer, error) { + if m.CreateConsumerMock != nil { + return m.CreateConsumerMock(ctx, id, name) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) AddConsumerACL(ctx context.Context, id string) error { + if m.AddConsumerACLMock != nil { + return m.AddConsumerACLMock(ctx, id) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) DeleteConsumer(ctx context.Context, id string) error { + if m.DeleteConsumerMock != nil { + return m.DeleteConsumerMock(ctx, id) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) DeleteOauth2(ctx context.Context, consumerID, clientID string) error { + if m.DeleteOauth2Mock != nil { + return m.DeleteOauth2Mock(ctx, consumerID, clientID) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) DeleteHttpBasic(ctx context.Context, consumerID, username string) error { + if m.DeleteHttpBasicMock != nil { + return m.DeleteHttpBasicMock(ctx, consumerID, username) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) DeleteAuthKey(ctx context.Context, consumerID, authKey string) error { + if m.DeleteAuthKeyMock != nil { + return m.DeleteAuthKeyMock(ctx, consumerID, authKey) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) CreateHttpBasic(ctx context.Context, consumerID string, basicAuth *klib.BasicAuth) (*klib.BasicAuth, error) { + if m.CreateHttpBasicMock != nil { + return m.CreateHttpBasicMock(ctx, consumerID, basicAuth) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) CreateOauth2(ctx context.Context, consumerID string, oauth2 *klib.Oauth2Credential) (*klib.Oauth2Credential, error) { + if m.CreateOauth2Mock != nil { + return m.CreateOauth2Mock(ctx, consumerID, oauth2) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) CreateAuthKey(ctx context.Context, consumerID string, keyAuth *klib.KeyAuth) (*klib.KeyAuth, error) { + if m.CreateAuthKeyMock != nil { + return m.CreateAuthKeyMock(ctx, consumerID, keyAuth) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) AddRouteACL(ctx context.Context, routeID, allowedID string) error { + if m.AddRouteACLMock != nil { + return m.AddRouteACLMock(ctx, routeID, allowedID) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) RemoveRouteACL(ctx context.Context, routeID, revokedID string) error { + if m.RemoveRouteACLMock != nil { + return m.RemoveRouteACLMock(ctx, routeID, revokedID) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) AddQuota(ctx context.Context, routeID, allowedID, quotaInterval string, quotaLimit int) error { + if m.AddQuotaMock != nil { + return m.AddQuotaMock(ctx, routeID, allowedID, quotaInterval, quotaLimit) + } + return fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) ListServices(ctx context.Context) ([]*klib.Service, error) { + if m.ListServicesMock != nil { + return m.ListServicesMock(ctx) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) ListRoutesForService(ctx context.Context, serviceId string) ([]*klib.Route, error) { + if m.ListRoutesForServiceMock != nil { + return m.ListRoutesForServiceMock(ctx, serviceId) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) GetSpecForService(ctx context.Context, service *klib.Service) ([]byte, error) { + if m.GetSpecForServiceMock != nil { + return m.GetSpecForServiceMock(ctx, service) + } + return nil, fmt.Errorf("unimplemented test func") +} + +func (m *mockKongClient) GetKongPlugins() *kong.Plugins { + if m.GetKongPluginsMock != nil { + return m.GetKongPluginsMock() + } + return nil +} + +type mockPluginLister struct { + plugins []*klib.Plugin +} + +func (m *mockPluginLister) ListAll(ctx context.Context) ([]*klib.Plugin, error) { + if m.plugins == nil { + return nil, fmt.Errorf("not implemented by test") + } + return m.plugins, nil +} diff --git a/pkg/discovery/gateway/route.go b/pkg/discovery/agent/route.go similarity index 87% rename from pkg/discovery/gateway/route.go rename to pkg/discovery/agent/route.go index a7c9d56..270a3e0 100644 --- a/pkg/discovery/gateway/route.go +++ b/pkg/discovery/agent/route.go @@ -1,4 +1,4 @@ -package gateway +package agent import ( "fmt" @@ -7,6 +7,11 @@ import ( klib "github.com/kong/go-kong/kong" ) +const ( + httpScheme = "http" + httpsScheme = "https" +) + type KongRoute struct { *klib.Route defaultHost string @@ -47,19 +52,19 @@ func (r *KongRoute) handlePaths(host, basePath string) []apic.EndpointDefinition func (r *KongRoute) handleProtocols(host, path string) []apic.EndpointDefinition { endpoints := make([]apic.EndpointDefinition, 0) for _, protocol := range r.Protocols { - if *protocol == "http" && r.httpPort != 0 { + if *protocol == httpScheme && r.httpPort != 0 { endpoints = append(endpoints, apic.EndpointDefinition{ Host: host, Port: int32(r.httpPort), - Protocol: "http", + Protocol: httpScheme, BasePath: path, }) } - if *protocol == "https" && r.httpsPort != 0 { + if *protocol == httpsScheme && r.httpsPort != 0 { endpoints = append(endpoints, apic.EndpointDefinition{ Host: host, Port: int32(r.httpsPort), - Protocol: "https", + Protocol: httpsScheme, BasePath: path, }) } diff --git a/pkg/discovery/gateway/route_test.go b/pkg/discovery/agent/route_test.go similarity index 99% rename from pkg/discovery/gateway/route_test.go rename to pkg/discovery/agent/route_test.go index de1de09..d678a7d 100644 --- a/pkg/discovery/gateway/route_test.go +++ b/pkg/discovery/agent/route_test.go @@ -1,4 +1,4 @@ -package gateway +package agent import ( "testing" diff --git a/pkg/discovery/cmd/cmd.go b/pkg/discovery/cmd/cmd.go index 5f39177..ea13655 100644 --- a/pkg/discovery/cmd/cmd.go +++ b/pkg/discovery/cmd/cmd.go @@ -7,8 +7,8 @@ import ( corecfg "github.com/Axway/agent-sdk/pkg/config" "github.com/Axway/agent-sdk/pkg/util/log" + "github.com/Axway/agents-kong/pkg/discovery/agent" "github.com/Axway/agents-kong/pkg/discovery/config" - "github.com/Axway/agents-kong/pkg/discovery/gateway" ) var DiscoveryCmd corecmd.AgentRootCmd @@ -35,14 +35,14 @@ func run() error { var err error stopChan := make(chan struct{}) - gatewayClient, err := gateway.NewClient(agentConfig) + kongAgent, err := agent.NewAgent(agentConfig) if err != nil { return err } go func() { for { - err = gatewayClient.DiscoverAPIs() + err = kongAgent.DiscoverAPIs() if err != nil { log.Errorf("error in processing: %v", err) stopChan <- struct{}{} diff --git a/pkg/discovery/gateway/client_test.go b/pkg/discovery/gateway/client_test.go deleted file mode 100644 index f8fcc00..0000000 --- a/pkg/discovery/gateway/client_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package gateway - -import ( - "testing" - - corecfg "github.com/Axway/agent-sdk/pkg/config" - config "github.com/Axway/agents-kong/pkg/discovery/config" -) - -func TestKongClient(t *testing.T) { - gatewayConfig := &config.KongGatewayConfig{} - _ = config.AgentConfig{ - CentralCfg: corecfg.NewCentralConfig(corecfg.DiscoveryAgent), - KongGatewayCfg: gatewayConfig, - } -} diff --git a/pkg/discovery/gateway/definitions.go b/pkg/discovery/gateway/definitions.go deleted file mode 100644 index c3ffe63..0000000 --- a/pkg/discovery/gateway/definitions.go +++ /dev/null @@ -1,44 +0,0 @@ -package gateway - -import ( - "github.com/Axway/agent-sdk/pkg/apic" - "github.com/Axway/agent-sdk/pkg/cache" - corecfg "github.com/Axway/agent-sdk/pkg/config" - "github.com/Axway/agent-sdk/pkg/filter" - "github.com/Axway/agent-sdk/pkg/util/log" - - config "github.com/Axway/agents-kong/pkg/discovery/config" - "github.com/Axway/agents-kong/pkg/discovery/kong" -) - -type Client struct { - logger log.FieldLogger - centralCfg corecfg.CentralConfig - kongGatewayCfg *config.KongGatewayConfig - kongClient kong.KongAPIClient - plugins kong.Plugins - cache cache.Cache - mode string - filter filter.Filter -} - -type KongAPI struct { - spec []byte - id string - name string - description string - version string - url string - documentation []byte - resourceType string - endpoints []apic.EndpointDefinition - image string - imageContentType string - crds []string - apiUpdateSeverity string - agentDetails map[string]string - tags []string - stage string - stageName string - ard string -} diff --git a/pkg/discovery/kong/authplugins.go b/pkg/discovery/kong/authplugins.go new file mode 100644 index 0000000..af15766 --- /dev/null +++ b/pkg/discovery/kong/authplugins.go @@ -0,0 +1,91 @@ +package kong + +import ( + "encoding/json" +) + +const ( + BasicAuthPlugin = "basic-auth" + KeyAuthPlugin = "key-auth" + OAuthPlugin = "oauth2" +) + +type OAuthPluginConfig struct { + HideCredentials bool `json:"hide_credentials,omitempty"` + PersistentRefreshToken bool `json:"persistent_refresh_token,omitempty"` + ProvisionKey string `json:"provision_key,omitempty"` + RefreshTokenTTL int64 `json:"refresh_token_ttl,omitempty"` + TokenExpiration int64 `json:"token_expiration,omitempty"` + AcceptHTTPIfAlreadyTerminated bool `json:"accept_http_if_already_terminated,omitempty"` + AuthHeaderName string `json:"auth_header_name,omitempty"` + MandatoryScope bool `json:"mandatory_scope,omitempty"` + Scopes []string `json:"scopes,omitempty"` + PKCE string `json:"pkce,omitempty"` + ReuseRefreshToken bool `json:"reuse_refresh_token,omitempty"` + EnablePasswordGrant bool `json:"enable_password_grant,omitempty"` + EnableClientCredentials bool `json:"enable_client_credentials,omitempty"` + GlobalCredentials bool `json:"global_credentials,omitempty"` + Anonymous string `json:"anonymous,omitempty"` + EnableImplicitGrant bool `json:"enable_implicit_grant,omitempty"` + EnableAuthorizationCode bool `json:"enable_authorization_code,omitempty"` +} + +func NewOAuthPluginConfigFromMap(mapData map[string]interface{}) (*OAuthPluginConfig, error) { + // Convert map to json string + jsonStr, err := json.Marshal(mapData) + if err != nil { + return nil, err + } + + config := &OAuthPluginConfig{} + if err := json.Unmarshal(jsonStr, config); err != nil { + return nil, err + } + + return config, nil +} + +type KeyAuthPluginConfig struct { + KeyInQuery bool `json:"key_in_query,omitempty"` + KeyInHeader bool `json:"key_in_header,omitempty"` + KeyNames []string `json:"key_names,omitempty"` + Anonymous string `json:"anonymous,omitempty"` + RunOnPreflight bool `json:"run_on_preflight,omitempty"` + HideCredentials bool `json:"hide_credentials,omitempty"` + KeyInBody bool `json:"key_in_body,omitempty"` +} + +func NewKeyAuthPluginConfigFromMap(mapData map[string]interface{}) (*KeyAuthPluginConfig, error) { + // Convert map to json string + jsonStr, err := json.Marshal(mapData) + if err != nil { + return nil, err + } + + config := &KeyAuthPluginConfig{} + if err := json.Unmarshal(jsonStr, config); err != nil { + return nil, err + } + + return config, nil +} + +type BasicAuthPluginConfig struct { + Anonymous string `json:"anonymous,omitempty"` + HideCredentials bool `json:"hide_credentials,omitempty"` +} + +func NewBasicAuthPluginConfigFromMap(mapData map[string]interface{}) (*BasicAuthPluginConfig, error) { + // Convert map to json string + jsonStr, err := json.Marshal(mapData) + if err != nil { + return nil, err + } + + config := &BasicAuthPluginConfig{} + if err := json.Unmarshal(jsonStr, config); err != nil { + return nil, err + } + + return config, nil +} diff --git a/pkg/discovery/kong/kongclient.go b/pkg/discovery/kong/kongclient.go index 57b3074..f8ddbdd 100644 --- a/pkg/discovery/kong/kongclient.go +++ b/pkg/discovery/kong/kongclient.go @@ -67,12 +67,14 @@ type KongClient struct { clientTimeout time.Duration } -func NewKongClient(baseClient *http.Client, kongConfig *config.KongGatewayConfig) (*KongClient, error) { +func NewKongClient(kongConfig *config.KongGatewayConfig) (*KongClient, error) { headers := make(http.Header) var kongEndpoint string kongTransport := http.DefaultTransport.(*http.Transport) kongTransport.TLSClientConfig = kongConfig.Admin.TLS.BuildTLSConfig() - baseClient.Transport = kongTransport + baseClient := &http.Client{ + Transport: kongTransport, + } kongEndpoint = kongConfig.Admin.Url if kongConfig.Admin.Auth.APIKey.Value != "" { diff --git a/pkg/discovery/kong/provisioning_test.go b/pkg/discovery/kong/provisioning_test.go index d3bd7e0..f23ee81 100644 --- a/pkg/discovery/kong/provisioning_test.go +++ b/pkg/discovery/kong/provisioning_test.go @@ -112,7 +112,7 @@ func createClient(responses map[string]response) KongAPIClient { if err := cfg.ValidateCfg(); err != nil { panic(err) } - client, _ := NewKongClient(&http.Client{}, cfg) + client, _ := NewKongClient(cfg) return client } diff --git a/pkg/discovery/subscription/provision.go b/pkg/discovery/subscription/provision.go index c340455..2d09055 100644 --- a/pkg/discovery/subscription/provision.go +++ b/pkg/discovery/subscription/provision.go @@ -3,6 +3,8 @@ package subscription import ( "context" + klib "github.com/kong/go-kong/kong" + "github.com/Axway/agent-sdk/pkg/agent" "github.com/Axway/agent-sdk/pkg/apic/provisioning" "github.com/Axway/agent-sdk/pkg/util/log" @@ -15,14 +17,38 @@ import ( type ProvisionerOption func(*provisioner) +type kongClient interface { + // Provisioning + CreateConsumer(ctx context.Context, id, name string) (*klib.Consumer, error) + AddConsumerACL(ctx context.Context, id string) error + DeleteConsumer(ctx context.Context, id string) error + // Credential + DeleteOauth2(ctx context.Context, consumerID, clientID string) error + DeleteHttpBasic(ctx context.Context, consumerID, username string) error + DeleteAuthKey(ctx context.Context, consumerID, authKey string) error + CreateHttpBasic(ctx context.Context, consumerID string, basicAuth *klib.BasicAuth) (*klib.BasicAuth, error) + CreateOauth2(ctx context.Context, consumerID string, oauth2 *klib.Oauth2Credential) (*klib.Oauth2Credential, error) + CreateAuthKey(ctx context.Context, consumerID string, keyAuth *klib.KeyAuth) (*klib.KeyAuth, error) + // Access Request + AddRouteACL(ctx context.Context, routeID, allowedID string) error + RemoveRouteACL(ctx context.Context, routeID, revokedID string) error + AddQuota(ctx context.Context, routeID, allowedID, quotaInterval string, quotaLimit int) error + // Discovery + ListServices(ctx context.Context) ([]*klib.Service, error) + ListRoutesForService(ctx context.Context, serviceId string) ([]*klib.Route, error) + GetSpecForService(ctx context.Context, service *klib.Service) ([]byte, error) + GetKongPlugins() *kong.Plugins +} + type provisioner struct { logger log.FieldLogger - client kong.KongAPIClient + client kongClient aclDisable bool } // NewProvisioner creates a type to implement the SDK Provisioning methods for handling subscriptions -func NewProvisioner(client kong.KongAPIClient, logger log.FieldLogger, opts ...ProvisionerOption) { +func NewProvisioner(client kongClient, opts ...ProvisionerOption) { + logger := log.NewFieldLogger().WithComponent("provision").WithPackage("subscription") logger.Info("Registering provisioning callbacks") provisioner := &provisioner{ client: client, diff --git a/pkg/traceability/processor/processor_test.go b/pkg/traceability/processor/processor_test.go index 2c5473f..81c38ef 100644 --- a/pkg/traceability/processor/processor_test.go +++ b/pkg/traceability/processor/processor_test.go @@ -4,11 +4,16 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + + "github.com/Axway/agent-sdk/pkg/agent" + "github.com/Axway/agent-sdk/pkg/apic/mock" + "github.com/Axway/agent-sdk/pkg/config" "github.com/Axway/agent-sdk/pkg/traceability/redaction" "github.com/Axway/agent-sdk/pkg/traceability/sampling" "github.com/Axway/agent-sdk/pkg/transaction/metric" - "github.com/Axway/agents-kong/pkg/traceability/processor/mock" - "github.com/stretchr/testify/assert" + + collectorMock "github.com/Axway/agents-kong/pkg/traceability/processor/mock" ) var testData = []byte(`[{ @@ -95,6 +100,8 @@ func TestNewHandler(t *testing.T) { } for name, tc := range cases { t.Run(name, func(t *testing.T) { + + agent.InitializeForTest(&mock.Client{}, agent.TestWithAgentType(config.TraceabilityAgent)) ctx := context.WithValue(context.Background(), "test", name) redaction.SetupGlobalRedaction(redaction.DefaultConfig()) @@ -111,14 +118,14 @@ func TestNewHandler(t *testing.T) { assert.NotNil(t, h) // setup collector - collector := &mock.CollectorMock{Details: make([]metric.Detail, 0), Expected: tc.expectedMetricDetails} - mock.SetMockCollector(collector) + collector := &collectorMock.CollectorMock{Details: make([]metric.Detail, 0), Expected: tc.expectedMetricDetails} + collectorMock.SetMockCollector(collector) h.collectorGetter = func() metricCollector { - return mock.GetMockCollector() + return collectorMock.GetMockCollector() } // setup event generator - h.eventGenerator = mock.NewEventGeneratorMock + h.eventGenerator = collectorMock.NewEventGeneratorMock // if metric details are expected if tc.expectedMetricDetails >= 1 { @@ -130,7 +137,7 @@ func TestNewHandler(t *testing.T) { collector.Wait() assert.Nil(t, err) assert.Len(t, events, tc.expectedEvents) - assert.Equal(t, tc.expectedMetricDetails, len(mock.GetMockCollector().Details)) + assert.Equal(t, tc.expectedMetricDetails, len(collectorMock.GetMockCollector().Details)) }) } }