diff --git a/graph/graph.go b/graph/graph.go new file mode 100644 index 0000000..adb3e93 --- /dev/null +++ b/graph/graph.go @@ -0,0 +1,102 @@ +package graph + +type Relations []string + +type Graph struct { + nodes []string + adjMatrix [][]Relations +} + +func NewGraph() *Graph { + return &Graph{ + nodes: make([]string, 0), + adjMatrix: make([][]Relations, 0), + } +} + +func (g *Graph) AddNode(nodeName string) { + for _, vertexName := range g.nodes { + if nodeName == vertexName { + return + } + } + g.nodes = append(g.nodes, nodeName) + g.adjMatrix = append(g.adjMatrix, make([]Relations, 0)) +} + +func (g *Graph) AddEdge(source, dest, edgeName string) { + sourceIndex := g.findVertexIndexByName(source) + destIndex := g.findVertexIndexByName(dest) + if destIndex == -1 || sourceIndex == -1 { + return + } + + if len(g.adjMatrix[sourceIndex]) == 0 { + g.adjMatrix[sourceIndex] = make([]Relations, len(g.nodes)) + } + + g.adjMatrix[sourceIndex][destIndex] = append(g.adjMatrix[sourceIndex][destIndex], edgeName) +} + +func (g *Graph) dFS(source, dest int, visited []map[int]bool, traversal [][]string, index int) [][]string { + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.nodes[source]) + + if source == dest { + return traversal + } else { + for i := 0; i < len(g.nodes); i++ { + if len(g.adjMatrix[source]) > i && len(g.adjMatrix[source][i]) != 0 { + if len(visited[i]) == 0 { + visited[i] = make(map[int]bool, len(g.adjMatrix[source][i])) + } + for j := 0; j < len(g.adjMatrix[source][i]); j++ { + if visited[i][j] { + continue + } + visited[i][j] = true + if len(traversal) < index+1 { + traversal = append(traversal, make([]string, 0)) + } + traversal[index] = append(traversal[index], g.adjMatrix[source][i][j]) + traversal = g.dFS(i, dest, visited, traversal, index) + index++ + } + visited[i] = make(map[int]bool, 0) + } + } + } + + return traversal +} + +func (g *Graph) FindPaths(startVertex, destVertx string) [][]string { + visited := make([]map[int]bool, len(g.nodes)) + traversal := make([][]string, 0) + startIndex := g.findVertexIndexByName(startVertex) + destIndex := g.findVertexIndexByName(destVertx) + if startIndex == -1 || destIndex == -1 { + return traversal + } + + traversal = g.dFS(startIndex, destIndex, visited, traversal, 0) + for i, paths := range traversal { + if paths[0] != startVertex { + traversal[i] = append([]string{startVertex}, paths...) + } + + } + + return traversal +} + +func (g *Graph) findVertexIndexByName(name string) int { + for index, vertexName := range g.nodes { + if vertexName == name { + return index + } + } + return -1 +} diff --git a/model/model.go b/model/model.go index a51443c..9aef76d 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,8 @@ import ( "encoding/json" "io" "time" + + "github.com/aserto-dev/azm/graph" ) const ModelVersion int = 1 @@ -77,6 +79,26 @@ func New(r io.Reader) (*Model, error) { return &m, nil } +func (m *Model) GetGraph() *graph.Graph { + grph := graph.NewGraph() + for objectName := range m.Objects { + grph.AddNode(string(objectName)) + } + for objectName, obj := range m.Objects { + for relName, rel := range obj.Relations { + for _, rl := range rel { + if string(rl.Direct) != "" { + grph.AddEdge(string(objectName), string(rl.Direct), string(relName)) + } else if rl.Subject != nil { + grph.AddEdge(string(objectName), string(rl.Subject.Object), string(relName)) + } + } + } + } + + return grph +} + func (m *Model) Reader() (io.Reader, error) { b := bytes.Buffer{} enc := json.NewEncoder(&b) diff --git a/model/model_test.go b/model/model_test.go index a2e1604..45ae065 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -226,3 +226,88 @@ func TestDiff(t *testing.T) { stretch.Equal(t, len(diffM1WithM3.Removed.Relations), 1) stretch.Equal(t, diffM1WithM3.Removed.Relations["document"], []model.RelationName{"parent_folder"}) } + +func TestGraph(t *testing.T) { + m := model.Model{ + Version: 1, + Objects: map[model.ObjectName]*model.Object{ + model.ObjectName("user"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("rel_name"): { + &model.Relation{Direct: model.ObjectName("ext_obj")}, + }, + }, + }, + model.ObjectName("ext_obj"): {}, + model.ObjectName("group"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("member"): { + &model.Relation{Direct: model.ObjectName("user")}, + &model.Relation{Subject: &model.SubjectRelation{ + Object: model.ObjectName("group"), + Relation: model.RelationName("member"), + }}, + }, + }, + }, + model.ObjectName("folder"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("owner"): { + &model.Relation{Direct: model.ObjectName("user")}, + }, + }, + }, + model.ObjectName("document"): { + Relations: map[model.RelationName][]*model.Relation{ + model.RelationName("parent_folder"): { + {Direct: model.ObjectName("folder")}, + }, + model.RelationName("writer"): { + {Direct: model.ObjectName("user")}, + }, + model.RelationName("reader"): { + {Direct: model.ObjectName("user")}, + {Wildcard: model.ObjectName("user")}, + }, + }, + }, + }, + } + + docExtObjResults := [][]string{ + {"document", "writer", "user", "rel_name", "ext_obj"}, + {"document", "reader", "user", "rel_name", "ext_obj"}, + {"document", "parent_folder", "folder", "owner", "user", "rel_name", "ext_obj"}, + } + + docUserResults := [][]string{ + {"document", "writer", "user"}, + {"document", "reader", "user"}, + {"document", "parent_folder", "folder", "owner", "user"}, + } + + groupExtObjResults := [][]string{ + {"group", "member", "group", "member", "user", "rel_name", "ext_obj"}, + {"group", "member", "user", "rel_name", "ext_obj"}, + } + + graph := m.GetGraph() + + search := graph.FindPaths("document", "ext_obj") + stretch.Equal(t, len(search), 3) + for _, expected := range docExtObjResults { + stretch.Contains(t, search, expected) + } + + search = graph.FindPaths("document", "user") + stretch.Equal(t, len(search), 3) + for _, expected := range docUserResults { + stretch.Contains(t, search, expected) + } + + search = graph.FindPaths("group", "ext_obj") + stretch.Equal(t, len(search), 2) + for _, expected := range groupExtObjResults { + stretch.Contains(t, search, expected) + } +}