From 9933df78557d2cdc1eac3ced62f91003a1569693 Mon Sep 17 00:00:00 2001 From: oanatmaria Date: Wed, 11 Oct 2023 18:05:47 +0300 Subject: [PATCH] Add model graph --- graph/graph.go | 69 +++++++++++++++++++++++++++++++++++++++++++++ model/model.go | 22 +++++++++++++++ model/model_test.go | 10 +++++++ 3 files changed, 101 insertions(+) create mode 100644 graph/graph.go diff --git a/graph/graph.go b/graph/graph.go new file mode 100644 index 0000000..b00c9b8 --- /dev/null +++ b/graph/graph.go @@ -0,0 +1,69 @@ +package graph + +type Graph struct { + nodes []string + adjMatrix [][]string +} + +func NewGraph() *Graph { + return &Graph{ + nodes: make([]string, 0), + adjMatrix: make([][]string, 0), + } +} + +func (g *Graph) AddNode(nodeName string) { + for _, vertexName := range g.nodes { + if nodeName == vertexName { + return + } + } + g.nodes = append(g.nodes, nodeName) + g.adjMatrix = append(g.adjMatrix, make([]string, 0)) +} + +func (g *Graph) AddEdge(source, dest, edgeName string) { + sourceIndex := g.findVertexIndexByName(source) + destIndex := g.findVertexIndexByName(dest) + if destIndex == -1 || sourceIndex == -1 { + return + } + + if len(g.adjMatrix[sourceIndex]) == 0 { + g.adjMatrix[sourceIndex] = make([]string, len(g.nodes)) + } + + g.adjMatrix[sourceIndex][destIndex] = edgeName +} + +func (g *Graph) dFS(source int, visited []bool, traversal []string) []string { + visited[source] = true + traversal = append(traversal, g.nodes[source]) + for i := 0; i < len(g.nodes); i++ { + if len(g.adjMatrix[source]) != 0 && g.adjMatrix[source][i] != "" && !visited[i] { + traversal = append(traversal, g.adjMatrix[source][i]) + traversal = g.dFS(i, visited, traversal) + } + } + return traversal +} + +func (g *Graph) TraverseGraph(startVertex string) []string { + visited := make([]bool, len(g.nodes)) + traversal := make([]string, 0) + startIndex := g.findVertexIndexByName(startVertex) + if startIndex == -1 { + return traversal + } + + return g.dFS(startIndex, visited, traversal) +} + +func (g *Graph) findVertexIndexByName(name string) int { + for index, vertexName := range g.nodes { + if vertexName == name { + return index + } + } + return -1 +} diff --git a/model/model.go b/model/model.go index 6e0d6f3..d3fadb8 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,8 @@ import ( "encoding/json" "io" "time" + + "github.com/aserto-dev/azm/graph" ) const ModelVersion int = 1 @@ -77,6 +79,26 @@ func New(r io.Reader) (*Model, error) { return &m, nil } +func (m *Model) GetGraph() *graph.Graph { + grph := graph.NewGraph() + for objectName := range m.Objects { + grph.AddNode(string(objectName)) + } + for objectName, obj := range m.Objects { + for relName, rel := range obj.Relations { + for _, rl := range rel { + if string(rl.Direct) != "" { + grph.AddEdge(string(objectName), string(rl.Direct), string(relName)) + } else if rl.Subject != nil { + grph.AddEdge(string(objectName), string(rl.Subject.Object), string(relName)) + } + } + } + } + + return grph +} + func (m *Model) Reader() (io.Reader, error) { b := bytes.Buffer{} enc := json.NewEncoder(&b) diff --git a/model/model_test.go b/model/model_test.go index 61df45c..57f921d 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -226,3 +226,13 @@ func TestDiff(t *testing.T) { require.Equal(t, len(diffm1m3.Removed.Relations), 1) require.Equal(t, diffm1m3.Removed.Relations["document"], []model.RelationName{"parent_folder"}) } + +func TestGraph(t *testing.T) { + graph := m1.GetGraph() + + traversal := graph.TraverseGraph("document") + require.Equal(t, len(traversal), 5) + traversal = graph.TraverseGraph("group") + require.Equal(t, len(traversal), 3) + require.Equal(t, traversal, []string{"group", "member", "user"}) +}