Skip to content

Commit

Permalink
Memory pool for slices and messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenh committed Dec 8, 2024
1 parent 05080ab commit b2198ea
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 154 deletions.
8 changes: 4 additions & 4 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"sync"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
"github.com/aserto-dev/azm/model/diff"
stts "github.com/aserto-dev/azm/stats"
Expand All @@ -29,7 +29,7 @@ func New(m *model.Model) *Cache {
return &Cache{
model: m,
mtx: sync.RWMutex{},
relsPool: mempool.NewSlicePool[*dsc.Relation](),
relsPool: mempool.NewCollectionPool[dsc.Relation, *dsc.Relation](),
}
}

Expand All @@ -42,8 +42,8 @@ func (c *Cache) UpdateModel(m *model.Model) error {
}

func (c *Cache) CanUpdate(other *model.Model, stats *stts.Stats) error {
c.mtx.Lock()
defer c.mtx.Unlock()
c.mtx.RLock()
defer c.mtx.RUnlock()
return diff.CanUpdateModel(c.model, other, stats)
}

Expand Down
6 changes: 0 additions & 6 deletions cache/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ import (
)

func (c *Cache) Check(req *dsr.CheckRequest, relReader graph.RelationReader) (*dsr.CheckResponse, error) {
c.mtx.RLock()
defer c.mtx.RUnlock()

checker := graph.NewCheck(c.model, req, relReader, c.relsPool)

ctx := pb.NewStruct()
Expand All @@ -29,9 +26,6 @@ type graphSearch interface {
}

func (c *Cache) GetGraph(req *dsr.GetGraphRequest, relReader graph.RelationReader) (*dsr.GetGraphResponse, error) {
c.mtx.RLock()
defer c.mtx.RUnlock()

var (
search graphSearch
err error
Expand Down
16 changes: 6 additions & 10 deletions graph/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,8 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
steps := c.m.StepRelation(r, params.st)

// Reuse the same slice in all steps.
relsPtr := c.pool.Get()
defer func() {
*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)
}()
relsPtr := c.pool.GetSlice()
defer c.pool.PutSlice(relsPtr)

for _, step := range steps {
*relsPtr = (*relsPtr)[:0]
Expand All @@ -114,7 +111,7 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
req.SubjectRelation = step.Relation.String()
}

if err := c.getRels(req, relsPtr); err != nil {
if err := c.getRels(req, c.pool, relsPtr); err != nil {
return checkStatusFalse, err
}

Expand Down Expand Up @@ -206,10 +203,10 @@ func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relati
Relation: pt.Base.String(),
}

relsPtr := c.pool.Get()
relsPtr := c.pool.GetSlice()

// Resolve the base of the arrow.
err := c.getRels(query, relsPtr)
err := c.getRels(query, c.pool, relsPtr)
if err != nil {
return relations{}, err
}
Expand All @@ -224,8 +221,7 @@ func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relati
}
})

*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)
c.pool.PutSlice(relsPtr)

return expanded, nil
}
Expand Down
137 changes: 81 additions & 56 deletions graph/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,109 @@ import (
"testing"

azmgraph "github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/mempool"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)

