From e832aa7ab90a4a74a29b6d160e640ae35712570c Mon Sep 17 00:00:00 2001 From: mchavez Date: Tue, 30 Jul 2024 17:14:25 -0600 Subject: [PATCH 1/2] Adding in-memory cache parameters --- pkg/client/client.go | 12 +++++++++++- pkg/client/internal_integration_test.go | 13 +++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index a0c4bfa..7c8969b 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -107,7 +107,17 @@ func New(ctx context.Context, clientId, clientSecret, orgId, vaultId string) (*V return nil, err } - cli := uhttp.NewBaseHttpClient(httpClient) + // Setting up in-cache memory parameters + ctx = context.WithValue(ctx, uhttp.ContextKey{}, uhttp.CacheConfig{ + LogDebug: true, + CacheTTL: int32(1000), + CacheMaxSize: int(1024), + }) + cli, err := uhttp.NewBaseHttpClient(ctx, httpClient) + if err != nil { + return nil, err + } + req, err := cli.NewRequest(ctx, http.MethodPost, uri, diff --git a/pkg/client/internal_integration_test.go b/pkg/client/internal_integration_test.go index aed57ac..0442b64 100644 --- a/pkg/client/internal_integration_test.go +++ b/pkg/client/internal_integration_test.go @@ -104,6 +104,19 @@ func TestVaultMembers(t *testing.T) { var data any err = json.Unmarshal(res, &data) assert.Nil(t, err) + + // -- force cache response -- + resp1, err := cli.httpClient.Do(req) + assert.Nil(t, err) + + defer resp1.Body.Close() + res1, err := io.ReadAll(resp1.Body) + assert.Nil(t, err) + assert.NotNil(t, res1) + + var data1 any + err = json.Unmarshal(res1, &data1) + assert.Nil(t, err) } func getClientForTesting(ctx context.Context, clientId, clientSecret, orgId, vaultId string) (*VGSClient, error) { From b1729e17dc2847bb05de69c70135acb2ad5af7ca Mon Sep 17 00:00:00 2001 From: mchavez Date: Wed, 31 Jul 2024 14:47:40 -0600 Subject: [PATCH 2/2] Refactoring code --- cmd/baton-vgs/main.go | 38 ++++++------ pkg/client/client.go | 80 ++++++++++++++++++++++--- pkg/client/internal_integration_test.go | 22 +++++-- pkg/connector/connector.go | 25 +++++--- pkg/connector/internal_test.go | 64 +++++++++++++++----- 5 files changed, 174 insertions(+), 55 deletions(-) diff --git a/cmd/baton-vgs/main.go b/cmd/baton-vgs/main.go index 65f0067..5cca445 100644 --- a/cmd/baton-vgs/main.go +++ b/cmd/baton-vgs/main.go @@ -9,6 +9,7 @@ import ( "github.com/conductorone/baton-sdk/pkg/connectorbuilder" "github.com/conductorone/baton-sdk/pkg/field" "github.com/conductorone/baton-sdk/pkg/types" + "github.com/conductorone/baton-vgs/pkg/client" "github.com/conductorone/baton-vgs/pkg/connector" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "github.com/spf13/viper" @@ -16,25 +17,31 @@ import ( ) const ( - version = "dev" - connectorName = "baton-vgs" - serviceAccountClientId = "service-account-client-id" - serviceAccountClientSecret = "service-account-client-secret" - organizationId = "organization-id" - vault = "vault" + version = "dev" + connectorName = "baton-vgs" + batonCacheDisable = "cache-disable" + batonCacheTTL = "cache-ttl" + batonCacheMaxSize = "cache-max-size" ) var ( - ServiceAccountClientId = field.StringField(serviceAccountClientId, field.WithRequired(true), field.WithDescription("The VGS client id.")) - ServiceAccountClientSecret = field.StringField(serviceAccountClientSecret, field.WithRequired(true), field.WithDescription("The VGS client secret.")) - OrganizationId = field.StringField(organizationId, field.WithRequired(true), field.WithDescription("The VGS organization id.")) - Vault = field.StringField(vault, field.WithRequired(true), field.WithDescription("The VGS vault id.")) - configurationFields = []field.SchemaField{Vault, ServiceAccountClientId, ServiceAccountClientSecret, OrganizationId} + ServiceAccountClientId = field.StringField(client.ServiceAccountClientIdName, field.WithRequired(true), field.WithDescription("The VGS client id.")) + ServiceAccountClientSecret = field.StringField(client.ServiceAccountClientSecretName, field.WithRequired(true), field.WithDescription("The VGS client secret.")) + OrganizationId = field.StringField(client.OrganizationId, field.WithRequired(true), field.WithDescription("The VGS organization id.")) + Vault = field.StringField(client.VaultId, field.WithRequired(true), field.WithDescription("The VGS vault id.")) + CacheDisabled = field.StringField(batonCacheDisable, field.WithRequired(false), field.WithDescription("Verbose mode shows information about new memory allocation.")) + CacheTTL = field.StringField(batonCacheTTL, field.WithRequired(false), field.WithDescription("Time after which entry can be evicted.")) + CacheMaxSize = field.StringField(batonCacheMaxSize, field.WithRequired(false), field.WithDescription("It is a limit for BytesQueue size in MB.")) + configurationFields = []field.SchemaField{Vault, ServiceAccountClientId, ServiceAccountClientSecret, OrganizationId, CacheDisabled, CacheTTL, CacheMaxSize} ) func main() { ctx := context.Background() - _, cmd, err := configSchema.DefineConfiguration(ctx, connectorName, getConnector, field.NewConfiguration(configurationFields)) + _, cmd, err := configSchema.DefineConfiguration(ctx, + connectorName, + getConnector, + field.NewConfiguration(configurationFields), + ) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) @@ -50,12 +57,7 @@ func main() { func getConnector(ctx context.Context, cfg *viper.Viper) (types.ConnectorServer, error) { l := ctxzap.Extract(ctx) - cb, err := connector.New(ctx, - cfg.GetString(serviceAccountClientId), - cfg.GetString(serviceAccountClientSecret), - cfg.GetString(organizationId), - cfg.GetString(vault), - ) + cb, err := connector.New(ctx, cfg) if err != nil { l.Error("error creating connector", zap.Error(err)) return nil, err diff --git a/pkg/client/client.go b/pkg/client/client.go index 7c8969b..7d76c66 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -16,12 +16,68 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" ) -type VGSClient struct { - httpClient *uhttp.BaseHttpClient - token *JWT - serviceEndpoint string - organizationId string - vaultId string +type ( + VGSClient struct { + httpClient *uhttp.BaseHttpClient + token *JWT + serviceEndpoint string + organizationId string + vaultId string + } + + Config struct { + serviceAccountClientId string + serviceAccountClientSecret string + organizationId string + vaultId string + } +) + +const ( + ServiceAccountClientIdName = "service-account-client-id" + ServiceAccountClientSecretName = "service-account-client-secret" + OrganizationId = "organization-id" + VaultId = "vault" + serviceAccountClient = "serviceAccountClientId" + serviceAccountClientSecret = "serviceAccountClientSecret" + organization = "organizationId" + vault = "vaultId" + empty = "" +) + +func (c *Config) WithServiceAccountClientId(sAccId string) *Config { + c.serviceAccountClientId = sAccId + return c +} + +func (c *Config) WithServiceAccountClientSecret(sAccSec string) *Config { + c.serviceAccountClientSecret = sAccSec + return c +} + +func (c *Config) WithOrganizationId(orgId string) *Config { + c.organizationId = orgId + return c +} + +func (c *Config) WithVaultId(vId string) *Config { + c.vaultId = vId + return c +} + +func (c *Config) getFieldValue(fieldName string) string { + switch fieldName { + case serviceAccountClient: + return c.serviceAccountClientId + case serviceAccountClientSecret: + return c.serviceAccountClientSecret + case organization: + return c.organizationId + case vault: + return c.vaultId + } + + return empty } func WithBody(body string) uhttp.RequestOption { @@ -95,8 +151,14 @@ func WithSetBasicAuthHeader(username, password string) uhttp.RequestOption { return uhttp.WithHeader("Authorization", "Basic "+basicAuth(username, password)) } -func New(ctx context.Context, clientId, clientSecret, orgId, vaultId string) (*VGSClient, error) { - var jwt = &JWT{} +func New(ctx context.Context, cfg Config) (*VGSClient, error) { + var ( + jwt = &JWT{} + clientId = cfg.getFieldValue(serviceAccountClient) + clientSecret = cfg.getFieldValue(serviceAccountClientSecret) + orgId = cfg.getFieldValue(organization) + vaultId = cfg.getFieldValue(vault) + ) uri, err := url.Parse("https://auth.verygoodsecurity.com/auth/realms/vgs/protocol/openid-connect/token") if err != nil { return nil, err @@ -107,7 +169,7 @@ func New(ctx context.Context, clientId, clientSecret, orgId, vaultId string) (*V return nil, err } - // Setting up in-cache memory parameters + // Setting up in-cache memory parameters, otherwise it takes default values ctx = context.WithValue(ctx, uhttp.ContextKey{}, uhttp.CacheConfig{ LogDebug: true, CacheTTL: int32(1000), diff --git a/pkg/client/internal_integration_test.go b/pkg/client/internal_integration_test.go index 0442b64..dfac940 100644 --- a/pkg/client/internal_integration_test.go +++ b/pkg/client/internal_integration_test.go @@ -18,6 +18,12 @@ var ( clientSecret, _ = os.LookupEnv("BATON_SERVICE_ACCOUNT_CLIENT_SECRET") vaultId, _ = os.LookupEnv("BATON_VAULT") orgId, _ = os.LookupEnv("BATON_ORGANIZATION_ID") + cfg = Config{ + serviceAccountClientId: clientId, + serviceAccountClientSecret: clientSecret, + organizationId: orgId, + vaultId: vaultId, + } ) const ( @@ -48,7 +54,13 @@ func TestOrganizationResources(t *testing.T) { }, } - cli, err := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cfg := Config{ + serviceAccountClientId: clientId, + serviceAccountClientSecret: clientSecret, + organizationId: orgId, + vaultId: vaultId, + } + cli, err := getClientForTesting(ctx, cfg) assert.Nil(t, err) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -81,7 +93,7 @@ func TestVaultMembers(t *testing.T) { t.Skip() } - cli, err := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cli, err := getClientForTesting(ctx, cfg) assert.Nil(t, err) endpointUrl, err := url.JoinPath(baseUrl, "vaults", vaultId, "members") @@ -119,8 +131,8 @@ func TestVaultMembers(t *testing.T) { assert.Nil(t, err) } -func getClientForTesting(ctx context.Context, clientId, clientSecret, orgId, vaultId string) (*VGSClient, error) { - cli, err := New(ctx, clientId, clientSecret, orgId, vaultId) +func getClientForTesting(ctx context.Context, cfg Config) (*VGSClient, error) { + cli, err := New(ctx, cfg) return cli, err } @@ -129,7 +141,7 @@ func TestVaults(t *testing.T) { t.Skip() } - cli, err := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cli, err := getClientForTesting(ctx, cfg) assert.Nil(t, err) endpointUrl, err := url.JoinPath(baseUrl, "vaults") diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 504ef6e..ddc6fc4 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -8,11 +8,14 @@ import ( "github.com/conductorone/baton-sdk/pkg/annotations" "github.com/conductorone/baton-sdk/pkg/connectorbuilder" "github.com/conductorone/baton-vgs/pkg/client" + "github.com/spf13/viper" ) -type Connector struct { - client *client.VGSClient -} +type ( + Connector struct { + client *client.VGSClient + } +) // ResourceSyncers returns a ResourceSyncer for each resource type that should be synced from the upstream service. func (d *Connector) ResourceSyncers(ctx context.Context) []connectorbuilder.ResourceSyncer { @@ -44,13 +47,21 @@ func (d *Connector) Validate(ctx context.Context) (annotations.Annotations, erro } // New returns a new instance of the connector. -func New(ctx context.Context, clientId, clientSecret, organizationId, vaultId string) (*Connector, error) { +func New(ctx context.Context, cfg *viper.Viper) (*Connector, error) { var ( - vc *client.VGSClient - err error + vc *client.VGSClient + config = client.Config{} + clientId = cfg.GetString(client.ServiceAccountClientIdName) + clientSecret = cfg.GetString(client.ServiceAccountClientSecretName) + organizationId = cfg.GetString(client.OrganizationId) + vaultId = cfg.GetString(client.VaultId) + err error ) + + config.WithServiceAccountClientId(clientId).WithServiceAccountClientSecret(clientSecret) + config.WithOrganizationId(organizationId).WithVaultId(vaultId) if clientId != "" && clientSecret != "" { - vc, err = client.New(ctx, clientId, clientSecret, organizationId, vaultId) + vc, err = client.New(ctx, config) if err != nil { return nil, err } diff --git a/pkg/connector/internal_test.go b/pkg/connector/internal_test.go index d2e01ac..696e2a0 100644 --- a/pkg/connector/internal_test.go +++ b/pkg/connector/internal_test.go @@ -24,9 +24,12 @@ func TestUserResourceTypeList(t *testing.T) { t.Skip() } + cli, err := getClientForTesting(ctx) + assert.Nil(t, err) + user := &userResourceType{ resourceType: &v2.ResourceType{}, - client: getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId), + client: cli, } rs, _, _, err := user.List(ctx, &v2.ResourceId{}, &pagination.Token{}) assert.Nil(t, err) @@ -38,9 +41,12 @@ func TestOrgResourceTypeList(t *testing.T) { t.Skip() } + cli, err := getClientForTesting(ctx) + assert.Nil(t, err) + org := &orgResourceType{ resourceType: &v2.ResourceType{}, - client: getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId), + client: cli, } rs, _, _, err := org.List(ctx, &v2.ResourceId{}, &pagination.Token{}) assert.Nil(t, err) @@ -52,18 +58,30 @@ func TestVaultResourceTypeList(t *testing.T) { t.Skip() } + cli, err := getClientForTesting(ctx) + assert.Nil(t, err) + vault := &vaultResourceType{ resourceType: &v2.ResourceType{}, - client: getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId), + client: cli, } rs, _, _, err := vault.List(ctx, &v2.ResourceId{}, &pagination.Token{}) assert.Nil(t, err) assert.NotNil(t, rs) } -func getClientForTesting(ctx context.Context, clientId, clientSecret, orgId, vaultId string) *client.VGSClient { - cli, _ := client.New(ctx, clientId, clientSecret, orgId, vaultId) - return cli +func getClientForTesting(ctx context.Context) (*client.VGSClient, error) { + cfg := client.Config{} + cfg.WithVaultId(vaultId). + WithOrganizationId(orgId). + WithServiceAccountClientId(clientId). + WithServiceAccountClientSecret(clientSecret) + cli, err := client.New(ctx, cfg) + if err != nil { + return nil, err + } + + return cli, nil } func TestClient(t *testing.T) { @@ -71,7 +89,7 @@ func TestClient(t *testing.T) { t.Skip() } - cli, err := client.New(ctx, clientId, clientSecret, orgId, vaultId) + cli, err := getClientForTesting(ctx) assert.Nil(t, err) assert.NotNil(t, cli) } @@ -81,7 +99,9 @@ func TestListVaults(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + lv, err := cliTest.ListVaults(ctx) assert.Nil(t, err) assert.NotNil(t, lv) @@ -92,7 +112,9 @@ func TestListVaultUsers(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + lvu, err := cliTest.ListVaultUsers(ctx, vaultId) assert.Nil(t, err) assert.NotNil(t, lvu) @@ -103,7 +125,9 @@ func TestListUsers(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + lu, err := cliTest.ListUsers(ctx, orgId, vaultId) assert.Nil(t, err) assert.NotNil(t, lu) @@ -114,7 +138,9 @@ func TestListUserInvites(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + lui, err := cliTest.ListUserInvites(ctx, orgId) assert.Nil(t, err) assert.NotNil(t, lui) @@ -125,7 +151,9 @@ func TestListOrganizations(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + lo, err := cliTest.ListOrganizations(ctx) assert.Nil(t, err) assert.NotNil(t, lo) @@ -136,8 +164,10 @@ func TestUpdateVault(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) - err := cliTest.UpdateUserAccessVault(ctx, vaultId, "ID9hRKLhcc6RWBvaHQ7L1Uan", "write") + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + + err = cliTest.UpdateUserAccessVault(ctx, vaultId, "ID9hRKLhcc6RWBvaHQ7L1Uan", "write") assert.Nil(t, err) } @@ -146,7 +176,9 @@ func TestRevokeVault(t *testing.T) { t.Skip() } - cliTest := getClientForTesting(ctx, clientId, clientSecret, orgId, vaultId) - err := cliTest.RevokeUserAccessVault(ctx, vaultId, "IDjSP9BVbJ3RnPr2FonGxXp5") + cliTest, err := getClientForTesting(ctx) + assert.Nil(t, err) + + err = cliTest.RevokeUserAccessVault(ctx, vaultId, "IDjSP9BVbJ3RnPr2FonGxXp5") assert.Nil(t, err) }