Skip to content

Commit

Permalink
edge methods in the Respository interface are implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Dec 21, 2024
1 parent 8b2a3a3 commit cb13bcd
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 118 deletions.
247 changes: 138 additions & 109 deletions repository/neo4j/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"

neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/owasp-amass/asset-db/types"
oam "github.com/owasp-amass/open-asset-model"
"gorm.io/gorm"
)

// CreateEdge creates an edge between two entities in the database.
Expand All @@ -32,50 +32,57 @@ func (neo *neoRepository) CreateEdge(edge *types.Edge) (*types.Edge, error) {
edge.FromEntity.Asset.AssetType(), edge.Relation.Label(), edge.ToEntity.Asset.AssetType())
}

var updated time.Time
if edge.LastSeen.IsZero() {
updated = time.Now().UTC()
} else {
updated = edge.LastSeen.UTC()
edge.LastSeen = time.Now()
}
// ensure that duplicate relationships are not entered into the database
if e, found := neo.isDuplicateEdge(edge, updated); found {
if e, found := neo.isDuplicateEdge(edge, edge.LastSeen); found {
return e, nil
}

fromEntityId, err := strconv.ParseUint(edge.FromEntity.ID, 10, 64)
if err != nil {
return nil, err
if edge.CreatedAt.IsZero() {
edge.CreatedAt = time.Now()
}

toEntityId, err := strconv.ParseUint(edge.ToEntity.ID, 10, 64)
props, err := edgePropsMap(edge)
if err != nil {
return nil, err
}

jsonContent, err := edge.Relation.JSON()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

from := fmt.Sprintf("MATCH (from:Entity {entity_id: '%s'})", edge.FromEntity.ID)
to := fmt.Sprintf("MATCH (to:Entity {entity_id: '%s'})", edge.ToEntity.ID)
query := fmt.Sprintf("%s %s CREATE (from)-[r:%s $props]->(to) RETURN r", from, to, strings.ToUpper(edge.Relation.Label()))
result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query,
map[string]interface{}{"props": props},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}
if len(result.Records) == 0 {
return nil, errors.New("no records returned from the query")
}

r := Edge{
Type: string(edge.Relation.RelationType()),
Content: jsonContent,
FromEntityID: fromEntityId,
ToEntityID: toEntityId,
UpdatedAt: updated,
rel, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](result.Records[0], "r")
if err != nil {
return nil, err
}
if edge.CreatedAt.IsZero() {
r.CreatedAt = time.Now().UTC()
} else {
r.CreatedAt = edge.CreatedAt.UTC()
if isnil {
return nil, errors.New("the record value for the relationship is nil")
}

result := sql.db.Create(&r)
if err := result.Error; err != nil {
r, err := relationshipToEdge(rel)
if err != nil {
return nil, err
}
return toEdge(r), nil
r.FromEntity = edge.FromEntity
r.ToEntity = edge.ToEntity

return r, nil
}

// isDuplicateEdge checks if the relationship between source and dest already exists.
Expand Down Expand Up @@ -104,41 +111,15 @@ func (neo *neoRepository) isDuplicateEdge(edge *types.Edge, updated time.Time) (

// edgeSeen updates the updated_at timestamp for the specified edge.
func (neo *neoRepository) edgeSeen(rel *types.Edge, updated time.Time) error {
id, err := strconv.ParseUint(rel.ID, 10, 64)
if err != nil {
return err
}

jsonContent, err := rel.Relation.JSON()
if err != nil {
return err
}

fromEntityId, err := strconv.ParseUint(rel.FromEntity.ID, 10, 64)
if err != nil {
return err
}

toEntityId, err := strconv.ParseUint(rel.ToEntity.ID, 10, 64)
if err != nil {
return err
}

r := Edge{
ID: id,
Type: string(rel.Relation.RelationType()),
Content: jsonContent,
FromEntityID: fromEntityId,
ToEntityID: toEntityId,
CreatedAt: rel.CreatedAt,
UpdatedAt: updated,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result := sql.db.Save(&r)
if err := result.Error; err != nil {
return err
}
return nil
query := fmt.Sprintf("MATCH ()-[r]->() WHERE r.elementId = '%s' SET r.updated_at = localDateTime('%s')", rel.ID, timeToNeo4jTime(updated))
_, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil,
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
return err
}

func (neo *neoRepository) FindEdgeById(id string) (*types.Edge, error) {
Expand Down Expand Up @@ -198,88 +179,136 @@ func (neo *neoRepository) FindEdgeById(id string) (*types.Edge, error) {
// If since.IsZero(), the parameter will be ignored.
// If no labels are specified, all incoming eges are returned.
func (neo *neoRepository) IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) {
entityId, err := strconv.ParseInt(entity.ID, 10, 64)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

var edges []Edge
var result *gorm.DB
if since.IsZero() {
result = sql.db.Where("to_entity_id = ?", entityId).Find(&edges)
} else {
result = sql.db.Where("to_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges)
query := fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})<-[r]-(from:Entity) RETURN r, from.entity_id AS fid", entity.ID)
if !since.IsZero() {
query = fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})<-[r]-(from:Entity) WHERE r.updated_at >= localDateTime('%s') RETURN r, from.entity_id AS fid", entity.ID, timeToNeo4jTime(since))
}
if err := result.Error; err != nil {

result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil,
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}

var results []Edge
if len(labels) > 0 {
for _, edge := range edges {
e := &edge

if rel, err := e.Parse(); err == nil {
for _, label := range labels {
if label == rel.Label() {
results = append(results, edge)
break
}
var results []*types.Edge
for _, record := range result.Records {
r, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](record, "r")
if err != nil {
continue
}
if isnil {
continue
}

if len(labels) > 0 {
var found bool

for _, label := range labels {
if strings.EqualFold(label, r.Type) {
found = true
break
}
}

if !found {
continue
}
}

fid, isnil, err := neo4jdb.GetRecordValue[string](record, "fid")
if err != nil {
continue
}
if isnil {
continue
}
} else {
results = edges

edge, err := relationshipToEdge(r)
if err != nil {
continue
}
edge.FromEntity = &types.Entity{ID: fid}
edge.ToEntity = entity
results = append(results, edge)
}

if len(results) == 0 {
return nil, errors.New("zero edges found")
}
return toEdges(results), nil
return results, nil
}

// OutgoingEdges finds all edges from the entity of the specified labels and last seen after the since parameter.
// If since.IsZero(), the parameter will be ignored.
// If no labels are specified, all outgoing edges are returned.
func (neo *neoRepository) OutgoingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) {
entityId, err := strconv.ParseInt(entity.ID, 10, 64)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

var edges []Edge
var result *gorm.DB
if since.IsZero() {
result = sql.db.Where("from_entity_id = ?", entityId).Find(&edges)
} else {
result = sql.db.Where("from_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges)
query := fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})-[r]->(to:Entity) RETURN r, to.entity_id AS tid", entity.ID)
if !since.IsZero() {
query = fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})-[r]->(to:Entity) WHERE r.updated_at >= localDateTime('%s') RETURN r, to.entity_id AS tid", entity.ID, timeToNeo4jTime(since))
}
if err := result.Error; err != nil {

result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil,
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}

var results []Edge
if len(labels) > 0 {
for _, edge := range edges {
e := &edge

if rel, err := e.Parse(); err == nil {
for _, label := range labels {
if label == rel.Label() {
results = append(results, edge)
break
}
var results []*types.Edge
for _, record := range result.Records {
r, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](record, "r")
if err != nil {
continue
}
if isnil {
continue
}

if len(labels) > 0 {
var found bool

for _, label := range labels {
if strings.EqualFold(label, r.Type) {
found = true
break
}
}

if !found {
continue
}
}

tid, isnil, err := neo4jdb.GetRecordValue[string](record, "tid")
if err != nil {
continue
}
if isnil {
continue
}

edge, err := relationshipToEdge(r)
if err != nil {
continue
}
} else {
results = edges
edge.FromEntity = entity
edge.ToEntity = &types.Entity{ID: tid}
results = append(results, edge)
}

if len(results) == 0 {
return nil, errors.New("zero edges found")
}
return toEdges(results), nil
return results, nil
}

// DeleteEdge removes an edge in the database by its ID.
Expand Down
4 changes: 3 additions & 1 deletion repository/neo4j/edge_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag)
return nil, errors.New("the input edge tag is nil")
}
// ensure that duplicate entities are not entered into the database
if tags, err := neo.FindEdgeTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 {
if tags, err := neo.FindEdgeTagsByContent(input.Property, time.Time{}); err == nil && len(tags) > 0 {
t := tags[0]

if input.Property.PropertyType() != t.Property.PropertyType() {
Expand All @@ -39,6 +39,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag)
return nil, err
}

t.Edge = edge
t.LastSeen = time.Now()
props, err := edgeTagPropsMap(t)
if err != nil {
Expand Down Expand Up @@ -83,6 +84,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag)
input.LastSeen = time.Now()
}

input.Edge = edge
props, err := edgeTagPropsMap(input)
if err != nil {
return nil, err
Expand Down
4 changes: 3 additions & 1 deletion repository/neo4j/entity_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent
return nil, errors.New("the input entity tag is nil")
}
// ensure that duplicate entities are not entered into the database
if tags, err := neo.FindEntityTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 {
if tags, err := neo.FindEntityTagsByContent(input.Property, time.Time{}); err == nil && len(tags) > 0 {
t := tags[0]

if input.Property.PropertyType() != t.Property.PropertyType() {
Expand All @@ -39,6 +39,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent
return nil, err
}

t.Entity = entity
t.LastSeen = time.Now()
props, err := entityTagPropsMap(t)
if err != nil {
Expand Down Expand Up @@ -83,6 +84,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent
input.LastSeen = time.Now()
}

input.Entity = entity
props, err := entityTagPropsMap(input)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit cb13bcd

Please sign in to comment.