diff --git a/orm/auto_uint64.go b/orm/auto_uint64.go new file mode 100644 index 000000000000..c0999bfd57b0 --- /dev/null +++ b/orm/auto_uint64.go @@ -0,0 +1,130 @@ +package orm + +import ( + "github.com/cosmos/cosmos-sdk/codec" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +var _ Indexable = &AutoUInt64TableBuilder{} + +// NewAutoUInt64TableBuilder creates a builder to setup a AutoUInt64Table object. +func NewAutoUInt64TableBuilder(prefixData byte, prefixSeq byte, storeKey sdk.StoreKey, model codec.ProtoMarshaler, cdc codec.Codec) *AutoUInt64TableBuilder { + if prefixData == prefixSeq { + panic("prefixData and prefixSeq must be unique") + } + + uInt64KeyCodec := FixLengthIndexKeys(EncodedSeqLength) + return &AutoUInt64TableBuilder{ + TableBuilder: NewTableBuilder(prefixData, storeKey, model, uInt64KeyCodec, cdc), + seq: NewSequence(storeKey, prefixSeq), + } +} + +type AutoUInt64TableBuilder struct { + *TableBuilder + seq Sequence +} + +// Build create the AutoUInt64Table object. +func (a AutoUInt64TableBuilder) Build() AutoUInt64Table { + return AutoUInt64Table{ + table: a.TableBuilder.Build(), + seq: a.seq, + } +} + +var _ SequenceExportable = &AutoUInt64Table{} +var _ TableExportable = &AutoUInt64Table{} + +// AutoUInt64Table is the table type which an auto incrementing ID. +type AutoUInt64Table struct { + table Table + seq Sequence +} + +// Create a new persistent object with an auto generated uint64 primary key. They key is returned. +// Create iterates though the registered callbacks and may add secondary index keys by them. +func (a AutoUInt64Table) Create(ctx HasKVStore, obj codec.ProtoMarshaler) (uint64, error) { + autoIncID := a.seq.NextVal(ctx) + err := a.table.Create(ctx, EncodeSequence(autoIncID), obj) + if err != nil { + return 0, err + } + return autoIncID, nil +} + +// Save updates the given object under the rowID key. It expects the key to exists already +// and fails with an `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. Parameters must not be nil. +// +// Save iterates though the registered callbacks and may add or remove secondary index keys by them. +func (a AutoUInt64Table) Save(ctx HasKVStore, rowID uint64, newValue codec.ProtoMarshaler) error { + return a.table.Save(ctx, EncodeSequence(rowID), newValue) +} + +// Delete removes the object under the rowID key. It expects the key to exists already +// and fails with a `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. +// +// Delete iterates though the registered callbacks and removes secondary index keys by them. +func (a AutoUInt64Table) Delete(ctx HasKVStore, rowID uint64) error { + return a.table.Delete(ctx, EncodeSequence(rowID)) +} + +// Has checks if a rowID exists. +func (a AutoUInt64Table) Has(ctx HasKVStore, rowID uint64) bool { + return a.table.Has(ctx, EncodeSequence(rowID)) +} + +// GetOne load the object persisted for the given RowID into the dest parameter. +// If none exists `ErrNotFound` is returned instead. Parameters must not be nil. +func (a AutoUInt64Table) GetOne(ctx HasKVStore, rowID uint64, dest codec.ProtoMarshaler) (RowID, error) { + rawRowID := EncodeSequence(rowID) + if err := a.table.GetOne(ctx, rawRowID, dest); err != nil { + return nil, err + } + return rawRowID, nil +} + +// PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. +// Example: +// it, err := idx.PrefixScan(ctx, start, end) +// if err !=nil { +// return err +// } +// const defaultLimit = 20 +// it = LimitIterator(it, defaultLimit) +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a AutoUInt64Table) PrefixScan(ctx HasKVStore, start, end uint64) (Iterator, error) { + return a.table.PrefixScan(ctx, EncodeSequence(start), EncodeSequence(end)) +} + +// ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. See `LimitIterator` +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a AutoUInt64Table) ReversePrefixScan(ctx HasKVStore, start uint64, end uint64) (Iterator, error) { + return a.table.ReversePrefixScan(ctx, EncodeSequence(start), EncodeSequence(end)) +} + +// Sequence returns the sequence used by this table +func (a AutoUInt64Table) Sequence() Sequence { + return a.seq +} + +// Table satisfies the TableExportable interface and must not be used otherwise. +func (a AutoUInt64Table) Table() Table { + return a.table +} diff --git a/orm/auto_uint64_test.go b/orm/auto_uint64_test.go new file mode 100644 index 000000000000..3b55cccfae33 --- /dev/null +++ b/orm/auto_uint64_test.go @@ -0,0 +1,163 @@ +package orm_test + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +func TestAutoUInt64PrefixScan(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const ( + testTablePrefix = iota + testTableSeqPrefix + ) + tb := orm.NewAutoUInt64TableBuilder(testTablePrefix, testTableSeqPrefix, storeKey, &testdata.GroupInfo{}, cdc).Build() + ctx := orm.NewMockContext() + + g1 := testdata.GroupInfo{ + Description: "my test 1", + Admin: sdk.AccAddress([]byte("admin-address")), + } + g2 := testdata.GroupInfo{ + Description: "my test 2", + Admin: sdk.AccAddress([]byte("admin-address")), + } + g3 := testdata.GroupInfo{ + Description: "my test 3", + Admin: sdk.AccAddress([]byte("admin-address")), + } + for _, g := range []testdata.GroupInfo{g1, g2, g3} { + _, err := tb.Create(ctx, &g) + require.NoError(t, err) + } + + specs := map[string]struct { + start, end uint64 + expResult []testdata.GroupInfo + expRowIDs []orm.RowID + expError *errors.Error + method func(ctx orm.HasKVStore, start uint64, end uint64) (orm.Iterator, error) + }{ + "first element": { + start: 1, + end: 2, + method: tb.PrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "first 2 elements": { + start: 1, + end: 3, + method: tb.PrefixScan, + expResult: []testdata.GroupInfo{g1, g2}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1), orm.EncodeSequence(2)}, + }, + "first 3 elements": { + start: 1, + end: 4, + method: tb.PrefixScan, + expResult: []testdata.GroupInfo{g1, g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1), orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "search with max end": { + start: 1, + end: math.MaxUint64, + method: tb.PrefixScan, + expResult: []testdata.GroupInfo{g1, g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1), orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "2 to end": { + start: 2, + end: 5, + method: tb.PrefixScan, + expResult: []testdata.GroupInfo{g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "start before end should fail": { + start: 2, + end: 1, + method: tb.PrefixScan, + expError: orm.ErrArgument, + }, + "start equals end should fail": { + start: 1, + end: 1, + method: tb.PrefixScan, + expError: orm.ErrArgument, + }, + "reverse first element": { + start: 1, + end: 2, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "reverse first 2 elements": { + start: 1, + end: 3, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupInfo{g2, g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(2), orm.EncodeSequence(1)}, + }, + "reverse first 3 elements": { + start: 1, + end: 4, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2, g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2), orm.EncodeSequence(1)}, + }, + "reverse search with max end": { + start: 1, + end: math.MaxUint64, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2, g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2), orm.EncodeSequence(1)}, + }, + "reverse 2 to end": { + start: 2, + end: 5, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2)}, + }, + "reverse start before end should fail": { + start: 2, + end: 1, + method: tb.ReversePrefixScan, + expError: orm.ErrArgument, + }, + "reverse start equals end should fail": { + start: 1, + end: 1, + method: tb.ReversePrefixScan, + expError: orm.ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + it, err := spec.method(ctx, spec.start, spec.end) + require.True(t, spec.expError.Is(err), "expected #+v but got #+v", spec.expError, err) + if spec.expError != nil { + return + } + var loaded []testdata.GroupInfo + rowIDs, err := orm.ReadAll(it, &loaded) + require.NoError(t, err) + assert.Equal(t, spec.expResult, loaded) + assert.Equal(t, spec.expRowIDs, rowIDs) + }) + } +} diff --git a/orm/example_test.go b/orm/example_test.go new file mode 100644 index 000000000000..2821a803b2f6 --- /dev/null +++ b/orm/example_test.go @@ -0,0 +1,52 @@ +package orm_test + +import ( + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type GroupKeeper struct { + key sdk.StoreKey + groupTable orm.AutoUInt64Table + groupByAdminIndex orm.Index + groupMemberTable orm.PrimaryKeyTable + groupMemberByGroupIndex orm.Index + groupMemberByMemberIndex orm.Index +} + +var ( + GroupTablePrefix byte = 0x0 + GroupTableSeqPrefix byte = 0x1 + GroupByAdminIndexPrefix byte = 0x2 + GroupMemberTablePrefix byte = 0x3 + GroupMemberTableSeqPrefix byte = 0x4 + GroupMemberTableIndexPrefix byte = 0x5 + GroupMemberByGroupIndexPrefix byte = 0x6 + GroupMemberByMemberIndexPrefix byte = 0x7 +) + +func NewGroupKeeper(storeKey sdk.StoreKey, cdc codec.Codec) GroupKeeper { + k := GroupKeeper{key: storeKey} + + groupTableBuilder := orm.NewAutoUInt64TableBuilder(GroupTablePrefix, GroupTableSeqPrefix, storeKey, &testdata.GroupInfo{}, cdc) + // note: quite easy to mess with Index prefixes when managed outside. no fail fast on duplicates + k.groupByAdminIndex = orm.NewIndex(groupTableBuilder, GroupByAdminIndexPrefix, func(val interface{}) ([]orm.RowID, error) { + return []orm.RowID{[]byte(val.(*testdata.GroupInfo).Admin)}, nil + }) + k.groupTable = groupTableBuilder.Build() + + groupMemberTableBuilder := orm.NewPrimaryKeyTableBuilder(GroupMemberTablePrefix, storeKey, &testdata.GroupMember{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc) + + k.groupMemberByGroupIndex = orm.NewIndex(groupMemberTableBuilder, GroupMemberByGroupIndexPrefix, func(val interface{}) ([]orm.RowID, error) { + group := val.(*testdata.GroupMember).Group + return []orm.RowID{[]byte(group)}, nil + }) + k.groupMemberByMemberIndex = orm.NewIndex(groupMemberTableBuilder, GroupMemberByMemberIndexPrefix, func(val interface{}) ([]orm.RowID, error) { + return []orm.RowID{[]byte(val.(*testdata.GroupMember).Member)}, nil + }) + k.groupMemberTable = groupMemberTableBuilder.Build() + + return k +} diff --git a/orm/genesis.go b/orm/genesis.go new file mode 100644 index 000000000000..b577c775ec6c --- /dev/null +++ b/orm/genesis.go @@ -0,0 +1,88 @@ +package orm + +import ( + "reflect" + + "github.com/cosmos/cosmos-sdk/store/prefix" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +// TableExportable +type TableExportable interface { + // Table returns the table to export + Table() Table +} + +// SequenceExportable +type SequenceExportable interface { + // Sequence returns the sequence to export + Sequence() Sequence +} + +// ExportTableData iterates over the given table entries and stores them at the passed ModelSlicePtr. +// When the given table implements the `SequenceExportable` interface then it's current value +// is returned as well or otherwise defaults to 0. +func ExportTableData(ctx HasKVStore, t TableExportable, dest ModelSlicePtr) (uint64, error) { + it, err := t.Table().PrefixScan(ctx, nil, nil) + if err != nil { + return 0, errors.Wrap(err, "table PrefixScan failure when exporting table data") + } + _, err = ReadAll(it, dest) + if err != nil { + return 0, err + } + var seqValue uint64 + if st, ok := t.(SequenceExportable); ok { + seqValue = st.Sequence().CurVal(ctx) + } + return seqValue, err +} + +// ImportTableData initializes a table and attaches indexers from the given data interface{}. +// data should be a slice of structs that implement PrimaryKeyed (eg []*GroupInfo). +// The seqValue is optional and only used with tables that implement the `SequenceExportable` interface. +func ImportTableData(ctx HasKVStore, t TableExportable, data interface{}, seqValue uint64) error { + table := t.Table() + if err := clearAllInTable(ctx, table); err != nil { + return errors.Wrap(err, "clear old entries") + } + + if st, ok := t.(SequenceExportable); ok { + if err := st.Sequence().InitVal(ctx, seqValue); err != nil { + return errors.Wrap(err, "sequence") + } + } + + // Provided data must be a slice + modelSlice := reflect.ValueOf(data) + if modelSlice.Kind() != reflect.Slice { + return errors.Wrap(ErrArgument, "data must be a slice") + } + + // Create table entries + for i := 0; i < modelSlice.Len(); i++ { + obj, ok := modelSlice.Index(i).Interface().(PrimaryKeyed) + if !ok { + return errors.Wrapf(ErrArgument, "unsupported type :%s", reflect.TypeOf(data).Elem().Elem()) + } + err := table.Create(ctx, obj.PrimaryKey(), obj) + if err != nil { + return err + } + } + + return nil +} + +// clearAllInTable deletes all entries in a table with delete interceptors called +func clearAllInTable(ctx HasKVStore, table Table) error { + store := prefix.NewStore(ctx.KVStore(table.storeKey), []byte{table.prefix}) + it := store.Iterator(nil, nil) + defer it.Close() + for ; it.Valid(); it.Next() { + if err := table.Delete(ctx, it.Key()); err != nil { + return err + } + } + return nil +} diff --git a/orm/genesis_test.go b/orm/genesis_test.go new file mode 100644 index 000000000000..817630b27e6e --- /dev/null +++ b/orm/genesis_test.go @@ -0,0 +1,55 @@ +package orm_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func TestImportExportTableData(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const prefix = iota + table := orm.NewAutoUInt64TableBuilder(prefix, 0x1, storeKey, &testdata.GroupInfo{}, cdc).Build() + + ctx := orm.NewMockContext() + + groups := []*testdata.GroupInfo{ + { + GroupId: 1, + Admin: sdk.AccAddress([]byte("admin1-address")), + }, + { + GroupId: 2, + Admin: sdk.AccAddress([]byte("admin2-address")), + }, + } + + err := orm.ImportTableData(ctx, table, groups, 2) + require.NoError(t, err) + + for _, g := range groups { + var loaded testdata.GroupInfo + _, err := table.GetOne(ctx, g.GroupId, &loaded) + require.NoError(t, err) + + require.Equal(t, g, &loaded) + } + + var exported []*testdata.GroupInfo + seq, err := orm.ExportTableData(ctx, table, &exported) + require.NoError(t, err) + require.Equal(t, seq, uint64(2)) + + for i, g := range exported { + require.Equal(t, g, groups[i]) + } +} diff --git a/orm/index.go b/orm/index.go new file mode 100644 index 000000000000..c918f446b193 --- /dev/null +++ b/orm/index.go @@ -0,0 +1,219 @@ +package orm + +import ( + "bytes" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/store/prefix" + "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/query" +) + +// indexer creates and modifies the second MultiKeyIndex based on the operations and changes on the primary object. +type indexer interface { + OnCreate(store sdk.KVStore, rowID RowID, value interface{}) error + OnDelete(store sdk.KVStore, rowID RowID, value interface{}) error + OnUpdate(store sdk.KVStore, rowID RowID, newValue, oldValue interface{}) error +} + +// MultiKeyIndex is an index where multiple entries can point to the same underlying object as opposite to a unique index +// where only one entry is allowed. +type MultiKeyIndex struct { + storeKey sdk.StoreKey + prefix byte + rowGetter RowGetter + indexer indexer + indexKeyCodec IndexKeyCodec +} + +// NewIndex builds a MultiKeyIndex +func NewIndex(builder Indexable, prefix byte, indexer IndexerFunc) MultiKeyIndex { + return newIndex(builder, prefix, NewIndexer(indexer, builder.IndexKeyCodec())) +} + +func newIndex(builder Indexable, prefix byte, indexer *Indexer) MultiKeyIndex { + codec := builder.IndexKeyCodec() + if codec == nil { + panic("IndexKeyCodec must not be nil") + } + storeKey := builder.StoreKey() + if storeKey == nil { + panic("StoreKey must not be nil") + } + rowGetter := builder.RowGetter() + if rowGetter == nil { + panic("RowGetter must not be nil") + } + + idx := MultiKeyIndex{ + storeKey: storeKey, + prefix: prefix, + rowGetter: rowGetter, + indexer: indexer, + indexKeyCodec: codec, + } + builder.AddAfterSaveInterceptor(idx.onSave) + builder.AddAfterDeleteInterceptor(idx.onDelete) + return idx +} + +// Has checks if a key exists. Panics on nil key. +func (i MultiKeyIndex) Has(ctx HasKVStore, key []byte) bool { + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + it := store.Iterator(PrefixRange(key)) + defer it.Close() + return it.Valid() +} + +// Get returns a result iterator for the searchKey. Parameters must not be nil. +func (i MultiKeyIndex) Get(ctx HasKVStore, searchKey []byte) (Iterator, error) { + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + it := store.Iterator(PrefixRange(searchKey)) + return indexIterator{ctx: ctx, it: it, rowGetter: i.rowGetter, keyCodec: i.indexKeyCodec}, nil +} + +// GetPaginated creates an iterator for the searchKey +// starting from pageRequest.Key if provided. +// The pageRequest.Key is the rowID while searchKey is a MultiKeyIndex key. +func (i MultiKeyIndex) GetPaginated(ctx HasKVStore, searchKey []byte, pageRequest *query.PageRequest) (Iterator, error) { + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + start, end := PrefixRange(searchKey) + + if pageRequest != nil && len(pageRequest.Key) != 0 { + start = i.indexKeyCodec.BuildIndexKey(searchKey, RowID(pageRequest.Key)) + } + it := store.Iterator(start, end) + return indexIterator{ctx: ctx, it: it, rowGetter: i.rowGetter, keyCodec: i.indexKeyCodec}, nil +} + +// PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. +// Example: +// it, err := idx.PrefixScan(ctx, start, end) +// if err !=nil { +// return err +// } +// const defaultLimit = 20 +// it = LimitIterator(it, defaultLimit) +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (i MultiKeyIndex) PrefixScan(ctx HasKVStore, start []byte, end []byte) (Iterator, error) { + if start != nil && end != nil && bytes.Compare(start, end) >= 0 { + return NewInvalidIterator(), errors.Wrap(ErrArgument, "start must be less than end") + } + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + it := store.Iterator(start, end) + return indexIterator{ctx: ctx, it: it, rowGetter: i.rowGetter, keyCodec: i.indexKeyCodec}, nil +} + +// ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. See `LimitIterator` +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (i MultiKeyIndex) ReversePrefixScan(ctx HasKVStore, start []byte, end []byte) (Iterator, error) { + if start != nil && end != nil && bytes.Compare(start, end) >= 0 { + return NewInvalidIterator(), errors.Wrap(ErrArgument, "start must be less than end") + } + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + it := store.ReverseIterator(start, end) + return indexIterator{ctx: ctx, it: it, rowGetter: i.rowGetter, keyCodec: i.indexKeyCodec}, nil +} + +func (i MultiKeyIndex) onSave(ctx HasKVStore, rowID RowID, newValue, oldValue codec.ProtoMarshaler) error { + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + if oldValue == nil { + return i.indexer.OnCreate(store, rowID, newValue) + } + return i.indexer.OnUpdate(store, rowID, newValue, oldValue) +} + +func (i MultiKeyIndex) onDelete(ctx HasKVStore, rowID RowID, oldValue codec.ProtoMarshaler) error { + store := prefix.NewStore(ctx.KVStore(i.storeKey), []byte{i.prefix}) + return i.indexer.OnDelete(store, rowID, oldValue) +} + +type UniqueIndex struct { + MultiKeyIndex +} + +// NewUniqueIndex create a new Index object where duplicate keys are prohibited. +func NewUniqueIndex(builder Indexable, prefix byte, uniqueIndexerFunc UniqueIndexerFunc) UniqueIndex { + return UniqueIndex{ + MultiKeyIndex: newIndex(builder, prefix, NewUniqueIndexer(uniqueIndexerFunc, builder.IndexKeyCodec())), + } +} + +// indexIterator uses rowGetter to lazy load new model values on request. +type indexIterator struct { + ctx HasKVStore + rowGetter RowGetter + it types.Iterator + keyCodec IndexKeyCodec +} + +// LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there +// are no more items the ErrIteratorDone error is returned +// The key is the rowID and not any MultiKeyIndex key. +func (i indexIterator) LoadNext(dest codec.ProtoMarshaler) (RowID, error) { + if !i.it.Valid() { + return nil, ErrIteratorDone + } + indexPrefixKey := i.it.Key() + rowID := i.keyCodec.StripRowID(indexPrefixKey) + i.it.Next() + return rowID, i.rowGetter(i.ctx, rowID, dest) +} + +// Close releases the iterator and should be called at the end of iteration +func (i indexIterator) Close() error { + i.it.Close() + return nil +} + +// PrefixRange turns a prefix into a (start, end) range. The start is the given prefix value and +// the end is calculated by adding 1 bit to the start value. Nil is not allowed as prefix. +// Example: []byte{1, 3, 4} becomes []byte{1, 3, 5} +// []byte{15, 42, 255, 255} becomes []byte{15, 43, 0, 0} +// +// In case of an overflow the end is set to nil. +// Example: []byte{255, 255, 255, 255} becomes nil +// +func PrefixRange(prefix []byte) ([]byte, []byte) { + if prefix == nil { + panic("nil key not allowed") + } + // special case: no prefix is whole range + if len(prefix) == 0 { + return nil, nil + } + + // copy the prefix and update last byte + end := make([]byte, len(prefix)) + copy(end, prefix) + l := len(end) - 1 + end[l]++ + + // wait, what if that overflowed?.... + for end[l] == 0 && l > 0 { + l-- + end[l]++ + } + + // okay, funny guy, you gave us FFF, no end to this range... + if l == 0 && end[0] == 0 { + end = nil + } + return prefix, end +} diff --git a/orm/index_key_codec.go b/orm/index_key_codec.go new file mode 100644 index 000000000000..8ef01ab04181 --- /dev/null +++ b/orm/index_key_codec.go @@ -0,0 +1,68 @@ +package orm + +// Max255DynamicLengthIndexKeyCodec works with up to 255 byte dynamic size RowIDs. +// They are encoded as `concat(searchableKey, rowID, len(rowID)[0])` and can be used +// with PrimaryKey or external Key tables for example. +type Max255DynamicLengthIndexKeyCodec struct{} + +// BuildIndexKey builds the index key by appending searchableKey with rowID and length int. +// The RowID length must not be greater than 255. +func (Max255DynamicLengthIndexKeyCodec) BuildIndexKey(searchableKey []byte, rowID RowID) []byte { + rowIDLen := len(rowID) + switch { + case rowIDLen == 0: + panic("Empty RowID") + case rowIDLen > 255: + panic("RowID exceeds max size") + } + + searchableKeyLen := len(searchableKey) + res := make([]byte, searchableKeyLen+rowIDLen+1) + copy(res, searchableKey) + copy(res[searchableKeyLen:], rowID) + res[searchableKeyLen+rowIDLen] = byte(rowIDLen) + return res +} + +// StripRowID returns the RowID from the combined persistentIndexKey. It is the reverse operation to BuildIndexKey +// but with the searchableKey and length int dropped. +func (Max255DynamicLengthIndexKeyCodec) StripRowID(persistentIndexKey []byte) RowID { + n := len(persistentIndexKey) + searchableKeyLen := persistentIndexKey[n-1] + return persistentIndexKey[n-int(searchableKeyLen)-1 : n-1] +} + +// FixLengthIndexKeyCodec expects the RowID to always have the same length with all entries. +// They are encoded as `concat(searchableKey, rowID)` and can be used +// with AutoUint64Tables and length EncodedSeqLength for example. +type FixLengthIndexKeyCodec struct { + rowIDLength int +} + +// FixLengthIndexKeys is a constructor for FixLengthIndexKeyCodec. +func FixLengthIndexKeys(rowIDLength int) *FixLengthIndexKeyCodec { + return &FixLengthIndexKeyCodec{rowIDLength: rowIDLength} +} + +// BuildIndexKey builds the index key by appending searchableKey with rowID. +// The RowID length must not be greater than what is defined by rowIDLength in construction. +func (c FixLengthIndexKeyCodec) BuildIndexKey(searchableKey []byte, rowID RowID) []byte { + switch n := len(rowID); { + case n == 0: + panic("Empty RowID") + case n > c.rowIDLength: + panic("RowID exceeds max size") + } + n := len(searchableKey) + res := make([]byte, n+c.rowIDLength) + copy(res, searchableKey) + copy(res[n:], rowID) + return res +} + +// StripRowID returns the RowID from the combined persistentIndexKey. It is the reverse operation to BuildIndexKey +// but with the searchableKey dropped. +func (c FixLengthIndexKeyCodec) StripRowID(persistentIndexKey []byte) RowID { + n := len(persistentIndexKey) + return persistentIndexKey[n-c.rowIDLength:] +} diff --git a/orm/index_key_codec_test.go b/orm/index_key_codec_test.go new file mode 100644 index 000000000000..d7f49df545f2 --- /dev/null +++ b/orm/index_key_codec_test.go @@ -0,0 +1,115 @@ +package orm + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeIndexKey(t *testing.T) { + specs := map[string]struct { + srcKey []byte + srcRowID RowID + enc IndexKeyCodec + expKey []byte + expPanic bool + }{ + "dynamic length example 1": { + srcKey: []byte{0x0, 0x1, 0x2}, + srcRowID: []byte{0x3, 0x4}, + enc: Max255DynamicLengthIndexKeyCodec{}, + expKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x2}, + }, + "dynamic length example 2": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte{0x2, 0x3, 0x4}, + enc: Max255DynamicLengthIndexKeyCodec{}, + expKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x3}, + }, + "dynamic length max row ID": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte(strings.Repeat("a", 255)), + enc: Max255DynamicLengthIndexKeyCodec{}, + expKey: append(append([]byte{0x0, 0x1}, []byte(strings.Repeat("a", 255))...), 0xff), + }, + "dynamic length panics with empty rowID": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte{}, + enc: Max255DynamicLengthIndexKeyCodec{}, + expPanic: true, + }, + "dynamic length exceeds max row ID": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte(strings.Repeat("a", 256)), + enc: Max255DynamicLengthIndexKeyCodec{}, + expPanic: true, + }, + "uint64 example": { + srcKey: []byte{0x0, 0x1, 0x2}, + srcRowID: []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}, + enc: FixLengthIndexKeys(8), + expKey: []byte{0x0, 0x1, 0x2, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}, + }, + "uint64 panics with empty rowID": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte{}, + enc: FixLengthIndexKeys(8), + expPanic: true, + }, + "uint64 exceeds max bytes in rowID": { + srcKey: []byte{0x0, 0x1}, + srcRowID: []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9}, + enc: FixLengthIndexKeys(8), + expPanic: true, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + if spec.expPanic { + require.Panics(t, + func() { + _ = spec.enc.BuildIndexKey(spec.srcKey, spec.srcRowID) + }) + return + } + got := spec.enc.BuildIndexKey(spec.srcKey, spec.srcRowID) + assert.Equal(t, spec.expKey, got) + }) + } +} +func TestDecodeIndexKey(t *testing.T) { + specs := map[string]struct { + srcKey []byte + enc IndexKeyCodec + expRowID RowID + }{ + "dynamic length example 1": { + srcKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x2}, + enc: Max255DynamicLengthIndexKeyCodec{}, + expRowID: []byte{0x3, 0x4}, + }, + "dynamic length example 2": { + srcKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x3}, + enc: Max255DynamicLengthIndexKeyCodec{}, + expRowID: []byte{0x2, 0x3, 0x4}, + }, + "dynamic length max row ID": { + srcKey: append(append([]byte{0x0, 0x1}, []byte(strings.Repeat("a", 255))...), 0xff), + enc: Max255DynamicLengthIndexKeyCodec{}, + expRowID: []byte(strings.Repeat("a", 255)), + }, + "uint64 example": { + srcKey: []byte{0x0, 0x1, 0x2, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}, + expRowID: []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}, + enc: FixLengthIndexKeys(8), + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + gotRow := spec.enc.StripRowID(spec.srcKey) + assert.Equal(t, spec.expRowID, gotRow) + }) + } +} diff --git a/orm/index_test.go b/orm/index_test.go new file mode 100644 index 000000000000..604d0c3637c4 --- /dev/null +++ b/orm/index_test.go @@ -0,0 +1,347 @@ +package orm_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/query" +) + +func TestIndexPrefixScan(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + storeKey := sdk.NewKVStoreKey("test") + const ( + testTablePrefix = iota + testTableSeqPrefix + ) + tBuilder := orm.NewAutoUInt64TableBuilder(testTablePrefix, testTableSeqPrefix, storeKey, &testdata.GroupInfo{}, cdc) + idx := orm.NewIndex(tBuilder, GroupByAdminIndexPrefix, func(val interface{}) ([]orm.RowID, error) { + return []orm.RowID{[]byte(val.(*testdata.GroupInfo).Admin)}, nil + }) + tb := tBuilder.Build() + ctx := orm.NewMockContext() + + g1 := testdata.GroupInfo{ + Description: "my test 1", + Admin: sdk.AccAddress([]byte("admin-address-a")), + } + g2 := testdata.GroupInfo{ + Description: "my test 2", + Admin: sdk.AccAddress([]byte("admin-address-b")), + } + g3 := testdata.GroupInfo{ + Description: "my test 3", + Admin: sdk.AccAddress([]byte("admin-address-b")), + } + for _, g := range []testdata.GroupInfo{g1, g2, g3} { + _, err := tb.Create(ctx, &g) + require.NoError(t, err) + } + + specs := map[string]struct { + start, end []byte + expResult []testdata.GroupInfo + expRowIDs []orm.RowID + expError *errors.Error + method func(ctx orm.HasKVStore, start, end []byte) (orm.Iterator, error) + }{ + "exact match with a single result": { + start: []byte("admin-address-a"), + end: []byte("admin-address-b"), + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "one result by prefix": { + start: []byte("admin-address"), + end: []byte("admin-address-b"), + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "multi key elements by exact match": { + start: []byte("admin-address-b"), + end: []byte("admin-address-c"), + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "open end query": { + start: []byte("admin-address-b"), + end: nil, + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "open start query": { + start: nil, + end: []byte("admin-address-b"), + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "open start and end query": { + start: nil, + end: nil, + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g1, g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1), orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "all matching prefix": { + start: []byte("admin"), + end: nil, + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{g1, g2, g3}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1), orm.EncodeSequence(2), orm.EncodeSequence(3)}, + }, + "non matching prefix": { + start: []byte("nobody"), + end: nil, + method: idx.PrefixScan, + expResult: []testdata.GroupInfo{}, + }, + "start equals end": { + start: []byte("any"), + end: []byte("any"), + method: idx.PrefixScan, + expError: orm.ErrArgument, + }, + "start after end": { + start: []byte("b"), + end: []byte("a"), + method: idx.PrefixScan, + expError: orm.ErrArgument, + }, + "reverse: exact match with a single result": { + start: []byte("admin-address-a"), + end: []byte("admin-address-b"), + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "reverse: one result by prefix": { + start: []byte("admin-address"), + end: []byte("admin-address-b"), + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "reverse: multi key elements by exact match": { + start: []byte("admin-address-b"), + end: []byte("admin-address-c"), + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2)}, + }, + "reverse: open end query": { + start: []byte("admin-address-b"), + end: nil, + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2)}, + }, + "reverse: open start query": { + start: nil, + end: []byte("admin-address-b"), + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(1)}, + }, + "reverse: open start and end query": { + start: nil, + end: nil, + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2, g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2), orm.EncodeSequence(1)}, + }, + "reverse: all matching prefix": { + start: []byte("admin"), + end: nil, + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{g3, g2, g1}, + expRowIDs: []orm.RowID{orm.EncodeSequence(3), orm.EncodeSequence(2), orm.EncodeSequence(1)}, + }, + "reverse: non matching prefix": { + start: []byte("nobody"), + end: nil, + method: idx.ReversePrefixScan, + expResult: []testdata.GroupInfo{}, + }, + "reverse: start equals end": { + start: []byte("any"), + end: []byte("any"), + method: idx.ReversePrefixScan, + expError: orm.ErrArgument, + }, + "reverse: start after end": { + start: []byte("b"), + end: []byte("a"), + method: idx.ReversePrefixScan, + expError: orm.ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + it, err := spec.method(ctx, spec.start, spec.end) + require.True(t, spec.expError.Is(err), "expected #+v but got #+v", spec.expError, err) + if spec.expError != nil { + return + } + var loaded []testdata.GroupInfo + rowIDs, err := orm.ReadAll(it, &loaded) + require.NoError(t, err) + assert.Equal(t, spec.expResult, loaded) + assert.Equal(t, spec.expRowIDs, rowIDs) + }) + } +} + +func TestUniqueIndex(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + + tableBuilder := orm.NewPrimaryKeyTableBuilder(GroupMemberTablePrefix, storeKey, &testdata.GroupMember{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc) + uniqueIdx := orm.NewUniqueIndex(tableBuilder, 0x10, func(val interface{}) (orm.RowID, error) { + return []byte{val.(*testdata.GroupMember).Member[0]}, nil + }) + myTable := tableBuilder.Build() + + ctx := orm.NewMockContext() + + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 10, + } + err := myTable.Create(ctx, &m) + require.NoError(t, err) + + indexedKey := []byte{byte('m')} + + // Has + assert.True(t, uniqueIdx.Has(ctx, indexedKey)) + + // Get + it, err := uniqueIdx.Get(ctx, indexedKey) + require.NoError(t, err) + var loaded testdata.GroupMember + rowID, err := it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, orm.RowID(m.PrimaryKey()), rowID) + require.Equal(t, m, loaded) + + // GetPaginated + cases := map[string]struct { + pageReq *query.PageRequest + expErr bool + }{ + "nil key": { + pageReq: &query.PageRequest{Key: nil}, + expErr: false, + }, + "after indexed key": { + pageReq: &query.PageRequest{Key: indexedKey}, + expErr: true, + }, + } + + for testName, tc := range cases { + t.Run(testName, func(t *testing.T) { + it, err := uniqueIdx.GetPaginated(ctx, indexedKey, tc.pageReq) + require.NoError(t, err) + rowID, err := it.LoadNext(&loaded) + if tc.expErr { // iterator done + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, orm.RowID(m.PrimaryKey()), rowID) + require.Equal(t, m, loaded) + } + }) + } + + // PrefixScan match + it, err = uniqueIdx.PrefixScan(ctx, []byte{byte('m')}, []byte{byte('n')}) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, orm.RowID(m.PrimaryKey()), rowID) + require.Equal(t, m, loaded) + + // PrefixScan no match + it, err = uniqueIdx.PrefixScan(ctx, []byte{byte('n')}, nil) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.Error(t, orm.ErrIteratorDone, err) + + // ReversePrefixScan match + it, err = uniqueIdx.ReversePrefixScan(ctx, []byte{byte('a')}, []byte{byte('z')}) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, orm.RowID(m.PrimaryKey()), rowID) + require.Equal(t, m, loaded) + + // ReversePrefixScan no match + it, err = uniqueIdx.ReversePrefixScan(ctx, []byte{byte('l')}, nil) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.Error(t, orm.ErrIteratorDone, err) + // create with same index key should fail + new := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("my-other")), + Weight: 10, + } + err = myTable.Create(ctx, &new) + require.Error(t, orm.ErrUniqueConstraint, err) + + // and when delete + err = myTable.Delete(ctx, &m) + require.NoError(t, err) + + // then no persistent element + assert.False(t, uniqueIdx.Has(ctx, indexedKey)) +} + +func TestPrefixRange(t *testing.T) { + cases := map[string]struct { + src []byte + expStart []byte + expEnd []byte + expPanic bool + }{ + "normal": {src: []byte{1, 3, 4}, expStart: []byte{1, 3, 4}, expEnd: []byte{1, 3, 5}}, + "normal short": {src: []byte{79}, expStart: []byte{79}, expEnd: []byte{80}}, + "empty case": {src: []byte{}}, + "roll-over example 1": {src: []byte{17, 28, 255}, expStart: []byte{17, 28, 255}, expEnd: []byte{17, 29, 0}}, + "roll-over example 2": {src: []byte{15, 42, 255, 255}, expStart: []byte{15, 42, 255, 255}, expEnd: []byte{15, 43, 0, 0}}, + "pathological roll-over": {src: []byte{255, 255, 255, 255}, expStart: []byte{255, 255, 255, 255}}, + "nil prohibited": {expPanic: true}, + } + + for testName, tc := range cases { + t.Run(testName, func(t *testing.T) { + if tc.expPanic { + require.Panics(t, func() { + orm.PrefixRange(tc.src) + }) + return + } + start, end := orm.PrefixRange(tc.src) + assert.Equal(t, tc.expStart, start) + assert.Equal(t, tc.expEnd, end) + }) + } +} diff --git a/orm/indexer.go b/orm/indexer.go new file mode 100644 index 000000000000..b02307b353ed --- /dev/null +++ b/orm/indexer.go @@ -0,0 +1,158 @@ +package orm + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +// IndexerFunc creates one or multiple index keys for the source object. +type IndexerFunc func(value interface{}) ([]RowID, error) + +// IndexerFunc creates exactly one index key for the source object. +type UniqueIndexerFunc func(value interface{}) (RowID, error) + +// Indexer manages the persistence for an Index based on searchable keys and operations. +type Indexer struct { + indexerFunc IndexerFunc + addFunc func(store sdk.KVStore, codec IndexKeyCodec, secondaryIndexKey []byte, rowID RowID) error + indexKeyCodec IndexKeyCodec +} + +// NewIndexer returns an indexer that supports multiple reference keys for an entity. +func NewIndexer(indexerFunc IndexerFunc, codec IndexKeyCodec) *Indexer { + if indexerFunc == nil { + panic("Indexer func must not be nil") + } + if codec == nil { + panic("IndexKeyCodec must not be nil") + } + return &Indexer{ + indexerFunc: pruneEmptyKeys(indexerFunc), + addFunc: multiKeyAddFunc, + indexKeyCodec: codec, + } +} + +// NewUniqueIndexer returns an indexer that requires exactly one reference keys for an entity. +func NewUniqueIndexer(f UniqueIndexerFunc, codec IndexKeyCodec) *Indexer { + if f == nil { + panic("indexer func must not be nil") + } + adaptor := func(indexerFunc UniqueIndexerFunc) IndexerFunc { + return func(v interface{}) ([]RowID, error) { + k, err := indexerFunc(v) + return []RowID{k}, err + } + } + idx := NewIndexer(adaptor(f), codec) + idx.addFunc = uniqueKeysAddFunc + return idx +} + +// OnCreate persists the secondary index entries for the new object. +func (i Indexer) OnCreate(store sdk.KVStore, rowID RowID, value interface{}) error { + secondaryIndexKeys, err := i.indexerFunc(value) + if err != nil { + return err + } + + for _, secondaryIndexKey := range secondaryIndexKeys { + if err := i.addFunc(store, i.indexKeyCodec, secondaryIndexKey, rowID); err != nil { + return err + } + } + return nil +} + +// OnDelete removes the secondary index entries for the deleted object. +func (i Indexer) OnDelete(store sdk.KVStore, rowID RowID, value interface{}) error { + secondaryIndexKeys, err := i.indexerFunc(value) + if err != nil { + return err + } + + for _, secondaryIndexKey := range secondaryIndexKeys { + indexKey := i.indexKeyCodec.BuildIndexKey(secondaryIndexKey, rowID) + store.Delete(indexKey) + } + return nil +} + +// OnUpdate rebuilds the secondary index entries for the updated object. +func (i Indexer) OnUpdate(store sdk.KVStore, rowID RowID, newValue, oldValue interface{}) error { + oldSecIdxKeys, err := i.indexerFunc(oldValue) + if err != nil { + return err + } + newSecIdxKeys, err := i.indexerFunc(newValue) + if err != nil { + return err + } + for _, oldIdxKey := range difference(oldSecIdxKeys, newSecIdxKeys) { + store.Delete(i.indexKeyCodec.BuildIndexKey(oldIdxKey, rowID)) + } + for _, newIdxKey := range difference(newSecIdxKeys, oldSecIdxKeys) { + if err := i.addFunc(store, i.indexKeyCodec, newIdxKey, rowID); err != nil { + return err + } + } + return nil +} + +// uniqueKeysAddFunc enforces keys to be unique +func uniqueKeysAddFunc(store sdk.KVStore, codec IndexKeyCodec, secondaryIndexKey []byte, rowID RowID) error { + if len(secondaryIndexKey) == 0 { + return errors.Wrap(ErrArgument, "empty index key") + } + it := store.Iterator(PrefixRange(secondaryIndexKey)) + defer it.Close() + if it.Valid() { + return ErrUniqueConstraint + } + indexKey := codec.BuildIndexKey(secondaryIndexKey, rowID) + store.Set(indexKey, []byte{}) + return nil +} + +// multiKeyAddFunc allows multiple entries for a key +func multiKeyAddFunc(store sdk.KVStore, codec IndexKeyCodec, secondaryIndexKey []byte, rowID RowID) error { + if len(secondaryIndexKey) == 0 { + return errors.Wrap(ErrArgument, "empty index key") + } + + indexKey := codec.BuildIndexKey(secondaryIndexKey, rowID) + store.Set(indexKey, []byte{}) + return nil +} + +// difference returns the list of elements that are in a but not in b. +func difference(a []RowID, b []RowID) []RowID { + set := make(map[string]struct{}, len(b)) + for _, v := range b { + set[string(v)] = struct{}{} + } + var result []RowID + for _, v := range a { + if _, ok := set[string(v)]; !ok { + result = append(result, v) + } + } + return result +} + +// pruneEmptyKeys drops any empty key from IndexerFunc f returned +func pruneEmptyKeys(f IndexerFunc) IndexerFunc { + return func(v interface{}) ([]RowID, error) { + keys, err := f(v) + if err != nil || keys == nil { + return keys, err + } + r := make([]RowID, 0, len(keys)) + for i := range keys { + if len(keys[i]) != 0 { + r = append(r, keys[i]) + } + } + return r, nil + } +} diff --git a/orm/indexer_test.go b/orm/indexer_test.go new file mode 100644 index 000000000000..86c362673c06 --- /dev/null +++ b/orm/indexer_test.go @@ -0,0 +1,528 @@ +package orm + +import ( + stdErrors "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +func TestIndexerOnCreate(t *testing.T) { + var myRowID RowID = EncodeSequence(1) + + specs := map[string]struct { + srcFunc IndexerFunc + expIndexKeys []RowID + expRowIDs []RowID + expAddFuncCalled bool + expErr error + }{ + "single key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + expAddFuncCalled: true, + expIndexKeys: []RowID{{0, 0, 0, 0, 0, 0, 0, 1}}, + expRowIDs: []RowID{myRowID}, + }, + "multi key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{0, 0, 0, 0, 0, 0, 0, 1}, {1, 0, 0, 0, 0, 0, 0, 0}}, nil + }, + expAddFuncCalled: true, + expIndexKeys: []RowID{{0, 0, 0, 0, 0, 0, 0, 1}, {1, 0, 0, 0, 0, 0, 0, 0}}, + expRowIDs: []RowID{myRowID, myRowID}, + }, + "empty key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{}}, nil + }, + expAddFuncCalled: false, + }, + "nil key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{nil}, nil + }, + expAddFuncCalled: false, + }, + "empty key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{}, nil + }, + expAddFuncCalled: false, + }, + "nil key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, nil + }, + expAddFuncCalled: false, + }, + "error case": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, stdErrors.New("test") + }, + expErr: stdErrors.New("test"), + expAddFuncCalled: false, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + mockPolicy := &addFuncRecorder{} + idx := NewIndexer(spec.srcFunc, Max255DynamicLengthIndexKeyCodec{}) + idx.addFunc = mockPolicy.add + + err := idx.OnCreate(nil, myRowID, nil) + if spec.expErr != nil { + require.Equal(t, spec.expErr, err) + return + } + require.NoError(t, err) + assert.Equal(t, spec.expIndexKeys, mockPolicy.secondaryIndexKeys) + assert.Equal(t, spec.expRowIDs, mockPolicy.rowIDs) + assert.Equal(t, spec.expAddFuncCalled, mockPolicy.called) + }) + } +} + +func TestIndexerOnDelete(t *testing.T) { + myRowID := EncodeSequence(1) + + var multiKeyIndex MultiKeyIndex + ctx := NewMockContext() + storeKey := sdk.NewKVStoreKey("test") + store := prefix.NewStore(ctx.KVStore(storeKey), []byte{multiKeyIndex.prefix}) + + specs := map[string]struct { + srcFunc IndexerFunc + expIndexKeys []RowID + expErr error + }{ + "single key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + expIndexKeys: []RowID{append([]byte{0, 0, 0, 0, 0, 0, 0, 1}, myRowID...)}, + }, + "multi key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{0, 0, 0, 0, 0, 0, 0, 1}, {1, 0, 0, 0, 0, 0, 0, 0}}, nil + }, + expIndexKeys: []RowID{ + append([]byte{0, 0, 0, 0, 0, 0, 0, 1}, myRowID...), + append([]byte{1, 0, 0, 0, 0, 0, 0, 0}, myRowID...), + }, + }, + "empty key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{}, nil + }, + }, + "nil key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, nil + }, + }, + "empty key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{}}, nil + }, + }, + "nil key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{nil}, nil + }, + }, + "error case": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, stdErrors.New("test") + }, + expErr: stdErrors.New("test"), + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + codec := FixLengthIndexKeys(EncodedSeqLength) + idx := NewIndexer(spec.srcFunc, codec) + err := idx.OnDelete(store, myRowID, nil) + if spec.expErr != nil { + require.Equal(t, spec.expErr, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestIndexerOnUpdate(t *testing.T) { + myRowID := EncodeSequence(1) + codec := FixLengthIndexKeys(EncodedSeqLength) + + var multiKeyIndex MultiKeyIndex + ctx := NewMockContext() + storeKey := sdk.NewKVStoreKey("test") + store := prefix.NewStore(ctx.KVStore(storeKey), []byte{multiKeyIndex.prefix}) + + specs := map[string]struct { + srcFunc IndexerFunc + mockStore *updateKVStoreRecorder + expAddedKeys []RowID + expDeletedKeys []RowID + expErr error + addFunc func(sdk.KVStore, IndexKeyCodec, []byte, RowID) error + }{ + "single key - same key, no update": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{EncodeSequence(1)}, nil + }, + }, + "single key - different key, replaced": { + srcFunc: func(value interface{}) ([]RowID, error) { + keys := []RowID{EncodeSequence(1), EncodeSequence(2)} + return []RowID{keys[value.(int)]}, nil + }, + expAddedKeys: []RowID{ + append(EncodeSequence(2), myRowID...), + }, + expDeletedKeys: []RowID{ + append(EncodeSequence(1), myRowID...), + }, + }, + "multi key - same key, no update": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{EncodeSequence(1), EncodeSequence(2)}, nil + }, + }, + "multi key - replaced": { + srcFunc: func(value interface{}) ([]RowID, error) { + keys := []RowID{EncodeSequence(1), EncodeSequence(2), EncodeSequence(3), EncodeSequence(4)} + return []RowID{keys[value.(int)], keys[value.(int)+2]}, nil + }, + expAddedKeys: []RowID{ + append(EncodeSequence(2), myRowID...), + append(EncodeSequence(4), myRowID...), + }, + expDeletedKeys: []RowID{ + append(EncodeSequence(1), myRowID...), + append(EncodeSequence(3), myRowID...), + }, + }, + "empty key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{}, nil + }, + }, + "nil key": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, nil + }, + }, + "empty key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{{}}, nil + }, + }, + "nil key in slice": { + srcFunc: func(value interface{}) ([]RowID, error) { + return []RowID{nil}, nil + }, + }, + "error case with new value": { + srcFunc: func(value interface{}) ([]RowID, error) { + return nil, stdErrors.New("test") + }, + expErr: stdErrors.New("test"), + }, + "error case with old value": { + srcFunc: func(value interface{}) ([]RowID, error) { + var err error + if value.(int)%2 == 1 { + err = stdErrors.New("test") + } + return []RowID{myRowID}, err + }, + expErr: stdErrors.New("test"), + }, + "error case on persisting new keys": { + srcFunc: func(value interface{}) ([]RowID, error) { + keys := []RowID{EncodeSequence(1), EncodeSequence(2)} + return []RowID{keys[value.(int)]}, nil + }, + addFunc: func(_ sdk.KVStore, _ IndexKeyCodec, _ []byte, _ RowID) error { + return stdErrors.New("test") + }, + expErr: stdErrors.New("test"), + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + idx := NewIndexer(spec.srcFunc, codec) + if spec.addFunc != nil { + idx.addFunc = spec.addFunc + } + err := idx.OnUpdate(store, myRowID, 1, 0) + if spec.expErr != nil { + require.Equal(t, spec.expErr, err) + return + } + require.NoError(t, err) + + }) + } +} + +func TestUniqueKeyAddFunc(t *testing.T) { + myRowID := EncodeSequence(1) + myPresetKey := append([]byte("my-preset-key"), myRowID...) + + specs := map[string]struct { + srcKey []byte + expErr *errors.Error + expExistingEntry []byte + }{ + + "create when not exists": { + srcKey: []byte("my-index-key"), + expExistingEntry: append([]byte("my-index-key"), myRowID...), + }, + "error when exists already": { + srcKey: []byte("my-preset-key"), + expErr: ErrUniqueConstraint, + }, + "nil key not allowed": { + srcKey: nil, + expErr: ErrArgument, + }, + "empty key not allowed": { + srcKey: []byte{}, + expErr: ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + storeKey := sdk.NewKVStoreKey("test") + store := NewMockContext().KVStore(storeKey) + store.Set(myPresetKey, []byte{}) + + codec := FixLengthIndexKeys(EncodedSeqLength) + err := uniqueKeysAddFunc(store, codec, spec.srcKey, myRowID) + require.True(t, spec.expErr.Is(err)) + if spec.expErr != nil { + return + } + assert.True(t, store.Has(spec.expExistingEntry), "not found") + }) + } +} + +func TestMultiKeyAddFunc(t *testing.T) { + myRowID := EncodeSequence(1) + myPresetKey := append([]byte("my-preset-key"), myRowID...) + + specs := map[string]struct { + srcKey []byte + expErr *errors.Error + expExistingEntry []byte + }{ + + "create when not exists": { + srcKey: []byte("my-index-key"), + expExistingEntry: append([]byte("my-index-key"), myRowID...), + }, + "noop when exists already": { + srcKey: []byte("my-preset-key"), + expExistingEntry: myPresetKey, + }, + "nil key not allowed": { + srcKey: nil, + expErr: ErrArgument, + }, + "empty key not allowed": { + srcKey: []byte{}, + expErr: ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + storeKey := sdk.NewKVStoreKey("test") + store := NewMockContext().KVStore(storeKey) + store.Set(myPresetKey, []byte{}) + + codec := FixLengthIndexKeys(EncodedSeqLength) + err := multiKeyAddFunc(store, codec, spec.srcKey, myRowID) + require.True(t, spec.expErr.Is(err)) + if spec.expErr != nil { + return + } + assert.True(t, store.Has(spec.expExistingEntry)) + }) + } +} + +func TestDifference(t *testing.T) { + asByte := func(s []string) []RowID { + r := make([]RowID, len(s)) + for i := 0; i < len(s); i++ { + r[i] = []byte(s[i]) + } + return r + } + + specs := map[string]struct { + srcA []string + srcB []string + expResult []RowID + }{ + "all of A": { + srcA: []string{"a", "b"}, + srcB: []string{"c"}, + expResult: []RowID{[]byte("a"), []byte("b")}, + }, + "A - B": { + srcA: []string{"a", "b"}, + srcB: []string{"b", "c", "d"}, + expResult: []RowID{[]byte("a")}, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + got := difference(asByte(spec.srcA), asByte(spec.srcB)) + assert.Equal(t, spec.expResult, got) + }) + } +} + +func TestPruneEmptyKeys(t *testing.T) { + specs := map[string]struct { + srcFunc IndexerFunc + expResult []RowID + expError error + }{ + "non empty": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{0}, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "empty": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{}, nil + }, + expResult: []RowID{}, + }, + "nil": { + srcFunc: func(v interface{}) ([]RowID, error) { + return nil, nil + }, + }, + "nil in the beginning": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{nil, {0}, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "nil in the middle": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{0}, nil, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "nil at the end": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{0}, nil, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "empty in the beginning": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{}, {0}, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "empty in the middle": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{0}, {}, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "empty at the end": { + srcFunc: func(v interface{}) ([]RowID, error) { + return []RowID{{0}, {}, {1}}, nil + }, + expResult: []RowID{{0}, {1}}, + }, + "error passed": { + srcFunc: func(v interface{}) ([]RowID, error) { + return nil, stdErrors.New("test") + }, + expError: stdErrors.New("test"), + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + r, err := pruneEmptyKeys(spec.srcFunc)(nil) + require.Equal(t, spec.expError, err) + if spec.expError != nil { + return + } + assert.Equal(t, spec.expResult, r) + }) + } +} + +type addFuncRecorder struct { + secondaryIndexKeys []RowID + rowIDs []RowID + called bool +} + +func (c *addFuncRecorder) add(_ sdk.KVStore, _ IndexKeyCodec, key []byte, rowID RowID) error { + c.secondaryIndexKeys = append(c.secondaryIndexKeys, key) + c.rowIDs = append(c.rowIDs, rowID) + c.called = true + return nil +} + +type deleteKVStoreRecorder struct { + AlwaysPanicKVStore + deletes []RowID +} + +func (m *deleteKVStoreRecorder) Delete(key []byte) { + m.deletes = append(m.deletes, key) +} + +type updateKVStoreRecorder struct { + deleteKVStoreRecorder + stored tuples + hasResult bool +} + +func (u *updateKVStoreRecorder) Set(key, value []byte) { + u.stored = append(u.stored, tuple{key, value}) +} + +func (u updateKVStoreRecorder) Has(key []byte) bool { + return u.hasResult +} + +type tuple struct { + key, val []byte +} + +type tuples []tuple + +func (t tuples) Keys() []RowID { + if t == nil { + return nil + } + r := make([]RowID, len(t)) + for i, v := range t { + r[i] = v.key + } + return r +} diff --git a/orm/iterator.go b/orm/iterator.go new file mode 100644 index 000000000000..234cb397914b --- /dev/null +++ b/orm/iterator.go @@ -0,0 +1,304 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/query" +) + +// IteratorFunc is a function type that satisfies the Iterator interface +// The passed function is called on LoadNext operations. +type IteratorFunc func(dest codec.ProtoMarshaler) (RowID, error) + +// LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there +// are no more items the ErrIteratorDone error is returned +// The key is the rowID and not any MultiKeyIndex key. +func (i IteratorFunc) LoadNext(dest codec.ProtoMarshaler) (RowID, error) { + return i(dest) +} + +// Close always returns nil +func (i IteratorFunc) Close() error { + return nil +} + +func NewSingleValueIterator(rowID RowID, val []byte) Iterator { + var closed bool + return IteratorFunc(func(dest codec.ProtoMarshaler) (RowID, error) { + if dest == nil { + return nil, errors.Wrap(ErrArgument, "destination object must not be nil") + } + if closed || val == nil { + return nil, ErrIteratorDone + } + closed = true + return rowID, dest.Unmarshal(val) + }) +} + +// Iterator that return ErrIteratorInvalid only. +func NewInvalidIterator() Iterator { + return IteratorFunc(func(dest codec.ProtoMarshaler) (RowID, error) { + return nil, ErrIteratorInvalid + }) +} + +// LimitedIterator returns up to defined maximum number of elements. +type LimitedIterator struct { + remainingCount int + parentIterator Iterator +} + +// LimitIterator returns a new iterator that returns max number of elements. +// The parent iterator must not be nil +// max can be 0 or any positive number +func LimitIterator(parent Iterator, max int) *LimitedIterator { + if max < 0 { + panic("quantity must not be negative") + } + if parent == nil { + panic("parent iterator must not be nil") + } + return &LimitedIterator{remainingCount: max, parentIterator: parent} +} + +// LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there +// are no more items or the defined max number of elements was returned the `ErrIteratorDone` error is returned +// The key is the rowID and not any MultiKeyIndex key. +func (i *LimitedIterator) LoadNext(dest codec.ProtoMarshaler) (RowID, error) { + if i.remainingCount == 0 { + return nil, ErrIteratorDone + } + i.remainingCount-- + return i.parentIterator.LoadNext(dest) +} + +// Close releases the iterator and should be called at the end of iteration +func (i LimitedIterator) Close() error { + return i.parentIterator.Close() +} + +// First loads the first element into the given destination type and closes the iterator. +// When the iterator is closed or has no elements the according error is passed as return value. +func First(it Iterator, dest codec.ProtoMarshaler) (RowID, error) { + if it == nil { + return nil, errors.Wrap(ErrArgument, "iterator must not be nil") + } + defer it.Close() + binKey, err := it.LoadNext(dest) + if err != nil { + return nil, err + } + return binKey, nil +} + +// Paginate does pagination with a given Iterator based on the provided +// PageRequest and unmarshals the results into the dest interface that must be +// an non-nil pointer to a slice. +// +// If pageRequest is nil, then we will use these default values: +// - Offset: 0 +// - Key: nil +// - Limit: 100 +// - CountTotal: true +// +// If pageRequest.Key was provided, it got used beforehand to instantiate the Iterator, +// using for instance UInt64Index.GetPaginated method. Only one of pageRequest.Offset or +// pageRequest.Key should be set. Using pageRequest.Key is more efficient for querying +// the next page. +// +// If pageRequest.CountTotal is set, we'll visit all iterators elements. +// pageRequest.CountTotal is only respected when offset is used. +// +// This function will call it.Close(). +func Paginate( + it Iterator, + pageRequest *query.PageRequest, + dest ModelSlicePtr, +) (*query.PageResponse, error) { + // if the PageRequest is nil, use default PageRequest + if pageRequest == nil { + pageRequest = &query.PageRequest{} + } + + offset := pageRequest.Offset + key := pageRequest.Key + limit := pageRequest.Limit + countTotal := pageRequest.CountTotal + + if offset > 0 && key != nil { + return nil, fmt.Errorf("invalid request, either offset or key is expected, got both") + } + + if limit == 0 { + limit = 100 + + // count total results when the limit is zero/not supplied + countTotal = true + } + + if it == nil { + return nil, errors.Wrap(ErrArgument, "iterator must not be nil") + } + defer it.Close() + + var destRef, tmpSlice reflect.Value + elemType, err := assertDest(dest, &destRef, &tmpSlice) + if err != nil { + return nil, err + } + + var end = offset + limit + var count uint64 + var nextKey []byte + for { + obj := reflect.New(elemType) + val := obj.Elem() + model := obj + if elemType.Kind() == reflect.Ptr { + val.Set(reflect.New(elemType.Elem())) + // if elemType is already a pointer (e.g. dest being some pointer to a slice of pointers, + // like []*GroupMember), then obj is a pointer to a pointer which might cause issues + // if we try to do obj.Interface().(codec.ProtoMarshaler). + // For that reason, we copy obj into model if we have a simple pointer + // but in case elemType.Kind() == reflect.Ptr, we overwrite it with model = val + // so we can safely call model.Interface().(codec.ProtoMarshaler) afterwards. + model = val + } + + modelProto, ok := model.Interface().(codec.ProtoMarshaler) + if !ok { + return nil, errors.Wrapf(ErrArgument, "%s should implement codec.ProtoMarshaler", elemType) + } + binKey, err := it.LoadNext(modelProto) + if err != nil { + if ErrIteratorDone.Is(err) { + break + } + return nil, err + } + + count++ + + // During the first loop, count value at this point will be 1, + // so if offset is >= 1, it will continue to load the next value until count > offset + // else (offset = 0, key might be set or not), + // it will start to append values to tmpSlice. + if count <= offset { + continue + } + + if count <= end { + tmpSlice = reflect.Append(tmpSlice, val) + } else if count == end+1 { + nextKey = binKey + + // countTotal is set to true to indicate that the result set should include + // a count of the total number of items available for pagination in UIs. + // countTotal is only respected when offset is used. It is ignored when key + // is set. + if !countTotal || len(key) != 0 { + break + } + } + } + destRef.Set(tmpSlice) + + res := &query.PageResponse{NextKey: nextKey} + if countTotal && len(key) == 0 { + res.Total = count + } + + return res, nil +} + +// ModelSlicePtr represents a pointer to a slice of models. Think of it as +// *[]Model Because of Go's type system, using []Model type would not work for us. +// Instead we use a placeholder type and the validation is done during the +// runtime. +type ModelSlicePtr interface{} + +// ReadAll consumes all values for the iterator and stores them in a new slice at the passed ModelSlicePtr. +// The slice can be empty when the iterator does not return any values but not nil. The iterator +// is closed afterwards. +// Example: +// var loaded []testdata.GroupInfo +// rowIDs, err := ReadAll(it, &loaded) +// require.NoError(t, err) +// +func ReadAll(it Iterator, dest ModelSlicePtr) ([]RowID, error) { + if it == nil { + return nil, errors.Wrap(ErrArgument, "iterator must not be nil") + } + defer it.Close() + + var destRef, tmpSlice reflect.Value + elemType, err := assertDest(dest, &destRef, &tmpSlice) + if err != nil { + return nil, err + } + + var rowIDs []RowID + for { + obj := reflect.New(elemType) + val := obj.Elem() + model := obj + if elemType.Kind() == reflect.Ptr { + val.Set(reflect.New(elemType.Elem())) + model = val + } + + binKey, err := it.LoadNext(model.Interface().(codec.ProtoMarshaler)) + switch { + case err == nil: + tmpSlice = reflect.Append(tmpSlice, val) + case ErrIteratorDone.Is(err): + destRef.Set(tmpSlice) + return rowIDs, nil + default: + return nil, err + } + rowIDs = append(rowIDs, binKey) + } +} + +// assertDest checks that the provided dest is not nil and a pointer to a slice. +// It also verifies that the slice elements implement *codec.ProtoMarshaler. +// It overwrites destRef and tmpSlice using reflection. +func assertDest(dest ModelSlicePtr, destRef *reflect.Value, tmpSlice *reflect.Value) (reflect.Type, error) { + if dest == nil { + return nil, errors.Wrap(ErrArgument, "destination must not be nil") + } + tp := reflect.ValueOf(dest) + if tp.Kind() != reflect.Ptr { + return nil, errors.Wrap(ErrArgument, "destination must be a pointer to a slice") + } + if tp.Elem().Kind() != reflect.Slice { + return nil, errors.Wrap(ErrArgument, "destination must point to a slice") + } + + // Since dest is just an interface{}, we overwrite destRef using reflection + // to have an assignable copy of it. + *destRef = tp.Elem() + // We need to verify that we can call Set() on destRef. + if !destRef.CanSet() { + return nil, errors.Wrap(ErrArgument, "destination not assignable") + } + + elemType := reflect.TypeOf(dest).Elem().Elem() + + protoMarshaler := reflect.TypeOf((*codec.ProtoMarshaler)(nil)).Elem() + if !elemType.Implements(protoMarshaler) && + !reflect.PtrTo(elemType).Implements(protoMarshaler) { + return nil, errors.Wrapf(ErrArgument, "unsupported type :%s", elemType) + } + + // tmpSlice is a slice value for the specified type + // that we'll use for appending new elements. + *tmpSlice = reflect.MakeSlice(reflect.SliceOf(elemType), 0, 0) + + return elemType, nil +} diff --git a/orm/iterator_test.go b/orm/iterator_test.go new file mode 100644 index 000000000000..b7bb514511b7 --- /dev/null +++ b/orm/iterator_test.go @@ -0,0 +1,263 @@ +package orm_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/query" +) + +func TestReadAll(t *testing.T) { + specs := map[string]struct { + srcIT orm.Iterator + destSlice func() orm.ModelSlicePtr + expErr *errors.Error + expIDs []orm.RowID + expResult orm.ModelSlicePtr + }{ + "all good with object slice": { + srcIT: mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{Description: "test"}), + destSlice: func() orm.ModelSlicePtr { + x := make([]testdata.GroupInfo, 1) + return &x + }, + expIDs: []orm.RowID{orm.EncodeSequence(1)}, + expResult: &[]testdata.GroupInfo{{Description: "test"}}, + }, + "all good with pointer slice": { + srcIT: mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{Description: "test"}), + destSlice: func() orm.ModelSlicePtr { + x := make([]*testdata.GroupInfo, 1) + return &x + }, + expIDs: []orm.RowID{orm.EncodeSequence(1)}, + expResult: &[]*testdata.GroupInfo{{Description: "test"}}, + }, + "dest slice empty": { + srcIT: mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{}), + destSlice: func() orm.ModelSlicePtr { + x := make([]testdata.GroupInfo, 0) + return &x + }, + expIDs: []orm.RowID{orm.EncodeSequence(1)}, + expResult: &[]testdata.GroupInfo{{}}, + }, + "dest pointer with nil value": { + srcIT: mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{}), + destSlice: func() orm.ModelSlicePtr { + return (*[]testdata.GroupInfo)(nil) + }, + expErr: orm.ErrArgument, + }, + "iterator is nil": { + srcIT: nil, + destSlice: func() orm.ModelSlicePtr { return new([]testdata.GroupInfo) }, + expErr: orm.ErrArgument, + }, + "dest slice is nil": { + srcIT: noopIter(), + destSlice: func() orm.ModelSlicePtr { return nil }, + expErr: orm.ErrArgument, + }, + "dest slice is not a pointer": { + srcIT: orm.IteratorFunc(nil), + destSlice: func() orm.ModelSlicePtr { return make([]testdata.GroupInfo, 1) }, + expErr: orm.ErrArgument, + }, + "error on loadNext is returned": { + srcIT: orm.NewInvalidIterator(), + destSlice: func() orm.ModelSlicePtr { + x := make([]testdata.GroupInfo, 1) + return &x + }, + expErr: orm.ErrIteratorInvalid, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + loaded := spec.destSlice() + ids, err := orm.ReadAll(spec.srcIT, loaded) + require.True(t, spec.expErr.Is(err), "expected %s but got %s", spec.expErr, err) + assert.Equal(t, spec.expIDs, ids) + if err == nil { + assert.Equal(t, spec.expResult, loaded) + } + }) + } +} + +func TestLimitedIterator(t *testing.T) { + specs := map[string]struct { + src orm.Iterator + exp []testdata.GroupInfo + }{ + "all from range with max > length": { + src: orm.LimitIterator(mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{Description: "test"}), 2), + exp: []testdata.GroupInfo{testdata.GroupInfo{Description: "test"}}, + }, + "up to max": { + src: orm.LimitIterator(mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{Description: "test"}), 1), + exp: []testdata.GroupInfo{testdata.GroupInfo{Description: "test"}}, + }, + "none when max = 0": { + src: orm.LimitIterator(mockIter(orm.EncodeSequence(1), &testdata.GroupInfo{Description: "test"}), 0), + exp: []testdata.GroupInfo{}, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + var loaded []testdata.GroupInfo + _, err := orm.ReadAll(spec.src, &loaded) + require.NoError(t, err) + assert.EqualValues(t, spec.exp, loaded) + }) + } +} + +func TestPaginate(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const ( + testTablePrefix = iota + testTableSeqPrefix + ) + tBuilder := orm.NewAutoUInt64TableBuilder(testTablePrefix, testTableSeqPrefix, storeKey, &testdata.GroupInfo{}, cdc) + idx := orm.NewIndex(tBuilder, GroupByAdminIndexPrefix, func(val interface{}) ([]orm.RowID, error) { + return []orm.RowID{[]byte(val.(*testdata.GroupInfo).Admin)}, nil + }) + tb := tBuilder.Build() + ctx := orm.NewMockContext() + + admin := sdk.AccAddress([]byte("admin-address")) + g1 := testdata.GroupInfo{ + Description: "my test 1", + Admin: admin, + } + g2 := testdata.GroupInfo{ + Description: "my test 2", + Admin: admin, + } + g3 := testdata.GroupInfo{ + Description: "my test 3", + Admin: sdk.AccAddress([]byte("other-admin-address")), + } + g4 := testdata.GroupInfo{ + Description: "my test 4", + Admin: admin, + } + g5 := testdata.GroupInfo{ + Description: "my test 5", + Admin: sdk.AccAddress([]byte("other-admin-address")), + } + + for _, g := range []testdata.GroupInfo{g1, g2, g3, g4, g5} { + _, err := tb.Create(ctx, &g) + require.NoError(t, err) + } + + specs := map[string]struct { + pageReq *query.PageRequest + expPageRes *query.PageResponse + exp []testdata.GroupInfo + key []byte + expErr bool + }{ + "one item": { + pageReq: &query.PageRequest{Key: nil, Limit: 1}, + exp: []testdata.GroupInfo{g1}, + expPageRes: &query.PageResponse{Total: 0, NextKey: orm.EncodeSequence(2)}, + key: admin, + }, + "with both key and offset": { + pageReq: &query.PageRequest{Key: orm.EncodeSequence(2), Offset: 1}, + expErr: true, + key: admin, + }, + "up to max": { + pageReq: &query.PageRequest{Key: nil, Limit: 3, CountTotal: true}, + exp: []testdata.GroupInfo{g1, g2, g4}, + expPageRes: &query.PageResponse{Total: 3, NextKey: nil}, + key: admin, + }, + "no results": { + pageReq: &query.PageRequest{Key: nil, Limit: 2, CountTotal: true}, + exp: []testdata.GroupInfo{}, + expPageRes: &query.PageResponse{Total: 0, NextKey: nil}, + key: sdk.AccAddress([]byte("no-group-address")), + }, + "with offset and count total": { + pageReq: &query.PageRequest{Key: nil, Offset: 1, Limit: 2, CountTotal: true}, + exp: []testdata.GroupInfo{g2, g4}, + expPageRes: &query.PageResponse{Total: 3, NextKey: nil}, + key: admin, + }, + "nil/default page req (limit = 100 > number of items)": { + pageReq: nil, + exp: []testdata.GroupInfo{g1, g2, g4}, + expPageRes: &query.PageResponse{Total: 3, NextKey: nil}, + key: admin, + }, + "with key and limit < number of elem (count total is ignored in this case)": { + pageReq: &query.PageRequest{Key: orm.EncodeSequence(2), Limit: 1, CountTotal: true}, + exp: []testdata.GroupInfo{g2}, + expPageRes: &query.PageResponse{Total: 0, NextKey: orm.EncodeSequence(4)}, + key: admin, + }, + "with key and limit >= number of elem": { + pageReq: &query.PageRequest{Key: orm.EncodeSequence(2), Limit: 2}, + exp: []testdata.GroupInfo{g2, g4}, + expPageRes: &query.PageResponse{Total: 0, NextKey: nil}, + key: admin, + }, + "with nothing left to iterate from key": { + pageReq: &query.PageRequest{Key: orm.EncodeSequence(5)}, + exp: []testdata.GroupInfo{}, + expPageRes: &query.PageResponse{Total: 0, NextKey: nil}, + key: admin, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + var loaded []testdata.GroupInfo + + it, err := idx.GetPaginated(ctx, spec.key, spec.pageReq) + require.NoError(t, err) + + res, err := orm.Paginate(it, spec.pageReq, &loaded) + if spec.expErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.EqualValues(t, spec.exp, loaded) + assert.EqualValues(t, spec.expPageRes.Total, res.Total) + assert.EqualValues(t, spec.expPageRes.NextKey, res.NextKey) + } + + }) + } +} + +// mockIter amino encodes + decodes value object. +func mockIter(rowID orm.RowID, val codec.ProtoMarshaler) orm.Iterator { + b, err := val.Marshal() + if err != nil { + panic(err) + } + return orm.NewSingleValueIterator(rowID, b) +} + +func noopIter() orm.Iterator { + return orm.IteratorFunc(func(dest codec.ProtoMarshaler) (orm.RowID, error) { + return nil, nil + }) +} diff --git a/orm/orm.go b/orm/orm.go new file mode 100644 index 000000000000..79a66d00ed6e --- /dev/null +++ b/orm/orm.go @@ -0,0 +1,176 @@ +/* +Package orm is a convenient object to data store mapper. +*/ +package orm + +import ( + "io" + "reflect" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/query" +) + +const ormCodespace = "orm" + +var ( + ErrNotFound = errors.Register(ormCodespace, 100, "not found") + ErrIteratorDone = errors.Register(ormCodespace, 101, "iterator done") + ErrIteratorInvalid = errors.Register(ormCodespace, 102, "iterator invalid") + ErrType = errors.Register(ormCodespace, 110, "invalid type") + ErrUniqueConstraint = errors.Register(ormCodespace, 111, "unique constraint violation") + ErrArgument = errors.Register(ormCodespace, 112, "invalid argument") + ErrIndexKeyMaxLength = errors.Register(ormCodespace, 113, "index key exceeds max length") +) + +// HasKVStore is a subset of the cosmos-sdk context defined for loose coupling and simpler test setups. +type HasKVStore interface { + KVStore(key sdk.StoreKey) sdk.KVStore +} + +// Unique identifier of a persistent table. +type RowID []byte + +// Bytes returns raw bytes. +func (r RowID) Bytes() []byte { + return r +} + +// Validateable is an interface that Persistent types can implement and is called on any orm save or update operation. +type Validateable interface { + // ValidateBasic is a sanity check on the data. Any error returned prevents create or updates. + ValidateBasic() error +} + +// Persistent supports Marshal and Unmarshal +// +// This is separated from Marshal, as this almost always requires +// a pointer, and functions that only need to marshal bytes can +// use the Marshaller interface to access non-pointers. +// +// As with Marshaller, this may do internal validation on the data +// and errors should be expected. +type Persistent interface { + // Marshal serializes object into binary representation + Marshal() ([]byte, error) + // Unmarshal deserializes the object from the binary representation + Unmarshal([]byte) error +} + +// Index allows efficient prefix scans is stored as key = concat(indexKeyBytes, rowIDUint64) with value empty +// so that the row PrimaryKey is allows a fixed with 8 byte integer. This allows the MultiKeyIndex key bytes to be +// variable length and scanned iteratively. The +type Index interface { + // Has checks if a key exists. Panics on nil key. + Has(ctx HasKVStore, key []byte) bool + + // Get returns a result iterator for the searchKey. + // searchKey must not be nil. + Get(ctx HasKVStore, searchKey []byte) (Iterator, error) + + // GetPaginated returns a result iterator for the searchKey and optional pageRequest. + // searchKey must not be nil. + GetPaginated(ctx HasKVStore, searchKey []byte, pageRequest *query.PageRequest) (Iterator, error) + + // PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. + // Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. + // Iterator must be closed by caller. + // To iterate over entire domain, use PrefixScan(nil, nil) + // + // WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose + // this as an endpoint to the public without further limits. + // Example: + // it, err := idx.PrefixScan(ctx, start, end) + // if err !=nil { + // return err + // } + // const defaultLimit = 20 + // it = LimitIterator(it, defaultLimit) + // + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + PrefixScan(ctx HasKVStore, start []byte, end []byte) (Iterator, error) + + // ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. + // Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. + // Iterator must be closed by caller. + // To iterate over entire domain, use PrefixScan(nil, nil) + // + // WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose + // this as an endpoint to the public without further limits. See `LimitIterator` + // + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + ReversePrefixScan(ctx HasKVStore, start []byte, end []byte) (Iterator, error) +} + +// Iterator allows iteration through a sequence of key value pairs +type Iterator interface { + // LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there + // are no more items the ErrIteratorDone error is returned + // The key is the rowID and not any MultiKeyIndex key. + LoadNext(dest codec.ProtoMarshaler) (RowID, error) + // Close releases the iterator and should be called at the end of iteration + io.Closer +} + +// IndexKeyCodec defines the encoding/ decoding methods for building/ splitting index keys. +type IndexKeyCodec interface { + // BuildIndexKey encodes a searchable key and the target RowID. + BuildIndexKey(searchableKey []byte, rowID RowID) []byte + // StripRowID returns the RowID from the combined persistentIndexKey. It is the reverse operation to BuildIndexKey + // but with the searchableKey dropped. + StripRowID(persistentIndexKey []byte) RowID +} + +// Indexable types are used to setup new tables. +// This interface provides a set of functions that can be called by indexes to register and interact with the tables. +type Indexable interface { + StoreKey() sdk.StoreKey + RowGetter() RowGetter + IndexKeyCodec() IndexKeyCodec + AddAfterSaveInterceptor(interceptor AfterSaveInterceptor) + AddAfterDeleteInterceptor(interceptor AfterDeleteInterceptor) +} + +// AfterSaveInterceptor defines a callback function to be called on Create + Update. +type AfterSaveInterceptor func(ctx HasKVStore, rowID RowID, newValue, oldValue codec.ProtoMarshaler) error + +// AfterDeleteInterceptor defines a callback function to be called on Delete operations. +type AfterDeleteInterceptor func(ctx HasKVStore, rowID RowID, value codec.ProtoMarshaler) error + +// RowGetter loads a persistent object by row ID into the destination object. The dest parameter must therefore be a pointer. +// Any implementation must return `ErrNotFound` when no object for the rowID exists +type RowGetter func(ctx HasKVStore, rowID RowID, dest codec.ProtoMarshaler) error + +// NewTypeSafeRowGetter returns a `RowGetter` with type check on the dest parameter. +func NewTypeSafeRowGetter(storeKey sdk.StoreKey, prefixKey byte, model reflect.Type, cdc codec.Codec) RowGetter { + return func(ctx HasKVStore, rowID RowID, dest codec.ProtoMarshaler) error { + if len(rowID) == 0 { + return errors.Wrap(ErrArgument, "key must not be nil") + } + if err := assertCorrectType(model, dest); err != nil { + return err + } + + store := prefix.NewStore(ctx.KVStore(storeKey), []byte{prefixKey}) + it := store.Iterator(PrefixRange(rowID)) + defer it.Close() + if !it.Valid() { + return ErrNotFound + } + return cdc.Unmarshal(it.Value(), dest) + } +} + +func assertCorrectType(model reflect.Type, obj codec.ProtoMarshaler) error { + tp := reflect.TypeOf(obj) + if tp.Kind() != reflect.Ptr { + return errors.Wrap(ErrType, "model destination must be a pointer") + } + if model != tp.Elem() { + return errors.Wrapf(ErrType, "can not use %T with this bucket", obj) + } + return nil +} diff --git a/orm/orm_scenario_test.go b/orm/orm_scenario_test.go new file mode 100644 index 000000000000..971611f647ba --- /dev/null +++ b/orm/orm_scenario_test.go @@ -0,0 +1,405 @@ +package orm_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// Testing ORM with arbitrary address length +const addrLen = 10 + +func TestKeeperEndToEndWithAutoUInt64Table(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + + k := NewGroupKeeper(storeKey, cdc) + + g := testdata.GroupInfo{ + GroupId: 1, + Description: "my test", + Admin: sdk.AccAddress([]byte("admin-address")), + } + // when stored + rowID, err := k.groupTable.Create(ctx, &g) + require.NoError(t, err) + // then we should find it + exists := k.groupTable.Has(ctx, rowID) + require.True(t, exists) + + // and load it + var loaded testdata.GroupInfo + + binKey, err := k.groupTable.GetOne(ctx, rowID, &loaded) + require.NoError(t, err) + + assert.Equal(t, rowID, binary.BigEndian.Uint64(binKey)) + assert.Equal(t, "my test", loaded.Description) + assert.Equal(t, sdk.AccAddress([]byte("admin-address")), loaded.Admin) + + // and exists in MultiKeyIndex + exists = k.groupByAdminIndex.Has(ctx, []byte("admin-address")) + require.True(t, exists) + + // and when loaded + it, err := k.groupByAdminIndex.Get(ctx, []byte("admin-address")) + require.NoError(t, err) + + // then + binKey, loaded = first(t, it) + assert.Equal(t, rowID, binary.BigEndian.Uint64(binKey)) + assert.Equal(t, g, loaded) + + // when updated + g.Admin = []byte("new-admin-address") + err = k.groupTable.Save(ctx, rowID, &g) + require.NoError(t, err) + + // then indexes are updated, too + exists = k.groupByAdminIndex.Has(ctx, []byte("new-admin-address")) + require.True(t, exists) + + exists = k.groupByAdminIndex.Has(ctx, []byte("admin-address")) + require.False(t, exists) + + // when deleted + err = k.groupTable.Delete(ctx, rowID) + require.NoError(t, err) + + // then removed from primary MultiKeyIndex + exists = k.groupTable.Has(ctx, rowID) + require.False(t, exists) + + // and also removed from secondary MultiKeyIndex + exists = k.groupByAdminIndex.Has(ctx, []byte("new-admin-address")) + require.False(t, exists) +} + +func TestKeeperEndToEndWithPrimaryKeyTable(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + + k := NewGroupKeeper(storeKey, cdc) + + g := testdata.GroupInfo{ + GroupId: 1, + Description: "my test", + Admin: sdk.AccAddress([]byte("admin-address")), + } + + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 10, + } + groupRowID, err := k.groupTable.Create(ctx, &g) + require.NoError(t, err) + require.Equal(t, uint64(1), groupRowID) + // when stored + err = k.groupMemberTable.Create(ctx, &m) + require.NoError(t, err) + + // then we should find it by primary key + primaryKey := m.PrimaryKey() + exists := k.groupMemberTable.Has(ctx, primaryKey) + require.True(t, exists) + // and load it by primary key + var loaded testdata.GroupMember + err = k.groupMemberTable.GetOne(ctx, primaryKey, &loaded) + require.NoError(t, err) + + // then values should match expectations + require.Equal(t, m, loaded) + + // and then the data should exists in MultiKeyIndex + exists = k.groupMemberByGroupIndex.Has(ctx, orm.EncodeSequence(groupRowID)) + require.True(t, exists) + + // and when loaded from MultiKeyIndex + it, err := k.groupMemberByGroupIndex.Get(ctx, orm.EncodeSequence(groupRowID)) + require.NoError(t, err) + + // then values should match as before + _, err = orm.First(it, &loaded) + require.NoError(t, err) + + assert.Equal(t, m, loaded) + // and when we create another entry with the same primary key + err = k.groupMemberTable.Create(ctx, &m) + // then it should fail as the primary key must be unique + require.True(t, orm.ErrUniqueConstraint.Is(err), err) + + // and when entity updated with new primary key + updatedMember := &testdata.GroupMember{ + Group: m.Group, + Member: []byte("new-member-address"), + Weight: m.Weight, + } + // then it should fail as the primary key is immutable + err = k.groupMemberTable.Save(ctx, updatedMember) + require.Error(t, err) + + // and when entity updated with non primary key attribute modified + updatedMember = &testdata.GroupMember{ + Group: m.Group, + Member: m.Member, + Weight: 99, + } + // then it should not fail + err = k.groupMemberTable.Save(ctx, updatedMember) + require.NoError(t, err) + + // and when entity deleted + err = k.groupMemberTable.Delete(ctx, &m) + require.NoError(t, err) + + // then it is removed from primary key MultiKeyIndex + exists = k.groupMemberTable.Has(ctx, primaryKey) + require.False(t, exists) + + // and removed from secondary MultiKeyIndex + exists = k.groupMemberByGroupIndex.Has(ctx, orm.EncodeSequence(groupRowID)) + require.False(t, exists) +} + +func TestGasCostsPrimaryKeyTable(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + + k := NewGroupKeeper(storeKey, cdc) + + g := testdata.GroupInfo{ + GroupId: 1, + Description: "my test", + Admin: sdk.AccAddress([]byte("admin-address")), + } + + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 10, + } + groupRowID, err := k.groupTable.Create(ctx, &g) + require.NoError(t, err) + require.Equal(t, uint64(1), groupRowID) + gCtx := orm.NewGasCountingMockContext(ctx) + err = k.groupMemberTable.Create(gCtx, &m) + require.NoError(t, err) + t.Logf("gas consumed on create: %d", gCtx.GasConsumed()) + + // get by primary key + gCtx.ResetGasMeter() + var loaded testdata.GroupMember + err = k.groupMemberTable.GetOne(gCtx, m.PrimaryKey(), &loaded) + require.NoError(t, err) + t.Logf("gas consumed on get by primary key: %d", gCtx.GasConsumed()) + + // get by secondary index + gCtx.ResetGasMeter() + // and when loaded from MultiKeyIndex + it, err := k.groupMemberByGroupIndex.Get(gCtx, orm.EncodeSequence(groupRowID)) + require.NoError(t, err) + var loadedSlice []testdata.GroupMember + _, err = orm.ReadAll(it, &loadedSlice) + require.NoError(t, err) + + t.Logf("gas consumed on get by multi index key: %d", gCtx.GasConsumed()) + + // delete + gCtx.ResetGasMeter() + err = k.groupMemberTable.Delete(gCtx, &m) + require.NoError(t, err) + t.Logf("gas consumed on delete by primary key: %d", gCtx.GasConsumed()) + + // with 3 elements + for i := 1; i < 4; i++ { + gCtx.ResetGasMeter() + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte(fmt.Sprintf("member-address%d", i))), + Weight: 10, + } + err = k.groupMemberTable.Create(gCtx, &m) + require.NoError(t, err) + t.Logf("%d: gas consumed on create: %d", i, gCtx.GasConsumed()) + } + + for i := 1; i < 4; i++ { + gCtx.ResetGasMeter() + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte(fmt.Sprintf("member-address%d", i))), + Weight: 10, + } + err = k.groupMemberTable.GetOne(gCtx, m.PrimaryKey(), &loaded) + require.NoError(t, err) + t.Logf("%d: gas consumed on get by primary key: %d", i, gCtx.GasConsumed()) + } + + // get by secondary index + gCtx.ResetGasMeter() + // and when loaded from MultiKeyIndex + it, err = k.groupMemberByGroupIndex.Get(gCtx, orm.EncodeSequence(groupRowID)) + require.NoError(t, err) + _, err = orm.ReadAll(it, &loadedSlice) + require.NoError(t, err) + require.Len(t, loadedSlice, 3) + t.Logf("gas consumed on get by multi index key: %d", gCtx.GasConsumed()) + + // delete + for i, m := range loadedSlice { + gCtx.ResetGasMeter() + + err = k.groupMemberTable.Delete(gCtx, &m) + require.NoError(t, err) + t.Logf("%d: gas consumed on delete: %d", i, gCtx.GasConsumed()) + } +} + +func TestExportImportStateAutoUInt64Table(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + + k := NewGroupKeeper(storeKey, cdc) + + testRecords := 10 + for i := 1; i <= testRecords; i++ { + myAddr := sdk.AccAddress(bytes.Repeat([]byte{byte(i)}, addrLen)) + g := testdata.GroupInfo{ + GroupId: uint64(i), + Description: fmt.Sprintf("my test %d", i), + Admin: myAddr, + } + + groupRowID, err := k.groupTable.Create(ctx, &g) + require.NoError(t, err) + require.Equal(t, uint64(i), groupRowID) + } + var groups []*testdata.GroupInfo + seqVal, err := orm.ExportTableData(ctx, k.groupTable, &groups) + require.NoError(t, err) + + // when a new db seeded + ctx = orm.NewMockContext() + + err = orm.ImportTableData(ctx, k.groupTable, groups, seqVal) + require.NoError(t, err) + // then all data is set again + + for i := 1; i <= testRecords; i++ { + require.True(t, k.groupTable.Has(ctx, uint64(i))) + var loaded testdata.GroupInfo + groupRowID, err := k.groupTable.GetOne(ctx, uint64(i), &loaded) + require.NoError(t, err) + + require.Equal(t, orm.RowID(orm.EncodeSequence(uint64(i))), groupRowID) + assert.Equal(t, fmt.Sprintf("my test %d", i), loaded.Description) + exp := sdk.AccAddress(bytes.Repeat([]byte{byte(i)}, addrLen)) + assert.Equal(t, exp, loaded.Admin) + + // and also the indexes + require.True(t, k.groupByAdminIndex.Has(ctx, exp)) + it, err := k.groupByAdminIndex.Get(ctx, exp) + require.NoError(t, err) + var all []testdata.GroupInfo + orm.ReadAll(it, &all) + require.Len(t, all, 1) + assert.Equal(t, loaded, all[0]) + } + require.Equal(t, uint64(testRecords), k.groupTable.Sequence().CurVal(ctx)) +} + +func TestExportImportStatePrimaryKeyTable(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + + k := NewGroupKeeper(storeKey, cdc) + myGroupAddr := sdk.AccAddress(bytes.Repeat([]byte{byte('a')}, addrLen)) + testRecordsNum := 10 + testRecords := make([]testdata.GroupMember, testRecordsNum) + for i := 1; i <= testRecordsNum; i++ { + myAddr := sdk.AccAddress(bytes.Repeat([]byte{byte(i)}, addrLen)) + g := testdata.GroupMember{ + Group: myGroupAddr, + Member: myAddr, + Weight: uint64(i), + } + err := k.groupMemberTable.Create(ctx, &g) + require.NoError(t, err) + testRecords[i-1] = g + } + var groupMembers []*testdata.GroupMember + _, err := orm.ExportTableData(ctx, k.groupMemberTable, &groupMembers) + require.NoError(t, err) + + // when a new db seeded + ctx = orm.NewMockContext() + + err = orm.ImportTableData(ctx, k.groupMemberTable, groupMembers, 0) + require.NoError(t, err) + + // then all data is set again + it, err := k.groupMemberTable.PrefixScan(ctx, nil, nil) + require.NoError(t, err) + var loaded []testdata.GroupMember + keys, err := orm.ReadAll(it, &loaded) + require.NoError(t, err) + for i := range keys { + assert.Equal(t, testRecords[i].PrimaryKey(), keys[i].Bytes()) + } + assert.Equal(t, testRecords, loaded) + + // and first index setup + it, err = k.groupMemberByGroupIndex.Get(ctx, myGroupAddr) + require.NoError(t, err) + loaded = nil + keys, err = orm.ReadAll(it, &loaded) + require.NoError(t, err) + for i := range keys { + assert.Equal(t, testRecords[i].PrimaryKey(), keys[i].Bytes()) + } + assert.Equal(t, testRecords, loaded) + + // and second index setup + for _, v := range testRecords { + it, err = k.groupMemberByMemberIndex.Get(ctx, v.Member) + require.NoError(t, err) + loaded = nil + keys, err = orm.ReadAll(it, &loaded) + require.NoError(t, err) + assert.Equal(t, []orm.RowID{v.PrimaryKey()}, keys) + assert.Equal(t, []testdata.GroupMember{v}, loaded) + } +} + +func first(t *testing.T, it orm.Iterator) ([]byte, testdata.GroupInfo) { + var loaded testdata.GroupInfo + key, err := orm.First(it, &loaded) + require.NoError(t, err) + return key, loaded +} diff --git a/orm/orm_test.go b/orm/orm_test.go new file mode 100644 index 000000000000..360bc02a8e39 --- /dev/null +++ b/orm/orm_test.go @@ -0,0 +1,77 @@ +package orm_test + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +func TestTypeSafeRowGetter(t *testing.T) { + storeKey := sdk.NewKVStoreKey("test") + ctx := orm.NewMockContext() + const prefixKey = 0x2 + store := prefix.NewStore(ctx.KVStore(storeKey), []byte{prefixKey}) + md := testdata.GroupInfo{Description: "foo"} + bz, err := md.Marshal() + require.NoError(t, err) + store.Set(orm.EncodeSequence(1), bz) + + specs := map[string]struct { + srcRowID orm.RowID + srcModelType reflect.Type + expObj interface{} + expErr *errors.Error + }{ + "happy path": { + srcRowID: orm.EncodeSequence(1), + srcModelType: reflect.TypeOf(testdata.GroupInfo{}), + expObj: testdata.GroupInfo{Description: "foo"}, + }, + "unknown rowID should return ErrNotFound": { + srcRowID: orm.EncodeSequence(999), + srcModelType: reflect.TypeOf(testdata.GroupInfo{}), + expErr: orm.ErrNotFound, + }, + "wrong type should cause ErrType": { + srcRowID: orm.EncodeSequence(1), + srcModelType: reflect.TypeOf(testdata.GroupMember{}), + expErr: orm.ErrType, + }, + "empty rowID not allowed": { + srcRowID: []byte{}, + srcModelType: reflect.TypeOf(testdata.GroupInfo{}), + expErr: orm.ErrArgument, + }, + "nil rowID not allowed": { + srcModelType: reflect.TypeOf(testdata.GroupInfo{}), + expErr: orm.ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + getter := orm.NewTypeSafeRowGetter(storeKey, prefixKey, spec.srcModelType, cdc) + var loadedObj testdata.GroupInfo + + err := getter(ctx, spec.srcRowID, &loadedObj) + if spec.expErr != nil { + require.True(t, spec.expErr.Is(err), err) + return + } + require.NoError(t, err) + assert.Equal(t, spec.expObj, loadedObj) + }) + } +} diff --git a/orm/primary_key.go b/orm/primary_key.go new file mode 100644 index 000000000000..76833e445a10 --- /dev/null +++ b/orm/primary_key.go @@ -0,0 +1,129 @@ +package orm + +import ( + "github.com/cosmos/cosmos-sdk/codec" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +var _ Indexable = &PrimaryKeyTableBuilder{} + +// NewPrimaryKeyTableBuilder creates a builder to setup a PrimaryKeyTable object. +func NewPrimaryKeyTableBuilder(prefixData byte, storeKey sdk.StoreKey, model PrimaryKeyed, codec IndexKeyCodec, cdc codec.Codec) *PrimaryKeyTableBuilder { + return &PrimaryKeyTableBuilder{ + TableBuilder: NewTableBuilder(prefixData, storeKey, model, codec, cdc), + } +} + +type PrimaryKeyTableBuilder struct { + *TableBuilder +} + +func (a PrimaryKeyTableBuilder) Build() PrimaryKeyTable { + return PrimaryKeyTable{table: a.TableBuilder.Build()} + +} + +// PrimaryKeyed defines an object type that is aware of it's immutable primary key. +type PrimaryKeyed interface { + // PrimaryKey returns the immutable and serialized primary key of this object. The primary key has to be unique within + // it's domain so that not two with same value can exist in the same table. + // + // The `IndexKeyCodec` used with the `PrimaryKeyTable` may add certain constraints to the byte representation as + // max length = 255 in `Max255DynamicLengthIndexKeyCodec` or a fix length in `FixLengthIndexKeyCodec` for example. + PrimaryKey() []byte + codec.ProtoMarshaler +} + +var _ TableExportable = &PrimaryKeyTable{} + +// PrimaryKeyTable provides simpler object style orm methods without passing database RowIDs. +// Entries are persisted and loaded with a reference to their unique primary key. +type PrimaryKeyTable struct { + table Table +} + +// Create persists the given object under their primary key. It checks if the +// key already exists and may return an `ErrUniqueConstraint`. +// Create iterates though the registered callbacks and may add secondary index keys by them. +func (a PrimaryKeyTable) Create(ctx HasKVStore, obj PrimaryKeyed) error { + rowID := obj.PrimaryKey() + if a.table.Has(ctx, rowID) { + return ErrUniqueConstraint + } + return a.table.Create(ctx, rowID, obj) +} + +// Save updates the given object under the primary key. It expects the key to exists already +// and fails with an `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. Parameters must not be nil. +// +// Save iterates though the registered callbacks and may add or remove secondary index keys by them. +func (a PrimaryKeyTable) Save(ctx HasKVStore, newValue PrimaryKeyed) error { + return a.table.Save(ctx, newValue.PrimaryKey(), newValue) +} + +// Delete removes the object. It expects the primary key to exists already +// and fails with a `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. +// +// Delete iterates though the registered callbacks and removes secondary index keys by them. +func (a PrimaryKeyTable) Delete(ctx HasKVStore, obj PrimaryKeyed) error { + return a.table.Delete(ctx, obj.PrimaryKey()) +} + +// Has checks if a key exists. Panics on nil key. +func (a PrimaryKeyTable) Has(ctx HasKVStore, primaryKey RowID) bool { + return a.table.Has(ctx, primaryKey) +} + +// Contains returns true when an object with same type and primary key is persisted in this table. +func (a PrimaryKeyTable) Contains(ctx HasKVStore, obj PrimaryKeyed) bool { + if err := assertCorrectType(a.table.model, obj); err != nil { + return false + } + return a.table.Has(ctx, obj.PrimaryKey()) +} + +// GetOne load the object persisted for the given primary Key into the dest parameter. +// If none exists `ErrNotFound` is returned instead. Parameters must not be nil. +func (a PrimaryKeyTable) GetOne(ctx HasKVStore, primKey RowID, dest codec.ProtoMarshaler) error { + return a.table.GetOne(ctx, primKey, dest) +} + +// PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. +// Example: +// it, err := idx.PrefixScan(ctx, start, end) +// if err !=nil { +// return err +// } +// const defaultLimit = 20 +// it = LimitIterator(it, defaultLimit) +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a PrimaryKeyTable) PrefixScan(ctx HasKVStore, start, end []byte) (Iterator, error) { + return a.table.PrefixScan(ctx, start, end) +} + +// ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. See `LimitIterator` +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a PrimaryKeyTable) ReversePrefixScan(ctx HasKVStore, start, end []byte) (Iterator, error) { + return a.table.ReversePrefixScan(ctx, start, end) +} + +// Table satisfies the TableExportable interface and must not be used otherwise. +func (a PrimaryKeyTable) Table() Table { + return a.table +} diff --git a/orm/primary_key_test.go b/orm/primary_key_test.go new file mode 100644 index 000000000000..92924e082de1 --- /dev/null +++ b/orm/primary_key_test.go @@ -0,0 +1,280 @@ +package orm_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +func TestPrimaryKeyTablePrefixScan(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const ( + testTablePrefix = iota + ) + + tb := orm.NewPrimaryKeyTableBuilder(testTablePrefix, storeKey, &testdata.GroupMember{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc). + Build() + + ctx := orm.NewMockContext() + + const anyWeight = 1 + m1 := testdata.GroupMember{ + Group: []byte("group-a"), + Member: []byte("member-one"), + Weight: anyWeight, + } + m2 := testdata.GroupMember{ + Group: []byte("group-a"), + Member: []byte("member-two"), + Weight: anyWeight, + } + m3 := testdata.GroupMember{ + Group: []byte("group-b"), + Member: []byte("member-two"), + Weight: anyWeight, + } + for _, g := range []testdata.GroupMember{m1, m2, m3} { + require.NoError(t, tb.Create(ctx, &g)) + } + + specs := map[string]struct { + start, end []byte + expResult []testdata.GroupMember + expRowIDs []orm.RowID + expError *errors.Error + method func(ctx orm.HasKVStore, start, end []byte) (orm.Iterator, error) + }{ + "exact match with a single result": { + start: []byte("group-amember-one"), // == m1.PrimaryKey() + end: []byte("group-amember-two"), // == m2.PrimaryKey() + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1}, + expRowIDs: []orm.RowID{m1.PrimaryKey()}, + }, + "one result by prefix": { + start: []byte("group-a"), + end: []byte("group-amember-two"), // == m2.PrimaryKey() + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1}, + expRowIDs: []orm.RowID{m1.PrimaryKey()}, + }, + "multi key elements by group prefix": { + start: []byte("group-a"), + end: []byte("group-b"), + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1, m2}, + expRowIDs: []orm.RowID{m1.PrimaryKey(), m2.PrimaryKey()}, + }, + "open end query with second group": { + start: []byte("group-b"), + end: nil, + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m3}, + expRowIDs: []orm.RowID{m3.PrimaryKey()}, + }, + "open end query with all": { + start: []byte("group-a"), + end: nil, + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1, m2, m3}, + expRowIDs: []orm.RowID{m1.PrimaryKey(), m2.PrimaryKey(), m3.PrimaryKey()}, + }, + "open start query": { + start: nil, + end: []byte("group-b"), + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1, m2}, + expRowIDs: []orm.RowID{m1.PrimaryKey(), m2.PrimaryKey()}, + }, + "open start and end query": { + start: nil, + end: nil, + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1, m2, m3}, + expRowIDs: []orm.RowID{m1.PrimaryKey(), m2.PrimaryKey(), m3.PrimaryKey()}, + }, + "all matching prefix": { + start: []byte("group"), + end: nil, + method: tb.PrefixScan, + expResult: []testdata.GroupMember{m1, m2, m3}, + expRowIDs: []orm.RowID{m1.PrimaryKey(), m2.PrimaryKey(), m3.PrimaryKey()}, + }, + "non matching prefix": { + start: []byte("nobody"), + end: nil, + method: tb.PrefixScan, + expResult: []testdata.GroupMember{}, + }, + "start equals end": { + start: []byte("any"), + end: []byte("any"), + method: tb.PrefixScan, + expError: orm.ErrArgument, + }, + "start after end": { + start: []byte("b"), + end: []byte("a"), + method: tb.PrefixScan, + expError: orm.ErrArgument, + }, + "reverse: exact match with a single result": { + start: []byte("group-amember-one"), // == m1.PrimaryKey() + end: []byte("group-amember-two"), // == m2.PrimaryKey() + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m1}, + expRowIDs: []orm.RowID{m1.PrimaryKey()}, + }, + "reverse: one result by prefix": { + start: []byte("group-a"), + end: []byte("group-amember-two"), // == m2.PrimaryKey() + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m1}, + expRowIDs: []orm.RowID{m1.PrimaryKey()}, + }, + "reverse: multi key elements by group prefix": { + start: []byte("group-a"), + end: []byte("group-b"), + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m2, m1}, + expRowIDs: []orm.RowID{m2.PrimaryKey(), m1.PrimaryKey()}, + }, + "reverse: open end query with second group": { + start: []byte("group-b"), + end: nil, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m3}, + expRowIDs: []orm.RowID{m3.PrimaryKey()}, + }, + "reverse: open end query with all": { + start: []byte("group-a"), + end: nil, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m3, m2, m1}, + expRowIDs: []orm.RowID{m3.PrimaryKey(), m2.PrimaryKey(), m1.PrimaryKey()}, + }, + "reverse: open start query": { + start: nil, + end: []byte("group-b"), + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m2, m1}, + expRowIDs: []orm.RowID{m2.PrimaryKey(), m1.PrimaryKey()}, + }, + "reverse: open start and end query": { + start: nil, + end: nil, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m3, m2, m1}, + expRowIDs: []orm.RowID{m3.PrimaryKey(), m2.PrimaryKey(), m1.PrimaryKey()}, + }, + "reverse: all matching prefix": { + start: []byte("group"), + end: nil, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{m3, m2, m1}, + expRowIDs: []orm.RowID{m3.PrimaryKey(), m2.PrimaryKey(), m1.PrimaryKey()}, + }, + "reverse: non matching prefix": { + start: []byte("nobody"), + end: nil, + method: tb.ReversePrefixScan, + expResult: []testdata.GroupMember{}, + }, + "reverse: start equals end": { + start: []byte("any"), + end: []byte("any"), + method: tb.ReversePrefixScan, + expError: orm.ErrArgument, + }, + "reverse: start after end": { + start: []byte("b"), + end: []byte("a"), + method: tb.ReversePrefixScan, + expError: orm.ErrArgument, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + it, err := spec.method(ctx, spec.start, spec.end) + require.True(t, spec.expError.Is(err), "expected #+v but got #+v", spec.expError, err) + if spec.expError != nil { + return + } + var loaded []testdata.GroupMember + rowIDs, err := orm.ReadAll(it, &loaded) + require.NoError(t, err) + assert.Equal(t, spec.expResult, loaded) + assert.Equal(t, spec.expRowIDs, rowIDs) + }) + } +} + +func TestContains(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const testTablePrefix = iota + + tb := orm.NewPrimaryKeyTableBuilder(testTablePrefix, storeKey, &testdata.GroupMember{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc). + Build() + + ctx := orm.NewMockContext() + + myPersistentObj := testdata.GroupMember{ + Group: []byte("group-a"), + Member: []byte("member-one"), + Weight: 1, + } + err := tb.Create(ctx, &myPersistentObj) + require.NoError(t, err) + + specs := map[string]struct { + src orm.PrimaryKeyed + exp bool + }{ + + "same object": {src: &myPersistentObj, exp: true}, + "clone": { + src: &testdata.GroupMember{ + Group: []byte("group-a"), + Member: []byte("member-one"), + Weight: 1, + }, + exp: true, + }, + "different primary key": { + src: &testdata.GroupMember{ + Group: []byte("another group"), + Member: []byte("member-one"), + Weight: 1, + }, + exp: false, + }, + "different type, same key": { + src: mockPrimaryKeyed{&myPersistentObj}, + exp: false, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + got := tb.Contains(ctx, spec.src) + assert.Equal(t, spec.exp, got) + }) + } +} + +type mockPrimaryKeyed struct { + *testdata.GroupMember +} diff --git a/orm/sequence.go b/orm/sequence.go new file mode 100644 index 000000000000..11b1b186c78a --- /dev/null +++ b/orm/sequence.go @@ -0,0 +1,83 @@ +package orm + +import ( + "encoding/binary" + + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +// sequenceStorageKey is a fix key to read/ write data on the storage layer +var sequenceStorageKey = []byte{0x1} + +// sequence is a persistent unique key generator based on a counter. +type Sequence struct { + storeKey sdk.StoreKey + prefix byte +} + +func NewSequence(storeKey sdk.StoreKey, prefix byte) Sequence { + return Sequence{ + prefix: prefix, + storeKey: storeKey, + } +} + +// NextVal increments and persists the counter by one and returns the value. +func (s Sequence) NextVal(ctx HasKVStore) uint64 { + store := prefix.NewStore(ctx.KVStore(s.storeKey), []byte{s.prefix}) + v := store.Get(sequenceStorageKey) + seq := DecodeSequence(v) + seq++ + store.Set(sequenceStorageKey, EncodeSequence(seq)) + return seq +} + +// CurVal returns the last value used. 0 if none. +func (s Sequence) CurVal(ctx HasKVStore) uint64 { + store := prefix.NewStore(ctx.KVStore(s.storeKey), []byte{s.prefix}) + v := store.Get(sequenceStorageKey) + return DecodeSequence(v) +} + +// PeekNextVal returns the CurVal + increment step. Not persistent. +func (s Sequence) PeekNextVal(ctx HasKVStore) uint64 { + store := prefix.NewStore(ctx.KVStore(s.storeKey), []byte{s.prefix}) + v := store.Get(sequenceStorageKey) + return DecodeSequence(v) + 1 +} + +// InitVal sets the start value for the sequence. It must be called only once on an empty DB. +// Otherwise an error is returned when the key exists. The given start value is stored as current +// value. +// +// It is recommended to call this method only for a sequence start value other than `1` as the +// method consumes unnecessary gas otherwise. A scenario would be an import from genesis. +func (s Sequence) InitVal(ctx HasKVStore, seq uint64) error { + store := prefix.NewStore(ctx.KVStore(s.storeKey), []byte{s.prefix}) + if store.Has(sequenceStorageKey) { + return errors.Wrap(ErrUniqueConstraint, "already initialized") + } + store.Set(sequenceStorageKey, EncodeSequence(seq)) + return nil +} + +// DecodeSequence converts the binary representation into an Uint64 value. +func DecodeSequence(bz []byte) uint64 { + if bz == nil { + return 0 + } + val := binary.BigEndian.Uint64(bz) + return val +} + +// EncodedSeqLength number of bytes used for the binary representation of a sequence value. +const EncodedSeqLength = 8 + +// EncodeSequence converts the sequence value into the binary representation. +func EncodeSequence(val uint64) []byte { + bz := make([]byte, EncodedSeqLength) + binary.BigEndian.PutUint64(bz, val) + return bz +} diff --git a/orm/sequence_test.go b/orm/sequence_test.go new file mode 100644 index 000000000000..a487115c8312 --- /dev/null +++ b/orm/sequence_test.go @@ -0,0 +1,25 @@ +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func TestSequenceIncrements(t *testing.T) { + storeKey := sdk.NewKVStoreKey("test") + ctx := NewMockContext() + + seq := NewSequence(storeKey, 0x1) + var i uint64 + for i = 1; i < 10; i++ { + autoID := seq.NextVal(ctx) + assert.Equal(t, i, autoID) + assert.Equal(t, i, seq.CurVal(ctx)) + } + // and persisted + seq = NewSequence(storeKey, 0x1) + assert.Equal(t, uint64(9), seq.CurVal(ctx)) +} diff --git a/orm/table.go b/orm/table.go new file mode 100644 index 000000000000..3a05cf81a196 --- /dev/null +++ b/orm/table.go @@ -0,0 +1,278 @@ +package orm + +import ( + "bytes" + "reflect" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/store/prefix" + "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +var _ Indexable = &TableBuilder{} + +type TableBuilder struct { + model reflect.Type + prefixData byte + storeKey sdk.StoreKey + indexKeyCodec IndexKeyCodec + afterSave []AfterSaveInterceptor + afterDelete []AfterDeleteInterceptor + cdc codec.Codec +} + +// NewTableBuilder creates a builder to setup a Table object. +func NewTableBuilder(prefixData byte, storeKey sdk.StoreKey, model codec.ProtoMarshaler, idxKeyCodec IndexKeyCodec, cdc codec.Codec) *TableBuilder { + if model == nil { + panic("Model must not be nil") + } + if storeKey == nil { + panic("StoreKey must not be nil") + } + if idxKeyCodec == nil { + panic("IndexKeyCodec must not be nil") + } + tp := reflect.TypeOf(model) + if tp.Kind() == reflect.Ptr { + tp = tp.Elem() + } + return &TableBuilder{ + prefixData: prefixData, + storeKey: storeKey, + model: tp, + indexKeyCodec: idxKeyCodec, + cdc: cdc, + } +} + +func (a TableBuilder) IndexKeyCodec() IndexKeyCodec { + return a.indexKeyCodec +} + +// RowGetter returns a type safe RowGetter. +func (a TableBuilder) RowGetter() RowGetter { + return NewTypeSafeRowGetter(a.storeKey, a.prefixData, a.model, a.cdc) +} + +func (a TableBuilder) StoreKey() sdk.StoreKey { + return a.storeKey +} + +// Build creates a new Table object. +func (a TableBuilder) Build() Table { + return Table{ + model: a.model, + prefix: a.prefixData, + storeKey: a.storeKey, + afterSave: a.afterSave, + afterDelete: a.afterDelete, + cdc: a.cdc, + } +} + +// AddAfterSaveInterceptor can be used to register a callback function that is executed after an object is created and/or updated. +func (a *TableBuilder) AddAfterSaveInterceptor(interceptor AfterSaveInterceptor) { + a.afterSave = append(a.afterSave, interceptor) +} + +// AddAfterDeleteInterceptor can be used to register a callback function that is executed after an object is deleted. +func (a *TableBuilder) AddAfterDeleteInterceptor(interceptor AfterDeleteInterceptor) { + a.afterDelete = append(a.afterDelete, interceptor) +} + +var _ TableExportable = &Table{} + +// Table is the high level object to storage mapper functionality. Persistent entities are stored by an unique identifier +// called `RowID`. +// The Table struct does not enforce uniqueness of the `RowID` but expects this to be satisfied by the callers and conditions +// to optimize Gas usage. +type Table struct { + model reflect.Type + prefix byte + storeKey sdk.StoreKey + afterSave []AfterSaveInterceptor + afterDelete []AfterDeleteInterceptor + cdc codec.Codec +} + +// Create persists the given object under the rowID key. It does not check if the +// key already exists. Any caller must either make sure that this contract is fulfilled +// by providing a universal unique ID or sequence that is guaranteed to not exist yet or +// by checking the state via `Has` function before. +// +// Create iterates though the registered callbacks and may add secondary index keys by them. +func (a Table) Create(ctx HasKVStore, rowID RowID, obj codec.ProtoMarshaler) error { + if err := assertCorrectType(a.model, obj); err != nil { + return err + } + if err := assertValid(obj); err != nil { + return err + } + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + v, err := a.cdc.Marshal(obj) + if err != nil { + return errors.Wrapf(err, "failed to serialize %T", obj) + } + store.Set(rowID, v) + for i, itc := range a.afterSave { + if err := itc(ctx, rowID, obj, nil); err != nil { + return errors.Wrapf(err, "interceptor %d failed", i) + } + } + return nil +} + +// Save updates the given object under the rowID key. It expects the key to exists already +// and fails with an `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. Parameters must not be nil. +// +// Save iterates though the registered callbacks and may add or remove secondary index keys by them. +func (a Table) Save(ctx HasKVStore, rowID RowID, newValue codec.ProtoMarshaler) error { + if err := assertCorrectType(a.model, newValue); err != nil { + return err + } + if err := assertValid(newValue); err != nil { + return err + } + + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + var oldValue = reflect.New(a.model).Interface().(codec.ProtoMarshaler) + + if err := a.GetOne(ctx, rowID, oldValue); err != nil { + return errors.Wrap(err, "load old value") + } + newValueEncoded, err := a.cdc.Marshal(newValue) + if err != nil { + return errors.Wrapf(err, "failed to serialize %T", newValue) + } + + store.Set(rowID, newValueEncoded) + for i, itc := range a.afterSave { + if err := itc(ctx, rowID, newValue, oldValue); err != nil { + return errors.Wrapf(err, "interceptor %d failed", i) + } + } + return nil +} + +func assertValid(obj codec.ProtoMarshaler) error { + if v, ok := obj.(Validateable); ok { + if err := v.ValidateBasic(); err != nil { + return err + } + } + return nil +} + +// Delete removes the object under the rowID key. It expects the key to exists already +// and fails with a `ErrNotFound` otherwise. Any caller must therefore make sure that this contract +// is fulfilled. +// +// Delete iterates though the registered callbacks and removes secondary index keys by them. +func (a Table) Delete(ctx HasKVStore, rowID RowID) error { + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + + var oldValue = reflect.New(a.model).Interface().(codec.ProtoMarshaler) + if err := a.GetOne(ctx, rowID, oldValue); err != nil { + return errors.Wrap(err, "load old value") + } + store.Delete(rowID) + + for i, itc := range a.afterDelete { + if err := itc(ctx, rowID, oldValue); err != nil { + return errors.Wrapf(err, "delete interceptor %d failed", i) + } + } + return nil +} + +// Has checks if a key exists. Panics on nil key. +func (a Table) Has(ctx HasKVStore, rowID RowID) bool { + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + it := store.Iterator(PrefixRange(rowID)) + defer it.Close() + return it.Valid() +} + +// GetOne load the object persisted for the given RowID into the dest parameter. +// If none exists `ErrNotFound` is returned instead. Parameters must not be nil. +func (a Table) GetOne(ctx HasKVStore, rowID RowID, dest codec.ProtoMarshaler) error { + x := NewTypeSafeRowGetter(a.storeKey, a.prefix, a.model, a.cdc) + return x(ctx, rowID, dest) +} + +// PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. +// Example: +// it, err := idx.PrefixScan(ctx, start, end) +// if err !=nil { +// return err +// } +// const defaultLimit = 20 +// it = LimitIterator(it, defaultLimit) +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a Table) PrefixScan(ctx HasKVStore, start, end RowID) (Iterator, error) { + if start != nil && end != nil && bytes.Compare(start, end) >= 0 { + return NewInvalidIterator(), errors.Wrap(ErrArgument, "start must be before end") + } + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + return &typeSafeIterator{ + ctx: ctx, + rowGetter: NewTypeSafeRowGetter(a.storeKey, a.prefix, a.model, a.cdc), + it: store.Iterator(start, end), + }, nil +} + +// ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. See `LimitIterator` +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (a Table) ReversePrefixScan(ctx HasKVStore, start, end RowID) (Iterator, error) { + if start != nil && end != nil && bytes.Compare(start, end) >= 0 { + return NewInvalidIterator(), errors.Wrap(ErrArgument, "start must be before end") + } + store := prefix.NewStore(ctx.KVStore(a.storeKey), []byte{a.prefix}) + return &typeSafeIterator{ + ctx: ctx, + rowGetter: NewTypeSafeRowGetter(a.storeKey, a.prefix, a.model, a.cdc), + it: store.ReverseIterator(start, end), + }, nil +} + +func (a Table) Table() Table { + return a +} + +// typeSafeIterator is initialized with a type safe RowGetter only. +type typeSafeIterator struct { + ctx HasKVStore + rowGetter RowGetter + it types.Iterator +} + +func (i typeSafeIterator) LoadNext(dest codec.ProtoMarshaler) (RowID, error) { + if !i.it.Valid() { + return nil, ErrIteratorDone + } + rowID := i.it.Key() + i.it.Next() + return rowID, i.rowGetter(i.ctx, rowID, dest) +} + +func (i typeSafeIterator) Close() error { + i.it.Close() + return nil +} diff --git a/orm/table_test.go b/orm/table_test.go new file mode 100644 index 000000000000..fed00ae30dee --- /dev/null +++ b/orm/table_test.go @@ -0,0 +1,129 @@ +package orm_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +func TestCreate(t *testing.T) { + specs := map[string]struct { + src codec.ProtoMarshaler + expErr *errors.Error + }{ + "happy path": { + src: &testdata.GroupInfo{ + Description: "my group", + Admin: sdk.AccAddress([]byte("my-admin-address")), + }, + }, + "wrong type": { + src: &testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 10, + }, + expErr: orm.ErrType, + }, + "model validation fails": { + src: &testdata.GroupInfo{Description: "invalid"}, + expErr: testdata.ErrTest, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const anyPrefix = 0x10 + tableBuilder := orm.NewTableBuilder(anyPrefix, storeKey, &testdata.GroupInfo{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc) + myTable := tableBuilder.Build() + + ctx := orm.NewMockContext() + err := myTable.Create(ctx, []byte("my-id"), spec.src) + + require.True(t, spec.expErr.Is(err), err) + shouldExists := spec.expErr == nil + assert.Equal(t, shouldExists, myTable.Has(ctx, []byte("my-id")), fmt.Sprintf("expected %v", shouldExists)) + + // then + var loaded testdata.GroupInfo + err = myTable.GetOne(ctx, []byte("my-id"), &loaded) + if spec.expErr != nil { + require.True(t, orm.ErrNotFound.Is(err)) + return + } + require.NoError(t, err) + assert.Equal(t, spec.src, &loaded) + }) + } + +} +func TestUpdate(t *testing.T) { + specs := map[string]struct { + src codec.ProtoMarshaler + expErr *errors.Error + }{ + "happy path": { + src: &testdata.GroupInfo{ + Description: "my group", + Admin: sdk.AccAddress([]byte("my-admin-address")), + }, + }, + "wrong type": { + src: &testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 9999, + }, + expErr: orm.ErrType, + }, + "model validation fails": { + src: &testdata.GroupInfo{Description: "invalid"}, + expErr: testdata.ErrTest, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + const anyPrefix = 0x10 + tableBuilder := orm.NewTableBuilder(anyPrefix, storeKey, &testdata.GroupInfo{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc) + myTable := tableBuilder.Build() + + initValue := testdata.GroupInfo{ + Description: "my old group description", + Admin: sdk.AccAddress([]byte("my-old-admin-address")), + } + ctx := orm.NewMockContext() + err := myTable.Create(ctx, []byte("my-id"), &initValue) + require.NoError(t, err) + + // when + err = myTable.Save(ctx, []byte("my-id"), spec.src) + require.True(t, spec.expErr.Is(err), "got ", err) + + // then + var loaded testdata.GroupInfo + require.NoError(t, myTable.GetOne(ctx, []byte("my-id"), &loaded)) + if spec.expErr == nil { + assert.Equal(t, spec.src, &loaded) + } else { + assert.Equal(t, initValue, loaded) + } + }) + } + +} diff --git a/orm/testdata/codec.pb.go b/orm/testdata/codec.pb.go new file mode 100644 index 000000000000..237ff7159292 --- /dev/null +++ b/orm/testdata/codec.pb.go @@ -0,0 +1,679 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: codec.proto + +package testdata + +import ( + fmt "fmt" + github_com_cosmos_cosmos_sdk_types "github.com/cosmos/cosmos-sdk/types" + _ "github.com/gogo/protobuf/gogoproto" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type GroupInfo struct { + GroupId uint64 `protobuf:"varint,1,opt,name=group_id,json=groupId,proto3" json:"group_id,omitempty"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` + Admin github_com_cosmos_cosmos_sdk_types.AccAddress `protobuf:"bytes,3,opt,name=admin,proto3,casttype=github.com/cosmos/cosmos-sdk/types.AccAddress" json:"admin,omitempty"` +} + +func (m *GroupInfo) Reset() { *m = GroupInfo{} } +func (m *GroupInfo) String() string { return proto.CompactTextString(m) } +func (*GroupInfo) ProtoMessage() {} +func (*GroupInfo) Descriptor() ([]byte, []int) { + return fileDescriptor_9610d574777ab505, []int{0} +} +func (m *GroupInfo) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GroupInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_GroupInfo.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *GroupInfo) XXX_Merge(src proto.Message) { + xxx_messageInfo_GroupInfo.Merge(m, src) +} +func (m *GroupInfo) XXX_Size() int { + return m.Size() +} +func (m *GroupInfo) XXX_DiscardUnknown() { + xxx_messageInfo_GroupInfo.DiscardUnknown(m) +} + +var xxx_messageInfo_GroupInfo proto.InternalMessageInfo + +func (m *GroupInfo) GetGroupId() uint64 { + if m != nil { + return m.GroupId + } + return 0 +} + +func (m *GroupInfo) GetDescription() string { + if m != nil { + return m.Description + } + return "" +} + +func (m *GroupInfo) GetAdmin() github_com_cosmos_cosmos_sdk_types.AccAddress { + if m != nil { + return m.Admin + } + return nil +} + +type GroupMember struct { + Group github_com_cosmos_cosmos_sdk_types.AccAddress `protobuf:"bytes,1,opt,name=group,proto3,casttype=github.com/cosmos/cosmos-sdk/types.AccAddress" json:"group,omitempty"` + Member github_com_cosmos_cosmos_sdk_types.AccAddress `protobuf:"bytes,2,opt,name=member,proto3,casttype=github.com/cosmos/cosmos-sdk/types.AccAddress" json:"member,omitempty"` + Weight uint64 `protobuf:"varint,3,opt,name=weight,proto3" json:"weight,omitempty"` +} + +func (m *GroupMember) Reset() { *m = GroupMember{} } +func (m *GroupMember) String() string { return proto.CompactTextString(m) } +func (*GroupMember) ProtoMessage() {} +func (*GroupMember) Descriptor() ([]byte, []int) { + return fileDescriptor_9610d574777ab505, []int{1} +} +func (m *GroupMember) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GroupMember) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_GroupMember.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *GroupMember) XXX_Merge(src proto.Message) { + xxx_messageInfo_GroupMember.Merge(m, src) +} +func (m *GroupMember) XXX_Size() int { + return m.Size() +} +func (m *GroupMember) XXX_DiscardUnknown() { + xxx_messageInfo_GroupMember.DiscardUnknown(m) +} + +var xxx_messageInfo_GroupMember proto.InternalMessageInfo + +func (m *GroupMember) GetGroup() github_com_cosmos_cosmos_sdk_types.AccAddress { + if m != nil { + return m.Group + } + return nil +} + +func (m *GroupMember) GetMember() github_com_cosmos_cosmos_sdk_types.AccAddress { + if m != nil { + return m.Member + } + return nil +} + +func (m *GroupMember) GetWeight() uint64 { + if m != nil { + return m.Weight + } + return 0 +} + +func init() { + proto.RegisterType((*GroupInfo)(nil), "testdata.GroupInfo") + proto.RegisterType((*GroupMember)(nil), "testdata.GroupMember") +} + +func init() { proto.RegisterFile("codec.proto", fileDescriptor_9610d574777ab505) } + +var fileDescriptor_9610d574777ab505 = []byte{ + // 296 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x91, 0xb1, 0x4e, 0xc3, 0x30, + 0x10, 0x86, 0x6b, 0x28, 0xa5, 0x75, 0x99, 0x2c, 0x84, 0x0a, 0x83, 0x89, 0x3a, 0x75, 0x69, 0x23, + 0xe0, 0x09, 0xda, 0xa5, 0xaa, 0x04, 0x4b, 0x46, 0x16, 0x94, 0xd8, 0x87, 0x6b, 0x15, 0xe7, 0x22, + 0xdb, 0x55, 0xc5, 0x5b, 0xc0, 0xe3, 0xf0, 0x06, 0x8c, 0x1d, 0x99, 0x10, 0x4a, 0xde, 0x82, 0x09, + 0xc5, 0x0d, 0x52, 0xe7, 0x4e, 0xf6, 0x7f, 0xd6, 0x7d, 0xfe, 0x4e, 0x47, 0xfb, 0x02, 0x25, 0x88, + 0x49, 0x61, 0xd1, 0x23, 0xeb, 0x7a, 0x70, 0x5e, 0xa6, 0x3e, 0xbd, 0x3a, 0x57, 0xa8, 0x30, 0x14, + 0xe3, 0xfa, 0xb6, 0x7b, 0x1f, 0xbe, 0x13, 0xda, 0x9b, 0x5b, 0x5c, 0x17, 0x8b, 0xfc, 0x19, 0xd9, + 0x25, 0xed, 0xaa, 0x3a, 0x3c, 0x69, 0x39, 0x20, 0x11, 0x19, 0xb5, 0x93, 0xd3, 0x90, 0x17, 0x92, + 0x45, 0xb4, 0x2f, 0xc1, 0x09, 0xab, 0x0b, 0xaf, 0x31, 0x1f, 0x1c, 0x45, 0x64, 0xd4, 0x4b, 0xf6, + 0x4b, 0x6c, 0x4e, 0x4f, 0x52, 0x69, 0x74, 0x3e, 0x38, 0x8e, 0xc8, 0xe8, 0x6c, 0x76, 0xf3, 0xfb, + 0x7d, 0x3d, 0x56, 0xda, 0x2f, 0xd7, 0xd9, 0x44, 0xa0, 0x89, 0x05, 0x3a, 0x83, 0xae, 0x39, 0xc6, + 0x4e, 0xae, 0x62, 0xff, 0x5a, 0x80, 0x9b, 0x4c, 0x85, 0x98, 0x4a, 0x69, 0xc1, 0xb9, 0x64, 0xd7, + 0x3f, 0xfc, 0x20, 0xb4, 0x1f, 0x9c, 0x1e, 0xc0, 0x64, 0x60, 0x6b, 0x70, 0xb0, 0x08, 0x4a, 0x87, + 0x81, 0x43, 0x3f, 0x5b, 0xd0, 0x8e, 0x09, 0xc8, 0xa0, 0x7f, 0x10, 0xa9, 0x01, 0xb0, 0x0b, 0xda, + 0xd9, 0x80, 0x56, 0x4b, 0x1f, 0xa6, 0x6d, 0x27, 0x4d, 0x9a, 0xdd, 0x7f, 0x96, 0x9c, 0x6c, 0x4b, + 0x4e, 0x7e, 0x4a, 0x4e, 0xde, 0x2a, 0xde, 0xda, 0x56, 0xbc, 0xf5, 0x55, 0xf1, 0xd6, 0xe3, 0xed, + 0xde, 0x47, 0x16, 0x14, 0xe4, 0xe3, 0x1c, 0xfc, 0x06, 0xed, 0xaa, 0x49, 0x2f, 0x20, 0x15, 0xd8, + 0x18, 0xad, 0x89, 0xff, 0x77, 0x96, 0x75, 0xc2, 0x92, 0xee, 0xfe, 0x02, 0x00, 0x00, 0xff, 0xff, + 0xb9, 0xf0, 0xdf, 0xf7, 0xd3, 0x01, 0x00, 0x00, +} + +func (m *GroupInfo) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GroupInfo) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GroupInfo) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Admin) > 0 { + i -= len(m.Admin) + copy(dAtA[i:], m.Admin) + i = encodeVarintCodec(dAtA, i, uint64(len(m.Admin))) + i-- + dAtA[i] = 0x1a + } + if len(m.Description) > 0 { + i -= len(m.Description) + copy(dAtA[i:], m.Description) + i = encodeVarintCodec(dAtA, i, uint64(len(m.Description))) + i-- + dAtA[i] = 0x12 + } + if m.GroupId != 0 { + i = encodeVarintCodec(dAtA, i, uint64(m.GroupId)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *GroupMember) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GroupMember) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GroupMember) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Weight != 0 { + i = encodeVarintCodec(dAtA, i, uint64(m.Weight)) + i-- + dAtA[i] = 0x18 + } + if len(m.Member) > 0 { + i -= len(m.Member) + copy(dAtA[i:], m.Member) + i = encodeVarintCodec(dAtA, i, uint64(len(m.Member))) + i-- + dAtA[i] = 0x12 + } + if len(m.Group) > 0 { + i -= len(m.Group) + copy(dAtA[i:], m.Group) + i = encodeVarintCodec(dAtA, i, uint64(len(m.Group))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintCodec(dAtA []byte, offset int, v uint64) int { + offset -= sovCodec(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *GroupInfo) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.GroupId != 0 { + n += 1 + sovCodec(uint64(m.GroupId)) + } + l = len(m.Description) + if l > 0 { + n += 1 + l + sovCodec(uint64(l)) + } + l = len(m.Admin) + if l > 0 { + n += 1 + l + sovCodec(uint64(l)) + } + return n +} + +func (m *GroupMember) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Group) + if l > 0 { + n += 1 + l + sovCodec(uint64(l)) + } + l = len(m.Member) + if l > 0 { + n += 1 + l + sovCodec(uint64(l)) + } + if m.Weight != 0 { + n += 1 + sovCodec(uint64(m.Weight)) + } + return n +} + +func sovCodec(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozCodec(x uint64) (n int) { + return sovCodec(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *GroupInfo) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GroupInfo: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GroupInfo: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field GroupId", wireType) + } + m.GroupId = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.GroupId |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Description", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthCodec + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthCodec + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Description = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Admin", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthCodec + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthCodec + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Admin = append(m.Admin[:0], dAtA[iNdEx:postIndex]...) + if m.Admin == nil { + m.Admin = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipCodec(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthCodec + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthCodec + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GroupMember) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GroupMember: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GroupMember: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Group", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthCodec + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthCodec + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Group = append(m.Group[:0], dAtA[iNdEx:postIndex]...) + if m.Group == nil { + m.Group = []byte{} + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Member", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthCodec + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthCodec + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Member = append(m.Member[:0], dAtA[iNdEx:postIndex]...) + if m.Member == nil { + m.Member = []byte{} + } + iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Weight", wireType) + } + m.Weight = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCodec + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Weight |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipCodec(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthCodec + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthCodec + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipCodec(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCodec + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCodec + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCodec + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthCodec + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupCodec + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthCodec + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthCodec = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowCodec = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupCodec = fmt.Errorf("proto: unexpected end of group") +) diff --git a/orm/testdata/codec.proto b/orm/testdata/codec.proto new file mode 100644 index 000000000000..2f3a679fc853 --- /dev/null +++ b/orm/testdata/codec.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package cosmos.orm.testdata; + +import "gogoproto/gogo.proto"; + +option go_package = "github.com/cosmos/cosmos-sdk/orm/testdata"; + +message GroupInfo { + uint64 group_id = 1; + string description = 2; + bytes admin = 3 [ (gogoproto.casttype) = + "github.com/cosmos/cosmos-sdk/types.AccAddress" ]; +} + +message GroupMember { + bytes group = 1 [ (gogoproto.casttype) = + "github.com/cosmos/cosmos-sdk/types.AccAddress" ]; + bytes member = 2 [ (gogoproto.casttype) = + "github.com/cosmos/cosmos-sdk/types.AccAddress" ]; + uint64 weight = 3; +} diff --git a/orm/testdata/model.go b/orm/testdata/model.go new file mode 100644 index 000000000000..4936c9c07254 --- /dev/null +++ b/orm/testdata/model.go @@ -0,0 +1,32 @@ +package testdata + +import ( + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/types/errors" +) + +var ( + ErrTest = errors.Register("orm_testdata", 9999, "test") +) + +func (g GroupMember) PrimaryKey() []byte { + result := make([]byte, 0, len(g.Group)+len(g.Member)) + result = append(result, g.Group...) + result = append(result, g.Member...) + return result +} + +func (g GroupInfo) PrimaryKey() []byte { + return orm.EncodeSequence(g.GroupId) +} + +func (g GroupInfo) ValidateBasic() error { + if g.Description == "invalid" { + return errors.Wrap(ErrTest, "description") + } + return nil +} + +func (g GroupMember) ValidateBasic() error { + return nil +} diff --git a/orm/testsupport.go b/orm/testsupport.go new file mode 100644 index 000000000000..ec1352a45903 --- /dev/null +++ b/orm/testsupport.go @@ -0,0 +1,137 @@ +package orm + +import ( + "fmt" + "io" + + dbm "github.com/tendermint/tm-db" + + "github.com/cosmos/cosmos-sdk/store" + "github.com/cosmos/cosmos-sdk/store/gaskv" + "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type MockContext struct { + db *dbm.MemDB + store types.CommitMultiStore +} + +func NewMockContext() *MockContext { + db := dbm.NewMemDB() + return &MockContext{ + db: dbm.NewMemDB(), + store: store.NewCommitMultiStore(db), + } +} + +func (m MockContext) KVStore(key sdk.StoreKey) sdk.KVStore { + if s := m.store.GetCommitKVStore(key); s != nil { + return s + } + m.store.MountStoreWithDB(key, sdk.StoreTypeIAVL, m.db) + if err := m.store.LoadLatestVersion(); err != nil { + panic(err) + } + return m.store.GetCommitKVStore(key) +} + +type debuggingGasMeter struct { + g types.GasMeter +} + +func (d debuggingGasMeter) GasConsumed() types.Gas { + return d.g.GasConsumed() +} + +func (d debuggingGasMeter) GasConsumedToLimit() types.Gas { + return d.g.GasConsumedToLimit() +} + +func (d debuggingGasMeter) Limit() types.Gas { + return d.g.Limit() +} + +func (d debuggingGasMeter) ConsumeGas(amount types.Gas, descriptor string) { + fmt.Printf("++ Consuming gas: %q :%d\n", descriptor, amount) + d.g.ConsumeGas(amount, descriptor) +} + +func (d debuggingGasMeter) RefundGas(amount types.Gas, descriptor string) { + fmt.Printf("-- Refunding gas: %q :%d\n", descriptor, amount) + d.g.RefundGas(amount, descriptor) +} + +func (d debuggingGasMeter) IsPastLimit() bool { + return d.g.IsPastLimit() +} + +func (d debuggingGasMeter) IsOutOfGas() bool { + return d.g.IsOutOfGas() +} + +func (d debuggingGasMeter) String() string { + return d.g.String() +} + +type GasCountingMockContext struct { + parent HasKVStore + GasMeter sdk.GasMeter +} + +func NewGasCountingMockContext(parent HasKVStore) *GasCountingMockContext { + return &GasCountingMockContext{ + parent: parent, + GasMeter: &debuggingGasMeter{sdk.NewInfiniteGasMeter()}, + } +} + +func (g GasCountingMockContext) KVStore(key sdk.StoreKey) sdk.KVStore { + return gaskv.NewStore(g.parent.KVStore(key), g.GasMeter, types.KVGasConfig()) +} + +func (g GasCountingMockContext) GasConsumed() types.Gas { + return g.GasMeter.GasConsumed() +} + +func (g *GasCountingMockContext) ResetGasMeter() { + g.GasMeter = sdk.NewInfiniteGasMeter() +} + +type AlwaysPanicKVStore struct{} + +func (a AlwaysPanicKVStore) GetStoreType() types.StoreType { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) CacheWrap() types.CacheWrap { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) CacheWrapWithTrace(w io.Writer, tc types.TraceContext) types.CacheWrap { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) Get(key []byte) []byte { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) Has(key []byte) bool { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) Set(key, value []byte) { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) Delete(key []byte) { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) Iterator(start, end []byte) types.Iterator { + panic("Not implemented") +} + +func (a AlwaysPanicKVStore) ReverseIterator(start, end []byte) types.Iterator { + panic("Not implemented") +} diff --git a/orm/uint64_index.go b/orm/uint64_index.go new file mode 100644 index 000000000000..98e9b8cf6349 --- /dev/null +++ b/orm/uint64_index.go @@ -0,0 +1,85 @@ +package orm + +import ( + "github.com/cosmos/cosmos-sdk/types/query" +) + +// UInt64IndexerFunc creates one or multiple multiKeyIndex keys of type uint64 for the source object. +type UInt64IndexerFunc func(value interface{}) ([]uint64, error) + +// UInt64MultiKeyAdapter converts UInt64IndexerFunc to IndexerFunc +func UInt64MultiKeyAdapter(indexer UInt64IndexerFunc) IndexerFunc { + return func(value interface{}) ([]RowID, error) { + d, err := indexer(value) + if err != nil { + return nil, err + } + r := make([]RowID, len(d)) + for i, v := range d { + r[i] = EncodeSequence(v) + } + return r, nil + } +} + +// UInt64Index is a typed index. +type UInt64Index struct { + multiKeyIndex MultiKeyIndex +} + +// NewUInt64Index creates a typed secondary index +func NewUInt64Index(builder Indexable, prefix byte, indexer UInt64IndexerFunc) UInt64Index { + return UInt64Index{ + multiKeyIndex: NewIndex(builder, prefix, UInt64MultiKeyAdapter(indexer)), + } +} + +// Has checks if a key exists. Panics on nil key. +func (i UInt64Index) Has(ctx HasKVStore, key uint64) bool { + return i.multiKeyIndex.Has(ctx, EncodeSequence(key)) +} + +// Get returns a result iterator for the searchKey. Parameters must not be nil. +func (i UInt64Index) Get(ctx HasKVStore, searchKey uint64) (Iterator, error) { + return i.multiKeyIndex.Get(ctx, EncodeSequence(searchKey)) +} + +// GetPaginated creates an iterator for the searchKey +// starting from pageRequest.Key if provided. +// The pageRequest.Key is the rowID while searchKey is a MultiKeyIndex key. +func (i UInt64Index) GetPaginated(ctx HasKVStore, searchKey uint64, pageRequest *query.PageRequest) (Iterator, error) { + return i.multiKeyIndex.GetPaginated(ctx, EncodeSequence(searchKey), pageRequest) +} + +// PrefixScan returns an Iterator over a domain of keys in ascending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a PrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. +// Example: +// it, err := idx.PrefixScan(ctx, start, end) +// if err !=nil { +// return err +// } +// const defaultLimit = 20 +// it = LimitIterator(it, defaultLimit) +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (i UInt64Index) PrefixScan(ctx HasKVStore, start, end uint64) (Iterator, error) { + return i.multiKeyIndex.PrefixScan(ctx, EncodeSequence(start), EncodeSequence(end)) +} + +// ReversePrefixScan returns an Iterator over a domain of keys in descending order. End is exclusive. +// Start is an MultiKeyIndex key or prefix. It must be less than end, or the Iterator is invalid and error is returned. +// Iterator must be closed by caller. +// To iterate over entire domain, use PrefixScan(nil, nil) +// +// WARNING: The use of a ReversePrefixScan can be very expensive in terms of Gas. Please make sure you do not expose +// this as an endpoint to the public without further limits. See `LimitIterator` +// +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +func (i UInt64Index) ReversePrefixScan(ctx HasKVStore, start, end uint64) (Iterator, error) { + return i.multiKeyIndex.ReversePrefixScan(ctx, EncodeSequence(start), EncodeSequence(end)) +} diff --git a/orm/uint64_index_test.go b/orm/uint64_index_test.go new file mode 100644 index 000000000000..7539a08c111c --- /dev/null +++ b/orm/uint64_index_test.go @@ -0,0 +1,163 @@ +package orm_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/orm" + "github.com/cosmos/cosmos-sdk/orm/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/query" +) + +func TestUInt64Index(t *testing.T) { + interfaceRegistry := types.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + + storeKey := sdk.NewKVStoreKey("test") + + const anyPrefix = 0x10 + tableBuilder := orm.NewPrimaryKeyTableBuilder(anyPrefix, storeKey, &testdata.GroupMember{}, orm.Max255DynamicLengthIndexKeyCodec{}, cdc) + myIndex := orm.NewUInt64Index(tableBuilder, GroupMemberByMemberIndexPrefix, func(val interface{}) ([]uint64, error) { + return []uint64{uint64(val.(*testdata.GroupMember).Member[0])}, nil + }) + myTable := tableBuilder.Build() + + ctx := orm.NewMockContext() + + m := testdata.GroupMember{ + Group: sdk.AccAddress(orm.EncodeSequence(1)), + Member: sdk.AccAddress([]byte("member-address")), + Weight: 10, + } + err := myTable.Create(ctx, &m) + require.NoError(t, err) + + indexedKey := uint64('m') + + // Has + assert.True(t, myIndex.Has(ctx, indexedKey)) + + // Get + it, err := myIndex.Get(ctx, indexedKey) + require.NoError(t, err) + var loaded testdata.GroupMember + rowID, err := it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, uint64(1), orm.DecodeSequence(rowID)) + require.Equal(t, m, loaded) + + // GetPaginated + cases := map[string]struct { + pageReq *query.PageRequest + expErr bool + }{ + "nil key": { + pageReq: &query.PageRequest{Key: nil}, + expErr: false, + }, + "after indexed key": { + pageReq: &query.PageRequest{Key: []byte{byte('m')}}, + expErr: true, + }, + } + + for testName, tc := range cases { + t.Run(testName, func(t *testing.T) { + it, err := myIndex.GetPaginated(ctx, indexedKey, tc.pageReq) + require.NoError(t, err) + rowID, err := it.LoadNext(&loaded) + if tc.expErr { // iterator done + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, orm.RowID(m.PrimaryKey()), rowID) + require.Equal(t, m, loaded) + } + }) + } + + // PrefixScan match + it, err = myIndex.PrefixScan(ctx, 0, 255) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, uint64(1), orm.DecodeSequence(rowID)) + require.Equal(t, m, loaded) + + // PrefixScan no match + it, err = myIndex.PrefixScan(ctx, indexedKey+1, 255) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.Error(t, orm.ErrIteratorDone, err) + + // ReversePrefixScan match + it, err = myIndex.ReversePrefixScan(ctx, 0, 255) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.NoError(t, err) + require.Equal(t, uint64(1), orm.DecodeSequence(rowID)) + require.Equal(t, m, loaded) + + // ReversePrefixScan no match + it, err = myIndex.ReversePrefixScan(ctx, indexedKey+1, 255) + require.NoError(t, err) + rowID, err = it.LoadNext(&loaded) + require.Error(t, orm.ErrIteratorDone, err) +} + +func TestUInt64MultiKeyAdapter(t *testing.T) { + specs := map[string]struct { + srcFunc orm.UInt64IndexerFunc + exp []orm.RowID + expErr error + }{ + "single key": { + srcFunc: func(value interface{}) ([]uint64, error) { + return []uint64{1}, nil + }, + exp: []orm.RowID{{0, 0, 0, 0, 0, 0, 0, 1}}, + }, + "multi key": { + srcFunc: func(value interface{}) ([]uint64, error) { + return []uint64{1, 1 << 56}, nil + }, + exp: []orm.RowID{{0, 0, 0, 0, 0, 0, 0, 1}, {1, 0, 0, 0, 0, 0, 0, 0}}, + }, + "empty key": { + srcFunc: func(value interface{}) ([]uint64, error) { + return []uint64{}, nil + }, + exp: []orm.RowID{}, + }, + "nil key": { + srcFunc: func(value interface{}) ([]uint64, error) { + return nil, nil + }, + exp: []orm.RowID{}, + }, + "error case": { + srcFunc: func(value interface{}) ([]uint64, error) { + return nil, errors.New("test") + }, + expErr: errors.New("test"), + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + fn := orm.UInt64MultiKeyAdapter(spec.srcFunc) + r, err := fn(nil) + if spec.expErr != nil { + require.Equal(t, spec.expErr, err) + return + } + require.NoError(t, err) + assert.Equal(t, spec.exp, r) + }) + } +}