From 42862689d35e2ead7ce052d2bc2c10cda30c48ff Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Tue, 7 May 2024 20:11:18 +0000 Subject: [PATCH] Fix etcd compat with result limits / pagination * Sort list results by name, not revision. List continuation (start key) functionality requires that keys be returned in ascending order. * Only count keys remaining after the start key, not the total number of keys in the prefix. * Return current revision in header along with error when unable to range on key. Signed-off-by: Brad Davidson --- pkg/drivers/generic/generic.go | 30 ++++++++++-------------------- pkg/drivers/nats/backend.go | 4 ++-- pkg/drivers/nats/backend_test.go | 6 +++--- pkg/drivers/nats/kv.go | 2 +- pkg/drivers/nats/logger.go | 8 ++++---- pkg/logstructured/logstructured.go | 6 +++--- pkg/logstructured/sqllog/sql.go | 6 +++--- pkg/server/get.go | 6 +----- pkg/server/list.go | 25 ++++++++----------------- pkg/server/types.go | 6 +++--- 10 files changed, 38 insertions(+), 61 deletions(-) diff --git a/pkg/drivers/generic/generic.go b/pkg/drivers/generic/generic.go index 42e99a51..4ef53f2a 100644 --- a/pkg/drivers/generic/generic.go +++ b/pkg/drivers/generic/generic.go @@ -29,7 +29,7 @@ const ( var _ server.Dialect = (*Generic)(nil) var ( - columns = "kv.id AS theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value" + columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value" revSQL = ` SELECT MAX(rkv.id) AS id FROM kine AS rkv` @@ -39,16 +39,6 @@ var ( FROM kine AS crkv WHERE crkv.name = 'compact_rev_key'` - idOfKey = ` - AND - mkv.id <= ? AND - mkv.id > ( - SELECT MAX(ikv.id) AS id - FROM kine AS ikv - WHERE - ikv.name = ? AND - ikv.id <= ?)` - listSQL = fmt.Sprintf(` SELECT * FROM ( @@ -66,7 +56,7 @@ var ( kv.deleted = 0 OR ? ) AS lkv - ORDER BY lkv.theid ASC + ORDER BY lkv.thename ASC `, revSQL, compactRevSQL, columns) ) @@ -218,19 +208,19 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered), ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), - GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered), + GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?"), paramCharacter, numbered), CountCurrentSQL: q(fmt.Sprintf(` SELECT (%s), COUNT(c.theid) FROM ( %s - ) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered), + ) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ?")), paramCharacter, numbered), CountRevisionSQL: q(fmt.Sprintf(` SELECT (%s), COUNT(c.theid) FROM ( %s - ) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered), + ) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?")), paramCharacter, numbered), AfterSQL: q(fmt.Sprintf(` SELECT (%s), (%s), %s @@ -364,27 +354,27 @@ func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revi if limit > 0 { sql = fmt.Sprintf("%s LIMIT %d", sql, limit) } - return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted) + return d.query(ctx, sql, prefix, startKey, revision, includeDeleted) } -func (d *Generic) CountCurrent(ctx context.Context, prefix string) (int64, int64, error) { +func (d *Generic) CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error) { var ( rev sql.NullInt64 id int64 ) - row := d.queryRow(ctx, d.CountCurrentSQL, prefix, false) + row := d.queryRow(ctx, d.CountCurrentSQL, prefix, startKey, false) err := row.Scan(&rev, &id) return rev.Int64, id, err } -func (d *Generic) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { +func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) { var ( rev sql.NullInt64 id int64 ) - row := d.queryRow(ctx, d.CountRevisionSQL, prefix, revision, false) + row := d.queryRow(ctx, d.CountRevisionSQL, prefix, startKey, revision, false) err := row.Scan(&rev, &id) return rev.Int64, id, err } diff --git a/pkg/drivers/nats/backend.go b/pkg/drivers/nats/backend.go index 5cbbcdb0..d094e200 100644 --- a/pkg/drivers/nats/backend.go +++ b/pkg/drivers/nats/backend.go @@ -135,8 +135,8 @@ func (b *Backend) CurrentRevision(ctx context.Context) (int64, error) { } // Count returns an exact count of the number of matching keys and the current revision of the database. -func (b *Backend) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { - count, err := b.kv.Count(ctx, prefix, revision) +func (b *Backend) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) { + count, err := b.kv.Count(ctx, prefix, startKey, revision) if err != nil { return 0, 0, err } diff --git a/pkg/drivers/nats/backend_test.go b/pkg/drivers/nats/backend_test.go index e5ba3f33..64e426ef 100644 --- a/pkg/drivers/nats/backend_test.go +++ b/pkg/drivers/nats/backend_test.go @@ -129,14 +129,14 @@ func TestBackend_Create(t *testing.T) { time.Sleep(2 * time.Millisecond) - srev, count, err := b.Count(ctx, "/", 0) + srev, count, err := b.Count(ctx, "/", "", 0) noErr(t, err) expEqual(t, 4, srev) expEqual(t, 4, count) time.Sleep(time.Second) - srev, count, err = b.Count(ctx, "/", 0) + srev, count, err = b.Count(ctx, "/", "", 0) noErr(t, err) expEqual(t, 4, srev) expEqual(t, 3, count) @@ -149,7 +149,7 @@ func TestBackend_Create(t *testing.T) { time.Sleep(2 * time.Millisecond) - srev, count, err = b.Count(ctx, "/", 0) + srev, count, err = b.Count(ctx, "/", "", 0) noErr(t, err) expEqual(t, 6, srev) expEqual(t, 4, count) diff --git a/pkg/drivers/nats/kv.go b/pkg/drivers/nats/kv.go index 3949f8fe..b2d95b79 100644 --- a/pkg/drivers/nats/kv.go +++ b/pkg/drivers/nats/kv.go @@ -376,7 +376,7 @@ type keySeq struct { seq uint64 } -func (e *KeyValue) Count(ctx context.Context, prefix string, revision int64) (int64, error) { +func (e *KeyValue) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) { it := e.bt.Iter() if prefix != "" { diff --git a/pkg/drivers/nats/logger.go b/pkg/drivers/nats/logger.go index 913cf7d0..fbfb0d0f 100644 --- a/pkg/drivers/nats/logger.go +++ b/pkg/drivers/nats/logger.go @@ -81,15 +81,15 @@ func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit } // Count returns an exact count of the number of matching keys and the current revision of the database -func (b *BackendLogger) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) { +func (b *BackendLogger) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) { start := time.Now() defer func() { dur := time.Since(start) - fStr := "COUNT %s, rev=%d => rev=%d, count=%d, err=%v, duration=%s" - b.logMethod(dur, fStr, prefix, revision, revRet, count, err, dur) + fStr := "COUNT %s, start=%s, rev=%d => rev=%d, count=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, prefix, startKey, revision, revRet, count, err, dur) }() - return b.backend.Count(ctx, prefix, revision) + return b.backend.Count(ctx, prefix, startKey, revision) } func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) { diff --git a/pkg/logstructured/logstructured.go b/pkg/logstructured/logstructured.go index f93beb0b..5db392d0 100644 --- a/pkg/logstructured/logstructured.go +++ b/pkg/logstructured/logstructured.go @@ -21,9 +21,9 @@ type Log interface { CompactRevision(ctx context.Context) (int64, error) CurrentRevision(ctx context.Context) (int64, error) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error) + Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error) Watch(ctx context.Context, prefix string) <-chan []*server.Event - Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) Append(ctx context.Context, event *server.Event) (int64, error) DbSize(ctx context.Context) (int64, error) } @@ -199,11 +199,11 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit return rev, kvs, nil } -func (l *LogStructured) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) { +func (l *LogStructured) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) { defer func() { logrus.Tracef("COUNT %s, rev=%d => rev=%d, count=%d, err=%v", prefix, revision, revRet, count, err) }() - rev, count, err := l.log.Count(ctx, prefix, revision) + rev, count, err := l.log.Count(ctx, prefix, startKey, revision) if err != nil { return 0, 0, err } diff --git a/pkg/logstructured/sqllog/sql.go b/pkg/logstructured/sqllog/sql.go index a6105507..06c4486f 100644 --- a/pkg/logstructured/sqllog/sql.go +++ b/pkg/logstructured/sqllog/sql.go @@ -526,15 +526,15 @@ func canSkipRevision(rev, skip int64, skipTime time.Time) bool { return rev == skip && time.Since(skipTime) > time.Second } -func (s *SQLLog) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { +func (s *SQLLog) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) { if strings.HasSuffix(prefix, "/") { prefix += "%" } if revision == 0 { - return s.d.CountCurrent(ctx, prefix) + return s.d.CountCurrent(ctx, prefix, startKey) } - return s.d.Count(ctx, prefix, revision) + return s.d.Count(ctx, prefix, startKey, revision) } func (s *SQLLog) Append(ctx context.Context, event *server.Event) (int64, error) { diff --git a/pkg/server/get.go b/pkg/server/get.go index dce2c32b..b86153ac 100644 --- a/pkg/server/get.go +++ b/pkg/server/get.go @@ -13,10 +13,6 @@ func (l *LimitedServer) get(ctx context.Context, r *etcdserverpb.RangeRequest) ( } rev, kv, err := l.backend.Get(ctx, string(r.Key), string(r.RangeEnd), r.Limit, r.Revision) - if err != nil { - return nil, err - } - resp := &RangeResponse{ Header: txnHeader(rev), } @@ -24,5 +20,5 @@ func (l *LimitedServer) get(ctx context.Context, r *etcdserverpb.RangeRequest) ( resp.Kvs = []*KeyValue{kv} resp.Count = 1 } - return resp, nil + return resp, err } diff --git a/pkg/server/list.go b/pkg/server/list.go index 9c66b528..31ffc06f 100644 --- a/pkg/server/list.go +++ b/pkg/server/list.go @@ -23,15 +23,13 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) revision := r.Revision if r.CountOnly { - rev, count, err := l.backend.Count(ctx, prefix, revision) - if err != nil { - return nil, err - } - logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count) - return &RangeResponse{ + rev, count, err := l.backend.Count(ctx, prefix, start, revision) + resp := &RangeResponse{ Header: txnHeader(rev), Count: count, - }, nil + } + logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count) + return resp, err } limit := r.Limit @@ -40,10 +38,6 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) } rev, kvs, err := l.backend.List(ctx, prefix, start, limit, revision) - if err != nil { - return nil, err - } - logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, revision, rev, len(kvs), r.Limit) resp := &RangeResponse{ Header: txnHeader(rev), @@ -51,7 +45,7 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) Kvs: kvs, } - // count the actual number of results if there are more items in the db. + // if the number of items returned exceeds the limit, count the keys remaining that follow the start key if limit > 0 && resp.Count > r.Limit { resp.More = true resp.Kvs = kvs[0 : limit-1] @@ -60,13 +54,10 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) revision = rev } - rev, resp.Count, err = l.backend.Count(ctx, prefix, revision) - if err != nil { - return nil, err - } + rev, resp.Count, err = l.backend.Count(ctx, prefix, start, revision) logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, resp.Count) resp.Header = txnHeader(rev) } - return resp, nil + return resp, err } diff --git a/pkg/server/types.go b/pkg/server/types.go index 8b040829..440338cd 100644 --- a/pkg/server/types.go +++ b/pkg/server/types.go @@ -23,7 +23,7 @@ type Backend interface { Create(ctx context.Context, key string, value []byte, lease int64) (int64, error) Delete(ctx context.Context, key string, revision int64) (int64, *KeyValue, bool, error) List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error) - Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) + Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *KeyValue, bool, error) Watch(ctx context.Context, key string, revision int64) WatchResult DbSize(ctx context.Context) (int64, error) @@ -33,8 +33,8 @@ type Backend interface { type Dialect interface { ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) - CountCurrent(ctx context.Context, prefix string) (int64, int64, error) - Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) + CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error) + Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) CurrentRevision(ctx context.Context) (int64, error) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (int64, error)