diff --git a/cmd/single/start.go b/cmd/single/start.go index a786c0b7e4..a86fc3772b 100644 --- a/cmd/single/start.go +++ b/cmd/single/start.go @@ -122,7 +122,7 @@ func startPropeller(ctx context.Context, cfg Propeller) error { DefaultNamespaces: namespaceConfigs, }, NewCache: executors.NewCache, - NewClient: executors.NewClient, + NewClient: executors.BuildNewClientFunc(propellerScope), Metrics: metricsserver.Options{ // Disable metrics serving BindAddress: "0", diff --git a/flytepropeller/cmd/controller/cmd/root.go b/flytepropeller/cmd/controller/cmd/root.go index 8696f3993a..e1069650ad 100644 --- a/flytepropeller/cmd/controller/cmd/root.go +++ b/flytepropeller/cmd/controller/cmd/root.go @@ -144,7 +144,7 @@ func executeRootCmd(baseCtx context.Context, cfg *config2.Config) error { DefaultNamespaces: namespaceConfigs, }, NewCache: executors.NewCache, - NewClient: executors.NewClient, + NewClient: executors.BuildNewClientFunc(propellerScope), Metrics: metricsserver.Options{ // Disable metrics serving BindAddress: "0", diff --git a/flytepropeller/cmd/controller/cmd/webhook.go b/flytepropeller/cmd/controller/cmd/webhook.go index e3c29ae3d9..ae538385fb 100644 --- a/flytepropeller/cmd/controller/cmd/webhook.go +++ b/flytepropeller/cmd/controller/cmd/webhook.go @@ -109,7 +109,7 @@ func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *w DefaultNamespaces: namespaceConfigs, }, NewCache: executors.NewCache, - NewClient: executors.NewClient, + NewClient: executors.BuildNewClientFunc(webhookScope), Metrics: metricsserver.Options{ // Disable metrics serving BindAddress: "0", diff --git a/flytepropeller/pkg/controller/executors/kube.go b/flytepropeller/pkg/controller/executors/kube.go index bdab0d91be..d6d89e1711 100644 --- a/flytepropeller/pkg/controller/executors/kube.go +++ b/flytepropeller/pkg/controller/executors/kube.go @@ -33,107 +33,95 @@ var NewCache = func(config *rest.Config, options cache.Options) (cache.Cache, er return otelutils.WrapK8sCache(k8sCache), nil } -var NewClient = func(config *rest.Config, options client.Options) (client.Client, error) { - var reader *fallbackClientReader - if options.Cache != nil && options.Cache.Reader != nil { - // if caching is enabled we create a fallback reader so we can attempt the client if the cache - // reader does not have the object - reader = &fallbackClientReader{ - orderedClients: []client.Reader{options.Cache.Reader}, +func BuildNewClientFunc(scope promutils.Scope) func(config *rest.Config, options client.Options) (client.Client, error) { + return func(config *rest.Config, options client.Options) (client.Client, error) { + var cacheReader client.Reader + cachelessOptions := options + if options.Cache != nil && options.Cache.Reader != nil { + cacheReader = options.Cache.Reader + cachelessOptions.Cache = nil } - options.Cache.Reader = reader - } - - // create the k8s client - k8sClient, err := client.New(config, options) - if err != nil { - return k8sClient, err - } + kubeClient, err := client.New(config, cachelessOptions) + if err != nil { + return nil, err + } - k8sOtelClient := otelutils.WrapK8sClient(k8sClient) - if reader != nil { - // once the k8s client is created we set the fallback reader's client to the k8s client - reader.orderedClients = append(reader.orderedClients, k8sOtelClient) + return newFlyteK8sClient(kubeClient, cacheReader, scope) } - - return k8sOtelClient, nil } -// fallbackClientReader reads from the cache first and if not found then reads from the configured reader, which -// directly reads from the API -type fallbackClientReader struct { - orderedClients []client.Reader +type flyteK8sClient struct { + client.Client + cacheReader client.Reader + writeFilter fastcheck.Filter } -func (c fallbackClientReader) Get(ctx context.Context, key client.ObjectKey, out client.Object, opts ...client.GetOption) (err error) { - for _, k8sClient := range c.orderedClients { - if err = k8sClient.Get(ctx, key, out, opts...); err == nil { +func (f flyteK8sClient) Get(ctx context.Context, key client.ObjectKey, out client.Object, opts ...client.GetOption) (err error) { + if f.cacheReader != nil { + if err = f.cacheReader.Get(ctx, key, out, opts...); err == nil { return nil } } - return + return f.Client.Get(ctx, key, out, opts...) } -func (c fallbackClientReader) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) (err error) { - for _, k8sClient := range c.orderedClients { - if err = k8sClient.List(ctx, list, opts...); err == nil { +func (f flyteK8sClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) (err error) { + if f.cacheReader != nil { + if err = f.cacheReader.List(ctx, list, opts...); err == nil { return nil } } - return -} - -type writeThroughCachingWriter struct { - client.Client - filter fastcheck.Filter -} - -func IDFromObject(obj client.Object, op string) []byte { - return []byte(fmt.Sprintf("%s:%s:%s:%s", obj.GetObjectKind().GroupVersionKind().String(), obj.GetNamespace(), obj.GetName(), op)) + return f.Client.List(ctx, list, opts...) } // Create first checks the local cache if the object with id was previously successfully saved, if not then // saves the object obj in the Kubernetes cluster -func (w writeThroughCachingWriter) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { +func (f flyteK8sClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { // "c" represents create - id := IDFromObject(obj, "c") - if w.filter.Contains(ctx, id) { + id := idFromObject(obj, "c") + if f.writeFilter.Contains(ctx, id) { return nil } - err := w.Client.Create(ctx, obj, opts...) + err := f.Client.Create(ctx, obj, opts...) if err != nil { return err } - w.filter.Add(ctx, id) + f.writeFilter.Add(ctx, id) return nil } // Delete first checks the local cache if the object with id was previously successfully deleted, if not then // deletes the given obj from Kubernetes cluster. -func (w writeThroughCachingWriter) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { +func (f flyteK8sClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { // "d" represents delete - id := IDFromObject(obj, "d") - if w.filter.Contains(ctx, id) { + id := idFromObject(obj, "d") + if f.writeFilter.Contains(ctx, id) { return nil } - err := w.Client.Delete(ctx, obj, opts...) + err := f.Client.Delete(ctx, obj, opts...) if err != nil { return err } - w.filter.Add(ctx, id) + f.writeFilter.Add(ctx, id) return nil } -func newWriteThroughCachingWriter(c client.Client, cacheSize int, scope promutils.Scope) (writeThroughCachingWriter, error) { - filter, err := fastcheck.NewOppoBloomFilter(cacheSize, scope.NewSubScope("kube_filter")) +func idFromObject(obj client.Object, op string) []byte { + return []byte(fmt.Sprintf("%s:%s:%s:%s", obj.GetObjectKind().GroupVersionKind().String(), obj.GetNamespace(), obj.GetName(), op)) +} + +func newFlyteK8sClient(kubeClient client.Client, cacheReader client.Reader, scope promutils.Scope) (flyteK8sClient, error) { + writeFilter, err := fastcheck.NewOppoBloomFilter(50000, scope.NewSubScope("kube_filter")) if err != nil { - return writeThroughCachingWriter{}, err + return flyteK8sClient{}, err } - return writeThroughCachingWriter{ - Client: c, - filter: filter, + + return flyteK8sClient{ + Client: kubeClient, + cacheReader: cacheReader, + writeFilter: writeFilter, }, nil } diff --git a/flytepropeller/pkg/controller/executors/kube_test.go b/flytepropeller/pkg/controller/executors/kube_test.go index bcaa64ff6f..4d84d3fb08 100644 --- a/flytepropeller/pkg/controller/executors/kube_test.go +++ b/flytepropeller/pkg/controller/executors/kube_test.go @@ -2,13 +2,14 @@ package executors import ( "context" - "fmt" "reflect" "testing" "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/flyteorg/flyte/flytestdlib/contextutils" @@ -45,42 +46,46 @@ func TestIdFromObject(t *testing.T) { APIVersion: "v1", }, } - if got := IDFromObject(p, tt.args.op); !reflect.DeepEqual(got, []byte(tt.want)) { - t.Errorf("IDFromObject() = %s, want %s", string(got), tt.want) + if got := idFromObject(p, tt.args.op); !reflect.DeepEqual(got, []byte(tt.want)) { + t.Errorf("idFromObject() = %s, want %s", string(got), tt.want) } }) } } -type singleInvokeClient struct { +type mockKubeClient struct { client.Client - createCalled bool - deleteCalled bool + createCalledCount int + deleteCalledCount int + getCalledCount int + getMissCount int } -func (f *singleInvokeClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { - if f.createCalled { - return fmt.Errorf("create called more than once") - } - f.createCalled = true +func (m *mockKubeClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + m.createCalledCount++ + return nil +} + +func (m *mockKubeClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + m.deleteCalledCount++ return nil } -func (f *singleInvokeClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { - if f.deleteCalled { - return fmt.Errorf("delete called more than once") +func (m *mockKubeClient) Get(ctx context.Context, objectKey types.NamespacedName, obj client.Object, opts ...client.GetOption) error { + if m.getCalledCount < m.getMissCount { + m.getMissCount-- + return k8serrors.NewNotFound(v1.Resource("pod"), "name") } - f.deleteCalled = true + + m.getCalledCount++ return nil } -func TestWriteThroughCachingWriter_Create(t *testing.T) { +func TestFlyteK8sClient(t *testing.T) { ctx := context.TODO() - c := &singleInvokeClient{} - w, err := newWriteThroughCachingWriter(c, 1000, promutils.NewTestScope()) - assert.NoError(t, err) + scope := promutils.NewTestScope() - p := &v1.Pod{ + pod := &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Namespace: "ns", Name: "name", @@ -91,39 +96,73 @@ func TestWriteThroughCachingWriter_Create(t *testing.T) { }, } - err = w.Create(ctx, p) - assert.NoError(t, err) + objectKey := types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + } - assert.True(t, c.createCalled) + // test cache reader + tests := []struct { + name string + initCacheReader bool + cacheMissCount int + expectedClientGetCount int + }{ + {"no-cache", false, 0, 2}, + {"with-cache-one-miss", true, 1, 1}, + {"with-cache-no-misses", true, 0, 0}, + } - err = w.Create(ctx, p) - assert.NoError(t, err) -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cacheReader client.Reader + if tt.initCacheReader { + cacheReader = &mockKubeClient{ + getMissCount: tt.cacheMissCount, + } + } -func TestWriteThroughCachingWriter_Delete(t *testing.T) { - ctx := context.TODO() - c := &singleInvokeClient{} - w, err := newWriteThroughCachingWriter(c, 1000, promutils.NewTestScope()) - assert.NoError(t, err) + kubeClient := &mockKubeClient{} - p := &v1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "ns", - Name: "name", - }, - TypeMeta: metav1.TypeMeta{ - Kind: "pod", - APIVersion: "v1", - }, - } + flyteK8sClient, err := newFlyteK8sClient(kubeClient, cacheReader, scope.NewSubScope(tt.name)) + assert.NoError(t, err) - err = w.Delete(ctx, p) - assert.NoError(t, err) + for i := 0; i < 2; i++ { + err := flyteK8sClient.Get(ctx, objectKey, pod) + assert.NoError(t, err) + } - assert.True(t, c.deleteCalled) + assert.Equal(t, tt.expectedClientGetCount, kubeClient.getCalledCount) + }) + } - err = w.Delete(ctx, p) - assert.NoError(t, err) + // test create + t.Run("create", func(t *testing.T) { + kubeClient := &mockKubeClient{} + flyteK8sClient, err := newFlyteK8sClient(kubeClient, nil, scope.NewSubScope("create")) + assert.NoError(t, err) + + for i := 0; i < 5; i++ { + err = flyteK8sClient.Create(ctx, pod) + assert.NoError(t, err) + } + + assert.Equal(t, 1, kubeClient.createCalledCount) + }) + + // test delete + t.Run("delete", func(t *testing.T) { + kubeClient := &mockKubeClient{} + flyteK8sClient, err := newFlyteK8sClient(kubeClient, nil, scope.NewSubScope("delete")) + assert.NoError(t, err) + + for i := 0; i < 5; i++ { + err = flyteK8sClient.Delete(ctx, pod) + assert.NoError(t, err) + } + + assert.Equal(t, 1, kubeClient.deleteCalledCount) + }) } func init() {