diff --git a/cmd/xgql/main.go b/cmd/xgql/main.go index e185951..2406e07 100644 --- a/cmd/xgql/main.go +++ b/cmd/xgql/main.go @@ -70,10 +70,10 @@ import ( "github.com/upbound/xgql/internal/auth" "github.com/upbound/xgql/internal/cache" "github.com/upbound/xgql/internal/clients" - "github.com/upbound/xgql/internal/graph/extensions/live_query" "github.com/upbound/xgql/internal/graph/generated" "github.com/upbound/xgql/internal/graph/present" "github.com/upbound/xgql/internal/graph/resolvers" + "github.com/upbound/xgql/internal/live_query" "github.com/upbound/xgql/internal/opentelemetry" "github.com/upbound/xgql/internal/request" hprobe "github.com/upbound/xgql/internal/server/health" @@ -214,7 +214,7 @@ func main() { //nolint:gocyclo camid = append(camid, cache.WithBBoltCache(*cacheFile)) } // enable live queries - camid = append(camid, clients.WithLiveQueries) + camid = append(camid, cache.WithLiveQueries) caopts := []clients.CacheOption{ clients.WithRESTMapper(rm), @@ -250,7 +250,7 @@ func main() { //nolint:gocyclo srv.Use(opentelemetry.MetricEmitter{}) srv.Use(opentelemetry.Tracer{}) if !*noApolloTracing { - srv.Use(apollotracing.Tracer{}) + srv.Use(apollotracing.Tracer{}) } srv.Use(live_query.LiveQuery{}) diff --git a/internal/cache/live_query.go b/internal/cache/live_query.go new file mode 100644 index 0000000..5d52e65 --- /dev/null +++ b/internal/cache/live_query.go @@ -0,0 +1,319 @@ +// Copyright 2023 Upbound Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "strings" + "sync" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/rest" + toolscache "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" + + "github.com/upbound/xgql/internal/clients" + "github.com/upbound/xgql/internal/live_query" +) + +// WithLiveQueries wraps NewCacheFn with a cache.Cache that tracks objects +// and object lists and notifies the live query in request context of changes. +func WithLiveQueries(fn clients.NewCacheFn) clients.NewCacheFn { + return func(cfg *rest.Config, o cache.Options) (cache.Cache, error) { + c, err := fn(cfg, o) + if err != nil { + return nil, err + } + return &liveQueryCache{ + Cache: c, + scheme: o.Scheme, + queries: make(map[uint64]*liveQueryTracker), + handles: make(set[schema.GroupVersionKind]), + }, nil + } +} + +var _ toolscache.ResourceEventHandler = (*liveQueryCache)(nil) + +// liveQueryCache is a cache.Cache that registers cache.Informer listeners for any +// retrieved object if executed in the context of a live query. When liveQueryCache +// is notified of events, it will trigger any active live queries. +type liveQueryCache struct { + cache.Cache + scheme *runtime.Scheme + + lock sync.Mutex + queries map[uint64]*liveQueryTracker + handles set[schema.GroupVersionKind] +} + +// Get implements cache.Cache. It wraps an underlying cache.Cache and sets up an Informer +// event handler that marks current live query as dirty if the current context has a live query. +func (c *liveQueryCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if err := c.Cache.Get(ctx, key, obj, opts...); err != nil { + return err + } + return c.trackObject(ctx, obj) +} + +// List implements cache.Cache. It wraps an underlying cache.Cache and sets up an Informer +// event handler that marks current live query as dirty if the current context has a live query. +func (c *liveQueryCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + if err := c.Cache.List(ctx, list, opts...); err != nil { + return err + } + return c.trackObject(ctx, list) +} + +// trackObject registers object or object list with a tracker for the live query. +// any updated from cache.Informer is broadcast to all live query trackers, if the +// changed object is tracked by a given liveQueryTracker, the live query associated +// with the tracker is Trigger()'d. +func (c *liveQueryCache) trackObject(ctx context.Context, object runtime.Object) error { + qid, ok := live_query.IsLive(ctx) + // if this isn't a live query context, skip. + if !ok { + return nil + } + gvk, err := apiutil.GVKForObject(object, c.scheme) + if err != nil { + return err + } + if _, ok := object.(client.ObjectList); ok { + // We need the non-list GVK, so chop off the "List" from the end of the kind. + gvk.Kind = strings.TrimSuffix(gvk.Kind, "List") + } + c.lock.Lock() + defer c.lock.Unlock() + // register event handler for the GVK that if we aren't watching it already. + if !c.handles.Contains(gvk) { + i, err := c.getInformer(ctx, object, gvk) + if err != nil { + return err + } + if _, err := i.AddEventHandler(c); err != nil { + return err + } else { + c.handles.Add(gvk) + } + } + // register live query tracker if we're not tracking it already. + q, ok := c.queries[qid] + if !ok { + q = newLiveQueryTracker(ctx) + c.queries[qid] = q + } + // register object or object list with the live query tracker. + switch o := object.(type) { + case client.Object: + q.Track(o.GetUID(), gvk) + case client.ObjectList: + q.TrackList(gvk) + } + return nil +} + +// getInformer gets cache.Informer for object and gvk. +func (c *liveQueryCache) getInformer(ctx context.Context, object runtime.Object, gvk schema.GroupVersionKind) (cache.Informer, error) { + // Handle unstructured.UnstructuredList. + if _, isUnstructured := object.(runtime.Unstructured); isUnstructured { + u := &unstructured.Unstructured{} + u.SetGroupVersionKind(gvk) + return c.Cache.GetInformer(ctx, u) + } + // Handle metav1.PartialObjectMetadataList. + if _, isPartialObjectMetadata := object.(*metav1.PartialObjectMetadataList); isPartialObjectMetadata { + pom := &metav1.PartialObjectMetadata{} + pom.SetGroupVersionKind(gvk) + return c.Cache.GetInformer(ctx, pom) + } + return c.Cache.GetInformerForKind(ctx, gvk) +} + +// OnAdd implements cache.ResourceEventHandler. +// Broadcasts the object change to all live query trackers after the initial sync. +func (c *liveQueryCache) OnAdd(obj interface{}, isInInitialList bool) { + // we don't care about initial sync + if isInInitialList { + return + } + object, ok := obj.(client.Object) + if !ok { + return + } + gvk, err := apiutil.GVKForObject(object, c.scheme) + if err != nil { + return + } + c.lock.Lock() + defer c.lock.Unlock() + for i := range c.queries { + if !c.queries[i].IsLive() { + delete(c.queries, i) + continue + } + c.queries[i].OnCreate(object, gvk) + } +} + +// OnDelete implements cache.ResourceEventHandler. +// Broadcasts the object change to all live query trackers after the initial sync. +func (c *liveQueryCache) OnDelete(obj interface{}) { + object, ok := obj.(client.Object) + if !ok { + return + } + gvk, err := apiutil.GVKForObject(object, c.scheme) + if err != nil { + return + } + c.lock.Lock() + defer c.lock.Unlock() + for i := range c.queries { + if !c.queries[i].IsLive() { + delete(c.queries, i) + continue + } + c.queries[i].OnDelete(object, gvk) + } +} + +// OnUpdate implements cache.ResourceEventHandler. +// Broadcasts the object change to all live query trackers after the initial sync. +func (c *liveQueryCache) OnUpdate(oldObj interface{}, newObj interface{}) { + oldObject, ok := oldObj.(client.Object) + if !ok { + return + } + newObject, ok := newObj.(client.Object) + if !ok { + return + } + gvk, err := apiutil.GVKForObject(oldObject, c.scheme) + if err != nil { + return + } + c.lock.Lock() + defer c.lock.Unlock() + for i := range c.queries { + // cleanup any stale queries. + if !c.queries[i].IsLive() { + delete(c.queries, i) + continue + } + c.queries[i].OnUpdate(oldObject, newObject, gvk) + } +} + +func newLiveQueryTracker(ctx context.Context) *liveQueryTracker { + return &liveQueryTracker{ctx: ctx, oids: make(map[schema.GroupVersionKind]set[types.UID])} +} + +// liveQueryTracker tracks objects of the same GVK for one live query. +// it can track individual objects as in when cache.Cache.Get() is +// called or the entire list when cache.Cache.List() is used. +type liveQueryTracker struct { + ctx context.Context + + lock sync.Mutex + oids map[schema.GroupVersionKind]set[types.UID] +} + +// IsLive returns true if live query is still active. +func (q *liveQueryTracker) IsLive() bool { + if _, ok := live_query.IsLive(q.ctx); ok { + return true + } + return false +} + +// OnCreate will notify the live query if tracking the entire GVK list. +func (q *liveQueryTracker) OnCreate(object client.Object, gvk schema.GroupVersionKind) { + var notify bool + // notify without holding the lock + defer func() { + if notify { + live_query.Trigger(q.ctx) + } + }() + q.lock.Lock() + defer q.lock.Unlock() + oids, ok := q.oids[gvk] + notify = ok && oids == nil +} + +// OnUpdate will notify the live query if tracking either object or the entire GVK list. +func (q *liveQueryTracker) OnUpdate(oldObject, newObject client.Object, gvk schema.GroupVersionKind) { + var notify bool + // notify without holding the lock + defer func() { + if notify { + live_query.Trigger(q.ctx) + } + }() + q.lock.Lock() + defer q.lock.Unlock() + oids, ok := q.oids[gvk] + // notify if tracking gvk list or either of the objects. + notify = ok && (oids == nil || oids.Contains(oldObject.GetUID()) || oids.Contains(newObject.GetUID())) +} + +// OnDelete will notify the live query if tracking the object or the entire GVK list. +func (q *liveQueryTracker) OnDelete(object client.Object, gvk schema.GroupVersionKind) { + var notify bool + // notify without holding the lock + defer func() { + if notify { + live_query.Trigger(q.ctx) + } + }() + q.lock.Lock() + defer q.lock.Unlock() + oids, ok := q.oids[gvk] + // notify if tracking gkv list or object. + notify = ok && (oids == nil || oids.Remove(object.GetUID())) +} + +// Track registers object for tracking. +func (q *liveQueryTracker) Track(oid types.UID, gvk schema.GroupVersionKind) { + q.lock.Lock() + defer q.lock.Unlock() + if uids, ok := q.oids[gvk]; ok { + // already tracking the entire list, skip. + if uids == nil { + return + } + // add object to track. + uids.Add(oid) + return + } + // register event handler for the new GVK. + // track object. + q.oids[gvk] = set[types.UID]{oid: struct{}{}} +} + +// TrackList begins tacking all objects of a given GVK. +func (q *liveQueryTracker) TrackList(gvk schema.GroupVersionKind) { + q.lock.Lock() + defer q.lock.Unlock() + // track list. + q.oids[gvk] = nil +} diff --git a/internal/cache/set.go b/internal/cache/set.go new file mode 100644 index 0000000..9a6e187 --- /dev/null +++ b/internal/cache/set.go @@ -0,0 +1,42 @@ +// Copyright 2023 Upbound Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +// map based set. +type set[T comparable] map[T]struct{} + +// Add returns true if values was not in set. +func (s set[T]) Add(v T) bool { + if _, ok := s[v]; ok { + return false + } + s[v] = struct{}{} + return true +} + +// Remove returns true if value was in set. +func (s set[T]) Remove(v T) bool { + if _, ok := s[v]; ok { + delete(s, v) + return true + } + return false +} + +// Contains returns true if value is in set. +func (s set[T]) Contains(v T) bool { + _, ok := s[v] + return ok +} diff --git a/internal/clients/live_query.go b/internal/clients/live_query.go deleted file mode 100644 index 54e7538..0000000 --- a/internal/clients/live_query.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright 2023 Upbound Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package clients - -import ( - "context" - "strings" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - kruntime "k8s.io/apimachinery/pkg/runtime" - "k8s.io/client-go/rest" - toolscache "k8s.io/client-go/tools/cache" - "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/client/apiutil" - - "github.com/upbound/xgql/internal/graph/extensions/live_query" -) - -// WithLiveQueries wraps NewCacheFn with a cache.Cache that tracks objects and lists -// and notifies the live query in request context of changes. -func WithLiveQueries(fn NewCacheFn) NewCacheFn { - return func(cfg *rest.Config, o cache.Options) (cache.Cache, error) { - c, err := fn(cfg, o) - if err != nil { - return nil, err - } - return &liveQueryCache{ - Cache: c, - scheme: o.Scheme, - }, nil - } -} - -func isSameObject(a, b client.Object) bool { - return a.GetName() == b.GetName() && a.GetNamespace() == b.GetNamespace() -} - -type liveQueryCache struct { - cache.Cache - scheme *kruntime.Scheme -} - -func (c *liveQueryCache) trackObject(ctx context.Context, co client.Object) error { - if !live_query.IsLive(ctx) { - return nil - } - i, err := c.Cache.GetInformer(ctx, co) - if err != nil { - return err - } - var r toolscache.ResourceEventHandlerRegistration - r, err = i.AddEventHandler(toolscache.FilteringResourceEventHandler{ - FilterFunc: func(obj interface{}) bool { - // If the context is done, remove the handler. - if !live_query.IsLive(ctx) { - _ = i.RemoveEventHandler(r) - return false - } - o, ok := obj.(client.Object) - if !ok { - return false - } - return isSameObject(co, o) - }, - Handler: toolscache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - live_query.NotifyChanged(ctx) - }, - UpdateFunc: func(oldObj, newObj interface{}) { - live_query.NotifyChanged(ctx) - }, - DeleteFunc: func(obj interface{}) { - live_query.NotifyChanged(ctx) - }, - }, - }) - return err -} - -func (c *liveQueryCache) getInformerForListObject(ctx context.Context, list client.ObjectList) (cache.Informer, error) { - gvk, err := apiutil.GVKForObject(list, c.scheme) - if err != nil { - return nil, err - } - - // We need the non-list GVK, so chop off the "List" from the end of the kind. - gvk.Kind = strings.TrimSuffix(gvk.Kind, "List") - - // Handle unstructured.UnstructuredList. - if _, isUnstructured := list.(kruntime.Unstructured); isUnstructured { - u := &unstructured.Unstructured{} - u.SetGroupVersionKind(gvk) - return c.Cache.GetInformer(ctx, u) - } - // Handle metav1.PartialObjectMetadataList. - if _, isPartialObjectMetadata := list.(*metav1.PartialObjectMetadataList); isPartialObjectMetadata { - pom := &metav1.PartialObjectMetadata{} - pom.SetGroupVersionKind(gvk) - return c.Cache.GetInformer(ctx, pom) - } - - return c.Cache.GetInformerForKind(ctx, gvk) -} - -func (c *liveQueryCache) trackObjectList(ctx context.Context, list client.ObjectList) error { - if !live_query.IsLive(ctx) { - return nil - } - i, err := c.getInformerForListObject(ctx, list) - if err != nil { - return err - } - var r toolscache.ResourceEventHandlerRegistration - r, err = i.AddEventHandler(toolscache.FilteringResourceEventHandler{ - FilterFunc: func(obj interface{}) bool { - if !live_query.IsLive(ctx) { - _ = i.RemoveEventHandler(r) - return false - } - return true - }, - Handler: toolscache.ResourceEventHandlerFuncs{ - AddFunc: func(_ interface{}) { - live_query.NotifyChanged(ctx) - }, - UpdateFunc: func(_, _ interface{}) { - live_query.NotifyChanged(ctx) - }, - DeleteFunc: func(_ interface{}) { - live_query.NotifyChanged(ctx) - }, - }, - }) - return err -} - -// Get implements cache.Cache. It wraps an underlying cache.Cache and sets up an Informer -// event handler that marks current live query as dirty if the current context has a live query. -func (c *liveQueryCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { - if err := c.Cache.Get(ctx, key, obj, opts...); err != nil { - return err - } - return c.trackObject(ctx, obj) -} - -// List implements cache.Cache. It wraps an underlying cache.Cache and sets up an Informer -// event handler that marks current live query as dirty if the current context has a live query. -func (c *liveQueryCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { - if err := c.Cache.List(ctx, list, opts...); err != nil { - return err - } - return c.trackObjectList(ctx, list) -} diff --git a/internal/graph/extensions/live_query/live_query.go b/internal/codegen/gqlgen/extensions/live_query/codegen.go similarity index 51% rename from internal/graph/extensions/live_query/live_query.go rename to internal/codegen/gqlgen/extensions/live_query/codegen.go index 2a74434..5d3f831 100644 --- a/internal/graph/extensions/live_query/live_query.go +++ b/internal/codegen/gqlgen/extensions/live_query/codegen.go @@ -15,15 +15,10 @@ package live_query import ( - "context" - _ "embed" "fmt" - "strings" "github.com/99designs/gqlgen/codegen" "github.com/99designs/gqlgen/codegen/config" - "github.com/99designs/gqlgen/codegen/templates" - "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/plugin" "github.com/vektah/gqlparser/v2/ast" ) @@ -42,8 +37,7 @@ const ( ) // LiveQuery is a graphql.HandlerExtension that enables live queries. -type LiveQuery struct { -} +type LiveQuery struct{} var _ interface { plugin.Plugin @@ -52,11 +46,6 @@ var _ interface { plugin.CodeGenerator } = LiveQuery{} -var _ interface { - graphql.HandlerExtension - graphql.OperationInterceptor -} = LiveQuery{} - // Name implements plugin.Plugin. func (LiveQuery) Name() string { return extName @@ -84,34 +73,20 @@ func (LiveQuery) MutateConfig(cfg *config.Config) error { return nil } -//go:embed resolve_live_query.gotpl -var liveQueryTemplate string - // GenerateCode implements plugin.CodeGenerator. func (LiveQuery) GenerateCode(cfg *codegen.Data) error { for _, object := range cfg.Objects { if object.Name != typeSubscription { continue } - for _, f := range object.Fields { - if f.Name != fieldName { + for i := range object.Fields { + if object.Fields[i].Name != fieldName { continue } - f.TypeReference.IsMarshaler = true - f.IsResolver = false - f.GoFieldType = codegen.GoFieldMethod - f.GoReceiverName = "ec" - f.GoFieldName = "__resolve_liveQuery" - f.MethodHasContext = true - - return templates.Render(templates.Options{ - PackageName: cfg.Config.Exec.Package, - Filename: cfg.Config.Exec.Dir() + "/resolve_live_query.gen.go", - Data: f, - GeneratedHeader: true, - Packages: cfg.Config.Packages, - Template: liveQueryTemplate, - }) + // remove field from codegen. need to mark it as marshaller to avoid generating marshalling code. + object.Fields[i].TypeReference.IsMarshaler = true + object.Fields = append(object.Fields[:i], object.Fields[i+1:]...) + break } } return nil @@ -144,89 +119,3 @@ func (LiveQuery) InjectSourceLate(schema *ast.Schema) *ast.Source { Input: subscriptionDefinition, } } - -// ExtensionName implements graphql.HandlerExtension -func (LiveQuery) ExtensionName() string { - return extName -} - -// Validate implements graphql.HandlerExtension -func (l LiveQuery) Validate(s graphql.ExecutableSchema) error { - subscriptionType, ok := s.Schema().Types[typeSubscription] - if !ok { - return fmt.Errorf("%q type not found", typeSubscription) - } - - field := subscriptionType.Fields.ForName(fieldName) - if field == nil { - return fmt.Errorf("%q type is missing %q field", typeSubscription, fieldName) - } - if field.Type.String() != typeQuery { - return fmt.Errorf("%q field on %q is not of type %q", fieldName, typeSubscription, typeQuery) - } - if field.Arguments.ForName(argThrottle) == nil { - return fmt.Errorf("%q field on %q is missing the %q argument", fieldName, typeSubscription, argThrottle) - } - - return nil -} - -type patch struct { - Revision int `json:"revision"` - JSONPatch []Operation `json:"jsonPatch,omitempty"` -} - -// InterceptOperation implements graphql.OperationInterceptor -func (l LiveQuery) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { - oc := graphql.GetOperationContext(ctx) - if oc.Operation.Operation != ast.Subscription { - return next(ctx) - } - fields := graphql.CollectFields(oc, oc.Operation.SelectionSet, []string{typeSubscription}) - if len(fields) != 1 { - return next(ctx) - } - field := fields[0] - if field.Name != fieldName { - return next(ctx) - } - ctx, cancel := context.WithCancel(ctx) - handler := next(ctx) - var ( - prevData strings.Builder - revision int - ) - return func(ctx context.Context) *graphql.Response { - for { - resp := handler(ctx) - if resp == nil { - cancel() - return nil - } - data := resp.Data - // Compare new data with previous response. - if prevData.Len() > 0 { - diff, err := CreateJSONPatch(prevData.String(), string(data)) - if err != nil { - cancel() - panic(err) - } - // response is the same, skip it. - if len(diff) == 0 { - continue - } - // reset data and add patch extension. - resp.Data = nil - resp.Extensions["patch"] = patch{ - Revision: revision, - JSONPatch: diff, - } - } - revision++ - // keep current data as previous response. - prevData.Reset() - _, _ = prevData.Write(data) - return resp - } - } -} diff --git a/internal/graph/generate/gqlgen/main.go b/internal/codegen/gqlgen/main.go similarity index 97% rename from internal/graph/generate/gqlgen/main.go rename to internal/codegen/gqlgen/main.go index 456f503..73b17bf 100644 --- a/internal/graph/generate/gqlgen/main.go +++ b/internal/codegen/gqlgen/main.go @@ -28,7 +28,7 @@ import ( "github.com/99designs/gqlgen/api" "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/plugin/modelgen" - "github.com/upbound/xgql/internal/graph/extensions/live_query" + "github.com/upbound/xgql/internal/codegen/gqlgen/extensions/live_query" "github.com/vektah/gqlparser/v2/ast" ) diff --git a/internal/generate.go b/internal/generate.go index 821d47f..832f9f1 100644 --- a/internal/generate.go +++ b/internal/generate.go @@ -19,7 +19,7 @@ // https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module // Generate xgql models, bindings, etc per gqlgen.yaml. -//go:generate go run -tags generate ./graph/generate/gqlgen +//go:generate go run -tags generate ./codegen/gqlgen // Add license headers to all files. //go:generate go run -tags generate github.com/google/addlicense -v -c "Upbound Inc" . ../cmd diff --git a/internal/graph/extensions/live_query/resolve_live_query.gotpl b/internal/graph/extensions/live_query/resolve_live_query.gotpl deleted file mode 100644 index aeeb376..0000000 --- a/internal/graph/extensions/live_query/resolve_live_query.gotpl +++ /dev/null @@ -1,37 +0,0 @@ -{{ reserveImport "context" }} -{{ reserveImport "time" }} -{{ reserveImport "github.com/99designs/gqlgen/graphql" }} -{{ reserveImport "github.com/upbound/xgql/internal/graph/extensions/live_query" }} - -func (ec *executionContext) __resolve_liveQuery(ctx context.Context, throttle *int) (<-chan graphql.Marshaler, error) { - out := make(chan graphql.Marshaler) - sel := graphql.GetFieldContext(ctx).Field.Selections - go func() { - defer close(out) - lqx, needsRefresh := live_query.WithLiveQuery(ctx) - // resolve once with live query context. - out <- ec._Query(lqx, sel) - d := 200 * time.Millisecond - if throttle != nil && *throttle > 0 { - d = time.Duration(*throttle) * time.Millisecond - } - throttle := time.NewTicker(d) - defer throttle.Stop() - for { - select { - case <-throttle.C: - if needsRefresh() { - lqx := graphql.WithFreshResponseContext(ctx) - out <- ec._Query(lqx, sel) - for _, err := range graphql.GetErrors(lqx) { - graphql.AddError(ctx, err) - } - } - throttle.Reset(d) - case <-ctx.Done(): - return - } - } - }() - return out, nil -} diff --git a/internal/graph/extensions/live_query/runtime.go b/internal/graph/extensions/live_query/runtime.go deleted file mode 100644 index ab532dc..0000000 --- a/internal/graph/extensions/live_query/runtime.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2023 Upbound Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package live_query - -import ( - "context" - "sync" - "sync/atomic" -) - -// liveQuery is set in context by the generated liveQuery resolver. -// Resolvers that need to refresh the query can use IsLive() and NotifyChanged() -// to check if the live query is still running and trigger a live query refresh -// accordingly. -type liveQuery struct { - doneCh <-chan struct{} - hasChanges uint32 - - mu sync.Mutex - cond *sync.Cond -} - -// HasChangesFn is a func that can be used to check if live query needs to be -// refreshed. It is used in generated live query resolver. -type HasChangesFn func() bool - -type liveQueryKey struct{} - -var liveQueryCtxKey = liveQueryKey{} - -// WithLiveQuery sets LiveQuery on derived context and returns a callable for -// checking if live query needs to be refreshed. This is used in generated -// live query resolver to set up periodic live query refresh if changes occurred. -func WithLiveQuery(ctx context.Context) (context.Context, HasChangesFn) { - lq := &liveQuery{doneCh: ctx.Done()} - lq.cond = sync.NewCond(&lq.mu) - return context.WithValue(ctx, liveQueryCtxKey, lq), func() bool { - if atomic.CompareAndSwapUint32(&lq.hasChanges, 1, 0) { - return true - } - lq.mu.Lock() - defer lq.mu.Unlock() - for !atomic.CompareAndSwapUint32(&lq.hasChanges, 1, 0) { - lq.cond.Wait() - } - return true - } -} - -// IsLive returns true if this is a live query context and query is active. -// TODO(avalanche123): add tests. -func IsLive(ctx context.Context) bool { - if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { - select { - case <-lq.doneCh: - return false - default: - return true - } - } - return false -} - -// NotifyChanged notifies live query of a change. -// TODO(avalanche123): add tests. -func NotifyChanged(ctx context.Context) { - if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { - atomic.StoreUint32(&lq.hasChanges, 1) - lq.cond.Broadcast() - } -} diff --git a/internal/graph/generated/generated.go b/internal/graph/generated/generated.go index 36795ac..cca95dc 100644 --- a/internal/graph/generated/generated.go +++ b/internal/graph/generated/generated.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "io" "strconv" "sync" "sync/atomic" @@ -625,7 +624,6 @@ type ComplexityRoot struct { } Subscription struct { - __resolve_liveQuery func(childComplexity int, throttle *int) int } TypeReference struct { @@ -3307,18 +3305,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.SecretReference.Namespace(childComplexity), true - case "Subscription.liveQuery": - if e.complexity.Subscription.__resolve_liveQuery == nil { - break - } - - args, err := ec.field_Subscription_liveQuery_args(context.TODO(), rawArgs) - if err != nil { - return 0, false - } - - return e.complexity.Subscription.__resolve_liveQuery(childComplexity, args["throttle"].(*int)), true - case "TypeReference.apiVersion": if e.complexity.TypeReference.APIVersion == nil { break @@ -7808,21 +7794,6 @@ func (ec *executionContext) field_Secret_fieldPath_args(ctx context.Context, raw return args, nil } -func (ec *executionContext) field_Subscription_liveQuery_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { - var err error - args := map[string]interface{}{} - var arg0 *int - if tmp, ok := rawArgs["throttle"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("throttle")) - arg0, err = ec.unmarshalOInt2ᚖint(ctx, tmp) - if err != nil { - return nil, err - } - } - args["throttle"] = arg0 - return args, nil -} - func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -24494,104 +24465,6 @@ func (ec *executionContext) fieldContext_SecretReference_namespace(ctx context.C return fc, nil } -func (ec *executionContext) _Subscription_liveQuery(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { - fc, err := ec.fieldContext_Subscription_liveQuery(ctx, field) - if err != nil { - return nil - } - ctx = graphql.WithFieldContext(ctx, fc) - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = nil - } - }() - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.__resolve_liveQuery(ctx, fc.Args["throttle"].(*int)) - }) - if err != nil { - ec.Error(ctx, err) - return nil - } - if resTmp == nil { - return nil - } - return func(ctx context.Context) graphql.Marshaler { - select { - case res, ok := <-resTmp.(<-chan graphql.Marshaler): - if !ok { - return nil - } - return graphql.WriterFunc(func(w io.Writer) { - w.Write([]byte{'{'}) - graphql.MarshalString(field.Alias).MarshalGQL(w) - w.Write([]byte{':'}) - ec.marshalOQuery2githubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚐMarshaler(ctx, field.Selections, res).MarshalGQL(w) - w.Write([]byte{'}'}) - }) - case <-ctx.Done(): - return nil - } - } -} - -func (ec *executionContext) fieldContext_Subscription_liveQuery(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { - fc = &graphql.FieldContext{ - Object: "Subscription", - Field: field, - IsMethod: true, - IsResolver: false, - Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "kubernetesResource": - return ec.fieldContext_Query_kubernetesResource(ctx, field) - case "kubernetesResources": - return ec.fieldContext_Query_kubernetesResources(ctx, field) - case "events": - return ec.fieldContext_Query_events(ctx, field) - case "secret": - return ec.fieldContext_Query_secret(ctx, field) - case "configMap": - return ec.fieldContext_Query_configMap(ctx, field) - case "providers": - return ec.fieldContext_Query_providers(ctx, field) - case "providerRevisions": - return ec.fieldContext_Query_providerRevisions(ctx, field) - case "customResourceDefinitions": - return ec.fieldContext_Query_customResourceDefinitions(ctx, field) - case "configurations": - return ec.fieldContext_Query_configurations(ctx, field) - case "configurationRevisions": - return ec.fieldContext_Query_configurationRevisions(ctx, field) - case "compositeResourceDefinitions": - return ec.fieldContext_Query_compositeResourceDefinitions(ctx, field) - case "compositions": - return ec.fieldContext_Query_compositions(ctx, field) - case "crossplaneResourceTree": - return ec.fieldContext_Query_crossplaneResourceTree(ctx, field) - case "__schema": - return ec.fieldContext_Query___schema(ctx, field) - case "__type": - return ec.fieldContext_Query___type(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type Query", field.Name) - }, - } - defer func() { - if r := recover(); r != nil { - err = ec.Recover(ctx, r) - ec.Error(ctx, err) - } - }() - ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Subscription_liveQuery_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { - ec.Error(ctx, err) - return fc, err - } - return fc, nil -} - func (ec *executionContext) _TypeReference_apiVersion(ctx context.Context, field graphql.CollectedField, obj *model.TypeReference) (ret graphql.Marshaler) { fc, err := ec.fieldContext_TypeReference_apiVersion(ctx, field) if err != nil { @@ -32515,8 +32388,6 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection } switch fields[0].Name { - case "liveQuery": - return ec._Subscription_liveQuery(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/internal/graph/generated/resolve_live_query.gen.go b/internal/graph/generated/resolve_live_query.gen.go deleted file mode 100644 index 0d3e707..0000000 --- a/internal/graph/generated/resolve_live_query.gen.go +++ /dev/null @@ -1,44 +0,0 @@ -// Code generated by github.com/99designs/gqlgen, DO NOT EDIT. - -package generated - -import ( - "context" - "time" - - "github.com/99designs/gqlgen/graphql" - "github.com/upbound/xgql/internal/graph/extensions/live_query" -) - -func (ec *executionContext) __resolve_liveQuery(ctx context.Context, throttle *int) (<-chan graphql.Marshaler, error) { - out := make(chan graphql.Marshaler) - sel := graphql.GetFieldContext(ctx).Field.Selections - go func() { - defer close(out) - lqx, needsRefresh := live_query.WithLiveQuery(ctx) - // resolve once with live query context. - out <- ec._Query(lqx, sel) - d := 200 * time.Millisecond - if throttle != nil && *throttle > 0 { - d = time.Duration(*throttle) * time.Millisecond - } - throttle := time.NewTicker(d) - defer throttle.Stop() - for { - select { - case <-throttle.C: - if needsRefresh() { - lqx := graphql.WithFreshResponseContext(ctx) - out <- ec._Query(lqx, sel) - for _, err := range graphql.GetErrors(lqx) { - graphql.AddError(ctx, err) - } - } - throttle.Reset(d) - case <-ctx.Done(): - return - } - } - }() - return out, nil -} diff --git a/internal/live_query/extension.go b/internal/live_query/extension.go new file mode 100644 index 0000000..407e223 --- /dev/null +++ b/internal/live_query/extension.go @@ -0,0 +1,180 @@ +// Copyright 2023 Upbound Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package live_query + +import ( + "context" + "fmt" + "time" + + "github.com/99designs/gqlgen/graphql" + jd "github.com/josephburnett/jd/lib" + "github.com/vektah/gqlparser/v2/ast" + "github.com/vektah/gqlparser/v2/gqlerror" +) + +const ( + // extName is the name of the extension + extName = "LiveQuery" + // fieldName is the field name exposing live queries. + fieldName = "liveQuery" + // typeQuery is the schema type that will be made subscribeable. + typeQuery = "Query" + // typeSubscription is the schema type that will expose "liveQuery" field. + typeSubscription = "Subscription" + // argThrottle is the name of the "trottle" argument for the "liveQuery" field. + argThrottle = "throttle" +) + +var _ interface { + graphql.HandlerExtension + graphql.OperationInterceptor + graphql.OperationParameterMutator + graphql.OperationContextMutator +} = LiveQuery{} + +// LiveQuery is a graphql.HandlerExtension that enables live queries. +type LiveQuery struct{} + +// ExtensionName implements graphql.HandlerExtension +func (LiveQuery) ExtensionName() string { + return extName +} + +// Validate implements graphql.HandlerExtension +func (l LiveQuery) Validate(s graphql.ExecutableSchema) error { + subscriptionType, ok := s.Schema().Types[typeSubscription] + if !ok { + return fmt.Errorf("%q type not found", typeSubscription) + } + + field := subscriptionType.Fields.ForName(fieldName) + if field == nil { + return fmt.Errorf("%q type is missing %q field", typeSubscription, fieldName) + } + if field.Type.String() != typeQuery { + return fmt.Errorf("%q field on %q is not of type %q", fieldName, typeSubscription, typeQuery) + } + if field.Arguments.ForName(argThrottle) == nil { + return fmt.Errorf("%q field on %q is missing the %q argument", fieldName, typeSubscription, argThrottle) + } + + return nil +} + +type patch struct { + Revision int `json:"revision"` + JSONPatch []Operation `json:"jsonPatch,omitempty"` +} + +type LiveQueryStats struct { + Revisions map[string]int `json:"revision"` + PrevData map[string]jd.JsonNode `json:"prevData"` + Throttle int `json:"throttle,omitempty"` +} + +func (l LiveQuery) MutateOperationParameters(ctx context.Context, request *graphql.RawParams) *gqlerror.Error { + return nil +} + +// MutateOperationContext implements graphql.OperationContextMutator +func (l LiveQuery) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error { + // we're only interested in subscriptions + if rc.Operation.Operation != ast.Subscription { + return nil + } + fields := graphql.CollectFields(rc, rc.Operation.SelectionSet, []string{typeSubscription}) + if len(fields) != 1 { + return nil + } + // check that the subscription field is "liveQuery" + field := fields[0] + if field.Name != fieldName { + return nil + } + operationCopy := *rc.Operation + operationCopy.Operation = ast.Query + operationCopy.SelectionSet = field.SelectionSet + rc.Operation = &operationCopy + rc.Stats.SetExtension(extName, &LiveQueryStats{ + Throttle: (int)(field.ArgumentMap(rc.Variables)["throttle"].(int64)), + Revisions: make(map[string]int), + PrevData: make(map[string]jd.JsonNode), + }) + return nil +} + +// InterceptOperation implements graphql.OperationInterceptor +func (l LiveQuery) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { //nolint:gocyclo + oc := graphql.GetOperationContext(ctx) + lqs, ok := oc.Stats.GetExtension(extName).(*LiveQueryStats) + if !ok { + return next(ctx) + } + throttle := time.Duration(lqs.Throttle) * time.Millisecond + handler := next(ctx) + var lq *liveQuery + return func(ctx context.Context) *graphql.Response { + // create live query context if not exists. + if lq == nil { + lq, ctx = withLiveQuery(ctx, throttle) + } + for { + // create the handler when live query is ready. + if handler == nil { + select { + case <-lq.Ready(): + handler = next(ctx) + case <-ctx.Done(): + return nil + } + } + resp := handler(ctx) + // reached the end of the handler, including deferreds. + if resp == nil { + // reset live query and handler for waiting. + handler = nil + lq.Reset() + continue + } + // propagate errors + data, err := jd.ReadJsonString(string(resp.Data)) + if err != nil { + panic(err) + } + if prevData, ok := lqs.PrevData[resp.Path.String()]; ok { + diff, err := CreateJSONPatch(prevData, data) + if err != nil { + panic(err) + } + if len(diff) > 0 { + // reset data and add patch extension. + resp.Data = nil + resp.Extensions["patch"] = patch{ + Revision: lqs.Revisions[resp.Path.String()], + JSONPatch: diff, + } + } else if len(resp.Errors) == 0 { + // nothing changed, wait for next change. + continue + } + } + lqs.Revisions[resp.Path.String()] += 1 + // keep current data as previous response. + lqs.PrevData[resp.Path.String()] = data + return resp + } + } +} diff --git a/internal/graph/extensions/live_query/json_patch.go b/internal/live_query/json_patch.go similarity index 88% rename from internal/graph/extensions/live_query/json_patch.go rename to internal/live_query/json_patch.go index a2adc5e..2139a59 100644 --- a/internal/graph/extensions/live_query/json_patch.go +++ b/internal/live_query/json_patch.go @@ -44,16 +44,8 @@ type Operation struct { // CreateJSONPatch creates a JSON patch between two json values. // TODO(avalanche123): add tests for json patch generation. -func CreateJSONPatch(x, y string) ([]Operation, error) { - xn, err := jd.ReadJsonString(x) - if err != nil { - return nil, err - } - yn, err := jd.ReadJsonString(y) - if err != nil { - return nil, err - } - raw, err := xn.Diff(yn).RenderPatch() +func CreateJSONPatch(x, y jd.JsonNode) ([]Operation, error) { + raw, err := x.Diff(y).RenderPatch() if err != nil { return nil, err } diff --git a/internal/live_query/runtime.go b/internal/live_query/runtime.go new file mode 100644 index 0000000..ef7e180 --- /dev/null +++ b/internal/live_query/runtime.go @@ -0,0 +1,166 @@ +// Copyright 2023 Upbound Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package live_query + +import ( + "context" + "sync/atomic" + "time" + + "k8s.io/utils/clock" +) + +// liveQuery is set in context by LiveQuery extension. +// Resolvers that need to refresh the query can use IsLive() and NotifyChanged() +// to check if the live query is still running and trigger a live query refresh +// accordingly. +type liveQuery struct { + // a unique id to make it easier to differentiate queries from resolvers. + id uint64 + throttle time.Duration + doneCh <-chan struct{} + actionsCh chan liveQueryAction + changesCh chan struct{} + clock clock.Clock +} + +// liveQueryAction is a signal +type liveQueryAction int + +const ( + fire liveQueryAction = iota + rearm +) + +// debounce ensures that live query only triggers Ready() channel +// at most every throttle interval. +func (lq *liveQuery) debounce() { //nolint:gocyclo + var ( + // channel that will trigger after at least throttle interval since the previous trigger. + timer <-chan time.Time + + // the debounce loop can be in armed state, at which point it will debounce fired event + // onto the changes channel after the throttle period. + // if debounce loop is not armed, it means the live query is being resolved. in this case, + // the fact that an event has occurred is marked in the fired bool. then, the next time + // the live query is rearmed, the throttle timer will be set at the same time. + // this way the query becomes ready after the throttle period and no changes are lost. + armed, fired bool + ) + defer close(lq.changesCh) + // Start debouncing + for { + select { + case a := <-lq.actionsCh: + switch a { + case fire: + if armed { + timer = lq.clock.After(lq.throttle) + continue + } + fired = true + continue + case rearm: + if fired { + timer = lq.clock.After(lq.throttle) + continue + } + armed = true + continue + } + case <-timer: + case <-lq.doneCh: + return + } + fired = false + armed = false + timer = nil + select { + case lq.changesCh <- struct{}{}: + case <-lq.doneCh: + return + } + } +} + +// Ready returns a channel that will be notified when a new change is ready. +func (lq *liveQuery) Ready() <-chan struct{} { + select { + // if query is done, return nil channel + case <-lq.doneCh: + return nil + default: + return lq.changesCh + } +} + +// Reset resets the live query throttling mechanism. +func (lq *liveQuery) Reset() { + select { + case lq.actionsCh <- rearm: + case <-lq.doneCh: + } +} + +// Trigger triggers the live query's Fired channel after the throttle period. +func (lq *liveQuery) Trigger() { + select { + case lq.actionsCh <- fire: + case <-lq.doneCh: + } +} + +type liveQueryKey struct{} + +var ( + liveQueryCtxKey = liveQueryKey{} + liveQueryIds = atomic.Uint64{} +) + +// withLiveQuery creates a new liveQuery and returns it with a modified context. +func withLiveQuery(ctx context.Context, throttle time.Duration) (*liveQuery, context.Context) { + lq := &liveQuery{ + id: liveQueryIds.Add(1), + throttle: throttle, + clock: clock.RealClock{}, + doneCh: ctx.Done(), + actionsCh: make(chan liveQueryAction), + changesCh: make(chan struct{}), + } + go lq.debounce() + return lq, context.WithValue(ctx, liveQueryCtxKey, lq) +} + +// IsLive returns query id and true if this is a live query context and query is active. +// TODO(avalanche123): add tests. +func IsLive(ctx context.Context) (uint64, bool) { + if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { + select { + case <-lq.doneCh: + return 0, false + default: + return lq.id, true + } + } + return 0, false +} + +// Trigger notifies live query of a change. +// TODO(avalanche123): add tests. +func Trigger(ctx context.Context) { + if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { + lq.Trigger() + } +} diff --git a/internal/live_query/runtime_test.go b/internal/live_query/runtime_test.go new file mode 100644 index 0000000..814866a --- /dev/null +++ b/internal/live_query/runtime_test.go @@ -0,0 +1,132 @@ +// Copyright 2023 Upbound Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package live_query + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + ktesting "k8s.io/utils/clock/testing" +) + +func Test_liveQuery_debounce(t *testing.T) { + t.Parallel() + + type step func(clock *ktesting.FakeClock, lq *liveQuery) + type steps []step + fire := func() step { + return func(_ *ktesting.FakeClock, lq *liveQuery) { + lq.Trigger() + } + } + arm := func() step { + return func(_ *ktesting.FakeClock, lq *liveQuery) { + lq.Reset() + } + } + sleep := func() step { + return func(clock *ktesting.FakeClock, lq *liveQuery) { + clock.Sleep(lq.throttle) + // allow the clock timer to fire before applying the next step + time.Sleep(1 * time.Millisecond) + } + } + + tests := map[string]struct { + reason string + throttle time.Duration + steps steps + changes []time.Duration + }{ + "MultipleFirings": { + reason: "coalesces all fired events into one", + throttle: 1 * time.Second, + steps: steps{ + arm(), + fire(), + fire(), + fire(), + sleep(), + }, + changes: []time.Duration{ + 1 * time.Second, + }, + }, + "NotArmed": { + reason: "doesn't fire until armed", + throttle: 1 * time.Second, + steps: steps{ + fire(), + fire(), + fire(), + sleep(), + }, + }, + "Rearmed": { + reason: "fires again when rearmed", + throttle: 1 * time.Second, + steps: steps{ + arm(), + fire(), + sleep(), + arm(), + fire(), + sleep(), + arm(), + fire(), + sleep(), + }, + changes: []time.Duration{ + 1 * time.Second, + 2 * time.Second, + 3 * time.Second, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + now := time.Now() + clock := ktesting.NewFakeClock(now) + doneCh := make(chan struct{}) + changesCh := make(chan struct{}) + actionsCh := make(chan liveQueryAction) + lq := &liveQuery{ + throttle: tt.throttle, + doneCh: doneCh, + actionsCh: actionsCh, + changesCh: changesCh, + clock: clock, + } + var changes []time.Duration + startCh := make(chan struct{}) + go func() { + defer close(doneCh) + close(startCh) + for _, step := range tt.steps { + step(clock, lq) + } + }() + go lq.debounce() + <-startCh + for range lq.Ready() { + changes = append(changes, clock.Now().Sub(now)) + } + if diff := cmp.Diff(tt.changes, changes); diff != "" { + t.Errorf("debounce(...): -want, +got:\n%s", diff) + } + }) + } +}