From be9e0e38658df91859fd3d486900db378c9ea3c4 Mon Sep 17 00:00:00 2001 From: oanatmaria Date: Wed, 11 Oct 2023 18:05:47 +0300 Subject: [PATCH 1/4] Add model graph --- graph/graph.go | 69 +++++++++++++++++++++++++++++++++++++++++++++ model/model.go | 22 +++++++++++++++ model/model_test.go | 10 +++++++ 3 files changed, 101 insertions(+) create mode 100644 graph/graph.go diff --git a/graph/graph.go b/graph/graph.go new file mode 100644 index 0000000..b00c9b8 --- /dev/null +++ b/graph/graph.go @@ -0,0 +1,69 @@ +package graph + +type Graph struct { + nodes []string + adjMatrix [][]string +} + +func NewGraph() *Graph { + return &Graph{ + nodes: make([]string, 0), + adjMatrix: make([][]string, 0), + } +} + +func (g *Graph) AddNode(nodeName string) { + for _, vertexName := range g.nodes { + if nodeName == vertexName { + return + } + } + g.nodes = append(g.nodes, nodeName) + g.adjMatrix = append(g.adjMatrix, make([]string, 0)) +} + +func (g *Graph) AddEdge(source, dest, edgeName string) { + sourceIndex := g.findVertexIndexByName(source) + destIndex := g.findVertexIndexByName(dest) + if destIndex == -1 || sourceIndex == -1 { + return + } + + if len(g.adjMatrix[sourceIndex]) == 0 { + g.adjMatrix[sourceIndex] = make([]string, len(g.nodes)) + } + + g.adjMatrix[sourceIndex][destIndex] = edgeName +} + +func (g *Graph) dFS(source int, visited []bool, traversal []string) []string { + visited[source] = true + traversal = append(traversal, g.nodes[source]) + for i := 0; i < len(g.nodes); i++ { + if len(g.adjMatrix[source]) != 0 && g.adjMatrix[source][i] != "" && !visited[i] { + traversal = append(traversal, g.adjMatrix[source][i]) + traversal = g.dFS(i, visited, traversal) + } + } + return traversal +} + +func (g *Graph) TraverseGraph(startVertex string) []string { + visited := make([]bool, len(g.nodes)) + traversal := make([]string, 0) + startIndex := g.findVertexIndexByName(startVertex) + if startIndex == -1 { + return traversal + } + + return g.dFS(startIndex, visited, traversal) +} + +func (g *Graph) findVertexIndexByName(name string) int { + for index, vertexName := range g.nodes { + if vertexName == name { + return index + } + } + return -1 +} diff --git a/model/model.go b/model/model.go index a51443c..9aef76d 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,8 @@ import ( "encoding/json" "io" "time" + + "github.com/aserto-dev/azm/graph" ) const ModelVersion int = 1 @@ -77,6 +79,26 @@ func New(r io.Reader) (*Model, error) { return &m, nil } +func (m *Model) GetGraph() *graph.Graph { + grph := graph.NewGraph() + for objectName := range m.Objects { + grph.AddNode(string(objectName)) + } + for objectName, obj := range m.Objects { + for relName, rel := range obj.Relations { + for _, rl := range rel { + if string(rl.Direct) != "" { + grph.AddEdge(string(objectName), string(rl.Direct), string(relName)) + } else if rl.Subject != nil { + grph.AddEdge(string(objectName), string(rl.Subject.Object), string(relName)) + } + } + } + } + + return grph +} + func (m *Model) Reader() (io.Reader, error) { b := bytes.Buffer{} enc := json.NewEncoder(&b) diff --git a/model/model_test.go b/model/model_test.go index a2e1604..afb893f 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -226,3 +226,13 @@ func TestDiff(t *testing.T) { stretch.Equal(t, len(diffM1WithM3.Removed.Relations), 1) stretch.Equal(t, diffM1WithM3.Removed.Relations["document"], []model.RelationName{"parent_folder"}) } + +func TestGraph(t *testing.T) { + graph := m1.GetGraph() + + traversal := graph.TraverseGraph("document") + require.Equal(t, len(traversal), 5) + traversal = graph.TraverseGraph("group") + require.Equal(t, len(traversal), 3) + require.Equal(t, traversal, []string{"group", "member", "user"}) +} From 744a550ed6c24d77218347bed5792e603b725c09 Mon Sep 17 00:00:00 2001 From: oanatmaria Date: Wed, 25 Oct 2023 14:53:17 +0300 Subject: [PATCH 2/4] Add search and return array of paths --- graph/graph.go | 87 ++++++++++++++++++++++++++++++++++++--------- model/model_test.go | 59 +++++++++++++++++++++++++++--- 2 files changed, 125 insertions(+), 21 deletions(-) diff --git a/graph/graph.go b/graph/graph.go index b00c9b8..5fe534f 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -1,14 +1,18 @@ package graph +import "github.com/samber/lo" + +type Relations []string + type Graph struct { nodes []string - adjMatrix [][]string + adjMatrix [][]Relations } func NewGraph() *Graph { return &Graph{ nodes: make([]string, 0), - adjMatrix: make([][]string, 0), + adjMatrix: make([][]Relations, 0), } } @@ -19,44 +23,93 @@ func (g *Graph) AddNode(nodeName string) { } } g.nodes = append(g.nodes, nodeName) - g.adjMatrix = append(g.adjMatrix, make([]string, 0)) + g.adjMatrix = append(g.adjMatrix, make([]Relations, 0)) } func (g *Graph) AddEdge(source, dest, edgeName string) { sourceIndex := g.findVertexIndexByName(source) destIndex := g.findVertexIndexByName(dest) - if destIndex == -1 || sourceIndex == -1 { - return + if destIndex == -1 { + g.AddNode(dest) + destIndex = g.findVertexIndexByName(dest) + } + + if sourceIndex == -1 { + g.AddNode(source) + sourceIndex = g.findVertexIndexByName(source) } if len(g.adjMatrix[sourceIndex]) == 0 { - g.adjMatrix[sourceIndex] = make([]string, len(g.nodes)) + g.adjMatrix[sourceIndex] = make([]Relations, len(g.nodes)) } - g.adjMatrix[sourceIndex][destIndex] = edgeName + g.adjMatrix[sourceIndex][destIndex] = append(g.adjMatrix[sourceIndex][destIndex], edgeName) } -func (g *Graph) dFS(source int, visited []bool, traversal []string) []string { - visited[source] = true - traversal = append(traversal, g.nodes[source]) +func (g *Graph) dFS(source int, visited []map[int]bool, traversal [][]string, index int) [][]string { + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.nodes[source]) for i := 0; i < len(g.nodes); i++ { - if len(g.adjMatrix[source]) != 0 && g.adjMatrix[source][i] != "" && !visited[i] { - traversal = append(traversal, g.adjMatrix[source][i]) - traversal = g.dFS(i, visited, traversal) + if len(g.adjMatrix[source]) > i && len(g.adjMatrix[source][i]) != 0 { + if len(visited[i]) == 0 { + visited[i] = make(map[int]bool, len(g.adjMatrix[source][i])) + } + for j := 0; j < len(g.adjMatrix[source][i]); j++ { + if !visited[i][j] { + visited[i][j] = true + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) + traversal = g.dFS(i, visited, traversal, index) + index = index + 1 + } + } } } return traversal } -func (g *Graph) TraverseGraph(startVertex string) []string { - visited := make([]bool, len(g.nodes)) - traversal := make([]string, 0) +func (g *Graph) TraverseGraph(startVertex string) [][]string { + visited := make([]map[int]bool, len(g.nodes)) + traversal := make([][]string, 0) startIndex := g.findVertexIndexByName(startVertex) if startIndex == -1 { return traversal } + traversal = g.dFS(startIndex, visited, traversal, 0) + for i, paths := range traversal { + if paths[0] != startVertex { + traversal[i] = append([]string{startVertex}, paths...) + } + } + + return traversal +} + +func (g *Graph) SearchGraph(startVertex, destVertx string) [][]string { + visited := make([]map[int]bool, len(g.nodes)) + traversal := make([][]string, 0) + startIndex := g.findVertexIndexByName(startVertex) + destIndex := g.findVertexIndexByName(destVertx) + if startIndex == -1 || destIndex == -1 { + return traversal + } + result := make([][]string, 0) + traversal = g.dFS(startIndex, visited, traversal, 0) + for i, paths := range traversal { + _, found := lo.Find(paths, func(elem string) bool { return elem == destVertx }) + if found { + if paths[0] != startVertex { + traversal[i] = append([]string{startVertex}, paths...) + } + result = append(result, traversal[i]) + } + } - return g.dFS(startIndex, visited, traversal) + return result } func (g *Graph) findVertexIndexByName(name string) int { diff --git a/model/model_test.go b/model/model_test.go index afb893f..1e95ed0 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -228,11 +228,62 @@ func TestDiff(t *testing.T) { } func TestGraph(t *testing.T) { - graph := m1.GetGraph() + m := model.Model{ + Version: 1, + Objects: map[model.ObjectName]*model.Object{ + model.ObjectName("user"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("rel_name"): { + &model.Relation{Direct: model.ObjectName("ext_obj")}, + }, + }, + }, + model.ObjectName("group"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("member"): { + &model.Relation{Direct: model.ObjectName("user")}, + &model.Relation{Subject: &model.SubjectRelation{ + Object: model.ObjectName("group"), + Relation: model.RelationName("member"), + }}, + }, + }, + }, + model.ObjectName("folder"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("owner"): { + &model.Relation{Direct: model.ObjectName("user")}, + }, + }, + }, + model.ObjectName("document"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("parent_folder"): { + {Direct: model.ObjectName("folder")}, + }, + model.RelationName("writer"): { + {Direct: model.ObjectName("user")}, + }, + model.RelationName("reader"): { + {Direct: model.ObjectName("user")}, + {Wildcard: model.ObjectName("user")}, + }, + }, + }, + }, + } + graph := m.GetGraph() + + search := graph.SearchGraph("document", "ext_obj") + stretch.Equal(t, len(search), 1) + + search = graph.SearchGraph("document", "user") + stretch.Equal(t, len(search), 2) traversal := graph.TraverseGraph("document") - require.Equal(t, len(traversal), 5) + stretch.Equal(t, len(traversal), 3) traversal = graph.TraverseGraph("group") - require.Equal(t, len(traversal), 3) - require.Equal(t, traversal, []string{"group", "member", "user"}) + stretch.Equal(t, len(traversal), 2) + stretch.Contains(t, traversal, []string{"group", "member", "user", "rel_name", "ext_obj"}) + stretch.Contains(t, traversal, []string{"group", "member", "group"}) } From b1007c8673b11492345154a97b9941f7617130e6 Mon Sep 17 00:00:00 2001 From: oanatmaria Date: Wed, 25 Oct 2023 15:11:05 +0300 Subject: [PATCH 3/4] Fix lint --- graph/graph.go | 18 ++++++++++-------- model/model_test.go | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/graph/graph.go b/graph/graph.go index 5fe534f..c9b4a6b 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -57,15 +57,17 @@ func (g *Graph) dFS(source int, visited []map[int]bool, traversal [][]string, in visited[i] = make(map[int]bool, len(g.adjMatrix[source][i])) } for j := 0; j < len(g.adjMatrix[source][i]); j++ { - if !visited[i][j] { - visited[i][j] = true - if len(traversal) < index+1 { - traversal = append(traversal, make([]string, 0)) - } - traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) - traversal = g.dFS(i, visited, traversal, index) - index = index + 1 + if visited[i][j] { + continue } + visited[i][j] = true + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) + traversal = g.dFS(i, visited, traversal, index) + index++ + } } } diff --git a/model/model_test.go b/model/model_test.go index 1e95ed0..ff4fcae 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -275,7 +275,7 @@ func TestGraph(t *testing.T) { graph := m.GetGraph() search := graph.SearchGraph("document", "ext_obj") - stretch.Equal(t, len(search), 1) + // stretch.Equal(t, len(search), 1) search = graph.SearchGraph("document", "user") stretch.Equal(t, len(search), 2) From 809fbf81d5832a87590b32431b7c38662c3563be Mon Sep 17 00:00:00 2001 From: oanatmaria Date: Fri, 27 Oct 2023 15:05:59 +0300 Subject: [PATCH 4/4] Only implement findPaths --- graph/graph.go | 82 +++++++++++++++++---------------------------- model/model_test.go | 44 ++++++++++++++++++------ 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/graph/graph.go b/graph/graph.go index c9b4a6b..adb3e93 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -1,7 +1,5 @@ package graph -import "github.com/samber/lo" - type Relations []string type Graph struct { @@ -29,14 +27,8 @@ func (g *Graph) AddNode(nodeName string) { func (g *Graph) AddEdge(source, dest, edgeName string) { sourceIndex := g.findVertexIndexByName(source) destIndex := g.findVertexIndexByName(dest) - if destIndex == -1 { - g.AddNode(dest) - destIndex = g.findVertexIndexByName(dest) - } - - if sourceIndex == -1 { - g.AddNode(source) - sourceIndex = g.findVertexIndexByName(source) + if destIndex == -1 || sourceIndex == -1 { + return } if len(g.adjMatrix[sourceIndex]) == 0 { @@ -46,52 +38,41 @@ func (g *Graph) AddEdge(source, dest, edgeName string) { g.adjMatrix[sourceIndex][destIndex] = append(g.adjMatrix[sourceIndex][destIndex], edgeName) } -func (g *Graph) dFS(source int, visited []map[int]bool, traversal [][]string, index int) [][]string { +func (g *Graph) dFS(source, dest int, visited []map[int]bool, traversal [][]string, index int) [][]string { if len(traversal) < index+1 { traversal = append(traversal, make([]string, 0)) } traversal[index] = append(traversal[index], g.nodes[source]) - for i := 0; i < len(g.nodes); i++ { - if len(g.adjMatrix[source]) > i && len(g.adjMatrix[source][i]) != 0 { - if len(visited[i]) == 0 { - visited[i] = make(map[int]bool, len(g.adjMatrix[source][i])) - } - for j := 0; j < len(g.adjMatrix[source][i]); j++ { - if visited[i][j] { - continue + + if source == dest { + return traversal + } else { + for i := 0; i < len(g.nodes); i++ { + if len(g.adjMatrix[source]) > i && len(g.adjMatrix[source][i]) != 0 { + if len(visited[i]) == 0 { + visited[i] = make(map[int]bool, len(g.adjMatrix[source][i])) } - visited[i][j] = true - if len(traversal) < index+1 { - traversal = append(traversal, make([]string, 0)) + for j := 0; j < len(g.adjMatrix[source][i]); j++ { + if visited[i][j] { + continue + } + visited[i][j] = true + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) + traversal = g.dFS(i, dest, visited, traversal, index) + index++ } - traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) - traversal = g.dFS(i, visited, traversal, index) - index++ - + visited[i] = make(map[int]bool, 0) } } } - return traversal -} - -func (g *Graph) TraverseGraph(startVertex string) [][]string { - visited := make([]map[int]bool, len(g.nodes)) - traversal := make([][]string, 0) - startIndex := g.findVertexIndexByName(startVertex) - if startIndex == -1 { - return traversal - } - traversal = g.dFS(startIndex, visited, traversal, 0) - for i, paths := range traversal { - if paths[0] != startVertex { - traversal[i] = append([]string{startVertex}, paths...) - } - } return traversal } -func (g *Graph) SearchGraph(startVertex, destVertx string) [][]string { +func (g *Graph) FindPaths(startVertex, destVertx string) [][]string { visited := make([]map[int]bool, len(g.nodes)) traversal := make([][]string, 0) startIndex := g.findVertexIndexByName(startVertex) @@ -99,19 +80,16 @@ func (g *Graph) SearchGraph(startVertex, destVertx string) [][]string { if startIndex == -1 || destIndex == -1 { return traversal } - result := make([][]string, 0) - traversal = g.dFS(startIndex, visited, traversal, 0) + + traversal = g.dFS(startIndex, destIndex, visited, traversal, 0) for i, paths := range traversal { - _, found := lo.Find(paths, func(elem string) bool { return elem == destVertx }) - if found { - if paths[0] != startVertex { - traversal[i] = append([]string{startVertex}, paths...) - } - result = append(result, traversal[i]) + if paths[0] != startVertex { + traversal[i] = append([]string{startVertex}, paths...) } + } - return result + return traversal } func (g *Graph) findVertexIndexByName(name string) int { diff --git a/model/model_test.go b/model/model_test.go index ff4fcae..45ae065 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -238,6 +238,7 @@ func TestGraph(t *testing.T) { }, }, }, + model.ObjectName("ext_obj"): {}, model.ObjectName("group"): { Relations: map[model.RelationName][]*model.Relation{ model.RelationName("member"): { @@ -272,18 +273,41 @@ func TestGraph(t *testing.T) { }, }, } + + docExtObjResults := [][]string{ + {"document", "writer", "user", "rel_name", "ext_obj"}, + {"document", "reader", "user", "rel_name", "ext_obj"}, + {"document", "parent_folder", "folder", "owner", "user", "rel_name", "ext_obj"}, + } + + docUserResults := [][]string{ + {"document", "writer", "user"}, + {"document", "reader", "user"}, + {"document", "parent_folder", "folder", "owner", "user"}, + } + + groupExtObjResults := [][]string{ + {"group", "member", "group", "member", "user", "rel_name", "ext_obj"}, + {"group", "member", "user", "rel_name", "ext_obj"}, + } + graph := m.GetGraph() - search := graph.SearchGraph("document", "ext_obj") - // stretch.Equal(t, len(search), 1) + search := graph.FindPaths("document", "ext_obj") + stretch.Equal(t, len(search), 3) + for _, expected := range docExtObjResults { + stretch.Contains(t, search, expected) + } - search = graph.SearchGraph("document", "user") - stretch.Equal(t, len(search), 2) + search = graph.FindPaths("document", "user") + stretch.Equal(t, len(search), 3) + for _, expected := range docUserResults { + stretch.Contains(t, search, expected) + } - traversal := graph.TraverseGraph("document") - stretch.Equal(t, len(traversal), 3) - traversal = graph.TraverseGraph("group") - stretch.Equal(t, len(traversal), 2) - stretch.Contains(t, traversal, []string{"group", "member", "user", "rel_name", "ext_obj"}) - stretch.Contains(t, traversal, []string{"group", "member", "group"}) + search = graph.FindPaths("group", "ext_obj") + stretch.Equal(t, len(search), 2) + for _, expected := range groupExtObjResults { + stretch.Contains(t, search, expected) + } }