Skip to content

Commit

Permalink
Merge pull request #7 from aserto-dev/implement_graph
Browse files Browse the repository at this point in the history
Add model graph
  • Loading branch information
oanatmaria authored Oct 27, 2023
2 parents 2ef8edd + 809fbf8 commit aaf6a30
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 0 deletions.
102 changes: 102 additions & 0 deletions graph/graph.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package graph

type Relations []string

type Graph struct {
nodes []string
adjMatrix [][]Relations
}

func NewGraph() *Graph {
return &Graph{
nodes: make([]string, 0),
adjMatrix: make([][]Relations, 0),
}
}

func (g *Graph) AddNode(nodeName string) {
for _, vertexName := range g.nodes {
if nodeName == vertexName {
return
}
}
g.nodes = append(g.nodes, nodeName)
g.adjMatrix = append(g.adjMatrix, make([]Relations, 0))
}

func (g *Graph) AddEdge(source, dest, edgeName string) {
sourceIndex := g.findVertexIndexByName(source)
destIndex := g.findVertexIndexByName(dest)
if destIndex == -1 || sourceIndex == -1 {
return
}

if len(g.adjMatrix[sourceIndex]) == 0 {
g.adjMatrix[sourceIndex] = make([]Relations, len(g.nodes))
}

g.adjMatrix[sourceIndex][destIndex] = append(g.adjMatrix[sourceIndex][destIndex], edgeName)
}

func (g *Graph) dFS(source, dest int, visited []map[int]bool, traversal [][]string, index int) [][]string {
if len(traversal) < index+1 {
traversal = append(traversal, make([]string, 0))
}
traversal[index] = append(traversal[index], g.nodes[source])

if source == dest {
return traversal
} else {
for i := 0; i < len(g.nodes); i++ {
if len(g.adjMatrix[source]) > i && len(g.adjMatrix[source][i]) != 0 {
if len(visited[i]) == 0 {
visited[i] = make(map[int]bool, len(g.adjMatrix[source][i]))
}
for j := 0; j < len(g.adjMatrix[source][i]); j++ {
if visited[i][j] {
continue
}
visited[i][j] = true
if len(traversal) < index+1 {
traversal = append(traversal, make([]string, 0))
}
traversal[index] = append(traversal[index], g.adjMatrix[source][i][j])
traversal = g.dFS(i, dest, visited, traversal, index)
index++
}
visited[i] = make(map[int]bool, 0)
}
}
}

return traversal
}

func (g *Graph) FindPaths(startVertex, destVertx string) [][]string {
visited := make([]map[int]bool, len(g.nodes))
traversal := make([][]string, 0)
startIndex := g.findVertexIndexByName(startVertex)
destIndex := g.findVertexIndexByName(destVertx)
if startIndex == -1 || destIndex == -1 {
return traversal
}

traversal = g.dFS(startIndex, destIndex, visited, traversal, 0)
for i, paths := range traversal {
if paths[0] != startVertex {
traversal[i] = append([]string{startVertex}, paths...)
}

}

return traversal
}

func (g *Graph) findVertexIndexByName(name string) int {
for index, vertexName := range g.nodes {
if vertexName == name {
return index
}
}
return -1
}
22 changes: 22 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"io"
"time"

"github.com/aserto-dev/azm/graph"
)

const ModelVersion int = 1
Expand Down Expand Up @@ -77,6 +79,26 @@ func New(r io.Reader) (*Model, error) {
return &m, nil
}

func (m *Model) GetGraph() *graph.Graph {
grph := graph.NewGraph()
for objectName := range m.Objects {
grph.AddNode(string(objectName))
}
for objectName, obj := range m.Objects {
for relName, rel := range obj.Relations {
for _, rl := range rel {
if string(rl.Direct) != "" {
grph.AddEdge(string(objectName), string(rl.Direct), string(relName))
} else if rl.Subject != nil {
grph.AddEdge(string(objectName), string(rl.Subject.Object), string(relName))
}
}
}
}

return grph
}

func (m *Model) Reader() (io.Reader, error) {
b := bytes.Buffer{}
enc := json.NewEncoder(&b)
Expand Down
85 changes: 85 additions & 0 deletions model/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,88 @@ func TestDiff(t *testing.T) {
stretch.Equal(t, len(diffM1WithM3.Removed.Relations), 1)
stretch.Equal(t, diffM1WithM3.Removed.Relations["document"], []model.RelationName{"parent_folder"})
}

func TestGraph(t *testing.T) {
m := model.Model{
Version: 1,
Objects: map[model.ObjectName]*model.Object{
model.ObjectName("user"): {
Relations: map[model.RelationName][]*model.Relation{
model.RelationName("rel_name"): {
&model.Relation{Direct: model.ObjectName("ext_obj")},
},
},
},
model.ObjectName("ext_obj"): {},
model.ObjectName("group"): {
Relations: map[model.RelationName][]*model.Relation{
model.RelationName("member"): {
&model.Relation{Direct: model.ObjectName("user")},
&model.Relation{Subject: &model.SubjectRelation{
Object: model.ObjectName("group"),
Relation: model.RelationName("member"),
}},
},
},
},
model.ObjectName("folder"): {
Relations: map[model.RelationName][]*model.Relation{
model.RelationName("owner"): {
&model.Relation{Direct: model.ObjectName("user")},
},
},
},
model.ObjectName("document"): {
Relations: map[model.RelationName][]*model.Relation{
model.RelationName("parent_folder"): {
{Direct: model.ObjectName("folder")},
},
model.RelationName("writer"): {
{Direct: model.ObjectName("user")},
},
model.RelationName("reader"): {
{Direct: model.ObjectName("user")},
{Wildcard: model.ObjectName("user")},
},
},
},
},
}

docExtObjResults := [][]string{
{"document", "writer", "user", "rel_name", "ext_obj"},
{"document", "reader", "user", "rel_name", "ext_obj"},
{"document", "parent_folder", "folder", "owner", "user", "rel_name", "ext_obj"},
}

docUserResults := [][]string{
{"document", "writer", "user"},
{"document", "reader", "user"},
{"document", "parent_folder", "folder", "owner", "user"},
}

groupExtObjResults := [][]string{
{"group", "member", "group", "member", "user", "rel_name", "ext_obj"},
{"group", "member", "user", "rel_name", "ext_obj"},
}

graph := m.GetGraph()

search := graph.FindPaths("document", "ext_obj")
stretch.Equal(t, len(search), 3)
for _, expected := range docExtObjResults {
stretch.Contains(t, search, expected)
}

search = graph.FindPaths("document", "user")
stretch.Equal(t, len(search), 3)
for _, expected := range docUserResults {
stretch.Contains(t, search, expected)
}

search = graph.FindPaths("group", "ext_obj")
stretch.Equal(t, len(search), 2)
for _, expected := range groupExtObjResults {
stretch.Contains(t, search, expected)
}
}

0 comments on commit aaf6a30

Please sign in to comment.