diff --git a/README.md b/README.md index 4fb363f..f3494fd 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,15 @@ -## dynamo [![GoDoc](https://godoc.org/github.com/guregu/dynamo?status.svg)](https://godoc.org/github.com/guregu/dynamo) -`import "github.com/guregu/dynamo"` +## dynamo [![GoDoc](https://godoc.org/github.com/guregu/dynamo/v2?status.svg)](https://godoc.org/github.com/guregu/dynamo/v2) +`import "github.com/guregu/dynamo/v2"` -dynamo is an expressive [DynamoDB](https://aws.amazon.com/dynamodb/) client for Go, with an easy but powerful API. dynamo integrates with the official [AWS SDK](https://github.com/aws/aws-sdk-go/). +dynamo is an expressive [DynamoDB](https://aws.amazon.com/dynamodb/) client for Go, with an easy but powerful API. dynamo integrates with the official [AWS SDK v2](https://github.com/aws/aws-sdk-go-v2/). This library is stable and versioned with Go modules. +> [!TIP] +> dynamo v2 is finally released! See [**v2 Migration**](#migrating-from-v1) for tips on migrating from dynamo v1. +> +> For dynamo v1, which uses [aws-sdk-go v1](https://github.com/aws/aws-sdk-go/), see: [**dynamo v1 documentation**](https://pkg.go.dev/github.com/guregu/dynamo). + ### Example ```go @@ -12,10 +17,12 @@ package dynamo import ( "time" + "context" + "log" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/guregu/dynamo" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/guregu/dynamo/v2" ) // Use struct tags much like the standard JSON library, @@ -34,27 +41,30 @@ type widget struct { func main() { - sess := session.Must(session.NewSession()) - db := dynamo.New(sess, &aws.Config{Region: aws.String("us-west-2")}) + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-1")) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + db := dynamo.New(cfg) table := db.Table("Widgets") // put item w := widget{UserID: 613, Time: time.Now(), Msg: "hello"} - err := table.Put(w).Run() + err = table.Put(w).Run(ctx) // get the same item var result widget err = table.Get("UserID", w.UserID). Range("Time", dynamo.Equal, w.Time). - One(&result) + One(ctx, &result) // get all items var results []widget - err = table.Scan().All(&results) + err = table.Scan().All(ctx, &results) // use placeholders in filter expressions (see Expressions section below) var filtered []widget - err = table.Scan().Filter("'Count' > ?", 10).All(&filtered) + err = table.Scan().Filter("'Count' > ?", 10).All(ctx, &filtered) } ``` @@ -71,14 +81,14 @@ Please see the [DynamoDB reference on expressions](http://docs.aws.amazon.com/am ```go // Using single quotes to escape a reserved word, and a question mark as a value placeholder. // Finds all items whose date is greater than or equal to lastUpdate. -table.Scan().Filter("'Date' >= ?", lastUpdate).All(&results) +table.Scan().Filter("'Date' >= ?", lastUpdate).All(ctx, &results) // Using dollar signs as a placeholder for attribute names. // Deletes the item with an ID of 42 if its score is at or below the cutoff, and its name starts with G. -table.Delete("ID", 42).If("Score <= ? AND begins_with($, ?)", cutoff, "Name", "G").Run() +table.Delete("ID", 42).If("Score <= ? AND begins_with($, ?)", cutoff, "Name", "G").Run(ctx) // Put a new item, only if it doesn't already exist. -table.Put(item{ID: 42}).If("attribute_not_exists(ID)").Run() +table.Put(item{ID: 42}).If("attribute_not_exists(ID)").Run(ctx) ``` ### Encoding support @@ -177,42 +187,38 @@ This creates a table with the primary hash key ID and range key Time. It creates ### Retrying -Requests that fail with certain errors (e.g. `ThrottlingException`) are [automatically retried](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.Errors.html#Programming.Errors.RetryAndBackoff). -Methods that take a `context.Context` will retry until the context is canceled. -Methods without a context will use the `RetryTimeout` global variable, which can be changed; using context is recommended instead. - -#### Limiting or disabling retrying +As of v2, dynamo relies on the AWS SDK for retrying. See: [**Retries and Timeouts documentation**](https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/retries-timeouts/) for information about how to configure its behavior. -The maximum number of retries can be configured via the `MaxRetries` field in the `*aws.Config` passed to `dynamo.New()`. A value of `0` will disable retrying. A value of `-1` means unlimited and is the default (however, context or `RetryTimeout` will still apply). +By default, canceled transactions (i.e. errors from conflicting transactions) will not be retried. To get automatic retrying behavior like in v1, use [`dynamo.RetryTxConflicts`](https://godoc.org/github.com/guregu/dynamo/v2#RetryTxConflicts). ```go -db := dynamo.New(session, &aws.Config{ - MaxRetries: aws.Int(0), // disables automatic retrying -}) -``` - -#### Custom retrying logic - -If a custom [`request.Retryer`](https://pkg.go.dev/github.com/aws/aws-sdk-go/aws/request#Retryer) is set via the `Retryer` field in `*aws.Config`, dynamo will delegate retrying entirely to it, taking precedence over other retrying settings. This allows you to have full control over all aspects of retrying. +import ( + "context" + "log" -Example using [`client.DefaultRetryer`](https://pkg.go.dev/github.com/aws/aws-sdk-go/aws/client#DefaultRetryer): + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/guregu/dynamo/v2" +) -```go -retryer := client.DefaultRetryer{ - NumMaxRetries: 10, - MinThrottleDelay: 500 * time.Millisecond, - MaxThrottleDelay: 30 * time.Second, +func main() { + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(dynamo.RetryTxConflicts) + })) + if err != nil { + log.Fatal(err) + } + db := dynamo.New(cfg) + // use db } -db := dynamo.New(session, &aws.Config{ - Retryer: retryer, -}) ``` ### Compatibility with the official AWS library -dynamo has been in development before the official AWS libraries were stable. We use a different encoder and decoder than the [dynamodbattribute](https://godoc.org/github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute) package. dynamo uses the `dynamo` struct tag instead of the `dynamodbav` struct tag, and we also prefer to automatically omit invalid values such as empty strings, whereas the dynamodbattribute package substitutes null values for them. Items that satisfy the [`dynamodbattribute.(Un)marshaler`](https://godoc.org/github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute#Marshaler) interfaces are compatibile with both libraries. +dynamo has been in development before the official AWS libraries were stable. We use a different encoder and decoder than the [dynamodbattribute](https://pkg.go.dev/github.com/jviney/aws-sdk-go-v2/service/dynamodb/dynamodbattribute) package. dynamo uses the `dynamo` struct tag instead of the `dynamodbav` struct tag, and we also prefer to automatically omit invalid values such as empty strings, whereas the dynamodbattribute package substitutes null values for them. Items that satisfy the [`dynamodbattribute.(Un)marshaler`](https://pkg.go.dev/github.com/jviney/aws-sdk-go-v2/service/dynamodb/dynamodbattribute#Marshaler) interfaces are compatibile with both libraries. -In order to use dynamodbattribute's encoding facilities, you must wrap objects passed to dynamo with [`dynamo.AWSEncoding`](https://godoc.org/github.com/guregu/dynamo#AWSEncoding). Here is a quick example: +In order to use dynamodbattribute's encoding facilities, you must wrap objects passed to dynamo with [`dynamo.AWSEncoding`](https://godoc.org/github.com/guregu/dynamo/v2#AWSEncoding). Here is a quick example: ```go // Notice the use of the dynamodbav struct tag @@ -224,12 +230,23 @@ type book struct { err := db.Table("Books").Put(dynamo.AWSEncoding(book{ ID: 42, Title: "Principia Discordia", -})).Run() +})).Run(ctx) // When getting an item you MUST pass a pointer to AWSEncoding! var someBook book -err := db.Table("Books").Get("ID", 555).One(dynamo.AWSEncoding(&someBook)) +err := db.Table("Books").Get("ID", 555).One(ctx, dynamo.AWSEncoding(&someBook)) ``` +### Migrating from v1 + +The API hasn't changed much from v1 to v2. Here are some migration tips: + +- All request methods now take a [context](https://go.dev/blog/context) as their first argument. +- Retrying relies on the AWS SDK configuration, see: [Retrying](#retrying). + - Transactions won't retry TransactionCanceled responses by default anymore, make sure you configure that if you need it. +- Arguments that took `int64` (such as in `Query.Limit`) now take `int` instead. +- [Compatibility with the official AWS library](#compatibility-with-the-official-aws-library) uses v2 interfaces instead of v1. +- `KMSMasterKeyArn` renamed to `KMSMasterKeyARN`. + ### Integration tests By default, tests are run in offline mode. In order to run the integration tests, some environment variables need to be set. diff --git a/attr.go b/attr.go index 50de3b6..269af89 100644 --- a/attr.go +++ b/attr.go @@ -4,11 +4,11 @@ import ( "fmt" "strconv" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Item is a type alias for the raw DynamoDB item type. -type Item = map[string]*dynamodb.AttributeValue +type Item = map[string]types.AttributeValue type shapeKey byte @@ -31,51 +31,51 @@ const ( shapeInvalid shapeKey = 0 ) -func shapeOf(av *dynamodb.AttributeValue) shapeKey { +func shapeOf(av types.AttributeValue) shapeKey { if av == nil { return shapeInvalid } - switch { - case av.B != nil: + switch av.(type) { + case *types.AttributeValueMemberB: return shapeB - case av.BS != nil: + case *types.AttributeValueMemberBS: return shapeBS - case av.BOOL != nil: + case *types.AttributeValueMemberBOOL: return shapeBOOL - case av.N != nil: + case *types.AttributeValueMemberN: return shapeN - case av.S != nil: + case *types.AttributeValueMemberS: return shapeS - case av.L != nil: + case *types.AttributeValueMemberL: return shapeL - case av.NS != nil: + case *types.AttributeValueMemberNS: return shapeNS - case av.SS != nil: + case *types.AttributeValueMemberSS: return shapeSS - case av.M != nil: + case *types.AttributeValueMemberM: return shapeM - case av.NULL != nil: + case *types.AttributeValueMemberNULL: return shapeNULL } return shapeAny } // av2iface converts an av into interface{}. -func av2iface(av *dynamodb.AttributeValue) (interface{}, error) { - switch { - case av.B != nil: - return av.B, nil - case av.BS != nil: - return av.BS, nil - case av.BOOL != nil: - return *av.BOOL, nil - case av.N != nil: - return strconv.ParseFloat(*av.N, 64) - case av.S != nil: - return *av.S, nil - case av.L != nil: - list := make([]interface{}, 0, len(av.L)) - for _, item := range av.L { +func av2iface(av types.AttributeValue) (interface{}, error) { + switch v := av.(type) { + case *types.AttributeValueMemberB: + return v.Value, nil + case *types.AttributeValueMemberBS: + return v.Value, nil + case *types.AttributeValueMemberBOOL: + return v.Value, nil + case *types.AttributeValueMemberN: + return strconv.ParseFloat(v.Value, 64) + case *types.AttributeValueMemberS: + return v.Value, nil + case *types.AttributeValueMemberL: + list := make([]interface{}, 0, len(v.Value)) + for _, item := range v.Value { iface, err := av2iface(item) if err != nil { return nil, err @@ -83,25 +83,21 @@ func av2iface(av *dynamodb.AttributeValue) (interface{}, error) { list = append(list, iface) } return list, nil - case av.NS != nil: - set := make([]float64, 0, len(av.NS)) - for _, n := range av.NS { - f, err := strconv.ParseFloat(*n, 64) + case *types.AttributeValueMemberNS: + set := make([]float64, 0, len(v.Value)) + for _, n := range v.Value { + f, err := strconv.ParseFloat(n, 64) if err != nil { return nil, err } set = append(set, f) } return set, nil - case av.SS != nil: - set := make([]string, 0, len(av.SS)) - for _, s := range av.SS { - set = append(set, *s) - } - return set, nil - case av.M != nil: - m := make(map[string]interface{}, len(av.M)) - for k, v := range av.M { + case *types.AttributeValueMemberSS: + return v.Value, nil + case *types.AttributeValueMemberM: + m := make(map[string]interface{}, len(v.Value)) + for k, v := range v.Value { iface, err := av2iface(v) if err != nil { return nil, err @@ -109,36 +105,36 @@ func av2iface(av *dynamodb.AttributeValue) (interface{}, error) { m[k] = iface } return m, nil - case av.NULL != nil: + case *types.AttributeValueMemberNULL: return nil, nil } - return nil, fmt.Errorf("dynamo: unsupported AV: %#v", *av) + return nil, fmt.Errorf("dynamo: unsupported AV: %#v", av) } -func avTypeName(av *dynamodb.AttributeValue) string { +func avTypeName(av types.AttributeValue) string { if av == nil { return "" } - switch { - case av.B != nil: + switch av.(type) { + case *types.AttributeValueMemberB: return "binary" - case av.BS != nil: + case *types.AttributeValueMemberBS: return "binary set" - case av.BOOL != nil: + case *types.AttributeValueMemberBOOL: return "boolean" - case av.N != nil: + case *types.AttributeValueMemberN: return "number" - case av.S != nil: + case *types.AttributeValueMemberS: return "string" - case av.L != nil: + case *types.AttributeValueMemberL: return "list" - case av.NS != nil: + case *types.AttributeValueMemberNS: return "number set" - case av.SS != nil: + case *types.AttributeValueMemberSS: return "string set" - case av.M != nil: + case *types.AttributeValueMemberM: return "map" - case av.NULL != nil: + case *types.AttributeValueMemberNULL: return "null" } return "" diff --git a/batch_test.go b/batch_test.go index aec7c38..c6ad405 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" "time" ) @@ -15,6 +16,7 @@ func TestBatchGetWrite(t *testing.T) { table2 := testDB.Table(testTableSprockets) tables := []Table{table1, table2} totalBatchSize := batchSize * len(tables) + ctx := context.TODO() items := make([]interface{}, batchSize) widgets := make(map[int]widget) @@ -39,7 +41,7 @@ func TestBatchGetWrite(t *testing.T) { batch1 := batches[0] batch1.Merge(batches[1:]...) var wcc ConsumedCapacity - wrote, err := batch1.ConsumedCapacity(&wcc).Run() + wrote, err := batch1.ConsumedCapacity(&wcc).Run(ctx) if wrote != totalBatchSize { t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } @@ -65,7 +67,7 @@ func TestBatchGetWrite(t *testing.T) { get1.Merge(gets[1:]...) var results []widget - err = get1.All(&results) + err = get1.All(ctx, &results) if err != nil { t.Error("unexpected error:", err) } @@ -92,7 +94,7 @@ func TestBatchGetWrite(t *testing.T) { wrote, err = table1.Batch("UserID", "Time").Write(). Delete(keys...). DeleteInRange(table2, "UserID", "Time", keys...). - Run() + Run(ctx) if wrote != totalBatchSize { t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } @@ -107,7 +109,7 @@ func TestBatchGetWrite(t *testing.T) { Get(keys...). FromRange(table2, "UserID", "Time", keys...). Consistent(true). - All(&results) + All(ctx, &results) if err != ErrNotFound { t.Error("expected ErrNotFound, got", err) } @@ -122,15 +124,16 @@ func TestBatchGetEmptySets(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() now := time.Now().UnixNano() / 1000000000 id := int(now) entry := widget{UserID: id, Time: time.Now()} - if err := table.Put(entry).Run(); err != nil { + if err := table.Put(entry).Run(ctx); err != nil { panic(err) } entry2 := widget{UserID: id + batchSize*2, Time: entry.Time} - if err := table.Put(entry2).Run(); err != nil { + if err := table.Put(entry2).Run(ctx); err != nil { panic(err) } @@ -140,7 +143,7 @@ func TestBatchGetEmptySets(t *testing.T) { } results := []widget{} - err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results) + err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results) if err != nil { t.Error(err) } @@ -148,12 +151,12 @@ func TestBatchGetEmptySets(t *testing.T) { t.Error("batch get empty set, unexpected length:", len(results), "want:", 2) } - if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(); err != nil { + if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(ctx); err != nil { panic(err) } results = []widget{} - err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results) + err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results) if err != nil { t.Error(err) } @@ -162,7 +165,7 @@ func TestBatchGetEmptySets(t *testing.T) { } results = []widget{} - err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(&results) + err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(ctx, &results) if err != ErrNotFound { t.Error(err) } @@ -173,13 +176,14 @@ func TestBatchGetEmptySets(t *testing.T) { func TestBatchEmptyInput(t *testing.T) { table := testDB.Table(testTableWidgets) + ctx := context.TODO() var out []any - err := table.Batch("UserID", "Time").Get().All(&out) + err := table.Batch("UserID", "Time").Get().All(ctx, &out) if err != ErrNoInput { t.Error("unexpected error", err) } - _, err = table.Batch("UserID", "Time").Write().Run() + _, err = table.Batch("UserID", "Time").Write().Run(ctx) if err != ErrNoInput { t.Error("unexpected error", err) } diff --git a/batchget.go b/batchget.go index ef8e55f..09a035c 100644 --- a/batchget.go +++ b/batchget.go @@ -6,8 +6,9 @@ import ( "fmt" "slices" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/aws/smithy-go/time" "github.com/cenkalti/backoff/v4" ) @@ -178,17 +179,9 @@ func (bg *BatchGet) ConsumedCapacity(cc *ConsumedCapacity) *BatchGet { } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (bg *BatchGet) All(out interface{}) error { +func (bg *BatchGet) All(ctx context.Context, out interface{}) error { iter := newBGIter(bg, unmarshalAppendTo(out), nil, bg.err) - for iter.Next(out) { - } - return iter.Err() -} - -// AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (bg *BatchGet) AllWithContext(ctx context.Context, out interface{}) error { - iter := newBGIter(bg, unmarshalAppendTo(out), nil, bg.err) - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } return iter.Err() } @@ -214,7 +207,7 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { } in := &dynamodb.BatchGetItemInput{ - RequestItems: make(map[string]*dynamodb.KeysAndAttributes), + RequestItems: make(map[string]types.KeysAndAttributes), } for _, get := range bg.reqs[start:end] { @@ -224,13 +217,13 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { } } if bg.cc != nil { - in.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + in.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } for _, get := range bg.reqs[start:end] { table := get.table.Name() - kas := in.RequestItems[table] - if kas == nil { + kas, ok := in.RequestItems[table] + if !ok { kas = get.keysAndAttribs() if bg.consistent { kas.ConsistentRead = &bg.consistent @@ -287,13 +280,7 @@ func newBGIter(bg *BatchGet, fn unmarshalFunc, track *string, err error) *bgIter // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *bgIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *bgIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *bgIter) Next(ctx context.Context, out interface{}) bool { // stop if we have an error if ctx.Err() != nil { itr.err = ctx.Err() @@ -341,7 +328,7 @@ redo: // no, prepare a new request with the remaining keys itr.input.RequestItems = itr.output.UnprocessedKeys // we need to sleep here a bit as per the official docs - if err := aws.SleepWithContext(ctx, itr.backoff.NextBackOff()); err != nil { + if err := time.SleepWithContext(ctx, itr.backoff.NextBackOff()); err != nil { // timed out itr.err = err return false @@ -352,7 +339,7 @@ redo: itr.err = itr.bg.batch.table.db.retry(ctx, func() error { var err error - itr.output, err = itr.bg.batch.table.db.client.BatchGetItemWithContext(ctx, itr.input) + itr.output, err = itr.bg.batch.table.db.client.BatchGetItem(ctx, itr.input) return err }) if itr.err != nil { @@ -360,7 +347,7 @@ redo: } if itr.bg.cc != nil { for _, cc := range itr.output.ConsumedCapacity { - addConsumedCapacity(itr.bg.cc, cc) + addConsumedCapacity(itr.bg.cc, &cc) } } diff --git a/batchwrite.go b/batchwrite.go index 41b14fe..adf5021 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -4,8 +4,9 @@ import ( "context" "math" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/aws/smithy-go/time" "github.com/cenkalti/backoff/v4" ) @@ -22,7 +23,7 @@ type BatchWrite struct { type batchWrite struct { table string - op *dynamodb.WriteRequest + op types.WriteRequest } // Write creates a new batch write request, to which @@ -48,7 +49,7 @@ func (bw *BatchWrite) PutIn(table Table, items ...interface{}) *BatchWrite { bw.setError(err) bw.ops = append(bw.ops, batchWrite{ table: name, - op: &dynamodb.WriteRequest{PutRequest: &dynamodb.PutRequest{ + op: types.WriteRequest{PutRequest: &types.PutRequest{ Item: encoded, }}, }) @@ -87,7 +88,7 @@ func (bw *BatchWrite) deleteIn(table Table, hashKey, rangeKey string, keys ...Ke } bw.ops = append(bw.ops, batchWrite{ table: name, - op: &dynamodb.WriteRequest{DeleteRequest: &dynamodb.DeleteRequest{ + op: types.WriteRequest{DeleteRequest: &types.DeleteRequest{ Key: del.key(), }}, }) @@ -113,17 +114,7 @@ func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite { // For batches with more than 25 operations, an error could indicate that // some records have been written and some have not. Consult the wrote // return amount to figure out which operations have succeeded. -func (bw *BatchWrite) Run() (wrote int, err error) { - ctx, cancel := defaultContext() - defer cancel() - return bw.RunWithContext(ctx) -} - -// RunWithContext executes this batch. -// For batches with more than 25 operations, an error could indicate that -// some records have been written and some have not. Consult the wrote -// return amount to figure out which operations have succeeded. -func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) { +func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { if bw.err != nil { return 0, bw.err } @@ -147,7 +138,7 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) req := bw.input(ops) err := bw.batch.table.db.retry(ctx, func() error { var err error - res, err = bw.batch.table.db.client.BatchWriteItemWithContext(ctx, req) + res, err = bw.batch.table.db.client.BatchWriteItem(ctx, req) return err }) if err != nil { @@ -155,7 +146,7 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) } if bw.cc != nil { for _, cc := range res.ConsumedCapacity { - addConsumedCapacity(bw.cc, cc) + addConsumedCapacity(bw.cc, &cc) } } @@ -176,7 +167,7 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) } // need to sleep when re-requesting, per spec - if err := aws.SleepWithContext(ctx, boff.NextBackOff()); err != nil { + if err := time.SleepWithContext(ctx, boff.NextBackOff()); err != nil { // timed out return wrote, err } @@ -187,7 +178,7 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) } func (bw *BatchWrite) input(ops []batchWrite) *dynamodb.BatchWriteItemInput { - items := make(map[string][]*dynamodb.WriteRequest) + items := make(map[string][]types.WriteRequest) for _, op := range ops { items[op.table] = append(items[op.table], op.op) } @@ -195,7 +186,7 @@ func (bw *BatchWrite) input(ops []batchWrite) *dynamodb.BatchWriteItemInput { RequestItems: items, } if bw.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input } diff --git a/bench_test.go b/bench_test.go index 577633b..9ae5fa7 100644 --- a/bench_test.go +++ b/bench_test.go @@ -6,8 +6,7 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) var ( @@ -146,8 +145,8 @@ func BenchmarkUnmarshalText(b *testing.B) { // x := newRecipe(rv) for i := 0; i < b.N; i++ { r, _ := typedefOf(rv.Type()) - if err := r.decodeItem(map[string]*dynamodb.AttributeValue{ - "Foo": {S: aws.String("true")}, + if err := r.decodeItem(map[string]types.AttributeValue{ + "Foo": &types.AttributeValueMemberS{Value: "true"}, }, rv); err != nil { b.Fatal(err) } @@ -162,7 +161,7 @@ func BenchmarkUnmarshalAppend(b *testing.B) { items := make([]Item, 10_000) for i := range items { items[i] = Item{ - "Hello": &dynamodb.AttributeValue{S: aws.String("world")}, + "Hello": &types.AttributeValueMemberS{Value: "world"}, } } b.ResetTimer() @@ -187,7 +186,7 @@ func BenchmarkUnmarshalAppend2(b *testing.B) { items := make([]Item, 10_000) for i := range items { items[i] = Item{ - "Hello": &dynamodb.AttributeValue{S: aws.String("world")}, + "Hello": &types.AttributeValueMemberS{Value: "world"}, } } b.ResetTimer() diff --git a/conditioncheck.go b/conditioncheck.go index fedc639..3c59fd1 100644 --- a/conditioncheck.go +++ b/conditioncheck.go @@ -3,8 +3,8 @@ package dynamo import ( "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // ConditionCheck represents a condition for a write transaction to succeed. @@ -12,9 +12,9 @@ import ( type ConditionCheck struct { table Table hashKey string - hashValue *dynamodb.AttributeValue + hashValue types.AttributeValue rangeKey string - rangeValue *dynamodb.AttributeValue + rangeValue types.AttributeValue condition string subber @@ -74,11 +74,11 @@ func (check *ConditionCheck) IfNotExists() *ConditionCheck { return check.If("attribute_not_exists($)", check.hashKey) } -func (check *ConditionCheck) writeTxItem() (*dynamodb.TransactWriteItem, error) { +func (check *ConditionCheck) writeTxItem() (*types.TransactWriteItem, error) { if check.err != nil { return nil, check.err } - item := &dynamodb.ConditionCheck{ + item := &types.ConditionCheck{ TableName: aws.String(check.table.name), Key: check.keys(), ExpressionAttributeNames: check.nameExpr, @@ -87,13 +87,13 @@ func (check *ConditionCheck) writeTxItem() (*dynamodb.TransactWriteItem, error) if check.condition != "" { item.ConditionExpression = aws.String(check.condition) } - return &dynamodb.TransactWriteItem{ + return &types.TransactWriteItem{ ConditionCheck: item, }, nil } -func (check *ConditionCheck) keys() map[string]*dynamodb.AttributeValue { - keys := map[string]*dynamodb.AttributeValue{check.hashKey: check.hashValue} +func (check *ConditionCheck) keys() Item { + keys := Item{check.hashKey: check.hashValue} if check.rangeKey != "" { keys[check.rangeKey] = check.rangeValue } diff --git a/createtable.go b/createtable.go index 2176c80..168798d 100644 --- a/createtable.go +++ b/createtable.go @@ -8,9 +8,10 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // StreamView determines what information is written to a table's stream. @@ -18,13 +19,13 @@ type StreamView string var ( // Only the key attributes of the modified item are written to the stream. - KeysOnlyView StreamView = dynamodb.StreamViewTypeKeysOnly + KeysOnlyView = StreamView(types.StreamViewTypeKeysOnly) // The entire item, as it appears after it was modified, is written to the stream. - NewImageView StreamView = dynamodb.StreamViewTypeNewImage + NewImageView = StreamView(types.StreamViewTypeNewImage) // The entire item, as it appeared before it was modified, is written to the stream. - OldImageView StreamView = dynamodb.StreamViewTypeOldImage + OldImageView = StreamView(types.StreamViewTypeOldImage) // Both the new and the old item images of the item are written to the stream. - NewAndOldImagesView StreamView = dynamodb.StreamViewTypeNewAndOldImages + NewAndOldImagesView = StreamView(types.StreamViewTypeNewAndOldImages) ) // IndexProjection determines which attributes are mirrored into indices. @@ -32,11 +33,11 @@ type IndexProjection string var ( // Only the key attributes of the modified item are written to the stream. - KeysOnlyProjection IndexProjection = dynamodb.ProjectionTypeKeysOnly + KeysOnlyProjection = IndexProjection(types.ProjectionTypeKeysOnly) // All of the table attributes are projected into the index. - AllProjection IndexProjection = dynamodb.ProjectionTypeAll + AllProjection = IndexProjection(types.ProjectionTypeAll) // Only the specified table attributes are projected into the index. - IncludeProjection IndexProjection = dynamodb.ProjectionTypeInclude + IncludeProjection = IndexProjection(types.ProjectionTypeInclude) ) // CreateTable is a request to create a new table. @@ -44,16 +45,16 @@ var ( type CreateTable struct { db *DB tableName string - attribs []*dynamodb.AttributeDefinition - schema []*dynamodb.KeySchemaElement - globalIndices map[string]dynamodb.GlobalSecondaryIndex - localIndices map[string]dynamodb.LocalSecondaryIndex + attribs []types.AttributeDefinition + schema []types.KeySchemaElement + globalIndices map[string]types.GlobalSecondaryIndex + localIndices map[string]types.LocalSecondaryIndex readUnits int64 writeUnits int64 streamView StreamView ondemand bool - tags []*dynamodb.Tag - encryptionSpecification *dynamodb.SSESpecification + tags []types.Tag + encryptionSpecification *types.SSESpecification err error } @@ -77,12 +78,12 @@ func (db *DB) CreateTable(name string, from interface{}) *CreateTable { ct := &CreateTable{ db: db, tableName: name, - schema: []*dynamodb.KeySchemaElement{}, - globalIndices: make(map[string]dynamodb.GlobalSecondaryIndex), - localIndices: make(map[string]dynamodb.LocalSecondaryIndex), + schema: []types.KeySchemaElement{}, + globalIndices: make(map[string]types.GlobalSecondaryIndex), + localIndices: make(map[string]types.LocalSecondaryIndex), readUnits: 1, writeUnits: 1, - tags: []*dynamodb.Tag{}, + tags: []types.Tag{}, } rv := reflect.ValueOf(from) ct.setError(ct.from(rv)) @@ -107,7 +108,7 @@ func (ct *CreateTable) Provision(readUnits, writeUnits int64) *CreateTable { // global secondary index. Local secondary indices share their capacity with the table. func (ct *CreateTable) ProvisionIndex(index string, readUnits, writeUnits int64) *CreateTable { idx := ct.globalIndices[index] - idx.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + idx.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: &readUnits, WriteCapacityUnits: &writeUnits, } @@ -125,20 +126,20 @@ func (ct *CreateTable) Stream(view StreamView) *CreateTable { // Project specifies the projection type for the given table. // When using IncludeProjection, you must specify the additional attributes to include via includeAttribs. func (ct *CreateTable) Project(index string, projection IndexProjection, includeAttribs ...string) *CreateTable { - projectionStr := string(projection) - proj := &dynamodb.Projection{ - ProjectionType: &projectionStr, + projectionStr := types.ProjectionType(projection) + proj := &types.Projection{ + ProjectionType: projectionStr, } if projection == IncludeProjection { attribs: for _, attr := range includeAttribs { attr := attr for _, a := range proj.NonKeyAttributes { - if attr == *a { + if attr == a { continue attribs } } - proj.NonKeyAttributes = append(proj.NonKeyAttributes, &attr) + proj.NonKeyAttributes = append(proj.NonKeyAttributes, attr) } } if idx, global := ct.globalIndices[index]; global { @@ -156,27 +157,27 @@ func (ct *CreateTable) Project(index string, projection IndexProjection, include // Index specifies an index to add to this table. func (ct *CreateTable) Index(index Index) *CreateTable { ct.add(index.HashKey, string(index.HashKeyType)) - ks := []*dynamodb.KeySchemaElement{ + ks := []types.KeySchemaElement{ { AttributeName: &index.HashKey, - KeyType: aws.String(dynamodb.KeyTypeHash), + KeyType: types.KeyTypeHash, }, } if index.RangeKey != "" { ct.add(index.RangeKey, string(index.RangeKeyType)) - ks = append(ks, &dynamodb.KeySchemaElement{ + ks = append(ks, types.KeySchemaElement{ AttributeName: &index.RangeKey, - KeyType: aws.String(dynamodb.KeyTypeRange), + KeyType: types.KeyTypeRange, }) } - var proj *dynamodb.Projection + var proj *types.Projection if index.ProjectionType != "" { - proj = &dynamodb.Projection{ - ProjectionType: aws.String((string)(index.ProjectionType)), + proj = &types.Projection{ + ProjectionType: types.ProjectionType(index.ProjectionType), } if index.ProjectionType == IncludeProjection { - proj.NonKeyAttributes = aws.StringSlice(index.ProjectionAttribs) + proj.NonKeyAttributes = index.ProjectionAttribs } } @@ -193,7 +194,7 @@ func (ct *CreateTable) Index(index Index) *CreateTable { idx := ct.globalIndices[index.Name] idx.KeySchema = ks if index.Throughput.Read != 0 || index.Throughput.Write != 0 { - idx.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + idx.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: &index.Throughput.Read, WriteCapacityUnits: &index.Throughput.Write, } @@ -213,7 +214,7 @@ func (ct *CreateTable) Tag(key, value string) *CreateTable { return ct } } - tag := &dynamodb.Tag{ + tag := types.Tag{ Key: aws.String(key), Value: aws.String(value), } @@ -224,48 +225,34 @@ func (ct *CreateTable) Tag(key, value string) *CreateTable { // SSEEncryption specifies the server side encryption for this table. // Encryption is disabled by default. func (ct *CreateTable) SSEEncryption(enabled bool, keyID string, sseType SSEType) *CreateTable { - encryption := &dynamodb.SSESpecification{ + encryption := types.SSESpecification{ Enabled: aws.Bool(enabled), KMSMasterKeyId: aws.String(keyID), - SSEType: aws.String(string(sseType)), + SSEType: types.SSEType(string(sseType)), } - ct.encryptionSpecification = encryption + ct.encryptionSpecification = &encryption return ct } // Run creates this table or returns an error. -func (ct *CreateTable) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return ct.RunWithContext(ctx) -} - -// RunWithContext creates this table or returns an error. -func (ct *CreateTable) RunWithContext(ctx context.Context) error { +func (ct *CreateTable) Run(ctx context.Context) error { if ct.err != nil { return ct.err } input := ct.input() return ct.db.retry(ctx, func() error { - _, err := ct.db.client.CreateTableWithContext(ctx, input) + _, err := ct.db.client.CreateTable(ctx, input) return err }) } // Wait creates this table and blocks until it exists and is ready to use. -func (ct *CreateTable) Wait() error { - ctx, cancel := defaultContext() - defer cancel() - return ct.WaitWithContext(ctx) -} - -// WaitWithContext creates this table and blocks until it exists and is ready to use. -func (ct *CreateTable) WaitWithContext(ctx context.Context) error { - if err := ct.RunWithContext(ctx); err != nil { +func (ct *CreateTable) Wait(ctx context.Context) error { + if err := ct.Run(ctx); err != nil { return err } - return ct.db.Table(ct.tableName).WaitWithContext(ctx) + return ct.db.Table(ct.tableName).Wait(ctx) } func (ct *CreateTable) from(rv reflect.Value) error { @@ -297,9 +284,9 @@ func (ct *CreateTable) from(rv reflect.Value) error { // primary keys if keyType := keyTypeFromTag(field.Tag.Get("dynamo")); keyType != "" { ct.add(name, typeOf(fv, field.Tag.Get("dynamo"))) - ct.schema = append(ct.schema, &dynamodb.KeySchemaElement{ + ct.schema = append(ct.schema, types.KeySchemaElement{ AttributeName: &name, - KeyType: &keyType, + KeyType: types.KeyType(keyType), }) } @@ -310,9 +297,9 @@ func (ct *CreateTable) from(rv reflect.Value) error { keyType := keyTypeFromTag(index) indexName := index[:len(index)-len(keyType)-1] idx := ct.globalIndices[indexName] - idx.KeySchema = append(idx.KeySchema, &dynamodb.KeySchemaElement{ + idx.KeySchema = append(idx.KeySchema, types.KeySchemaElement{ AttributeName: &name, - KeyType: &keyType, + KeyType: types.KeyType(keyType), }) ct.globalIndices[indexName] = idx } @@ -325,9 +312,9 @@ func (ct *CreateTable) from(rv reflect.Value) error { keyType := keyTypeFromTag(localIndex) indexName := localIndex[:len(localIndex)-len(keyType)-1] idx := ct.localIndices[indexName] - idx.KeySchema = append(idx.KeySchema, &dynamodb.KeySchemaElement{ + idx.KeySchema = append(idx.KeySchema, types.KeySchemaElement{ AttributeName: &name, - KeyType: &keyType, + KeyType: types.KeyType(keyType), }) ct.localIndices[indexName] = idx } @@ -346,9 +333,9 @@ func (ct *CreateTable) input() *dynamodb.CreateTableInput { SSESpecification: ct.encryptionSpecification, } if ct.ondemand { - input.BillingMode = aws.String(dynamodb.BillingModePayPerRequest) + input.BillingMode = types.BillingModePayPerRequest } else { - input.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + input.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: &ct.readUnits, WriteCapacityUnits: &ct.writeUnits, } @@ -356,9 +343,9 @@ func (ct *CreateTable) input() *dynamodb.CreateTableInput { if ct.streamView != "" { enabled := true view := string(ct.streamView) - input.StreamSpecification = &dynamodb.StreamSpecification{ + input.StreamSpecification = &types.StreamSpecification{ StreamEnabled: &enabled, - StreamViewType: &view, + StreamViewType: types.StreamViewType(view), } } for name, idx := range ct.localIndices { @@ -366,40 +353,40 @@ func (ct *CreateTable) input() *dynamodb.CreateTableInput { idx.IndexName = &name if idx.Projection == nil { all := string(AllProjection) - idx.Projection = &dynamodb.Projection{ - ProjectionType: &all, + idx.Projection = &types.Projection{ + ProjectionType: types.ProjectionType(all), } } // add the primary hash key if len(idx.KeySchema) == 1 { - idx.KeySchema = []*dynamodb.KeySchemaElement{ + idx.KeySchema = []types.KeySchemaElement{ ct.schema[0], idx.KeySchema[0], } } sortKeySchemas(idx.KeySchema) - input.LocalSecondaryIndexes = append(input.LocalSecondaryIndexes, &idx) + input.LocalSecondaryIndexes = append(input.LocalSecondaryIndexes, idx) } for name, idx := range ct.globalIndices { name, idx := name, idx idx.IndexName = &name if idx.Projection == nil { all := string(AllProjection) - idx.Projection = &dynamodb.Projection{ - ProjectionType: &all, + idx.Projection = &types.Projection{ + ProjectionType: types.ProjectionType(all), } } if ct.ondemand { idx.ProvisionedThroughput = nil } else if idx.ProvisionedThroughput == nil { units := int64(1) - idx.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + idx.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: &units, WriteCapacityUnits: &units, } } sortKeySchemas(idx.KeySchema) - input.GlobalSecondaryIndexes = append(input.GlobalSecondaryIndexes, &idx) + input.GlobalSecondaryIndexes = append(input.GlobalSecondaryIndexes, idx) } if len(ct.tags) > 0 { input.Tags = ct.tags @@ -419,9 +406,9 @@ func (ct *CreateTable) add(name string, typ string) { } } - ct.attribs = append(ct.attribs, &dynamodb.AttributeDefinition{ + ct.attribs = append(ct.attribs, types.AttributeDefinition{ AttributeName: &name, - AttributeType: &typ, + AttributeType: types.ScalarAttributeType(typ), }) } @@ -448,9 +435,9 @@ func typeOf(rv reflect.Value, tag string) string { return typeOf(reflect.ValueOf(iface), tag) } } - case dynamodbattribute.Marshaler: - av := &dynamodb.AttributeValue{} - if err := x.MarshalDynamoDBAttributeValue(av); err == nil { + case attributevalue.Marshaler: + + if av, err := x.MarshalDynamoDBAttributeValue(); err == nil { if iface, err := av2iface(av); err == nil { return typeOf(reflect.ValueOf(iface), tag) } @@ -481,7 +468,7 @@ check: return "" } -func keyTypeFromTag(tag string) string { +func keyTypeFromTag(tag string) types.KeyType { split := strings.Split(tag, ",") if len(split) <= 1 { return "" @@ -489,16 +476,16 @@ func keyTypeFromTag(tag string) string { for _, v := range split[1:] { switch v { case "hash", "partition": - return dynamodb.KeyTypeHash + return types.KeyTypeHash case "range", "sort": - return dynamodb.KeyTypeRange + return types.KeyTypeRange } } return "" } -func sortKeySchemas(schemas []*dynamodb.KeySchemaElement) { - if *schemas[0].KeyType == dynamodb.KeyTypeRange { +func sortKeySchemas(schemas []types.KeySchemaElement) { + if schemas[0].KeyType == types.KeyTypeRange { schemas[0], schemas[1] = schemas[1], schemas[0] } } diff --git a/createtable_test.go b/createtable_test.go index e5e8ad9..23d467c 100644 --- a/createtable_test.go +++ b/createtable_test.go @@ -5,9 +5,10 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) type UserAction struct { @@ -24,8 +25,8 @@ type embeddedWithKeys struct { } type Metric struct { - ID uint64 `dynamo:"ID,hash"` - Time dynamodbattribute.UnixTime `dynamo:",range"` + ID uint64 `dynamo:"ID,hash"` + Time attributevalue.UnixTime `dynamo:",range"` Value uint64 } @@ -51,73 +52,73 @@ func TestCreateTable(t *testing.T) { input() expected := &dynamodb.CreateTableInput{ - AttributeDefinitions: []*dynamodb.AttributeDefinition{ + AttributeDefinitions: []types.AttributeDefinition{ { AttributeName: aws.String("ID"), - AttributeType: aws.String("S"), + AttributeType: types.ScalarAttributeTypeS, }, { AttributeName: aws.String("Time"), - AttributeType: aws.String("S"), + AttributeType: types.ScalarAttributeTypeS, }, { AttributeName: aws.String("Seq"), - AttributeType: aws.String("N"), + AttributeType: types.ScalarAttributeTypeN, }, { AttributeName: aws.String("Embedded"), - AttributeType: aws.String("B"), + AttributeType: types.ScalarAttributeTypeB, }, }, - GlobalSecondaryIndexes: []*dynamodb.GlobalSecondaryIndex{{ + GlobalSecondaryIndexes: []types.GlobalSecondaryIndex{{ IndexName: aws.String("Embedded-index"), - KeySchema: []*dynamodb.KeySchemaElement{{ + KeySchema: []types.KeySchemaElement{{ AttributeName: aws.String("Embedded"), - KeyType: aws.String("HASH"), + KeyType: types.KeyTypeHash, }}, - Projection: &dynamodb.Projection{ - ProjectionType: aws.String("ALL"), + Projection: &types.Projection{ + ProjectionType: types.ProjectionTypeAll, }, - ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ProvisionedThroughput: &types.ProvisionedThroughput{ ReadCapacityUnits: aws.Int64(1), WriteCapacityUnits: aws.Int64(2), }, }}, - KeySchema: []*dynamodb.KeySchemaElement{{ + KeySchema: []types.KeySchemaElement{{ AttributeName: aws.String("ID"), - KeyType: aws.String("HASH"), + KeyType: types.KeyTypeHash, }, { AttributeName: aws.String("Time"), - KeyType: aws.String("RANGE"), + KeyType: types.KeyTypeRange, }}, - LocalSecondaryIndexes: []*dynamodb.LocalSecondaryIndex{{ + LocalSecondaryIndexes: []types.LocalSecondaryIndex{{ IndexName: aws.String("ID-Seq-index"), - KeySchema: []*dynamodb.KeySchemaElement{{ + KeySchema: []types.KeySchemaElement{{ AttributeName: aws.String("ID"), - KeyType: aws.String("HASH"), + KeyType: types.KeyTypeHash, }, { AttributeName: aws.String("Seq"), - KeyType: aws.String("RANGE"), + KeyType: types.KeyTypeRange, }}, - Projection: &dynamodb.Projection{ - ProjectionType: aws.String("INCLUDE"), - NonKeyAttributes: []*string{aws.String("UUID"), aws.String("Name")}, + Projection: &types.Projection{ + ProjectionType: types.ProjectionTypeInclude, + NonKeyAttributes: []string{"UUID", "Name"}, }, }}, - ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ProvisionedThroughput: &types.ProvisionedThroughput{ ReadCapacityUnits: aws.Int64(4), WriteCapacityUnits: aws.Int64(2), }, - Tags: []*dynamodb.Tag{ + Tags: []types.Tag{ { Key: aws.String("Tag-Key"), Value: aws.String("Tag-Value"), }, }, - SSESpecification: &dynamodb.SSESpecification{ + SSESpecification: &types.SSESpecification{ Enabled: aws.Bool(true), KMSMasterKeyId: aws.String("alias/key"), - SSEType: aws.String("KMS"), + SSEType: types.SSEType("KMS"), }, TableName: aws.String("UserActions"), } @@ -135,24 +136,24 @@ func TestCreateTableUintUnixTime(t *testing.T) { OnDemand(true). input() expected := &dynamodb.CreateTableInput{ - AttributeDefinitions: []*dynamodb.AttributeDefinition{ + AttributeDefinitions: []types.AttributeDefinition{ { AttributeName: aws.String("ID"), - AttributeType: aws.String("N"), + AttributeType: types.ScalarAttributeTypeN, }, { AttributeName: aws.String("Time"), - AttributeType: aws.String("N"), + AttributeType: types.ScalarAttributeTypeN, }, }, - KeySchema: []*dynamodb.KeySchemaElement{{ + KeySchema: []types.KeySchemaElement{{ AttributeName: aws.String("ID"), - KeyType: aws.String("HASH"), + KeyType: types.KeyTypeHash, }, { AttributeName: aws.String("Time"), - KeyType: aws.String("RANGE"), + KeyType: types.KeyTypeRange, }}, - BillingMode: aws.String(dynamodb.BillingModePayPerRequest), + BillingMode: types.BillingModePayPerRequest, TableName: aws.String("Metrics"), } if !reflect.DeepEqual(input, expected) { diff --git a/db.go b/db.go index d04b28a..511e297 100644 --- a/db.go +++ b/db.go @@ -5,58 +5,38 @@ import ( "context" "errors" "fmt" + "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/aws/smithy-go" + + "github.com/guregu/dynamo/v2/dynamodbiface" ) // DB is a DynamoDB client. type DB struct { - client dynamodbiface.DynamoDBAPI - logger aws.Logger - retryer request.Retryer - retryMax int + client dynamodbiface.DynamoDBAPI + // table description cache for LEK inference + descs *sync.Map // table name → Description } // New creates a new client with the given configuration. // If Retryer is configured, retrying responsibility will be delegated to it. // If MaxRetries is configured, the maximum number of retry attempts will be limited to the specified value // (0 for no retrying, -1 for default behavior of unlimited retries). -// MaxRetries is ignored if Retryer is set. -func New(p client.ConfigProvider, cfgs ...*aws.Config) *DB { - cfg := p.ClientConfig(dynamodb.EndpointsID, cfgs...) - return newDB(dynamodb.New(p, cfgs...), cfg.Config) +func New(cfg aws.Config, options ...func(*dynamodb.Options)) *DB { + client := dynamodb.NewFromConfig(cfg, options...) + return NewFromIface(client) } // NewFromIface creates a new client with the given interface. func NewFromIface(client dynamodbiface.DynamoDBAPI) *DB { - if c, ok := client.(*dynamodb.DynamoDB); ok { - return newDB(c, &c.Config) - } - return newDB(client, &aws.Config{}) -} - -func newDB(client dynamodbiface.DynamoDBAPI, cfg *aws.Config) *DB { db := &DB{ - client: client, - logger: cfg.Logger, - retryMax: -1, + client: client, + descs: new(sync.Map), } - - if db.logger == nil { - db.logger = aws.NewDefaultLogger() - } - - if retryer, ok := cfg.Retryer.(request.Retryer); ok { - db.retryer = retryer - } else if cfg.MaxRetries != nil { - db.retryMax = *cfg.MaxRetries - } - return db } @@ -65,31 +45,15 @@ func (db *DB) Client() dynamodbiface.DynamoDBAPI { return db.client } -// TODO: should we expose these, or come up with a better interface? -// They could be useful in conjunction with NewFromIface, but SetRetryer would be misleading; -// dynamo expects it to be called from within the dynamodbapi interface. -// Probably best to create a forward-compatible (v2-friendly) configuration API instead. - -// func (db *DB) SetRetryer(retryer request.Retryer) { -// db.retryer = retryer -// } - -// func (db *DB) SetMaxRetries(max int) *DB { -// db.retryMax = max -// return db -// } - -// func (db *DB) SetLogger(logger aws.Logger) *DB { -// if logger == nil { -// db.logger = noopLogger{} -// return db -// } -// db.logger = logger -// return db -// } - -func (db *DB) log(v ...interface{}) { - db.logger.Log(v...) +func (db *DB) loadDesc(name string) (desc Description, ok bool) { + if descv, exists := db.descs.Load(name); exists { + desc, ok = descv.(Description) + } + return +} + +func (db *DB) storeDesc(desc Description) { + db.descs.Store(desc.Name, desc) } // ListTables is a request to list tables. @@ -104,18 +68,11 @@ func (db *DB) ListTables() *ListTables { } // All returns every table or an error. -func (lt *ListTables) All() ([]string, error) { - ctx, cancel := defaultContext() - defer cancel() - return lt.AllWithContext(ctx) -} - -// AllWithContext returns every table or an error. -func (lt *ListTables) AllWithContext(ctx context.Context) ([]string, error) { +func (lt *ListTables) All(ctx context.Context) ([]string, error) { var tables []string itr := lt.Iter() var name string - for itr.NextWithContext(ctx, &name) { + for itr.Next(ctx, &name) { tables = append(tables, name) } return tables, itr.Err() @@ -134,13 +91,7 @@ func (lt *ListTables) Iter() Iter { return <Iter{lt: lt} } -func (itr *ltIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *ltIter) Next(ctx context.Context, out interface{}) bool { if ctx.Err() != nil { itr.err = ctx.Err() } @@ -155,7 +106,7 @@ func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool { if itr.result != nil { if itr.idx < len(itr.result.TableNames) { - *out.(*string) = *itr.result.TableNames[itr.idx] + *out.(*string) = itr.result.TableNames[itr.idx] itr.idx++ return true } @@ -167,7 +118,7 @@ func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool { } itr.err = itr.lt.db.retry(ctx, func() error { - res, err := itr.lt.db.client.ListTablesWithContext(ctx, itr.input()) + res, err := itr.lt.db.client.ListTables(ctx, itr.input()) if err != nil { return err } @@ -182,7 +133,7 @@ func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool { return false } - *out.(*string) = *itr.result.TableNames[0] + *out.(*string) = itr.result.TableNames[0] itr.idx = 1 return true } @@ -203,10 +154,7 @@ func (itr *ltIter) input() *dynamodb.ListTablesInput { type Iter interface { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. - Next(out interface{}) bool - // NextWithContext tries to unmarshal the next result into out. - // Returns false when it is complete or if it runs into an error. - NextWithContext(ctx context.Context, out interface{}) bool + Next(ctx context.Context, out interface{}) bool // Err returns the error encountered, if any. // You should check this after Next is finished. Err() error @@ -218,7 +166,7 @@ type PagingIter interface { Iter // LastEvaluatedKey returns a key that can be passed to StartFrom in Query or Scan. // Combined with SearchLimit, it is useful for paginating partial results. - LastEvaluatedKey() PagingKey + LastEvaluatedKey(context.Context) (PagingKey, error) } // PagingIter is an iterator of combined request results from multiple iterators running in parallel. @@ -226,18 +174,18 @@ type ParallelIter interface { Iter // LastEvaluatedKeys returns each parallel segment's last evaluated key in order of segment number. // The slice will be the same size as the number of segments, and the keys can be nil. - LastEvaluatedKeys() []PagingKey + LastEvaluatedKeys(context.Context) ([]PagingKey, error) } // PagingKey is a key used for splitting up partial results. // Get a PagingKey from a PagingIter and pass it to StartFrom in Query or Scan. -type PagingKey map[string]*dynamodb.AttributeValue +type PagingKey Item // IsCondCheckFailed returns true if the given error is a "conditional check failed" error. // This corresponds with a ConditionalCheckFailedException in most APIs, // or a TransactionCanceledException with a ConditionalCheckFailed cancellation reason in transactions. func IsCondCheckFailed(err error) bool { - var txe *dynamodb.TransactionCanceledException + var txe *types.TransactionCanceledException if errors.As(err, &txe) { for _, cr := range txe.CancellationReasons { if cr.Code != nil && *cr.Code == "ConditionalCheckFailed" { @@ -247,8 +195,8 @@ func IsCondCheckFailed(err error) bool { return false } - var ae awserr.Error - if errors.As(err, &ae) && ae.Code() == "ConditionalCheckFailedException" { + var ae smithy.APIError + if errors.As(err, &ae) && ae.ErrorCode() == "ConditionalCheckFailedException" { return true } diff --git a/db_test.go b/db_test.go index e9cd8c2..e793842 100644 --- a/db_test.go +++ b/db_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "errors" "fmt" "log" @@ -10,10 +11,11 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/smithy-go" ) var ( @@ -22,7 +24,7 @@ var ( testTableSprockets = "TestDB-Sprockets" ) -var dummyCreds = credentials.NewStaticCredentials("dummy", "dummy", "") +var dummyCreds = credentials.NewStaticCredentialsProvider("dummy", "dummy", "") const offlineSkipMsg = "DYNAMO_TEST_REGION not set" @@ -49,11 +51,28 @@ func TestMain(m *testing.M) { region = &dtr } if region != nil { - testDB = New(session.Must(session.NewSession()), &aws.Config{ - Region: region, - Endpoint: endpoint, - // LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), - }) + var resolv aws.EndpointResolverWithOptions + if endpoint != nil { + resolv = aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{URL: *endpoint}, nil + }, + ) + } + // TransactionCanceledException + + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion(*region), + config.WithEndpointResolverWithOptions(resolv), + config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(RetryTxConflicts) + }), + ) + if err != nil { + log.Fatal(err) + } + testDB = New(cfg) } timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10) @@ -85,17 +104,17 @@ func TestMain(m *testing.M) { default: shouldCreate = endpoint != nil } - + ctx := context.Background() var created []Table if testDB != nil { for _, name := range []string{testTableWidgets, testTableSprockets} { table := testDB.Table(name) log.Println("Checking test table:", name) - _, err := table.Describe().Run() + _, err := table.Describe().Run(ctx) switch { case isTableNotExistsErr(err) && shouldCreate: log.Println("Creating test table:", name) - if err := testDB.CreateTable(name, widget{}).Run(); err != nil { + if err := testDB.CreateTable(name, widget{}).Run(ctx); err != nil { panic(err) } created = append(created, testDB.Table(name)) @@ -110,15 +129,18 @@ func TestMain(m *testing.M) { for _, table := range created { log.Println("Deleting test table:", table.Name()) - if err := table.DeleteTable().Run(); err != nil { + if err := table.DeleteTable().Run(ctx); err != nil { log.Println("Error deleting test table:", table.Name(), err) } } } func isTableNotExistsErr(err error) bool { - var ae awserr.Error - return errors.As(err, &ae) && ae.Code() == "ResourceNotFoundException" + var aerr smithy.APIError + if errors.As(err, &aerr) { + return aerr.ErrorCode() == "ResourceNotFoundException" + } + return false } func TestListTables(t *testing.T) { @@ -126,7 +148,7 @@ func TestListTables(t *testing.T) { t.Skip(offlineSkipMsg) } - tables, err := testDB.ListTables().All() + tables, err := testDB.ListTables().All(context.TODO()) if err != nil { t.Error(err) return diff --git a/decode.go b/decode.go index 864016a..7586c69 100644 --- a/decode.go +++ b/decode.go @@ -4,32 +4,32 @@ import ( "fmt" "reflect" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Unmarshaler is the interface implemented by objects that can unmarshal // an AttributeValue into themselves. type Unmarshaler interface { - UnmarshalDynamo(av *dynamodb.AttributeValue) error + UnmarshalDynamo(av types.AttributeValue) error } // ItemUnmarshaler is the interface implemented by objects that can unmarshal // an Item (a map of strings to AttributeValues) into themselves. type ItemUnmarshaler interface { - UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error + UnmarshalDynamoItem(item Item) error } // Unmarshal decodes a DynamoDB item into out, which must be a pointer. -func UnmarshalItem(item map[string]*dynamodb.AttributeValue, out interface{}) error { +func UnmarshalItem(item Item, out interface{}) error { return unmarshalItem(item, out) } // Unmarshal decodes a DynamoDB value into out, which must be a pointer. -func Unmarshal(av *dynamodb.AttributeValue, out interface{}) error { +func Unmarshal(av types.AttributeValue, out interface{}) error { switch out := out.(type) { case awsEncoder: - return dynamodbattribute.Unmarshal(av, out.iface) + return attributevalue.Unmarshal(av, out.iface) } rv := reflect.ValueOf(out) @@ -41,9 +41,9 @@ func Unmarshal(av *dynamodb.AttributeValue, out interface{}) error { } // used in iterators for unmarshaling one item -type unmarshalFunc func(map[string]*dynamodb.AttributeValue, interface{}) error +type unmarshalFunc func(Item, interface{}) error -func unmarshalItem(item map[string]*dynamodb.AttributeValue, out interface{}) error { +func unmarshalItem(item Item, out interface{}) error { rv := reflect.ValueOf(out) plan, err := typedefOf(rv.Type()) if err != nil { @@ -52,7 +52,7 @@ func unmarshalItem(item map[string]*dynamodb.AttributeValue, out interface{}) er return plan.decodeItem(item, rv) } -func unmarshalAppend(item map[string]*dynamodb.AttributeValue, out interface{}) error { +func unmarshalAppend(item Item, out interface{}) error { if awsenc, ok := out.(awsEncoder); ok { return unmarshalAppendAWS(item, awsenc.iface) } @@ -73,9 +73,9 @@ func unmarshalAppend(item map[string]*dynamodb.AttributeValue, out interface{}) return nil } -func unmarshalAppendTo(out interface{}) func(item map[string]*dynamodb.AttributeValue, out interface{}) error { +func unmarshalAppendTo(out interface{}) func(item Item, out interface{}) error { if awsenc, ok := out.(awsEncoder); ok { - return func(item map[string]*dynamodb.AttributeValue, _ any) error { + return func(item Item, _ any) error { return unmarshalAppendAWS(item, awsenc.iface) } } @@ -84,14 +84,14 @@ func unmarshalAppendTo(out interface{}) func(item map[string]*dynamodb.Attribute slicet := ptr.Type().Elem() membert := slicet.Elem() if ptr.Kind() != reflect.Ptr || slicet.Kind() != reflect.Slice { - return func(item map[string]*dynamodb.AttributeValue, _ any) error { + return func(item Item, _ any) error { return fmt.Errorf("dynamo: unmarshal append: result argument must be a slice pointer") } } plan, err := typedefOf(membert) if err != nil { - return func(item map[string]*dynamodb.AttributeValue, _ any) error { + return func(item Item, _ any) error { return err } } @@ -104,7 +104,7 @@ func unmarshalAppendTo(out interface{}) func(item map[string]*dynamodb.Attribute *slice = append(*slice, *member) } */ - return func(item map[string]*dynamodb.AttributeValue, _ any) error { + return func(item map[string]types.AttributeValue, _ any) error { member := reflect.New(membert) // *T of *[]T if err := plan.decodeItem(item, member); err != nil { return err diff --git a/decode_aux_test.go b/decode_aux_test.go index 19dd36e..5e7c570 100644 --- a/decode_aux_test.go +++ b/decode_aux_test.go @@ -5,11 +5,9 @@ import ( "reflect" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - - "github.com/guregu/dynamo" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/guregu/dynamo/v2" ) type Coffee struct { @@ -21,9 +19,9 @@ func TestEncodingAux(t *testing.T) { // using the "aux" unmarshaling trick. // See: https://github.com/guregu/dynamo/issues/181 - in := map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("intenso")}, - "Name": {S: aws.String("Intenso 12")}, + in := dynamo.Item{ + "ID": &types.AttributeValueMemberS{Value: "intenso"}, + "Name": &types.AttributeValueMemberS{Value: "Intenso 12"}, } type coffeeItemDefault struct { @@ -62,7 +60,7 @@ type coffeeItemFlat struct { Name string } -func (c *coffeeItemFlat) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (c *coffeeItemFlat) UnmarshalDynamoItem(item dynamo.Item) error { type alias coffeeItemFlat aux := struct { *alias @@ -80,7 +78,7 @@ type coffeeItemInvalid struct { Name string } -func (c *coffeeItemInvalid) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (c *coffeeItemInvalid) UnmarshalDynamoItem(item dynamo.Item) error { type alias coffeeItemInvalid aux := struct { *alias @@ -98,7 +96,7 @@ type coffeeItemEmbedded struct { Coffee } -func (c *coffeeItemEmbedded) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (c *coffeeItemEmbedded) UnmarshalDynamoItem(item dynamo.Item) error { type alias coffeeItemEmbedded aux := struct { *alias @@ -116,7 +114,7 @@ type coffeeItemEmbeddedPointer struct { *Coffee } -func (c *coffeeItemEmbeddedPointer) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (c *coffeeItemEmbeddedPointer) UnmarshalDynamoItem(item dynamo.Item) error { type alias coffeeItemEmbeddedPointer aux := struct { *alias @@ -147,14 +145,14 @@ type coffeeItemSDKEmbeddedPointer struct { *Coffee } -func (c *coffeeItemSDKEmbeddedPointer) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (c *coffeeItemSDKEmbeddedPointer) UnmarshalDynamoItem(item dynamo.Item) error { type alias coffeeItemEmbeddedPointer aux := struct { *alias }{ alias: (*alias)(c), } - if err := dynamodbattribute.UnmarshalMap(item, &aux); err != nil { + if err := attributevalue.UnmarshalMap(item, &aux); err != nil { return err } return nil diff --git a/decode_test.go b/decode_test.go index 1311ccd..7e5b7b0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -7,20 +7,19 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) var itemDecodeOnlyTests = []struct { name string - given map[string]*dynamodb.AttributeValue + given Item expect interface{} }{ { // unexported embedded pointers should be ignored name: "embedded unexported pointer", - given: map[string]*dynamodb.AttributeValue{ - "Embedded": {BOOL: aws.Bool(true)}, + given: Item{ + "Embedded": &types.AttributeValueMemberBOOL{Value: true}, }, expect: struct { *embedded @@ -29,8 +28,8 @@ var itemDecodeOnlyTests = []struct { { // unexported fields should be ignored name: "unexported fields", - given: map[string]*dynamodb.AttributeValue{ - "a": {BOOL: aws.Bool(true)}, + given: Item{ + "a": &types.AttributeValueMemberBOOL{Value: true}, }, expect: struct { a bool @@ -39,8 +38,8 @@ var itemDecodeOnlyTests = []struct { { // embedded pointers shouldn't clobber existing fields name: "exported pointer embedded struct clobber", - given: map[string]*dynamodb.AttributeValue{ - "Embedded": {S: aws.String("OK")}, + given: Item{ + "Embedded": &types.AttributeValueMemberS{Value: "OK"}, }, expect: struct { Embedded string @@ -79,11 +78,11 @@ func TestUnmarshalAppend(t *testing.T) { page := "5" limit := "20" null := true - item := map[string]*dynamodb.AttributeValue{ - "UserID": {N: &id}, - "Page": {N: &page}, - "Limit": {N: &limit}, - "Null": {NULL: &null}, + item := Item{ + "UserID": &types.AttributeValueMemberN{Value: id}, + "Page": &types.AttributeValueMemberN{Value: page}, + "Limit": &types.AttributeValueMemberN{Value: limit}, + "Null": &types.AttributeValueMemberNULL{Value: null}, } do := unmarshalAppendTo(&results) @@ -92,7 +91,7 @@ func TestUnmarshalAppend(t *testing.T) { item2 := maps.Clone(item) id := 12345 + i idstr := strconv.Itoa(id) - item2["UserID"] = &dynamodb.AttributeValue{N: &idstr} + item2["UserID"] = &types.AttributeValueMemberN{Value: idstr} err := do(item2, &results) if err != nil { t.Fatal(err) @@ -157,52 +156,6 @@ func TestUnmarshalItem(t *testing.T) { } } -func TestUnmarshalNULL(t *testing.T) { - tru := true - arbitrary := "hello world" - double := new(*int) - item := map[string]*dynamodb.AttributeValue{ - "String": {NULL: &tru}, - "Slice": {NULL: &tru}, - "Array": {NULL: &tru}, - "StringPtr": {NULL: &tru}, - "DoublePtr": {NULL: &tru}, - "Map": {NULL: &tru}, - "Interface": {NULL: &tru}, - } - - type resultType struct { - String string - Slice []string - Array [2]byte - StringPtr *string - DoublePtr **int - Map map[string]int - Interface interface{} - } - - // dirty result, we want this to be reset - result := resultType{ - String: "ABC", - Slice: []string{"A", "B"}, - Array: [2]byte{'A', 'B'}, - StringPtr: &arbitrary, - DoublePtr: double, - Map: map[string]int{ - "A": 1, - }, - Interface: "interface{}", - } - - if err := UnmarshalItem(item, &result); err != nil { - t.Error(err) - } - - if (!reflect.DeepEqual(result, resultType{})) { - t.Error("unmarshal null: bad result:", result, "≠", resultType{}) - } -} - func TestUnmarshalMissing(t *testing.T) { // This test makes sure we're zeroing out fields of structs even if the given data doesn't contain them @@ -232,8 +185,8 @@ func TestUnmarshalMissing(t *testing.T) { }, } - replace := map[string]*dynamodb.AttributeValue{ - "UserID": {N: aws.String("112")}, + replace := Item{ + "UserID": &types.AttributeValueMemberN{Value: "112"}, } if err := UnmarshalItem(replace, &w); err != nil { @@ -244,11 +197,13 @@ func TestUnmarshalMissing(t *testing.T) { t.Error("bad unmarshal missing. want:", want, "got:", w) } - replace2 := map[string]*dynamodb.AttributeValue{ - "UserID": {N: aws.String("113")}, - "Foo": {M: map[string]*dynamodb.AttributeValue{ - "Bar": {N: aws.String("1338")}, - }}, + replace2 := Item{ + "UserID": &types.AttributeValueMemberN{Value: "113"}, + "Foo": &types.AttributeValueMemberM{ + Value: Item{ + "Bar": &types.AttributeValueMemberN{Value: "1338"}, + }, + }, } want = widget2{ @@ -322,12 +277,12 @@ func TestDecode3(t *testing.T) { // t.Fail() } -var exampleItem = map[string]*dynamodb.AttributeValue{ - "UserID": {N: aws.String("555")}, - "Msg": {S: aws.String("fux")}, - "Count": {N: aws.String("1337")}, - "Meta": {M: map[string]*dynamodb.AttributeValue{ - "Foo": {S: aws.String("1336")}, +var exampleItem = map[string]types.AttributeValue{ + "UserID": &types.AttributeValueMemberN{Value: "555"}, + "Msg": &types.AttributeValueMemberS{Value: "fux"}, + "Count": &types.AttributeValueMemberN{Value: "1337"}, + "Meta": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Foo": &types.AttributeValueMemberS{Value: "1336"}, }}, } var exampleWant = widget{ diff --git a/decodefunc.go b/decodefunc.go index 140fb63..95dba40 100644 --- a/decodefunc.go +++ b/decodefunc.go @@ -6,12 +6,12 @@ import ( "strconv" "time" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -type decodeFunc func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error +type decodeFunc func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error -func decodePtr(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func decodePtr(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { var elem reflect.Value if rv.IsNil() { if rv.CanSet() { @@ -29,7 +29,7 @@ func decodePtr(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv return nil } -func decodeNull(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func decodeNull(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { if !rv.IsValid() { return nil } @@ -40,13 +40,13 @@ func decodeNull(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, r return nil } -func decodeString(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - v.SetString(*av.S) +func decodeString(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + v.SetString(av.(*types.AttributeValueMemberS).Value) return nil } -func decodeInt(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - n, err := strconv.ParseInt(*av.N, 10, 64) +func decodeInt(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + n, err := strconv.ParseInt(av.(*types.AttributeValueMemberN).Value, 10, 64) if err != nil { return err } @@ -54,8 +54,8 @@ func decodeInt(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v return nil } -func decodeUint(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - n, err := strconv.ParseUint(*av.N, 10, 64) +func decodeUint(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + n, err := strconv.ParseUint(av.(*types.AttributeValueMemberN).Value, 10, 64) if err != nil { return err } @@ -63,8 +63,8 @@ func decodeUint(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v return nil } -func decodeFloat(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - f, err := strconv.ParseFloat(*av.N, 64) +func decodeFloat(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + f, err := strconv.ParseFloat(av.(*types.AttributeValueMemberN).Value, 64) if err != nil { return err } @@ -72,19 +72,20 @@ func decodeFloat(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, return nil } -func decodeBool(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - v.SetBool(*av.BOOL) +func decodeBool(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + v.SetBool(av.(*types.AttributeValueMemberBOOL).Value) return nil } -func decodeBytes(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - v.SetBytes(av.B) +func decodeBytes(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + v.SetBytes(av.(*types.AttributeValueMemberB).Value) return nil } -func decodeSliceL(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - reallocSlice(v, len(av.L)) - for i, innerAV := range av.L { +func decodeSliceL(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + list := av.(*types.AttributeValueMemberL).Value + reallocSlice(v, len(list)) + for i, innerAV := range list { innerRV := v.Index(i).Addr() if err := plan.decodeAttr(flags, innerAV, innerRV); err != nil { return err @@ -94,7 +95,7 @@ func decodeSliceL(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, return nil } -// func decodeSliceB(plan *decodePlan, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { +// func decodeSliceB(plan *decodePlan, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { // reallocSlice(v, len(av.B)) // for i, b := range av.B { // innerB := reflect.ValueOf(b).Convert(v.Type().Elem()) @@ -104,54 +105,59 @@ func decodeSliceL(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, // return nil // } -func decodeSliceBS(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - reallocSlice(v, len(av.BS)) - for i, b := range av.BS { +func decodeSliceBS(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + set := av.(*types.AttributeValueMemberBS).Value + reallocSlice(v, len(set)) + for i, b := range set { innerRV := v.Index(i).Addr() - if err := plan.decodeAttr(flags, &dynamodb.AttributeValue{B: b}, innerRV); err != nil { + if err := plan.decodeAttr(flags, &types.AttributeValueMemberB{Value: b}, innerRV); err != nil { return err } } return nil } -func decodeSliceSS(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - reallocSlice(v, len(av.SS)) - for i, s := range av.SS { +func decodeSliceSS(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + set := av.(*types.AttributeValueMemberSS).Value + reallocSlice(v, len(set)) + for i, s := range set { innerRV := v.Index(i).Addr() - if err := plan.decodeAttr(flags, &dynamodb.AttributeValue{S: s}, innerRV); err != nil { + if err := plan.decodeAttr(flags, &types.AttributeValueMemberS{Value: s}, innerRV); err != nil { return err } } return nil } -func decodeSliceNS(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - reallocSlice(v, len(av.NS)) - for i, n := range av.NS { +func decodeSliceNS(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + set := av.(*types.AttributeValueMemberNS).Value + reallocSlice(v, len(set)) + for i, n := range set { innerRV := v.Index(i).Addr() - if err := plan.decodeAttr(flags, &dynamodb.AttributeValue{N: n}, innerRV); err != nil { + if err := plan.decodeAttr(flags, &types.AttributeValueMemberN{Value: n}, innerRV); err != nil { return err } } return nil } -func decodeArrayB(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - if len(av.B) > v.Len() { - return fmt.Errorf("dynamo: cannot marshal %s into %s; too small (dst len: %d, src len: %d)", avTypeName(av), v.Type().String(), v.Len(), len(av.B)) +func decodeArrayB(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + bs := av.(*types.AttributeValueMemberB).Value + if len(bs) > v.Len() { + return fmt.Errorf("dynamo: cannot marshal %s into %s; too small (dst len: %d, src len: %d)", avTypeName(av), v.Type().String(), v.Len(), len(bs)) } vt := v.Type() - array := reflect.ValueOf(av.B) + array := reflect.ValueOf(bs) reflect.Copy(v, array.Convert(vt)) return nil } -func decodeArrayL(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - if len(av.L) > v.Len() { - return fmt.Errorf("dynamo: cannot marshal %s into %s; too small (dst len: %d, src len: %d)", avTypeName(av), v.Type().String(), v.Len(), len(av.L)) +func decodeArrayL(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + list := av.(*types.AttributeValueMemberL).Value + if len(list) > v.Len() { + return fmt.Errorf("dynamo: cannot marshal %s into %s; too small (dst len: %d, src len: %d)", avTypeName(av), v.Type().String(), v.Len(), len(list)) } - for i, innerAV := range av.L { + for i, innerAV := range list { if err := plan.decodeAttr(flags, innerAV, v.Index(i)); err != nil { return err } @@ -159,8 +165,9 @@ func decodeArrayL(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, return nil } -func decodeStruct(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - return visitFields(av.M, rv, nil, func(av *dynamodb.AttributeValue, flags encodeFlags, v reflect.Value) error { +func decodeStruct(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + m := av.(*types.AttributeValueMemberM).Value + return visitFields(m, rv, nil, func(av types.AttributeValue, flags encodeFlags, v reflect.Value) error { if av == nil { if v.CanSet() && !nullish(v) { v.SetZero() @@ -171,7 +178,7 @@ func decodeStruct(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, }) } -func decodeMap(decodeKey func(reflect.Value, string) error) func(plan *typedef, _ encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { +func decodeMap(decodeKey func(reflect.Value, string) error) func(plan *typedef, _ encodeFlags, av types.AttributeValue, v reflect.Value) error { /* Something like: @@ -186,10 +193,11 @@ func decodeMap(decodeKey func(reflect.Value, string) error) func(plan *typedef, out[*kp] = *vp } */ - return func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - reallocMap(rv, len(av.M)) + return func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + m := av.(*types.AttributeValueMemberM).Value + reallocMap(rv, len(m)) kp := reflect.New(rv.Type().Key()) - for name, v := range av.M { + for name, v := range m { if err := decodeKey(kp, name); err != nil { return fmt.Errorf("error decoding key %q into %v", name, kp.Type().Elem()) } @@ -203,12 +211,13 @@ func decodeMap(decodeKey func(reflect.Value, string) error) func(plan *typedef, } } -func decodeMapSS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - return func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - reallocMap(rv, len(av.SS)) +func decodeMapSS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + return func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + set := av.(*types.AttributeValueMemberSS).Value + reallocMap(rv, len(set)) kp := reflect.New(rv.Type().Key()) - for _, s := range av.SS { - if err := decodeKey(kp, *s); err != nil { + for _, s := range set { + if err := decodeKey(kp, s); err != nil { return err } rv.SetMapIndex(kp.Elem(), truthy) @@ -217,12 +226,13 @@ func decodeMapSS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typed } } -func decodeMapNS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - return func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - reallocMap(rv, len(av.NS)) +func decodeMapNS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + return func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + set := av.(*types.AttributeValueMemberNS).Value + reallocMap(rv, len(set)) kv := reflect.New(rv.Type().Key()).Elem() - for _, n := range av.NS { - if err := plan.decodeAttr(flagNone, &dynamodb.AttributeValue{N: n}, kv); err != nil { + for _, n := range set { + if err := plan.decodeAttr(flagNone, &types.AttributeValueMemberN{Value: n}, kv); err != nil { return err } rv.SetMapIndex(kv, truthy) @@ -230,11 +240,12 @@ func decodeMapNS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typed return nil } } -func decodeMapBS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - return func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { - reallocMap(rv, len(av.BS)) +func decodeMapBS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + return func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { + set := av.(*types.AttributeValueMemberBS).Value + reallocMap(rv, len(set)) kv := reflect.New(rv.Type().Key()).Elem() - for _, bb := range av.BS { + for _, bb := range set { reflect.Copy(kv, reflect.ValueOf(bb)) rv.SetMapIndex(kv, truthy) } @@ -242,8 +253,8 @@ func decodeMapBS(decodeKey decodeKeyFunc, truthy reflect.Value) func(plan *typed } } -func decode2[T any](fn func(t T, av *dynamodb.AttributeValue) error) func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { - return func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func decode2[T any](fn func(t T, av types.AttributeValue) error) func(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { + return func(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { if !rv.CanInterface() { return nil } @@ -264,7 +275,7 @@ func decode2[T any](fn func(t T, av *dynamodb.AttributeValue) error) func(plan * } } -func decodeAny(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v reflect.Value) error { +func decodeAny(plan *typedef, flags encodeFlags, av types.AttributeValue, v reflect.Value) error { iface, err := av2iface(av) if err != nil { return err @@ -277,10 +288,10 @@ func decodeAny(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, v return nil } -func decodeUnixTime(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func decodeUnixTime(plan *typedef, flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { rv = indirect(rv) - ts, err := strconv.ParseInt(*av.N, 10, 64) + ts, err := strconv.ParseInt(av.(*types.AttributeValueMemberN).Value, 10, 64) if err != nil { return err } diff --git a/delete.go b/delete.go index 7b10837..85ca119 100644 --- a/delete.go +++ b/delete.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Delete is a request to delete an item. @@ -15,10 +15,10 @@ type Delete struct { returnType string hashKey string - hashValue *dynamodb.AttributeValue + hashValue types.AttributeValue rangeKey string - rangeValue *dynamodb.AttributeValue + rangeValue types.AttributeValue subber condition string @@ -78,13 +78,7 @@ func (d *Delete) ConsumedCapacity(cc *ConsumedCapacity) *Delete { } // Run executes this delete request. -func (d *Delete) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return d.RunWithContext(ctx) -} - -func (d *Delete) RunWithContext(ctx context.Context) error { +func (d *Delete) Run(ctx context.Context) error { d.returnType = "NONE" _, err := d.run(ctx) return err @@ -92,13 +86,7 @@ func (d *Delete) RunWithContext(ctx context.Context) error { // OldValue executes this delete request, unmarshaling the previous value to out. // Returns ErrNotFound is there was no previous value. -func (d *Delete) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return d.OldValueWithContext(ctx, out) -} - -func (d *Delete) OldValueWithContext(ctx context.Context, out interface{}) error { +func (d *Delete) OldValue(ctx context.Context, out interface{}) error { d.returnType = "ALL_OLD" output, err := d.run(ctx) switch { @@ -119,7 +107,7 @@ func (d *Delete) run(ctx context.Context) (*dynamodb.DeleteItemOutput, error) { var output *dynamodb.DeleteItemOutput err := d.table.db.retry(ctx, func() error { var err error - output, err = d.table.db.client.DeleteItemWithContext(ctx, input) + output, err = d.table.db.client.DeleteItem(ctx, input) return err }) if d.cc != nil { @@ -132,7 +120,7 @@ func (d *Delete) deleteInput() *dynamodb.DeleteItemInput { input := &dynamodb.DeleteItemInput{ TableName: &d.table.name, Key: d.key(), - ReturnValues: &d.returnType, + ReturnValues: types.ReturnValue(d.returnType), ExpressionAttributeNames: d.nameExpr, ExpressionAttributeValues: d.valueExpr, } @@ -140,18 +128,18 @@ func (d *Delete) deleteInput() *dynamodb.DeleteItemInput { input.ConditionExpression = &d.condition } if d.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input } -func (d *Delete) writeTxItem() (*dynamodb.TransactWriteItem, error) { +func (d *Delete) writeTxItem() (*types.TransactWriteItem, error) { if d.err != nil { return nil, d.err } input := d.deleteInput() - item := &dynamodb.TransactWriteItem{ - Delete: &dynamodb.Delete{ + item := &types.TransactWriteItem{ + Delete: &types.Delete{ TableName: input.TableName, Key: input.Key, ExpressionAttributeNames: input.ExpressionAttributeNames, @@ -162,8 +150,8 @@ func (d *Delete) writeTxItem() (*dynamodb.TransactWriteItem, error) { return item, nil } -func (d *Delete) key() map[string]*dynamodb.AttributeValue { - key := map[string]*dynamodb.AttributeValue{ +func (d *Delete) key() Item { + key := Item{ d.hashKey: d.hashValue, } if d.rangeKey != "" { diff --git a/delete_test.go b/delete_test.go index 9d11dd7..feae4c2 100644 --- a/delete_test.go +++ b/delete_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "reflect" "testing" "time" @@ -10,6 +11,7 @@ func TestDelete(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() table := testDB.Table(testTableWidgets) // first, add an item to delete later @@ -21,7 +23,7 @@ func TestDelete(t *testing.T) { "color": "octarine", }, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -31,7 +33,7 @@ func TestDelete(t *testing.T) { Range("Time", item.Time). If("Meta.'color' = ?", "octarine"). If("Msg = ?", "wrong msg"). - Run() + Run(ctx) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -39,7 +41,7 @@ func TestDelete(t *testing.T) { // delete it var old widget var cc ConsumedCapacity - err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(&old) + err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(ctx, &old) if err != nil { t.Error("unexpected error:", err) } diff --git a/describetable.go b/describetable.go index 0722c58..afbdacf 100644 --- a/describetable.go +++ b/describetable.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Description contains information about a table. @@ -90,7 +90,7 @@ type Index struct { ProjectionAttribs []string } -func newDescription(table *dynamodb.TableDescription) Description { +func newDescription(table *types.TableDescription) Description { desc := Description{ Name: *table.TableName, } @@ -98,8 +98,8 @@ func newDescription(table *dynamodb.TableDescription) Description { if table.TableArn != nil { desc.ARN = *table.TableArn } - if table.TableStatus != nil { - desc.Status = Status(*table.TableStatus) + if table.TableStatus != "" { + desc.Status = Status(table.TableStatus) } if table.CreationDateTime != nil { desc.Created = *table.CreationDateTime @@ -109,8 +109,8 @@ func newDescription(table *dynamodb.TableDescription) Description { desc.HashKeyType = lookupADType(table.AttributeDefinitions, desc.HashKey) desc.RangeKeyType = lookupADType(table.AttributeDefinitions, desc.RangeKey) - if table.BillingModeSummary != nil && table.BillingModeSummary.BillingMode != nil { - desc.OnDemand = *table.BillingModeSummary.BillingMode == dynamodb.BillingModePayPerRequest + if table.BillingModeSummary != nil && table.BillingModeSummary.BillingMode != "" { + desc.OnDemand = table.BillingModeSummary.BillingMode == types.BillingModePayPerRequest } if table.ProvisionedThroughput != nil { @@ -126,20 +126,14 @@ func newDescription(table *dynamodb.TableDescription) Description { for _, index := range table.GlobalSecondaryIndexes { idx := Index{ + Name: *index.IndexName, + ARN: *index.IndexArn, + Status: Status(index.IndexStatus), Throughput: newThroughput(index.ProvisionedThroughput), } - if index.IndexName != nil { - idx.Name = *index.IndexName - } - if index.IndexArn != nil { - idx.ARN = *index.IndexArn - } - if index.IndexStatus != nil { - idx.Status = Status(*index.IndexStatus) - } - if index.Projection != nil && index.Projection.ProjectionType != nil { - idx.ProjectionType = IndexProjection(*index.Projection.ProjectionType) - idx.ProjectionAttribs = aws.StringValueSlice(index.Projection.NonKeyAttributes) + if index.Projection != nil && index.Projection.ProjectionType != "" { + idx.ProjectionType = IndexProjection(index.Projection.ProjectionType) + idx.ProjectionAttribs = index.Projection.NonKeyAttributes } if index.Backfilling != nil { idx.Backfilling = *index.Backfilling @@ -167,9 +161,9 @@ func newDescription(table *dynamodb.TableDescription) Description { if index.IndexArn != nil { idx.ARN = *index.IndexArn } - if index.Projection != nil && index.Projection.ProjectionType != nil { - idx.ProjectionType = IndexProjection(*index.Projection.ProjectionType) - idx.ProjectionAttribs = aws.StringValueSlice(index.Projection.NonKeyAttributes) + if index.Projection != nil && index.Projection.ProjectionType != "" { + idx.ProjectionType = IndexProjection(index.Projection.ProjectionType) + idx.ProjectionAttribs = index.Projection.NonKeyAttributes } idx.HashKey, idx.RangeKey = schemaKeys(index.KeySchema) idx.HashKeyType = lookupADType(table.AttributeDefinitions, idx.HashKey) @@ -187,8 +181,8 @@ func newDescription(table *dynamodb.TableDescription) Description { if table.StreamSpecification.StreamEnabled != nil { desc.StreamEnabled = *table.StreamSpecification.StreamEnabled } - if table.StreamSpecification.StreamViewType != nil { - desc.StreamView = StreamView(*table.StreamSpecification.StreamViewType) + if table.StreamSpecification.StreamViewType != "" { + desc.StreamView = StreamView(table.StreamSpecification.StreamViewType) } } if table.LatestStreamArn != nil { @@ -204,13 +198,13 @@ func newDescription(table *dynamodb.TableDescription) Description { sseDesc.InaccessibleEncryptionDateTime = *table.SSEDescription.InaccessibleEncryptionDateTime } if table.SSEDescription.KMSMasterKeyArn != nil { - sseDesc.KMSMasterKeyArn = *table.SSEDescription.KMSMasterKeyArn + sseDesc.KMSMasterKeyARN = *table.SSEDescription.KMSMasterKeyArn } - if table.SSEDescription.SSEType != nil { - sseDesc.SSEType = lookupSSEType(*table.SSEDescription.SSEType) + if table.SSEDescription.SSEType != "" { + sseDesc.SSEType = table.SSEDescription.SSEType } - if table.SSEDescription.Status != nil { - sseDesc.Status = *table.SSEDescription.Status + if table.SSEDescription.Status != "" { + sseDesc.Status = table.SSEDescription.Status } desc.SSEDescription = sseDesc } @@ -260,19 +254,13 @@ func (table Table) Describe() *DescribeTable { } // Run executes this request and describe the table. -func (dt *DescribeTable) Run() (Description, error) { - ctx, cancel := defaultContext() - defer cancel() - return dt.RunWithContext(ctx) -} - -func (dt *DescribeTable) RunWithContext(ctx context.Context) (Description, error) { +func (dt *DescribeTable) Run(ctx context.Context) (Description, error) { input := dt.input() var result *dynamodb.DescribeTableOutput err := dt.table.db.retry(ctx, func() error { var err error - result, err = dt.table.db.client.DescribeTableWithContext(ctx, input) + result, err = dt.table.db.client.DescribeTable(ctx, input) return err }) if err != nil { @@ -280,7 +268,7 @@ func (dt *DescribeTable) RunWithContext(ctx context.Context) (Description, error } desc := newDescription(result.Table) - dt.table.desc.Store(desc) + dt.table.db.storeDesc(desc) return desc, nil } @@ -291,7 +279,7 @@ func (dt *DescribeTable) input() *dynamodb.DescribeTableInput { } } -func newThroughput(td *dynamodb.ProvisionedThroughputDescription) Throughput { +func newThroughput(td *types.ProvisionedThroughputDescription) Throughput { if td == nil { return Throughput{} } @@ -312,25 +300,25 @@ func newThroughput(td *dynamodb.ProvisionedThroughputDescription) Throughput { return thru } -func schemaKeys(schema []*dynamodb.KeySchemaElement) (hashKey, rangeKey string) { +func schemaKeys(schema []types.KeySchemaElement) (hashKey, rangeKey string) { for _, ks := range schema { - switch *ks.KeyType { - case dynamodb.KeyTypeHash: + switch ks.KeyType { + case types.KeyTypeHash: hashKey = *ks.AttributeName - case dynamodb.KeyTypeRange: + case types.KeyTypeRange: rangeKey = *ks.AttributeName } } return } -func lookupADType(ads []*dynamodb.AttributeDefinition, name string) KeyType { +func lookupADType(ads []types.AttributeDefinition, name string) KeyType { if name == "" { return "" } for _, ad := range ads { if *ad.AttributeName == name { - return KeyType(*ad.AttributeType) + return KeyType(ad.AttributeType) } } return "" diff --git a/describetable_test.go b/describetable_test.go index 9e22173..798a94b 100644 --- a/describetable_test.go +++ b/describetable_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -10,7 +11,7 @@ func TestDescribeTable(t *testing.T) { } table := testDB.Table(testTableWidgets) - desc, err := table.Describe().Run() + desc, err := table.Describe().Run(context.TODO()) if err != nil { t.Error(err) return diff --git a/dynamodbiface/interface.go b/dynamodbiface/interface.go new file mode 100644 index 0000000..c214cbb --- /dev/null +++ b/dynamodbiface/interface.go @@ -0,0 +1,83 @@ +// Package dynamodbiface provides an interface to enable mocking the Amazon DynamoDB service client +// for testing your code. +// +// It is important to note that this interface will have breaking changes +// when the service model is updated and adds new API operations, paginators, +// and waiters. +package dynamodbiface + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +// DynamoDBAPI provides an interface to enable mocking the +// dynamodb.DynamoDB service client's API operation, +// paginators, and waiters. This make unit testing your code that calls out +// to the SDK's service client's calls easier. +// +// The best way to use this interface is so the SDK's service client's calls +// can be stubbed out for unit testing your code with the SDK without needing +// to inject custom request handlers into the SDK's request pipeline. +// +// // myFunc uses an SDK service client to make a request to +// // Amazon DynamoDB. +// func myFunc(svc dynamodbiface.DynamoDBAPI) bool { +// // Make svc.BatchExecuteStatement request +// } +// +// func main() { +// cfg := config.LoadConfig() +// svc := dynamodb.New(cfg) +// +// myFunc(svc) +// } +// +// In your _test.go file: +// +// // Define a mock struct to be used in your unit tests of myFunc. +// type mockDynamoDBClient struct { +// dynamodbiface.DynamoDBAPI +// } +// func (m *mockDynamoDBClient) BatchExecuteStatement(input *dynamodb.BatchExecuteStatementInput) (*dynamodb.BatchExecuteStatementOutput, error) { +// // mock response/functionality +// } +// +// func TestMyFunc(t *testing.T) { +// // Setup Test +// mockSvc := &mockDynamoDBClient{} +// +// myfunc(mockSvc) +// +// // Verify myFunc's functionality +// } +// +// It is important to note that this interface will have breaking changes +// when the service model is updated and adds new API operations, paginators, +// and waiters. Its suggested to use the pattern above for testing, or using +// tooling to generate mocks to satisfy the interfaces. +type DynamoDBAPI interface { + CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) + ListTables(ctx context.Context, params *dynamodb.ListTablesInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ListTablesOutput, error) + ListGlobalTables(ctx context.Context, params *dynamodb.ListGlobalTablesInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ListGlobalTablesOutput, error) + DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) + UpdateTable(ctx context.Context, params *dynamodb.UpdateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateTableOutput, error) + + TransactGetItems(ctx context.Context, params *dynamodb.TransactGetItemsInput, optFns ...func(*dynamodb.Options)) (*dynamodb.TransactGetItemsOutput, error) + BatchGetItem(ctx context.Context, params *dynamodb.BatchGetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchGetItemOutput, error) + BatchWriteItem(ctx context.Context, params *dynamodb.BatchWriteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchWriteItemOutput, error) + + GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) + DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) + PutItem(ctx context.Context, params *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) + UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) + + UpdateTimeToLive(ctx context.Context, params *dynamodb.UpdateTimeToLiveInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateTimeToLiveOutput, error) + DescribeTimeToLive(ctx context.Context, params *dynamodb.DescribeTimeToLiveInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTimeToLiveOutput, error) + + Query(ctx context.Context, params *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) + Scan(ctx context.Context, params *dynamodb.ScanInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) + DeleteTable(ctx context.Context, params *dynamodb.DeleteTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteTableOutput, error) + TransactWriteItems(ctx context.Context, params *dynamodb.TransactWriteItemsInput, optFns ...func(*dynamodb.Options)) (*dynamodb.TransactWriteItemsOutput, error) +} diff --git a/encode.go b/encode.go index 6e6f6b5..8ace28b 100644 --- a/encode.go +++ b/encode.go @@ -5,27 +5,27 @@ import ( "reflect" "strconv" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Marshaler is the interface implemented by objects that can marshal themselves into // an AttributeValue. type Marshaler interface { - MarshalDynamo() (*dynamodb.AttributeValue, error) + MarshalDynamo() (types.AttributeValue, error) } // ItemMarshaler is the interface implemented by objects that can marshal themselves // into an Item (a map of strings to AttributeValues). type ItemMarshaler interface { - MarshalDynamoItem() (map[string]*dynamodb.AttributeValue, error) + MarshalDynamoItem() (Item, error) } // MarshalItem converts the given struct into a DynamoDB item. -func MarshalItem(v interface{}) (map[string]*dynamodb.AttributeValue, error) { +func MarshalItem(v interface{}) (Item, error) { return marshalItem(v) } -func marshalItem(v interface{}) (map[string]*dynamodb.AttributeValue, error) { +func marshalItem(v interface{}) (Item, error) { rv := reflect.ValueOf(v) rt := rv.Type() plan, err := typedefOf(rt) @@ -37,11 +37,11 @@ func marshalItem(v interface{}) (map[string]*dynamodb.AttributeValue, error) { } // Marshal converts the given value into a DynamoDB attribute value. -func Marshal(v interface{}) (*dynamodb.AttributeValue, error) { +func Marshal(v interface{}) (types.AttributeValue, error) { return marshal(v, flagNone) } -func marshal(v interface{}, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func marshal(v interface{}, flags encodeFlags) (types.AttributeValue, error) { rv := reflect.ValueOf(v) if !rv.IsValid() { return nil, nil @@ -64,8 +64,8 @@ func marshal(v interface{}, flags encodeFlags) (*dynamodb.AttributeValue, error) return enc(rv, flags) } -func marshalSliceNoOmit(values []interface{}) ([]*dynamodb.AttributeValue, error) { - avs := make([]*dynamodb.AttributeValue, 0, len(values)) +func marshalSliceNoOmit(values []interface{}) ([]types.AttributeValue, error) { + avs := make([]types.AttributeValue, 0, len(values)) for _, v := range values { av, err := marshal(v, flagAllowEmpty) if err != nil { diff --git a/encode_test.go b/encode_test.go index c96426b..b570f71 100644 --- a/encode_test.go +++ b/encode_test.go @@ -4,14 +4,14 @@ import ( "reflect" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) var itemEncodeOnlyTests = []struct { name string in interface{} - out map[string]*dynamodb.AttributeValue + out Item }{ { name: "omitemptyelem", @@ -26,10 +26,10 @@ var itemEncodeOnlyTests = []struct { M: map[string]string{"test": ""}, Other: true, }, - out: map[string]*dynamodb.AttributeValue{ - "L": {L: []*dynamodb.AttributeValue{}}, - "M": {M: map[string]*dynamodb.AttributeValue{}}, - "Other": {BOOL: aws.Bool(true)}, + out: Item{ + "L": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "M": &types.AttributeValueMemberM{Value: Item{}}, + "Other": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -43,8 +43,8 @@ var itemEncodeOnlyTests = []struct { M: map[string]string{"test": ""}, Other: true, }, - out: map[string]*dynamodb.AttributeValue{ - "Other": {BOOL: aws.Bool(true)}, + out: Item{ + "Other": &types.AttributeValueMemberBOOL{Value: (true)}, }, }, { @@ -62,14 +62,20 @@ var itemEncodeOnlyTests = []struct { }, }, }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "struct": {M: map[string]*dynamodb.AttributeValue{ - "InnerMap": {M: map[string]*dynamodb.AttributeValue{ - // expected empty inside - }}, - }}, - }}, + out: Item{ + "M": &types.AttributeValueMemberM{ + Value: Item{ + "struct": &types.AttributeValueMemberM{ + Value: Item{ + "InnerMap": &types.AttributeValueMemberM{ + Value: Item{ + // expected empty inside + }, + }, + }, + }, + }, + }, }, }, { @@ -83,8 +89,8 @@ var itemEncodeOnlyTests = []struct { private: 1337, private2: new(int), }, - out: map[string]*dynamodb.AttributeValue{ - "Public": {N: aws.String("555")}, + out: Item{ + "Public": &types.AttributeValueMemberN{Value: ("555")}, }, }, { @@ -95,8 +101,8 @@ var itemEncodeOnlyTests = []struct { }{ ID: "abc", }, - out: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("abc")}, + out: Item{ + "ID": &types.AttributeValueMemberS{Value: "abc"}, }, }, } @@ -155,19 +161,19 @@ type myStruct struct { Value isValue_Kind } -func (ms *myStruct) MarshalDynamoItem() (map[string]*dynamodb.AttributeValue, error) { +func (ms *myStruct) MarshalDynamoItem() (map[string]types.AttributeValue, error) { world := "world" - return map[string]*dynamodb.AttributeValue{ - "hello": {S: &world}, + return map[string]types.AttributeValue{ + "hello": &types.AttributeValueMemberS{Value: world}, }, nil } -func (ms *myStruct) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (ms *myStruct) UnmarshalDynamoItem(item map[string]types.AttributeValue) error { hello := item["hello"] - if hello == nil || hello.S == nil || *hello.S != "world" { - ms.OK = false - } else { + if h, ok := hello.(*types.AttributeValueMemberS); ok && h.Value == "world" { ms.OK = true + } else { + ms.OK = false } return nil } @@ -183,8 +189,8 @@ func TestMarshalItemBypass(t *testing.T) { } world := "world" - expect := map[string]*dynamodb.AttributeValue{ - "hello": {S: &world}, + expect := map[string]types.AttributeValue{ + "hello": &types.AttributeValueMemberS{Value: world}, } if !reflect.DeepEqual(got, expect) { t.Error("bad marshal. want:", expect, "got:", got) diff --git a/encodefunc.go b/encodefunc.go index c37acb0..a1a793c 100644 --- a/encodefunc.go +++ b/encodefunc.go @@ -7,12 +7,11 @@ import ( "strconv" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -type encodeFunc func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) +type encodeFunc func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { encKey := encodeKey{rt, flags} @@ -23,8 +22,79 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structI try := rt for { switch try { + case rtypeAttrB: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrBS: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrBOOL: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrN: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrS: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrL: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrNS: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrSS: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrM: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttrNULL: + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { + if av == nil { + return nil, nil + } + return av, nil + }), nil + case rtypeAttr: - return encode2(func(av *dynamodb.AttributeValue, _ encodeFlags) (*dynamodb.AttributeValue, error) { + return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { if av == nil { return nil, nil } @@ -37,14 +107,13 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structI } switch { case try.Implements(rtypeMarshaler): - return encode2(func(x Marshaler, _ encodeFlags) (*dynamodb.AttributeValue, error) { + return encode2(func(x Marshaler, _ encodeFlags) (types.AttributeValue, error) { return x.MarshalDynamo() }), nil case try.Implements(rtypeAWSMarshaler): - return encode2(func(x dynamodbattribute.Marshaler, _ encodeFlags) (*dynamodb.AttributeValue, error) { - var av dynamodb.AttributeValue - err := x.MarshalDynamoDBAttributeValue(&av) - return &av, err + return encode2(func(x attributevalue.Marshaler, _ encodeFlags) (types.AttributeValue, error) { + av, err := x.MarshalDynamoDBAttributeValue() + return av, err }), nil case try.Implements(rtypeTextMarshaler): return encodeTextMarshaler, nil @@ -62,8 +131,8 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structI // BOOL case reflect.Bool: - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - return &dynamodb.AttributeValue{BOOL: aws.Bool(rv.Bool())}, nil + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + return &types.AttributeValueMemberBOOL{Value: rv.Bool()}, nil }, nil // N @@ -115,7 +184,7 @@ func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags, info *structIn if err != nil { return nil, err } - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if rv.IsNil() { if flags&flagNull != 0 { return nullAV, nil @@ -126,10 +195,10 @@ func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags, info *structIn }, nil } -func encode2[T any](fn func(T, encodeFlags) (*dynamodb.AttributeValue, error)) func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func encode2[T any](fn func(T, encodeFlags) (types.AttributeValue, error)) func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { target := reflect.TypeOf((*T)(nil)).Elem() interfacing := target.Kind() == reflect.Interface - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if !rv.IsValid() || !rv.CanInterface() { return nil, nil } @@ -149,7 +218,7 @@ func encode2[T any](fn func(T, encodeFlags) (*dynamodb.AttributeValue, error)) f } } -func encodeString(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func encodeString(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { s := rv.String() if len(s) == 0 { if flags&flagAllowEmpty != 0 { @@ -160,28 +229,28 @@ func encodeString(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue } return nil, nil } - return &dynamodb.AttributeValue{S: &s}, nil + return &types.AttributeValueMemberS{Value: s}, nil } -var encodeTextMarshaler = encode2[encoding.TextMarshaler](func(x encoding.TextMarshaler, flags encodeFlags) (*dynamodb.AttributeValue, error) { +var encodeTextMarshaler = encode2[encoding.TextMarshaler](func(x encoding.TextMarshaler, flags encodeFlags) (types.AttributeValue, error) { text, err := x.MarshalText() switch { case err != nil: return nil, err case len(text) == 0: if flags&flagAllowEmpty != 0 { - return &dynamodb.AttributeValue{S: new(string)}, nil + return emptyS, nil } return nil, nil } str := string(text) - return &dynamodb.AttributeValue{S: &str}, nil + return &types.AttributeValueMemberS{Value: str}, nil }) func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { if rt.Kind() == reflect.Array { size := rt.Len() - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if rv.IsZero() { switch { case flags&flagNull != 0: @@ -193,11 +262,11 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { } data := make([]byte, size) reflect.Copy(reflect.ValueOf(data), rv) - return &dynamodb.AttributeValue{B: data}, nil + return &types.AttributeValueMemberB{Value: data}, nil } } - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if rv.IsNil() { if flags&flagNull != 0 { return nullAV, nil @@ -210,7 +279,7 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { } return nil, nil } - return &dynamodb.AttributeValue{B: rv.Bytes()}, nil + return &types.AttributeValueMemberB{Value: rv.Bytes()}, nil } } @@ -225,12 +294,12 @@ func (def *typedef) encodeStruct(rt reflect.Type, flags encodeFlags, info *struc fields = append(fields, *field) } - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { item, err := encodeItem(fields, rv) if err != nil { return nil, err } - return &dynamodb.AttributeValue{M: item}, nil + return &types.AttributeValueMemberM{Value: item}, nil }, nil } @@ -264,8 +333,8 @@ func encodeSliceSet(rt /* []T */ reflect.Type, flags encodeFlags) (encodeFunc, e return nil, fmt.Errorf("dynamo: invalid type for set: %v", rt) } -func encodeSliceTMSS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - ss := make([]*string, 0, rv.Len()) +func encodeSliceTMSS(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + ss := make([]string, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { tm := rv.Index(i).Interface().(encoding.TextMarshaler) text, err := tm.MarshalText() @@ -275,30 +344,30 @@ func encodeSliceTMSS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeVa if flags&flagOmitEmptyElem != 0 && len(text) == 0 { continue } - ss = append(ss, aws.String(string(text))) + ss = append(ss, string(text)) } if len(ss) == 0 { return nil, nil } - return &dynamodb.AttributeValue{SS: ss}, nil + return &types.AttributeValueMemberSS{Value: ss}, nil } -func encodeSliceSS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - ss := make([]*string, 0, rv.Len()) +func encodeSliceSS(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + ss := make([]string, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { s := rv.Index(i).String() if flags&flagOmitEmptyElem != 0 && s == "" { continue } - ss = append(ss, aws.String(s)) + ss = append(ss, s) } if len(ss) == 0 { return nil, nil } - return &dynamodb.AttributeValue{SS: ss}, nil + return &types.AttributeValueMemberSS{Value: ss}, nil } -func encodeSliceBS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func encodeSliceBS(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { bs := make([][]byte, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { b := rv.Index(i).Bytes() @@ -310,7 +379,7 @@ func encodeSliceBS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValu if len(bs) == 0 { return nil, nil } - return &dynamodb.AttributeValue{BS: bs}, nil + return &types.AttributeValueMemberBS{Value: bs}, nil } func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { @@ -334,10 +403,10 @@ func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structI return nil, err } - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if rv.IsNil() { if flags&flagAllowEmpty != 0 { - return &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{}}, nil + return &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}}, nil } if flags&flagNull != 0 { return nullAV, nil @@ -345,7 +414,7 @@ func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structI return nil, nil } - avs := make(map[string]*dynamodb.AttributeValue, rv.Len()) + avs := make(map[string]types.AttributeValue, rv.Len()) iter := rv.MapRange() for iter.Next() { @@ -369,7 +438,7 @@ func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structI return nil, nil } - return &dynamodb.AttributeValue{M: avs}, nil + return &types.AttributeValueMemberM{Value: avs}, nil }, nil } @@ -381,9 +450,9 @@ func encodeMapSet(rt /* map[T]bool | map[T]struct{} */ reflect.Type, flags encod } if rt.Key().Implements(rtypeTextMarshaler) { - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { length := rv.Len() - ss := make([]*string, 0, length) + ss := make([]string, 0, length) iter := rv.MapRange() for iter.Next() { if useBool && !iter.Value().Equal(truthy) { @@ -397,12 +466,12 @@ func encodeMapSet(rt /* map[T]bool | map[T]struct{} */ reflect.Type, flags encod continue } str := string(text) - ss = append(ss, &str) + ss = append(ss, str) } if len(ss) == 0 { return nil, nil } - return &dynamodb.AttributeValue{SS: ss}, nil + return &types.AttributeValueMemberSS{Value: ss}, nil }, nil } @@ -417,8 +486,8 @@ func encodeMapSet(rt /* map[T]bool | map[T]struct{} */ reflect.Type, flags encod // SS case reflect.String: - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - ss := make([]*string, 0, rv.Len()) + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + ss := make([]string, 0, rv.Len()) iter := rv.MapRange() for iter.Next() { if useBool && !iter.Value().Equal(truthy) { @@ -428,19 +497,19 @@ func encodeMapSet(rt /* map[T]bool | map[T]struct{} */ reflect.Type, flags encod if flags&flagOmitEmptyElem != 0 && s == "" { continue } - ss = append(ss, aws.String(s)) + ss = append(ss, s) } if len(ss) == 0 { return nil, nil } - return &dynamodb.AttributeValue{SS: ss}, nil + return &types.AttributeValueMemberSS{Value: ss}, nil }, nil // BS case reflect.Array: if rt.Key().Elem().Kind() == reflect.Uint8 { size := rt.Key().Len() - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { bs := make([][]byte, 0, rv.Len()) key := make([]byte, size) keyv := reflect.ValueOf(key) @@ -455,7 +524,7 @@ func encodeMapSet(rt /* map[T]bool | map[T]struct{} */ reflect.Type, flags encod if len(bs) == 0 { return nil, nil } - return &dynamodb.AttributeValue{BS: bs}, nil + return &types.AttributeValueMemberBS{Value: bs}, nil }, nil } } @@ -468,34 +537,34 @@ type numberType interface { } func encodeN[T numberType](get func(reflect.Value) T, format func(T, int) string) encodeFunc { - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { str := format(get(rv), 10) - return &dynamodb.AttributeValue{N: &str}, nil + return &types.AttributeValueMemberN{Value: str}, nil } } func encodeSliceNS[T numberType](get func(reflect.Value) T, format func(T, int) string) encodeFunc { - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - ns := make([]*string, 0, rv.Len()) + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + ns := make([]string, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { n := get(rv.Index(i)) if flags&flagOmitEmptyElem != 0 && n == 0 { continue } str := format(n, 10) - ns = append(ns, &str) + ns = append(ns, str) } if len(ns) == 0 { return nil, nil } - return &dynamodb.AttributeValue{NS: ns}, nil + return &types.AttributeValueMemberNS{Value: ns}, nil } } func encodeMapNS[T numberType](truthy reflect.Value, get func(reflect.Value) T, format func(T, int) string) encodeFunc { useBool := truthy.Kind() == reflect.Bool - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - ns := make([]*string, 0, rv.Len()) + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + ns := make([]string, 0, rv.Len()) iter := rv.MapRange() for iter.Next() { if useBool && !iter.Value().Equal(truthy) { @@ -506,12 +575,12 @@ func encodeMapNS[T numberType](truthy reflect.Value, get func(reflect.Value) T, continue } str := format(n, 10) - ns = append(ns, &str) + ns = append(ns, str) } if len(ns) == 0 { return nil, nil } - return &dynamodb.AttributeValue{NS: ns}, nil + return &types.AttributeValueMemberNS{Value: ns}, nil } } @@ -545,8 +614,8 @@ func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags, info *structI return nil, err } - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - avs := make([]*dynamodb.AttributeValue, 0, rv.Len()) + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + avs := make([]types.AttributeValue, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { innerVal := rv.Index(i) av, err := valueEnc(innerVal, flags|subflags) @@ -566,11 +635,11 @@ func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags, info *structI if flags&flagOmitEmpty != 0 && len(avs) == 0 { return nil, nil } - return &dynamodb.AttributeValue{L: avs}, nil + return &types.AttributeValueMemberL{Value: avs}, nil }, nil } -func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if !rv.CanInterface() || rv.IsNil() { if flags&flagNull != 0 { return nullAV, nil @@ -587,20 +656,20 @@ func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (*dynamodb.At func encodeUnixTime(rt reflect.Type) encodeFunc { switch rt { case rtypeTimePtr: - return encode2[*time.Time](func(t *time.Time, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return encode2[*time.Time](func(t *time.Time, flags encodeFlags) (types.AttributeValue, error) { if t == nil || t.IsZero() { return nil, nil } str := strconv.FormatInt(t.Unix(), 10) - return &dynamodb.AttributeValue{N: &str}, nil + return &types.AttributeValueMemberN{Value: str}, nil }) case rtypeTime: - return encode2[time.Time](func(t time.Time, flags encodeFlags) (*dynamodb.AttributeValue, error) { + return encode2[time.Time](func(t time.Time, flags encodeFlags) (types.AttributeValue, error) { if t.IsZero() { return nil, nil } str := strconv.FormatInt(t.Unix(), 10) - return &dynamodb.AttributeValue{N: &str}, nil + return &types.AttributeValueMemberN{Value: str}, nil }) } panic(fmt.Errorf("not time type: %v", rt)) diff --git a/encoding.go b/encoding.go index f619e4a..42e0fce 100644 --- a/encoding.go +++ b/encoding.go @@ -6,9 +6,8 @@ import ( "reflect" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) var typeCache sync.Map // unmarshalKey → *typedef @@ -110,29 +109,29 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) { if err != nil { return nil, err } - return av.M, err + return av.(*types.AttributeValueMemberM).Value, err } return encodeItem(def.fields, rv) } -func (def *typedef) encodeItemBypass(in any) (item map[string]*dynamodb.AttributeValue, err error) { +func (def *typedef) encodeItemBypass(in any) (item map[string]types.AttributeValue, err error) { switch x := in.(type) { - case map[string]*dynamodb.AttributeValue: + case map[string]types.AttributeValue: item = x - case *map[string]*dynamodb.AttributeValue: + case *map[string]types.AttributeValue: if x == nil { return nil, fmt.Errorf("item to encode is nil") } item = *x case awsEncoder: - item, err = dynamodbattribute.MarshalMap(x.iface) + item, err = attributevalue.MarshalMap(x.iface) case ItemMarshaler: item, err = x.MarshalDynamoItem() } return } -func (def *typedef) decodeItem(item map[string]*dynamodb.AttributeValue, outv reflect.Value) error { +func (def *typedef) decodeItem(item map[string]types.AttributeValue, outv reflect.Value) error { out := outv outv = indirectPtr(outv) if shouldBypassDecodeItem(outv.Type()) { @@ -150,36 +149,36 @@ func (def *typedef) decodeItem(item map[string]*dynamodb.AttributeValue, outv re // debugf("decode item: %v -> %T(%v)", item, out, out) switch outv.Kind() { case reflect.Struct: - return decodeStruct(def, flagNone, &dynamodb.AttributeValue{M: item}, outv) + return decodeStruct(def, flagNone, &types.AttributeValueMemberM{Value: item}, outv) case reflect.Map: - return def.decodeAttr(flagNone, &dynamodb.AttributeValue{M: item}, outv) + return def.decodeAttr(flagNone, &types.AttributeValueMemberM{Value: item}, outv) } bad: return fmt.Errorf("dynamo: cannot unmarshal item into type %v (must be a pointer to a map or struct, or a supported interface)", out.Type()) } -func (def *typedef) decodeItemBypass(item map[string]*dynamodb.AttributeValue, out any) error { +func (def *typedef) decodeItemBypass(item Item, out any) error { switch x := out.(type) { - case *map[string]*dynamodb.AttributeValue: + case *Item: *x = item return nil case awsEncoder: - return dynamodbattribute.UnmarshalMap(item, x.iface) + return attributevalue.UnmarshalMap(item, x.iface) case ItemUnmarshaler: return x.UnmarshalDynamoItem(item) } return nil } -func (def *typedef) decodeAttr(flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func (def *typedef) decodeAttr(flags encodeFlags, av types.AttributeValue, rv reflect.Value) error { if !rv.IsValid() || av == nil { return nil } // debugf("decodeAttr: %v(%v) <- %v", rv.Type(), rv, av) - if av.NULL != nil { + if _, isNull := av.(*types.AttributeValueMemberNULL); isNull { return decodeNull(def, flags, av, rv) } @@ -213,7 +212,7 @@ retry: return fmt.Errorf("dynamo: cannot unmarshal %s attribute value into type %s", avTypeName(av), rv.Type().String()) } -func (def *typedef) decodeType(key unmarshalKey, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) (bool, error) { +func (def *typedef) decodeType(key unmarshalKey, flags encodeFlags, av types.AttributeValue, rv reflect.Value) (bool, error) { do, ok := def.decoders[key] if !ok { return false, nil @@ -244,33 +243,83 @@ func (def *typedef) learn(rt reflect.Type) { } for { switch try { - case rtypeAttr: - def.handle(this(shapeAny), decode2(func(dst *dynamodb.AttributeValue, src *dynamodb.AttributeValue) error { - *dst = *src + case rtypeAttrB: + def.handle(this(shapeB), decode2(func(dst *types.AttributeValueMemberB, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberB) return nil })) - return + case rtypeAttrBS: + def.handle(this(shapeBS), decode2(func(dst *types.AttributeValueMemberBS, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberBS) + return nil + })) + case rtypeAttrBOOL: + def.handle(this(shapeBOOL), decode2(func(dst *types.AttributeValueMemberBOOL, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberBOOL) + return nil + })) + case rtypeAttrN: + def.handle(this(shapeN), decode2(func(dst *types.AttributeValueMemberN, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberN) + return nil + })) + case rtypeAttrS: + def.handle(this(shapeS), decode2(func(dst *types.AttributeValueMemberS, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberS) + return nil + })) + case rtypeAttrL: + def.handle(this(shapeL), decode2(func(dst *types.AttributeValueMemberL, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberL) + return nil + })) + case rtypeAttrNS: + def.handle(this(shapeNS), decode2(func(dst *types.AttributeValueMemberNS, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberNS) + return nil + })) + case rtypeAttrSS: + def.handle(this(shapeSS), decode2(func(dst *types.AttributeValueMemberSS, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberSS) + return nil + })) + case rtypeAttrM: + def.handle(this(shapeM), decode2(func(dst *types.AttributeValueMemberM, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberM) + return nil + })) + case rtypeAttrNULL: + def.handle(this(shapeNULL), decode2(func(dst *types.AttributeValueMemberNULL, src types.AttributeValue) error { + *dst = *src.(*types.AttributeValueMemberNULL) + return nil + })) + case rtypeTimePtr, rtypeTime: def.handle(this(shapeN), decodeUnixTime) - def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { - return t.UnmarshalText([]byte(*av.S)) + def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av types.AttributeValue) error { + return t.UnmarshalText([]byte(av.(*types.AttributeValueMemberS).Value)) })) return } switch { + // case try.Implements(rtypeAttr): + // def.handle(this(shapeAny), decode2(func(dst types.AttributeValue, src types.AttributeValue) error { + // *dst = src + // return nil + // })) case try.Implements(rtypeUnmarshaler): - def.handle(this(shapeAny), decode2(func(t Unmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeAny), decode2(func(t Unmarshaler, av types.AttributeValue) error { return t.UnmarshalDynamo(av) })) return case try.Implements(rtypeAWSUnmarshaler): - def.handle(this(shapeAny), decode2(func(t dynamodbattribute.Unmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeAny), decode2(func(t attributevalue.Unmarshaler, av types.AttributeValue) error { return t.UnmarshalDynamoDBAttributeValue(av) })) return case try.Implements(rtypeTextUnmarshaler): - def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { - return t.UnmarshalText([]byte(*av.S)) + def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av types.AttributeValue) error { + return t.UnmarshalText([]byte(av.(*types.AttributeValueMemberS).Value)) })) return } @@ -318,7 +367,7 @@ func (def *typedef) learn(rt reflect.Type) { truthy := truthy(rt) if !truthy.IsValid() { - bad := func(_ *typedef, _ encodeFlags, _ *dynamodb.AttributeValue, _ reflect.Value) error { + bad := func(_ *typedef, _ encodeFlags, _ types.AttributeValue, _ reflect.Value) error { return fmt.Errorf("dynamo: unmarshal map set: value type must be struct{} or bool, got %v", rt) } def.handle(this(shapeSS), bad) @@ -403,7 +452,7 @@ type structField struct { } var ( - nullAV = &dynamodb.AttributeValue{NULL: aws.Bool(true)} - emptyB = &dynamodb.AttributeValue{B: []byte("")} - emptyS = &dynamodb.AttributeValue{S: new(string)} + nullAV = &types.AttributeValueMemberNULL{Value: true} + emptyB = &types.AttributeValueMemberB{Value: []byte("")} + emptyS = &types.AttributeValueMemberS{Value: ""} ) diff --git a/encoding_aws.go b/encoding_aws.go index 5d49ead..003136d 100644 --- a/encoding_aws.go +++ b/encoding_aws.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) type Coder interface { @@ -17,12 +17,12 @@ type awsEncoder struct { iface interface{} } -func (w awsEncoder) MarshalDynamo() (*dynamodb.AttributeValue, error) { - return dynamodbattribute.Marshal(w.iface) +func (w awsEncoder) MarshalDynamo() (types.AttributeValue, error) { + return attributevalue.Marshal(w.iface) } -func (w awsEncoder) UnmarshalDynamo(av *dynamodb.AttributeValue) error { - return dynamodbattribute.Unmarshal(av, w.iface) +func (w awsEncoder) UnmarshalDynamo(av types.AttributeValue) error { + return attributevalue.Unmarshal(av, w.iface) } // AWSEncoding wraps an object, forcing it to use AWS's official dynamodbattribute package @@ -32,7 +32,7 @@ func AWSEncoding(v interface{}) Coder { return awsEncoder{v} } -func unmarshalAppendAWS(item map[string]*dynamodb.AttributeValue, out interface{}) error { +func unmarshalAppendAWS(item Item, out interface{}) error { rv := reflect.ValueOf(out) if rv.Kind() != reflect.Ptr || rv.Elem().Kind() != reflect.Slice { return fmt.Errorf("dynamo: unmarshal append AWS: result argument must be a slice pointer") @@ -40,7 +40,7 @@ func unmarshalAppendAWS(item map[string]*dynamodb.AttributeValue, out interface{ slicev := rv.Elem() innerRV := reflect.New(slicev.Type().Elem()) - if err := dynamodbattribute.UnmarshalMap(item, innerRV.Interface()); err != nil { + if err := attributevalue.UnmarshalMap(item, innerRV.Interface()); err != nil { return err } slicev = reflect.Append(slicev, innerRV.Elem()) diff --git a/encoding_aws_test.go b/encoding_aws_test.go index 20921a3..01c7ccc 100644 --- a/encoding_aws_test.go +++ b/encoding_aws_test.go @@ -5,9 +5,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) type awsTestWidget struct { @@ -33,7 +32,7 @@ func TestAWSEncoding(t *testing.T) { if err != nil { t.Error(err) } - official, err := dynamodbattribute.Marshal(w) + official, err := attributevalue.Marshal(w) if err != nil { t.Error(err) } @@ -56,12 +55,12 @@ func TestAWSEncoding(t *testing.T) { } func TestAWSIfaces(t *testing.T) { - unix := dynamodbattribute.UnixTime(time.Now()) + unix := attributevalue.UnixTime(time.Now()) av, err := Marshal(unix) if err != nil { t.Error(err) } - official, err := dynamodbattribute.Marshal(unix) + official, err := attributevalue.Marshal(unix) if err != nil { t.Error(err) } @@ -69,12 +68,12 @@ func TestAWSIfaces(t *testing.T) { t.Error("marshal not equal.", av, "≠", official) } - var result, officialResult dynamodbattribute.UnixTime + var result, officialResult attributevalue.UnixTime err = Unmarshal(official, &result) if err != nil { t.Error(err) } - err = dynamodbattribute.Unmarshal(official, &officialResult) + err = attributevalue.Unmarshal(official, &officialResult) if err != nil { t.Error(err) } @@ -96,7 +95,7 @@ func TestAWSItems(t *testing.T) { if err != nil { t.Error(err) } - official, err := dynamodbattribute.MarshalMap(item) + official, err := attributevalue.MarshalMap(item) if err != nil { t.Error(err) } @@ -109,7 +108,7 @@ func TestAWSItems(t *testing.T) { if err != nil { t.Error(err) } - err = dynamodbattribute.UnmarshalMap(official, &unmarshaledOfficial) + err = attributevalue.UnmarshalMap(official, &unmarshaledOfficial) if err != nil { t.Error(err) } @@ -132,20 +131,20 @@ func TestAWSUnmarshalAppend(t *testing.T) { A: "two", B: 222, } - err := unmarshalAppend(map[string]*dynamodb.AttributeValue{ - "one": {S: aws.String("test")}, - "two": {N: aws.String("555")}, - }, AWSEncoding(&list)) + err := unmarshalAppend(Item{ + "one": &types.AttributeValueMemberS{Value: "test"}, + "two": &types.AttributeValueMemberN{Value: "555"}, + }, &list) if err != nil { t.Error(err) } if len(list) != 1 && reflect.DeepEqual(list, []foo{expect1}) { t.Error("bad AWS unmarshal append:", list) } - err = unmarshalAppend(map[string]*dynamodb.AttributeValue{ - "one": {S: aws.String("two")}, - "two": {N: aws.String("222")}, - }, AWSEncoding(&list)) + err = unmarshalAppend(Item{ + "one": &types.AttributeValueMemberS{Value: ("two")}, + "two": &types.AttributeValueMemberN{Value: ("222")}, + }, &list) if err != nil { t.Error(err) } diff --git a/encoding_test.go b/encoding_test.go index 8c63567..d07c512 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -6,9 +6,9 @@ import ( "strconv" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) const ( @@ -21,51 +21,55 @@ var ( maxUintStr = strconv.FormatUint(uint64(maxUint), 10) ) +func init() { + time.Local = time.UTC +} + type customString string type customEmpty struct{} var encodingTests = []struct { name string in interface{} - out *dynamodb.AttributeValue + out types.AttributeValue }{ { name: "strings", in: "hello", - out: &dynamodb.AttributeValue{S: aws.String("hello")}, + out: &types.AttributeValueMemberS{Value: "hello"}, }, { name: "bools", in: true, - out: &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, + out: &types.AttributeValueMemberBOOL{Value: true}, }, { name: "ints", in: 123, - out: &dynamodb.AttributeValue{N: aws.String("123")}, + out: &types.AttributeValueMemberN{Value: "123"}, }, { name: "uints", in: uint(123), - out: &dynamodb.AttributeValue{N: aws.String("123")}, + out: &types.AttributeValueMemberN{Value: "123"}, }, { name: "floats", in: 1.2, - out: &dynamodb.AttributeValue{N: aws.String("1.2")}, + out: &types.AttributeValueMemberN{Value: "1.2"}, }, { name: "pointer (int)", in: new(int), - out: &dynamodb.AttributeValue{N: aws.String("0")}, + out: &types.AttributeValueMemberN{Value: "0"}, }, { name: "maps", in: map[string]bool{ "OK": true, }, - out: &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - "OK": {BOOL: aws.Bool(true)}, + out: &types.AttributeValueMemberM{Value: Item{ + "OK": &types.AttributeValueMemberBOOL{Value: true}, }}, }, { @@ -76,8 +80,8 @@ var encodingTests = []struct { }{ Empty: map[string]bool{}, }, - out: &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - "Empty": {M: map[string]*dynamodb.AttributeValue{}}, + out: &types.AttributeValueMemberM{Value: Item{ + "Empty": &types.AttributeValueMemberM{Value: Item{}}, }}, }, { @@ -87,9 +91,9 @@ var encodingTests = []struct { }{ M1: map[textMarshaler]bool{textMarshaler(true): true}, }, - out: &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - "M1": {M: map[string]*dynamodb.AttributeValue{ - "true": {BOOL: aws.Bool(true)}, + out: &types.AttributeValueMemberM{Value: Item{ + "M1": &types.AttributeValueMemberM{Value: Item{ + "true": &types.AttributeValueMemberBOOL{Value: true}, }}, }}, }, @@ -98,147 +102,147 @@ var encodingTests = []struct { in: struct { OK bool }{OK: true}, - out: &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - "OK": {BOOL: aws.Bool(true)}, + out: &types.AttributeValueMemberM{Value: Item{ + "OK": &types.AttributeValueMemberBOOL{Value: true}, }}, }, { name: "[]byte", in: []byte{'O', 'K'}, - out: &dynamodb.AttributeValue{B: []byte{'O', 'K'}}, + out: &types.AttributeValueMemberB{Value: []byte{'O', 'K'}}, }, { name: "slice", in: []int{1, 2, 3}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {N: aws.String("1")}, - {N: aws.String("2")}, - {N: aws.String("3")}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, }}, }, { name: "array", in: [3]int{1, 2, 3}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {N: aws.String("1")}, - {N: aws.String("2")}, - {N: aws.String("3")}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, }}, }, { name: "byte array", in: [4]byte{'a', 'b', 'c', 'd'}, - out: &dynamodb.AttributeValue{B: []byte{'a', 'b', 'c', 'd'}}, + out: &types.AttributeValueMemberB{Value: []byte{'a', 'b', 'c', 'd'}}, }, { name: "dynamo.Marshaler", in: customMarshaler(1), - out: &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, + out: &types.AttributeValueMemberBOOL{Value: true}, }, { name: "encoding.TextMarshaler", in: textMarshaler(true), - out: &dynamodb.AttributeValue{S: aws.String("true")}, + out: &types.AttributeValueMemberS{Value: "true"}, }, { name: "dynamodb.AttributeValue", - in: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {N: aws.String("1")}, - {N: aws.String("2")}, - {N: aws.String("3")}, + in: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, }}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {N: aws.String("1")}, - {N: aws.String("2")}, - {N: aws.String("3")}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, }}, }, { name: "slice with nil", in: []*int64{nil, aws.Int64(0), nil, aws.Int64(1337), nil}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {NULL: aws.Bool(true)}, - {N: aws.String("0")}, - {NULL: aws.Bool(true)}, - {N: aws.String("1337")}, - {NULL: aws.Bool(true)}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberN{Value: "0"}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberN{Value: "1337"}, + &types.AttributeValueMemberNULL{Value: true}, }}, }, { name: "array with nil", in: [...]*int64{nil, aws.Int64(0), nil, aws.Int64(1337), nil}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {NULL: aws.Bool(true)}, - {N: aws.String("0")}, - {NULL: aws.Bool(true)}, - {N: aws.String("1337")}, - {NULL: aws.Bool(true)}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberN{Value: "0"}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberN{Value: "1337"}, + &types.AttributeValueMemberNULL{Value: true}, }}, }, { name: "slice with empty string", in: []string{"", "hello", "", "world", ""}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {S: aws.String("")}, - {S: aws.String("hello")}, - {S: aws.String("")}, - {S: aws.String("world")}, - {S: aws.String("")}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: ""}, + &types.AttributeValueMemberS{Value: "hello"}, + &types.AttributeValueMemberS{Value: ""}, + &types.AttributeValueMemberS{Value: "world"}, + &types.AttributeValueMemberS{Value: ""}, }}, }, { name: "array with empty string", in: [...]string{"", "hello", "", "world", ""}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {S: aws.String("")}, - {S: aws.String("hello")}, - {S: aws.String("")}, - {S: aws.String("world")}, - {S: aws.String("")}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: ""}, + &types.AttributeValueMemberS{Value: "hello"}, + &types.AttributeValueMemberS{Value: ""}, + &types.AttributeValueMemberS{Value: "world"}, + &types.AttributeValueMemberS{Value: ""}, }}, }, { name: "slice of string pointers", in: []*string{nil, aws.String("hello"), aws.String(""), aws.String("world"), nil}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {NULL: aws.Bool(true)}, - {S: aws.String("hello")}, - {S: aws.String("")}, - {S: aws.String("world")}, - {NULL: aws.Bool(true)}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberS{Value: "hello"}, + &types.AttributeValueMemberS{Value: ""}, + &types.AttributeValueMemberS{Value: "world"}, + &types.AttributeValueMemberNULL{Value: true}, }}, }, { name: "slice with empty binary", in: [][]byte{{}, []byte("hello"), {}, []byte("world"), {}}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {B: []byte{}}, - {B: []byte{'h', 'e', 'l', 'l', 'o'}}, - {B: []byte{}}, - {B: []byte{'w', 'o', 'r', 'l', 'd'}}, - {B: []byte{}}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberB{Value: []byte{}}, + &types.AttributeValueMemberB{Value: []byte{'h', 'e', 'l', 'l', 'o'}}, + &types.AttributeValueMemberB{Value: []byte{}}, + &types.AttributeValueMemberB{Value: []byte{'w', 'o', 'r', 'l', 'd'}}, + &types.AttributeValueMemberB{Value: []byte{}}, }}, }, { name: "array with empty binary", in: [...][]byte{{}, []byte("hello"), {}, []byte("world"), {}}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {B: []byte{}}, - {B: []byte{'h', 'e', 'l', 'l', 'o'}}, - {B: []byte{}}, - {B: []byte{'w', 'o', 'r', 'l', 'd'}}, - {B: []byte{}}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberB{Value: []byte{}}, + &types.AttributeValueMemberB{Value: []byte{'h', 'e', 'l', 'l', 'o'}}, + &types.AttributeValueMemberB{Value: []byte{}}, + &types.AttributeValueMemberB{Value: []byte{'w', 'o', 'r', 'l', 'd'}}, + &types.AttributeValueMemberB{Value: []byte{}}, }}, }, { name: "array with empty binary ptrs", in: [...]*[]byte{byteSlicePtr([]byte{}), byteSlicePtr([]byte("hello")), nil, byteSlicePtr([]byte("world")), byteSlicePtr([]byte{})}, - out: &dynamodb.AttributeValue{L: []*dynamodb.AttributeValue{ - {B: []byte{}}, - {B: []byte{'h', 'e', 'l', 'l', 'o'}}, - {NULL: aws.Bool(true)}, - {B: []byte{'w', 'o', 'r', 'l', 'd'}}, - {B: []byte{}}, + out: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberB{Value: []byte{}}, + &types.AttributeValueMemberB{Value: []byte{'h', 'e', 'l', 'l', 'o'}}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberB{Value: []byte{'w', 'o', 'r', 'l', 'd'}}, + &types.AttributeValueMemberB{Value: []byte{}}, }}, }, } @@ -246,7 +250,7 @@ var encodingTests = []struct { var itemEncodingTests = []struct { name string in interface{} - out map[string]*dynamodb.AttributeValue + out Item }{ { name: "strings", @@ -255,8 +259,8 @@ var itemEncodingTests = []struct { }{ A: "hello", }, - out: map[string]*dynamodb.AttributeValue{ - "A": {S: aws.String("hello")}, + out: Item{ + "A": &types.AttributeValueMemberS{Value: "hello"}, }, }, { @@ -266,8 +270,8 @@ var itemEncodingTests = []struct { }{ A: "hello", }, - out: map[string]*dynamodb.AttributeValue{ - "A": {S: aws.String("hello")}, + out: Item{ + "A": &types.AttributeValueMemberS{Value: "hello"}, }, }, { @@ -277,8 +281,8 @@ var itemEncodingTests = []struct { }{ A: new(textMarshaler), }, - out: map[string]*dynamodb.AttributeValue{ - "A": {S: aws.String("false")}, + out: Item{ + "A": &types.AttributeValueMemberS{Value: "false"}, }, }, { @@ -288,8 +292,8 @@ var itemEncodingTests = []struct { }{ A: "hello", }, - out: map[string]*dynamodb.AttributeValue{ - "renamed": {S: aws.String("hello")}, + out: Item{ + "renamed": &types.AttributeValueMemberS{Value: "hello"}, }, }, { @@ -301,8 +305,8 @@ var itemEncodingTests = []struct { A: "", Other: true, }, - out: map[string]*dynamodb.AttributeValue{ - "Other": {BOOL: aws.Bool(true)}, + out: Item{ + "Other": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -316,8 +320,8 @@ var itemEncodingTests = []struct { }{ Other: true, }, - out: map[string]*dynamodb.AttributeValue{ - "Other": {BOOL: aws.Bool(true)}, + out: Item{ + "Other": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -334,14 +338,14 @@ var itemEncodingTests = []struct { NilTime *time.Time NilCustom *customMarshaler NilText *textMarshaler - NilAWS *dynamodbattribute.UnixTime + NilAWS *attributevalue.UnixTime }{ OK: "OK", EmptyL: []int{}, }, - out: map[string]*dynamodb.AttributeValue{ - "OK": {S: aws.String("OK")}, - "EmptyL": {L: []*dynamodb.AttributeValue{}}, + out: Item{ + "OK": &types.AttributeValueMemberS{Value: "OK"}, + "EmptyL": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }, }, { @@ -352,9 +356,9 @@ var itemEncodingTests = []struct { }{ B: []byte{}, }, - out: map[string]*dynamodb.AttributeValue{ - "S": {S: aws.String("")}, - "B": {B: []byte{}}, + out: Item{ + "S": &types.AttributeValueMemberS{Value: ""}, + "B": &types.AttributeValueMemberB{Value: []byte{}}, }, }, { @@ -364,11 +368,11 @@ var itemEncodingTests = []struct { }{ M: map[string]*string{"null": nil, "empty": aws.String(""), "normal": aws.String("hello")}, }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "null": {NULL: aws.Bool(true)}, - "empty": {S: aws.String("")}, - "normal": {S: aws.String("hello")}, + out: Item{ + "M": &types.AttributeValueMemberM{Value: Item{ + "null": &types.AttributeValueMemberNULL{Value: true}, + "empty": &types.AttributeValueMemberS{Value: ""}, + "normal": &types.AttributeValueMemberS{Value: "hello"}, }}, }, }, @@ -383,12 +387,16 @@ var itemEncodingTests = []struct { }, }, }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "nestedmap": {M: map[string]*dynamodb.AttributeValue{ - "empty": {S: aws.String("")}, - }}, - }}, + out: Item{ + "M": &types.AttributeValueMemberM{ + Value: Item{ + "nestedmap": &types.AttributeValueMemberM{ + Value: Item{ + "empty": &types.AttributeValueMemberS{Value: ""}, + }, + }, + }, + }, }, }, { @@ -402,14 +410,20 @@ var itemEncodingTests = []struct { }, }, }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "slice": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "empty": {S: aws.String("")}, - }}, - }}, - }}, + out: Item{ + "M": &types.AttributeValueMemberM{ + Value: Item{ + "slice": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberM{ + Value: Item{ + "empty": &types.AttributeValueMemberS{Value: ""}, + }, + }, + }, + }, + }, + }, }, }, { @@ -423,15 +437,16 @@ var itemEncodingTests = []struct { }, }, }, - out: map[string]*dynamodb.AttributeValue{ - "L": {L: []*dynamodb.AttributeValue{ - { - M: map[string]*dynamodb.AttributeValue{ - "empty": {S: aws.String("")}, + out: Item{ + "L": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberM{ + Value: Item{ + "empty": &types.AttributeValueMemberS{Value: ""}, + }, }, }, }, - }, }, }, { @@ -443,12 +458,12 @@ var itemEncodingTests = []struct { M map[string]*string `dynamo:",null"` SS []string `dynamo:",null,set"` }{}, - out: map[string]*dynamodb.AttributeValue{ - "S": {NULL: aws.Bool(true)}, - "B": {NULL: aws.Bool(true)}, - "NilTime": {NULL: aws.Bool(true)}, - "M": {NULL: aws.Bool(true)}, - "SS": {NULL: aws.Bool(true)}, + out: Item{ + "S": &types.AttributeValueMemberNULL{Value: true}, + "B": &types.AttributeValueMemberNULL{Value: true}, + "NilTime": &types.AttributeValueMemberNULL{Value: true}, + "M": &types.AttributeValueMemberNULL{Value: true}, + "SS": &types.AttributeValueMemberNULL{Value: true}, }, }, { @@ -460,8 +475,8 @@ var itemEncodingTests = []struct { Embedded: true, }, }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {BOOL: aws.Bool(true)}, + out: Item{ + "Embedded": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -473,8 +488,8 @@ var itemEncodingTests = []struct { Embedded: true, }, }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {BOOL: aws.Bool(true)}, + out: Item{ + "Embedded": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -486,8 +501,8 @@ var itemEncodingTests = []struct { Embedded: true, }, }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {BOOL: aws.Bool(true)}, + out: Item{ + "Embedded": &types.AttributeValueMemberBOOL{Value: true}, }, }, { @@ -498,8 +513,8 @@ var itemEncodingTests = []struct { }{ Embedded: "OK", }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {S: aws.String("OK")}, + out: Item{ + "Embedded": &types.AttributeValueMemberS{Value: "OK"}, }, }, { @@ -510,8 +525,8 @@ var itemEncodingTests = []struct { }{ Embedded: "OK", }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {S: aws.String("OK")}, + out: Item{ + "Embedded": &types.AttributeValueMemberS{Value: "OK"}, }, }, { @@ -522,8 +537,8 @@ var itemEncodingTests = []struct { }{ Embedded: "OK", }, - out: map[string]*dynamodb.AttributeValue{ - "Embedded": {S: aws.String("OK")}, + out: Item{ + "Embedded": &types.AttributeValueMemberS{Value: "OK"}, }, }, { @@ -569,26 +584,26 @@ var itemEncodingTests = []struct { NS4: map[int]struct{}{maxInt: {}}, NS5: map[uint]bool{maxUint: true}, }, - out: map[string]*dynamodb.AttributeValue{ - "SS1": {SS: []*string{aws.String("A"), aws.String("B")}}, - "SS2": {SS: []*string{aws.String("true"), aws.String("false")}}, - "SS3": {SS: []*string{aws.String("A")}}, - "SS4": {SS: []*string{aws.String("A")}}, - "SS5": {SS: []*string{aws.String("A")}}, - "SS6": {SS: []*string{aws.String("A"), aws.String("B")}}, - "SS7": {SS: []*string{aws.String("true")}}, - "SS8": {SS: []*string{aws.String("false")}}, - "SS9": {SS: []*string{aws.String("A"), aws.String("B"), aws.String("")}}, - "SS10": {SS: []*string{aws.String("A")}}, - "BS1": {BS: [][]byte{{'A'}, {'B'}}}, - "BS2": {BS: [][]byte{{'A'}}}, - "BS3": {BS: [][]byte{{'A'}}}, - "BS4": {BS: [][]byte{{'A'}, {'B'}, {}}}, - "NS1": {NS: []*string{aws.String("1"), aws.String("2")}}, - "NS2": {NS: []*string{aws.String("1"), aws.String("2")}}, - "NS3": {NS: []*string{aws.String("1"), aws.String("2")}}, - "NS4": {NS: []*string{aws.String(maxIntStr)}}, - "NS5": {NS: []*string{aws.String(maxUintStr)}}, + out: Item{ + "SS1": &types.AttributeValueMemberSS{Value: []string{"A", "B"}}, + "SS2": &types.AttributeValueMemberSS{Value: []string{"true", "false"}}, + "SS3": &types.AttributeValueMemberSS{Value: []string{"A"}}, + "SS4": &types.AttributeValueMemberSS{Value: []string{"A"}}, + "SS5": &types.AttributeValueMemberSS{Value: []string{"A"}}, + "SS6": &types.AttributeValueMemberSS{Value: []string{"A", "B"}}, + "SS7": &types.AttributeValueMemberSS{Value: []string{"true"}}, + "SS8": &types.AttributeValueMemberSS{Value: []string{"false"}}, + "SS9": &types.AttributeValueMemberSS{Value: []string{"A", "B", ""}}, + "SS10": &types.AttributeValueMemberSS{Value: []string{"A"}}, + "BS1": &types.AttributeValueMemberBS{Value: [][]byte{{'A'}, {'B'}}}, + "BS2": &types.AttributeValueMemberBS{Value: [][]byte{{'A'}}}, + "BS3": &types.AttributeValueMemberBS{Value: [][]byte{{'A'}}}, + "BS4": &types.AttributeValueMemberBS{Value: [][]byte{{'A'}, {'B'}, {}}}, + "NS1": &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + "NS2": &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + "NS3": &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + "NS4": &types.AttributeValueMemberNS{Value: []string{maxIntStr}}, + "NS5": &types.AttributeValueMemberNS{Value: []string{maxUintStr}}, }, }, { @@ -602,17 +617,17 @@ var itemEncodingTests = []struct { "OK": true, }, }, - out: map[string]*dynamodb.AttributeValue{ - "S": {S: aws.String("Hello")}, - "B": {B: []byte{'A', 'B'}}, - "N": {N: aws.String("1.2")}, - "L": {L: []*dynamodb.AttributeValue{ - {S: aws.String("A")}, - {S: aws.String("B")}, - {N: aws.String("1.2")}, + out: Item{ + "S": &types.AttributeValueMemberS{Value: "Hello"}, + "B": &types.AttributeValueMemberB{Value: []byte{'A', 'B'}}, + "N": &types.AttributeValueMemberN{Value: "1.2"}, + "L": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "A"}, + &types.AttributeValueMemberS{Value: "B"}, + &types.AttributeValueMemberN{Value: "1.2"}, }}, - "M": {M: map[string]*dynamodb.AttributeValue{ - "OK": {BOOL: aws.Bool(true)}, + "M": &types.AttributeValueMemberM{Value: Item{ + "OK": &types.AttributeValueMemberBOOL{Value: true}, }}, }, }, @@ -625,22 +640,9 @@ var itemEncodingTests = []struct { "Hello": "world", }, }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "Hello": {S: aws.String("world")}, - }}, - }, - }, - { - name: "map string attributevalue", - in: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "Hello": {S: aws.String("world")}, - }}, - }, - out: map[string]*dynamodb.AttributeValue{ - "M": {M: map[string]*dynamodb.AttributeValue{ - "Hello": {S: aws.String("world")}, + out: Item{ + "M": &types.AttributeValueMemberM{Value: Item{ + "Hello": &types.AttributeValueMemberS{Value: "world"}, }}, }, }, @@ -651,8 +653,8 @@ var itemEncodingTests = []struct { }{ TTL: time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC), }, - out: map[string]*dynamodb.AttributeValue{ - "TTL": {S: aws.String("2019-01-01T00:00:00Z")}, + out: Item{ + "TTL": &types.AttributeValueMemberS{Value: "2019-01-01T00:00:00Z"}, }, }, { @@ -662,8 +664,8 @@ var itemEncodingTests = []struct { }{ TTL: time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC), }, - out: map[string]*dynamodb.AttributeValue{ - "TTL": {N: aws.String("1546300800")}, + out: Item{ + "TTL": &types.AttributeValueMemberN{Value: "1546300800"}, }, }, { @@ -673,7 +675,7 @@ var itemEncodingTests = []struct { }{ TTL: time.Time{}, }, - out: map[string]*dynamodb.AttributeValue{}, + out: Item{}, }, { name: "*time.Time (unixtime encoding)", @@ -682,8 +684,8 @@ var itemEncodingTests = []struct { }{ TTL: aws.Time(time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)), }, - out: map[string]*dynamodb.AttributeValue{ - "TTL": {N: aws.String("1546300800")}, + out: Item{ + "TTL": &types.AttributeValueMemberN{Value: "1546300800"}, }, }, { @@ -693,20 +695,20 @@ var itemEncodingTests = []struct { }{ TTL: nil, }, - out: map[string]*dynamodb.AttributeValue{}, + out: Item{}, }, { name: "dynamodb.ItemUnmarshaler", in: customItemMarshaler{Thing: 52}, - out: map[string]*dynamodb.AttributeValue{ - "thing": {N: aws.String("52")}, + out: Item{ + "thing": &types.AttributeValueMemberN{Value: "52"}, }, }, { name: "*dynamodb.ItemUnmarshaler", in: &customItemMarshaler{Thing: 52}, - out: map[string]*dynamodb.AttributeValue{ - "thing": {N: aws.String("52")}, + out: Item{ + "thing": &types.AttributeValueMemberN{Value: "52"}, }, }, { @@ -719,22 +721,22 @@ var itemEncodingTests = []struct { Children: []Person{{Name: "Bobby", Children: []Person{}}}, Name: "Hank", }, - out: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Hank")}, - "Spouse": {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Peggy")}, - "Children": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Bobby")}, - "Children": {L: []*dynamodb.AttributeValue{}}, + out: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Hank"}, + "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Peggy"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, }, }, }}, - "Children": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Bobby")}, - "Children": {L: []*dynamodb.AttributeValue{}}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, }}, }, @@ -753,29 +755,28 @@ var itemEncodingTests = []struct { }, Nickname: "H-Dawg", }, - out: map[string]*dynamodb.AttributeValue{ - "ID": {N: aws.String("555")}, - "Nickname": {S: aws.String("H-Dawg")}, - "Person": {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Hank")}, - "Spouse": {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Peggy")}, - "Children": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Bobby")}, - "Children": {L: []*dynamodb.AttributeValue{}}, + out: map[string]types.AttributeValue{ + "ID": &types.AttributeValueMemberN{Value: "555"}, + "Nickname": &types.AttributeValueMemberS{Value: "H-Dawg"}, + "Person": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Hank"}, + "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Peggy"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, }, }, }}, - "Children": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "Name": {S: aws.String("Bobby")}, - "Children": {L: []*dynamodb.AttributeValue{}}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, }}, - }, - }, + }}, }, }, { @@ -805,46 +806,46 @@ var itemEncodingTests = []struct { {ID: "recursion", Blah: 30}, }, }, - out: map[string]*dynamodb.AttributeValue{ - "ID": {N: aws.String("123")}, - "Text": {S: aws.String("hello")}, - "Friends": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "ID": {N: aws.String("1")}, - "Text": {S: aws.String("suffering")}, - "Child": {M: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("pain")}, + out: Item{ + "ID": &types.AttributeValueMemberN{Value: "123"}, + "Text": &types.AttributeValueMemberS{Value: "hello"}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "1"}, + "Text": &types.AttributeValueMemberS{Value: "suffering"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "pain"}, }}, - "Friends": {L: []*dynamodb.AttributeValue{}}, - "Enemies": {L: []*dynamodb.AttributeValue{}}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, - {M: map[string]*dynamodb.AttributeValue{ - "ID": {N: aws.String("2")}, - "Text": {S: aws.String("love")}, - "Child": {M: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("understanding")}, + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "2"}, + "Text": &types.AttributeValueMemberS{Value: "love"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "understanding"}, }}, - "Friends": {L: []*dynamodb.AttributeValue{}}, - "Enemies": {L: []*dynamodb.AttributeValue{}}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, }}, }}, - "Enemies": {L: []*dynamodb.AttributeValue{ - {M: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("recursion")}, - "Blah": {N: aws.String("30")}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "recursion"}, + "Blah": &types.AttributeValueMemberN{Value: "30"}, }}, }}, - "Child": {M: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("test")}, - "Blah": {N: aws.String("555")}, - "Child": {M: map[string]*dynamodb.AttributeValue{ - "ID": {N: aws.String("222")}, - "Text": {S: aws.String("help")}, - "Friends": {L: []*dynamodb.AttributeValue{}}, - "Enemies": {L: []*dynamodb.AttributeValue{}}, - "Child": {M: map[string]*dynamodb.AttributeValue{ - "ID": {S: aws.String("why")}, - "Blah": {N: aws.String("1337")}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "test"}, + "Blah": &types.AttributeValueMemberN{Value: "555"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "222"}, + "Text": &types.AttributeValueMemberS{Value: "help"}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "why"}, + "Blah": &types.AttributeValueMemberN{Value: "1337"}, }}, }}, }}, @@ -862,14 +863,13 @@ type ExportedEmbedded struct { type customMarshaler int -func (cm customMarshaler) MarshalDynamo() (*dynamodb.AttributeValue, error) { - return &dynamodb.AttributeValue{ - BOOL: aws.Bool(cm != 0), - }, nil +func (cm customMarshaler) MarshalDynamo() (types.AttributeValue, error) { + return &types.AttributeValueMemberBOOL{Value: cm != 0}, nil } -func (cm *customMarshaler) UnmarshalDynamo(av *dynamodb.AttributeValue) error { - if *av.BOOL == true { +func (cm *customMarshaler) UnmarshalDynamo(av types.AttributeValue) error { + + if res, ok := av.(*types.AttributeValueMemberBOOL); ok && res.Value == true { *cm = 1 } return nil @@ -913,30 +913,29 @@ type customItemMarshaler struct { Thing interface{} `dynamo:"thing"` } -func (cim *customItemMarshaler) MarshalDynamoItem() (map[string]*dynamodb.AttributeValue, error) { +func (cim *customItemMarshaler) MarshalDynamoItem() (Item, error) { thing := strconv.Itoa(cim.Thing.(int)) - attrs := map[string]*dynamodb.AttributeValue{ - "thing": { - N: &thing, - }, + attrs := Item{ + "thing": &types.AttributeValueMemberN{Value: thing}, } return attrs, nil } -func (cim *customItemMarshaler) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { +func (cim *customItemMarshaler) UnmarshalDynamoItem(item Item) error { thingAttr := item["thing"] - if thingAttr == nil || thingAttr.N == nil { + if res, ok := thingAttr.(*types.AttributeValueMemberN); !ok { return errors.New("Missing or not a number") - } + } else { - thing, err := strconv.Atoi(*thingAttr.N) - if err != nil { - return errors.New("Invalid number") - } + thing, err := strconv.Atoi(res.Value) + if err != nil { + return errors.New("Invalid number") + } - cim.Thing = thing + cim.Thing = thing + } return nil } diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..9fc803f --- /dev/null +++ b/example_test.go @@ -0,0 +1,67 @@ +package dynamo_test + +import ( + "context" + "log" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/guregu/dynamo/v2" +) + +func ExampleNew() { + // Basic setup example. + // See: https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config for more on configuration options. + const region = "us-west-2" + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion(region), + ) + if err != nil { + log.Fatal(err) + } + db := dynamo.New(cfg) + // use the db + _ = db +} + +func ExampleNew_local_endpoint() { + // Example of connecting to a DynamoDB local instance. + // See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.DownloadingAndRunning.html + const endpoint = "http://localhost:8000" + resolver := aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{URL: endpoint}, nil + }, + ) + // credentials can be anything, but must be set + creds := credentials.NewStaticCredentialsProvider("dummy", "dummy", "") + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion("local"), // region can also be anything + config.WithEndpointResolverWithOptions(resolver), + config.WithCredentialsProvider(creds), + ) + if err != nil { + log.Fatal(err) + } + db := dynamo.New(cfg) + // use the db + _ = db +} + +func ExampleRetryTxConflicts() { + // `dynamo.RetryTxConflicts` is an option you can pass to retry.NewStandard. + // It will automatically retry canceled transactions. + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(dynamo.RetryTxConflicts) + })) + if err != nil { + log.Fatal(err) + } + db := dynamo.New(cfg) + // use the db + _ = db +} diff --git a/go.mod b/go.mod index 2a5ecf9..5f81167 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,28 @@ -module github.com/guregu/dynamo +module github.com/guregu/dynamo/v2 require ( - github.com/aws/aws-sdk-go v1.48.10 - github.com/cenkalti/backoff/v4 v4.2.1 - golang.org/x/sync v0.5.0 + github.com/aws/aws-sdk-go-v2 v1.27.2 + github.com/aws/aws-sdk-go-v2/config v1.11.0 + github.com/aws/aws-sdk-go-v2/credentials v1.6.4 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.1 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.32.8 + github.com/aws/smithy-go v1.20.2 + github.com/cenkalti/backoff/v4 v4.3.0 + golang.org/x/sync v0.7.0 ) -require github.com/jmespath/go-jmespath v0.4.0 // indirect - -go 1.20 +require ( + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.8.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.2 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.20.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.5.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.11.1 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect +) -retract ( - v1.22.0 // See issues: #228, #230 -) \ No newline at end of file +go 1.21 diff --git a/go.sum b/go.sum index 3988a60..971c824 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,69 @@ -github.com/aws/aws-sdk-go v1.48.10 h1:0LIFG3wp2Dt6PsxKWCg1Y1xRrn2vZnW5/gWdgaBalKg= -github.com/aws/aws-sdk-go v1.48.10/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/aws/aws-sdk-go-v2 v1.11.2/go.mod h1:SQfA+m2ltnu1cA0soUkj4dRSsmITiVQUJvBIZjzfPyQ= +github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= +github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.27.2 h1:pLsTXqX93rimAOZG2FIYraDQstZaaGVVN4tNw65v0h8= +github.com/aws/aws-sdk-go-v2 v1.27.2/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/config v1.11.0 h1:Czlld5zBB61A3/aoegA9/buZulwL9mHHfizh/Oq+Kqs= +github.com/aws/aws-sdk-go-v2/config v1.11.0/go.mod h1:VrQDJGFBM5yZe+IOeenNZ/DWoErdny+k2MHEIpwDsEY= +github.com/aws/aws-sdk-go-v2/credentials v1.6.4 h1:2hvbUoHufns0lDIsaK8FVCMukT1WngtZPavN+W2FkSw= +github.com/aws/aws-sdk-go-v2/credentials v1.6.4/go.mod h1:tTrhvBPHyPde4pdIPSba4Nv7RYr4wP9jxXEDa1bKn/8= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.4.4 h1:9WteVf5jmManG9HlxTFsk1+MT1IZ8S/8rvR+3A3OKng= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.4.4/go.mod h1:MWyvQ5I9fEsoV+Im6IgpILXlAaypjlRqUkyS5GP5pIo= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.1 h1:Uhn/kOwwHAL4vI6LdgvV0cfaQbaLyvJbCCyrSZLNBm8= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.1/go.mod h1:fEjI/gFP0DXxz5c4tRWyYEQpcNCVvMzjh62t0uKFk8U= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.8.2 h1:KiN5TPOLrEjbGCvdTQR4t0U4T87vVwALZ5Bg3jpMqPY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.8.2/go.mod h1:dF2F6tXEOgmW5X1ZFO/EPtWrcm7XkW07KNcJUGNtt4s= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.2/go.mod h1:SgKKNBIoDC/E1ZCDhhMW3yalWjwuLjMcpLzsM/QQnWo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.9 h1:cy8ahBJuhtM8GTTSyOkfy6WVPV1IE+SS5/wfXUYuulw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.9/go.mod h1:CZBXGLaJnEZI6EVNcPd7a6B5IC5cA/GkRWtu9fp3S6Y= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.0.2/go.mod h1:xT4XX6w5Sa3dhg50JrYyy3e4WPYo/+WjY/BXtqXVunU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.9 h1:A4SYk07ef04+vxZToz9LWvAXl9LW0NClpPpMsi31cz0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.9/go.mod h1:5jJcHuwDagxN+ErjQ3PU3ocf6Ylc/p9x+BLO/+X4iXw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.2 h1:IQup8Q6lorXeiA/rK72PeToWoWK8h7VAPgHNWdSrtgE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.2/go.mod h1:VITe/MdW6EMXPb0o0txu/fsonXbMHUU2OC2Qp7ivU4o= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.10.0/go.mod h1:ELltfl9ri0n4sZ/VjPZBgemNMd9mYIpCAuZhc7NP7l4= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.26.8 h1:XKO0BswTDeZMLDBd/b5pCEZGttNXrzRUVtFvp2Ak/Vo= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.26.8/go.mod h1:N5tqZcYMM0N1PN7UQYJNWuGyO886OfnMhf/3MAbqMcI= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.32.8 h1:yOosUCdI/P+gfBd8uXk6lvZmrp7z2Xs8s1caIDP33lo= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.32.8/go.mod h1:4sYs0Krug9vn4cfDly4ExdbXJRqqZZBVDJNtBHGxCpQ= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.8.1 h1:AQurjazY9KPUxvq4EBN9Q3iWGaDrcqfpfSWtkP0Qy+g= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.8.1/go.mod h1:RiesWyLiePOOwyT5ySDupQosvbG+OTMv9pws/EhDu4U= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.20.10 h1:aK9uyT3Ua6UOmTMBYEM3sJHlnSO994eNZGagFlfLiOs= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.20.10/go.mod h1:S541uoWn3nWvo28EE8DnMbqZ5sZRAipVUPuL11V08Xw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.5.0/go.mod h1:80NaCIH9YU3rzTTs/J/ECATjXuRqzo/wB6ukO6MZ0XY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.3.3/go.mod h1:zOyLMYyg60yyZpOCniAUuibWVqTU4TuLmMa/Wh4P+HA= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.11 h1:e9AVb17H4x5FTE5KWIP5M1Du+9M86pS+Hw0lBUdN8EY= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.11/go.mod h1:B90ZQJa36xo0ph9HsoteI1+r8owgQH/U1QNfqZQkj1Q= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.10 h1:+ijk29Q2FlKCinEzG6GE3IcOyBsmPNUmFq/L82pSyhI= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.10/go.mod h1:D9WZXFWtJD76gmV2ZciWcY8BJBFdCblqdfF9OmkrwVU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.5.2 h1:CKdUNKmuilw/KNmO2Q53Av8u+ZyXMC2M9aX8Z+c/gzg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.5.2/go.mod h1:FgR1tCsn8C6+Hf+N5qkfrE4IXvUL1RgW87sunJ+5J4I= +github.com/aws/aws-sdk-go-v2/service/sso v1.6.2 h1:2IDmvSb86KT44lSg1uU4ONpzgWLOuApRl6Tg54mZ6Dk= +github.com/aws/aws-sdk-go-v2/service/sso v1.6.2/go.mod h1:KnIpszaIdwI33tmc/W/GGXyn22c1USYxA/2KyvoeDY0= +github.com/aws/aws-sdk-go-v2/service/sts v1.11.1 h1:QKR7wy5e650q70PFKMfGF9sTo0rZgUevSSJ4wxmyWXk= +github.com/aws/aws-sdk-go-v2/service/sts v1.11.1/go.mod h1:UV2N5HaPfdbDpkgkz4sRzWCvQswZjdO1FfqCWl0t7RA= +github.com/aws/smithy-go v1.9.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/cenkalti/backoff/v4 v4.1.2 h1:6Yo7N8UP2K6LWZnW94DLVSSrbobcWdVzAYOisuDPIFo= +github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -11,8 +71,11 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/put.go b/put.go index 206728b..4d43b71 100644 --- a/put.go +++ b/put.go @@ -3,8 +3,8 @@ package dynamo import ( "context" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Put is a request to create or replace an item. @@ -13,7 +13,7 @@ type Put struct { table Table returnType string - item map[string]*dynamodb.AttributeValue + item Item subber condition string @@ -53,14 +53,7 @@ func (p *Put) ConsumedCapacity(cc *ConsumedCapacity) *Put { } // Run executes this put. -func (p *Put) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return p.RunWithContext(ctx) -} - -// Run executes this put. -func (p *Put) RunWithContext(ctx context.Context) error { +func (p *Put) Run(ctx context.Context) error { p.returnType = "NONE" _, err := p.run(ctx) return err @@ -68,15 +61,7 @@ func (p *Put) RunWithContext(ctx context.Context) error { // OldValue executes this put, unmarshaling the previous value into out. // Returns ErrNotFound is there was no previous value. -func (p *Put) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return p.OldValueWithContext(ctx, out) -} - -// OldValueWithContext executes this put, unmarshaling the previous value into out. -// Returns ErrNotFound is there was no previous value. -func (p *Put) OldValueWithContext(ctx context.Context, out interface{}) error { +func (p *Put) OldValue(ctx context.Context, out interface{}) error { p.returnType = "ALL_OLD" output, err := p.run(ctx) switch { @@ -95,7 +80,7 @@ func (p *Put) run(ctx context.Context) (output *dynamodb.PutItemOutput, err erro req := p.input() p.table.db.retry(ctx, func() error { - output, err = p.table.db.client.PutItemWithContext(ctx, req) + output, err = p.table.db.client.PutItem(ctx, req) return err }) if p.cc != nil { @@ -108,7 +93,7 @@ func (p *Put) input() *dynamodb.PutItemInput { input := &dynamodb.PutItemInput{ TableName: &p.table.name, Item: p.item, - ReturnValues: &p.returnType, + ReturnValues: types.ReturnValue(p.returnType), ExpressionAttributeNames: p.nameExpr, ExpressionAttributeValues: p.valueExpr, } @@ -116,18 +101,18 @@ func (p *Put) input() *dynamodb.PutItemInput { input.ConditionExpression = &p.condition } if p.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input } -func (p *Put) writeTxItem() (*dynamodb.TransactWriteItem, error) { +func (p *Put) writeTxItem() (*types.TransactWriteItem, error) { if p.err != nil { return nil, p.err } input := p.input() - item := &dynamodb.TransactWriteItem{ - Put: &dynamodb.Put{ + item := &types.TransactWriteItem{ + Put: &types.Put{ TableName: input.TableName, Item: input.Item, ExpressionAttributeNames: input.ExpressionAttributeNames, diff --git a/put_test.go b/put_test.go index fe2ad63..cbdd400 100644 --- a/put_test.go +++ b/put_test.go @@ -1,11 +1,12 @@ package dynamo import ( + "context" "reflect" "testing" "time" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" ) func TestPut(t *testing.T) { @@ -13,6 +14,7 @@ func TestPut(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() type widget2 struct { widget @@ -34,7 +36,7 @@ func TestPut(t *testing.T) { List: []*string{}, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -53,7 +55,7 @@ func TestPut(t *testing.T) { } var oldValue widget2 var cc ConsumedCapacity - err = table.Put(newItem).ConsumedCapacity(&cc).OldValue(&oldValue) + err = table.Put(newItem).ConsumedCapacity(&cc).OldValue(ctx, &oldValue) if err != nil { t.Error("unexpected error:", err) } @@ -67,7 +69,7 @@ func TestPut(t *testing.T) { } // putting the same item: this should fail - err = table.Put(newItem).If("attribute_not_exists(UserID)").If("attribute_not_exists('Time')").Run() + err = table.Put(newItem).If("attribute_not_exists(UserID)").If("attribute_not_exists('Time')").Run(ctx) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -78,6 +80,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() type awsWidget struct { XUserID int `dynamodbav:"UserID"` @@ -98,13 +101,13 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { XMsg: "hello world", } - err = table.Put(AWSEncoding(item)).Run() + err = table.Put(AWSEncoding(item)).Run(ctx) if err != nil { t.Error(err) } var result awsWidget - err = table.Get("UserID", item.XUserID).Range("Time", Equal, item.XTime).Consistent(true).One(AWSEncoding(&result)) + err = table.Get("UserID", item.XUserID).Range("Time", Equal, item.XTime).Consistent(true).One(ctx, AWSEncoding(&result)) if err != nil { t.Error(err) } @@ -113,7 +116,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { } var list []awsWidget - err = table.Get("UserID", item.XUserID).Consistent(true).All(AWSEncoding(&list)) + err = table.Get("UserID", item.XUserID).Consistent(true).All(ctx, AWSEncoding(&list)) if err != nil { t.Error(err) } diff --git a/query.go b/query.go index 19f7b4e..8e49d5c 100644 --- a/query.go +++ b/query.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" + "math" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Query is a request to get one or more items in a table. @@ -16,21 +17,21 @@ import ( // and http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_GetItem.html type Query struct { table Table - startKey map[string]*dynamodb.AttributeValue + startKey Item index string hashKey string - hashValue *dynamodb.AttributeValue + hashValue types.AttributeValue rangeKey string - rangeValues []*dynamodb.AttributeValue + rangeValues []types.AttributeValue rangeOp Operator projection string filters []string consistent bool - limit int64 - searchLimit int64 + limit int + searchLimit int32 reqLimit int order *Order @@ -71,7 +72,7 @@ const ( Descending = false // ScanIndexForward = false ) -var selectCount = aws.String("COUNT") +var selectCount types.Select = "COUNT" // Get creates a new request to get an item. // Name is the name of the hash key (a.k.a. partition key). @@ -170,7 +171,7 @@ func (q *Query) Consistent(on bool) *Query { } // Limit specifies the maximum amount of results to return. -func (q *Query) Limit(limit int64) *Query { +func (q *Query) Limit(limit int) *Query { q.limit = limit return q } @@ -179,8 +180,9 @@ func (q *Query) Limit(limit int64) *Query { // If a filter is not specified, the number of results will be limited. // If a filter is specified, the number of results to consider for filtering will be limited. // SearchLimit > 0 implies RequestLimit(1). -func (q *Query) SearchLimit(limit int64) *Query { - q.searchLimit = limit +// Note: limit will be capped to MaxInt32 as that is the maximum number the DynamoDB API will accept. +func (q *Query) SearchLimit(limit int) *Query { + q.searchLimit = int32(min(limit, math.MaxInt32)) return q } @@ -206,13 +208,7 @@ func (q *Query) ConsumedCapacity(cc *ConsumedCapacity) *Query { // One executes this query and retrieves a single result, // unmarshaling the result to out. -func (q *Query) One(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return q.OneWithContext(ctx, out) -} - -func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { +func (q *Query) One(ctx context.Context, out interface{}) error { if q.err != nil { return q.err } @@ -224,7 +220,7 @@ func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { var res *dynamodb.GetItemOutput err := q.table.db.retry(ctx, func() error { var err error - res, err = q.table.db.client.GetItemWithContext(ctx, req) + res, err = q.table.db.client.GetItem(ctx, req) if err != nil { return err } @@ -249,7 +245,7 @@ func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { var res *dynamodb.QueryOutput err := q.table.db.retry(ctx, func() error { var err error - res, err = q.table.db.client.QueryWithContext(ctx, req) + res, err = q.table.db.client.Query(ctx, req) if err != nil { return err } @@ -276,18 +272,13 @@ func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { } // Count executes this request, returning the number of results. -func (q *Query) Count() (int64, error) { - ctx, cancel := defaultContext() - defer cancel() - return q.CountWithContext(ctx) -} - -func (q *Query) CountWithContext(ctx context.Context) (int64, error) { +func (q *Query) Count(ctx context.Context) (int, error) { if q.err != nil { return 0, q.err } - var count, scanned int64 + var count int + var scanned int32 var reqs int var res *dynamodb.QueryOutput for { @@ -296,19 +287,14 @@ func (q *Query) CountWithContext(ctx context.Context) (int64, error) { err := q.table.db.retry(ctx, func() error { var err error - res, err = q.table.db.client.QueryWithContext(ctx, input) + res, err = q.table.db.client.Query(ctx, input) if err != nil { return err } reqs++ - if res.Count == nil { - return errors.New("malformed DynamoDB response: count is nil") - } - count += *res.Count - if res.ScannedCount != nil { - scanned += *res.ScannedCount - } + count += int(res.Count) + scanned += res.ScannedCount return nil }) @@ -338,16 +324,16 @@ type queryIter struct { output *dynamodb.QueryOutput err error idx int - n int64 + n int reqs int // last item evaluated - last map[string]*dynamodb.AttributeValue + last Item // cache of primary keys, used for generating LEKs keys map[string]struct{} // example LastEvaluatedKey and ExclusiveStartKey, used to lazily evaluate the primary keys if possible - exLEK map[string]*dynamodb.AttributeValue - exESK map[string]*dynamodb.AttributeValue + exLEK Item + exESK Item keyErr error unmarshal unmarshalFunc @@ -355,13 +341,7 @@ type queryIter struct { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *queryIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *queryIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *queryIter) Next(ctx context.Context, out interface{}) bool { // stop if we have an error if ctx.Err() != nil { itr.err = ctx.Err() @@ -411,7 +391,7 @@ func (itr *queryIter) NextWithContext(ctx context.Context, out interface{}) bool itr.err = itr.query.table.db.retry(ctx, func() error { var err error - itr.output, err = itr.query.table.db.client.QueryWithContext(ctx, itr.input) + itr.output, err = itr.query.table.db.client.Query(ctx, itr.input) return err }) @@ -432,7 +412,7 @@ func (itr *queryIter) NextWithContext(ctx context.Context, out interface{}) bool } if itr.output.LastEvaluatedKey != nil { // we need to retry until we get some data - return itr.NextWithContext(ctx, out) + return itr.Next(ctx, out) } // we're done return false @@ -452,76 +432,58 @@ func (itr *queryIter) Err() error { return itr.err } -func (itr *queryIter) LastEvaluatedKey() PagingKey { +func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { if itr.output != nil { // if we've hit the end of our results, we can use the real LEK if itr.idx == len(itr.output.Items) { - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, nil } // figure out the primary keys if needed if itr.keys == nil && itr.keyErr == nil { - ctx, _ := defaultContext() // TODO(v2): take context instead of using the default itr.keys, itr.keyErr = itr.query.table.primaryKeys(ctx, itr.exLEK, itr.exESK, itr.query.index) } if itr.keyErr != nil { // primaryKeys can fail if the credentials lack DescribeTable permissions // in order to preserve backwards compatibility, we fall back to the old behavior and warn // see: https://github.com/guregu/dynamo/pull/187#issuecomment-1045183901 - // TODO(v2): rejigger this API. - itr.query.table.db.log("dynamo: Warning:", itr.keyErr, "Returning a later LastEvaluatedKey.") - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to determine LastEvaluatedKey in query: %w", itr.keyErr) } // we can't use the real LEK, so we need to infer the LEK from the last item we saw lek, err := lekify(itr.last, itr.keys) - // unfortunately, this API can't return an error so a warning is the best we can do... - // this matches old behavior before the LEK was automatically generated - // TODO(v2): fix this. if err != nil { - itr.query.table.db.log("dynamo: Warning:", err, "Returning a later LastEvaluatedKey.") - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to infer LastEvaluatedKey in query: %w", err) } - return lek + return lek, nil } - return nil + return nil, nil } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (q *Query) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return q.AllWithContext(ctx, out) -} - -func (q *Query) AllWithContext(ctx context.Context, out interface{}) error { +func (q *Query) All(ctx context.Context, out interface{}) error { iter := &queryIter{ query: q, unmarshal: unmarshalAppendTo(out), err: q.err, } - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } return iter.Err() } // AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice. // This returns a PagingKey you can use with StartFrom to split up results. -func (q *Query) AllWithLastEvaluatedKey(out interface{}) (PagingKey, error) { - ctx, cancel := defaultContext() - defer cancel() - return q.AllWithLastEvaluatedKeyContext(ctx, out) -} - -func (q *Query) AllWithLastEvaluatedKeyContext(ctx context.Context, out interface{}) (PagingKey, error) { +func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) { iter := &queryIter{ query: q, unmarshal: unmarshalAppendTo(out), err: q.err, } - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } - return iter.LastEvaluatedKey(), iter.Err() + lek, err := iter.LastEvaluatedKey(ctx) + return lek, errors.Join(iter.Err(), err) } // Iter returns a results iterator for this request. @@ -531,7 +493,6 @@ func (q *Query) Iter() PagingIter { unmarshal: unmarshalItem, err: q.err, } - return iter } @@ -563,7 +524,8 @@ func (q *Query) queryInput() *dynamodb.QueryInput { } if q.limit > 0 { if len(q.filters) == 0 { - req.Limit = &q.limit + limit := int32(min(math.MaxInt32, q.limit)) + req.Limit = &limit } } if q.searchLimit > 0 { @@ -583,22 +545,22 @@ func (q *Query) queryInput() *dynamodb.QueryInput { req.ScanIndexForward = (*bool)(q.order) } if q.cc != nil { - req.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + req.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return req } -func (q *Query) keyConditions() map[string]*dynamodb.Condition { - conds := map[string]*dynamodb.Condition{ +func (q *Query) keyConditions() map[string]types.Condition { + conds := map[string]types.Condition{ q.hashKey: { - AttributeValueList: []*dynamodb.AttributeValue{q.hashValue}, - ComparisonOperator: aws.String(string(Equal)), + AttributeValueList: []types.AttributeValue{q.hashValue}, + ComparisonOperator: types.ComparisonOperatorEq, }, } if q.rangeKey != "" && q.rangeOp != "" { - conds[q.rangeKey] = &dynamodb.Condition{ + conds[q.rangeKey] = types.Condition{ AttributeValueList: q.rangeValues, - ComparisonOperator: aws.String(string(q.rangeOp)), + ComparisonOperator: types.ComparisonOperator(q.rangeOp), } } return conds @@ -617,18 +579,18 @@ func (q *Query) getItemInput() *dynamodb.GetItemInput { req.ProjectionExpression = &q.projection } if q.cc != nil { - req.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + req.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return req } -func (q *Query) getTxItem() (*dynamodb.TransactGetItem, error) { +func (q *Query) getTxItem() (types.TransactGetItem, error) { if !q.canGetItem() { - return nil, errors.New("dynamo: transaction Query is too complex; no indexes or filters are allowed") + return types.TransactGetItem{}, errors.New("dynamo: transaction Query is too complex; no indexes or filters are allowed") } input := q.getItemInput() - return &dynamodb.TransactGetItem{ - Get: &dynamodb.Get{ + return types.TransactGetItem{ + Get: &types.Get{ TableName: input.TableName, Key: input.Key, ExpressionAttributeNames: input.ExpressionAttributeNames, @@ -637,8 +599,8 @@ func (q *Query) getTxItem() (*dynamodb.TransactGetItem, error) { }, nil } -func (q *Query) keys() map[string]*dynamodb.AttributeValue { - keys := map[string]*dynamodb.AttributeValue{ +func (q *Query) keys() Item { + keys := Item{ q.hashKey: q.hashValue, } if q.rangeKey != "" && len(q.rangeValues) > 0 { @@ -647,9 +609,9 @@ func (q *Query) keys() map[string]*dynamodb.AttributeValue { return keys } -func (q *Query) keysAndAttribs() *dynamodb.KeysAndAttributes { - kas := &dynamodb.KeysAndAttributes{ - Keys: []map[string]*dynamodb.AttributeValue{q.keys()}, +func (q *Query) keysAndAttribs() types.KeysAndAttributes { + kas := types.KeysAndAttributes{ + Keys: []Item{q.keys()}, ExpressionAttributeNames: q.nameExpr, ConsistentRead: &q.consistent, } diff --git a/query_test.go b/query_test.go index f1d7992..d4ac428 100644 --- a/query_test.go +++ b/query_test.go @@ -1,18 +1,20 @@ package dynamo import ( + "context" "reflect" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) func TestGetAllCount(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one @@ -26,7 +28,7 @@ func TestGetAllCount(t *testing.T) { }, StrPtr: new(string), } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -37,8 +39,8 @@ func TestGetAllCount(t *testing.T) { "#meta": "Meta", "#foo": "foo", }), - AttributeValues: map[string]*dynamodb.AttributeValue{ - ":bar": {S: aws.String("bar")}, + AttributeValues: Item{ + ":bar": &types.AttributeValueMemberS{Value: "bar"}, }, } @@ -51,7 +53,7 @@ func TestGetAllCount(t *testing.T) { Filter("StrPtr = ?", ""). Filter("?", lit). ConsumedCapacity(&cc1). - All(&result) + All(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -62,7 +64,7 @@ func TestGetAllCount(t *testing.T) { Filter("StrPtr = ?", ""). Filter("$", lit). // both $ and ? are OK for literals ConsumedCapacity(&cc2). - Count() + Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -91,7 +93,7 @@ func TestGetAllCount(t *testing.T) { // query specifically against the inserted item (using GetItem) var one widget - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -101,7 +103,7 @@ func TestGetAllCount(t *testing.T) { // query specifically against the inserted item (using Query) one = widget{} - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Filter("Msg = ?", item.Msg).Filter("StrPtr = ?", "").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Filter("Msg = ?", item.Msg).Filter("StrPtr = ?", "").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -115,7 +117,7 @@ func TestGetAllCount(t *testing.T) { UserID: item.UserID, Time: item.Time, } - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Project("UserID", "Time").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Project("UserID", "Time").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -133,7 +135,7 @@ func TestGetAllCount(t *testing.T) { "animal.cow": "moo", }, } - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).ProjectExpr("UserID, $, Meta.foo, Meta.$", "Time", "animal.cow").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).ProjectExpr("UserID, $, Meta.foo, Meta.$", "Time", "animal.cow").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -146,6 +148,7 @@ func TestQueryPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() table := testDB.Table(testTableWidgets) widgets := []interface{}{ @@ -166,7 +169,7 @@ func TestQueryPaging(t *testing.T) { }, } - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Error("couldn't write paging prep data", err) return } @@ -174,18 +177,22 @@ func TestQueryPaging(t *testing.T) { itr := table.Get("UserID", 1969).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } - itr = table.Get("UserID", 1969).StartFrom(itr.LastEvaluatedKey()).SearchLimit(1).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", err) + } + itr = table.Get("UserID", 1969).StartFrom(lek).SearchLimit(1).Iter() } } @@ -193,6 +200,7 @@ func TestQueryMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.Background() table := testDB.Table(testTableWidgets) widgets := []interface{}{ @@ -214,7 +222,7 @@ func TestQueryMagicLEK(t *testing.T) { } t.Run("prepare data", func(t *testing.T) { - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Fatal(err) } }) @@ -223,23 +231,27 @@ func TestQueryMagicLEK(t *testing.T) { itr := table.Get("UserID", 1970).Filter("attribute_exists('Msg')").Limit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } - itr = table.Get("UserID", 1970).StartFrom(itr.LastEvaluatedKey()).Limit(1).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", lek) + } + itr = table.Get("UserID", 1970).StartFrom(lek).Limit(1).Iter() } }) t.Run("table cache", func(t *testing.T) { - pk, err := table.primaryKeys(aws.BackgroundContext(), nil, nil, "") + pk, err := table.primaryKeys(context.Background(), nil, nil, "") if err != nil { t.Fatal(err) } @@ -256,18 +268,22 @@ func TestQueryMagicLEK(t *testing.T) { itr := table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).Limit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } - itr = table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).StartFrom(itr.LastEvaluatedKey()).Limit(1).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", err) + } + itr = table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).StartFrom(lek).Limit(1).Iter() } }) } @@ -277,10 +293,11 @@ func TestQueryBadKeys(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.Background() t.Run("hash key", func(t *testing.T) { var v interface{} - err := table.Get("UserID", "").Range("Time", Equal, "123").One(&v) + err := table.Get("UserID", "").Range("Time", Equal, "123").One(ctx, &v) if err == nil { t.Error("want error, got", err) } @@ -288,7 +305,7 @@ func TestQueryBadKeys(t *testing.T) { t.Run("range key", func(t *testing.T) { var v interface{} - err := table.Get("UserID", 123).Range("Time", Equal, "").One(&v) + err := table.Get("UserID", 123).Range("Time", Equal, "").One(ctx, &v) if err == nil { t.Error("want error, got", err) } diff --git a/reflect.go b/reflect.go index 32b5874..9da805c 100644 --- a/reflect.go +++ b/reflect.go @@ -6,14 +6,25 @@ import ( "reflect" "time" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // special attribute encoders var ( - // *dynamodb.AttributeValue - rtypeAttr = reflect.TypeOf((*dynamodb.AttributeValue)(nil)) + // types.AttributeValue + rtypeAttr = reflect.TypeOf((*types.AttributeValue)(nil)).Elem() + rtypeAttrB = reflect.TypeOf((*types.AttributeValueMemberB)(nil)) + rtypeAttrBS = reflect.TypeOf((*types.AttributeValueMemberBS)(nil)) + rtypeAttrBOOL = reflect.TypeOf((*types.AttributeValueMemberBOOL)(nil)) + rtypeAttrN = reflect.TypeOf((*types.AttributeValueMemberN)(nil)) + rtypeAttrS = reflect.TypeOf((*types.AttributeValueMemberS)(nil)) + rtypeAttrL = reflect.TypeOf((*types.AttributeValueMemberL)(nil)) + rtypeAttrNS = reflect.TypeOf((*types.AttributeValueMemberNS)(nil)) + rtypeAttrSS = reflect.TypeOf((*types.AttributeValueMemberSS)(nil)) + rtypeAttrM = reflect.TypeOf((*types.AttributeValueMemberM)(nil)) + rtypeAttrNULL = reflect.TypeOf((*types.AttributeValueMemberNULL)(nil)) + // *time.Time rtypeTimePtr = reflect.TypeOf((*time.Time)(nil)) // time.Time @@ -22,14 +33,14 @@ var ( // Unmarshaler rtypeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() // dynamodbattribute.Unmarshaler - rtypeAWSUnmarshaler = reflect.TypeOf((*dynamodbattribute.Unmarshaler)(nil)).Elem() + rtypeAWSUnmarshaler = reflect.TypeOf((*attributevalue.Unmarshaler)(nil)).Elem() // encoding.TextUnmarshaler rtypeTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() // Marshaler rtypeMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() - // dynamodbattribute.Marshaler - rtypeAWSMarshaler = reflect.TypeOf((*dynamodbattribute.Marshaler)(nil)).Elem() + // attributevalue.Marshaler + rtypeAWSMarshaler = reflect.TypeOf((*attributevalue.Marshaler)(nil)).Elem() // encoding.TextMarshaler rtypeTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() @@ -41,7 +52,7 @@ var ( // special item encoders var ( - rtypeItemPtr = reflect.TypeOf((*map[string]*dynamodb.AttributeValue)(nil)) + rtypeItemPtr = reflect.TypeOf((*map[string]types.AttributeValue)(nil)) rtypeItem = rtypeItemPtr.Elem() rtypeItemUnmarshaler = reflect.TypeOf((*ItemUnmarshaler)(nil)).Elem() rtypeItemMarshaler = reflect.TypeOf((*ItemMarshaler)(nil)).Elem() @@ -118,7 +129,7 @@ func dig(rv reflect.Value, index []int) reflect.Value { return rv } -func visitFields(item map[string]*dynamodb.AttributeValue, rv reflect.Value, seen map[string]struct{}, fn func(av *dynamodb.AttributeValue, flags encodeFlags, v reflect.Value) error) error { +func visitFields(item map[string]types.AttributeValue, rv reflect.Value, seen map[string]struct{}, fn func(av types.AttributeValue, flags encodeFlags, v reflect.Value) error) error { for rv.Kind() == reflect.Pointer { if rv.IsNil() { if !rv.CanSet() { @@ -203,7 +214,7 @@ type structInfo struct { queue []encodeKey } -func (info *structInfo) encode(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { +func (info *structInfo) encode(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { item := make(Item, len(info.fields)) for _, field := range info.fields { fv := dig(rv, field.index) @@ -230,7 +241,7 @@ func (info *structInfo) encode(rv reflect.Value, flags encodeFlags) (*dynamodb.A } item[field.name] = av } - return &dynamodb.AttributeValue{M: item}, nil + return &types.AttributeValueMemberM{Value: item}, nil } func (info *structInfo) isZero(rv reflect.Value) bool { diff --git a/retry.go b/retry.go index adf0f6b..3be6e8a 100644 --- a/retry.go +++ b/retry.go @@ -3,92 +3,41 @@ package dynamo import ( "context" "errors" - "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/cenkalti/backoff/v4" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -// RetryTimeout defines the maximum amount of time that requests will -// attempt to automatically retry for. In other words, this is the maximum -// amount of time that dynamo operations will block. -// RetryTimeout is only considered by methods that do not take a context. -// Higher values are better when using tables with lower throughput. -var RetryTimeout = 1 * time.Minute +// TODO: delete this -func defaultContext() (context.Context, context.CancelFunc) { - if RetryTimeout == 0 { - return aws.BackgroundContext(), (func() {}) - } - return context.WithDeadline(aws.BackgroundContext(), time.Now().Add(RetryTimeout)) +func (db *DB) retry(_ context.Context, f func() error) error { + return f() } -func (db *DB) retry(ctx context.Context, f func() error) error { - // if a custom retryer has been set, the SDK will retry for us - if db.retryer != nil { - return f() - } - - var err error - var next time.Duration - b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) - for i := 0; db.retryMax < 0 || i <= db.retryMax; i++ { - if err = f(); err == nil { - return nil - } - - if !canRetry(err) { - return err - } - - if next = b.NextBackOff(); next == backoff.Stop { - return err - } - - if err := aws.SleepWithContext(ctx, next); err != nil { - return err - } - } - return err +// RetryTxConflicts is an option for [github.com/aws/aws-sdk-go-v2/aws/retry.NewStandard] +// that adds retrying behavior for TransactionConflict within TransactionCanceledException errors. +// See also: [github.com/aws/aws-sdk-go-v2/config.WithRetryer]. +func RetryTxConflicts(opts *retry.StandardOptions) { + opts.Retryables = append(opts.Retryables, retry.IsErrorRetryableFunc(shouldRetryTx)) } -// errRetry is a sentinel error to retry, should never be returned to user -var errRetry = errors.New("dynamo: retry") - -func canRetry(err error) bool { - if errors.Is(err, errRetry) { - return true - } - - if txe, ok := err.(*dynamodb.TransactionCanceledException); ok && txe.StatusCode() == 400 { - retry := false +func shouldRetryTx(err error) aws.Ternary { + var txe *types.TransactionCanceledException + if errors.As(err, &txe) { + retry := aws.FalseTernary for _, reason := range txe.CancellationReasons { if reason.Code == nil { continue } switch *reason.Code { case "ValidationError", "ConditionalCheckFailed", "ItemCollectionSizeLimitExceeded": - return false + return aws.FalseTernary case "ThrottlingError", "ProvisionedThroughputExceeded", "TransactionConflict": - retry = true + retry = aws.TrueTernary } } return retry } - - if ae, ok := err.(awserr.RequestFailure); ok { - switch ae.StatusCode() { - case 500, 503: - return true - case 400: - switch ae.Code() { - case "ProvisionedThroughputExceededException", - "ThrottlingException": - return true - } - } - } - return false + return aws.UnknownTernary } diff --git a/retry_test.go b/retry_test.go index 3075ebb..48e90f9 100644 --- a/retry_test.go +++ b/retry_test.go @@ -2,77 +2,29 @@ package dynamo import ( "context" - "fmt" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -func TestRetryMax(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - test := func(max int) (string, func(t *testing.T)) { - name := fmt.Sprintf("max(%d)", max) - return name, func(t *testing.T) { - t.Parallel() - t.Helper() - sesh, err := session.NewSession(&aws.Config{ - MaxRetries: aws.Int(max), - Credentials: dummyCreds, - }) - if err != nil { - t.Fatal(err) - } - db := New(sesh) - - var runs int - err = db.retry(context.Background(), func() error { - runs++ - return awserr.NewRequestFailure( - awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "dummy error", nil), - 400, - fmt.Sprintf("try-%d", runs), - ) - }) - if err == nil { - t.Fatal("expected error, got nil") - } - if want := max + 1; runs != want { - t.Error("wrong number of runs. want:", want, "got:", runs) - } - } - } - t.Run(test(0)) - t.Run(test(1)) - t.Run(test(3)) -} - func TestRetryCustom(t *testing.T) { t.Parallel() - sesh, err := session.NewSession(&aws.Config{ - Retryer: client.NoOpRetryer{}, - MaxRetries: aws.Int(10), // should be ignored (superseded by Retryer) + retryer := func() aws.Retryer { + return retry.NewStandard(func(so *retry.StandardOptions) { + so.MaxAttempts = 1 + }) + } + db := New(aws.Config{ + Retryer: retryer, Credentials: dummyCreds, }) - if err != nil { - t.Fatal(err) - } - db := New(sesh) var runs int - err = db.retry(context.Background(), func() error { + err := db.retry(context.Background(), func() error { runs++ - return awserr.NewRequestFailure( - awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "dummy error", nil), - 400, - fmt.Sprintf("try-%d", runs), - ) + return &types.ProvisionedThroughputExceededException{} }) if err == nil { t.Fatal("expected error, got nil") diff --git a/scan.go b/scan.go index 2facf32..f649a7f 100644 --- a/scan.go +++ b/scan.go @@ -3,11 +3,13 @@ package dynamo import ( "context" "errors" + "fmt" + "math" "strings" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "golang.org/x/sync/errgroup" ) @@ -15,18 +17,18 @@ import ( // See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Scan.html type Scan struct { table Table - startKey map[string]*dynamodb.AttributeValue + startKey Item index string projection string filters []string consistent bool - limit int64 - searchLimit int64 + limit int + searchLimit int32 reqLimit int - segment int64 - totalSegments int64 + segment int32 + totalSegments int32 subber @@ -58,16 +60,20 @@ func (s *Scan) Index(name string) *Scan { // Segment specifies the Segment and Total Segments to operate on in a manual parallel scan. // This is useful if you want to control the parallel scans by yourself instead of using ParallelIter. // Ignored by ParallelIter and friends. -func (s *Scan) Segment(segment int64, totalSegments int64) *Scan { - s.segment = segment - s.totalSegments = totalSegments +// totalSegments must be less than MaxInt32 due to API limits. +func (s *Scan) Segment(segment int, totalSegments int) *Scan { + s.segment = int32(segment) + s.totalSegments = int32(totalSegments) + if totalSegments > math.MaxInt32 { + s.setError(fmt.Errorf("dynamo: total segments in Scan must be less than or equal to %d (got %d)", math.MaxInt32, totalSegments)) + } return s } -func (s *Scan) newSegments(segments int64, leks []PagingKey) []*scanIter { +func (s *Scan) newSegments(segments int, leks []PagingKey) []*scanIter { iters := make([]*scanIter, segments) - lekLen := int64(len(leks)) - for i := int64(0); i < segments; i++ { + lekLen := len(leks) + for i := int(0); i < segments; i++ { seg := *s var cc *ConsumedCapacity if s.cc != nil { @@ -117,7 +123,7 @@ func (s *Scan) Consistent(on bool) *Scan { } // Limit specifies the maximum amount of results to return. -func (s *Scan) Limit(limit int64) *Scan { +func (s *Scan) Limit(limit int) *Scan { s.limit = limit return s } @@ -126,8 +132,8 @@ func (s *Scan) Limit(limit int64) *Scan { // Use this along with StartFrom and Iter's LastEvaluatedKey to split up results. // Note that DynamoDB limits result sets to 1MB. // SearchLimit > 0 implies RequestLimit(1). -func (s *Scan) SearchLimit(limit int64) *Scan { - s.searchLimit = limit +func (s *Scan) SearchLimit(limit int) *Scan { + s.searchLimit = int32(min(limit, math.MaxInt32)) return s } @@ -155,7 +161,7 @@ func (s *Scan) Iter() PagingIter { // IterParallel returns a results iterator for this request, running the given number of segments in parallel. // Canceling the context given here will cancel the processing of all segments. -func (s *Scan) IterParallel(ctx context.Context, segments int64) ParallelIter { +func (s *Scan) IterParallel(ctx context.Context, segments int) ParallelIter { iters := s.newSegments(segments, nil) ps := newParallelScan(iters, s.cc, false, unmarshalItem) go ps.run(ctx) @@ -165,109 +171,89 @@ func (s *Scan) IterParallel(ctx context.Context, segments int64) ParallelIter { // IterParallelFrom returns a results iterator continued from a previous ParallelIter's LastEvaluatedKeys. // Canceling the context given here will cancel the processing of all segments. func (s *Scan) IterParallelStartFrom(ctx context.Context, keys []PagingKey) ParallelIter { - iters := s.newSegments(int64(len(keys)), keys) + iters := s.newSegments(len(keys), keys) ps := newParallelScan(iters, s.cc, false, unmarshalItem) go ps.run(ctx) return ps } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (s *Scan) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return s.AllWithContext(ctx, out) -} - -// AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (s *Scan) AllWithContext(ctx context.Context, out interface{}) error { +func (s *Scan) All(ctx context.Context, out interface{}) error { itr := &scanIter{ scan: s, unmarshal: unmarshalAppendTo(out), err: s.err, } - for itr.NextWithContext(ctx, out) { + for itr.Next(ctx, out) { } return itr.Err() } // AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice. // It returns a key you can use with StartWith to continue this query. -func (s *Scan) AllWithLastEvaluatedKey(out interface{}) (PagingKey, error) { - ctx, cancel := defaultContext() - defer cancel() - return s.AllWithLastEvaluatedKeyContext(ctx, out) -} - -// AllWithLastEvaluatedKeyContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -// It returns a key you can use with StartWith to continue this query. -func (s *Scan) AllWithLastEvaluatedKeyContext(ctx context.Context, out interface{}) (PagingKey, error) { +func (s *Scan) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) { itr := &scanIter{ scan: s, unmarshal: unmarshalAppendTo(out), err: s.err, } - for itr.NextWithContext(ctx, out) { + for itr.Next(ctx, out) { } - return itr.LastEvaluatedKey(), itr.Err() + lek, err := itr.LastEvaluatedKey(ctx) + return lek, errors.Join(itr.Err(), err) } // AllParallel executes this request by running the given number of segments in parallel, then unmarshaling all results to out, which must be a pointer to a slice. -func (s *Scan) AllParallel(ctx context.Context, segments int64, out interface{}) error { +func (s *Scan) AllParallel(ctx context.Context, segments int, out interface{}) error { iters := s.newSegments(segments, nil) ps := newParallelScan(iters, s.cc, true, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } return ps.Err() } // AllParallelWithLastEvaluatedKeys executes this request by running the given number of segments in parallel, then unmarshaling all results to out, which must be a pointer to a slice. // Returns a slice of LastEvalutedKeys that can be used to continue the query later. -func (s *Scan) AllParallelWithLastEvaluatedKeys(ctx context.Context, segments int64, out interface{}) ([]PagingKey, error) { +func (s *Scan) AllParallelWithLastEvaluatedKeys(ctx context.Context, segments int, out interface{}) ([]PagingKey, error) { iters := s.newSegments(segments, nil) ps := newParallelScan(iters, s.cc, false, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } - return ps.LastEvaluatedKeys(), ps.Err() + leks, err := ps.LastEvaluatedKeys(ctx) + return leks, errors.Join(ps.Err(), err) } // AllParallelStartFrom executes this request by continuing parallel scans from the given LastEvaluatedKeys, then unmarshaling all results to out, which must be a pointer to a slice. // Returns a new slice of LastEvaluatedKeys after the scan finishes. func (s *Scan) AllParallelStartFrom(ctx context.Context, keys []PagingKey, out interface{}) ([]PagingKey, error) { - iters := s.newSegments(int64(len(keys)), keys) + iters := s.newSegments(len(keys), keys) ps := newParallelScan(iters, s.cc, false, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } - return ps.LastEvaluatedKeys(), ps.Err() + leks, err := ps.LastEvaluatedKeys(ctx) + return leks, errors.Join(ps.Err(), err) } // Count executes this request and returns the number of items matching the scan. // It takes into account the filter, limit, search limit, and all other parameters given. // It may return a higher count than the limits. -func (s *Scan) Count() (int64, error) { - ctx, cancel := defaultContext() - defer cancel() - return s.CountWithContext(ctx) -} - -// CountWithContext executes this request and returns the number of items matching the scan. -// It takes into account the filter, limit, search limit, and all other parameters given. -// It may return a higher count than the limits. -func (s *Scan) CountWithContext(ctx context.Context) (int64, error) { +func (s *Scan) Count(ctx context.Context) (int, error) { if s.err != nil { return 0, s.err } - var count, scanned int64 + var count int + var scanned int32 input := s.scanInput() - input.Select = aws.String(dynamodb.SelectCount) + input.Select = types.SelectCount var reqs int for { var out *dynamodb.ScanOutput err := s.table.db.retry(ctx, func() error { var err error - out, err = s.table.db.client.ScanWithContext(ctx, input) + out, err = s.table.db.client.Scan(ctx, input) return err }) if err != nil { @@ -275,13 +261,8 @@ func (s *Scan) CountWithContext(ctx context.Context) (int64, error) { } reqs++ - if out.Count == nil { - return count, errors.New("malformed DynamoDB outponse: count is nil") - } - count += *out.Count - if out.ScannedCount != nil { - scanned += *out.ScannedCount - } + count += int(out.Count) + scanned += out.ScannedCount if s.cc != nil { addConsumedCapacity(s.cc, out.ConsumedCapacity) @@ -313,7 +294,8 @@ func (s *Scan) scanInput() *dynamodb.ScanInput { } if s.limit > 0 { if len(s.filters) == 0 { - input.Limit = &s.limit + limit := int32(min(s.limit, math.MaxInt32)) + input.Limit = &limit } } if s.searchLimit > 0 { @@ -330,7 +312,7 @@ func (s *Scan) scanInput() *dynamodb.ScanInput { input.FilterExpression = &filter } if s.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input } @@ -348,16 +330,16 @@ type scanIter struct { output *dynamodb.ScanOutput err error idx int - n int64 + n int reqs int // last item evaluated - last map[string]*dynamodb.AttributeValue + last Item // cache of primary keys, used for generating LEKs keys map[string]struct{} // example LastEvaluatedKey and ExclusiveStartKey, used to lazily evaluate the primary keys if possible - exLEK map[string]*dynamodb.AttributeValue - exESK map[string]*dynamodb.AttributeValue + exLEK Item + exESK Item keyErr error unmarshal unmarshalFunc @@ -365,13 +347,7 @@ type scanIter struct { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *scanIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *scanIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *scanIter) Next(ctx context.Context, out interface{}) bool { redo: // stop if we have an error if ctx.Err() != nil { @@ -422,7 +398,7 @@ redo: itr.err = itr.scan.table.db.retry(ctx, func() error { var err error - itr.output, err = itr.scan.table.db.client.ScanWithContext(ctx, itr.input) + itr.output, err = itr.scan.table.db.client.Scan(ctx, itr.input) return err }) @@ -463,49 +439,44 @@ func (itr *scanIter) Err() error { // LastEvaluatedKey returns a key that can be used to continue this scan. // Use with SearchLimit for best results. -func (itr *scanIter) LastEvaluatedKey() PagingKey { +func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { if itr.output != nil { // if we've hit the end of our results, we can use the real LEK if itr.idx == len(itr.output.Items) { - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, nil } // figure out the primary keys if needed if itr.keys == nil && itr.keyErr == nil { - ctx, _ := defaultContext() // TODO(v2): take context instead of using the default itr.keys, itr.keyErr = itr.scan.table.primaryKeys(ctx, itr.exLEK, itr.exESK, itr.scan.index) } if itr.keyErr != nil { // primaryKeys can fail if the credentials lack DescribeTable permissions // in order to preserve backwards compatibility, we fall back to the old behavior and warn // see: https://github.com/guregu/dynamo/pull/187#issuecomment-1045183901 - // TODO(v2): rejigger this API. - itr.scan.table.db.log("dynamo: Warning:", itr.keyErr, "Returning a later LastEvaluatedKey.") - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to determine LastEvaluatedKey in scan: %w", itr.keyErr) } // we can't use the real LEK, so we need to infer the LEK from the last item we saw lek, err := lekify(itr.last, itr.keys) - // unfortunately, this API can't return an error so a warning is the best we can do... - // this matches old behavior before the LEK was automatically generated - // TODO(v2): fix this. if err != nil { - itr.scan.table.db.log("dynamo: Warning:", err, "Returning a later LastEvaluatedKey.") - return itr.output.LastEvaluatedKey + return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to infer LastEvaluatedKey in scan: %w", err) } - return lek + return lek, nil } - return nil + return nil, nil } type parallelScan struct { iters []*scanIter - items chan map[string]*dynamodb.AttributeValue + items chan Item - leks []PagingKey - cc *ConsumedCapacity - err error - mu *sync.Mutex + leks []PagingKey + lekErr error + + cc *ConsumedCapacity + err error + mu *sync.Mutex unmarshal unmarshalFunc } @@ -513,7 +484,7 @@ type parallelScan struct { func newParallelScan(iters []*scanIter, cc *ConsumedCapacity, skipLEK bool, unmarshal unmarshalFunc) *parallelScan { ps := ¶llelScan{ iters: iters, - items: make(chan map[string]*dynamodb.AttributeValue), + items: make(chan Item), cc: cc, mu: new(sync.Mutex), unmarshal: unmarshal, @@ -532,8 +503,8 @@ func (ps *parallelScan) run(ctx context.Context) { continue } grp.Go(func() error { - var item map[string]*dynamodb.AttributeValue - for iter.NextWithContext(ctx, &item) { + var item Item + for iter.Next(ctx, &item) { select { case <-ctx.Done(): return ctx.Err() @@ -543,9 +514,12 @@ func (ps *parallelScan) run(ctx context.Context) { } if ps.leks != nil { - lek := iter.LastEvaluatedKey() + lek, err := iter.LastEvaluatedKey(ctx) ps.mu.Lock() ps.leks[i] = lek + if err != nil && ps.lekErr == nil { + ps.lekErr = err + } ps.mu.Unlock() } } @@ -566,13 +540,7 @@ func (ps *parallelScan) run(ctx context.Context) { close(ps.items) } -func (ps *parallelScan) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return ps.NextWithContext(ctx, out) -} - -func (ps *parallelScan) NextWithContext(ctx context.Context, out interface{}) bool { +func (ps *parallelScan) Next(ctx context.Context, out interface{}) bool { select { case <-ctx.Done(): ps.setError(ctx.Err()) @@ -603,10 +571,10 @@ func (ps *parallelScan) Err() error { return ps.err } -func (ps *parallelScan) LastEvaluatedKeys() []PagingKey { +func (ps *parallelScan) LastEvaluatedKeys(_ context.Context) ([]PagingKey, error) { keys := make([]PagingKey, len(ps.leks)) ps.mu.Lock() defer ps.mu.Unlock() copy(keys, ps.leks) - return keys + return keys, ps.lekErr } diff --git a/scan_test.go b/scan_test.go index 9941b6a..c25622b 100644 --- a/scan_test.go +++ b/scan_test.go @@ -3,10 +3,9 @@ package dynamo import ( "context" "reflect" + "sync" "testing" "time" - - "github.com/aws/aws-sdk-go/aws" ) func TestScan(t *testing.T) { @@ -14,6 +13,7 @@ func TestScan(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() // first, add an item to make sure there is at least one item := widget{ @@ -21,13 +21,13 @@ func TestScan(t *testing.T) { Time: time.Now().UTC(), Msg: "hello", } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } // count items via Query - ct, err := table.Get("UserID", 42).Consistent(true).Count() + ct, err := table.Get("UserID", 42).Consistent(true).Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -36,7 +36,7 @@ func TestScan(t *testing.T) { t.Run("All", func(t *testing.T) { var result []widget var cc ConsumedCapacity - err = table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc).All(&result) + err = table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc).All(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -63,7 +63,7 @@ func TestScan(t *testing.T) { // check this against Scan's count, too t.Run("Count", func(t *testing.T) { var cc2 ConsumedCapacity - scanCt, err := table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc2).Count() + scanCt, err := table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc2).Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -108,6 +108,7 @@ func TestScanPaging(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() // prepare data insert := make([]interface{}, 10) @@ -118,7 +119,7 @@ func TestScanPaging(t *testing.T) { Msg: "garbage", } } - if _, err := table.Batch().Write().Put(insert...).Run(); err != nil { + if _, err := table.Batch().Write().Put(insert...).Run(ctx); err != nil { t.Fatal(err) } @@ -126,12 +127,16 @@ func TestScanPaging(t *testing.T) { widgets := [10]widget{} itr := table.Scan().Consistent(true).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { - itr.Next(&widgets[i]) + itr.Next(ctx, &widgets[i]) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) break } - itr = table.Scan().StartFrom(itr.LastEvaluatedKey()).SearchLimit(1).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", err) + } + itr = table.Scan().StartFrom(lek).SearchLimit(1).Iter() } for i, w := range widgets { if w.UserID == 0 && w.Time.IsZero() { @@ -144,17 +149,21 @@ func TestScanPaging(t *testing.T) { const segments = 2 ctx := context.Background() widgets := [10]widget{} - limit := int64(len(widgets) / segments) + limit := int(len(widgets) / segments) itr := table.Scan().Consistent(true).SearchLimit(limit).IterParallel(ctx, segments) for i := 0; i < len(widgets); { - for ; i < len(widgets) && itr.Next(&widgets[i]); i++ { + for ; i < len(widgets) && itr.Next(ctx, &widgets[i]); i++ { } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) break } t.Logf("parallel chunk: %d", i) - itr = table.Scan().SearchLimit(limit).IterParallelStartFrom(ctx, itr.LastEvaluatedKeys()) + lek, err := itr.LastEvaluatedKeys(ctx) + if err != nil { + t.Fatal("lek error", err) + } + itr = table.Scan().SearchLimit(limit).IterParallelStartFrom(ctx, lek) } for i, w := range widgets { if w.UserID == 0 && w.Time.IsZero() { @@ -168,7 +177,13 @@ func TestScanMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTableWidgets) + + testDB0 := *testDB + testDB0.descs = new(sync.Map) + freshTestDB := &testDB0 + + table := freshTestDB.Table(testTableWidgets) + ctx := context.Background() widgets := []interface{}{ widget{ @@ -188,7 +203,7 @@ func TestScanMagicLEK(t *testing.T) { }, } // prepare data - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Fatal(err) } @@ -196,11 +211,15 @@ func TestScanMagicLEK(t *testing.T) { itr := table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").Limit(2).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - itr = table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").StartFrom(itr.LastEvaluatedKey()).Limit(2).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", err) + } + itr = table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").StartFrom(lek).Limit(2).Iter() } }) @@ -208,16 +227,20 @@ func TestScanMagicLEK(t *testing.T) { itr := table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).Limit(2).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - itr = table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).StartFrom(itr.LastEvaluatedKey()).Limit(2).Iter() + lek, err := itr.LastEvaluatedKey(context.Background()) + if err != nil { + t.Error("LEK error", err) + } + itr = table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).StartFrom(lek).Limit(2).Iter() } }) t.Run("table cache", func(t *testing.T) { - pk, err := table.primaryKeys(aws.BackgroundContext(), nil, nil, "") + pk, err := table.primaryKeys(context.Background(), nil, nil, "") if err != nil { t.Fatal(err) } diff --git a/sse.go b/sse.go index b678f1e..9495455 100644 --- a/sse.go +++ b/sse.go @@ -1,6 +1,10 @@ package dynamo -import "time" +import ( + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) // SSEType is used to specify the type of server side encryption // to use on a table @@ -14,9 +18,9 @@ const ( type SSEDescription struct { InaccessibleEncryptionDateTime time.Time - KMSMasterKeyArn string - SSEType SSEType - Status string + KMSMasterKeyARN string + SSEType types.SSEType + Status types.SSEStatus } func lookupSSEType(sseType string) SSEType { diff --git a/substitute.go b/substitute.go index a503716..65c2631 100644 --- a/substitute.go +++ b/substitute.go @@ -8,31 +8,28 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - - "github.com/guregu/dynamo/internal/exprs" + "github.com/guregu/dynamo/v2/internal/exprs" ) // subber is a "mixin" for operators for keep track of subtituted keys and values type subber struct { - nameExpr map[string]*string - valueExpr map[string]*dynamodb.AttributeValue + nameExpr map[string]string + valueExpr Item } func (s *subber) subName(name string) string { if s.nameExpr == nil { - s.nameExpr = make(map[string]*string) + s.nameExpr = make(map[string]string) } sub := "#s" + encodeName(name) - s.nameExpr[sub] = aws.String(name) + s.nameExpr[sub] = name return sub } func (s *subber) subValue(value interface{}, flags encodeFlags) (string, error) { if s.valueExpr == nil { - s.valueExpr = make(map[string]*dynamodb.AttributeValue) + s.valueExpr = make(Item) } if lit, ok := value.(ExpressionLiteral); ok { @@ -144,7 +141,7 @@ type ExpressionLiteral struct { // AttributeNames is a map of placeholders (such as #foo) to attribute names. AttributeNames map[string]*string // AttributeValues is a map of placeholders (such as :bar) to attribute values. - AttributeValues map[string]*dynamodb.AttributeValue + AttributeValues Item } // we don't want people to accidentally refer to our placeholders, so just slap an x_ in front of theirs @@ -158,15 +155,15 @@ func (s *subber) merge(lit ExpressionLiteral) string { } if len(lit.AttributeNames) > 0 && s.nameExpr == nil { - s.nameExpr = make(map[string]*string) + s.nameExpr = make(map[string]string) } for k, v := range lit.AttributeNames { safe := prefix(k) - s.nameExpr[safe] = v + s.nameExpr[safe] = *v } if len(lit.AttributeValues) > 0 && s.valueExpr == nil { - s.valueExpr = make(map[string]*dynamodb.AttributeValue) + s.valueExpr = make(Item) } for k, v := range lit.AttributeValues { safe := prefix(k) diff --git a/substitute_test.go b/substitute_test.go index 7ee930e..3fd73c5 100644 --- a/substitute_test.go +++ b/substitute_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) func TestSubExpr(t *testing.T) { @@ -64,9 +64,9 @@ func TestSubMerge(t *testing.T) { "#abc": aws.String("custom"), "#abcdef": aws.String("model"), }, - AttributeValues: map[string]*dynamodb.AttributeValue{ - ":v": {S: aws.String("abc")}, - ":v0": {N: aws.String("555")}, + AttributeValues: Item{ + ":v": &types.AttributeValueMemberS{Value: "abc"}, + ":v0": &types.AttributeValueMemberN{Value: "555"}, }, } rewrite, err := s.subExpr("?", lit) @@ -84,8 +84,8 @@ func TestSubMerge(t *testing.T) { if !ok { t.Error("missing merged name:", k, foreign) } - if !reflect.DeepEqual(v, got) { - t.Error("merged name mismatch. want:", v, "got:", got) + if !reflect.DeepEqual(*v, got) { + t.Error("merged name mismatch. want:", *v, "got:", got) } } diff --git a/table.go b/table.go index 75f7722..056e991 100644 --- a/table.go +++ b/table.go @@ -2,12 +2,11 @@ package dynamo import ( "context" - "errors" "fmt" - "sync/atomic" + "time" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Status is an enumeration of table and index statuses. @@ -33,8 +32,6 @@ const ( type Table struct { name string db *DB - // desc is this table's cached description, used for inferring keys - desc *atomic.Value // Description } // Table returns a Table handle specified by name. @@ -42,7 +39,6 @@ func (db *DB) Table(name string) Table { return Table{ name: name, db: db, - desc: new(atomic.Value), } } @@ -53,15 +49,7 @@ func (table Table) Name() string { // Wait blocks until this table's status matches any status provided by want. // If no statuses are specified, the active status is used. -func (table Table) Wait(want ...Status) error { - ctx, cancel := defaultContext() - defer cancel() - return table.WaitWithContext(ctx, want...) -} - -// Wait blocks until this table's status matches any status provided by want. -// If no statuses are specified, the active status is used. -func (table Table) WaitWithContext(ctx context.Context, want ...Status) error { +func (table Table) Wait(ctx context.Context, want ...Status) error { if len(want) == 0 { want = []Status{ActiveStatus} } @@ -72,38 +60,43 @@ func (table Table) WaitWithContext(ctx context.Context, want ...Status) error { } } - err := table.db.retry(ctx, func() error { - desc, err := table.Describe().RunWithContext(ctx) - var aerr awserr.RequestFailure - if errors.As(err, &aerr) { - if aerr.Code() == "ResourceNotFoundException" { - if wantGone { - return nil - } - return errRetry - } - } - if err != nil { - return err - } + // I don't know why AWS wants a context _and_ a duration param. + // Infer it from context; if it's indefinite then set it to something really high (1 day) + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(24 * time.Hour) + } + maxDur := time.Until(deadline) - for _, status := range want { - if status == desc.Status { - return nil + if wantGone { + waiter := dynamodb.NewTableNotExistsWaiter(table.db.client) + return waiter.Wait(ctx, table.Describe().input(), maxDur) + } + + waiter := dynamodb.NewTableExistsWaiter(table.db.client, func(opts *dynamodb.TableExistsWaiterOptions) { + fallback := opts.Retryable + opts.Retryable = func(ctx context.Context, in *dynamodb.DescribeTableInput, out *dynamodb.DescribeTableOutput, err error) (bool, error) { + if err == nil && out != nil && out.Table != nil { + status := string(out.Table.TableStatus) + for _, wantStatus := range want { + if status == string(wantStatus) { + return false, nil + } + } } + return fallback(ctx, in, out, err) } - return errRetry }) - return err + return waiter.Wait(ctx, table.Describe().input(), maxDur) } // primaryKeys attempts to determine this table's primary keys. // It will try: -// - output LastEvaluatedKey -// - input ExclusiveStartKey -// - DescribeTable as a last resort (cached inside table) -func (table Table) primaryKeys(ctx context.Context, lek, esk map[string]*dynamodb.AttributeValue, index string) (map[string]struct{}, error) { - extract := func(item map[string]*dynamodb.AttributeValue) map[string]struct{} { +// - output LastEvaluatedKey +// - input ExclusiveStartKey +// - DescribeTable as a last resort (cached inside table) +func (table Table) primaryKeys(ctx context.Context, lek, esk Item, index string) (map[string]struct{}, error) { + extract := func(item Item) map[string]struct{} { keys := make(map[string]struct{}, len(item)) for k := range item { keys[k] = struct{}{} @@ -122,7 +115,8 @@ func (table Table) primaryKeys(ctx context.Context, lek, esk map[string]*dynamod // now we're forced to call DescribeTable // do we have a description cached? - if desc, ok := table.desc.Load().(Description); ok { + + if desc, ok := table.db.loadDesc(table.name); ok { keys := desc.keys(index) if keys != nil { return keys, nil @@ -133,7 +127,7 @@ func (table Table) primaryKeys(ctx context.Context, lek, esk map[string]*dynamod keys := make(map[string]struct{}) err := table.db.retry(ctx, func() error { - desc, err := table.Describe().RunWithContext(ctx) + desc, err := table.Describe().Run(ctx) if err != nil { return err } @@ -149,7 +143,7 @@ func (table Table) primaryKeys(ctx context.Context, lek, esk map[string]*dynamod return keys, nil } -func lekify(item map[string]*dynamodb.AttributeValue, keys map[string]struct{}) (map[string]*dynamodb.AttributeValue, error) { +func lekify(item Item, keys map[string]struct{}) (Item, error) { if item == nil { // this shouldn't happen because in queries without results, a LastEvaluatedKey should be given to us by AWS return nil, fmt.Errorf("dynamo: can't determine LastEvaluatedKey: no keys or results") @@ -157,7 +151,7 @@ func lekify(item map[string]*dynamodb.AttributeValue, keys map[string]struct{}) if keys == nil { return nil, fmt.Errorf("dynamo: can't determine LastEvaluatedKey: failed to infer primary keys") } - lek := make(map[string]*dynamodb.AttributeValue, len(keys)) + lek := make(Item, len(keys)) for k := range keys { v, ok := item[k] if !ok { @@ -180,34 +174,20 @@ func (table Table) DeleteTable() *DeleteTable { } // Run executes this request and deletes the table. -func (dt *DeleteTable) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return dt.RunWithContext(ctx) -} - -// RunWithContext executes this request and deletes the table. -func (dt *DeleteTable) RunWithContext(ctx context.Context) error { +func (dt *DeleteTable) Run(ctx context.Context) error { input := dt.input() return dt.table.db.retry(ctx, func() error { - _, err := dt.table.db.client.DeleteTableWithContext(ctx, input) + _, err := dt.table.db.client.DeleteTable(ctx, input) return err }) } // Wait executes this request and blocks until the table is finished deleting. -func (dt *DeleteTable) Wait() error { - ctx, cancel := defaultContext() - defer cancel() - return dt.WaitWithContext(ctx) -} - -// WaitWithContext executes this request and blocks until the table is finished deleting. -func (dt *DeleteTable) WaitWithContext(ctx context.Context) error { - if err := dt.RunWithContext(ctx); err != nil { +func (dt *DeleteTable) Wait(ctx context.Context) error { + if err := dt.Run(ctx); err != nil { return err } - return dt.table.WaitWithContext(ctx, NotExistsStatus) + return dt.table.Wait(ctx, NotExistsStatus) } func (dt *DeleteTable) input() *dynamodb.DeleteTableInput { @@ -255,7 +235,7 @@ type ConsumedCapacity struct { TableName string } -func addConsumedCapacity(cc *ConsumedCapacity, raw *dynamodb.ConsumedCapacity) { +func addConsumedCapacity(cc *ConsumedCapacity, raw *types.ConsumedCapacity) { if cc == nil || raw == nil { return } diff --git a/table_test.go b/table_test.go index 11b08b3..ca356cc 100644 --- a/table_test.go +++ b/table_test.go @@ -1,14 +1,15 @@ package dynamo import ( + "context" "fmt" "reflect" "sort" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) func TestTableLifecycle(t *testing.T) { @@ -19,6 +20,8 @@ func TestTableLifecycle(t *testing.T) { t.SkipNow() } + ctx := context.TODO() + now := time.Now().UTC() name := fmt.Sprintf("TestDB-%d", now.UnixNano()) @@ -37,11 +40,11 @@ func TestTableLifecycle(t *testing.T) { HashKeyType: StringType, RangeKey: "Bar", RangeKeyType: NumberType, - }).Wait(); err != nil { + }).Wait(ctx); err != nil { t.Fatal(err) } - desc, err := testDB.Table(name).Describe().Run() + desc, err := testDB.Table(name).Describe().Run(ctx) if err != nil { t.Fatal(err) } @@ -65,7 +68,7 @@ func TestTableLifecycle(t *testing.T) { RangeKeyType: NumberType, Throughput: Throughput{Read: 1, Write: 1}, ProjectionType: AllProjection, - ProjectionAttribs: []string{}, + ProjectionAttribs: []string(nil), }, { Name: "Seq-ID-index", @@ -77,7 +80,7 @@ func TestTableLifecycle(t *testing.T) { RangeKeyType: StringType, Throughput: Throughput{Read: 1, Write: 1}, ProjectionType: AllProjection, - ProjectionAttribs: []string{}, + ProjectionAttribs: []string(nil), }, { Name: "UUID-index", @@ -87,7 +90,7 @@ func TestTableLifecycle(t *testing.T) { HashKeyType: StringType, Throughput: Throughput{Read: 1, Write: 1}, ProjectionType: AllProjection, - ProjectionAttribs: []string{}, + ProjectionAttribs: []string(nil), }, }, LSI: []Index{ @@ -103,7 +106,7 @@ func TestTableLifecycle(t *testing.T) { RangeKeyType: NumberType, Throughput: Throughput{Read: 1, Write: 1}, ProjectionType: AllProjection, - ProjectionAttribs: []string{}, + ProjectionAttribs: []string(nil), }, }, } @@ -114,32 +117,32 @@ func TestTableLifecycle(t *testing.T) { // make sure it really works table := testDB.Table(name) - if err := table.Put(UserAction{UserID: "test", Time: now, Seq: 1, UUID: "42"}).Run(); err != nil { + if err := table.Put(UserAction{UserID: "test", Time: now, Seq: 1, UUID: "42"}).Run(ctx); err != nil { t.Fatal(err) } // delete & wait - if err := testDB.Table(name).DeleteTable().Wait(); err != nil { + if err := testDB.Table(name).DeleteTable().Wait(ctx); err != nil { t.Fatal(err) } } func TestAddConsumedCapacity(t *testing.T) { - raw := &dynamodb.ConsumedCapacity{ + raw := &types.ConsumedCapacity{ TableName: aws.String("TestTable"), - Table: &dynamodb.Capacity{ + Table: &types.Capacity{ CapacityUnits: aws.Float64(9), ReadCapacityUnits: aws.Float64(4), WriteCapacityUnits: aws.Float64(5), }, - GlobalSecondaryIndexes: map[string]*dynamodb.Capacity{ + GlobalSecondaryIndexes: map[string]types.Capacity{ "TestGSI": { CapacityUnits: aws.Float64(3), ReadCapacityUnits: aws.Float64(1), WriteCapacityUnits: aws.Float64(2), }, }, - LocalSecondaryIndexes: map[string]*dynamodb.Capacity{ + LocalSecondaryIndexes: map[string]types.Capacity{ "TestLSI": { CapacityUnits: aws.Float64(30), ReadCapacityUnits: aws.Float64(10), @@ -150,7 +153,7 @@ func TestAddConsumedCapacity(t *testing.T) { ReadCapacityUnits: aws.Float64(15), WriteCapacityUnits: aws.Float64(27), } - expected := ConsumedCapacity{ + expected := &ConsumedCapacity{ TableName: *raw.TableName, Table: *raw.Table.CapacityUnits, TableRead: *raw.Table.ReadCapacityUnits, @@ -166,8 +169,8 @@ func TestAddConsumedCapacity(t *testing.T) { Write: *raw.WriteCapacityUnits, } - var cc ConsumedCapacity - addConsumedCapacity(&cc, raw) + var cc = new(ConsumedCapacity) + addConsumedCapacity(cc, raw) if !reflect.DeepEqual(cc, expected) { t.Error("bad ConsumedCapacity:", cc, "≠", expected) diff --git a/ttl.go b/ttl.go index aaaebdc..10f9e65 100644 --- a/ttl.go +++ b/ttl.go @@ -3,8 +3,9 @@ package dynamo import ( "context" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // UpdateTTL is a request to enable or disable a table's time to live functionality. @@ -31,18 +32,11 @@ func (table Table) UpdateTTL(attribute string, enabled bool) *UpdateTTL { } // Run executes this request. -func (ttl *UpdateTTL) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return ttl.RunWithContext(ctx) -} - -// RunWithContext executes this request. -func (ttl *UpdateTTL) RunWithContext(ctx context.Context) error { +func (ttl *UpdateTTL) Run(ctx context.Context) error { input := ttl.input() err := ttl.table.db.retry(ctx, func() error { - _, err := ttl.table.db.client.UpdateTimeToLiveWithContext(ctx, input) + _, err := ttl.table.db.client.UpdateTimeToLive(ctx, input) return err }) return err @@ -51,7 +45,7 @@ func (ttl *UpdateTTL) RunWithContext(ctx context.Context) error { func (ttl *UpdateTTL) input() *dynamodb.UpdateTimeToLiveInput { return &dynamodb.UpdateTimeToLiveInput{ TableName: aws.String(ttl.table.Name()), - TimeToLiveSpecification: &dynamodb.TimeToLiveSpecification{ + TimeToLiveSpecification: &types.TimeToLiveSpecification{ Enabled: aws.Bool(ttl.enabled), AttributeName: aws.String(ttl.attrib), }, @@ -69,20 +63,13 @@ func (table Table) DescribeTTL() *DescribeTTL { } // Run executes this request and returns details about time to live, or an error. -func (d *DescribeTTL) Run() (TTLDescription, error) { - ctx, cancel := defaultContext() - defer cancel() - return d.RunWithContext(ctx) -} - -// RunWithContext executes this request and returns details about time to live, or an error. -func (d *DescribeTTL) RunWithContext(ctx context.Context) (TTLDescription, error) { +func (d *DescribeTTL) Run(ctx context.Context) (TTLDescription, error) { input := d.input() var result *dynamodb.DescribeTimeToLiveOutput err := d.table.db.retry(ctx, func() error { var err error - result, err = d.table.db.client.DescribeTimeToLiveWithContext(ctx, input) + result, err = d.table.db.client.DescribeTimeToLive(ctx, input) return err }) if err != nil { @@ -92,8 +79,8 @@ func (d *DescribeTTL) RunWithContext(ctx context.Context) (TTLDescription, error desc := TTLDescription{ Status: TTLDisabled, } - if result.TimeToLiveDescription.TimeToLiveStatus != nil { - desc.Status = TTLStatus(*result.TimeToLiveDescription.TimeToLiveStatus) + if result.TimeToLiveDescription.TimeToLiveStatus != "" { + desc.Status = TTLStatus(result.TimeToLiveDescription.TimeToLiveStatus) } if result.TimeToLiveDescription.AttributeName != nil { desc.Attribute = *result.TimeToLiveDescription.AttributeName diff --git a/ttl_test.go b/ttl_test.go index 9ffcafc..ed78111 100644 --- a/ttl_test.go +++ b/ttl_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -9,8 +10,9 @@ func TestDescribeTTL(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() - desc, err := table.DescribeTTL().Run() + desc, err := table.DescribeTTL().Run(ctx) if err != nil { t.Error(err) return diff --git a/tx.go b/tx.go index c0e5316..1caecd5 100644 --- a/tx.go +++ b/tx.go @@ -6,8 +6,9 @@ import ( "encoding/hex" "errors" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // ErrNoInput is returned when APIs that can take multiple inputs are run with zero inputs. @@ -15,7 +16,7 @@ import ( var ErrNoInput = errors.New("dynamo: no input items") type getTxOp interface { - getTxItem() (*dynamodb.TransactGetItem, error) + getTxItem() (types.TransactGetItem, error) } // GetTx is a transaction to retrieve items. @@ -60,14 +61,7 @@ func (tx *GetTx) ConsumedCapacity(cc *ConsumedCapacity) *GetTx { } // Run executes this transaction and unmarshals everything specified by GetOne. -func (tx *GetTx) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return tx.RunWithContext(ctx) -} - -// RunWithContext executes this transaction and unmarshals everything specified by GetOne. -func (tx *GetTx) RunWithContext(ctx context.Context) error { +func (tx *GetTx) Run(ctx context.Context) error { input, err := tx.input() if err != nil { return err @@ -75,10 +69,10 @@ func (tx *GetTx) RunWithContext(ctx context.Context) error { var resp *dynamodb.TransactGetItemsOutput err = tx.db.retry(ctx, func() error { var err error - resp, err = tx.db.client.TransactGetItemsWithContext(ctx, input) + resp, err = tx.db.client.TransactGetItems(ctx, input) if tx.cc != nil && resp != nil { for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, cc) + addConsumedCapacity(tx.cc, &cc) } } return err @@ -107,14 +101,7 @@ func (tx *GetTx) unmarshal(resp *dynamodb.TransactGetItemsOutput) error { } // All executes this transaction and unmarshals every value to out, which must be a pointer to a slice. -func (tx *GetTx) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return tx.AllWithContext(ctx, out) -} - -// AllWithContext executes this transaction and unmarshals every value to out, which must be a pointer to a slice. -func (tx *GetTx) AllWithContext(ctx context.Context, out interface{}) error { +func (tx *GetTx) All(ctx context.Context, out interface{}) error { input, err := tx.input() if err != nil { return err @@ -122,10 +109,10 @@ func (tx *GetTx) AllWithContext(ctx context.Context, out interface{}) error { var resp *dynamodb.TransactGetItemsOutput err = tx.db.retry(ctx, func() error { var err error - resp, err = tx.db.client.TransactGetItemsWithContext(ctx, input) + resp, err = tx.db.client.TransactGetItems(ctx, input) if tx.cc != nil && resp != nil { for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, cc) + addConsumedCapacity(tx.cc, &cc) } } return err @@ -164,13 +151,13 @@ func (tx *GetTx) input() (*dynamodb.TransactGetItemsInput, error) { input.TransactItems = append(input.TransactItems, tgi) } if tx.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input, nil } type writeTxOp interface { - writeTxItem() (*dynamodb.TransactWriteItem, error) + writeTxItem() (*types.TransactWriteItem, error) } // WriteTx is a transaction to delete, put, update, and check items. @@ -259,14 +246,7 @@ func (tx *WriteTx) ConsumedCapacity(cc *ConsumedCapacity) *WriteTx { } // Run executes this transaction. -func (tx *WriteTx) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return tx.RunWithContext(ctx) -} - -// RunWithContext executes this transaction. -func (tx *WriteTx) RunWithContext(ctx context.Context) error { +func (tx *WriteTx) Run(ctx context.Context) error { if tx.err != nil { return tx.err } @@ -275,10 +255,10 @@ func (tx *WriteTx) RunWithContext(ctx context.Context) error { return err } err = tx.db.retry(ctx, func() error { - out, err := tx.db.client.TransactWriteItemsWithContext(ctx, input) + out, err := tx.db.client.TransactWriteItems(ctx, input) if tx.cc != nil && out != nil { for _, cc := range out.ConsumedCapacity { - addConsumedCapacity(tx.cc, cc) + addConsumedCapacity(tx.cc, &cc) } } return err @@ -296,13 +276,13 @@ func (tx *WriteTx) input() (*dynamodb.TransactWriteItemsInput, error) { if err != nil { return nil, err } - input.TransactItems = append(input.TransactItems, wti) + input.TransactItems = append(input.TransactItems, *wti) } if tx.token != "" { input.ClientRequestToken = aws.String(tx.token) } if tx.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input, nil } @@ -313,7 +293,7 @@ func (tx *WriteTx) setError(err error) { } } -func isResponsesEmpty(resps []*dynamodb.ItemResponse) bool { +func isResponsesEmpty(resps []types.ItemResponse) bool { for _, resp := range resps { if resp.Item != nil { return false diff --git a/tx_test.go b/tx_test.go index 0fd28e0..d44d709 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,12 +1,14 @@ package dynamo import ( + "context" + "errors" "reflect" "sync" "testing" "time" - "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/smithy-go" ) func TestTx(t *testing.T) { @@ -14,6 +16,8 @@ func TestTx(t *testing.T) { t.Skip(offlineSkipMsg) } + ctx := context.TODO() + date1 := time.Date(1969, 1, 1, 1, 1, 1, 0, time.UTC) date2 := time.Date(1969, 2, 2, 2, 2, 2, 0, time.UTC) date3 := time.Date(1969, 3, 3, 3, 3, 3, 0, time.UTC) @@ -29,7 +33,7 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget2)) tx.Check(table.Check("UserID", 69).Range("Time", date3).IfNotExists()) tx.ConsumedCapacity(&cc) - err := tx.Run() + err := tx.Run(ctx) if err != nil { t.Error(err) } @@ -38,7 +42,7 @@ func TestTx(t *testing.T) { } ccold = cc - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -63,7 +67,7 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget1)) tx.Put(table.Put(widget2)) tx.ConsumedCapacity(&cc) - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -72,7 +76,7 @@ func TestTx(t *testing.T) { } ccold = cc - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -94,7 +98,7 @@ func TestTx(t *testing.T) { getTx.GetOne(table.Get("UserID", 69).Range("Time", Equal, date2), &record2) getTx.GetOne(table.Get("UserID", 69).Range("Time", Equal, date3), &record3) getTx.ConsumedCapacity(&cc2) - err = getTx.Run() + err = getTx.Run(ctx) if err != nil { t.Error(err) } @@ -114,7 +118,7 @@ func TestTx(t *testing.T) { // All oldCC2 := cc2 var records []widget - err = getTx.All(&records) + err = getTx.All(ctx, &records) if err != nil { t.Error(err) } @@ -130,7 +134,7 @@ func TestTx(t *testing.T) { tx = testDB.WriteTx() tx.Check(table.Check("UserID", widget1.UserID).Range("Time", widget1.Time).If("Msg = ?", widget1.Msg)) tx.Update(table.Update("UserID", widget2.UserID).Range("Time", widget2.Time).Set("Msg", widget2.Msg)) - if err = tx.Run(); err != nil { + if err = tx.Run(ctx); err != nil { t.Error(err) } @@ -138,12 +142,12 @@ func TestTx(t *testing.T) { tx = testDB.WriteTx() tx.Delete(table.Delete("UserID", widget1.UserID).Range("Time", widget1.Time).If("Msg = ?", widget1.Msg)) tx.Delete(table.Delete("UserID", widget2.UserID).Range("Time", widget2.Time).If("Msg = ?", widget2.Msg)) - if err = tx.Run(); err != nil { + if err = tx.Run(ctx); err != nil { t.Error(err) } // zero results - if err = getTx.Run(); err != ErrNotFound { + if err = getTx.Run(ctx); err != ErrNotFound { t.Error("expected ErrNotFound, got:", err) } @@ -152,11 +156,12 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget{UserID: 69, Time: date1}).If("'Msg' = ?", "should not exist")) tx.Put(table.Put(widget{UserID: 69, Time: date2})) tx.Check(table.Check("UserID", 69).Range("Time", date3).IfExists().If("Msg = ?", "don't exist foo")) - err = tx.Run() + err = tx.Run(ctx) if err == nil { t.Error("expected error") } else { - if err.(awserr.Error).Code() != "TransactionCanceledException" { + var apiErr smithy.APIError + if errors.As(err, &apiErr) && apiErr.ErrorCode() != "TransactionCanceledException" { t.Error("unexpected error:", err) } } @@ -165,12 +170,12 @@ func TestTx(t *testing.T) { t.Logf("All: %+v (len: %d)", records, len(records)) // no input - err = testDB.GetTx().All(nil) + err = testDB.GetTx().All(ctx, nil) if err != ErrNoInput { t.Error("unexpected error", err) } - err = testDB.WriteTx().Run() + err = testDB.WriteTx().Run(ctx) if err != ErrNoInput { t.Error("unexpected error", err) } @@ -180,12 +185,13 @@ func TestTxRetry(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() date1 := time.Date(1999, 1, 1, 1, 1, 1, 0, time.UTC) widget1 := widget{UserID: 69, Time: date1, Msg: "dog", Count: 0} table := testDB.Table(testTableWidgets) - if err := table.Put(widget1).Run(); err != nil { + if err := table.Put(widget1).Run(ctx); err != nil { t.Fatal(err) } @@ -202,7 +208,8 @@ func TestTxRetry(t *testing.T) { tx.Update(table.Update("UserID", widget1.UserID). Range("Time", widget1.Time). Add("Count", 1)) - if err := tx.Run(); err != nil { + if err := tx.Run(ctx); err != nil { + // spew.Dump(err) panic(err) } }() @@ -216,7 +223,7 @@ func TestTxRetry(t *testing.T) { tx.Update(table.Update("UserID", widget1.UserID). Range("Time", widget1.Time).Add("Count", 1). If("'Count' = ?", -1)) - if err := tx.Run(); err != nil && !IsCondCheckFailed(err) { + if err := tx.Run(ctx); err != nil && !IsCondCheckFailed(err) { panic(err) } }() @@ -227,13 +234,13 @@ func TestTxRetry(t *testing.T) { defer wg.Done() tx := testDB.WriteTx() tx.Update(table.Update("UserID", "\u0002").Set("Foo", "")) - _ = tx.Run() + _ = tx.Run(ctx) }() wg.Wait() var got widget - if err := table.Get("UserID", widget1.UserID).Range("Time", Equal, widget1.Time).One(&got); err != nil { + if err := table.Get("UserID", widget1.UserID).Range("Time", Equal, widget1.Time).One(ctx, &got); err != nil { t.Fatal(err) } diff --git a/update.go b/update.go index 10a12c7..f841769 100644 --- a/update.go +++ b/update.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Update represents changes to an existing item. @@ -17,10 +17,10 @@ type Update struct { returnType string hashKey string - hashValue *dynamodb.AttributeValue + hashValue types.AttributeValue rangeKey string - rangeValue *dynamodb.AttributeValue + rangeValue types.AttributeValue set []string add map[string]string @@ -210,19 +210,17 @@ func (u *Update) DeleteFromSet(path string, value interface{}) *Update { u.setError(err) return u } - switch { + switch t := v.(type) { // ok: - case v.NS != nil: - case v.SS != nil: - case v.BS != nil: + case *types.AttributeValueMemberNS, *types.AttributeValueMemberSS, *types.AttributeValueMemberBS: // need to box: - case v.N != nil: - v = &dynamodb.AttributeValue{NS: []*string{v.N}} - case v.S != nil: - v = &dynamodb.AttributeValue{SS: []*string{v.S}} - case v.B != nil: - v = &dynamodb.AttributeValue{BS: [][]byte{v.B}} + case *types.AttributeValueMemberN: + v = &types.AttributeValueMemberNS{Value: []string{t.Value}} + case *types.AttributeValueMemberS: + v = &types.AttributeValueMemberSS{Value: []string{t.Value}} + case *types.AttributeValueMemberB: + v = &types.AttributeValueMemberBS{Value: [][]byte{t.Value}} default: u.setError(fmt.Errorf("dynamo: Update.DeleteFromSet given unsupported value: %v (%T: %s)", value, value, avTypeName(v))) @@ -289,14 +287,7 @@ func (u *Update) ConsumedCapacity(cc *ConsumedCapacity) *Update { } // Run executes this update. -func (u *Update) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return u.RunWithContext(ctx) -} - -// RunWithContext executes this update. -func (u *Update) RunWithContext(ctx context.Context) error { +func (u *Update) Run(ctx context.Context) error { u.returnType = "NONE" _, err := u.run(ctx) return err @@ -304,15 +295,7 @@ func (u *Update) RunWithContext(ctx context.Context) error { // Value executes this update, encoding out with the new value after the update. // This is equivalent to ReturnValues = ALL_NEW in the DynamoDB API. -func (u *Update) Value(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.ValueWithContext(ctx, out) -} - -// ValueWithContext executes this update, encoding out with the new value after the update. -// This is equivalent to ReturnValues = ALL_NEW in the DynamoDB API. -func (u *Update) ValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) Value(ctx context.Context, out interface{}) error { u.returnType = "ALL_NEW" output, err := u.run(ctx) if err != nil { @@ -323,15 +306,7 @@ func (u *Update) ValueWithContext(ctx context.Context, out interface{}) error { // OldValue executes this update, encoding out with the old value before the update. // This is equivalent to ReturnValues = ALL_OLD in the DynamoDB API. -func (u *Update) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OldValueWithContext(ctx, out) -} - -// OldValueWithContext executes this update, encoding out with the old value before the update. -// This is equivalent to ReturnValues = ALL_OLD in the DynamoDB API. -func (u *Update) OldValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OldValue(ctx context.Context, out interface{}) error { u.returnType = "ALL_OLD" output, err := u.run(ctx) if err != nil { @@ -342,15 +317,7 @@ func (u *Update) OldValueWithContext(ctx context.Context, out interface{}) error // OnlyUpdatedValue executes this update, encoding out with only with new values of the attributes that were changed. // This is equivalent to ReturnValues = UPDATED_NEW in the DynamoDB API. -func (u *Update) OnlyUpdatedValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OnlyUpdatedValueWithContext(ctx, out) -} - -// OnlyUpdatedValueWithContext executes this update, encoding out with only with new values of the attributes that were changed. -// This is equivalent to ReturnValues = UPDATED_NEW in the DynamoDB API. -func (u *Update) OnlyUpdatedValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OnlyUpdatedValue(ctx context.Context, out interface{}) error { u.returnType = "UPDATED_NEW" output, err := u.run(ctx) if err != nil { @@ -361,15 +328,7 @@ func (u *Update) OnlyUpdatedValueWithContext(ctx context.Context, out interface{ // OnlyUpdatedOldValue executes this update, encoding out with only with old values of the attributes that were changed. // This is equivalent to ReturnValues = UPDATED_OLD in the DynamoDB API. -func (u *Update) OnlyUpdatedOldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OnlyUpdatedOldValueWithContext(ctx, out) -} - -// OnlyUpdatedOldValueWithContext executes this update, encoding out with only with old values of the attributes that were changed. -// This is equivalent to ReturnValues = UPDATED_OLD in the DynamoDB API. -func (u *Update) OnlyUpdatedOldValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OnlyUpdatedOldValue(ctx context.Context, out interface{}) error { u.returnType = "UPDATED_OLD" output, err := u.run(ctx) if err != nil { @@ -387,7 +346,7 @@ func (u *Update) run(ctx context.Context) (*dynamodb.UpdateItemOutput, error) { var output *dynamodb.UpdateItemOutput err := u.table.db.retry(ctx, func() error { var err error - output, err = u.table.db.client.UpdateItemWithContext(ctx, input) + output, err = u.table.db.client.UpdateItem(ctx, input) return err }) if u.cc != nil { @@ -403,24 +362,24 @@ func (u *Update) updateInput() *dynamodb.UpdateItemInput { UpdateExpression: u.updateExpr(), ExpressionAttributeNames: u.nameExpr, ExpressionAttributeValues: u.valueExpr, - ReturnValues: &u.returnType, + ReturnValues: types.ReturnValue(u.returnType), } if u.condition != "" { input.ConditionExpression = &u.condition } if u.cc != nil { - input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) + input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } return input } -func (u *Update) writeTxItem() (*dynamodb.TransactWriteItem, error) { +func (u *Update) writeTxItem() (*types.TransactWriteItem, error) { if u.err != nil { return nil, u.err } input := u.updateInput() - item := &dynamodb.TransactWriteItem{ - Update: &dynamodb.Update{ + item := &types.TransactWriteItem{ + Update: &types.Update{ TableName: input.TableName, Key: input.Key, UpdateExpression: input.UpdateExpression, @@ -433,8 +392,8 @@ func (u *Update) writeTxItem() (*dynamodb.TransactWriteItem, error) { return item, nil } -func (u *Update) key() map[string]*dynamodb.AttributeValue { - key := map[string]*dynamodb.AttributeValue{ +func (u *Update) key() Item { + key := Item{ u.hashKey: u.hashValue, } if u.rangeKey != "" { diff --git a/update_test.go b/update_test.go index ae22824..fce154e 100644 --- a/update_test.go +++ b/update_test.go @@ -1,12 +1,13 @@ package dynamo import ( + "context" "reflect" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) func TestUpdate(t *testing.T) { @@ -14,6 +15,7 @@ func TestUpdate(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() type widget2 struct { widget @@ -40,7 +42,7 @@ func TestUpdate(t *testing.T) { MySet2: map[string]struct{}{"a": {}, "b": {}, "bad1": {}, "c": {}, "bad2": {}}, MySet3: map[int64]struct{}{1: {}, 999: {}, 2: {}, 3: {}, 555: {}}, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -51,8 +53,8 @@ func TestUpdate(t *testing.T) { "#meta": "Meta", "#pet": "pet", }), - AttributeValues: map[string]*dynamodb.AttributeValue{ - ":cat": {S: aws.String("猫")}, + AttributeValues: Item{ + ":cat": &types.AttributeValueMemberS{Value: "猫"}, }, } rmLit := ExpressionLiteral{ @@ -67,8 +69,8 @@ func TestUpdate(t *testing.T) { AttributeNames: aws.StringMap(map[string]string{ "#msg": "Msg", }), - AttributeValues: map[string]*dynamodb.AttributeValue{ - ":hi": {S: aws.String("hello")}, + AttributeValues: Item{ + ":hi": &types.AttributeValueMemberS{Value: "hello"}, }, } @@ -90,7 +92,7 @@ func TestUpdate(t *testing.T) { DeleteFromSet("MySet2", []string{"bad1", "bad2"}). DeleteFromSet("MySet3", map[int64]struct{}{999: {}, 555: {}}). ConsumedCapacity(&cc). - Value(&result) + Value(ctx, &result) expected := widget2{ widget: widget{ @@ -130,7 +132,7 @@ func TestUpdate(t *testing.T) { Range("Time", item.Time). Set("Msg", expected2.Msg). Add("Count", 1). - OnlyUpdatedValue(&updated) + OnlyUpdatedValue(ctx, &updated) if err != nil { t.Error("unexpected error:", err) } @@ -143,7 +145,7 @@ func TestUpdate(t *testing.T) { Range("Time", item.Time). Set("Msg", "this shouldn't be seen"). Add("Count", 100). - OnlyUpdatedOldValue(&updatedOld) + OnlyUpdatedOldValue(ctx, &updatedOld) if err != nil { t.Error("unexpected error:", err) } @@ -158,7 +160,7 @@ func TestUpdate(t *testing.T) { Add("Count", 1). If("'Count' > ?", 100). If("(MeaningOfLife = ?)", 42). - Value(&result) + Value(ctx, &result) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -169,6 +171,7 @@ func TestUpdateNil(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() // first, add an item to make sure there is at least one item := widget{ @@ -180,7 +183,7 @@ func TestUpdateNil(t *testing.T) { }, Count: 100, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) t.FailNow() @@ -199,7 +202,7 @@ func TestUpdateNil(t *testing.T) { Set("Meta.'ok'", (*ptrTextMarshaler)(nil)). SetExpr("'Count' = ?", (*textMarshaler)(nil)). SetExpr("MsgPtr = ?", ""). - Value(&result) + Value(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -224,6 +227,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() type widget2 struct { widget @@ -241,7 +245,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { CStr: customString("delete me"), SPtr: &str, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) t.FailNow() @@ -252,7 +256,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { err = table.Update("UserID", item.UserID).Range("Time", item.Time). Set("CStr", customString("")). Set("SPtr", nil). - Value(&result) + Value(ctx, &result) if err != nil { t.Error("unexpected error:", err) } diff --git a/updatetable.go b/updatetable.go index 9563659..1b6eae3 100644 --- a/updatetable.go +++ b/updatetable.go @@ -4,8 +4,9 @@ import ( "context" "errors" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // UpdateTable is a request to change a table's settings. @@ -14,7 +15,7 @@ type UpdateTable struct { table Table r, w int64 // throughput - billingMode *string + billingMode types.BillingMode disableStream bool streamView StreamView @@ -22,7 +23,7 @@ type UpdateTable struct { updateIdx map[string]Throughput createIdx []Index deleteIdx []string - ads []*dynamodb.AttributeDefinition + ads []types.AttributeDefinition err error } @@ -39,9 +40,9 @@ func (table Table) UpdateTable() *UpdateTable { // If enabled is false, this table will be changed to provisioned billing mode. func (ut *UpdateTable) OnDemand(enabled bool) *UpdateTable { if enabled { - ut.billingMode = aws.String(dynamodb.BillingModePayPerRequest) + ut.billingMode = types.BillingModePayPerRequest } else { - ut.billingMode = aws.String(dynamodb.BillingModeProvisioned) + ut.billingMode = types.BillingModeProvisioned } return ut } @@ -106,13 +107,7 @@ func (ut *UpdateTable) DisableStream() *UpdateTable { } // Run executes this request and describes the table. -func (ut *UpdateTable) Run() (Description, error) { - ctx, cancel := defaultContext() - defer cancel() - return ut.RunWithContext(ctx) -} - -func (ut *UpdateTable) RunWithContext(ctx context.Context) (Description, error) { +func (ut *UpdateTable) Run(ctx context.Context) (Description, error) { if ut.err != nil { return Description{}, ut.err } @@ -122,7 +117,7 @@ func (ut *UpdateTable) RunWithContext(ctx context.Context) (Description, error) var result *dynamodb.UpdateTableOutput err := ut.table.db.retry(ctx, func() error { var err error - result, err = ut.table.db.client.UpdateTableWithContext(ctx, input) + result, err = ut.table.db.client.UpdateTable(ctx, input) return err }) if err != nil { @@ -140,27 +135,27 @@ func (ut *UpdateTable) input() *dynamodb.UpdateTableInput { } if ut.r != 0 || ut.w != 0 { - input.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + input.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: &ut.r, WriteCapacityUnits: &ut.w, } } if ut.disableStream { - input.StreamSpecification = &dynamodb.StreamSpecification{ + input.StreamSpecification = &types.StreamSpecification{ StreamEnabled: aws.Bool(false), } } else if ut.streamView != "" { - input.StreamSpecification = &dynamodb.StreamSpecification{ + input.StreamSpecification = &types.StreamSpecification{ StreamEnabled: aws.Bool(true), - StreamViewType: aws.String((string)(ut.streamView)), + StreamViewType: types.StreamViewType(ut.streamView), } } for index, thru := range ut.updateIdx { - up := &dynamodb.GlobalSecondaryIndexUpdate{Update: &dynamodb.UpdateGlobalSecondaryIndexAction{ + up := types.GlobalSecondaryIndexUpdate{Update: &types.UpdateGlobalSecondaryIndexAction{ IndexName: aws.String(index), - ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ProvisionedThroughput: &types.ProvisionedThroughput{ ReadCapacityUnits: aws.Int64(thru.Read), WriteCapacityUnits: aws.Int64(thru.Write), }, @@ -168,11 +163,11 @@ func (ut *UpdateTable) input() *dynamodb.UpdateTableInput { input.GlobalSecondaryIndexUpdates = append(input.GlobalSecondaryIndexUpdates, up) } for _, index := range ut.createIdx { - up := &dynamodb.GlobalSecondaryIndexUpdate{Create: createIndexAction(index)} + up := types.GlobalSecondaryIndexUpdate{Create: createIndexAction(index)} input.GlobalSecondaryIndexUpdates = append(input.GlobalSecondaryIndexUpdates, up) } for _, del := range ut.deleteIdx { - up := &dynamodb.GlobalSecondaryIndexUpdate{Delete: &dynamodb.DeleteGlobalSecondaryIndexAction{ + up := types.GlobalSecondaryIndexUpdate{Delete: &types.DeleteGlobalSecondaryIndexAction{ IndexName: aws.String(del), }} input.GlobalSecondaryIndexUpdates = append(input.GlobalSecondaryIndexUpdates, up) @@ -187,40 +182,40 @@ func (ut *UpdateTable) addAD(name string, typ KeyType) { } } - ut.ads = append(ut.ads, &dynamodb.AttributeDefinition{ + ut.ads = append(ut.ads, types.AttributeDefinition{ AttributeName: &name, - AttributeType: aws.String((string)(typ)), + AttributeType: types.ScalarAttributeType(typ), }) } -func createIndexAction(index Index) *dynamodb.CreateGlobalSecondaryIndexAction { - ks := []*dynamodb.KeySchemaElement{ +func createIndexAction(index Index) *types.CreateGlobalSecondaryIndexAction { + ks := []types.KeySchemaElement{ { AttributeName: &index.HashKey, - KeyType: aws.String(dynamodb.KeyTypeHash), + KeyType: types.KeyTypeHash, }, } if index.RangeKey != "" { - ks = append(ks, &dynamodb.KeySchemaElement{ + ks = append(ks, types.KeySchemaElement{ AttributeName: &index.RangeKey, - KeyType: aws.String(dynamodb.KeyTypeRange), + KeyType: types.KeyTypeRange, }) } - add := &dynamodb.CreateGlobalSecondaryIndexAction{ + add := &types.CreateGlobalSecondaryIndexAction{ IndexName: &index.Name, KeySchema: ks, - Projection: &dynamodb.Projection{ - ProjectionType: aws.String((string)(index.ProjectionType)), + Projection: &types.Projection{ + ProjectionType: types.ProjectionType(index.ProjectionType), }, } if index.Throughput.Read > 0 && index.Throughput.Write > 0 { - add.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{ + add.ProvisionedThroughput = &types.ProvisionedThroughput{ ReadCapacityUnits: aws.Int64(index.Throughput.Read), WriteCapacityUnits: aws.Int64(index.Throughput.Write), } } if index.ProjectionType == IncludeProjection { - add.Projection.NonKeyAttributes = aws.StringSlice(index.ProjectionAttribs) + add.Projection.NonKeyAttributes = index.ProjectionAttribs } return add } diff --git a/updatetable_test.go b/updatetable_test.go index 472b641..1fa5913 100644 --- a/updatetable_test.go +++ b/updatetable_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -10,6 +11,7 @@ func _TestUpdateTable(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTableWidgets) + ctx := context.TODO() desc, err := table.UpdateTable().CreateIndex(Index{ Name: "test123", @@ -23,7 +25,7 @@ func _TestUpdateTable(t *testing.T) { Read: 1, Write: 1, }, - }).Run() + }).Run(ctx) // desc, err := table.UpdateTable().DeleteIndex("test123").Run()