Skip to content

Commit

Permalink
Fix etcd compat with result limits / pagination
Browse files Browse the repository at this point in the history
* Sort list results by name, not revision. List continuation (start key)
  functionality requires that keys be returned in ascending order.
* Only count keys remaining after the start key, not the total number of
  keys in the prefix.
* Return current revision in header along with error when unable to
  range on key.

Signed-off-by: Brad Davidson <[email protected]>
  • Loading branch information
brandond committed May 8, 2024
1 parent 7484a03 commit 4286268
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 61 deletions.
30 changes: 10 additions & 20 deletions pkg/drivers/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
var _ server.Dialect = (*Generic)(nil)

var (
columns = "kv.id AS theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
revSQL = `
SELECT MAX(rkv.id) AS id
FROM kine AS rkv`
Expand All @@ -39,16 +39,6 @@ var (
FROM kine AS crkv
WHERE crkv.name = 'compact_rev_key'`

idOfKey = `
AND
mkv.id <= ? AND
mkv.id > (
SELECT MAX(ikv.id) AS id
FROM kine AS ikv
WHERE
ikv.name = ? AND
ikv.id <= ?)`

listSQL = fmt.Sprintf(`
SELECT *
FROM (
Expand All @@ -66,7 +56,7 @@ var (
kv.deleted = 0 OR
?
) AS lkv
ORDER BY lkv.theid ASC
ORDER BY lkv.thename ASC
`, revSQL, compactRevSQL, columns)
)

Expand Down Expand Up @@ -218,19 +208,19 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig

GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered),
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?"), paramCharacter, numbered),

CountCurrentSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered),
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ?")), paramCharacter, numbered),

CountRevisionSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered),
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?")), paramCharacter, numbered),

AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
Expand Down Expand Up @@ -364,27 +354,27 @@ func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revi
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted)
return d.query(ctx, sql, prefix, startKey, revision, includeDeleted)
}

func (d *Generic) CountCurrent(ctx context.Context, prefix string) (int64, int64, error) {
func (d *Generic) CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountCurrentSQL, prefix, false)
row := d.queryRow(ctx, d.CountCurrentSQL, prefix, startKey, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}

func (d *Generic) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountRevisionSQL, prefix, revision, false)
row := d.queryRow(ctx, d.CountRevisionSQL, prefix, startKey, revision, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/drivers/nats/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (b *Backend) CurrentRevision(ctx context.Context) (int64, error) {
}

// Count returns an exact count of the number of matching keys and the current revision of the database.
func (b *Backend) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix, revision)
func (b *Backend) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix, startKey, revision)
if err != nil {
return 0, 0, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/drivers/nats/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err := b.Count(ctx, "/", 0)
srev, count, err := b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 4, count)

time.Sleep(time.Second)

srev, count, err = b.Count(ctx, "/", 0)
srev, count, err = b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 3, count)
Expand All @@ -149,7 +149,7 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err = b.Count(ctx, "/", 0)
srev, count, err = b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 6, srev)
expEqual(t, 4, count)
Expand Down
2 changes: 1 addition & 1 deletion pkg/drivers/nats/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ type keySeq struct {
seq uint64
}

func (e *KeyValue) Count(ctx context.Context, prefix string, revision int64) (int64, error) {
func (e *KeyValue) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) {
it := e.bt.Iter()

if prefix != "" {
Expand Down
8 changes: 4 additions & 4 deletions pkg/drivers/nats/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit
}

// Count returns an exact count of the number of matching keys and the current revision of the database
func (b *BackendLogger) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
func (b *BackendLogger) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) {
start := time.Now()
defer func() {
dur := time.Since(start)
fStr := "COUNT %s, rev=%d => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, revision, revRet, count, err, dur)
fStr := "COUNT %s, start=%s, rev=%d => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, startKey, revision, revRet, count, err, dur)
}()

return b.backend.Count(ctx, prefix, revision)
return b.backend.Count(ctx, prefix, startKey, revision)
}

func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/logstructured/logstructured.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ type Log interface {
CompactRevision(ctx context.Context) (int64, error)
CurrentRevision(ctx context.Context) (int64, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error)
Watch(ctx context.Context, prefix string) <-chan []*server.Event
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Append(ctx context.Context, event *server.Event) (int64, error)
DbSize(ctx context.Context) (int64, error)
}
Expand Down Expand Up @@ -199,11 +199,11 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit
return rev, kvs, nil
}

func (l *LogStructured) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
func (l *LogStructured) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) {
defer func() {
logrus.Tracef("COUNT %s, rev=%d => rev=%d, count=%d, err=%v", prefix, revision, revRet, count, err)
}()
rev, count, err := l.log.Count(ctx, prefix, revision)
rev, count, err := l.log.Count(ctx, prefix, startKey, revision)
if err != nil {
return 0, 0, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/logstructured/sqllog/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,15 @@ func canSkipRevision(rev, skip int64, skipTime time.Time) bool {
return rev == skip && time.Since(skipTime) > time.Second
}

func (s *SQLLog) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
func (s *SQLLog) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
if strings.HasSuffix(prefix, "/") {
prefix += "%"
}

if revision == 0 {
return s.d.CountCurrent(ctx, prefix)
return s.d.CountCurrent(ctx, prefix, startKey)
}
return s.d.Count(ctx, prefix, revision)
return s.d.Count(ctx, prefix, startKey, revision)
}

func (s *SQLLog) Append(ctx context.Context, event *server.Event) (int64, error) {
Expand Down
6 changes: 1 addition & 5 deletions pkg/server/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ func (l *LimitedServer) get(ctx context.Context, r *etcdserverpb.RangeRequest) (
}

rev, kv, err := l.backend.Get(ctx, string(r.Key), string(r.RangeEnd), r.Limit, r.Revision)
if err != nil {
return nil, err
}

resp := &RangeResponse{
Header: txnHeader(rev),
}
if kv != nil {
resp.Kvs = []*KeyValue{kv}
resp.Count = 1
}
return resp, nil
return resp, err
}
25 changes: 8 additions & 17 deletions pkg/server/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
revision := r.Revision

if r.CountOnly {
rev, count, err := l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count)
return &RangeResponse{
rev, count, err := l.backend.Count(ctx, prefix, start, revision)
resp := &RangeResponse{
Header: txnHeader(rev),
Count: count,
}, nil
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count)
return resp, err
}

limit := r.Limit
Expand All @@ -40,18 +38,14 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
}

rev, kvs, err := l.backend.List(ctx, prefix, start, limit, revision)
if err != nil {
return nil, err
}

logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, revision, rev, len(kvs), r.Limit)
resp := &RangeResponse{
Header: txnHeader(rev),
Count: int64(len(kvs)),
Kvs: kvs,
}

// count the actual number of results if there are more items in the db.
// if the number of items returned exceeds the limit, count the keys remaining that follow the start key
if limit > 0 && resp.Count > r.Limit {
resp.More = true
resp.Kvs = kvs[0 : limit-1]
Expand All @@ -60,13 +54,10 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
revision = rev
}

rev, resp.Count, err = l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
rev, resp.Count, err = l.backend.Count(ctx, prefix, start, revision)
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, resp.Count)
resp.Header = txnHeader(rev)
}

return resp, nil
return resp, err
}
6 changes: 3 additions & 3 deletions pkg/server/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Backend interface {
Create(ctx context.Context, key string, value []byte, lease int64) (int64, error)
Delete(ctx context.Context, key string, revision int64) (int64, *KeyValue, bool, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *KeyValue, bool, error)
Watch(ctx context.Context, key string, revision int64) WatchResult
DbSize(ctx context.Context) (int64, error)
Expand All @@ -33,8 +33,8 @@ type Backend interface {
type Dialect interface {
ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error)
CountCurrent(ctx context.Context, prefix string) (int64, int64, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
CurrentRevision(ctx context.Context) (int64, error)
After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error)
Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (int64, error)
Expand Down

0 comments on commit 4286268

Please sign in to comment.