diff --git a/pkg/directory/v3/model.go b/pkg/directory/v3/model.go index 75be189..1112a63 100644 --- a/pkg/directory/v3/model.go +++ b/pkg/directory/v3/model.go @@ -81,6 +81,7 @@ func (s *Model) GetManifest(req *dsm3.GetManifestRequest, stream dsm3.Model_GetM return err } + // optimistic concurrency check inMD, _ := metadata.FromIncomingContext(stream.Context()) if lo.Contains(inMD.Get(headers.IfNoneMatch), manifest.Metadata.Etag) { return nil @@ -145,6 +146,7 @@ func (s *Model) SetManifest(stream dsm3.Model_SetManifestServer) error { logger := s.logger.With().Str("method", "SetManifest").Logger() logger.Trace().Send() + // optimistic concurrency check etag := metautils.ExtractIncoming(stream.Context()).Get(headers.IfMatch) if etag != "" && etag != s.store.MC().Metadata().ETag { return derr.ErrHashMismatch @@ -236,6 +238,20 @@ func (s *Model) DeleteManifest(ctx context.Context, req *dsm3.DeleteManifestRequ } if err := s.store.DB().Update(func(tx *bolt.Tx) error { + // optimistic concurrency check + ifMatchHeader := metautils.ExtractIncoming(ctx).Get(headers.IfMatch) + if ifMatchHeader != "" { + dbMd := &dsm3.Metadata{UpdatedAt: timestamppb.Now(), Etag: ""} + manifest, err := ds.Manifest(dbMd).Get(ctx, tx) + if err != nil { + return nil + } + + if ifMatchHeader != manifest.Metadata.Etag { + return derr.ErrHashMismatch + } + } + if err := ds.Manifest(&dsm3.Metadata{}).Delete(ctx, tx); err != nil { return derr.ErrUnknown.Msgf("failed to delete manifest: %s", err.Error()) } diff --git a/pkg/directory/v3/reader.go b/pkg/directory/v3/reader.go index d6b07b7..4ad9f69 100644 --- a/pkg/directory/v3/reader.go +++ b/pkg/directory/v3/reader.go @@ -8,6 +8,10 @@ import ( "github.com/aserto-dev/go-directory/pkg/derr" "github.com/aserto-dev/go-edge-ds/pkg/bdb" "github.com/aserto-dev/go-edge-ds/pkg/ds" + "github.com/go-http-utils/headers" + "github.com/samber/lo" + "google.golang.org/grpc" + grpcmd "google.golang.org/grpc/metadata" "github.com/bufbuild/protovalidate-go" "github.com/rs/zerolog" @@ -45,6 +49,14 @@ func (s *Reader) GetObject(ctx context.Context, req *dsr3.GetObjectRequest) (*ds return err } + inMD, _ := grpcmd.FromIncomingContext(ctx) + // optimistic concurrency check + if lo.Contains(inMD.Get(headers.IfNoneMatch), obj.Etag) { + _ = grpc.SetHeader(ctx, grpcmd.Pairs("x-http-code", "304")) + + return nil + } + if req.GetWithRelations() { // incoming object relations of object instance (result.type == incoming.subject.type && result.key == incoming.subject.key) incoming, err := bdb.Scan[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, ds.Object(obj).Key()) @@ -170,12 +182,20 @@ func (s *Reader) GetRelation(ctx context.Context, req *dsr3.GetRelationRequest) return bdb.ErrMultipleResults } - result := relations[0] - resp.Result = result + dbRel := relations[0] + resp.Result = dbRel + + inMD, _ := grpcmd.FromIncomingContext(ctx) + // optimistic concurrency check + if lo.Contains(inMD.Get(headers.IfNoneMatch), dbRel.Etag) { + _ = grpc.SetHeader(ctx, grpcmd.Pairs("x-http-code", "304")) + + return nil + } if req.GetWithObjects() { objects := map[string]*dsc3.Object{} - rel := ds.Relation(result) + rel := ds.Relation(dbRel) sub, err := bdb.Get[dsc3.Object](ctx, tx, bdb.ObjectsPath, ds.ObjectIdentifier(rel.Subject()).Key()) if err != nil { diff --git a/pkg/directory/v3/writer.go b/pkg/directory/v3/writer.go index eb71bd4..2110a21 100644 --- a/pkg/directory/v3/writer.go +++ b/pkg/directory/v3/writer.go @@ -8,6 +8,8 @@ import ( "github.com/aserto-dev/go-directory/pkg/derr" "github.com/aserto-dev/go-edge-ds/pkg/bdb" "github.com/aserto-dev/go-edge-ds/pkg/ds" + "github.com/go-http-utils/headers" + "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "google.golang.org/protobuf/types/known/emptypb" "github.com/bufbuild/protovalidate-go" @@ -41,20 +43,27 @@ func (s *Writer) SetObject(ctx context.Context, req *dsw3.SetObjectRequest) (*ds etag := ds.Object(req.Object).Hash() err := s.store.DB().Update(func(tx *bolt.Tx) error { - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), req.Object) + updObj, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), req.Object) if err != nil { return err } - if etag == updReq.Etag { + // optimistic concurrency check + ifMatchHeader := metautils.ExtractIncoming(ctx).Get(headers.IfMatch) + // if the updReq.Etag == "" this means the this is an insert + if ifMatchHeader != "" && updObj.Etag != "" && ifMatchHeader != updObj.Etag { + return derr.ErrHashMismatch.Msgf("for object with type [%s] and id [%s]", updObj.Type, updObj.Id) + } + + if etag == updObj.Etag { s.logger.Trace().Str("key", ds.Object(req.Object).Key()).Str("etag-equal", etag).Msg("set_object") - resp.Result = updReq + resp.Result = updObj return nil } - updReq.Etag = etag + updObj.Etag = etag - objType, err := bdb.Set(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), updReq) + objType, err := bdb.Set(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), updObj) if err != nil { return err } @@ -74,15 +83,30 @@ func (s *Writer) DeleteObject(ctx context.Context, req *dsw3.DeleteObjectRequest } err := s.store.DB().Update(func(tx *bolt.Tx) error { - objIdent := &dsc3.ObjectIdentifier{ObjectType: req.GetObjectType(), ObjectId: req.GetObjectId()} - if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, ds.ObjectIdentifier(objIdent).Key()); err != nil { + objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.ObjectType, ObjectId: req.ObjectId}) + + // optimistic concurrency check + ifMatchHeader := metautils.ExtractIncoming(ctx).Get(headers.IfMatch) + if ifMatchHeader != "" { + obj := &dsc3.Object{Type: req.ObjectType, Id: req.ObjectId} + updObj, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(obj).Key(), obj) + if err != nil { + return err + } + + if ifMatchHeader != updObj.Etag { + return derr.ErrHashMismatch.Msgf("for object with type [%s] and id [%s]", updObj.Type, updObj.Id) + } + } + + if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, objIdent.Key()); err != nil { return err } if req.GetWithRelations() { { // incoming object relations of object instance (result.type == incoming.subject.type && result.key == incoming.subject.key) - iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, bdb.WithKeyFilter(ds.ObjectIdentifier(objIdent).Key()+ds.InstanceSeparator)) + iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, bdb.WithKeyFilter(objIdent.Key()+ds.InstanceSeparator)) if err != nil { return err } @@ -100,7 +124,7 @@ func (s *Writer) DeleteObject(ctx context.Context, req *dsw3.DeleteObjectRequest } { // outgoing object relations of object instance (result.type == outgoing.object.type && result.key == outgoing.object.key) - iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, bdb.WithKeyFilter(ds.ObjectIdentifier(objIdent).Key()+ds.InstanceSeparator)) + iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, bdb.WithKeyFilter(objIdent.Key()+ds.InstanceSeparator)) if err != nil { return err } @@ -137,25 +161,32 @@ func (s *Writer) SetRelation(ctx context.Context, req *dsw3.SetRelationRequest) etag := ds.Relation(req.Relation).Hash() err := s.store.DB().Update(func(tx *bolt.Tx) error { - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), req.Relation) + updRel, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), req.Relation) if err != nil { return err } - if etag == updReq.Etag { + // optimistic concurrency check + ifMatchHeader := metautils.ExtractIncoming(ctx).Get(headers.IfMatch) + // if the updReq.Etag == "" this means the this is an insert + if ifMatchHeader != "" && updRel.Etag != "" && ifMatchHeader != updRel.Etag { + return derr.ErrHashMismatch.Msgf("for relation with objectType [%s], objectId [%s], relation [%s], subjectType [%s], SubjectId [%s]", updRel.ObjectType, updRel.ObjectId, updRel.Relation, updRel.SubjectType, updRel.SubjectId) + } + + if etag == updRel.Etag { s.logger.Trace().Str("key", ds.Relation(req.Relation).ObjKey()).Str("etag-equal", etag).Msg("set_relation") - resp.Result = updReq + resp.Result = updRel return nil } - updReq.Etag = etag + updRel.Etag = etag - objRel, err := bdb.Set(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), updReq) + objRel, err := bdb.Set(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), updRel) if err != nil { return err } - if _, err := bdb.Set(ctx, tx, bdb.RelationsSubPath, ds.Relation(req.Relation).SubKey(), updReq); err != nil { + if _, err := bdb.Set(ctx, tx, bdb.RelationsSubPath, ds.Relation(req.Relation).SubKey(), updRel); err != nil { return err } @@ -175,20 +206,35 @@ func (s *Writer) DeleteRelation(ctx context.Context, req *dsw3.DeleteRelationReq } err := s.store.DB().Update(func(tx *bolt.Tx) error { - rel := ds.Relation(&dsc3.Relation{ + rel := &dsc3.Relation{ ObjectType: req.ObjectType, ObjectId: req.ObjectId, Relation: req.Relation, SubjectType: req.SubjectType, SubjectId: req.SubjectId, SubjectRelation: req.SubjectRelation, - }) + } + + dsRel := ds.Relation(rel) + + // optimistic concurrency check + ifMatchHeader := metautils.ExtractIncoming(ctx).Get(headers.IfMatch) + if ifMatchHeader != "" { + updRel, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, dsRel.ObjKey(), rel) + if err != nil { + return err + } + + if ifMatchHeader != updRel.Etag { + return derr.ErrHashMismatch.Msgf("for relation with objectType [%s], objectId [%s], relation [%s], subjectType [%s], SubjectId [%s]", rel.ObjectType, rel.ObjectId, rel.Relation, rel.SubjectType, rel.SubjectId) + } + } - if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, rel.ObjKey()); err != nil { + if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, dsRel.ObjKey()); err != nil { return err } - if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, rel.SubKey()); err != nil { + if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, dsRel.SubKey()); err != nil { return err }