func TestCheck(t *testing.T) {
tests := []struct {
check string
expected bool
}{
// Relations
{"doc:doc1#owner@user:user1", false},
{"doc:doc1#viewer@user:user1", true},
{"doc:doc2#viewer@user:user1", true},
{"doc:doc2#viewer@user:userX", true},
{"doc:doc1#viewer@user:user2", true},
{"doc:doc1#viewer@user:user3", true},
{"doc:doc1#viewer@group:d1_viewers", false},

{"group:yin#member@user:yin_user", true},
{"group:yin#member@user:yang_user", true},
{"group:yang#member@user:yin_user", true},
{"group:yang#member@user:yang_user", true},

{"group:alpha#member@user:user1", false},

// Permissions
{"doc:doc1#can_change_owner@user:d1_owner", true},
{"doc:doc1#can_change_owner@user:user1", false},
{"doc:doc1#can_change_owner@user:userX", false},

{"doc:doc1#can_read@user:d1_owner", true},
{"doc:doc1#can_read@user:f1_owner", true},
{"doc:doc1#can_read@user:user1", true},
{"doc:doc1#can_read@user:f1_viewer", true},
{"doc:doc1#can_read@user:userX", false},

{"doc:doc1#can_write@user:d1_owner", true},
{"doc:doc1#can_write@user:f1_owner", true},
{"doc:doc1#can_write@user:user2", false},

{"folder:folder1#owner@user:f1_owner", true},
{"folder:folder1#can_create_file@user:f1_owner", true},
{"folder:folder1#can_share@user:f1_owner", true},

// intersection
{"doc:doc1#can_share@user:d1_owner", false},
{"doc:doc1#can_share@user:f1_owner", true},

// negation
{"folder:folder1#can_read@user:f1_owner", true},
{"doc:doc1#viewer@user:f1_owner", false},
{"doc:doc1#can_invite@user:f1_owner", true},

// cycles
{"cycle:loop#can_delete@user:loop_owner", true},
{"cycle:loop#can_delete@user:user1", false},
}
var tests = []struct {
check string
expected bool
}{
// Relations
{"doc:doc1#owner@user:user1", false},
{"doc:doc1#viewer@user:user1", true},
{"doc:doc2#viewer@user:user1", true},
{"doc:doc2#viewer@user:userX", true},
{"doc:doc1#viewer@user:user2", true},
{"doc:doc1#viewer@user:user3", true},
{"doc:doc1#viewer@group:d1_viewers", false},

{"group:yin#member@user:yin_user", true},
{"group:yin#member@user:yang_user", true},
{"group:yang#member@user:yin_user", true},
{"group:yang#member@user:yang_user", true},

{"group:alpha#member@user:user1", false},

// Permissions
{"doc:doc1#can_change_owner@user:d1_owner", true},
{"doc:doc1#can_change_owner@user:user1", false},
{"doc:doc1#can_change_owner@user:userX", false},

{"doc:doc1#can_read@user:d1_owner", true},
{"doc:doc1#can_read@user:f1_owner", true},
{"doc:doc1#can_read@user:user1", true},
{"doc:doc1#can_read@user:f1_viewer", true},
{"doc:doc1#can_read@user:userX", false},

{"doc:doc1#can_write@user:d1_owner", true},
{"doc:doc1#can_write@user:f1_owner", true},
{"doc:doc1#can_write@user:user2", false},

{"folder:folder1#owner@user:f1_owner", true},
{"folder:folder1#can_create_file@user:f1_owner", true},
{"folder:folder1#can_share@user:f1_owner", true},

// intersection
{"doc:doc1#can_share@user:d1_owner", false},
{"doc:doc1#can_share@user:f1_owner", true},

// negation
{"folder:folder1#can_read@user:f1_owner", true},
{"doc:doc1#viewer@user:f1_owner", false},
{"doc:doc1#can_invite@user:f1_owner", true},

// cycles
{"cycle:loop#can_delete@user:loop_owner", true},
{"cycle:loop#can_delete@user:user1", false},
}

func TestCheck(t *testing.T) {
m, err := v3.LoadFile("./check_test.yaml")
assert.NoError(t, err)
assert.NotNil(t, m)

pool := mempool.NewSlicePool[*dsc.Relation]()
pool := mempool.NewCollectionPool[dsc.Relation]()

for _, test := range tests {
t.Run(test.check, func(tt *testing.T) {
assert := assert.New(tt)

checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations, pool)
checker := azmgraph.NewCheck(m, checkReq(test.check, true), rels.GetRelations, pool)

res, err := checker.Check()
assert.NoError(err)
tt.Log("trace:\n", strings.Join(checker.Trace(), "\n"))
assert.Equal(test.expected, res)
})
}
}

func BenchmarkCheck(b *testing.B) {
zerolog.SetGlobalLevel(zerolog.InfoLevel)

m, err := v3.LoadFile("./check_test.yaml")
if err != nil {
b.Fatalf("failed to load model: %s", err)
}

pool := mempool.NewCollectionPool[dsc.Relation]()

b.ResetTimer()
for _, test := range tests {
assert := assert.New(b)

b.StopTimer()
checker := azmgraph.NewCheck(m, checkReq(test.check, false), rels.GetRelations, pool)
b.StartTimer()

res, err := checker.Check()
assert.NoError(err)
assert.Equal(test.expected, res)
}

}

Expand Down
4 changes: 2 additions & 2 deletions graph/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ func wildcardParams(params *relation) *relation {
}

func invertedRelationReader(m *model.Model, reader RelationReader) RelationReader {
return func(r *dsc.Relation, out *Relations) error {
return func(r *dsc.Relation, relPool MessagePool[dsc.Relation, *dsc.Relation], out *Relations) error {
ir := uninvertRelation(m, relationFromProto(r))
if err := reader(ir.asProto(), out); err != nil {
if err := reader(ir.asProto(), relPool, out); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions graph/objects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"testing"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
Expand Down Expand Up @@ -36,7 +36,7 @@ func TestSearchObjects(t *testing.T) {
im.Validate(model.SkipNameValidation, model.AllowPermissionInArrowBase),
)

pool := mempool.NewSlicePool[*dsc.Relation]()
pool := mempool.NewCollectionPool[dsc.Relation]()

for _, test := range searchObjectsTests {
t.Run(test.search, func(tt *testing.T) {
Expand Down
23 changes: 14 additions & 9 deletions graph/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
"strings"

"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -14,14 +14,9 @@ import (
)

type (
ObjectID = model.ObjectID

Relations = []*dsc.Relation

// RelationReader retrieves relations that match the given filter.
RelationReader func(*dsc.Relation, *Relations) error

RelationsPool = mempool.Pool[*Relations]
ObjectID = model.ObjectID
Relations = []*dsc.Relation
RelationsPool = mempool.CollectionPool[dsc.Relation, *dsc.Relation]

searchPath relations

Expand All @@ -35,6 +30,16 @@ type (
searchResults map[object][]searchPath
)

type MessagePool[M any, T mempool.Resetable[M]] interface {
Get() T
Put(T)
}

type RelationPool = MessagePool[dsc.Relation, *dsc.Relation]

// RelationReader retrieves relations that match the given filter.
type RelationReader func(*dsc.Relation, RelationPool, *Relations) error

// Objects returns the objects from the search results.
func (r searchResults) Objects() []*dsc.ObjectIdentifier {
return lo.MapToSlice(r, func(o object, _ []searchPath) *dsc.ObjectIdentifier {
Expand Down
Loading

0 comments on commit b2198ea

Please sign in to comment.