Skip to content

Commit

Permalink
All RelationTerm fields are RelationRef
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenh committed Dec 6, 2023
1 parent f110717 commit c567399
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 141 deletions.
10 changes: 5 additions & 5 deletions cache/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func (c *Cache) ExpandRelation(on model.ObjectName, rn model.RelationName) []mod
switch {
case rt.Subject != nil && rt.Subject.Object == on:
results = append(results, rt.Subject.Relation)
case rt.Direct != "":
results = append(results, c.ExpandRelation(on, model.RelationName(rt.Direct))...)
case rt.Direct != nil:
results = append(results, c.ExpandRelation(on, model.RelationName(rt.Direct.Object))...)
}
}

Expand Down Expand Up @@ -80,11 +80,11 @@ func (c *Cache) expandUnion(o *model.Object, u ...*model.PermissionRef) []model.
result = append(result, rn)

exp := lo.FilterMap(o.Relations[rn].Union, func(r *model.RelationTerm, _ int) (*model.PermissionRef, bool) {
if r.Direct == "" {
if r.Direct == nil {
return &model.PermissionRef{}, false
}
_, ok := o.Relations[model.RelationName(r.Direct)]
return &model.PermissionRef{RelOrPerm: string(r.Direct)}, ok
_, ok := o.Relations[model.RelationName(r.Direct.Object)]
return &model.PermissionRef{RelOrPerm: string(r.Direct.Object)}, ok

})
result = append(result, c.expandUnion(o, exp...)...)
Expand Down
80 changes: 32 additions & 48 deletions cache/path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,30 @@ import (
"github.com/stretchr/testify/require"
)

type PathMap map[model.ObjectName]map[model.RelationName][]*model.ObjectRelation
type PathMap map[model.ObjectName]map[model.RelationName][]*model.RelationRef

func (pm PathMap) GetPath(or *model.ObjectRelation) []*model.ObjectRelation {
func (pm PathMap) GetPath(or *model.RelationRef) []*model.RelationRef {
if or == nil {
return []*model.ObjectRelation{}
return []*model.RelationRef{}
}

p1, ok := pm[or.Object]
if !ok {
return []*model.ObjectRelation{}
return []*model.RelationRef{}
}

p2, ok := p1[or.Relation]
if !ok {
return []*model.ObjectRelation{}
return []*model.RelationRef{}
}

return p2
}

// func walkPath(m *model.Model, rr *model.RelationRef, path []string) []string {

// }

func TestPathMap(t *testing.T) {
r, err := os.Open("./path_test.yaml")
require.NoError(t, err)
Expand All @@ -49,20 +53,15 @@ func TestPathMap(t *testing.T) {
require.NotNil(t, pm)

// plot all paths for all roots.
roots := []*model.ObjectRelation{}
for on, rns := range *pm {
for rn := range rns {
roots = append(roots, model.NewObjectRelation(on, rn))
path := pm.WalkPath(model.NewRelationRef(on, rn), []string{})
fmt.Printf("%s:%s: %s\n", on, rn, strings.Join(path, " -> "))
}
}

for i := 0; i < len(roots); i++ {
path := pm.WalkPath(roots[i], []string{})
fmt.Println(strings.Join(path, " -> "))
}
}

func (pm PathMap) WalkPath(or *model.ObjectRelation, path []string) []string {
func (pm PathMap) WalkPath(or *model.RelationRef, path []string) []string {
paths := pm.GetPath(or)
for i := 0; i < len(paths); i++ {
path = append(path, paths[i].String())
Expand All @@ -71,7 +70,7 @@ func (pm PathMap) WalkPath(or *model.ObjectRelation, path []string) []string {
return path
}

func (pm PathMap) plotPaths(w io.Writer, or *model.ObjectRelation) {
func (pm PathMap) plotPaths(w io.Writer, or *model.RelationRef) {

Check failure on line 73 in cache/path_test.go

View workflow job for this annotation

GitHub Actions / test

func `PathMap.plotPaths` is unused (unused)
paths := pm.GetPath(or)

for _, p := range paths {
Expand All @@ -93,20 +92,20 @@ func createPathMap(m *model.Model) *PathMap {
// create roots
for on, o := range m.Objects {
if _, ok := pm[on]; !ok {
pm[on] = map[model.RelationName][]*model.ObjectRelation{}
pm[on] = map[model.RelationName][]*model.RelationRef{}
}

p1 := pm[on]

for pn := range o.Permissions {
if _, ok := p1[pn.RN()]; !ok {
p1[pn.RN()] = []*model.ObjectRelation{}
p1[pn.RN()] = []*model.RelationRef{}
}
}

for rn := range o.Relations {
if _, ok := p1[rn]; !ok {
p1[rn] = []*model.ObjectRelation{}
p1[rn] = []*model.RelationRef{}
}
}
}
Expand All @@ -127,8 +126,8 @@ func createPathMap(m *model.Model) *PathMap {
return &pm
}

func expandPerm(m *model.Model, on model.ObjectName, pn model.PermissionName) []*model.ObjectRelation {
result := []*model.ObjectRelation{}
func expandPerm(m *model.Model, on model.ObjectName, pn model.PermissionName) []*model.RelationRef {
result := []*model.RelationRef{}

p, ok := m.Objects[on].Permissions[pn]
if !ok {
Expand All @@ -150,73 +149,58 @@ func expandPerm(m *model.Model, on model.ObjectName, pn model.PermissionName) []
return result
}

func expandRel(m *model.Model, on model.ObjectName, rn model.RelationName) []*model.ObjectRelation {
result := []*model.ObjectRelation{}
func expandRel(m *model.Model, on model.ObjectName, rn model.RelationName) []*model.RelationRef {
result := []*model.RelationRef{}

relation, ok := m.Objects[on].Relations[rn]
if !ok {
return result
}

for _, r := range relation.Union {
if r.Direct != "" {
result = append(result, &model.ObjectRelation{
Object: r.Direct,
Relation: "",
})
if r.Direct != nil {
result = append(result, r.Direct)
}

if r.Subject != nil {
result = append(result, &model.ObjectRelation{
result = append(result, &model.RelationRef{
Object: r.Subject.Object,
Relation: r.Subject.Relation,
})
}

if r.Wildcard != "" {
result = append(result, &model.ObjectRelation{
Object: r.Wildcard,
Relation: "*",
})
if r.Wildcard != nil {
result = append(result, r.Wildcard)
}
}

return result
}

func resolve(m *model.Model, on model.ObjectName, rn model.RelationName) *model.ObjectRelation {
func resolve(m *model.Model, on model.ObjectName, rn model.RelationName) *model.RelationRef {
if strings.Contains(rn.String(), v3.ArrowIdentifier) {
parts := strings.Split(rn.String(), v3.ArrowIdentifier)

rn = model.RelationName(parts[0])

if _, ok := m.Objects[on].Relations[rn]; ok { // if c.RelationExists(on, rn) {
for _, rel := range m.Objects[on].Relations[rn].Union {
if rel.Direct != "" {
return &model.ObjectRelation{
Object: rel.Direct,
Relation: model.RelationName(parts[1]),
}
if rel.Direct != nil {
return rel.Direct
}

if rel.Subject != nil {
return &model.ObjectRelation{
Object: rel.Subject.Object,
Relation: rel.Subject.Relation,
}
return rel.Subject.RelationRef
}

if rel.Wildcard != "" {
return &model.ObjectRelation{
Object: rel.Wildcard,
Relation: "*",
}
if rel.Wildcard != nil {
return rel.Wildcard
}
}
}
}

return &model.ObjectRelation{
return &model.RelationRef{
Object: on,
Relation: rn,
}
Expand Down
97 changes: 48 additions & 49 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,40 @@ type Relation struct {
}

type RelationTerm struct {
Direct ObjectName `json:"direct,omitempty"`
Direct *RelationRef `json:"direct,omitempty"`
Subject *SubjectRelation `json:"subject,omitempty"`
Wildcard ObjectName `json:"wildcard,omitempty"`
Wildcard *RelationRef `json:"wildcard,omitempty"`
}

func (rt *RelationTerm) Ref() *RelationRef {
switch {
case rt.Direct != nil:
return rt.Direct
case rt.Subject != nil:
return rt.Subject.RelationRef
case rt.Wildcard != nil:
return rt.Wildcard
default:
return nil
}
}

type RelationRef struct {
Object ObjectName `json:"object,omitempty"`
Relation RelationName `json:"relation,omitempty"`
}

func NewRelationRef(on ObjectName, rn RelationName) *RelationRef {
return &RelationRef{Object: on, Relation: rn}
}

func (rr *RelationRef) String() string {
if rr.Relation == "" {
return string(rr.Object)
}
return fmt.Sprintf("%s:%s", rr.Object, rr.Relation)
}

type SubjectRelation struct {
*RelationRef
SubjectTypes []ObjectName `json:"subject_types,omitempty"`
Expand All @@ -104,7 +128,6 @@ func (p *Permission) Refs() []*PermissionRef {
}

return refs

}

type PermissionRef struct {
Expand All @@ -123,19 +146,6 @@ type ArrowPermission struct {
Permission string `json:"permission,omitempty"`
}

type ObjectRelation struct {
Object ObjectName `json:"object"`
Relation RelationName `json:"relation,omitempty"`
}

func NewObjectRelation(on ObjectName, rn RelationName) *ObjectRelation {
return &ObjectRelation{Object: on, Relation: rn}
}

func (or ObjectRelation) String() string {
return fmt.Sprintf("%s:%s", or.Object, or.Relation)
}

func New(r io.Reader) (*Model, error) {
m := Model{}
dec := json.NewDecoder(r)
Expand All @@ -154,8 +164,8 @@ func (m *Model) GetGraph() *graph.Graph {
for objectName, obj := range m.Objects {
for relName, rel := range obj.Relations {
for _, rl := range rel.Union {
if string(rl.Direct) != "" {
grph.AddEdge(string(objectName), string(rl.Direct), string(relName))
if rl.Direct != nil {
grph.AddEdge(string(objectName), string(rl.Direct.Object), string(relName))
} else if rl.Subject != nil {
grph.AddEdge(string(objectName), string(rl.Subject.Object), string(relName))
}
Expand Down Expand Up @@ -263,36 +273,28 @@ func (m *Model) validateObjectRels(on ObjectName, o *Object) error {
var errs error
for rn, rs := range o.Relations {
for _, r := range rs.Union {
switch {
case r.Direct != "":
if _, ok := m.Objects[r.Direct]; !ok {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' references undefined object type '%s'", on, rn, r.Direct),
)
}
case r.Wildcard != "":
if _, ok := m.Objects[r.Wildcard]; !ok {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' references undefined object type '%s'", on, rn, r.Wildcard),
)
}
case r.Subject != nil:
if _, ok := m.Objects[r.Subject.Object]; !ok {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' references undefined object type '%s'", on, rn, r.Subject.Object),
)
break
}
ref := r.Ref()
if ref == nil {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' has no definition", on, rn),
)
continue
}

if _, ok := m.Objects[r.Subject.Object].Relations[r.Subject.Relation]; !ok {
o := m.Objects[ref.Object]
if o == nil {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' references undefined object type '%s'", on, rn, ref.Object),
)
continue
}

if r.Subject != nil {
if _, ok := o.Relations[ref.Relation]; !ok {
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' references undefined relation type '%s#%s'", on, rn, r.Subject.Object, r.Subject.Relation),
"relation '%s:%s' references undefined relation type '%s#%s'", on, rn, ref.Object, ref.Relation),
)
}
default:
errs = multierror.Append(errs, derr.ErrInvalidRelation.Msgf(
"relation '%s:%s' has no definition", on, rn),
)
}
}
}
Expand Down Expand Up @@ -358,10 +360,6 @@ func (m *Model) resolveRelation(r *Relation, seen RelSet) []ObjectName {
subjectTypes := []ObjectName{}
for _, rt := range r.Union {
switch {
case rt.Direct != "":
subjectTypes = append(subjectTypes, rt.Direct)
case rt.Wildcard != "":
subjectTypes = append(subjectTypes, rt.Wildcard)
case rt.Subject != nil:
if !seen.Contains(*rt.Subject.RelationRef) {
seen.Add(*rt.Subject.RelationRef)
Expand All @@ -370,7 +368,8 @@ func (m *Model) resolveRelation(r *Relation, seen RelSet) []ObjectName {
seen)...,
)
}

default:
subjectTypes = append(subjectTypes, rt.Ref().Object)
}
}
return subjectTypes
Expand Down
Loading

0 comments on commit c567399

Please sign in to comment.