Skip to content

Commit

Permalink
Use sync.Pool for Relation slices. (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenh authored Dec 6, 2024
1 parent 830edfe commit 3a4bb96
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 55 deletions.
12 changes: 8 additions & 4 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cache
import (
"sync"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
"github.com/aserto-dev/azm/model/diff"
stts "github.com/aserto-dev/azm/stats"
Expand All @@ -17,15 +19,17 @@ type (
)

type Cache struct {
model *model.Model
mtx sync.RWMutex
model *model.Model
mtx sync.RWMutex
relsPool *graph.RelationsPool
}

// New, create new model cache instance.
func New(m *model.Model) *Cache {
return &Cache{
model: m,
mtx: sync.RWMutex{},
model: m,
mtx: sync.RWMutex{},
relsPool: mempool.NewSlicePool[*dsc.Relation](),
}
}

Expand Down
12 changes: 9 additions & 3 deletions cache/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
)

func (c *Cache) Check(req *dsr.CheckRequest, relReader graph.RelationReader) (*dsr.CheckResponse, error) {
checker := graph.NewCheck(c.model, req, relReader)
c.mtx.RLock()
defer c.mtx.RUnlock()

checker := graph.NewCheck(c.model, req, relReader, c.relsPool)

ctx := pb.NewStruct()

Expand All @@ -26,15 +29,18 @@ type graphSearch interface {
}

func (c *Cache) GetGraph(req *dsr.GetGraphRequest, relReader graph.RelationReader) (*dsr.GetGraphResponse, error) {
c.mtx.RLock()
defer c.mtx.RUnlock()

var (
search graphSearch
err error
)

if req.ObjectId == "" {
search, err = graph.NewObjectSearch(c.model, req, relReader)
search, err = graph.NewObjectSearch(c.model, req, relReader, c.relsPool)
} else {
search, err = graph.NewSubjectSearch(c.model, req, relReader)
search, err = graph.NewSubjectSearch(c.model, req, relReader, c.relsPool)
}

if err != nil {
Expand Down
37 changes: 27 additions & 10 deletions graph/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ type Checker struct {
getRels RelationReader

memo *checkMemo
pool *RelationsPool
}

func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Checker {
func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader, pool *RelationsPool) *Checker {
return &Checker{
m: m,
params: &relation{
Expand All @@ -29,6 +30,7 @@ func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Che
},
getRels: reader,
memo: newCheckMemo(req.Trace),
pool: pool,
}
}

Expand Down Expand Up @@ -88,7 +90,16 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
r := c.m.Objects[params.ot].Relations[params.rel]
steps := c.m.StepRelation(r, params.st)

// Reuse the same slice in all steps.
relsPtr := c.pool.Get()
defer func() {
*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)
}()

for _, step := range steps {
*relsPtr = (*relsPtr)[:0]

req := &dsc.Relation{
ObjectType: params.ot.String(),
ObjectId: params.oid.String(),
Expand All @@ -103,27 +114,26 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
req.SubjectRelation = step.Relation.String()
}

rels, err := c.getRels(req)
if err != nil {
if err := c.getRels(req, relsPtr); err != nil {
return checkStatusFalse, err
}

switch {
case step.IsDirect():
for _, rel := range rels {
for _, rel := range *relsPtr {
if rel.SubjectId == params.sid.String() {
return checkStatusTrue, nil
}
}

case step.IsWildcard():
if len(rels) > 0 {
if len(*relsPtr) > 0 {
// We have a wildcard match.
return checkStatusTrue, nil
}

case step.IsSubject():
for _, rel := range rels {
for _, rel := range *relsPtr {
if status, err := c.check(&relation{
ot: step.Object,
oid: ObjectID(rel.SubjectId),
Expand Down Expand Up @@ -190,17 +200,21 @@ func (c *Checker) checkPermission(params *relation) (checkStatus, error) {

func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relations, error) {
if pt.IsArrow() {
// Resolve the base of the arrow.
rels, err := c.getRels(&dsc.Relation{
query := &dsc.Relation{
ObjectType: params.ot.String(),
ObjectId: params.oid.String(),
Relation: pt.Base.String(),
})
}

relsPtr := c.pool.Get()

// Resolve the base of the arrow.
err := c.getRels(query, relsPtr)
if err != nil {
return relations{}, err
}

expanded := lo.Map(rels, func(rel *dsc.Relation, _ int) *relation {
expanded := lo.Map(*relsPtr, func(rel *dsc.Relation, _ int) *relation {
return &relation{
ot: model.ObjectName(rel.SubjectType),
oid: ObjectID(rel.SubjectId),
Expand All @@ -210,6 +224,9 @@ func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relati
}
})

*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)

return expanded, nil
}

Expand Down
6 changes: 5 additions & 1 deletion graph/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"testing"

azmgraph "github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -67,11 +69,13 @@ func TestCheck(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, m)

pool := mempool.NewSlicePool[*dsc.Relation]()

for _, test := range tests {
t.Run(test.check, func(tt *testing.T) {
assert := assert.New(tt)

checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations)
checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations, pool)

res, err := checker.Check()
assert.NoError(err)
Expand Down
20 changes: 12 additions & 8 deletions graph/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ObjectSearch struct {
wildcardSearch *SubjectSearch
}

func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader) (*ObjectSearch, error) {
func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*ObjectSearch, error) {
params := searchParams(req)
if err := validate(m, params); err != nil {
return nil, err
Expand All @@ -40,13 +40,15 @@ func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationRe
getRels: invertedRelationReader(im, reader),
memo: newSearchMemo(req.Trace),
explain: req.Explain,
pool: pool,
}},
wildcardSearch: &SubjectSearch{graphSearch{
m: im,
params: wildcardParams(iParams),
getRels: invertedRelationReader(im, reader),
memo: newSearchMemo(req.Trace),
explain: req.Explain,
pool: pool,
}},
}, nil
}
Expand Down Expand Up @@ -125,22 +127,24 @@ func wildcardParams(params *relation) *relation {
}

func invertedRelationReader(m *model.Model, reader RelationReader) RelationReader {
return func(r *dsc.Relation) ([]*dsc.Relation, error) {
return func(r *dsc.Relation, out *Relations) error {
ir := uninvertRelation(m, relationFromProto(r))
res, err := reader(ir.asProto())
if err != nil {
return nil, err
if err := reader(ir.asProto(), out); err != nil {
return err
}

return lo.Map(res, func(r *dsc.Relation, _ int) *dsc.Relation {
return &dsc.Relation{
res := *out
for i, r := range res {
res[i] = &dsc.Relation{
ObjectType: r.SubjectType,
ObjectId: r.SubjectId,
Relation: r.Relation,
SubjectType: r.ObjectType,
SubjectId: r.ObjectId,
}
}), nil
}

return nil
}
}

Expand Down
5 changes: 4 additions & 1 deletion graph/objects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
Expand Down Expand Up @@ -35,11 +36,13 @@ func TestSearchObjects(t *testing.T) {
im.Validate(model.SkipNameValidation, model.AllowPermissionInArrowBase),
)

pool := mempool.NewSlicePool[*dsc.Relation]()

for _, test := range searchObjectsTests {
t.Run(test.search, func(tt *testing.T) {
assert := assert.New(tt)

objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations)
objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations, pool)
assert.NoError(err)

res, err := objSearch.Search()
Expand Down
30 changes: 19 additions & 11 deletions graph/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"strings"

"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -12,21 +13,27 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

type ObjectID = model.ObjectID
type (
ObjectID = model.ObjectID

// RelationReader retrieves relations that match the given filter.
type RelationReader func(*dsc.Relation) ([]*dsc.Relation, error)
Relations = []*dsc.Relation

type searchPath relations
// RelationReader retrieves relations that match the given filter.
RelationReader func(*dsc.Relation, *Relations) error

type object struct {
Type model.ObjectName
ID ObjectID
}
RelationsPool = mempool.Pool[*Relations]

searchPath relations

// The results of a search is a map where the key is a matching relations
// and the value is a list of paths that connect the search object and subject.
type searchResults map[object][]searchPath
object struct {
Type model.ObjectName
ID ObjectID
}

// The results of a search is a map where the key is a matching relations
// and the value is a list of paths that connect the search object and subject.
searchResults map[object][]searchPath
)

// Objects returns the objects from the search results.
func (r searchResults) Objects() []*dsc.ObjectIdentifier {
Expand Down Expand Up @@ -92,6 +99,7 @@ type graphSearch struct {

memo *searchMemo
explain bool
pool *RelationsPool
}

func validate(m *model.Model, params *relation) error {
Expand Down
Loading

0 comments on commit 3a4bb96

Please sign in to comment.